Metadata-Version: 2.4
Name: deem
Version: 0.1.1
Summary: Deep Ensemble Energy Models - RBM-based ensemble aggregation for crowd learning and classifier combination
Author-email: Maymona Albadri <maymona3@campus.technion.ac.il>
License: MIT
Project-URL: Homepage, https://github.com/Rem4rkable/rbm_python
Project-URL: Repository, https://github.com/Rem4rkable/rbm_python
Project-URL: Issues, https://github.com/Rem4rkable/rbm_python/issues
Project-URL: Changelog, https://github.com/Rem4rkable/rbm_python/blob/main/CHANGELOG.md
Keywords: ensemble learning,restricted boltzmann machine,crowd learning,multi-annotator,energy-based models,deep learning,pytorch,machine learning
Classifier: Development Status :: 3 - Alpha
Classifier: Intended Audience :: Science/Research
Classifier: Intended Audience :: Developers
Classifier: License :: OSI Approved :: MIT License
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.8
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: Topic :: Scientific/Engineering :: Information Analysis
Classifier: Topic :: Software Development :: Libraries :: Python Modules
Requires-Python: >=3.8
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: torch>=1.9.0
Requires-Dist: numpy>=1.19.0
Requires-Dist: scipy>=1.7.0
Requires-Dist: entmax>=1.0
Provides-Extra: automl
Requires-Dist: scikit-learn>=0.24.0; extra == "automl"
Requires-Dist: pandas>=1.3.0; extra == "automl"
Requires-Dist: joblib>=1.0.0; extra == "automl"
Provides-Extra: dev
Requires-Dist: pytest>=7.0.0; extra == "dev"
Requires-Dist: ruff>=0.1.0; extra == "dev"
Provides-Extra: all
Requires-Dist: deem[automl,dev]; extra == "all"
Dynamic: license-file

# DEEM - Deep Ensemble Energy Models

[![PyPI version](https://badge.fury.io/py/deem.svg)](https://badge.fury.io/py/deem)
[![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/)
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)

**DEEM** is a Python library for training Restricted Boltzmann Machines (RBMs) on ensemble predictions from multiple classifiers. It provides a scikit-learn compatible API for unsupervised ensemble aggregation, crowd learning, and model combination.

## Features

- 🚀 **Simple 3-line API** - Fit and predict in just a few lines of code
- 🔬 **Unsupervised Learning** - No labels required for training (though they can be used for evaluation)
- 🧮 **Energy-Based Models** - Uses RBMs to learn the joint distribution of classifier predictions
- 🎯 **Hungarian Alignment** - Automatic label permutation handling via Hungarian algorithm
- ⚡ **GPU Acceleration** - Full PyTorch backend with CUDA support
- 🔧 **Scikit-learn Compatible** - Standard `.fit()`, `.predict()`, `.score()` interface
- 📊 **Automatic Hyperparameters** - Optional meta-learning for hyperparameter selection

## Installation

```bash
pip install deem
```

### From Source

```bash
git clone https://github.com/Rem4rkable/rbm_python.git
cd rbm_python
pip install -e .
```

## Quick Start

```python
import numpy as np
from deem import DEEM

# Ensemble predictions from 15 classifiers on 100 samples with 3 classes
predictions = np.random.randint(0, 3, (100, 15))

# Train and predict in 3 lines!
model = DEEM()
model.fit(predictions)
consensus = model.predict(predictions)
```

### With Evaluation

```python
# If you have true labels, evaluate with automatic label alignment
model = DEEM(n_classes=3, epochs=50)
model.fit(train_predictions)

# Automatically handles label permutation problem
accuracy = model.score(test_predictions, test_labels)
print(f"Consensus accuracy: {accuracy:.2%}")
```

### Custom Configuration

```python
model = DEEM(
    n_classes=5,
    hidden_dim=2,           # Number of hidden units
    learning_rate=0.01,
    epochs=100,
    batch_size=64,
    cd_k=10,                # Contrastive divergence steps
    deterministic=True,     # Use probabilities (more stable)
    device='cuda'           # Use GPU
)
model.fit(predictions)
```

## Use Cases

### 1. Crowd Learning / Multi-Annotator Aggregation
Aggregate noisy labels from multiple human annotators:
```python
# annotator_labels: (n_samples, n_annotators) with values 0 to k-1
model = DEEM(n_classes=k)
model.fit(annotator_labels)
consensus_labels = model.predict(annotator_labels)
```

### 2. Ensemble Model Aggregation
Combine predictions from multiple trained classifiers:
```python
# Get predictions from multiple models
predictions = np.column_stack([
    model1.predict(X),
    model2.predict(X),
    model3.predict(X),
    # ... more models
])

# Learn optimal aggregation
ensemble = DEEM()
ensemble.fit(predictions)
final_predictions = ensemble.predict(predictions)
```

### 3. Missing Predictions
DEEM automatically handles cases where some classifiers don't provide predictions (use `-1` for missing):
```python
predictions = np.array([
    [0, 1, -1, 2, 1],  # Classifier 3 missing
    [1, 1, 1, -1, 1],  # Classifier 4 missing
    # ...
])
model = DEEM(n_classes=3)
model.fit(predictions)  # Missing values handled automatically
```

## How It Works

DEEM uses **Restricted Boltzmann Machines** (RBMs) - energy-based models that learn the joint probability distribution over classifier predictions and hidden representations. The key insight is that multiple weak classifiers contain complementary information that can be combined through unsupervised learning.

### Key Components

1. **Energy Function**: Models compatibility between visible (predictions) and hidden (consensus) states
2. **Contrastive Divergence**: Trains the RBM using DLP/GWG sampling
3. **Hungarian Algorithm**: Solves the label permutation problem during evaluation

### Architecture

```
Classifier Predictions → RBM → Hidden Representation → Consensus Label
     (visible layer)           (hidden layer)
```

## API Reference

### `DEEM`

Main class for ensemble aggregation.

**Parameters:**
- `n_classes` (int, optional): Number of classes. Auto-detected if not specified.
- `hidden_dim` (int, default=1): Number of hidden units.
- `cd_k` (int, default=10): Contrastive divergence steps.
- `deterministic` (bool, default=True): Use probabilities instead of sampling.
- `learning_rate` (float, default=0.001): Learning rate for SGD.
- `momentum` (float, default=0.9): SGD momentum.
- `epochs` (int, default=100): Training epochs.
- `batch_size` (int, default=128): Batch size.
- `device` (str, default='auto'): Device ('cpu', 'cuda', or 'auto').
- `random_state` (int, optional): Random seed.

**Methods:**
- `fit(predictions, labels=None, **kwargs)`: Train the model
- `predict(predictions, return_probs=False)`: Get consensus predictions
- `predict_with_hungarian(predictions, true_labels)`: Predict with label alignment
- `score(predictions, true_labels)`: Compute accuracy with Hungarian alignment
- `save(path)`: Save model to disk
- `load(path)`: Load model from disk
- `get_params()`: Get parameters (sklearn compatibility)
- `set_params(**params)`: Set parameters (sklearn compatibility)

## Advanced Features

### Automatic Hyperparameter Selection

```python
model = DEEM(
    auto_hyperparameters=True,
    model_dir='saved_hyp_models_v1'  # Path to trained predictor
)
model.fit(predictions)  # Hyperparameters automatically selected
```

### Save and Load Models

```python
# Save trained model
model.save('my_ensemble.pt')

# Load later
model = DEEM()
model.load('my_ensemble.pt')
predictions = model.predict(new_data)
```

### Soft Labels (Probability Distributions)

DEEM can also work with soft predictions (probability distributions):
```python
# soft_predictions: (n_samples, n_classes, n_classifiers)
model = DEEM(n_classes=3)
model.fit(soft_predictions)
```

## Requirements

- Python >= 3.8
- PyTorch >= 1.9
- NumPy >= 1.19
- SciPy >= 1.7

Optional:
- scikit-learn >= 0.24 (for automatic hyperparameter selection)

## Citation

If you use DEEM in your research, please cite:

```bibtex
@software{deem2026,
  title={DEEM: Deep Ensemble Energy Models for Classifier Aggregation},
  author={[Your Name]},
  year={2026},
  url={https://github.com/Rem4rkable/rbm_python}
}
```

## License

This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.

## Contributing

Contributions are welcome! Please feel free to submit a Pull Request.

## Acknowledgments

- Based on research in energy-based models and crowd learning
- Built with PyTorch and inspired by scikit-learn's API design

## Links

- **GitHub**: https://github.com/Rem4rkable/rbm_python
- **Documentation**: [Coming soon]
- **Issues**: https://github.com/Rem4rkable/rbm_python/issues

---

Made with ❤️ for the machine learning community
