Metadata-Version: 2.4
Name: torchax
Version: 0.0.4
Summary: torchax is a library for running PyTorch on TPU
Project-URL: Homepage, https://github.com/pytorch/xla/tree/master/torchax
Author-email: Han Qi <qihan.dev@gmail.com>, Pytorch/XLA team <pytorchxla-dev@google.com>
License: BSD 3-Clause License
        
        Copyright (c) 2023, pytorch-tpu
        
        Redistribution and use in source and binary forms, with or without
        modification, are permitted provided that the following conditions are met:
        
        1. Redistributions of source code must retain the above copyright notice, this
           list of conditions and the following disclaimer.
        
        2. Redistributions in binary form must reproduce the above copyright notice,
           this list of conditions and the following disclaimer in the documentation
           and/or other materials provided with the distribution.
        
        3. Neither the name of the copyright holder nor the names of its
           contributors may be used to endorse or promote products derived from
           this software without specific prior written permission.
        
        THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
        AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
        IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
        DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
        FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
        DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
        SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
        CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
        OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
        OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
License-File: LICENSE
Classifier: Development Status :: 3 - Alpha
Classifier: Intended Audience :: Developers
Classifier: Intended Audience :: Education
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: BSD License
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3.13
Classifier: Topic :: Scientific/Engineering
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: Topic :: Scientific/Engineering :: Mathematics
Classifier: Topic :: Software Development
Classifier: Topic :: Software Development :: Libraries
Classifier: Topic :: Software Development :: Libraries :: Python Modules
Requires-Python: >=3.10
Provides-Extra: cpu
Requires-Dist: jax[cpu]; extra == 'cpu'
Requires-Dist: jax[cpu]>=0.4.30; extra == 'cpu'
Requires-Dist: tensorflow-cpu; extra == 'cpu'
Provides-Extra: cuda
Requires-Dist: jax[cpu]>=0.4.30; extra == 'cuda'
Requires-Dist: jax[cuda12]; extra == 'cuda'
Requires-Dist: tensorflow-cpu; extra == 'cuda'
Provides-Extra: odml
Requires-Dist: jax[cpu]; extra == 'odml'
Requires-Dist: jax[cpu]>=0.4.30; extra == 'odml'
Provides-Extra: tpu
Requires-Dist: jax[cpu]>=0.4.30; extra == 'tpu'
Requires-Dist: jax[tpu]; extra == 'tpu'
Requires-Dist: tensorflow-cpu; extra == 'tpu'
Description-Content-Type: text/markdown

# torchxla2

## Install

Currently this is only source-installable. Requires Python version >= 3.10.

### NOTE:

Please don't install torch-xla from instructions in
https://github.com/pytorch/xla/blob/master/CONTRIBUTING.md .
In particular, the following are not needed:

* There is no need to build pytorch/pytorch from source.
* There is no need to clone pytorch/xla project inside of pytorch/pytorch
  git checkout.


TorchXLA2 and torch-xla have different installation instructions, please follow
the instructions below from scratch (fresh venv / conda environment.)


### 1. Installing `torchax`

The following instructions assume you are in the `torchax` directory:

```
Fork the repository
$ git clone https://github.com/<github_username>/xla.git
$ cd xla/torchax
```


#### 1.0 (recommended) Make a virtualenv / conda env

If you are using VSCode, then [you can create a new environment from
UI](https://code.visualstudio.com/docs/python/environments). Select the
`dev-requirements.txt` when asked to install project dependencies.

Otherwise create a new environment from the command line.

```bash
# Option 1: venv
python -m venv my_venv
source my_venv/bin/activate

# Option 2: conda
conda create --name <your_name> python=3.10
conda activate <your_name>

# Either way, install the dev requirements.
pip install -r dev-requirements.txt
```

Note: `dev-requirements.txt` will install the CPU-only version of PyTorch.

#### 1.1 Install this package

Install `torchax` from source for your platform:
```bash
pip install -e .[cpu]
pip install -e .[cuda]
pip install -e .[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
```

#### 1.2 (optional) verify installation by running tests

```bash
pip install -r test-requirements.txt
pytest test
```

## Run a model

Now let's execute a model under torchax. We'll start with a simple 2-layer model
it can be in theory any instance of `torch.nn.Module`.

```python
import torch
import torch.nn as nn
import torch.nn.functional as F


class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28 * 28, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

m = MyModel()

# Execute this model using torch
inputs = torch.randn(3, 3, 28, 28)
print(m(inputs))
```

This model `m` contains 2 parts: the weights that is stored inside of the model
and it's submodules (`nn.Linear`).

To execute this model with `torchax`; we need construct and run the model
under an `environment` that captures pytorch ops and swaps them with TPU equivalent.

To create this environment: use

```python
import torchax

env = torchax.default_env() 
```
Then, execute the instantiation of the model, as well as evaluation of model, 
using `env` as a context manager:

```python
with env:
  inputs = torch.randn(3, 3, 28, 28)
  m = MyModel()
  res = m(inputs)
  print(type(res))  # outputs Tensor
```

You can also enable the environment globally with
```python
import torchax

torchax.enable_globally() 
```

Then everything afterwards is run with XLA.

## What is happening behind the scene:

When a torch op is executed inside of `env` context manager, we can swap out the 
implementation of that op with a version that runs on TPU. 
When a model's constructor runs, it will call some tensor constructor, such as
`torch.rand`, `torch.ones` or `torch.zeros` etc to create its weights. Those
ops are captured by `env` too and placed directly on TPU.

See more at [how_it_works](docs/how_it_works.md) and [ops registry](docs/ops_registry.md).

### What if I created model outside of `env`.

So if you have

```
m = MyModel()
```
outside of env, then regular torch ops will run when creating this model.
Then presumably the model's weights will be on CPU (as instances of `torch.Tensor`).

To move this model into XLA device, one can use `env.to_xla()` function.

i.e.
```
m2 = env.to_xla(m)
inputs = env.to_xla(inputs)

with env:
  res = m2(inputs)
```

NOTE that we also need to move inputs to xla using `.to_xla`. 
`to_xla` works with all pytrees of `torch.Tensor`.


### Executing with jax.jit

The above script will execute the model using eager mode Jax as backend. This 
does allow executing torch models on TPU, but is often slower than what we can 
achieve with `jax.jit`.

`jax.jit` is a function that takes a Jax function (i.e. a function that takes jax array
and returns jax array) into the same function, but faster.

We have made the `jax_jit` decorator that would accomplish the same with functions
that takes and returns `torch.Tensor`. To use this, the first step is to create
a functional version of this model: this means the parameters should be passed in
as input instead of being attributes on class:


```python

def model_func(param, inputs):
  return torch.func.functional_call(m, param, inputs)

```
Here we use [torch.func.functional_call](https://pytorch.org/docs/stable/generated/torch.func.functional_call.html) 
from PyTorch to replace the model
weights with `param`, then call the model. This is equivalent to:

```python
def model_func(param, inputs):
  m.load_state_dict(param)
  return m(*inputs)
```

Now, we can apply `jax_jit`

```python
from torchax.interop import jax_jit
model_func_jitted = jax_jit(model_func)
print(model_func_jitted(new_state_dict, inputs))
```

See more examples at [eager_mode.py](examples/eager_mode.py) and the (examples folder)[examples/]