Coverage for trnbl\loggers\wandb.py: 0%
46 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-17 02:23 -0700
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-17 02:23 -0700
1import datetime
2import json
3import logging
4from typing import Any
5from pathlib import Path
6import sys
8import wandb
9from wandb.sdk.wandb_run import Run, Artifact
11from trnbl.loggers.base import TrainingLoggerBase
14class WandbLogger(TrainingLoggerBase):
15 """wrapper around wandb logging for `TrainingLoggerBase`. create using `WandbLogger.create(config, project, job_type)`"""
17 def __init__(self, run: Run):
18 self._run: Run = run
20 @classmethod
21 def create(
22 cls,
23 config: dict,
24 project: str | None = None,
25 job_type: str | None = None,
26 logging_fmt: str = "%(asctime)s %(levelname)s %(message)s",
27 logging_datefmt: str = "%Y-%m-%d %H:%M:%S",
28 wandb_kwargs: dict | None = None,
29 ) -> "WandbLogger":
30 logging.basicConfig(
31 stream=sys.stdout,
32 level=logging.INFO,
33 format=logging_fmt,
34 datefmt=logging_datefmt,
35 )
37 run: Run # type: ignore[return-value]
38 run = wandb.init(
39 config=config,
40 project=project,
41 job_type=job_type,
42 **(wandb_kwargs if wandb_kwargs else {}),
43 )
45 assert run is not None, f"wandb.init returned None: {wandb_kwargs}"
47 logger: WandbLogger = WandbLogger(run)
48 # TODO: why are we ignoring type checking here?
49 logger.progress(f"{config =}") # type: ignore[attr-defined]
50 return logger
52 def debug(self, message: str, **kwargs) -> None:
53 if kwargs:
54 message += f" {kwargs =}"
55 logging.debug(message)
57 def message(self, message: str, **kwargs) -> None:
58 if kwargs:
59 message += f" {kwargs =}"
60 logging.info(message)
62 def metrics(self, data: dict[str, Any]) -> None:
63 self._run.log(data)
65 def artifact(
66 self,
67 path: Path,
68 type: str,
69 aliases: list[str] | None = None,
70 metadata: dict | None = None,
71 ) -> None:
72 artifact: Artifact = wandb.Artifact(name=self._run.id, type=type)
73 artifact.add_file(str(path))
74 if metadata:
75 artifact.description = json.dumps(
76 dict(
77 timestamp=datetime.datetime.now().isoformat(),
78 path=path.as_posix(),
79 type=type,
80 aliases=aliases,
81 metadata=metadata if metadata else {},
82 )
83 )
84 self._run.log_artifact(artifact, aliases=aliases)
86 @property
87 def url(self) -> str:
88 # TODO: get_url returns `None` for offline runs. need to adjust allowed return types in superclass
89 return str(self._run.get_url())
91 @property
92 def run_path(self) -> Path:
93 return Path(self._run._get_path())
95 def flush(self) -> None:
96 self._run.save()
98 def finish(self) -> None:
99 """Finish logging"""
100 self._run.finish()