Metadata-Version: 2.1
Name: vistabnet
Version: 0.1.1
Summary: 
Author: wwydmanski
Author-email: wwydmanski@gmail.com
Requires-Python: >=3.9,<4.0
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Requires-Dist: focal-loss-torch (>=0.1.2,<0.2.0)
Requires-Dist: torch (>=2,<3)
Requires-Dist: torchvision (>=0.15.0,<0.16.0)
Requires-Dist: tqdm (>=4.65.0,<5.0.0)
Description-Content-Type: text/markdown

# VisTabNet
This package introduces VisTabNet - Vision Transformer-based Tabular Data Classifier. 

## Usage
```python
from vistabnet import VisTabNetClassifier

X_train, y_train, X_test, y_test = ... # Load your data here. Y should be label encoded, not one-hot encoded.

model = VisTabNetClassifier(input_features=X_train.shape[1], classes=len(np.unique(y_train)), device="cuda:1")
model.fit(X_train, y_train, eval_X=X_test, eval_y=y_test)

y_pred = model.predict(X_test)
acc = balanced_accuracy_score(y_test_, y_pred)
print(f"Balanced accuracy: {acc}")
```

## Installation
```bash
pip install vistabnet
```
