Metadata-Version: 2.1
Name: jaxpole
Version: 0.0.2
Summary: A differentiable implementation of an all-pole filter in JAX
Home-page: https://github.com/rodrigodzf/jaxpole
Author: Rodrigo Diaz
Author-email: rodrigodzf@gmail.com
License: Apache Software License 2.0
Keywords: nbdev jupyter notebook python
Classifier: Development Status :: 4 - Beta
Classifier: Intended Audience :: Developers
Classifier: Natural Language :: English
Classifier: Programming Language :: Python :: 3.7
Classifier: Programming Language :: Python :: 3.8
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: License :: OSI Approved :: Apache Software License
Requires-Python: >=3.7
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: jax[cpu]
Requires-Dist: numpy
Provides-Extra: dev
Requires-Dist: pytest; extra == "dev"
Requires-Dist: nbdev; extra == "dev"

# jaxpole


<!-- WARNING: THIS FILE WAS AUTOGENERATED! DO NOT EDIT! -->

This is a Direct-Form I implementation of a time-varying all-pole filter
in JAX based on [torchlpc](https://github.com/yoyololicon/torchlpc).

## Install

``` sh
pip install jaxpole
```

or locally from source

``` sh
pip install -e '.[dev]'
```

## How to use

``` python
import jax.numpy as jnp
import jax

pole = 0.99 * jnp.exp(1j * jnp.pi / 4)
coeffs = jnp.array([-2 * pole.real, pole.real**2 + pole.imag**2])

x = jax.random.normal(jax.random.PRNGKey(0), (1, 1000)) # (B, T)
A = jnp.tile(coeffs, (1, x.shape[-1], 1)) # (B, T, P)
zi = jnp.zeros((1, 2)) # (B, P)

# filter the signal
y = allpole(x, A, zi)
```

    (1, 1000)
