Metadata-Version: 2.1
Name: torch_rim
Version: 0.2.3
Summary: A torch implementation of the Recurrent Inference Machine
Home-page: https://github.com/AlexandreAdam/torch_rim
Author: Alexandre Adam
Author-email: alexande.adam@umontreal.ca
Classifier: Development Status :: 3 - Alpha
Classifier: Intended Audience :: Developers
Classifier: License :: OSI Approved :: MIT License
Classifier: Programming Language :: Python :: 3.8
Requires-Python: >=3.8
License-File: LICENSE.txt

=======================================
RIM: Recurrent Inference Machines
=======================================

.. image:: https://badge.fury.io/py/torch-rim.svg
    :target: https://badge.fury.io/py/torch-rim

.. image:: https://codecov.io/gh/AlexandreAdam/torch_rim/branch/master/graph/badge.svg
    :target: https://codecov.io/gh/AlexandreAdam/torch-rim

This is an implementation of a Recurrent Inference Machine (see `Putzky & Welling (2017) <https://arxiv.org/abs/1706.04008>`_)
alongside some standard neural network architectures for the type of problem RIM can solve.

Installation
------------

To install the package, you can use pip:

.. code-block:: bash

    pip install torch-rim

Usage
-----

.. code-block:: python

    from torch_rim import RIM, Hourglass, Unet
    from torch.func import vmap

    # B is the batch size
    # C is the input channels
    # dimensions are the spatial dimensions (e.g. [28, 28] for MNIST)

    # Create a score_fn, e.g. a Gaussian likelihood score function
    @vmap
    def score_fn(x, y, A, Sigma): # must respect the signature (x, y, *args)
        # A is a linear forward model, Sigma is the noise covariance
        return (y - A @ x) @ Sigma.inverse() @ A

    # ... or a Gaussian energy function (unnormalized log probability)
    @vmap
    def energy_fn(x, y, F, Sigma):
        # F is a general forward model
        return (y - F(x)) @ Sigma.inverse() @ (y - F(x))

    # Create a RIM instance with the Hourglass neural network back-bone and the score function
    net = Hourglass(C, dimensions=len(dimensions))
    rim = RIM(dimensions, net, score_fn=score_fn)

    # ... or with the energy function
    rim = RIM(dimensions, net, energy_fn=energy_fn)

    # Train the rim, and save its weight in checkpoints_directory
    rim.fit(dataset, epochs=100, learning_rate=1e-4, checkpoints_directory=checkpoints_directory)

    # Make a prediction on an observation y
    x_hat = rim.predict(y, A, Sigma) # of with the signature (y, F, Sigma) with the energy_fn


