Metadata-Version: 2.3
Name: dinox
Version: 0.4.1
Summary: Derivative Informed Neural Operators in JAX and Equinox
Author: Joshua Chen, Michael Brennan, Lianghao Cao, Thomas O'Leary-Roseberry
Author-email: Joshua Chen <joshuawchen@icloud.com>, Michael Brennan <mcbrenn@mit.edu>, Lianghao Cao <lianghao@caltech.edu>, Thomas O'Leary-Roseberry <tom.olearyroseberry@utexas.edu>
License: MIT License
         
         Copyright (c) 2025 Joshua Chen
         
         Permission is hereby granted, free of charge, to any person obtaining a copy
         of this software and associated documentation files (the "Software"), to deal
         in the Software without restriction, including without limitation the rights
         to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
         copies of the Software, and to permit persons to whom the Software is
         furnished to do so, subject to the following conditions:
         
         The above copyright notice and this permission notice shall be included in all
         copies or substantial portions of the Software.
         
         THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
         IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
         FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
         AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
         LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
         OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
         SOFTWARE.
Requires-Dist: jax>=0.4.30
Requires-Dist: jaxlib>=0.4.30
Requires-Dist: jaxtyping
Requires-Dist: optax>=0.2.3
Requires-Dist: bayesflux>=0.7.5
Requires-Dist: equinox
Requires-Dist: hickle>=5.0.3
Requires-Dist: cupy-cuda117>=10.0.0,<11.0.0 ; extra == 'cupy'
Requires-Dist: pytest ; extra == 'dev'
Requires-Dist: black ; extra == 'dev'
Requires-Dist: isort ; extra == 'dev'
Requires-Dist: flake8 ; extra == 'dev'
Requires-Dist: flake8-pyproject ; extra == 'dev'
Requires-Python: >=3.9
Project-URL: Homepage, https://github.com/dinoSciML/dinox
Project-URL: Repository, https://github.com/dinoSciML/dinox
Provides-Extra: cupy
Provides-Extra: dev
Description-Content-Type: text/markdown

# dinox
Implementation of **Derivative Informed Neural Operators** in `jax`. Build for fast performance in single GPU environments-- and specifically where _all-data-can-fit-in-gpu-memory_. In the future, this code will be generalized for the setting in which one has multiple GPUs and would like to take advantage. It will also be generalized to account for big-data (where not all samples can fit in gpu or cpu memory) -- Probably via memmapping. 

# Installation
Create a brand new environment. Use `mamba` in place of `conda` if you can. (i.e. run the first line below). The assumption is that conda is already installed on your machine.

If one has access to an NVIDIA gpu, use gpu_environment.yml, otherwise use cpu_environment.yml, which will install the dependencies for the code, but the code will not be as performant, since the library is a GPU-forward library.
```
conda install -c conda-forge mamba

mamba env create -f <gpu, cpu>_environment.yml
poetry install
```
# Running dinox
```
python -m dinox -network_name "<name_to_save_network_as>" -data_dir "<location_of_jacobian_enriched_training_data>"
```
# Examples


# Note, the codebase needs to be generalized to work generally on CPUs. It also does not fully work on Apple Silicon (jax-metal has limitations)
# Notes on why we require these packages:
- `cupy` - for rapid permuting of data on GPUs
- `kvikio` - for interfacing with NVIDIA GPU Direct Storage (GDS) for loading data directly to GPU, skipping the CPU
- `equinox` - Dinox is primarily build off of equinox and is therefore fully jax compatible. Most of dinox are simply lightweight utilities for dealing with mean H1 loss training of nerual networks with data that is enriched with Jacobians ($`X, Y, dY/dX`$)
- `optax` - we use optax for optimization, though any neural network optimization library can be used. We make choices primarily for speed.

## Need to generalize this to figure out the actual minimal requirements in terms of cuda, jax versions, and kvikio. The main tricky parts are which versions of jax/kvikio/cudatoolkit/cuda-nvcc/cudnn work together well. For now, only want to restrict to python>=3.10
## Let me know if anyone has depenency resolution issues.
