Metadata-Version: 2.4
Name: factrainer
Version: 0.1.24
Summary: Framework Agnostic Cross-validation Trainer
Project-URL: Homepage, https://github.com/ritsuki1227/factrainer/
Project-URL: Documentation, https://ritsuki1227.github.io/factrainer/stable/
Project-URL: Repository, https://github.com/ritsuki1227/factrainer/
Project-URL: Bug Tracker, https://github.com/ritsuki1227/factrainer/issues/
Author-email: ritsuki1227 <ritsuki1227@gmail.com>
License-Expression: MIT
License-File: LICENSE
Classifier: Development Status :: 4 - Beta
Classifier: Intended Audience :: Developers
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: MIT License
Classifier: Natural Language :: English
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3 :: Only
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3.13
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Requires-Python: >=3.12
Requires-Dist: factrainer-core
Provides-Extra: all
Requires-Dist: factrainer-catboost; extra == 'all'
Requires-Dist: factrainer-lightgbm; extra == 'all'
Requires-Dist: factrainer-sklearn; extra == 'all'
Requires-Dist: factrainer-xgboost; extra == 'all'
Provides-Extra: catboost
Requires-Dist: factrainer-catboost; extra == 'catboost'
Provides-Extra: lightgbm
Requires-Dist: factrainer-lightgbm; extra == 'lightgbm'
Provides-Extra: sklearn
Requires-Dist: factrainer-sklearn; extra == 'sklearn'
Provides-Extra: xgboost
Requires-Dist: factrainer-xgboost; extra == 'xgboost'
Description-Content-Type: text/markdown

# Factrainer

![CI](https://github.com/ritsuki1227/factrainer/actions/workflows/ci.yaml/badge.svg)
[![codecov](https://codecov.io/gh/ritsuki1227/factrainer/branch/main/graph/badge.svg)](https://codecov.io/gh/ritsuki1227/factrainer)
[![PyPI](https://img.shields.io/pypi/v/factrainer.svg)](https://pypi.python.org/project/factrainer)
[![image](https://img.shields.io/pypi/pyversions/factrainer.svg)](https://pypi.python.org/pypi/factrainer)
![License](https://img.shields.io/github/license/ritsuki1227/factrainer.svg)
![Stars](https://img.shields.io/github/stars/ritsuki1227/factrainer.svg?style=social)

**Factrainer** (Framework Agnostic Cross-validation Trainer) is a machine learning tool that provides a flexible cross-validation training framework. It addresses the limitations of existing cross-validation utilities in popular ML libraries by offering a unified, parallelized approach that retains models and yields out-of-fold (OOF) predictions.

**Documentation**: For detailed documentation, please visit [https://ritsuki1227.github.io/factrainer/](https://ritsuki1227.github.io/factrainer/)

## Why Use Factrainer?

Various ML frameworks (e.g., Scikit-learn, LightGBM) offer cross-validation functions. However, each has different features and interfaces. The table below highlights some widely used cross-validation APIs and which capabilities they support:

| Framework    | API                                                                                                                     | OOF prediction | return trained models | parallel training |
| ------------ | ----------------------------------------------------------------------------------------------------------------------- | :------------: | :-------------------: | :---------------: |
| LightGBM     | [`lgb.cv`](https://lightgbm.readthedocs.io/en/stable/pythonapi/lightgbm.cv.html)                                        |       🚫       |          ✅️          |        🚫         |
| Scikit-learn | [`GridSearchCV`](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html)           |       🚫       |          🚫           |        ✅️        |
| Scikit-learn | [`cross_val_score`](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.cross_val_score.html)     |       🚫       |          🚫           |        ✅️        |
| Scikit-learn | [`cross_val_predict`](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.cross_val_predict.html) |      ✅️       |          🚫           |        ✅️        |
| Scikit-learn | [`cross_validate`](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.cross_validate.html)       |       🚫       |          ✅️          |        ✅️        |

No built-in API combines OOF predictions, trained-model access, and parallelized training—Factrainer does.

## Key Features

- **Unified Cross-Validation API** – Provides a single, consistent interface to perform K-fold (or any CV) training, acting as a meta-framework that wraps around multiple ML libraries.
- **Parallelized Training** – Run cross-validation folds in parallel to fully utilize multi-core CPUs and speed up model training.
- **Mutable Model Container** – Access each fold’s trained model as a mutable object. This makes it easy to analyze models or create ensembles from the fold models.
- **Out-of-Fold Predictions** – Retrieve out-of-fold predictions for every training instance through a simple API.

## Installation

To install with LightGBM and Scikit-learn support:

```sh
pip install "factrainer[lightgbm,sklearn]"
```

To install with all supported backends (LightGBM, Scikit-learn, XGBoost, and CatBoost):

```sh
pip install "factrainer[all]"
```

At present, LightGBM and Scikit-learn are the primary supported backends. Support for additional frameworks will be implemented as the project evolves.

## Get started

Code example: **California Housing dataset**

```python
import lightgbm as lgb
from sklearn.datasets import fetch_california_housing
from factrainer.core import CvModelContainer
from factrainer.lightgbm import LgbDataset, LgbModelConfig, LgbTrainConfig

data = fetch_california_housing()
dataset = LgbDataset(
    dataset=lgb.Dataset(
        data.data, label=data.target
    )
)
config = LgbModelConfig.create(
    train_config=LgbTrainConfig(
        params={"objective": "regression"},
        callbacks=[lgb.early_stopping(100, verbose=False)],
    ),
)
k_fold = KFold(n_splits=4, shuffle=True, random_state=1)
model = CvModelContainer(config, k_fold)
model.train(dataset, n_jobs=4)

# trained models
model.raw_model

# OOF prediction
y_pred = model.predict(dataset, n_jobs=4)
print(r2_score(data.target, y_pred))
```

## Project Status

Factrainer is in active development. The goal is to expand support to more frameworks and make the tool even more robust. Contributions, issues, and feedback are welcome to help shape the future of Factrainer.
