docs for
trnblv0.1.0
trnbl –
Training
ButlerIf you train a lot of models, you might often find yourself being
annoyed at swapping between different loggers and fiddling with a bunch
of if batch_idx % some_number == 0 statements. This package
aims to fix that problem.
Firstly, a universal interface to wandb, tensorboard,
and a minimal local logging solution (live demo) is
provided.
Secondly, a TrainingManager class is provided which
handles logging, artifacts, checkpointing, evaluations, exceptions, and
more, with flexibly customizable intervals.
Rather than having to specify all intervals in batches and then change everything manually when you change the batch size, dataset size, or number of epochs, you specify an interval in samples, batches, epochs, or runs. This is computed into the correct number of batches or epochs based on the current dataset and batch size.
"1/10 runs" – 10 times a run"2.5 epochs" – every 2 & 1/2 epochs(100, "batches") – every 100 batches"10k samples" – every 10,000 samplesan evaluation function is passed in a tuple with an interval, takes the model as an argument, and returns the metrics as a dictionary
checkpointing is handled automatically, specifying an interval in the same way as evaluations
models are saved at the end of the run, or if an exception is
raised, a model.exception.pt is saved
pip install trnbl
also see the notebooks/
folder: - demo_minimal.py
for a minimal example with dummy data - demo.ipynb
for an example with all options on the iris dataset
import torch
from torch.utils.data import DataLoader
from trnbl.logging.local import LocalLogger
from trnbl.training_manager import TrainingManager
# set up your dataset, model, optimizer, etc as usual
dataloader: DataLoader = DataLoader(my_dataset, batch_size=32)
model: torch.nn.Module = MyModel()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# set up a logger -- swap seamlessly between wandb, tensorboard, and local logging
logger: LocalLogger = LocalLogger(
project="iris-demo",
metric_names=["train/loss", "train/acc", "val/loss", "val/acc"],
train_config=dict(
model=str(model), optimizer=str(optimizer), criterion=str(criterion)
),
)
with TrainingManager(
# pass your model and logger
model=model,
logger=logger,
evals={
# pass evaluation functions which take a model, and return a dict of metrics
"1k samples": my_evaluation_function,
"0.5 epochs": lambda model: logger.get_mem_usage(),
"100 batches": my_other_eval_function,
}.items(),
checkpoint_interval="1/10 run", # will save a checkpoint 10 times per run
) as tr:
# wrap the loops, and length will be automatically calculated
# and used to figure out when to run evals, checkpoint, etc
for epoch in tr.epoch_loop(range(120)):
for inputs, targets in tr.batch_loop(TRAIN_LOADER):
# your normal training code
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
# compute whatever you want every batch
accuracy = torch.sum(torch.argmax(outputs, dim=1) == targets).item() / len(targets)
# log the metrics
tr.batch_update(
samples=len(targets),
**{"train/loss": loss.item(), "train/acc": accuracy},
)
# a `model.final.pt` checkpoint will be saved at the end of the run,
# or a `model.exception.pt` if something crashes inside the contextLocalLoggerIntended as a minimal logging solution for local runs, when you’re
too lazy to set up a new wandb project for a quick test,
and want to be able to easily read the logs. It logs everything as json
or jsonl files, and provides a simple web interface for viewing the
data. The web interface allows:
You can view a live demo of the web interface here.
frontend:
deployment:
trnbl
trnbl –
Training
ButlerIf you train a lot of models, you might often find yourself being
annoyed at swapping between different loggers and fiddling with a bunch
of if batch_idx % some_number == 0 statements. This package
aims to fix that problem.
Firstly, a universal interface to wandb, tensorboard,
and a minimal local logging solution (live demo) is
provided.
Secondly, a TrainingManager class is provided which
handles logging, artifacts, checkpointing, evaluations, exceptions, and
more, with flexibly customizable intervals.
Rather than having to specify all intervals in batches and then change everything manually when you change the batch size, dataset size, or number of epochs, you specify an interval in samples, batches, epochs, or runs. This is computed into the correct number of batches or epochs based on the current dataset and batch size.
"1/10 runs" – 10 times a run"2.5 epochs" – every 2 & 1/2 epochs(100, "batches") – every 100 batches"10k samples" – every 10,000 samplesan evaluation function is passed in a tuple with an interval, takes the model as an argument, and returns the metrics as a dictionary
checkpointing is handled automatically, specifying an interval in the same way as evaluations
models are saved at the end of the run, or if an exception is
raised, a model.exception.pt is saved
pip install trnbl
also see the notebooks/
folder: - demo_minimal.py
for a minimal example with dummy data - demo.ipynb
for an example with all options on the iris dataset
import torch
from torch.utils.data import DataLoader
from trnbl.logging.local import LocalLogger
from <a href="trnbl/training_manager.html">trnbl.training_manager</a> import TrainingManager
### set up your dataset, model, optimizer, etc as usual
dataloader: DataLoader = DataLoader(my_dataset, batch_size=32)
model: torch.nn.Module = MyModel()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
### set up a logger -- swap seamlessly between wandb, tensorboard, and local logging
logger: LocalLogger = LocalLogger(
project="iris-demo",
metric_names=["train/loss", "train/acc", "val/loss", "val/acc"],
train_config=dict(
model=str(model), optimizer=str(optimizer), criterion=str(criterion)
),
)
with TrainingManager(
# pass your model and logger
model=model,
logger=logger,
evals={
# pass evaluation functions which take a model, and return a dict of metrics
"1k samples": my_evaluation_function,
"0.5 epochs": lambda model: logger.get_mem_usage(),
"100 batches": my_other_eval_function,
}.items(),
checkpoint_interval="1/10 run", # will save a checkpoint 10 times per run
) as tr:
# wrap the loops, and length will be automatically calculated
# and used to figure out when to run evals, checkpoint, etc
for epoch in tr.epoch_loop(range(120)):
for inputs, targets in tr.batch_loop(TRAIN_LOADER):
# your normal training code
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
# compute whatever you want every batch
accuracy = torch.sum(torch.argmax(outputs, dim=1) == targets).item() / len(targets)
# log the metrics
tr.batch_update(
samples=len(targets),
**{"train/loss": loss.item(), "train/acc": accuracy},
)
# a `model.final.pt` checkpoint will be saved at the end of the run,
# or a `model.exception.pt` if something crashes inside the contextLocalLoggerIntended as a minimal logging solution for local runs, when you’re
too lazy to set up a new wandb project for a quick test,
and want to be able to easily read the logs. It logs everything as json
or jsonl files, and provides a simple web interface for viewing the
data. The web interface allows:
You can view a live demo of the web interface here.
frontend:
deployment:
class TrainingInterval:A training interval, which can be specified in a few different units.
quantity: int|float - the quantity of the intervalunit: TrainingIntervalUnit - the unit of the interval,
one of “runs”, “epochs”, “batches”, or “samples”<a href="#TrainingInterval.from_str">TrainingInterval.from_str</a>(raw: str) -> TrainingInterval
- parse a string into a TrainingInterval object<a href="#TrainingInterval.as_batch_count">TrainingInterval.as_batch_count</a>(batchsize: int, batches_per_epoch: int, epochs: int|None) -> int
- convert the interval to a raw number of batches<a href="#TrainingInterval.process_to_batches">TrainingInterval.process_to_batches</a>(interval: str|TrainingInterval, batchsize: int, batches_per_epoch: int, epochs: int|None) -> int
- any representation to a number of batches<a href="#TrainingInterval.normalized">TrainingInterval.normalized</a>(batchsize: int, batches_per_epoch: int, epochs: int|None) -> None
- current interval, with units switched to batchesProvides methods for reading from a string or tuple, and normalizing to batches.
TrainingInterval(
quantity: int | float,
unit: Literal['runs', 'epochs', 'batches', 'samples']
)quantity: int | float
unit: Literal['runs', 'epochs', 'batches', 'samples']
def as_batch_count(
self,
batchsize: int,
batches_per_epoch: int,
epochs: int | None = None
) -> intgiven the batchsize, number of batches per epoch, and number of epochs, return the interval as a number of batches
batchsize: int the size of a batchbatches_per_epoch: int the number of batches in an
epochepochs: int|None the number of epochs to run (only
required if the interval is in “runs”)int the interval as a number of batchesValueError if the interval is less than 1 batch, and
the
<a href="trnbl/training_interval.html#WhenIntervalLessThanBatch">trnbl.training_interval.WhenIntervalLessThanBatch</a>
is set to muutils.errormode.ErrorMode.ERROR otherwise, will
warn or ignore and set the interval to 1 batchValueError if the unit is not one of “runs”, “epochs”,
“batches”, or “samples”def normalized(
self,
batchsize: int,
batches_per_epoch: int,
epochs: int | None = None
) -> trnbl.training_interval.TrainingIntervalconvert the units of the interval to batches, by calling
as_batch_count and setting the unit to
“batches
def from_str(cls, raw: str) -> trnbl.training_interval.TrainingIntervalparse a string into a TrainingInterval object
TrainingInterval.from_str(“5 epochs”) TrainingInterval(5, ‘epochs’) TrainingInterval.from_str(“100 batches”) TrainingInterval(100, ‘batches’) TrainingInterval.from_str(“0.1 runs”) TrainingInterval(0.1, ‘runs’) TrainingInterval.from_str(“1/5 runs”) TrainingInterval(0.2, ‘runs’)
def from_any(cls, *args, **kwargs) -> trnbl.training_interval.TrainingIntervalparse a string or tuple into a TrainingInterval object
def process_to_batches(
cls,
interval: Union[str, tuple[Union[int, float, str], str], trnbl.training_interval.TrainingInterval],
batchsize: int,
batches_per_epoch: int,
epochs: int | None = None
) -> intdirectly from any representation to a number of batches
TrainingIntervalUnit = typing.Literal['runs', 'epochs', 'batches', 'samples']class TrainingLoggerBase(abc.ABC):Base class for training loggers
def debug(self, message: str, **kwargs) -> Nonelog a debug message which will be saved, but not printed
def message(self, message: str, **kwargs) -> Nonelog a progress message, which will be printed to stdout
def warning(self, message: str, **kwargs) -> Nonelog a warning message, which will be printed to stderr
def error(self, message: str, **kwargs) -> Nonelog an error message
def metrics(self, data: dict[str, typing.Any]) -> NoneLog a dictionary of metrics
def artifact(
self,
path: pathlib.Path,
type: str,
aliases: list[str] | None = None,
metadata: dict | None = None
) -> Nonelog an artifact from a file
url: str | list[str]Get the URL for the current logging run
run_path: pathlib.Path | list[pathlib.Path]Get the path to the current logging run
def flush(self) -> NoneFlush the logger
def finish(self) -> NoneFinish logging
def get_mem_usage(self) -> dictdef spinner_task(self, **kwargs) -> trnbl.loggers.base.LoggerSpinnerCreate a spinner task. kwargs are passed to Spinner.
class TrainingManager(typing.Generic[~TLogger]):context manager for training a model, with logging, evals, and checkpoints
model : torch.nn.Module ref to model being trained -
used for saving checkpointsdataloader : torch.utils.data.DataLoader ref to
dataloader being used - used for calculating training progresslogger : TrainingLoggerBase logger, which can be local
or interface with wandb.epochs : int number of epochs to train for (defaults to
1)evals : Iterable[tuple[TrainingInterval | str, EvalFunction]] | None
list of pairs of (interval, eval_fn) to run evals on the model. See
TrainingInterval for interval options. (defaults to
None)checkpoint_interval : TrainingInterval | str interval
at which to save model checkpoints (defaults to
TrainingInterval(1, "epochs"))print_metrics_interval : TrainingInterval | str
interval at which to print metrics (defaults to
TrainingInterval(0.1, "runs"))save_model : Callable[[torch.nn.Module, Path], None]
function to save the model (defaults to torch.save)
(defaults to torch.save)model_save_path : str format string for saving model
checkpoints. uses _get_format_kwargs for formatting, along
with an alias kwarg (defaults to
"{run_path}/checkpoints/model.checkpoint-{latest_checkpoint}.pt")model_save_path_special : str format string for saving
special model checkpoints (final, exception, etc.). uses
_get_format_kwargs for formatting, along with an
alias kwarg (defaults to
"{run_path}/model.{alias}.pt")with TrainingManager(
model=model, dataloader=TRAIN_LOADER, logger=logger, epochs=500,
evals={
"1 epochs": eval_func,
"0.1 runs": lambda model: logger.get_mem_usage(),
}.items(),
checkpoint_interval="50 epochs",
) as tp:
# Training loop
model.train()
for epoch in range(epochs):
for inputs, targets in TRAIN_LOADER:
# the usual
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
# compute accuracy
accuracy = torch.sum(torch.argmax(outputs, dim=1) == targets).item() / len(targets)
# log metrics
tp.batch_update(
# pass in number of samples in your batch (or it will be inferred from the batch size)
samples=len(targets),
# any other metrics you compute every loop
**{"train/loss": loss.item(), "train/acc": accuracy},
)
# batch_update will automatically run evals and save checkpoints as needed
tp.epoch_update()TrainingManager(
model: torch.nn.modules.module.Module,
logger: ~TLogger,
dataloader: torch.utils.data.dataloader.DataLoader | None = None,
epochs_total: int | None = None,
save_model: Callable[[torch.nn.modules.module.Module, pathlib.Path], NoneType] = <function save>,
evals: Optional[Iterable[tuple[Union[str, tuple[Union[int, float, str], str], trnbl.training_interval.TrainingInterval], Callable[[torch.nn.modules.module.Module], dict]]]] = None,
checkpoint_interval: Union[str, tuple[Union[int, float, str], str], trnbl.training_interval.TrainingInterval] = TrainingInterval(quantity=1, unit='epochs'),
print_metrics_interval: Union[str, tuple[Union[int, float, str], str], trnbl.training_interval.TrainingInterval] = TrainingInterval(quantity=0.1, unit='runs'),
model_save_path: str = '{run_path}/checkpoints/model.checkpoint-{latest_checkpoint}.pt',
model_save_path_special: str = '{run_path}/model.{alias}.pt'
)start_time: float
model: torch.nn.modules.module.Module
logger: ~TLogger
save_model: Callable[[torch.nn.modules.module.Module, pathlib.Path], NoneType]
model_save_path: str
model_save_path_special: str
evals: list[tuple[int, typing.Callable[[torch.nn.modules.module.Module], dict]]]
checkpoint_interval: int | None
print_metrics_interval: int | None
epochs: int
batches: int
samples: int
checkpoints: int
epochs_total: int | None
batches_per_epoch: int | None
batch_size: int | None
samples_per_epoch: int | None
batches_total: int | None
samples_total: int | None
init_complete: bool
def try_compute_counters(self) -> Nonedef epoch_loop(
self,
epochs: Sequence[int],
use_tqdm: bool = True,
**tqdm_kwargs
) -> Generator[int, NoneType, NoneType]def batch_loop(
self,
batches: Sequence[int],
use_tqdm: bool = False,
**tqdm_kwargs
) -> Generator[int, NoneType, NoneType]def check_is_initialized(self)def get_elapsed_time(self) -> floatreturn the elapsed time in seconds since the start of training
def training_status(self) -> dict[str, int | float]status of elapsed time, samples, batches, epochs, and checkpoints
def batch_update(self, samples: int | None, metrics: dict | None = None, **kwargs)call this at the end of every batch. Pass samples or it
will be inferred from the batch size, and any other metrics as
kwargs
This function will: - update internal counters - run evals as needed (based on the intervals passed) - log all metrics and training status - save a checkpoint as needed (based on the checkpoint interval)
def epoch_update(self)call this at the end of every epoch. This function will log the completion of the epoch and update the epoch counter
docs for
trnblv0.1.0
trnbl.loggersdocs for
trnblv0.1.0
GPU_UTILS_AVAILABLEPSUTIL_AVAILABLEVOWELSCONSONANTSrand_syllabic_stringLoggerSpinnerTrainingLoggerBasetrnbl.loggers.baseGPU_UTILS_AVAILABLE: bool = True
PSUTIL_AVAILABLE: bool = True
VOWELS: str = 'aeiou'
CONSONANTS: str = 'bcdfghjklmnpqrstvwxyz'
def rand_syllabic_string(length: int = 6) -> strGenerate a random string of alternating consonants and vowels to use as a unique identifier
for a length of 2n, there are about 10^{2n} possible strings
default is 6 characters, which gives 10^6 possible strings
class LoggerSpinner(muutils.spinner.Spinner):see Spinner for parameters. catches
update_value and passes it to the
LocalLogger
LoggerSpinner(*args, logger: trnbl.loggers.base.TrainingLoggerBase, **kwargs)logger: trnbl.loggers.base.TrainingLoggerBasedef update_value(self, value: Any) -> Noneupdate the value of the spinner and log it
configformat_string_when_updatedupdate_intervalmessagecurrent_valueformat_stringoutput_streamstart_timestop_spinnerspinner_threadvalue_changedterm_widthstatespinstartstopclass TrainingLoggerBase(abc.ABC):Base class for training loggers
def debug(self, message: str, **kwargs) -> Nonelog a debug message which will be saved, but not printed
def message(self, message: str, **kwargs) -> Nonelog a progress message, which will be printed to stdout
def warning(self, message: str, **kwargs) -> Nonelog a warning message, which will be printed to stderr
def error(self, message: str, **kwargs) -> Nonelog an error message
def metrics(self, data: dict[str, typing.Any]) -> NoneLog a dictionary of metrics
def artifact(
self,
path: pathlib.Path,
type: str,
aliases: list[str] | None = None,
metadata: dict | None = None
) -> Nonelog an artifact from a file
url: str | list[str]Get the URL for the current logging run
run_path: pathlib.Path | list[pathlib.Path]Get the path to the current logging run
def flush(self) -> NoneFlush the logger
def finish(self) -> NoneFinish logging
def get_mem_usage(self) -> dictdef spinner_task(self, **kwargs) -> trnbl.loggers.base.LoggerSpinnerCreate a spinner task. kwargs are passed to Spinner.
docs for
trnblv0.1.0
trnbl.loggers.localclass FilePaths:TRAIN_CONFIG: pathlib.Path = WindowsPath('config.json')
LOGGER_META: pathlib.Path = WindowsPath('meta.json')
TRAIN_CONFIG_YML: pathlib.Path = WindowsPath('config.yml')
LOGGER_META_YML: pathlib.Path = WindowsPath('meta.yml')
ARTIFACTS: pathlib.Path = WindowsPath('artifacts.jsonl')
METRICS: pathlib.Path = WindowsPath('metrics.jsonl')
LOG: pathlib.Path = WindowsPath('log.jsonl')
ERROR_FILE: pathlib.Path = WindowsPath('ERROR.txt')
RUNS_MANIFEST: pathlib.Path = WindowsPath('runs.jsonl')
RUNS_DIR: pathlib.Path = WindowsPath('runs')
HTML_INDEX: pathlib.Path = WindowsPath('index.html')
START_SERVER: pathlib.Path = WindowsPath('start_server.py')
class LocalLogger(trnbl.loggers.base.TrainingLoggerBase):Base class for training loggers
LocalLogger(
project: str,
metric_names: list[str],
train_config: dict,
group: str = '',
base_path: str | pathlib.Path = WindowsPath('trnbl-logs'),
memusage_as_metrics: bool = True,
console_msg_prefix: str = '# '
)log_list: list[dict]
metrics_list: list[dict]
artifacts_list: list[dict]
train_config: dict
project: str
group: str
group_str: str
base_path: pathlib.Path
console_msg_prefix: str
run_init_timestamp: datetime.datetime
run_id: str
project_path: pathlib.Path
log_file: _io.TextIOWrapper
metrics_file: _io.TextIOWrapper
artifacts_file: _io.TextIOWrapper
metric_names: list[str]
logger_meta: dict
syllabic_id: str
def get_timestamp(self) -> strdef debug(self, message: str, **kwargs) -> Nonelog a debug message
def message(self, message: str, **kwargs) -> Nonelog a progress message
def warning(self, message: str, **kwargs) -> Nonelog a warning message
def error(self, message: str, **kwargs) -> Nonelog an error message
def metrics(self, data: dict[str, typing.Any]) -> Nonelog a dictionary of metrics
def artifact(
self,
path: pathlib.Path,
type: str,
aliases: list[str] | None = None,
metadata: dict | None = None
) -> Nonelog an artifact from a file
url: strGet the URL for the current logging run
run_path: pathlib.PathGet the path to the current logging run
def flush(self) -> NoneFlush the logger
def finish(self) -> NoneFinish logging
docs for
trnblv0.1.0
trnbl.loggers.local.build_distdef get_remote(
path_or_url: str,
download_remote: bool = False,
get_bytes: bool = False,
allow_remote_fail: bool = True
) -> str | bytes | Nonegets a resource from a path or url
get_bytes is
TrueNone if its from the web and
download_remote is Falsepath_or_url : str location of the resource. if it
starts with http, it is considered a urldownload_remote : bool whether to download the resource
if it is a url (defaults to False)get_bytes : bool whether to return the resource as
bytes (defaults to False)allow_remote_fail : bool if a remote resource fails to
download, return None. if this is False, raise
an exception (defaults to True)requests.HTTPError if the remote resource returns an
error, and allow_remote_fail is Falsestr|bytes|Nonedef build_dist(
path: pathlib.Path,
minify: bool = True,
download_remote: bool = True
) -> strBuild a single file html from a folder
partially from https://stackoverflow.com/questions/44646481/merging-js-css-html-into-single-html
def main() -> Nonedocs for
trnblv0.1.0
trnbl.loggers.local.html_frontenddef get_html_frontend() -> strdocs for
trnblv0.1.0
trnbl.loggers.local.localloggerclass FilePaths:TRAIN_CONFIG: pathlib.Path = WindowsPath('config.json')
LOGGER_META: pathlib.Path = WindowsPath('meta.json')
TRAIN_CONFIG_YML: pathlib.Path = WindowsPath('config.yml')
LOGGER_META_YML: pathlib.Path = WindowsPath('meta.yml')
ARTIFACTS: pathlib.Path = WindowsPath('artifacts.jsonl')
METRICS: pathlib.Path = WindowsPath('metrics.jsonl')
LOG: pathlib.Path = WindowsPath('log.jsonl')
ERROR_FILE: pathlib.Path = WindowsPath('ERROR.txt')
RUNS_MANIFEST: pathlib.Path = WindowsPath('runs.jsonl')
RUNS_DIR: pathlib.Path = WindowsPath('runs')
HTML_INDEX: pathlib.Path = WindowsPath('index.html')
START_SERVER: pathlib.Path = WindowsPath('start_server.py')
class LocalLogger(trnbl.loggers.base.TrainingLoggerBase):Base class for training loggers
LocalLogger(
project: str,
metric_names: list[str],
train_config: dict,
group: str = '',
base_path: str | pathlib.Path = WindowsPath('trnbl-logs'),
memusage_as_metrics: bool = True,
console_msg_prefix: str = '# '
)log_list: list[dict]
metrics_list: list[dict]
artifacts_list: list[dict]
train_config: dict
project: str
group: str
group_str: str
base_path: pathlib.Path
console_msg_prefix: str
run_init_timestamp: datetime.datetime
run_id: str
project_path: pathlib.Path
log_file: _io.TextIOWrapper
metrics_file: _io.TextIOWrapper
artifacts_file: _io.TextIOWrapper
metric_names: list[str]
logger_meta: dict
syllabic_id: str
def get_timestamp(self) -> strdef debug(self, message: str, **kwargs) -> Nonelog a debug message
def message(self, message: str, **kwargs) -> Nonelog a progress message
def warning(self, message: str, **kwargs) -> Nonelog a warning message
def error(self, message: str, **kwargs) -> Nonelog an error message
def metrics(self, data: dict[str, typing.Any]) -> Nonelog a dictionary of metrics
def artifact(
self,
path: pathlib.Path,
type: str,
aliases: list[str] | None = None,
metadata: dict | None = None
) -> Nonelog an artifact from a file
url: strGet the URL for the current logging run
run_path: pathlib.PathGet the path to the current logging run
def flush(self) -> NoneFlush the logger
def finish(self) -> NoneFinish logging
docs for
trnblv0.1.0
Usage: python start_server.py path/to/directory [port]
trnbl.loggers.local.start_serverUsage: python start_server.py path/to/directory [port]
def start_server(path: str, port: int = 8000) -> NoneStarts a server to serve the files in the given path.
docs for
trnblv0.1.0
trnbl.loggers.multidef maybe_flatten(lst: list[typing.Union[~T, list[~T]]]) -> list[~T]flatten a list if it is nested
class MultiLogger(trnbl.loggers.base.TrainingLoggerBase):use multiple loggers at once
MultiLogger(loggers: list[trnbl.loggers.base.TrainingLoggerBase])loggers: list[trnbl.loggers.base.TrainingLoggerBase]def debug(self, message: str, **kwargs) -> Nonelog a debug message which will be saved, but not printed
def message(self, message: str, **kwargs) -> Nonelog a progress message
def metrics(self, data: dict[str, typing.Any]) -> NoneLog a dictionary of metrics
def artifact(
self,
path: pathlib.Path,
type: str,
aliases: list[str] | None = None,
metadata: dict | None = None
) -> Nonelog an artifact from a file
url: list[str]Get the URL for the current logging run
run_path: list[pathlib.Path]Get the paths to the current logging run
def flush(self) -> NoneFlush the logger
def finish(self) -> NoneFinish logging
docs for
trnblv0.1.0
trnbl.loggers.tensorboardclass TensorBoardLogger(trnbl.loggers.base.TrainingLoggerBase):Base class for training loggers
TensorBoardLogger(
log_dir: str | pathlib.Path,
train_config: dict | None = None,
name: str | None = None,
**kwargs
)def debug(self, message: str, **kwargs) -> Nonelog a debug message which will be saved, but not printed
def message(self, message: str, **kwargs) -> Nonelog a progress message, which will be printed to stdout
def metrics(self, data: dict[str, typing.Any]) -> NoneLog a dictionary of metrics
def artifact(
self,
path: pathlib.Path,
type: str,
aliases: list[str] | None = None,
metadata: dict | None = None
) -> Nonelog an artifact from a file
url: strGet the URL for the current logging run
run_path: pathlib.PathGet the path to the current logging run
def flush(self) -> NoneFlush the logger
def finish(self) -> NoneFinish logging
docs for
trnblv0.1.0
trnbl.loggers.wandbclass WandbLogger(trnbl.loggers.base.TrainingLoggerBase):wrapper around wandb logging for TrainingLoggerBase.
create using
<a href="#WandbLogger.create">WandbLogger.create</a>(config, project, job_type)
WandbLogger(run: wandb.sdk.wandb_run.Run)def create(
cls,
config: dict,
project: str | None = None,
job_type: str | None = None,
logging_fmt: str = '%(asctime)s %(levelname)s %(message)s',
logging_datefmt: str = '%Y-%m-%d %H:%M:%S',
wandb_kwargs: dict | None = None
) -> trnbl.loggers.wandb.WandbLoggerdef debug(self, message: str, **kwargs) -> Nonelog a debug message which will be saved, but not printed
def message(self, message: str, **kwargs) -> Nonelog a progress message, which will be printed to stdout
def metrics(self, data: dict[str, typing.Any]) -> NoneLog a dictionary of metrics
def artifact(
self,
path: pathlib.Path,
type: str,
aliases: list[str] | None = None,
metadata: dict | None = None
) -> Nonelog an artifact from a file
url: strGet the URL for the current logging run
run_path: pathlib.PathGet the path to the current logging run
def flush(self) -> NoneFlush the logger
def finish(self) -> NoneFinish logging
docs for
trnblv0.1.0
TrainingIntervalUnitWhenIntervalLessThanBatchIntervalValueErrorTrainingIntervalCastableToTrainingIntervaltrnbl.training_intervalTrainingIntervalUnit = typing.Literal['runs', 'epochs', 'batches', 'samples']
WhenIntervalLessThanBatch: muutils.errormode.ErrorMode = ErrorMode.Warn
class IntervalValueError(builtins.UserWarning):Error for when the interval is less than 1 batch
class TrainingInterval:A training interval, which can be specified in a few different units.
quantity: int|float - the quantity of the intervalunit: TrainingIntervalUnit - the unit of the interval,
one of “runs”, “epochs”, “batches”, or “samples”<a href="#TrainingInterval.from_str">TrainingInterval.from_str</a>(raw: str) -> TrainingInterval
- parse a string into a TrainingInterval object<a href="#TrainingInterval.as_batch_count">TrainingInterval.as_batch_count</a>(batchsize: int, batches_per_epoch: int, epochs: int|None) -> int
- convert the interval to a raw number of batches<a href="#TrainingInterval.process_to_batches">TrainingInterval.process_to_batches</a>(interval: str|TrainingInterval, batchsize: int, batches_per_epoch: int, epochs: int|None) -> int
- any representation to a number of batches<a href="#TrainingInterval.normalized">TrainingInterval.normalized</a>(batchsize: int, batches_per_epoch: int, epochs: int|None) -> None
- current interval, with units switched to batchesProvides methods for reading from a string or tuple, and normalizing to batches.
TrainingInterval(
quantity: int | float,
unit: Literal['runs', 'epochs', 'batches', 'samples']
)quantity: int | float
unit: Literal['runs', 'epochs', 'batches', 'samples']
def as_batch_count(
self,
batchsize: int,
batches_per_epoch: int,
epochs: int | None = None
) -> intgiven the batchsize, number of batches per epoch, and number of epochs, return the interval as a number of batches
batchsize: int the size of a batchbatches_per_epoch: int the number of batches in an
epochepochs: int|None the number of epochs to run (only
required if the interval is in “runs”)int the interval as a number of batchesValueError if the interval is less than 1 batch, and
the
<a href="#WhenIntervalLessThanBatch">WhenIntervalLessThanBatch</a>
is set to muutils.errormode.ErrorMode.ERROR otherwise, will
warn or ignore and set the interval to 1 batchValueError if the unit is not one of “runs”, “epochs”,
“batches”, or “samples”def normalized(
self,
batchsize: int,
batches_per_epoch: int,
epochs: int | None = None
) -> trnbl.training_interval.TrainingIntervalconvert the units of the interval to batches, by calling
as_batch_count and setting the unit to
“batches
def from_str(cls, raw: str) -> trnbl.training_interval.TrainingIntervalparse a string into a TrainingInterval object
TrainingInterval.from_str(“5 epochs”) TrainingInterval(5, ‘epochs’) TrainingInterval.from_str(“100 batches”) TrainingInterval(100, ‘batches’) TrainingInterval.from_str(“0.1 runs”) TrainingInterval(0.1, ‘runs’) TrainingInterval.from_str(“1/5 runs”) TrainingInterval(0.2, ‘runs’)
def from_any(cls, *args, **kwargs) -> trnbl.training_interval.TrainingIntervalparse a string or tuple into a TrainingInterval object
def process_to_batches(
cls,
interval: Union[str, tuple[Union[int, float, str], str], trnbl.training_interval.TrainingInterval],
batchsize: int,
batches_per_epoch: int,
epochs: int | None = None
) -> intdirectly from any representation to a number of batches
CastableToTrainingInterval = typing.Union[str, tuple[typing.Union[int, float, str], str], trnbl.training_interval.TrainingInterval]docs for
trnblv0.1.0
trnbl.training_managerEvalFunction = typing.Callable[[ForwardRef('torch.nn.Module')], dict]class TrainingManagerInitError(builtins.Exception):Common base class for all non-exit exceptions.
def wrapped_iterable(
sequence: Sequence[~T],
manager: trnbl.training_manager.TrainingManager,
is_epoch: bool = False,
use_tqdm: bool | None = None,
tqdm_kwargs: dict[str, typing.Any] | None = None
) -> Generator[~T, NoneType, NoneType]class TrainingManager(typing.Generic[~TLogger]):context manager for training a model, with logging, evals, and checkpoints
model : torch.nn.Module ref to model being trained -
used for saving checkpointsdataloader : torch.utils.data.DataLoader ref to
dataloader being used - used for calculating training progresslogger : TrainingLoggerBase logger, which can be local
or interface with wandb.epochs : int number of epochs to train for (defaults to
1)evals : Iterable[tuple[TrainingInterval | str, EvalFunction]] | None
list of pairs of (interval, eval_fn) to run evals on the model. See
TrainingInterval for interval options. (defaults to
None)checkpoint_interval : TrainingInterval | str interval
at which to save model checkpoints (defaults to
TrainingInterval(1, "epochs"))print_metrics_interval : TrainingInterval | str
interval at which to print metrics (defaults to
TrainingInterval(0.1, "runs"))save_model : Callable[[torch.nn.Module, Path], None]
function to save the model (defaults to torch.save)
(defaults to torch.save)model_save_path : str format string for saving model
checkpoints. uses _get_format_kwargs for formatting, along
with an alias kwarg (defaults to
"{run_path}/checkpoints/model.checkpoint-{latest_checkpoint}.pt")model_save_path_special : str format string for saving
special model checkpoints (final, exception, etc.). uses
_get_format_kwargs for formatting, along with an
alias kwarg (defaults to
"{run_path}/model.{alias}.pt")with TrainingManager(
model=model, dataloader=TRAIN_LOADER, logger=logger, epochs=500,
evals={
"1 epochs": eval_func,
"0.1 runs": lambda model: logger.get_mem_usage(),
}.items(),
checkpoint_interval="50 epochs",
) as tp:
# Training loop
model.train()
for epoch in range(epochs):
for inputs, targets in TRAIN_LOADER:
# the usual
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
# compute accuracy
accuracy = torch.sum(torch.argmax(outputs, dim=1) == targets).item() / len(targets)
# log metrics
tp.batch_update(
# pass in number of samples in your batch (or it will be inferred from the batch size)
samples=len(targets),
# any other metrics you compute every loop
**{"train/loss": loss.item(), "train/acc": accuracy},
)
# batch_update will automatically run evals and save checkpoints as needed
tp.epoch_update()TrainingManager(
model: torch.nn.modules.module.Module,
logger: ~TLogger,
dataloader: torch.utils.data.dataloader.DataLoader | None = None,
epochs_total: int | None = None,
save_model: Callable[[torch.nn.modules.module.Module, pathlib.Path], NoneType] = <function save>,
evals: Optional[Iterable[tuple[Union[str, tuple[Union[int, float, str], str], trnbl.training_interval.TrainingInterval], Callable[[torch.nn.modules.module.Module], dict]]]] = None,
checkpoint_interval: Union[str, tuple[Union[int, float, str], str], trnbl.training_interval.TrainingInterval] = TrainingInterval(quantity=1, unit='epochs'),
print_metrics_interval: Union[str, tuple[Union[int, float, str], str], trnbl.training_interval.TrainingInterval] = TrainingInterval(quantity=0.1, unit='runs'),
model_save_path: str = '{run_path}/checkpoints/model.checkpoint-{latest_checkpoint}.pt',
model_save_path_special: str = '{run_path}/model.{alias}.pt'
)start_time: float
model: torch.nn.modules.module.Module
logger: ~TLogger
save_model: Callable[[torch.nn.modules.module.Module, pathlib.Path], NoneType]
model_save_path: str
model_save_path_special: str
evals: list[tuple[int, typing.Callable[[torch.nn.modules.module.Module], dict]]]
checkpoint_interval: int | None
print_metrics_interval: int | None
epochs: int
batches: int
samples: int
checkpoints: int
epochs_total: int | None
batches_per_epoch: int | None
batch_size: int | None
samples_per_epoch: int | None
batches_total: int | None
samples_total: int | None
init_complete: bool
def try_compute_counters(self) -> Nonedef epoch_loop(
self,
epochs: Sequence[int],
use_tqdm: bool = True,
**tqdm_kwargs
) -> Generator[int, NoneType, NoneType]def batch_loop(
self,
batches: Sequence[int],
use_tqdm: bool = False,
**tqdm_kwargs
) -> Generator[int, NoneType, NoneType]def check_is_initialized(self)def get_elapsed_time(self) -> floatreturn the elapsed time in seconds since the start of training
def training_status(self) -> dict[str, int | float]status of elapsed time, samples, batches, epochs, and checkpoints
def batch_update(self, samples: int | None, metrics: dict | None = None, **kwargs)call this at the end of every batch. Pass samples or it
will be inferred from the batch size, and any other metrics as
kwargs
This function will: - update internal counters - run evals as needed (based on the intervals passed) - log all metrics and training status - save a checkpoint as needed (based on the checkpoint interval)
def epoch_update(self)call this at the end of every epoch. This function will log the completion of the epoch and update the epoch counter