Metadata-Version: 2.1
Name: jaxnet
Version: 0.1.1
Summary: Neural Nets for JAX
Home-page: http://github.com/JuliusKunze/jaxnet
License: UNKNOWN
Description: # JAXnet [![Build Status](https://travis-ci.org/JuliusKunze/jaxnet.svg?branch=master)](https://travis-ci.org/JuliusKunze/jaxnet) [![PyPI](https://img.shields.io/pypi/v/jaxnet.svg)](https://pypi.python.org/pypi/jaxnet/#history)
        
        JAXnet is a neural net library for [JAX](https://github.com/google/jax).
        Different from popular neural net libraries, it is completely functional:
        - No mutable weights in modules
        - No global compute graph
        - No global random key
        
        This is an early version. Expect bugs, sharp edges and breaking changes! Install with
        ```
        pip install jaxnet
        ```
        
        ## Overview
        
        Defining networks looks similar to the [TensorFlow2 / Keras functional API](https://www.tensorflow.org/beta/guide/keras/functional):
        ```python
        from jax import numpy as np, random, jit
        from jaxnet import *
        
        net = Sequential([Dense(10), relu, Dense(4)])
        ```
        
        `Sequential`, `Dense`, `Conv` and `RNN` (with `GRUCell`) are already supported.
        
        To initialize parameter values for a network, call `init_params` on any module (with example inputs and a random key):
        ```python
        batch = np.zeros((3, 2))
        params = net.init_params(random.PRNGKey(0), batch)
        ```
        
        It initializes and returns all parameters, accessible via attributes:
        ```python
        print(params.layers[0].bias) # [0.00212132 0.01169001 0.00331698 0.00460713]
        ```
        
        Invoke the network with:
        ```python
        output = net(params, batch)
        ```
        
        For acceleration use `jit`:
        
        ```python
        output = jit(net)(params, batch)
        ```
        
        Full examples are given in these interactive demos: [Mnist VAE](https://colab.research.google.com/drive/19web5SnmIFglLcnpXE34phiTY03v39-g) and [OCR with RNNs](https://colab.research.google.com/drive/1YuI6GUtMgnMiWtqoaPznwAiSCe9hMR1E).
        
        ## Defining modules
        
        Modules are functions decorated with `@parameterized`, with parameters defined through default values:
        ```python
        def Dense(out_dim, kernel_init=glorot(), bias_init=randn()):
            @parameterized
            def dense(inputs,
                      kernel=Param(lambda inputs: (inputs.shape[-1], out_dim), kernel_init),
                      bias=Param(lambda _: (out_dim,), bias_init)):
                return np.dot(inputs, kernel) + bias
        
            return dense
        ```
        
        `Param` specifies parameter shape and initialization function. 
        `@parameterized` transforms this function to allow usage as above.
        
        ## Nesting modules
        
        Modules can be used in other modules through default arguments:
        
        ```python
        @parameterized
        def net(inputs, layer1=Dense(10), layer2=Dense(20))
            inputs = layer1(inputs)
            return layer2(inputs)
        ```
        
        Use many modules at once with collections:
        ```python
        def Sequential(layers):
            @parameterized
            def sequential(inputs, layers=layers):
                for module in layers:
                    inputs = module(inputs)
                return inputs
        
            return sequential
        ```
        
        Nested `tuples`/`list`/`dicts` of modules work. The same is true for `Params`.
        
        Using parameter-free functions is seamless:
        ```python
        def relu(input):
            return np.maximum(input, 0)
        
        layer = Sequential([Dense(10), relu])
        ```
        
        ## Parameter sharing
        
        Parameters can be shared by using module or parameter objects multiple times (not yet implemented):
        
        ```python
        shared_net=Sequential([layer, layer])
        ```
        
        This is equivalent to (already implemented):
        
        ```python
        @parameterized
        def shared_net(input, layer=layer):
            return layer(layer(input))
        ```
        
        ## Parameter reuse
        
        You can reuse parameters of submodules:
        
        ```python
        inputs = np.zeros((1, 2))
        
        layer = Dense(5)
        net1 = Sequential([layer, Dense(2)])
        net2 = Sequential([layer, Dense(3)])
        
        layer_params = layer.init_params(random.PRNGKey(0), inputs)
        net1_params = net1.init_params(random.PRNGKey(0), inputs, reuse={layer: layer_params})
        net2_params = net2.init_params(random.PRNGKey(1), inputs, reuse={layer: layer_params})
        
        # Now net1_params.layers[0] equals net2_params.layers[0] equals layer_params
        ```
        
        If all parameters are reused, you can use `join_params` instead of `init_params`:
        
        ```python
        inputs = np.zeros((1, 2))
        
        net = Dense(5)
        prediction = Sequential([net, softmax])
        
        net_params = net.init_params(random.PRNGKey(0), inputs)
        prediction_params = prediction.join_params({net: layer_params})
        
        # net_params.layers[0] is now equal to net_params
        
        output = jit(prediction)(prediction_params, inputs)
        ```
        
        If you just want to call the network with these joined parameters, you can use the shorthand:
        
        ```python
        output = prediction.apply_joined({net: layer_params}, inputs, jit=True)
        ```
        
        ## What about [stax](https://github.com/google/jax/blob/master/jax/experimental/stax.py)?
        JAXnet is independent of stax.
        The main motivation over stax is to simplify nesting modules:
         - Automating `init_params`: delegation to submodules, `output_shape` inference, `rng` passing
         - Seamless use of parameter-free functions as modules
         - Allowing streamlined module/parameter-sharing
        
        Alternative design ideas are discussed [here](DESIGN.md).
Platform: UNKNOWN
Classifier: Programming Language :: Python :: 3
Classifier: License :: OSI Approved :: Apache Software License
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: Intended Audience :: Developers
Classifier: Intended Audience :: Education
Classifier: Intended Audience :: Science/Research
Description-Content-Type: text/markdown
