Metadata-Version: 2.3
Name: cmonge
Version: 0.1.1
Summary: Extension of the Monge Gap to learn conditional optimal transport maps
Keywords: Machine Learning,Optimal Transport,Neural OT,Monge Gap,Conditional Distribution Learning
Author: Alice Driessen
Author-email: adr@zurich.ibm.com
Requires-Python: >=3.10,<3.12
Classifier: Development Status :: 4 - Beta
Classifier: Intended Audience :: Developers
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: MIT License
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Requires-Dist: anndata (>=0.10.5.post1,<0.11.0)
Requires-Dist: black (>=24.4.2,<25.0.0)
Requires-Dist: chex (>=0.1.85,<0.2.0)
Requires-Dist: dotmap (>=1.3.30,<2.0.0)
Requires-Dist: flax (>=0.10.2,<0.11.0)
Requires-Dist: isort (>=5.13.2,<6.0.0)
Requires-Dist: jax (>=0.4.36,<0.5.0)
Requires-Dist: jaxlib (>=0.4.36,<0.5.0)
Requires-Dist: loguru (>=0.7.2,<0.8.0)
Requires-Dist: optax (>=0.2.4,<0.3.0)
Requires-Dist: optuna (>=3.5.0,<4.0.0)
Requires-Dist: ott-jax (>=0.5.0,<0.6.0)
Requires-Dist: pandas (>=2.0.0,<3.0.0)
Requires-Dist: rdkit (>=2023.9.5,<2024.0.0)
Requires-Dist: ruff (>=0.5.4,<0.6.0)
Requires-Dist: scanpy (>=1.9.8,<2.0.0)
Requires-Dist: scikit-learn (>=1.4.0,<2.0.0)
Requires-Dist: scipy (==1.12.0)
Requires-Dist: seaborn (>=0.13.2,<0.14.0)
Requires-Dist: typer (>=0.9.0,<0.10.0)
Requires-Dist: types-pyyaml (>=6.0.12.20240311,<7.0.0.0)
Requires-Dist: umap-learn (>=0.5.5,<0.6.0)
Project-URL: Homepage, https://github.com/AI4SCR/conditional-monge
Project-URL: Repository, https://github.com/AI4SCR/conditional-monge
Description-Content-Type: text/markdown

# Conditional Monge Gap

[![CI](https://github.com/AI4SCR/conditional-monge/actions/workflows/ci.yml/badge.svg)](https://github.com/AI4SCR/conditional-monge/actions/workflows/ci.yml)
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)

![](assets/overview.jpg)

An extension of the [Monge Gap](https://proceedings.mlr.press/v202/uscidda23a.html), an approach to estimate transport maps conditionally on arbitrary context vectors. It is based on a two-step training procedure combining an encoder-decoder architecture with an OT estimator. The model is applied to [4i](https://pubmed.ncbi.nlm.nih.gov/30072512/) and [scRNA-seq](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC7289078/) datasets.

## Installation from PyPI

You can install this package as follows
```sh
pip install cmonge
```

## Development setup & installation
The package environment is managed by [poetry](https://python-poetry.org/docs/managing-environments/). 
The code was tested in Python 3.10.
```sh
pip install poetry
git clone git@github.com:AI4SCR/conditional-monge.git
cd cmonge
poetry install -v
```

If the installation was successful you can run the tests using pytest
```sh
poetry shell # activate env
pytest
```

## Data

The preprocessed version of the Sciplex3 and 4i datasets can be downloaded [here](https://www.research-collection.ethz.ch/handle/20.500.11850/609681).


## Example usage

You can find example config in `configs/conditional-monge-sciplex.yml`.
To train an autoencoder model:
```py
from cmonge.datasets.conditional_loader import ConditionalDataModule
from cmonge.trainers.ae_trainer import AETrainerModule
from cmonge.utils import load_config


config_path = Path("configs/conditional-monge-sciplex.yml")
config = load_config(config_path)
config.data.ae = True

datamodule = ConditionalDataModule(config.data, config.condition)
ae_trainer = AETrainerModule(config.ae)

ae_trainer.train(datamodule)
ae_trainer.evaluate(datamodule)
```

To train a conditional monge model:

```py
from cmonge.datasets.conditional_loader import ConditionalDataModule
from cmonge.trainers.conditional_monge_trainer import ConditionalMongeTrainer
from cmonge.utils import load_config

config_path = Path("configs/conditional-monge-sciplex.yml")
logger_path = Path("logs")
config = load_config(config_path)

datamodule = ConditionalDataModule(config.data, config.condition)
trainer = ConditionalMongeTrainer(jobid=1, logger_path=logger_path, config=config.model, datamodule=datamodule)

trainer.train(datamodule)
trainer.evaluate(datamodule)
```

## Older checkpoints loading
If you want to load model weights of older checkpoints (cmonge-{moa, rdkit}-ood or cmonge-{moa, rdkit}-homogeneous), make sure you are on the tag `cmonge_checkpoint_loading`.

```sh
git checkout cmonge_checkpoint_loading
```

## Citation
If you use the package, please cite:
```bib
@inproceedings{
  harsanyi2024learning,
  title={Learning Drug Perturbations via Conditional Map Estimators},
  author={Benedek Harsanyi and Marianna Rapsomaniki and Jannis Born},
  booktitle={ICLR 2024 Workshop on Machine Learning for Genomics Explorations},
  year={2024},
  url={https://openreview.net/forum?id=FE7lRuwmfI}
}
```

