Metadata-Version: 2.2
Name: openml-pytorch
Version: 0.1.2
Summary: Pytorch extension for OpenML
Author-email: SubhadityaMukherjee <msubhaditya@gmail.com>, Taniya Das <t.das@tue.nl>
Requires-Python: >=3.9
Description-Content-Type: text/markdown
License-File: LICENSE.md
Requires-Dist: absl-py==2.1.0
Requires-Dist: argon2-cffi==23.1.0
Requires-Dist: argon2-cffi-bindings==21.2.0
Requires-Dist: certifi==2025.1.31
Requires-Dist: cffi==1.17.1
Requires-Dist: charset-normalizer==3.4.1
Requires-Dist: contourpy==1.3.1
Requires-Dist: cycler==0.12.1
Requires-Dist: filelock==3.17.0
Requires-Dist: fonttools==4.56.0
Requires-Dist: fsspec==2025.3.0
Requires-Dist: grpcio==1.71.0
Requires-Dist: idna==3.10
Requires-Dist: jinja2==3.1.6
Requires-Dist: joblib==1.4.2
Requires-Dist: kiwisolver==1.4.8
Requires-Dist: liac-arff==2.5.0
Requires-Dist: markdown==3.7
Requires-Dist: markupsafe==3.0.2
Requires-Dist: matplotlib==3.10.1
Requires-Dist: minio==7.2.15
Requires-Dist: mpmath==1.3.0
Requires-Dist: netron==8.2.0
Requires-Dist: networkx==3.4.2
Requires-Dist: numpy==2.2.3
Requires-Dist: onnx==1.17.0
Requires-Dist: openml==0.15.1
Requires-Dist: packaging==24.2
Requires-Dist: pandas==2.2.3
Requires-Dist: pillow==11.1.0
Requires-Dist: protobuf==6.30.0
Requires-Dist: pyarrow==19.0.1
Requires-Dist: pycparser==2.22
Requires-Dist: pycryptodome==3.21.0
Requires-Dist: pyparsing==3.2.1
Requires-Dist: python-dateutil==2.9.0.post0
Requires-Dist: pytz==2025.1
Requires-Dist: requests==2.32.3
Requires-Dist: scikit-learn==1.6.1
Requires-Dist: scipy==1.15.2
Requires-Dist: six==1.17.0
Requires-Dist: sympy==1.13.1
Requires-Dist: tensorboard==2.19.0
Requires-Dist: tensorboard-data-server==0.7.2
Requires-Dist: threadpoolctl==3.5.0
Requires-Dist: torch==2.6.0
Requires-Dist: torchvision==0.21.0
Requires-Dist: tqdm==4.67.1
Requires-Dist: typing-extensions==4.12.2
Requires-Dist: tzdata==2025.1
Requires-Dist: urllib3==2.3.0
Requires-Dist: werkzeug==3.1.3
Requires-Dist: xmltodict==0.14.2
Provides-Extra: dev
Requires-Dist: absl-py==2.1.0; extra == "dev"
Requires-Dist: appnope==0.1.4; extra == "dev"
Requires-Dist: argon2-cffi==23.1.0; extra == "dev"
Requires-Dist: argon2-cffi-bindings==21.2.0; extra == "dev"
Requires-Dist: asttokens==3.0.0; extra == "dev"
Requires-Dist: attrs==25.3.0; extra == "dev"
Requires-Dist: babel==2.17.0; extra == "dev"
Requires-Dist: backrefs==5.8; extra == "dev"
Requires-Dist: beautifulsoup4==4.13.3; extra == "dev"
Requires-Dist: bleach[css]==6.2.0; extra == "dev"
Requires-Dist: bracex==2.5.post1; extra == "dev"
Requires-Dist: certifi==2025.1.31; extra == "dev"
Requires-Dist: cffi==1.17.1; extra == "dev"
Requires-Dist: charset-normalizer==3.4.1; extra == "dev"
Requires-Dist: click==8.1.8; extra == "dev"
Requires-Dist: colorama==0.4.6; extra == "dev"
Requires-Dist: comm==0.2.2; extra == "dev"
Requires-Dist: contourpy==1.3.1; extra == "dev"
Requires-Dist: cycler==0.12.1; extra == "dev"
Requires-Dist: debugpy==1.8.13; extra == "dev"
Requires-Dist: decorator==5.2.1; extra == "dev"
Requires-Dist: defusedxml==0.7.1; extra == "dev"
Requires-Dist: executing==2.2.0; extra == "dev"
Requires-Dist: fastjsonschema==2.21.1; extra == "dev"
Requires-Dist: filelock==3.17.0; extra == "dev"
Requires-Dist: fonttools==4.56.0; extra == "dev"
Requires-Dist: fsspec==2025.3.0; extra == "dev"
Requires-Dist: ghp-import==2.1.0; extra == "dev"
Requires-Dist: gitdb==4.0.12; extra == "dev"
Requires-Dist: gitpython==3.1.44; extra == "dev"
Requires-Dist: griffe==1.6.0; extra == "dev"
Requires-Dist: grpcio==1.71.0; extra == "dev"
Requires-Dist: idna==3.10; extra == "dev"
Requires-Dist: ipykernel==6.29.5; extra == "dev"
Requires-Dist: ipython==9.0.2; extra == "dev"
Requires-Dist: ipython-pygments-lexers==1.1.1; extra == "dev"
Requires-Dist: jedi==0.19.2; extra == "dev"
Requires-Dist: jinja2==3.1.6; extra == "dev"
Requires-Dist: joblib==1.4.2; extra == "dev"
Requires-Dist: jsonschema==4.23.0; extra == "dev"
Requires-Dist: jsonschema-specifications==2024.10.1; extra == "dev"
Requires-Dist: jupyter-client==8.6.3; extra == "dev"
Requires-Dist: jupyter-core==5.7.2; extra == "dev"
Requires-Dist: jupyterlab-pygments==0.3.0; extra == "dev"
Requires-Dist: jupytext==1.16.7; extra == "dev"
Requires-Dist: kiwisolver==1.4.8; extra == "dev"
Requires-Dist: liac-arff==2.5.0; extra == "dev"
Requires-Dist: markdown==3.7; extra == "dev"
Requires-Dist: markdown-it-py==3.0.0; extra == "dev"
Requires-Dist: markupsafe==3.0.2; extra == "dev"
Requires-Dist: matplotlib==3.10.1; extra == "dev"
Requires-Dist: matplotlib-inline==0.1.7; extra == "dev"
Requires-Dist: mdit-py-plugins==0.4.2; extra == "dev"
Requires-Dist: mdurl==0.1.2; extra == "dev"
Requires-Dist: mergedeep==1.3.4; extra == "dev"
Requires-Dist: minio==7.2.15; extra == "dev"
Requires-Dist: mistune==3.1.2; extra == "dev"
Requires-Dist: mkdocs==1.6.1; extra == "dev"
Requires-Dist: mkdocs-autorefs==1.4.1; extra == "dev"
Requires-Dist: mkdocs-awesome-pages-plugin==2.10.1; extra == "dev"
Requires-Dist: mkdocs-get-deps==0.2.0; extra == "dev"
Requires-Dist: mkdocs-jupyter==0.25.1; extra == "dev"
Requires-Dist: mkdocs-material==9.6.8; extra == "dev"
Requires-Dist: mkdocs-material-extensions==1.3.1; extra == "dev"
Requires-Dist: mkdocs-redirects==1.2.2; extra == "dev"
Requires-Dist: mkdocstrings==0.29.0; extra == "dev"
Requires-Dist: mkdocstrings-python==1.16.5; extra == "dev"
Requires-Dist: mknotebooks==0.8.0; extra == "dev"
Requires-Dist: mpmath==1.3.0; extra == "dev"
Requires-Dist: natsort==8.4.0; extra == "dev"
Requires-Dist: nbclient==0.10.2; extra == "dev"
Requires-Dist: nbconvert==7.16.6; extra == "dev"
Requires-Dist: nbformat==5.10.4; extra == "dev"
Requires-Dist: nest-asyncio==1.6.0; extra == "dev"
Requires-Dist: netron==8.2.0; extra == "dev"
Requires-Dist: networkx==3.4.2; extra == "dev"
Requires-Dist: numpy==2.2.3; extra == "dev"
Requires-Dist: onnx==1.17.0; extra == "dev"
Requires-Dist: openml==0.15.1; extra == "dev"
Requires-Dist: packaging==24.2; extra == "dev"
Requires-Dist: paginate==0.5.7; extra == "dev"
Requires-Dist: pandas==2.2.3; extra == "dev"
Requires-Dist: pandocfilters==1.5.1; extra == "dev"
Requires-Dist: parso==0.8.4; extra == "dev"
Requires-Dist: pathspec==0.12.1; extra == "dev"
Requires-Dist: pexpect==4.9.0; extra == "dev"
Requires-Dist: pillow==11.1.0; extra == "dev"
Requires-Dist: platformdirs==4.3.6; extra == "dev"
Requires-Dist: prompt-toolkit==3.0.50; extra == "dev"
Requires-Dist: protobuf==6.30.0; extra == "dev"
Requires-Dist: psutil==7.0.0; extra == "dev"
Requires-Dist: ptyprocess==0.7.0; extra == "dev"
Requires-Dist: pure-eval==0.2.3; extra == "dev"
Requires-Dist: pyarrow==19.0.1; extra == "dev"
Requires-Dist: pycparser==2.22; extra == "dev"
Requires-Dist: pycryptodome==3.21.0; extra == "dev"
Requires-Dist: pygments==2.19.1; extra == "dev"
Requires-Dist: pymdown-extensions==10.14.3; extra == "dev"
Requires-Dist: pyparsing==3.2.1; extra == "dev"
Requires-Dist: python-dateutil==2.9.0.post0; extra == "dev"
Requires-Dist: pytz==2025.1; extra == "dev"
Requires-Dist: pyyaml==6.0.2; extra == "dev"
Requires-Dist: pyyaml-env-tag==0.1; extra == "dev"
Requires-Dist: pyzmq==26.3.0; extra == "dev"
Requires-Dist: referencing==0.36.2; extra == "dev"
Requires-Dist: requests==2.32.3; extra == "dev"
Requires-Dist: rpds-py==0.23.1; extra == "dev"
Requires-Dist: scikit-learn==1.6.1; extra == "dev"
Requires-Dist: scipy==1.15.2; extra == "dev"
Requires-Dist: six==1.17.0; extra == "dev"
Requires-Dist: smmap==5.0.2; extra == "dev"
Requires-Dist: soupsieve==2.6; extra == "dev"
Requires-Dist: stack-data==0.6.3; extra == "dev"
Requires-Dist: sympy==1.13.1; extra == "dev"
Requires-Dist: tensorboard==2.19.0; extra == "dev"
Requires-Dist: tensorboard-data-server==0.7.2; extra == "dev"
Requires-Dist: threadpoolctl==3.5.0; extra == "dev"
Requires-Dist: tinycss2==1.4.0; extra == "dev"
Requires-Dist: torch==2.6.0; extra == "dev"
Requires-Dist: torchvision==0.21.0; extra == "dev"
Requires-Dist: tornado==6.4.2; extra == "dev"
Requires-Dist: tqdm==4.67.1; extra == "dev"
Requires-Dist: traitlets==5.14.3; extra == "dev"
Requires-Dist: typing-extensions==4.12.2; extra == "dev"
Requires-Dist: tzdata==2025.1; extra == "dev"
Requires-Dist: urllib3==2.3.0; extra == "dev"
Requires-Dist: watchdog==6.0.0; extra == "dev"
Requires-Dist: wcmatch==10.0; extra == "dev"
Requires-Dist: wcwidth==0.2.13; extra == "dev"
Requires-Dist: webencodings==0.5.1; extra == "dev"
Requires-Dist: werkzeug==3.1.3; extra == "dev"
Requires-Dist: xmltodict==0.14.2; extra == "dev"

# Pytorch extension for OpenML python

Pytorch extension for [openml-python API](https://github.com/openml/openml-python). This library provides a simple way to run your Pytorch models on OpenML tasks. 

For a more native experience, PyTorch itself provides OpenML integrations for some tasks. You can find more information [here](<Integrations of OpenML in PyTorch.md>).

## Installation Instructions:

`pip install openml-pytorch`

PyPi link https://pypi.org/project/openml-pytorch/

Set the API key for OpenML from the command line:
```bash
openml configure apikey <your API key>
```

## Usage
### Load Data from OpenML and Train a Model
```python
# Import libraries
import openml
import torch
import numpy as np
from sklearn.model_selection import train_test_split
from typing import Any
from tqdm import tqdm

from openml_pytorch import GenericDataset

# Get dataset by ID and split into train and test
dataset = openml.datasets.get_dataset(20)
X, y, _, _ = dataset.get_data(target=dataset.default_target_attribute)
X = X.to_numpy(dtype=np.float32)  
y = y.to_numpy(dtype=np.int64)    
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=1, stratify=y)

# Dataloaders
ds_train = GenericDataset(X_train, y_train)
ds_test = GenericDataset(X_test, y_test)
dataloader_train = torch.utils.data.DataLoader(ds_train, batch_size=64, shuffle=True)
dataloader_test = torch.utils.data.DataLoader(ds_test, batch_size=64, shuffle=False)

# Model Definition
class TabularClassificationModel(torch.nn.Module):
    def __init__(self, input_size, output_size):
        super(TabularClassificationModel, self).__init__()
        self.fc1 = torch.nn.Linear(input_size, 128)
        self.fc2 = torch.nn.Linear(128, 64)
        self.fc3 = torch.nn.Linear(64, output_size)
        self.relu = torch.nn.ReLU()
        self.softmax = torch.nn.Softmax(dim=1)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        x = self.softmax(x)
        return x

# Train the model
trainer = BasicTrainer(
    model = TabularClassificationModel(X_train.shape[1], len(np.unique(y_train))),
    loss_fn = torch.nn.CrossEntropyLoss(),
    opt = torch.optim.Adam,
    dataloader_train = dataloader_train,
    dataloader_test = dataloader_test,
    device= torch.device("mps")
)
trainer.fit(10)
```
## More Complex Image Classification Example

Import openML libraries
```python
import torch.nn
import torch.optim

import openml_pytorch.config
import openml
import logging

from openml_pytorch.trainer import OpenMLTrainerModule
from openml_pytorch.trainer import OpenMLDataModule
from torchvision.transforms import Compose, Resize, ToPILImage, ToTensor, Lambda
import torchvision
from openml_pytorch.trainer import convert_to_rgb

```
Create a pytorch model and get a task from openML
```python
model = torchvision.models.efficientnet_b0(num_classes=200)
# Download the OpenML task for tiniest imagenet
task = openml.tasks.get_task(362128)
```
Download the task from openML and define Data and Trainer configuration
```python
transform = Compose(
    [
        ToPILImage(),  # Convert tensor to PIL Image to ensure PIL Image operations can be applied.
        Lambda(
            convert_to_rgb
        ),  # Convert PIL Image to RGB if it's not already.
        Resize(
            (64, 64)
        ),  # Resize the image.
        ToTensor(),  # Convert the PIL Image back to a tensor.
    ]
)
data_module = OpenMLDataModule(
    type_of_data="image",
    file_dir="datasets",
    filename_col="image_path",
    target_mode="categorical",
    target_column="label",
    batch_size = 64,
    transform=transform
)
trainer = OpenMLTrainerModule(
    data_module=data_module,
    verbose = True,
    epoch_count = 1,
)
openml_pytorch.config.trainer = trainer
```
Run the model on the task
```python
run = openml.runs.run_model_on_task(model, task, avoid_duplicate_runs=False)
run.publish()
print('URL for run: %s/run/%d' % (openml.config.server, run.run_id))
```
Note: The input layer of the network should be compatible with OpenML data output shape. Please check [examples](/examples/) for more information.

Additionally, if you want to publish the run with onnx file, then you must call ```openml_pytorch.add_experiment_info_to_run()``` immediately before ```run.publish()```. 

```python
run = openml_pytorch.add_experiment_info_to_run(run=run, trainer=trainer)
run.publish()
print('URL for run: %s/run/%d' % (openml.config.server, run.run_id))
```
