Coverage for trnbl\loggers\wandb.py: 0%

46 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-17 02:23 -0700

1import datetime 

2import json 

3import logging 

4from typing import Any 

5from pathlib import Path 

6import sys 

7 

8import wandb 

9from wandb.sdk.wandb_run import Run, Artifact 

10 

11from trnbl.loggers.base import TrainingLoggerBase 

12 

13 

14class WandbLogger(TrainingLoggerBase): 

15 """wrapper around wandb logging for `TrainingLoggerBase`. create using `WandbLogger.create(config, project, job_type)`""" 

16 

17 def __init__(self, run: Run): 

18 self._run: Run = run 

19 

20 @classmethod 

21 def create( 

22 cls, 

23 config: dict, 

24 project: str | None = None, 

25 job_type: str | None = None, 

26 logging_fmt: str = "%(asctime)s %(levelname)s %(message)s", 

27 logging_datefmt: str = "%Y-%m-%d %H:%M:%S", 

28 wandb_kwargs: dict | None = None, 

29 ) -> "WandbLogger": 

30 logging.basicConfig( 

31 stream=sys.stdout, 

32 level=logging.INFO, 

33 format=logging_fmt, 

34 datefmt=logging_datefmt, 

35 ) 

36 

37 run: Run # type: ignore[return-value] 

38 run = wandb.init( 

39 config=config, 

40 project=project, 

41 job_type=job_type, 

42 **(wandb_kwargs if wandb_kwargs else {}), 

43 ) 

44 

45 assert run is not None, f"wandb.init returned None: {wandb_kwargs}" 

46 

47 logger: WandbLogger = WandbLogger(run) 

48 # TODO: why are we ignoring type checking here? 

49 logger.progress(f"{config =}") # type: ignore[attr-defined] 

50 return logger 

51 

52 def debug(self, message: str, **kwargs) -> None: 

53 if kwargs: 

54 message += f" {kwargs =}" 

55 logging.debug(message) 

56 

57 def message(self, message: str, **kwargs) -> None: 

58 if kwargs: 

59 message += f" {kwargs =}" 

60 logging.info(message) 

61 

62 def metrics(self, data: dict[str, Any]) -> None: 

63 self._run.log(data) 

64 

65 def artifact( 

66 self, 

67 path: Path, 

68 type: str, 

69 aliases: list[str] | None = None, 

70 metadata: dict | None = None, 

71 ) -> None: 

72 artifact: Artifact = wandb.Artifact(name=self._run.id, type=type) 

73 artifact.add_file(str(path)) 

74 if metadata: 

75 artifact.description = json.dumps( 

76 dict( 

77 timestamp=datetime.datetime.now().isoformat(), 

78 path=path.as_posix(), 

79 type=type, 

80 aliases=aliases, 

81 metadata=metadata if metadata else {}, 

82 ) 

83 ) 

84 self._run.log_artifact(artifact, aliases=aliases) 

85 

86 @property 

87 def url(self) -> str: 

88 # TODO: get_url returns `None` for offline runs. need to adjust allowed return types in superclass 

89 return str(self._run.get_url()) 

90 

91 @property 

92 def run_path(self) -> Path: 

93 return Path(self._run._get_path()) 

94 

95 def flush(self) -> None: 

96 self._run.save() 

97 

98 def finish(self) -> None: 

99 """Finish logging""" 

100 self._run.finish()