Metadata-Version: 2.1
Name: jaxkan
Version: 0.1.3
Summary: A JAX-based implementation of Kolmogorov-Arnold Networks
Home-page: https://github.com/srigas/jaxkan
Author: Spyros Rigas, Michalis Papachristou
Author-email: rigassp@gmail.com
Classifier: Programming Language :: Python :: 3
Classifier: License :: OSI Approved :: MIT License
Classifier: Operating System :: OS Independent
Requires-Python: >=3.6
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: numpy==1.26.4
Requires-Dist: flax==0.8.3
Requires-Dist: jax[cpu]==0.4.28
Requires-Dist: optax==0.2.2
Provides-Extra: gpu
Requires-Dist: jax[cuda12]; extra == "gpu"

# jaxKAN

A JAX implementation of the original Kolmogorov-Arnold Networks (KANs), using the Flax and Optax frameworks for neural networks and optimization, respectively. Our adaptation is based on the original [pykan](https://github.com/KindXiaoming/pykan), however we also included a built-in grid extension routine, which does not simply perform an adaptation of the grid based on the inputs, but also extends its size.


## Why not more efficient?

Despite their overall potential in the Deep Learning field, the authors of KANs emphasized their performance when it comes to scientific computing, in tasks such as Symbolic Regression or solving PDEs. This is why we put emphasis on preserving their original form, albeit less computationally efficient, as it allows the user to utilize the full regularization terms presented in the [arXiv pre-print](https://arxiv.org/abs/2404.19756) and not the "mock" regularization terms presented, for instance, in the [efficient-kan](https://github.com/Blealtan/efficient-kan/tree/master) implementation.


## Why JAX?

Because speed + scientific computing. Need we say more? Plus, even though all tests were performed on CPU, in JAX it is more than straightforward to switch to GPU.
