Metadata-Version: 2.1
Name: torch-runner
Version: 0.0.2
Summary: Trainer for Pytorch
Home-page: https://github.com/grohith327
Author: Rohith Gandhi G
Author-email: grohith327@gmail.com
License: MIT
Platform: UNKNOWN
Classifier: Development Status :: 5 - Production/Stable
Classifier: Intended Audience :: Developers
Classifier: Topic :: Software Development :: Build Tools
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: Operating System :: OS Independent
Classifier: License :: OSI Approved :: MIT License
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.4
Classifier: Programming Language :: Python :: 3.5
Classifier: Programming Language :: Python :: 3.6
Classifier: Programming Language :: Python :: 3.7
Description-Content-Type: text/markdown
Requires-Dist: torch
Requires-Dist: tqdm

# Torch Runner
### A minimal wrapper that removes some of the overhead code in training pytorch models
Note: If you are looking for something more extensive, checkout [Pytorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning). This is mostly designed for my personal use. 

## Requirements
- torch
- tqdm

## Installation

```
pip install torch-runner
```

## Features
- seed all variables
- text logger
- early stopping
- save hyperparameters

## Example
Checkout the [examples](https://github.com/grohith327/Torch-Runner/blob/main/examples/Torch_Runner_CIFAR10.ipynb) folder which contains a jupyter notebook to train a resnet50 using torch_runner. 

```python
import torch
import torch_runner as T


class myTrainer(T.TrainerModule):
    def __init__(self, model, optimizer):
        super(myTrainer, self).__init__(model, optimizer)

    def calc_metric(self, preds, target):
        ## Calc metrics such as accuracy etc.

    def loss_fct(self, preds, target):
        ## Calc loss

    def train_one_step(self, batch, batch_id):
        ## Get batch data from dataloader and perform one update

    def valid_one_step(self, batch, batch_id):
        ## Perform validation step

model = myModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

train_dataloader = ## pytorch dataloader
val_dataloader = ## pytorch dataloader

Trainer = myTrainer(model, optimizer)
Trainer.fit(train_dataloader, val_dataloader, epochs=10, batch_size=32)
```


