  docs for trnbl v0.1.1

Contents

[PyPI] [docs] [Checks] [Coverage]

[PyPI - Downloads] [GitHub commit activity] [GitHub closed pull
requests] [code size, bytes]

trnbl – Training Butler

If you train a lot of models, you might often find yourself being
annoyed at swapping between different loggers and fiddling with a bunch
of if batch_idx % some_number == 0 statements. This package aims to fix
that problem.

Firstly, a universal interface to wandb, tensorboard, and a minimal
local logging solution (live demo) is provided.

-   This interface handles logging, error messages, metrics, and
    artifacts.
-   Swapping from one logger to another requires no modifications except
    initializing the new logger you want, and passing that instead.
-   You can even log to multiple loggers at once!

Secondly, a TrainingManager class is provided which handles logging,
artifacts, checkpointing, evaluations, exceptions, and more, with
flexibly customizable intervals.

-   Rather than having to specify all intervals in batches and then
    change everything manually when you change the batch size, dataset
    size, or number of epochs, you specify an interval in samples,
    batches, epochs, or runs. This is computed into the correct number
    of batches or epochs based on the current dataset and batch size.

    -   "1/10 runs" – 10 times a run
    -   "2.5 epochs" – every 2 & 1/2 epochs
    -   (100, "batches") – every 100 batches
    -   "10k samples" – every 10,000 samples

-   an evaluation function is passed in a tuple with an interval, takes
    the model as an argument, and returns the metrics as a dictionary

-   checkpointing is handled automatically, specifying an interval in
    the same way as evaluations

-   models are saved at the end of the run, or if an exception is
    raised, a model.exception.pt is saved

Installation

    pip install trnbl

Usage

also see the notebooks/ folder: - demo_minimal.py for a minimal example
with dummy data - demo.ipynb for an example with all options on the iris
dataset

    import torch
    from torch.utils.data import DataLoader
    from trnbl.logging.local import LocalLogger
    from trnbl.training_manager import TrainingManager

    # set up your dataset, model, optimizer, etc as usual
    dataloader: DataLoader = DataLoader(my_dataset, batch_size=32)
    model: torch.nn.Module = MyModel()
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    # set up a logger -- swap seamlessly between wandb, tensorboard, and local logging
    logger: LocalLogger = LocalLogger(
        project="iris-demo",
        metric_names=["train/loss", "train/acc", "val/loss", "val/acc"],
        train_config=dict(
            model=str(model), optimizer=str(optimizer), criterion=str(criterion)
        ),
    )

    with TrainingManager(
        # pass your model and logger
        model=model,
        logger=logger,
        evals={
            # pass evaluation functions which take a model, and return a dict of metrics
            "1k samples": my_evaluation_function,
            "0.5 epochs": lambda model: logger.get_mem_usage(),
            "100 batches": my_other_eval_function,
        }.items(),
        checkpoint_interval="1/10 run", # will save a checkpoint 10 times per run
    ) as tr:

        # wrap the loops, and length will be automatically calculated
        # and used to figure out when to run evals, checkpoint, etc
        for epoch in tr.epoch_loop(range(120)):
            for inputs, targets in tr.batch_loop(TRAIN_LOADER):
                # your normal training code
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                loss.backward()
                optimizer.step()

                # compute whatever you want every batch
                accuracy = torch.sum(torch.argmax(outputs, dim=1) == targets).item() / len(targets)
                
                # log the metrics
                tr.batch_update(
                    samples=len(targets),
                    **{"train/loss": loss.item(), "train/acc": accuracy},
                )

        # a `model.final.pt` checkpoint will be saved at the end of the run,
        # or a `model.exception.pt` if something crashes inside the context

LocalLogger

Intended as a minimal logging solution for local runs, when you’re too
lazy to set up a new wandb project for a quick test, and want to be able
to easily read the logs. It logs everything as json or jsonl files, and
provides a simple web interface for viewing the data. The web interface
allows:

-   enable or disable the visibility of individual runs
-   filter and sort runs by various stats via an interactive table
-   smooth the data and change axes scales
-   move and resize all plots and tables

You can view a live demo of the web interface here.

[]

TODOs:

-   ☐ BUG: minifying the html/js code causes things to break?

-   frontend:

    -   ☐ batch/epoch size to table in config column group
    -   ☐ box to add aliases to runs
    -   ☐ customizable grid snap size?
    -   ☐ display the grid on the background?

-   deployment:

    -   ☐ demo website for local logger
    -   ☐ CI/CD for website, minification, tests, etc
    -   ☐ migrate to typescript

Submodules

-   loggers
-   training_interval
-   training_manager

API Documentation

-   TrainingInterval
-   TrainingIntervalUnit
-   TrainingLoggerBase
-   TrainingManager

View Source on GitHub

trnbl

[PyPI] [docs] [Checks] [Coverage]

[PyPI - Downloads] [GitHub commit activity] [GitHub closed pull
requests] [code size, bytes]

trnbl – Training Butler

If you train a lot of models, you might often find yourself being
annoyed at swapping between different loggers and fiddling with a bunch
of if batch_idx % some_number == 0 statements. This package aims to fix
that problem.

Firstly, a universal interface to wandb, tensorboard, and a minimal
local logging solution (live demo) is provided.

-   This interface handles logging, error messages, metrics, and
    artifacts.
-   Swapping from one logger to another requires no modifications except
    initializing the new logger you want, and passing that instead.
-   You can even log to multiple loggers at once!

Secondly, a TrainingManager class is provided which handles logging,
artifacts, checkpointing, evaluations, exceptions, and more, with
flexibly customizable intervals.

-   Rather than having to specify all intervals in batches and then
    change everything manually when you change the batch size, dataset
    size, or number of epochs, you specify an interval in samples,
    batches, epochs, or runs. This is computed into the correct number
    of batches or epochs based on the current dataset and batch size.

    -   "1/10 runs" – 10 times a run
    -   "2.5 epochs" – every 2 & 1/2 epochs
    -   (100, "batches") – every 100 batches
    -   "10k samples" – every 10,000 samples

-   an evaluation function is passed in a tuple with an interval, takes
    the model as an argument, and returns the metrics as a dictionary

-   checkpointing is handled automatically, specifying an interval in
    the same way as evaluations

-   models are saved at the end of the run, or if an exception is
    raised, a model.exception.pt is saved

Installation

    pip install trnbl

Usage

also see the notebooks/ folder: - demo_minimal.py for a minimal example
with dummy data - demo.ipynb for an example with all options on the iris
dataset

    import torch
    from torch.utils.data import DataLoader
    from trnbl.logging.local import LocalLogger
    from <a href="trnbl/training_manager.html">trnbl.training_manager</a> import TrainingManager

    ### set up your dataset, model, optimizer, etc as usual
    dataloader: DataLoader = DataLoader(my_dataset, batch_size=32)
    model: torch.nn.Module = MyModel()
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    ### set up a logger -- swap seamlessly between wandb, tensorboard, and local logging
    logger: LocalLogger = LocalLogger(
        project="iris-demo",
        metric_names=["train/loss", "train/acc", "val/loss", "val/acc"],
        train_config=dict(
            model=str(model), optimizer=str(optimizer), criterion=str(criterion)
        ),
    )

    with TrainingManager(
        # pass your model and logger
        model=model,
        logger=logger,
        evals={
            # pass evaluation functions which take a model, and return a dict of metrics
            "1k samples": my_evaluation_function,
            "0.5 epochs": lambda model: logger.get_mem_usage(),
            "100 batches": my_other_eval_function,
        }.items(),
        checkpoint_interval="1/10 run", # will save a checkpoint 10 times per run
    ) as tr:

        # wrap the loops, and length will be automatically calculated
        # and used to figure out when to run evals, checkpoint, etc
        for epoch in tr.epoch_loop(range(120)):
            for inputs, targets in tr.batch_loop(TRAIN_LOADER):
                # your normal training code
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                loss.backward()
                optimizer.step()

                # compute whatever you want every batch
                accuracy = torch.sum(torch.argmax(outputs, dim=1) == targets).item() / len(targets)
                
                # log the metrics
                tr.batch_update(
                    samples=len(targets),
                    **{"train/loss": loss.item(), "train/acc": accuracy},
                )

        # a `model.final.pt` checkpoint will be saved at the end of the run,
        # or a `model.exception.pt` if something crashes inside the context

LocalLogger

Intended as a minimal logging solution for local runs, when you’re too
lazy to set up a new wandb project for a quick test, and want to be able
to easily read the logs. It logs everything as json or jsonl files, and
provides a simple web interface for viewing the data. The web interface
allows:

-   enable or disable the visibility of individual runs
-   filter and sort runs by various stats via an interactive table
-   smooth the data and change axes scales
-   move and resize all plots and tables

You can view a live demo of the web interface here.

[]

TODOs:

-   ☐ BUG: minifying the html/js code causes things to break?

-   frontend:

    -   ☐ batch/epoch size to table in config column group
    -   ☐ box to add aliases to runs
    -   ☐ customizable grid snap size?
    -   ☐ display the grid on the background?

-   deployment:

    -   ☐ demo website for local logger
    -   ☐ CI/CD for website, minification, tests, etc
    -   ☐ migrate to typescript

View Source on GitHub

class TrainingInterval:

View Source on GitHub

A training interval, which can be specified in a few different units.

Attributes:

-   quantity: int|float - the quantity of the interval
-   unit: TrainingIntervalUnit - the unit of the interval, one of
    “runs”, “epochs”, “batches”, or “samples”

Methods:

-   <a href="#TrainingInterval.from_str">TrainingInterval.from_str</a>(raw: str) -> TrainingInterval -
    parse a string into a TrainingInterval object
-   <a href="#TrainingInterval.as_batch_count">TrainingInterval.as_batch_count</a>(batchsize: int, batches_per_epoch: int, epochs: int|None) -> int -
    convert the interval to a raw number of batches
-   <a href="#TrainingInterval.process_to_batches">TrainingInterval.process_to_batches</a>(interval: str|TrainingInterval, batchsize: int, batches_per_epoch: int, epochs: int|None) -> int -
    any representation to a number of batches
-   <a href="#TrainingInterval.normalized">TrainingInterval.normalized</a>(batchsize: int, batches_per_epoch: int, epochs: int|None) -> None -
    current interval, with units switched to batches

Provides methods for reading from a string or tuple, and normalizing to
batches.

TrainingInterval

    (
        quantity: int | float,
        unit: Literal['runs', 'epochs', 'batches', 'samples']
    )

-   quantity: int | float

-   unit: Literal['runs', 'epochs', 'batches', 'samples']

def as_batch_count

    (
        self,
        batchsize: int,
        batches_per_epoch: int,
        epochs: int | None = None
    ) -> int

View Source on GitHub

given the batchsize, number of batches per epoch, and number of epochs,
return the interval as a number of batches

Parameters:

-   batchsize: int the size of a batch
-   batches_per_epoch: int the number of batches in an epoch
-   epochs: int|None the number of epochs to run (only required if the
    interval is in “runs”)

Returns:

-   int the interval as a number of batches

Raises:

-   ValueError if the interval is less than 1 batch, and the
    <a href="trnbl/training_interval.html#WhenIntervalLessThanBatch">trnbl.training_interval.WhenIntervalLessThanBatch</a>
    is set to muutils.errormode.ErrorMode.ERROR otherwise, will warn or
    ignore and set the interval to 1 batch
-   ValueError if the unit is not one of “runs”, “epochs”, “batches”, or
    “samples”

def normalized

    (
        self,
        batchsize: int,
        batches_per_epoch: int,
        epochs: int | None = None
    ) -> trnbl.training_interval.TrainingInterval

View Source on GitHub

convert the units of the interval to batches, by calling as_batch_count
and setting the unit to “batches

def from_str

    (cls, raw: str) -> trnbl.training_interval.TrainingInterval

View Source on GitHub

parse a string into a TrainingInterval object

Examples:

      TrainingInterval.from_str(“5 epochs”) TrainingInterval(5,
      ‘epochs’) TrainingInterval.from_str(“100 batches”)
      TrainingInterval(100, ‘batches’) TrainingInterval.from_str(“0.1
      runs”) TrainingInterval(0.1, ‘runs’)
      TrainingInterval.from_str(“1/5 runs”) TrainingInterval(0.2,
      ‘runs’)

def from_any

    (cls, *args, **kwargs) -> trnbl.training_interval.TrainingInterval

View Source on GitHub

parse a string or tuple into a TrainingInterval object

def process_to_batches

    (
        cls,
        interval: Union[str, tuple[Union[int, float, str], str], trnbl.training_interval.TrainingInterval],
        batchsize: int,
        batches_per_epoch: int,
        epochs: int | None = None
    ) -> int

View Source on GitHub

directly from any representation to a number of batches

-   TrainingIntervalUnit = typing.Literal['runs', 'epochs', 'batches', 'samples']

class TrainingLoggerBase(abc.ABC):

View Source on GitHub

Base class for training loggers

def debug

    (self, message: str, **kwargs) -> None

View Source on GitHub

log a debug message which will be saved, but not printed

def message

    (self, message: str, **kwargs) -> None

View Source on GitHub

log a progress message, which will be printed to stdout

def warning

    (self, message: str, **kwargs) -> None

View Source on GitHub

log a warning message, which will be printed to stderr

def error

    (self, message: str, **kwargs) -> None

View Source on GitHub

log an error message

def metrics

    (self, data: dict[str, typing.Any]) -> None

View Source on GitHub

Log a dictionary of metrics

def artifact

    (
        self,
        path: pathlib.Path,
        type: str,
        aliases: list[str] | None = None,
        metadata: dict | None = None
    ) -> None

View Source on GitHub

log an artifact from a file

-   url: str | list[str]

View Source on GitHub

Get the URL for the current logging run

-   run_path: pathlib.Path | list[pathlib.Path]

View Source on GitHub

Get the path to the current logging run

def flush

    (self) -> None

View Source on GitHub

Flush the logger

def finish

    (self) -> None

View Source on GitHub

Finish logging

def get_mem_usage

    (self) -> dict

View Source on GitHub

def spinner_task

    (self, **kwargs) -> trnbl.loggers.base.LoggerSpinner

View Source on GitHub

Create a spinner task. kwargs are passed to Spinner.

class TrainingManager(typing.Generic[~TLogger]):

View Source on GitHub

context manager for training a model, with logging, evals, and
checkpoints

Parameters:

-   model : torch.nn.Module ref to model being trained - used for saving
    checkpoints
-   dataloader : torch.utils.data.DataLoader ref to dataloader being
    used - used for calculating training progress
-   logger : TrainingLoggerBase logger, which can be local or interface
    with wandb.
-   epochs : int number of epochs to train for (defaults to 1)
-   evals : Iterable[tuple[TrainingInterval | str, EvalFunction]] | None
    list of pairs of (interval, eval_fn) to run evals on the model. See
    TrainingInterval for interval options. (defaults to None)
-   checkpoint_interval : TrainingInterval | str interval at which to
    save model checkpoints (defaults to TrainingInterval(1, "epochs"))
-   print_metrics_interval : TrainingInterval | str interval at which to
    print metrics (defaults to TrainingInterval(0.1, "runs"))
-   save_model : Callable[[torch.nn.Module, Path], None] function to
    save the model (defaults to torch.save) (defaults to torch.save)
-   model_save_path : str format string for saving model checkpoints.
    uses _get_format_kwargs for formatting, along with an alias kwarg
    (defaults to
    "{run_path}/checkpoints/model.checkpoint-{latest_checkpoint}.pt")
-   model_save_path_special : str format string for saving special model
    checkpoints (final, exception, etc.). uses _get_format_kwargs for
    formatting, along with an alias kwarg (defaults to
    "{run_path}/model.{alias}.pt")

Usage:

    with TrainingManager(
        model=model, dataloader=TRAIN_LOADER, logger=logger, epochs=500,
        evals={
            "1 epochs": eval_func,
            "0.1 runs": lambda model: logger.get_mem_usage(),
        }.items(),
        checkpoint_interval="50 epochs",
    ) as tp:

        # Training loop
        model.train()
        for epoch in range(epochs):
            for inputs, targets in TRAIN_LOADER:
                # the usual
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                loss.backward()
                optimizer.step()

                # compute accuracy
                accuracy = torch.sum(torch.argmax(outputs, dim=1) == targets).item() / len(targets)

                # log metrics
                tp.batch_update(
                    # pass in number of samples in your batch (or it will be inferred from the batch size)
                    samples=len(targets),
                    # any other metrics you compute every loop
                    **{"train/loss": loss.item(), "train/acc": accuracy},
                )
                # batch_update will automatically run evals and save checkpoints as needed

            tp.epoch_update()

TrainingManager

    (
        model: torch.nn.modules.module.Module,
        logger: ~TLogger,
        dataloader: torch.utils.data.dataloader.DataLoader | None = None,
        epochs_total: int | None = None,
        save_model: Callable[[torch.nn.modules.module.Module, pathlib.Path], NoneType] = <function save>,
        evals: Optional[Iterable[tuple[Union[str, tuple[Union[int, float, str], str], trnbl.training_interval.TrainingInterval], Callable[[torch.nn.modules.module.Module], dict]]]] = None,
        checkpoint_interval: Union[str, tuple[Union[int, float, str], str], trnbl.training_interval.TrainingInterval] = TrainingInterval(quantity=1, unit='epochs'),
        print_metrics_interval: Union[str, tuple[Union[int, float, str], str], trnbl.training_interval.TrainingInterval] = TrainingInterval(quantity=0.1, unit='runs'),
        model_save_path: str = '{run_path}/checkpoints/model.checkpoint-{latest_checkpoint}.pt',
        model_save_path_special: str = '{run_path}/model.{alias}.pt'
    )

View Source on GitHub

-   start_time: float

-   model: torch.nn.modules.module.Module

-   logger: ~TLogger

-   save_model: Callable[[torch.nn.modules.module.Module, pathlib.Path], NoneType]

-   model_save_path: str

-   model_save_path_special: str

-   evals: list[tuple[int, typing.Callable[[torch.nn.modules.module.Module], dict]]]

-   checkpoint_interval: int | None

-   print_metrics_interval: int | None

-   epochs: int

-   batches: int

-   samples: int

-   checkpoints: int

-   epochs_total: int | None

-   batches_per_epoch: int | None

-   batch_size: int | None

-   samples_per_epoch: int | None

-   batches_total: int | None

-   samples_total: int | None

-   init_complete: bool

def try_compute_counters

    (self) -> None

View Source on GitHub

def epoch_loop

    (
        self,
        epochs: Sequence[int],
        use_tqdm: bool = True,
        **tqdm_kwargs
    ) -> Generator[int, NoneType, NoneType]

View Source on GitHub

def batch_loop

    (
        self,
        batches: Sequence[int],
        use_tqdm: bool = False,
        **tqdm_kwargs
    ) -> Generator[int, NoneType, NoneType]

View Source on GitHub

def check_is_initialized

    (self)

View Source on GitHub

def get_elapsed_time

    (self) -> float

View Source on GitHub

return the elapsed time in seconds since the start of training

def training_status

    (self) -> dict[str, int | float]

View Source on GitHub

status of elapsed time, samples, batches, epochs, and checkpoints

def batch_update

    (self, samples: int | None, metrics: dict | None = None, **kwargs)

View Source on GitHub

call this at the end of every batch. Pass samples or it will be inferred
from the batch size, and any other metrics as kwargs

This function will: - update internal counters - run evals as needed
(based on the intervals passed) - log all metrics and training status -
save a checkpoint as needed (based on the checkpoint interval)

def epoch_update

    (self)

View Source on GitHub

call this at the end of every epoch. This function will log the
completion of the epoch and update the epoch counter

  docs for trnbl v0.1.1

Submodules

-   local
-   base
-   multi
-   tensorboard
-   wandb

View Source on GitHub

trnbl.loggers

View Source on GitHub

  docs for trnbl v0.1.1

API Documentation

-   GPU_UTILS_AVAILABLE
-   PSUTIL_AVAILABLE
-   VOWELS
-   CONSONANTS
-   rand_syllabic_string
-   LoggerSpinner
-   TrainingLoggerBase

View Source on GitHub

trnbl.loggers.base

View Source on GitHub

-   GPU_UTILS_AVAILABLE: bool = True

-   PSUTIL_AVAILABLE: bool = True

-   VOWELS: str = 'aeiou'

-   CONSONANTS: str = 'bcdfghjklmnpqrstvwxyz'

def rand_syllabic_string

    (length: int = 6) -> str

View Source on GitHub

Generate a random string of alternating consonants and vowels to use as
a unique identifier

for a length of 2n, there are about 10^{2n} possible strings

default is 6 characters, which gives 10^6 possible strings

class LoggerSpinner(muutils.spinner.Spinner):

View Source on GitHub

see Spinner for parameters. catches update_value and passes it to the
LocalLogger

LoggerSpinner

    (*args, logger: trnbl.loggers.base.TrainingLoggerBase, **kwargs)

View Source on GitHub

-   logger: trnbl.loggers.base.TrainingLoggerBase

def update_value

    (self, value: Any) -> None

View Source on GitHub

update the value of the spinner and log it

Inherited Members

-   config
-   format_string_when_updated
-   update_interval
-   message
-   current_value
-   format_string
-   output_stream
-   start_time
-   stop_spinner
-   spinner_thread
-   value_changed
-   term_width
-   state
-   spin
-   start
-   stop

class TrainingLoggerBase(abc.ABC):

View Source on GitHub

Base class for training loggers

def debug

    (self, message: str, **kwargs) -> None

View Source on GitHub

log a debug message which will be saved, but not printed

def message

    (self, message: str, **kwargs) -> None

View Source on GitHub

log a progress message, which will be printed to stdout

def warning

    (self, message: str, **kwargs) -> None

View Source on GitHub

log a warning message, which will be printed to stderr

def error

    (self, message: str, **kwargs) -> None

View Source on GitHub

log an error message

def metrics

    (self, data: dict[str, typing.Any]) -> None

View Source on GitHub

Log a dictionary of metrics

def artifact

    (
        self,
        path: pathlib.Path,
        type: str,
        aliases: list[str] | None = None,
        metadata: dict | None = None
    ) -> None

View Source on GitHub

log an artifact from a file

-   url: str | list[str]

View Source on GitHub

Get the URL for the current logging run

-   run_path: pathlib.Path | list[pathlib.Path]

View Source on GitHub

Get the path to the current logging run

def flush

    (self) -> None

View Source on GitHub

Flush the logger

def finish

    (self) -> None

View Source on GitHub

Finish logging

def get_mem_usage

    (self) -> dict

View Source on GitHub

def spinner_task

    (self, **kwargs) -> trnbl.loggers.base.LoggerSpinner

View Source on GitHub

Create a spinner task. kwargs are passed to Spinner.

  docs for trnbl v0.1.1

Submodules

-   build_dist
-   html_frontend
-   locallogger
-   start_server

API Documentation

-   FilePaths
-   LocalLogger

View Source on GitHub

trnbl.loggers.local

View Source on GitHub

class FilePaths:

View Source on GitHub

-   TRAIN_CONFIG: pathlib.Path = WindowsPath('config.json')

-   LOGGER_META: pathlib.Path = WindowsPath('meta.json')

-   TRAIN_CONFIG_YML: pathlib.Path = WindowsPath('config.yml')

-   LOGGER_META_YML: pathlib.Path = WindowsPath('meta.yml')

-   ARTIFACTS: pathlib.Path = WindowsPath('artifacts.jsonl')

-   METRICS: pathlib.Path = WindowsPath('metrics.jsonl')

-   LOG: pathlib.Path = WindowsPath('log.jsonl')

-   ERROR_FILE: pathlib.Path = WindowsPath('ERROR.txt')

-   RUNS_MANIFEST: pathlib.Path = WindowsPath('runs.jsonl')

-   RUNS_DIR: pathlib.Path = WindowsPath('runs')

-   HTML_INDEX: pathlib.Path = WindowsPath('index.html')

-   START_SERVER: pathlib.Path = WindowsPath('start_server.py')

class LocalLogger(trnbl.loggers.base.TrainingLoggerBase):

View Source on GitHub

Base class for training loggers

LocalLogger

    (
        project: str,
        metric_names: list[str],
        train_config: dict,
        group: str = '',
        base_path: str | pathlib.Path = WindowsPath('trnbl-logs'),
        memusage_as_metrics: bool = True,
        console_msg_prefix: str = '# '
    )

View Source on GitHub

-   log_list: list[dict]

-   metrics_list: list[dict]

-   artifacts_list: list[dict]

-   train_config: dict

-   project: str

-   group: str

-   group_str: str

-   base_path: pathlib.Path

-   console_msg_prefix: str

-   run_init_timestamp: datetime.datetime

-   run_id: str

-   project_path: pathlib.Path

-   log_file: _io.TextIOWrapper

-   metrics_file: _io.TextIOWrapper

-   artifacts_file: _io.TextIOWrapper

-   metric_names: list[str]

-   logger_meta: dict

-   syllabic_id: str

View Source on GitHub

def get_timestamp

    (self) -> str

View Source on GitHub

def debug

    (self, message: str, **kwargs) -> None

View Source on GitHub

log a debug message

def message

    (self, message: str, **kwargs) -> None

View Source on GitHub

log a progress message

def warning

    (self, message: str, **kwargs) -> None

View Source on GitHub

log a warning message

def error

    (self, message: str, **kwargs) -> None

View Source on GitHub

log an error message

def metrics

    (self, data: dict[str, typing.Any]) -> None

View Source on GitHub

log a dictionary of metrics

def artifact

    (
        self,
        path: pathlib.Path,
        type: str,
        aliases: list[str] | None = None,
        metadata: dict | None = None
    ) -> None

View Source on GitHub

log an artifact from a file

-   url: str

View Source on GitHub

Get the URL for the current logging run

-   run_path: pathlib.Path

View Source on GitHub

Get the path to the current logging run

def flush

    (self) -> None

View Source on GitHub

Flush the logger

def finish

    (self) -> None

View Source on GitHub

Finish logging

Inherited Members

-   get_mem_usage
-   spinner_task

  docs for trnbl v0.1.1

API Documentation

-   get_remote
-   build_dist
-   main

View Source on GitHub

trnbl.loggers.local.build_dist

View Source on GitHub

def get_remote

    (
        path_or_url: str,
        download_remote: bool = False,
        get_bytes: bool = False,
        allow_remote_fail: bool = True
    ) -> str | bytes | None

View Source on GitHub

gets a resource from a path or url

-   returns a string by default, or bytes if get_bytes is True
-   returns None if its from the web and download_remote is False

Parameters:

-   path_or_url : str location of the resource. if it starts with http,
    it is considered a url
-   download_remote : bool whether to download the resource if it is a
    url (defaults to False)
-   get_bytes : bool whether to return the resource as bytes (defaults
    to False)
-   allow_remote_fail : bool if a remote resource fails to download,
    return None. if this is False, raise an exception (defaults to True)

Raises:

-   requests.HTTPError if the remote resource returns an error, and
    allow_remote_fail is False

Returns:

-   str|bytes|None

def build_dist

    (
        path: pathlib.Path,
        minify: bool = True,
        download_remote: bool = True
    ) -> str

View Source on GitHub

Build a single file html from a folder

partially from
https://stackoverflow.com/questions/44646481/merging-js-css-html-into-single-html

def main

    () -> None

View Source on GitHub

  docs for trnbl v0.1.1

API Documentation

-   get_html_frontend

View Source on GitHub

trnbl.loggers.local.html_frontend

View Source on GitHub

def get_html_frontend

    () -> str

View Source on GitHub

  docs for trnbl v0.1.1

API Documentation

-   FilePaths
-   LocalLogger

View Source on GitHub

trnbl.loggers.local.locallogger

View Source on GitHub

class FilePaths:

View Source on GitHub

-   TRAIN_CONFIG: pathlib.Path = WindowsPath('config.json')

-   LOGGER_META: pathlib.Path = WindowsPath('meta.json')

-   TRAIN_CONFIG_YML: pathlib.Path = WindowsPath('config.yml')

-   LOGGER_META_YML: pathlib.Path = WindowsPath('meta.yml')

-   ARTIFACTS: pathlib.Path = WindowsPath('artifacts.jsonl')

-   METRICS: pathlib.Path = WindowsPath('metrics.jsonl')

-   LOG: pathlib.Path = WindowsPath('log.jsonl')

-   ERROR_FILE: pathlib.Path = WindowsPath('ERROR.txt')

-   RUNS_MANIFEST: pathlib.Path = WindowsPath('runs.jsonl')

-   RUNS_DIR: pathlib.Path = WindowsPath('runs')

-   HTML_INDEX: pathlib.Path = WindowsPath('index.html')

-   START_SERVER: pathlib.Path = WindowsPath('start_server.py')

class LocalLogger(trnbl.loggers.base.TrainingLoggerBase):

View Source on GitHub

Base class for training loggers

LocalLogger

    (
        project: str,
        metric_names: list[str],
        train_config: dict,
        group: str = '',
        base_path: str | pathlib.Path = WindowsPath('trnbl-logs'),
        memusage_as_metrics: bool = True,
        console_msg_prefix: str = '# '
    )

View Source on GitHub

-   log_list: list[dict]

-   metrics_list: list[dict]

-   artifacts_list: list[dict]

-   train_config: dict

-   project: str

-   group: str

-   group_str: str

-   base_path: pathlib.Path

-   console_msg_prefix: str

-   run_init_timestamp: datetime.datetime

-   run_id: str

-   project_path: pathlib.Path

-   log_file: _io.TextIOWrapper

-   metrics_file: _io.TextIOWrapper

-   artifacts_file: _io.TextIOWrapper

-   metric_names: list[str]

-   logger_meta: dict

-   syllabic_id: str

View Source on GitHub

def get_timestamp

    (self) -> str

View Source on GitHub

def debug

    (self, message: str, **kwargs) -> None

View Source on GitHub

log a debug message

def message

    (self, message: str, **kwargs) -> None

View Source on GitHub

log a progress message

def warning

    (self, message: str, **kwargs) -> None

View Source on GitHub

log a warning message

def error

    (self, message: str, **kwargs) -> None

View Source on GitHub

log an error message

def metrics

    (self, data: dict[str, typing.Any]) -> None

View Source on GitHub

log a dictionary of metrics

def artifact

    (
        self,
        path: pathlib.Path,
        type: str,
        aliases: list[str] | None = None,
        metadata: dict | None = None
    ) -> None

View Source on GitHub

log an artifact from a file

-   url: str

View Source on GitHub

Get the URL for the current logging run

-   run_path: pathlib.Path

View Source on GitHub

Get the path to the current logging run

def flush

    (self) -> None

View Source on GitHub

Flush the logger

def finish

    (self) -> None

View Source on GitHub

Finish logging

Inherited Members

-   get_mem_usage
-   spinner_task

  docs for trnbl v0.1.1

Contents

Usage: python start_server.py path/to/directory [port]

API Documentation

-   start_server

View Source on GitHub

trnbl.loggers.local.start_server

Usage: python start_server.py path/to/directory [port]

View Source on GitHub

def start_server

    (path: str, port: int = 8000) -> None

View Source on GitHub

Starts a server to serve the files in the given path.

  docs for trnbl v0.1.1

API Documentation

-   maybe_flatten
-   MultiLogger

View Source on GitHub

trnbl.loggers.multi

View Source on GitHub

def maybe_flatten

    (lst: list[typing.Union[~T, list[~T]]]) -> list[~T]

View Source on GitHub

flatten a list if it is nested

class MultiLogger(trnbl.loggers.base.TrainingLoggerBase):

View Source on GitHub

use multiple loggers at once

MultiLogger

    (loggers: list[trnbl.loggers.base.TrainingLoggerBase])

View Source on GitHub

-   loggers: list[trnbl.loggers.base.TrainingLoggerBase]

def debug

    (self, message: str, **kwargs) -> None

View Source on GitHub

log a debug message which will be saved, but not printed

def message

    (self, message: str, **kwargs) -> None

View Source on GitHub

log a progress message

def metrics

    (self, data: dict[str, typing.Any]) -> None

View Source on GitHub

Log a dictionary of metrics

def artifact

    (
        self,
        path: pathlib.Path,
        type: str,
        aliases: list[str] | None = None,
        metadata: dict | None = None
    ) -> None

View Source on GitHub

log an artifact from a file

-   url: list[str]

View Source on GitHub

Get the URL for the current logging run

-   run_path: list[pathlib.Path]

View Source on GitHub

Get the paths to the current logging run

def flush

    (self) -> None

View Source on GitHub

Flush the logger

def finish

    (self) -> None

View Source on GitHub

Finish logging

Inherited Members

-   warning
-   error
-   get_mem_usage
-   spinner_task

  docs for trnbl v0.1.1

API Documentation

-   TensorBoardLogger

View Source on GitHub

trnbl.loggers.tensorboard

View Source on GitHub

class TensorBoardLogger(trnbl.loggers.base.TrainingLoggerBase):

View Source on GitHub

Base class for training loggers

TensorBoardLogger

    (
        log_dir: str | pathlib.Path,
        train_config: dict | None = None,
        name: str | None = None,
        **kwargs
    )

View Source on GitHub

def debug

    (self, message: str, **kwargs) -> None

View Source on GitHub

log a debug message which will be saved, but not printed

def message

    (self, message: str, **kwargs) -> None

View Source on GitHub

log a progress message, which will be printed to stdout

def metrics

    (self, data: dict[str, typing.Any]) -> None

View Source on GitHub

Log a dictionary of metrics

def artifact

    (
        self,
        path: pathlib.Path,
        type: str,
        aliases: list[str] | None = None,
        metadata: dict | None = None
    ) -> None

View Source on GitHub

log an artifact from a file

-   url: str

View Source on GitHub

Get the URL for the current logging run

-   run_path: pathlib.Path

View Source on GitHub

Get the path to the current logging run

def flush

    (self) -> None

View Source on GitHub

Flush the logger

def finish

    (self) -> None

View Source on GitHub

Finish logging

Inherited Members

-   warning
-   error
-   get_mem_usage
-   spinner_task

  docs for trnbl v0.1.1

API Documentation

-   WandbLogger

View Source on GitHub

trnbl.loggers.wandb

View Source on GitHub

class WandbLogger(trnbl.loggers.base.TrainingLoggerBase):

View Source on GitHub

wrapper around wandb logging for TrainingLoggerBase. create using
<a href="#WandbLogger.create">WandbLogger.create</a>(config, project, job_type)

WandbLogger

    (run: wandb.sdk.wandb_run.Run)

View Source on GitHub

def create

    (
        cls,
        config: dict,
        project: str | None = None,
        job_type: str | None = None,
        logging_fmt: str = '%(asctime)s %(levelname)s %(message)s',
        logging_datefmt: str = '%Y-%m-%d %H:%M:%S',
        wandb_kwargs: dict | None = None
    ) -> trnbl.loggers.wandb.WandbLogger

View Source on GitHub

def debug

    (self, message: str, **kwargs) -> None

View Source on GitHub

log a debug message which will be saved, but not printed

def message

    (self, message: str, **kwargs) -> None

View Source on GitHub

log a progress message, which will be printed to stdout

def metrics

    (self, data: dict[str, typing.Any]) -> None

View Source on GitHub

Log a dictionary of metrics

def artifact

    (
        self,
        path: pathlib.Path,
        type: str,
        aliases: list[str] | None = None,
        metadata: dict | None = None
    ) -> None

View Source on GitHub

log an artifact from a file

-   url: str

View Source on GitHub

Get the URL for the current logging run

-   run_path: pathlib.Path

View Source on GitHub

Get the path to the current logging run

def flush

    (self) -> None

View Source on GitHub

Flush the logger

def finish

    (self) -> None

View Source on GitHub

Finish logging

Inherited Members

-   warning
-   error
-   get_mem_usage
-   spinner_task

  docs for trnbl v0.1.1

API Documentation

-   TrainingIntervalUnit
-   WhenIntervalLessThanBatch
-   IntervalValueError
-   TrainingInterval
-   CastableToTrainingInterval

View Source on GitHub

trnbl.training_interval

View Source on GitHub

-   TrainingIntervalUnit = typing.Literal['runs', 'epochs', 'batches', 'samples']

-   WhenIntervalLessThanBatch: muutils.errormode.ErrorMode = ErrorMode.Warn

class IntervalValueError(builtins.UserWarning):

View Source on GitHub

Error for when the interval is less than 1 batch

Inherited Members

-   UserWarning

-   with_traceback

-   add_note

-   args

class TrainingInterval:

View Source on GitHub

A training interval, which can be specified in a few different units.

Attributes:

-   quantity: int|float - the quantity of the interval
-   unit: TrainingIntervalUnit - the unit of the interval, one of
    “runs”, “epochs”, “batches”, or “samples”

Methods:

-   <a href="#TrainingInterval.from_str">TrainingInterval.from_str</a>(raw: str) -> TrainingInterval -
    parse a string into a TrainingInterval object
-   <a href="#TrainingInterval.as_batch_count">TrainingInterval.as_batch_count</a>(batchsize: int, batches_per_epoch: int, epochs: int|None) -> int -
    convert the interval to a raw number of batches
-   <a href="#TrainingInterval.process_to_batches">TrainingInterval.process_to_batches</a>(interval: str|TrainingInterval, batchsize: int, batches_per_epoch: int, epochs: int|None) -> int -
    any representation to a number of batches
-   <a href="#TrainingInterval.normalized">TrainingInterval.normalized</a>(batchsize: int, batches_per_epoch: int, epochs: int|None) -> None -
    current interval, with units switched to batches

Provides methods for reading from a string or tuple, and normalizing to
batches.

TrainingInterval

    (
        quantity: int | float,
        unit: Literal['runs', 'epochs', 'batches', 'samples']
    )

-   quantity: int | float

-   unit: Literal['runs', 'epochs', 'batches', 'samples']

def as_batch_count

    (
        self,
        batchsize: int,
        batches_per_epoch: int,
        epochs: int | None = None
    ) -> int

View Source on GitHub

given the batchsize, number of batches per epoch, and number of epochs,
return the interval as a number of batches

Parameters:

-   batchsize: int the size of a batch
-   batches_per_epoch: int the number of batches in an epoch
-   epochs: int|None the number of epochs to run (only required if the
    interval is in “runs”)

Returns:

-   int the interval as a number of batches

Raises:

-   ValueError if the interval is less than 1 batch, and the
    <a href="#WhenIntervalLessThanBatch">WhenIntervalLessThanBatch</a>
    is set to muutils.errormode.ErrorMode.ERROR otherwise, will warn or
    ignore and set the interval to 1 batch
-   ValueError if the unit is not one of “runs”, “epochs”, “batches”, or
    “samples”

def normalized

    (
        self,
        batchsize: int,
        batches_per_epoch: int,
        epochs: int | None = None
    ) -> trnbl.training_interval.TrainingInterval

View Source on GitHub

convert the units of the interval to batches, by calling as_batch_count
and setting the unit to “batches

def from_str

    (cls, raw: str) -> trnbl.training_interval.TrainingInterval

View Source on GitHub

parse a string into a TrainingInterval object

Examples:

      TrainingInterval.from_str(“5 epochs”) TrainingInterval(5,
      ‘epochs’) TrainingInterval.from_str(“100 batches”)
      TrainingInterval(100, ‘batches’) TrainingInterval.from_str(“0.1
      runs”) TrainingInterval(0.1, ‘runs’)
      TrainingInterval.from_str(“1/5 runs”) TrainingInterval(0.2,
      ‘runs’)

def from_any

    (cls, *args, **kwargs) -> trnbl.training_interval.TrainingInterval

View Source on GitHub

parse a string or tuple into a TrainingInterval object

def process_to_batches

    (
        cls,
        interval: Union[str, tuple[Union[int, float, str], str], trnbl.training_interval.TrainingInterval],
        batchsize: int,
        batches_per_epoch: int,
        epochs: int | None = None
    ) -> int

View Source on GitHub

directly from any representation to a number of batches

-   CastableToTrainingInterval = typing.Union[str, tuple[typing.Union[int, float, str], str], trnbl.training_interval.TrainingInterval]

  docs for trnbl v0.1.1

API Documentation

-   EvalFunction
-   TrainingManagerInitError
-   wrapped_iterable
-   TrainingManager

View Source on GitHub

trnbl.training_manager

View Source on GitHub

-   EvalFunction = typing.Callable[[ForwardRef('torch.nn.Module')], dict]

class TrainingManagerInitError(builtins.Exception):

View Source on GitHub

Common base class for all non-exit exceptions.

Inherited Members

-   Exception

-   with_traceback

-   add_note

-   args

def wrapped_iterable

    (
        sequence: Sequence[~T],
        manager: trnbl.training_manager.TrainingManager,
        is_epoch: bool = False,
        use_tqdm: bool | None = None,
        tqdm_kwargs: dict[str, typing.Any] | None = None
    ) -> Generator[~T, NoneType, NoneType]

View Source on GitHub

class TrainingManager(typing.Generic[~TLogger]):

View Source on GitHub

context manager for training a model, with logging, evals, and
checkpoints

Parameters:

-   model : torch.nn.Module ref to model being trained - used for saving
    checkpoints
-   dataloader : torch.utils.data.DataLoader ref to dataloader being
    used - used for calculating training progress
-   logger : TrainingLoggerBase logger, which can be local or interface
    with wandb.
-   epochs : int number of epochs to train for (defaults to 1)
-   evals : Iterable[tuple[TrainingInterval | str, EvalFunction]] | None
    list of pairs of (interval, eval_fn) to run evals on the model. See
    TrainingInterval for interval options. (defaults to None)
-   checkpoint_interval : TrainingInterval | str interval at which to
    save model checkpoints (defaults to TrainingInterval(1, "epochs"))
-   print_metrics_interval : TrainingInterval | str interval at which to
    print metrics (defaults to TrainingInterval(0.1, "runs"))
-   save_model : Callable[[torch.nn.Module, Path], None] function to
    save the model (defaults to torch.save) (defaults to torch.save)
-   model_save_path : str format string for saving model checkpoints.
    uses _get_format_kwargs for formatting, along with an alias kwarg
    (defaults to
    "{run_path}/checkpoints/model.checkpoint-{latest_checkpoint}.pt")
-   model_save_path_special : str format string for saving special model
    checkpoints (final, exception, etc.). uses _get_format_kwargs for
    formatting, along with an alias kwarg (defaults to
    "{run_path}/model.{alias}.pt")

Usage:

    with TrainingManager(
        model=model, dataloader=TRAIN_LOADER, logger=logger, epochs=500,
        evals={
            "1 epochs": eval_func,
            "0.1 runs": lambda model: logger.get_mem_usage(),
        }.items(),
        checkpoint_interval="50 epochs",
    ) as tp:

        # Training loop
        model.train()
        for epoch in range(epochs):
            for inputs, targets in TRAIN_LOADER:
                # the usual
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                loss.backward()
                optimizer.step()

                # compute accuracy
                accuracy = torch.sum(torch.argmax(outputs, dim=1) == targets).item() / len(targets)

                # log metrics
                tp.batch_update(
                    # pass in number of samples in your batch (or it will be inferred from the batch size)
                    samples=len(targets),
                    # any other metrics you compute every loop
                    **{"train/loss": loss.item(), "train/acc": accuracy},
                )
                # batch_update will automatically run evals and save checkpoints as needed

            tp.epoch_update()

TrainingManager

    (
        model: torch.nn.modules.module.Module,
        logger: ~TLogger,
        dataloader: torch.utils.data.dataloader.DataLoader | None = None,
        epochs_total: int | None = None,
        save_model: Callable[[torch.nn.modules.module.Module, pathlib.Path], NoneType] = <function save>,
        evals: Optional[Iterable[tuple[Union[str, tuple[Union[int, float, str], str], trnbl.training_interval.TrainingInterval], Callable[[torch.nn.modules.module.Module], dict]]]] = None,
        checkpoint_interval: Union[str, tuple[Union[int, float, str], str], trnbl.training_interval.TrainingInterval] = TrainingInterval(quantity=1, unit='epochs'),
        print_metrics_interval: Union[str, tuple[Union[int, float, str], str], trnbl.training_interval.TrainingInterval] = TrainingInterval(quantity=0.1, unit='runs'),
        model_save_path: str = '{run_path}/checkpoints/model.checkpoint-{latest_checkpoint}.pt',
        model_save_path_special: str = '{run_path}/model.{alias}.pt'
    )

View Source on GitHub

-   start_time: float

-   model: torch.nn.modules.module.Module

-   logger: ~TLogger

-   save_model: Callable[[torch.nn.modules.module.Module, pathlib.Path], NoneType]

-   model_save_path: str

-   model_save_path_special: str

-   evals: list[tuple[int, typing.Callable[[torch.nn.modules.module.Module], dict]]]

-   checkpoint_interval: int | None

-   print_metrics_interval: int | None

-   epochs: int

-   batches: int

-   samples: int

-   checkpoints: int

-   epochs_total: int | None

-   batches_per_epoch: int | None

-   batch_size: int | None

-   samples_per_epoch: int | None

-   batches_total: int | None

-   samples_total: int | None

-   init_complete: bool

def try_compute_counters

    (self) -> None

View Source on GitHub

def epoch_loop

    (
        self,
        epochs: Sequence[int],
        use_tqdm: bool = True,
        **tqdm_kwargs
    ) -> Generator[int, NoneType, NoneType]

View Source on GitHub

def batch_loop

    (
        self,
        batches: Sequence[int],
        use_tqdm: bool = False,
        **tqdm_kwargs
    ) -> Generator[int, NoneType, NoneType]

View Source on GitHub

def check_is_initialized

    (self)

View Source on GitHub

def get_elapsed_time

    (self) -> float

View Source on GitHub

return the elapsed time in seconds since the start of training

def training_status

    (self) -> dict[str, int | float]

View Source on GitHub

status of elapsed time, samples, batches, epochs, and checkpoints

def batch_update

    (self, samples: int | None, metrics: dict | None = None, **kwargs)

View Source on GitHub

call this at the end of every batch. Pass samples or it will be inferred
from the batch size, and any other metrics as kwargs

This function will: - update internal counters - run evals as needed
(based on the intervals passed) - log all metrics and training status -
save a checkpoint as needed (based on the checkpoint interval)

def epoch_update

    (self)

View Source on GitHub

call this at the end of every epoch. This function will log the
completion of the epoch and update the epoch counter
