docs for trnbl v0.1.0
View Source on GitHub

trnbl.loggers.wandb


  1import datetime
  2import json
  3import logging
  4from typing import Any
  5from pathlib import Path
  6import sys
  7
  8import wandb
  9from wandb.sdk.wandb_run import Run, Artifact
 10
 11from trnbl.loggers.base import TrainingLoggerBase
 12
 13
 14class WandbLogger(TrainingLoggerBase):
 15	"""wrapper around wandb logging for `TrainingLoggerBase`. create using `WandbLogger.create(config, project, job_type)`"""
 16
 17	def __init__(self, run: Run):
 18		self._run: Run = run
 19
 20	@classmethod
 21	def create(
 22		cls,
 23		config: dict,
 24		project: str | None = None,
 25		job_type: str | None = None,
 26		logging_fmt: str = "%(asctime)s %(levelname)s %(message)s",
 27		logging_datefmt: str = "%Y-%m-%d %H:%M:%S",
 28		wandb_kwargs: dict | None = None,
 29	) -> "WandbLogger":
 30		logging.basicConfig(
 31			stream=sys.stdout,
 32			level=logging.INFO,
 33			format=logging_fmt,
 34			datefmt=logging_datefmt,
 35		)
 36
 37		run: Run  # type: ignore[return-value]
 38		run = wandb.init(
 39			config=config,
 40			project=project,
 41			job_type=job_type,
 42			**(wandb_kwargs if wandb_kwargs else {}),
 43		)
 44
 45		assert run is not None, f"wandb.init returned None: {wandb_kwargs}"
 46
 47		logger: WandbLogger = WandbLogger(run)
 48		# TODO: why are we ignoring type checking here?
 49		logger.progress(f"{config =}")  # type: ignore[attr-defined]
 50		return logger
 51
 52	def debug(self, message: str, **kwargs) -> None:
 53		if kwargs:
 54			message += f" {kwargs =}"
 55		logging.debug(message)
 56
 57	def message(self, message: str, **kwargs) -> None:
 58		if kwargs:
 59			message += f" {kwargs =}"
 60		logging.info(message)
 61
 62	def metrics(self, data: dict[str, Any]) -> None:
 63		self._run.log(data)
 64
 65	def artifact(
 66		self,
 67		path: Path,
 68		type: str,
 69		aliases: list[str] | None = None,
 70		metadata: dict | None = None,
 71	) -> None:
 72		artifact: Artifact = wandb.Artifact(name=self._run.id, type=type)
 73		artifact.add_file(str(path))
 74		if metadata:
 75			artifact.description = json.dumps(
 76				dict(
 77					timestamp=datetime.datetime.now().isoformat(),
 78					path=path.as_posix(),
 79					type=type,
 80					aliases=aliases,
 81					metadata=metadata if metadata else {},
 82				)
 83			)
 84		self._run.log_artifact(artifact, aliases=aliases)
 85
 86	@property
 87	def url(self) -> str:
 88		# TODO: get_url returns `None` for offline runs. need to adjust allowed return types in superclass
 89		return str(self._run.get_url())
 90
 91	@property
 92	def run_path(self) -> Path:
 93		return Path(self._run._get_path())
 94
 95	def flush(self) -> None:
 96		self._run.save()
 97
 98	def finish(self) -> None:
 99		"""Finish logging"""
100		self._run.finish()

 15class WandbLogger(TrainingLoggerBase):
 16	"""wrapper around wandb logging for `TrainingLoggerBase`. create using `WandbLogger.create(config, project, job_type)`"""
 17
 18	def __init__(self, run: Run):
 19		self._run: Run = run
 20
 21	@classmethod
 22	def create(
 23		cls,
 24		config: dict,
 25		project: str | None = None,
 26		job_type: str | None = None,
 27		logging_fmt: str = "%(asctime)s %(levelname)s %(message)s",
 28		logging_datefmt: str = "%Y-%m-%d %H:%M:%S",
 29		wandb_kwargs: dict | None = None,
 30	) -> "WandbLogger":
 31		logging.basicConfig(
 32			stream=sys.stdout,
 33			level=logging.INFO,
 34			format=logging_fmt,
 35			datefmt=logging_datefmt,
 36		)
 37
 38		run: Run  # type: ignore[return-value]
 39		run = wandb.init(
 40			config=config,
 41			project=project,
 42			job_type=job_type,
 43			**(wandb_kwargs if wandb_kwargs else {}),
 44		)
 45
 46		assert run is not None, f"wandb.init returned None: {wandb_kwargs}"
 47
 48		logger: WandbLogger = WandbLogger(run)
 49		# TODO: why are we ignoring type checking here?
 50		logger.progress(f"{config =}")  # type: ignore[attr-defined]
 51		return logger
 52
 53	def debug(self, message: str, **kwargs) -> None:
 54		if kwargs:
 55			message += f" {kwargs =}"
 56		logging.debug(message)
 57
 58	def message(self, message: str, **kwargs) -> None:
 59		if kwargs:
 60			message += f" {kwargs =}"
 61		logging.info(message)
 62
 63	def metrics(self, data: dict[str, Any]) -> None:
 64		self._run.log(data)
 65
 66	def artifact(
 67		self,
 68		path: Path,
 69		type: str,
 70		aliases: list[str] | None = None,
 71		metadata: dict | None = None,
 72	) -> None:
 73		artifact: Artifact = wandb.Artifact(name=self._run.id, type=type)
 74		artifact.add_file(str(path))
 75		if metadata:
 76			artifact.description = json.dumps(
 77				dict(
 78					timestamp=datetime.datetime.now().isoformat(),
 79					path=path.as_posix(),
 80					type=type,
 81					aliases=aliases,
 82					metadata=metadata if metadata else {},
 83				)
 84			)
 85		self._run.log_artifact(artifact, aliases=aliases)
 86
 87	@property
 88	def url(self) -> str:
 89		# TODO: get_url returns `None` for offline runs. need to adjust allowed return types in superclass
 90		return str(self._run.get_url())
 91
 92	@property
 93	def run_path(self) -> Path:
 94		return Path(self._run._get_path())
 95
 96	def flush(self) -> None:
 97		self._run.save()
 98
 99	def finish(self) -> None:
100		"""Finish logging"""
101		self._run.finish()

wrapper around wandb logging for TrainingLoggerBase. create using WandbLogger.create(config, project, job_type)

WandbLogger(run: wandb.sdk.wandb_run.Run)
18	def __init__(self, run: Run):
19		self._run: Run = run
@classmethod
def create( cls, config: dict, project: str | None = None, job_type: str | None = None, logging_fmt: str = '%(asctime)s %(levelname)s %(message)s', logging_datefmt: str = '%Y-%m-%d %H:%M:%S', wandb_kwargs: dict | None = None) -> WandbLogger:
21	@classmethod
22	def create(
23		cls,
24		config: dict,
25		project: str | None = None,
26		job_type: str | None = None,
27		logging_fmt: str = "%(asctime)s %(levelname)s %(message)s",
28		logging_datefmt: str = "%Y-%m-%d %H:%M:%S",
29		wandb_kwargs: dict | None = None,
30	) -> "WandbLogger":
31		logging.basicConfig(
32			stream=sys.stdout,
33			level=logging.INFO,
34			format=logging_fmt,
35			datefmt=logging_datefmt,
36		)
37
38		run: Run  # type: ignore[return-value]
39		run = wandb.init(
40			config=config,
41			project=project,
42			job_type=job_type,
43			**(wandb_kwargs if wandb_kwargs else {}),
44		)
45
46		assert run is not None, f"wandb.init returned None: {wandb_kwargs}"
47
48		logger: WandbLogger = WandbLogger(run)
49		# TODO: why are we ignoring type checking here?
50		logger.progress(f"{config =}")  # type: ignore[attr-defined]
51		return logger
def debug(self, message: str, **kwargs) -> None:
53	def debug(self, message: str, **kwargs) -> None:
54		if kwargs:
55			message += f" {kwargs =}"
56		logging.debug(message)

log a debug message which will be saved, but not printed

def message(self, message: str, **kwargs) -> None:
58	def message(self, message: str, **kwargs) -> None:
59		if kwargs:
60			message += f" {kwargs =}"
61		logging.info(message)

log a progress message, which will be printed to stdout

def metrics(self, data: dict[str, typing.Any]) -> None:
63	def metrics(self, data: dict[str, Any]) -> None:
64		self._run.log(data)

Log a dictionary of metrics

def artifact( self, path: pathlib.Path, type: str, aliases: list[str] | None = None, metadata: dict | None = None) -> None:
66	def artifact(
67		self,
68		path: Path,
69		type: str,
70		aliases: list[str] | None = None,
71		metadata: dict | None = None,
72	) -> None:
73		artifact: Artifact = wandb.Artifact(name=self._run.id, type=type)
74		artifact.add_file(str(path))
75		if metadata:
76			artifact.description = json.dumps(
77				dict(
78					timestamp=datetime.datetime.now().isoformat(),
79					path=path.as_posix(),
80					type=type,
81					aliases=aliases,
82					metadata=metadata if metadata else {},
83				)
84			)
85		self._run.log_artifact(artifact, aliases=aliases)

log an artifact from a file

url: str
87	@property
88	def url(self) -> str:
89		# TODO: get_url returns `None` for offline runs. need to adjust allowed return types in superclass
90		return str(self._run.get_url())

Get the URL for the current logging run

run_path: pathlib.Path
92	@property
93	def run_path(self) -> Path:
94		return Path(self._run._get_path())

Get the path to the current logging run

def flush(self) -> None:
96	def flush(self) -> None:
97		self._run.save()

Flush the logger

def finish(self) -> None:
 99	def finish(self) -> None:
100		"""Finish logging"""
101		self._run.finish()

Finish logging