Metadata-Version: 2.1
Name: stitchnet
Version: 0.2.1
Summary: 
Author: Surat Teerapittayanon
Author-email: steerapi@gmail.com
Requires-Python: >=3.9,<3.12
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.9
Requires-Dist: datasets (>=2.13.1,<3.0.0)
Requires-Dist: evaluate (>=0.4.0,<0.5.0)
Requires-Dist: graphviz (>=0.20.1,<0.21.0)
Requires-Dist: netron (>=7.0.4,<8.0.0)
Requires-Dist: numpy (>=1.23.1,<2.0.0)
Requires-Dist: onnx (==1.13.1)
Requires-Dist: onnx-tool (==0.7.3)
Requires-Dist: onnx2torch (>=1.4.1,<2.0.0)
Requires-Dist: onnxoptimizer (>=0.3.0,<0.4.0)
Requires-Dist: onnxruntime (>=1.12.1,<2.0.0)
Requires-Dist: pyppeteer (>=1.0.2,<2.0.0)
Requires-Dist: pyre-extensions (>=0.0.30,<0.0.31)
Requires-Dist: scipy (>=1.7.3,<2.0.0)
Requires-Dist: seaborn (>=0.11.2,<0.12.0)
Requires-Dist: skl2onnx (>=1.12.0,<1.13.0)
Requires-Dist: torch (>=1.12.1,<2.0.0)
Requires-Dist: torchvision (>=0.13.1,<0.14.0)
Requires-Dist: transformers (>=4.25.1,<5.0.0)
Description-Content-Type: text/markdown

StitchNet: Composing Neural Networks from Pre-Trained Fragments
=============


Installation
=============

    pip install stitchnet
    
Usage
=============
    
    import stitchnet
    
    # load the beans dataset from huggingface
    from stitchnet import load_hf_dataset
    dataset_train, dataset_val = load_hf_dataset('beans', train_split='validation', val_split='test', label_column='labels', seed=47)

    # prepare stitching dataset
    import numpy as np
    from tqdm import tqdm
    stitching_dataset = np.vstack([x['pixel_values'] for x in tqdm(dataset_train.select(range(32)))])

    # generate stitchnets 1 sample
    score,net = generate(stitching_dataset, threshold=0, totalThreshold=0, maxDepth=10, K=2, sample=True)
    
    # generate multiple stitchnets
    generator = generate(stitching_dataset, threshold=0.8, totalThreshold=0.8, maxDepth=10, K=2, sample=False)
    for score,net in generator:
        print(score,net)
    
    # print macs and params
    net.get_macs_params() # {'macs': 4488343528.0, 'params': 25653096}
    
    # save onnx
    net.save_onnx('./_data/net') # saving to ./_results/net.onnx
        
    # draw the stitchnet
    net.draw_svg('./_data/net') # saving to ./_results/net.svg
    
    # train a classifier
    net.fit(dataset_train, label_column="labels")
    
    # use it for prediction
    net.predict_files(['./_results/test.jpg']) # [{'score': [0.8, 0.2, 0.0], 'label': 0}]
    
    # evaluate with validation dataset
    net.evaluate_dataset(dataset_val, label_column='labels') # {'accuracy': 0.7421875}

CUDA
=============
See https://pytorch.org/get-started/previous-versions/ to install appropriate version. For example

    # CUDA 11.6
    pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu116


Experiment Notebooks
=============

1. Download dogs and cats dataset from https://www.kaggle.com/c/dogs-vs-cats/data and put train data in _data/dogs_cats/raw/train folder
2. See 00_prepare_data.ipynb to split the images into cats and dogs folder
3. See 01_download_networks.ipynb to download the pretrained networks from Torchvision
4. See 02_generate_fragments.ipynb to generate fragments from the pretrained networks
5. See 03_stitchnet.ipynb to generate stitchnets
6. See 04_render_graph.ipynb to create svg images of the network graphs using netron
7. See 05_eval_original_networks.ipynb for evaluating the original pretrained networks
8. See 06_finetuning.ipynb to generate the finetuning result
9. See 07_ensemble.ipynb to generate the ensemble result
10. See 08_number_of_samples_for_stitching.ipynb for experimenting with varying number of samples to use when stitching
11. See 09_plot_results.ipynb plot figures of the results for the paper


Installation using conda
=============

Create a new conda env

    conda create -n stitchnet python=3.10
    
Activate stitchnet conda env

    conda activate stitchnet

For conda and NVIDIA gpu, please also install for CUDA runtime on onnx

    conda install -c conda-forge cudnn
    
Install poetry

    curl -sSL https://install.python-poetry.org | python3 -

Install dependencies using poetry 

    poetry install


