Metadata-Version: 2.1
Name: cjm-pytorch-utils
Version: 0.0.2
Summary: Some utility functions for working with PyTorch.
Home-page: https://github.com/cj-mills/cjm-pytorch-utils
Author: cj-mills
Author-email: millscj.mills2@gmail.com
License: Apache Software License 2.0
Keywords: nbdev jupyter notebook python
Classifier: Development Status :: 4 - Beta
Classifier: Intended Audience :: Developers
Classifier: Natural Language :: English
Classifier: Programming Language :: Python :: 3.7
Classifier: Programming Language :: Python :: 3.8
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: License :: OSI Approved :: Apache Software License
Requires-Python: >=3.7
Description-Content-Type: text/markdown
Provides-Extra: dev
License-File: LICENSE

cjm-pytorch-utils
================

<!-- WARNING: THIS FILE WAS AUTOGENERATED! DO NOT EDIT! -->

## Install

``` sh
pip install cjm_pytorch_utils
```

## How to use

### pil_to_tensor

``` python
from cjm_pytorch_utils.core import pil_to_tensor
from PIL import Image
from torchvision import transforms
```

``` python
img_path = img_path = '../images/cat.jpg'
src_img = Image.open(img_path).convert('RGB')
print(f"Source Image Size: {src_img.size}")

img_tensor = pil_to_tensor(src_img, [0.5], [0.5])
img_tensor.shape, img_tensor.min(), img_tensor.max()
```

    Source Image Size: (768, 512)

    (torch.Size([1, 3, 512, 768]), tensor(-1.), tensor(1.))

### tensor_to_pil

``` python
from cjm_pytorch_utils.core import tensor_to_pil
```

``` python
tensor_img = tensor_to_pil(transforms.ToTensor()(src_img))
tensor_img
```

![](index_files/figure-commonmark/cell-5-output-1.png)

### iterate_modules

``` python
from cjm_pytorch_utils.core import iterate_modules
import torch
from torchvision import models
```

``` python
vgg = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features

for index, module in enumerate(iterate_modules(vgg)):
    if type(module) == torch.nn.modules.activation.ReLU:
        print(f"{index}: {module}")
```

    1: ReLU(inplace=True)
    3: ReLU(inplace=True)
    6: ReLU(inplace=True)
    8: ReLU(inplace=True)
    11: ReLU(inplace=True)
    13: ReLU(inplace=True)
    15: ReLU(inplace=True)
    18: ReLU(inplace=True)
    20: ReLU(inplace=True)
    22: ReLU(inplace=True)
    25: ReLU(inplace=True)
    27: ReLU(inplace=True)
    29: ReLU(inplace=True)

### tensor_stats_df

``` python
from cjm_pytorch_utils.core import tensor_stats_df
```

``` python
tensor_stats_df(torch.randn(1, 3, 256, 256))
```

<div>
<style scoped>
    .dataframe tbody tr th:only-of-type {
        vertical-align: middle;
    }

    .dataframe tbody tr th {
        vertical-align: top;
    }

    .dataframe thead th {
        text-align: right;
    }
</style>
<table border="1" class="dataframe">
  <thead>
    <tr style="text-align: right;">
      <th></th>
      <th>0</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <th>mean</th>
      <td>0.000952</td>
    </tr>
    <tr>
      <th>std</th>
      <td>0.998587</td>
    </tr>
    <tr>
      <th>min</th>
      <td>-4.616786</td>
    </tr>
    <tr>
      <th>max</th>
      <td>5.122179</td>
    </tr>
    <tr>
      <th>shape</th>
      <td>(1, 3, 256, 256)</td>
    </tr>
  </tbody>
</table>
</div>
