Metadata-Version: 2.4
Name: jaxflow-lib
Version: 0.1.0
Summary: High-performance, PyTree-native data loading and processing for JAX/Flax
Author-email: Prabhnoor Singh <prabhnoors093@gmail.com>
License: MIT
Project-URL: Homepage, https://github.com/prabhnoors12/jaxflow
Project-URL: Documentation, https://github.com/prabhnoors12/jaxflow#readme
Project-URL: Repository, https://github.com/prabhnoors12/jaxflow
Project-URL: Issues, https://github.com/prabhnoors12/jaxflow/issues
Keywords: jax,flax,dataloader,dataset,machine-learning,deep-learning,neural-networks,pytree,performance
Classifier: Development Status :: 3 - Alpha
Classifier: Intended Audience :: Developers
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: MIT License
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.8
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Topic :: Scientific/Engineering
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: Topic :: Software Development :: Libraries :: Python Modules
Classifier: Operating System :: OS Independent
Requires-Python: >=3.8
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: jax>=0.4.0
Requires-Dist: flax>=0.7.0
Requires-Dist: numpy
Provides-Extra: dev
Requires-Dist: pytest>=7.0; extra == "dev"
Requires-Dist: black; extra == "dev"
Requires-Dist: isort; extra == "dev"
Requires-Dist: mypy; extra == "dev"
Requires-Dist: mkdocs; extra == "dev"
Provides-Extra: viz
Requires-Dist: matplotlib; extra == "viz"
Requires-Dist: seaborn; extra == "viz"
Provides-Extra: all
Requires-Dist: jaxflow[dev,viz]; extra == "all"
Dynamic: license-file

# JaxFlow

[![PyPI version](https://badge.fury.io/py/jaxflow.svg)](https://badge.fury.io/py/jaxflow)
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
[![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/)
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)

**JaxFlow** is a high-performance, PyTree-native data loading and processing library designed specifically for the JAX and Flax ecosystem. 

Unlike generic data loaders, JaxFlow is built from the ground up to handle JAX's specific needs—like handling arbitrary PyTrees, efficient prefetching to devices (GPU/TPU), and seamless integration with `jax.jit` and `jax.pmap`.

## 🚀 Key Features

*   **PyTree Native**: Data loaders yield PyTrees (dicts, tuples, lists, custom classes) directly, ready for `jax.tree_map`.
*   **JAX Device Prefetching**: Automatically prefetches batches to the target device (GPU/TPU) to minimize host-device transfer bottlenecks.
*   **Torch-like API**: Familiar `Dataset` and `Loader` API for those coming from PyTorch, but optimized for JAX.
*   **Multiprocessing**: Robust multiprocessing workers for parallel data loading and augmentation.
*   **Composability**: Flexible `transforms` module for composing image and data augmentations.
*   **Visualization**: Built-in tools in `jaxflow.viz` to quickly inspect batches and training curves.
*   **CLI Tools**: Includes a command-line interface for benchmarking system performance (`python -m jaxflow.cli benchmark`).

## 📦 Installation

```bash
# Install from PyPI
pip install jaxflow

# Install with visualization support
pip install jaxflow[viz]

# Install from source
pip install .
```

## ⚡ Quick Start

Here's how to create a simple dataset and iterate over it:

```python
import jax.numpy as jnp
import numpy as np
from jaxflow import Dataset, Loader, transforms

# 1. Define a custom dataset
class RandomDataset(Dataset):
    def __init__(self, length=1000):
        self.length = length
        self.transform = transforms.Compose([
            transforms.ToArray(dtype=jnp.float32),
            transforms.Normalize(mean=[0.5], std=[0.5])
        ])

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        # Return a dict (PyTree)
        image = np.random.rand(28, 28, 1)
        label = np.random.randint(0, 10)
        return {
            "image": self.transform(image),
            "label": label
        }

# 2. Create a loader
dataset = RandomDataset()
loader = Loader(
    dataset, 
    batch_size=32, 
    shuffle=True, 
    num_workers=2,
    drop_last=True
)

# 3. Iterate (batches are automatically prefetched to device if available)
print("Starting training loop...")
for batch in loader:
    images = batch["image"] # Shape: (32, 28, 28, 1)
    labels = batch["label"] # Shape: (32,)
    
    # Your JAX training step here...
    # params = train_step(params, images, labels)
    
    print(f"Batch shape: {images.shape}")
    break
```

## 🛠️ CLI Usage

JaxFlow comes with a handy CLI to check your environment and run benchmarks.

```bash
# Check system info
jaxflow info --json

# Run a matrix multiplication benchmark to test device performance
jaxflow benchmark --device gpu --size 4096 --iters 100
```

## 🤝 Contributing

Contributions are welcome! Please feel free to submit a Pull Request.

## 📄 License

This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
