Metadata-Version: 2.1
Name: flaxsr
Version: 0.0.5
Summary: Super Resolution models with Jax/Flax
Home-page: https://github.com/dslisleedh/FlaxSR
Author: dslisleedh
Author-email: dslisleedh@gmail.com
Project-URL: Bug Tracker, https://github.com/dslisleedh/FlaxSR/issues
Classifier: Programming Language :: Python :: 3
Classifier: License :: OSI Approved :: Apache Software License
Classifier: Intended Audience :: Developers
Classifier: Intended Audience :: Science/Research
Classifier: Topic :: Scientific/Engineering :: Image Processing
Requires-Python: >=3.6
Description-Content-Type: text/markdown
License-File: LICENSE

# FlaxSR

Super Resolution models with Jax/Flax

## HOW TO USE

### Install
```shell
pip install flaxsr
```

### Usage
<b> You can easily load model/losses and train model using custom train_states. </b>

 - Train example
```python
import flaxsr
import jax
import jax.numpy as jnp
import numpy as np
import optax

model_kwargs = {
    'n_filters': 64, 'n_blocks': 8, 'scale': 4
}
model = flaxsr.get("models", "vdsr", **model_kwargs)  # This equals flaxsr.models.VDSR(**model_kwargs)
losses = [
    flaxsr.losses.L1Loss(reduce='sum'),
    flaxsr.get('losses', 'vgg', feats_from=(6, 8, 14,), before_act=False, reduce='mean')
]
loss_weights = (.1, 1.)
params = model.init(jax.random.PRNGKey(0), jnp.ones((1, 8, 8, 3), dtype=jnp.float32))
tx = optax.adam(1e-3)

state = flaxsr.training.TrainState.create(
    apply_fn=model.apply, params=params, tx=tx, losses=losses
)

hr = jnp.ones((1, 32, 32, 3), dtype=jnp.float32)
lr = jnp.ones((1, 8, 8, 3), dtype=jnp.float32)
batch = (lr, hr)

state_new, loss = flaxsr.training.discriminative_train_step(state, batch)

assert state_new.step == 1
np.not_equal(state_new.params['params']['Conv_0']['kernel'], state.params['params']['Conv_0']['kernel'])
```


## Models implemented
 - SRCNN
 - FSRCNN
 - ESPCN
 - VDSR
 - EDSR, MDSR,
 - NCNet
 - SRResNet(SRGAN will be implemented in future)
