Metadata-Version: 2.1
Name: fau-tools
Version: 1.6.1
Summary: A python module. The main function is for pytorch training.
Home-page: https://github.com/Fau818/fau-tools
License: MIT
Author: Fau
Author-email: Fau818@qq.com
Requires-Python: >=3.8
Classifier: License :: OSI Approved :: MIT License
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.8
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Project-URL: Repository, https://github.com/Fau818/fau-tools
Description-Content-Type: text/markdown

## Introduction

This is an individual module, which is mainly for **pytorch CNN** training.

Moreover, it also supports some awesome features: saving model, saving training process, plotting figures and so on...

## Install

`pip install fau-tools`

## Usage

### import

The following code is recommended.

```python
import fau_tools
from fau_tools import torch_tools
```

### quick start

The tutor will use a simple example to help you get started quickly!

**The following example uses Fau-tools to train a model in MNIST hand-written digits dataset.**

```python
import torch
import torch.utils.data as tdata
import torchvision
from torch import nn

import fau_tools
from fau_tools import torch_tools

# A simple CNN network
class CNN(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv = nn.Sequential(
      nn.Conv2d(1, 16, 3, 1, 1),  # -> (16, 28, 28)
      nn.ReLU(),
      nn.MaxPool2d(2),  # -> (16, 14, 14)

      nn.Conv2d(16, 32, 3, 1, 1),  # -> (32, 14, 14)
      nn.ReLU(),
      nn.MaxPool2d(2)  # -> (32, 7, 7)
    )
    self.output = nn.Linear(32 * 7 * 7, 10)

  def forward(self, x):
    x = self.conv(x)
    x = x.flatten(1)  # same as x = x.view(x.size(0), -1)
    return self.output(x)


# Hyper Parameters definition
total_epoch = 10
lr = 1E-3
batch_size = 1024

# Load dataset
train_data      = torchvision.datasets.MNIST('Datasets', True, torchvision.transforms.ToTensor(), download=True)
test_data       = torchvision.datasets.MNIST('Datasets', False, torchvision.transforms.ToTensor())
train_data.data = train_data.data[:6000]  # mini data
test_data.data  = test_data.data[:2000]  # mini data

# Get data loader
train_loader = tdata.DataLoader(train_data, batch_size, True)
test_loader  = tdata.DataLoader(test_data, batch_size)

# Initialize model, optimizer and loss function
model = CNN()
optimizer = torch.optim.Adam(model.parameters(), lr)
loss_function = nn.CrossEntropyLoss()

# Train!
torch_tools.torch_train(model, train_loader, test_loader, optimizer, loss_function, total_epoch=total_epoch, name="MNIST")
# the last parameter is the name for saving model and training process.
```

Now, we can run the python file, and the training process will be visualized, just like the following picture.

![training_visualization](github_attachment/training_visualization.png)

> Three files named `MNIST_9846.pth`, `MNIST_9846.csv` and `MNIST_9846.txt` will be saved.
>
> The first file is the trained model.
>
> The second file records the training process, which you can use matplotlib to visualize it.
>
> The third file saves some hyper parameters about the training.

---

The above is the primary usage of this tool, but there are also some other snazzy features, which will be introduced later.

## END

Hope you could like it! And welcome issues and pull requests.

