trnbl.loggers.wandb
1import datetime 2import json 3import logging 4from typing import Any 5from pathlib import Path 6import sys 7 8import wandb 9from wandb.sdk.wandb_run import Run, Artifact 10 11from trnbl.loggers.base import TrainingLoggerBase 12 13 14class WandbLogger(TrainingLoggerBase): 15 """wrapper around wandb logging for `TrainingLoggerBase`. create using `WandbLogger.create(config, project, job_type)`""" 16 17 def __init__(self, run: Run): 18 self._run: Run = run 19 20 @classmethod 21 def create( 22 cls, 23 config: dict, 24 project: str | None = None, 25 job_type: str | None = None, 26 logging_fmt: str = "%(asctime)s %(levelname)s %(message)s", 27 logging_datefmt: str = "%Y-%m-%d %H:%M:%S", 28 wandb_kwargs: dict | None = None, 29 ) -> "WandbLogger": 30 logging.basicConfig( 31 stream=sys.stdout, 32 level=logging.INFO, 33 format=logging_fmt, 34 datefmt=logging_datefmt, 35 ) 36 37 run: Run # type: ignore[return-value] 38 run = wandb.init( 39 config=config, 40 project=project, 41 job_type=job_type, 42 **(wandb_kwargs if wandb_kwargs else {}), 43 ) 44 45 assert run is not None, f"wandb.init returned None: {wandb_kwargs}" 46 47 logger: WandbLogger = WandbLogger(run) 48 # TODO: why are we ignoring type checking here? 49 logger.progress(f"{config =}") # type: ignore[attr-defined] 50 return logger 51 52 def debug(self, message: str, **kwargs) -> None: 53 if kwargs: 54 message += f" {kwargs =}" 55 logging.debug(message) 56 57 def message(self, message: str, **kwargs) -> None: 58 if kwargs: 59 message += f" {kwargs =}" 60 logging.info(message) 61 62 def metrics(self, data: dict[str, Any]) -> None: 63 self._run.log(data) 64 65 def artifact( 66 self, 67 path: Path, 68 type: str, 69 aliases: list[str] | None = None, 70 metadata: dict | None = None, 71 ) -> None: 72 artifact: Artifact = wandb.Artifact(name=self._run.id, type=type) 73 artifact.add_file(str(path)) 74 if metadata: 75 artifact.description = json.dumps( 76 dict( 77 timestamp=datetime.datetime.now().isoformat(), 78 path=path.as_posix(), 79 type=type, 80 aliases=aliases, 81 metadata=metadata if metadata else {}, 82 ) 83 ) 84 self._run.log_artifact(artifact, aliases=aliases) 85 86 @property 87 def url(self) -> str: 88 # TODO: get_url returns `None` for offline runs. need to adjust allowed return types in superclass 89 return str(self._run.get_url()) 90 91 @property 92 def run_path(self) -> Path: 93 return Path(self._run._get_path()) 94 95 def flush(self) -> None: 96 self._run.save() 97 98 def finish(self) -> None: 99 """Finish logging""" 100 self._run.finish()
15class WandbLogger(TrainingLoggerBase): 16 """wrapper around wandb logging for `TrainingLoggerBase`. create using `WandbLogger.create(config, project, job_type)`""" 17 18 def __init__(self, run: Run): 19 self._run: Run = run 20 21 @classmethod 22 def create( 23 cls, 24 config: dict, 25 project: str | None = None, 26 job_type: str | None = None, 27 logging_fmt: str = "%(asctime)s %(levelname)s %(message)s", 28 logging_datefmt: str = "%Y-%m-%d %H:%M:%S", 29 wandb_kwargs: dict | None = None, 30 ) -> "WandbLogger": 31 logging.basicConfig( 32 stream=sys.stdout, 33 level=logging.INFO, 34 format=logging_fmt, 35 datefmt=logging_datefmt, 36 ) 37 38 run: Run # type: ignore[return-value] 39 run = wandb.init( 40 config=config, 41 project=project, 42 job_type=job_type, 43 **(wandb_kwargs if wandb_kwargs else {}), 44 ) 45 46 assert run is not None, f"wandb.init returned None: {wandb_kwargs}" 47 48 logger: WandbLogger = WandbLogger(run) 49 # TODO: why are we ignoring type checking here? 50 logger.progress(f"{config =}") # type: ignore[attr-defined] 51 return logger 52 53 def debug(self, message: str, **kwargs) -> None: 54 if kwargs: 55 message += f" {kwargs =}" 56 logging.debug(message) 57 58 def message(self, message: str, **kwargs) -> None: 59 if kwargs: 60 message += f" {kwargs =}" 61 logging.info(message) 62 63 def metrics(self, data: dict[str, Any]) -> None: 64 self._run.log(data) 65 66 def artifact( 67 self, 68 path: Path, 69 type: str, 70 aliases: list[str] | None = None, 71 metadata: dict | None = None, 72 ) -> None: 73 artifact: Artifact = wandb.Artifact(name=self._run.id, type=type) 74 artifact.add_file(str(path)) 75 if metadata: 76 artifact.description = json.dumps( 77 dict( 78 timestamp=datetime.datetime.now().isoformat(), 79 path=path.as_posix(), 80 type=type, 81 aliases=aliases, 82 metadata=metadata if metadata else {}, 83 ) 84 ) 85 self._run.log_artifact(artifact, aliases=aliases) 86 87 @property 88 def url(self) -> str: 89 # TODO: get_url returns `None` for offline runs. need to adjust allowed return types in superclass 90 return str(self._run.get_url()) 91 92 @property 93 def run_path(self) -> Path: 94 return Path(self._run._get_path()) 95 96 def flush(self) -> None: 97 self._run.save() 98 99 def finish(self) -> None: 100 """Finish logging""" 101 self._run.finish()
wrapper around wandb logging for TrainingLoggerBase. create using WandbLogger.create(config, project, job_type)
@classmethod
def
create( cls, config: dict, project: str | None = None, job_type: str | None = None, logging_fmt: str = '%(asctime)s %(levelname)s %(message)s', logging_datefmt: str = '%Y-%m-%d %H:%M:%S', wandb_kwargs: dict | None = None) -> WandbLogger:
21 @classmethod 22 def create( 23 cls, 24 config: dict, 25 project: str | None = None, 26 job_type: str | None = None, 27 logging_fmt: str = "%(asctime)s %(levelname)s %(message)s", 28 logging_datefmt: str = "%Y-%m-%d %H:%M:%S", 29 wandb_kwargs: dict | None = None, 30 ) -> "WandbLogger": 31 logging.basicConfig( 32 stream=sys.stdout, 33 level=logging.INFO, 34 format=logging_fmt, 35 datefmt=logging_datefmt, 36 ) 37 38 run: Run # type: ignore[return-value] 39 run = wandb.init( 40 config=config, 41 project=project, 42 job_type=job_type, 43 **(wandb_kwargs if wandb_kwargs else {}), 44 ) 45 46 assert run is not None, f"wandb.init returned None: {wandb_kwargs}" 47 48 logger: WandbLogger = WandbLogger(run) 49 # TODO: why are we ignoring type checking here? 50 logger.progress(f"{config =}") # type: ignore[attr-defined] 51 return logger
def
debug(self, message: str, **kwargs) -> None:
53 def debug(self, message: str, **kwargs) -> None: 54 if kwargs: 55 message += f" {kwargs =}" 56 logging.debug(message)
log a debug message which will be saved, but not printed
def
message(self, message: str, **kwargs) -> None:
58 def message(self, message: str, **kwargs) -> None: 59 if kwargs: 60 message += f" {kwargs =}" 61 logging.info(message)
log a progress message, which will be printed to stdout
def
artifact( self, path: pathlib.Path, type: str, aliases: list[str] | None = None, metadata: dict | None = None) -> None:
66 def artifact( 67 self, 68 path: Path, 69 type: str, 70 aliases: list[str] | None = None, 71 metadata: dict | None = None, 72 ) -> None: 73 artifact: Artifact = wandb.Artifact(name=self._run.id, type=type) 74 artifact.add_file(str(path)) 75 if metadata: 76 artifact.description = json.dumps( 77 dict( 78 timestamp=datetime.datetime.now().isoformat(), 79 path=path.as_posix(), 80 type=type, 81 aliases=aliases, 82 metadata=metadata if metadata else {}, 83 ) 84 ) 85 self._run.log_artifact(artifact, aliases=aliases)
log an artifact from a file
url: str
87 @property 88 def url(self) -> str: 89 # TODO: get_url returns `None` for offline runs. need to adjust allowed return types in superclass 90 return str(self._run.get_url())
Get the URL for the current logging run