Metadata-Version: 2.1
Name: symtorch
Version: 0.0.0
Summary: Symbolic Expressions in PyTorch
Author-email: John Gardner <gardner.john97@gmail.com>
Project-URL: Homepage, https://github.com/jla-gardner/symtorch
Classifier: License :: OSI Approved :: MIT License
Classifier: Programming Language :: Python
Classifier: Programming Language :: Python :: 3
Requires-Python: >=3.8
Description-Content-Type: text/markdown
Requires-Dist: torch
Requires-Dist: sympy
Requires-Dist: numpy
Provides-Extra: dev
Requires-Dist: notebook ; extra == 'dev'
Requires-Dist: pytest ; extra == 'dev'
Requires-Dist: pytest-cov ; extra == 'dev'
Requires-Dist: bumpver ; extra == 'dev'
Requires-Dist: ruff ; extra == 'dev'
Provides-Extra: publish
Requires-Dist: build ; extra == 'publish'
Requires-Dist: twine ; extra == 'publish'

<div align="center">
<img src="icon-with-text.svg" style="width: min(100%, 400px); height: auto;"/>
</div>

---

Fast, optimisable, symbolic expressions in PyTorch.

```python-repl
>>> from symtorch import symtorchify
>>> f = symtorchify("x**2 + 2.5*x + 1.7")
>>> f
x²+2.5x+1.7
>>> len(list(f.parameters()))
2
>>> import torch
>>> f.evalf({"x": torch.tensor(2.0)})
tensor([10.7000], grad_fn=<AddBackward0>)
```

## Installation

```bash
pip install symtorch
```

## Features and Documentation


## What about [SymPyTorch](https://github.com/patrick-kidger/sympytorch)?

This package attempts to supersede the amazing [Patrick Kidger]()'s original SymPyTorch.
Useful features improvements here are:

- implementations of `state_dict` and `load_state_dict` for all `SymTorch` objects, allowing for automated saving and loading via the native PyTorch mechanisms
- plays nicely with TorchScript, allowing for integration into C++ code
- a `SymbolAssignment` helper class to enable "drag-and-drop" replace of existing NN components with symbolic ones:
```python-repl
>>> model = nn.Sequential(
    SymbolAssignment(["a", "b"]), 
    symtorchify("3*a + b")
)
>>> model(torch.tensor([[1, 2], [3, 4]]))
tensor([[ 5.],
        [13.]], grad_fn=<AddBackward0>)
```
