Metadata-Version: 2.1
Name: jax-toolkit
Version: 0.1.2
Summary: A collection of jax functions to help with common machine/deep learning related functionality.
Home-page: https://github.com/asmith26/jax_toolkit.git
Author: asmith26
License: Apache-2.0
Platform: UNKNOWN
Classifier: Development Status :: 2 - Pre-Alpha
Classifier: Intended Audience :: Science/Research
Classifier: Programming Language :: Python :: 3
Requires-Python: >=3.7
Description-Content-Type: text/markdown
Requires-Dist: jax
Requires-Dist: jaxlib
Provides-Extra: losses_utils
Requires-Dist: dm-haiku ; extra == 'losses_utils'

# jax_toolkit

A collection of jax functions to help with common machine/deep learning related functionality. 

[Documentation](https://asmith26.github.io/jax_toolkit/), [PyPi](https://pypi.org/project/jax-toolkit/)

This library currently contains the basics for a number of losses and metrics. We intend to add more complexity and functionality as and when it's needed - of course contributions/pull requests/bug reports etc. are very welcome if you discover problems or need something that is currently missing.

## Installation

```bash
pip install jax_toolkit
```

Or for additional loss function [utils](https://asmith26.github.io/jax_toolkit/losses_and_metrics/#utils):

```bash
pip install jax_toolkit[losses_utils]
```

