Metadata-Version: 2.4
Name: pinnx
Version: 0.0.3
Summary: Physics-Informed Neural Networks for Scientific Machine Learning in JAX.
Author-email: PINNx Developers <chao.brain@qq.com>
Project-URL: Homepage, https://github.com/chaobrain/pinnx
Project-URL: Documentation, https://pinnx.readthedocs.io/
Project-URL: Source Code, https://github.com/chaobrain/pinnx
Project-URL: Bug Tracker, https://github.com/chaobrain/pinnx/issues
Keywords: computational neuroscience,brain-inspired computation,brain dynamics programming
Classifier: Natural Language :: English
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3.13
Classifier: Intended Audience :: Science/Research
Classifier: Topic :: Scientific/Engineering :: Bio-Informatics
Classifier: Topic :: Scientific/Engineering :: Mathematics
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: Topic :: Software Development :: Libraries
Requires-Python: >=3.10
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: jax
Requires-Dist: brainunit
Requires-Dist: brainstate
Requires-Dist: braintools
Provides-Extra: cpu
Requires-Dist: jaxlib; extra == "cpu"
Provides-Extra: cuda12
Requires-Dist: jaxlib[cuda12]; extra == "cuda12"
Provides-Extra: cuda13
Requires-Dist: jaxlib[cuda13]; extra == "cuda13"
Provides-Extra: tpu
Requires-Dist: jaxlib[tpu]; extra == "tpu"
Provides-Extra: testing
Requires-Dist: pytest; extra == "testing"
Dynamic: license-file

# PINNx: Physics-Informed Neural Networks for Scientific Machine Learning in JAX


<p align="center">
  	<img alt="Header image of pinnx." src="https://github.com/chaobrain/pinnx/blob/main/docs/_static/pinnx.png" width=40%>
</p> 


[![Build Status](https://github.com/chaobrain/pinnx/actions/workflows/build.yml/badge.svg)](https://github.com/chaobrain/pinnx/actions/workflows/build.yml)
[![Documentation Status](https://readthedocs.org/projects/pinnx/badge/?version=latest)](https://pinnx.readthedocs.io/en/latest/?badge=latest)
[![PyPI Version](https://badge.fury.io/py/pinnx.svg)](https://badge.fury.io/py/pinnx)
[![License](https://img.shields.io/github/license/chaobrain/pinnx)](https://github.com/chaobrain/pinnx/blob/master/LICENSE)

``PINNx`` is a library for scientific machine learning and physics-informed learning in JAX. 
It is a rewrite of [DeepXDE](https://github.com/lululxvi/deepxde) but is enhanced by our 
[brain modeling ecosystem](https://brainmodeling.readthedocs.io/). 
For example, it leverages 

- [brainstate](https://brainstate.readthedocs.io/) for just-in-time compilation,
- [brainunit](https://brainunit.readthedocs.io/) for dimensional analysis, 
- [braintools](https://braintools.readthedocs.io/) for checkpointing, loss functions, and other utilities.


## Quickstart


Define a PINN with explicit variables and physical units.

```python
import braintools
import brainunit as u
import pinnx

# geometry
geometry = pinnx.geometry.GeometryXTime(
    geometry=pinnx.geometry.Interval(-1, 1.),
    timedomain=pinnx.geometry.TimeDomain(0, 0.99)
).to_dict_point(x=u.meter, t=u.second)

uy = u.meter / u.second
v = 0.01 / u.math.pi * u.meter ** 2 / u.second

# boundary conditions
bc = pinnx.icbc.DirichletBC(lambda x: {'y': 0. * uy})
ic = pinnx.icbc.IC(lambda x: {'y': -u.math.sin(u.math.pi * x['x'] / u.meter) * uy})

# PDE equation
def pde(x, y):
    jacobian = approximator.jacobian(x)
    hessian = approximator.hessian(x)
    dy_x = jacobian['y']['x']
    dy_t = jacobian['y']['t']
    dy_xx = hessian['y']['x']['x']
    residual = dy_t + y['y'] * dy_x - v * dy_xx
    return residual

# neural network
approximator = pinnx.nn.Model(
    pinnx.nn.DictToArray(x=u.meter, t=u.second),
    pinnx.nn.FNN(
        [geometry.dim] + [20] * 3 + [1],
        "tanh",
        braintools.init.KaimingUniform()
    ),
    pinnx.nn.ArrayToDict(y=uy)
)

# problem
problem = pinnx.problem.TimePDE(
    geometry,
    pde,
    [bc, ic],
    approximator,
    num_domain=2540,
    num_boundary=80,
    num_initial=160,
)

# training
trainer = pinnx.Trainer(problem)
trainer.compile(braintools.optim.Adam(1e-3)).train(iterations=15000)
trainer.compile(braintools.optim.LBFGS(1e-3)).train(2000, display_every=500)
trainer.saveplot(issave=True, isplot=True)

```



## Installation

- Install the stable version with `pip`:

``` sh
pip install pinnx --upgrade
```


- Install ``pinnx`` on CPU or GPU with JAX following the instructions on

```shell
pip install pinnx[cpu]  # for CPU

pip install pinnx[cuda12]  # for NVIDIA GPUs with CUDA 12

pip install pinnx[cuda13]  # for NVIDIA GPUs with CUDA 13

pip install pinnx[tpu]  # for Google TPUs

```


## Documentation

The official documentation is hosted on Read the Docs: [https://pinnx.readthedocs.io/](https://pinnx.readthedocs.io/)


## See also the ecosystem

``pinnx`` is one part of our brain modeling ecosystem: https://brainmodeling.readthedocs.io/

