Metadata-Version: 2.1
Name: cfnet
Version: 0.0.10
Summary: A counterfactual explanation library using Jax
Home-page: https://github.com/birkhoffg/cfnet/tree/master/
Author: BirkhoffG
Author-email: 26811230+BirkhoffG@users.noreply.github.com
License: Apache Software License 2.0
Keywords: some keywords
Classifier: Development Status :: 3 - Alpha
Classifier: Intended Audience :: Developers
Classifier: Natural Language :: English
Classifier: Programming Language :: Python :: 3.7
Classifier: Programming Language :: Python :: 3.8
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: License :: OSI Approved :: Apache Software License
Requires-Python: >=3.7
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: matplotlib
Requires-Dist: scikit-learn
Requires-Dist: pandas
Requires-Dist: nbdev
Requires-Dist: jupyter
Requires-Dist: dm-haiku
Requires-Dist: test-tube
Requires-Dist: jax[cpu]
Requires-Dist: tqdm
Requires-Dist: optax
Requires-Dist: pydantic (<2,>=1.9.0)
Requires-Dist: deprecation
Provides-Extra: dev
Requires-Dist: torch (>=1.7.0) ; extra == 'dev'

Welcome to cfnet
================

<!-- WARNING: THIS FILE WAS AUTOGENERATED! DO NOT EDIT! -->

## Key Features

- **fast**: code runs significantly faster than existing CF explanation
  libraries.
- **scalable**: code can be accelerated over *CPU*, *GPU*, and *TPU*
- **flexible**: we provide flexible API for researchers to allow full
  customization.

TODO: - implement various methods of CF explanations

## Install

`cfnet` is built on top of
[Jax](https://jax.readthedocs.io/en/latest/index.html). It also uses
[Pytorch](https://pytorch.org/) to load data.

### Running on CPU

If you only need to run `cfnet` on CPU, you can simply install via `pip`
or clone the `GitHub` project.

Installation via PyPI:

``` bash
pip install cfnet
```

Editable Install:

``` bash
git clone https://github.com/BirkhoffG/cfnet.git
pip install -e cfnet
```

### Running on GPU or TPU

If you wish to run `cfnet` on GPU or TPU, please first install this
library via `pip install cfnet`.

Then, you should install the right GPU or TPU version of Jax by
following steps in the [install
guidelines](https://github.com/google/jax#installation).

## A Minimum Example

``` python
from cfnet.utils import load_json
from cfnet.datasets import TabularDataModule
from cfnet.training_module import PredictiveTrainingModule
from cfnet.train import train_model
from cfnet.methods import VanillaCF
from cfnet.evaluate import generate_cf_results_local_exp, benchmark_cfs
from cfnet.import_essentials import *

data_configs = {
    "data_dir": "assets/data/s_adult.csv",
    "data_name": "adult",
    "batch_size": 256,
    "continous_cols": ["age","hours_per_week"],
    "discret_cols": ["workclass","education","marital_status","occupation","race","gender"],
    "imutable_cols": ["race","gender"]
}
m_configs = {
    'lr': 0.003,
    "sizes": [50, 10, 50],
    "dropout_rate": 0.3
}
t_configs = {
    'n_epochs': 10,
    'monitor_metrics': 'val/val_loss',
    'logger_name': 'pred',
    'seed': 42,
    "batch_size": 256
}
cf_configs = {
    'n_steps': 1000,
    'lr': 0.001
}

# load data
dm = TabularDataModule(data_configs)

# specify the ML model 
training_module = PredictiveTrainingModule(m_configs)

# train ML model
params, opt_state = train_model(
    training_module, dm, t_configs
)

# define CF Explanation Module
pred_fn = lambda x: training_module.forward(
    params, random.PRNGKey(0), x, is_training=False)
cf_exp = VanillaCF(cf_configs)

# generate cf explanations
cf_results = generate_cf_results_local_exp(cf_exp, dm, pred_fn)

# benchmark different cf explanation methods
benchmark_cfs([cf_results])
```

    /home/birk/mambaforge-pypy3/envs/nbdev2/lib/python3.7/site-packages/ipykernel_launcher.py:36: DeprecatedWarning: PredictiveTrainingModule is deprecated as of 0.0.7 and will be removed in 0.1.0. Use `cfnet.module.PredictiveTrainingModule` instead.
    WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
    /home/birk/mambaforge-pypy3/envs/nbdev2/lib/python3.7/site-packages/haiku/_src/data_structures.py:144: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.
      leaves, treedef = jax.tree_flatten(tree)
    /home/birk/mambaforge-pypy3/envs/nbdev2/lib/python3.7/site-packages/haiku/_src/data_structures.py:145: FutureWarning: jax.tree_unflatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_unflatten instead.
      return jax.tree_unflatten(treedef, leaves)
    /home/birk/code/cfnet/cfnet/_ckpt_manager.py:14: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
      for x in jax.tree_leaves(state):
    Epoch 9: 100%|██████████| 96/96 [00:01<00:00, 57.03batch/s, train/train_loss_1=0.0485]
    100%|██████████| 1000/1000 [00:08<00:00, 124.53it/s]

<div>
<style scoped>
    .dataframe tbody tr th:only-of-type {
        vertical-align: middle;
    }

    .dataframe tbody tr th {
        vertical-align: top;
    }

    .dataframe thead th {
        text-align: right;
    }
</style>
<table border="1" class="dataframe">
  <thead>
    <tr style="text-align: right;">
      <th></th>
      <th></th>
      <th>acc</th>
      <th>validity</th>
      <th>proximity</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <th>adult</th>
      <th>VanillaCF</th>
      <td>0.826188</td>
      <td>0.883675</td>
      <td>7.05637</td>
    </tr>
  </tbody>
</table>
</div>
