Metadata-Version: 2.1
Name: flax
Version: 0.2.2
Summary: Flax: A neural network library for JAX designed for flexibility
Home-page: https://github.com/google/flax
Author: Flax team
Author-email: flax-dev@google.com
License: UNKNOWN
Platform: UNKNOWN
Classifier: Development Status :: 3 - Alpha
Classifier: Intended Audience :: Developers
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: Apache Software License
Classifier: Programming Language :: Python :: 3.7
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Description-Content-Type: text/markdown
Requires-Dist: numpy (>=1.12)
Requires-Dist: jax (>=0.1.59)
Requires-Dist: matplotlib
Requires-Dist: msgpack
Requires-Dist: dataclasses ; python_version < "3.7"
Provides-Extra: testing
Requires-Dist: atari-py ; extra == 'testing'
Requires-Dist: gym ; extra == 'testing'
Requires-Dist: jaxlib ; extra == 'testing'
Requires-Dist: ml-collections ; extra == 'testing'
Requires-Dist: opencv-python ; extra == 'testing'
Requires-Dist: pytest ; extra == 'testing'
Requires-Dist: pytest-cov ; extra == 'testing'
Requires-Dist: pytest-xdist (==1.34.0) ; extra == 'testing'
Requires-Dist: svn ; extra == 'testing'
Requires-Dist: tensorflow ; extra == 'testing'
Requires-Dist: tensorflow-datasets ; extra == 'testing'

# Flax: A neural network library for JAX designed for flexibility

[![coverage](https://badgen.net/codecov/c/github/google/flax)](https://codecov.io/github/google/flax)

**NOTE**: Flax is being actively improved and has a growing community
of researchers and engineers at Google who happily use Flax for their
daily research. Flax is in "early release stage" -- if that's your style,
now could be a good time to start using it.
We want to smooth out any rough edges so please report
any issues, questions or concerns in our 
[discussion forum](https://github.com/google/flax/discussions), or just let us know 
what you're working on!

Expect changes to the
API, but we'll use deprecation warnings when we can, and keep
track of them in our [Changelog](CHANGELOG.md).

In case you need to reach us directly, we're at flax-dev@google.com.

## Quickstart

**⟶ [Full documentation and API reference](https://flax.readthedocs.io/)**

**⟶ [Annotated full end-to-end MNIST example](https://flax.readthedocs.io/en/latest/annotated_mnist.html)**

**⟶ [The Flax Guide](https://flax.readthedocs.io/en/latest/notebooks/flax_guided_tour.html)** -- a guided walkthrough of the parts of Flax

## Background: JAX

[JAX](https://github.com/google/jax) is NumPy + autodiff + GPU/TPU

It allows for fast scientific computing and machine learning
with the normal NumPy API
(+ additional APIs for special accelerator ops when needed)

JAX comes with powerful primitives, which you can compose arbitrarily:

* Autodiff (`jax.grad`): Efficient any-order gradients w.r.t any variables
* JIT compilation (`jax.jit`): Trace any function ⟶ fused accelerator ops
* Vectorization (`jax.vmap`): Automatically batch code written for individual samples
* Parallelization (`jax.pmap`): Automatically parallelize code across multiple accelerators (including across hosts, e.g. for TPU pods)

## What is Flax?

Flax is a high-performance neural network library for
JAX that is **designed for flexibility**:
Try new forms of training by forking an example and by modifying the training
loop, not by adding features to a framework.

Flax is being developed in close collaboration with the JAX team and 
comes with everything you need to start your research, including:

* **Common layers** (`flax.nn`): Dense, Conv, {Batch|Layer|Group} Norm, Attention, Pooling, {LSTM|GRU} Cell, Dropout

* **Optimizers** (`flax.optim`): SGD, Momentum, Adam, LARS, Adagrad, LAMB, RMSprop

* **Utilities and patterns**: replicated training, serialization and checkpointing, metrics, prefetching on device

* **Educational examples** that work out of the box: MNIST, LSTM seq2seq, Graph Neural Networks, Sequence Tagging

* **HOWTO guides**: diffs that add functionality to educational base examples

* **Fast, tuned large-scale end-to-end examples**: CIFAR10, ResNet on ImageNet, Transformer LM1b

## Try Flax now by forking one of our starter examples

We keep here a limited list of canonical examples maintained by the Flax team. If you are looking for more examples, or others built by the community, please check the [examples folder](examples/README.md) for further guidance.

### Image Classification
⟶ [MNIST](examples/mnist) (also see [annotated version](https://flax.readthedocs.io/en/latest/annotated_mnist.html))

⟶ [CIFAR-10](examples/cifar10) (Wide ResNet w/ and w/o Shake-Shake, PyramidNet w/ShakeDrop)

⟶ [ResNet50 on ImageNet](examples/imagenet)

### Transformer Models
⟶ [Sequence tagging on Universal Dependencies](examples/nlp_seq)

⟶ [LM1b language modeling](examples/lm1b) **([try on a TPU in Colab](https://colab.research.google.com/github/google/flax/blob/master/examples/lm1b/Colab_Language_Model.ipynb))**

⟶ [WMT translation](examples/wmt)

### RNNs
⟶ [LSTM text classifier on SST-2](examples/sst2)

⟶ [LSTM seq2seq on number addition](examples/seq2seq)


### Generative Models
⟶ [Basic VAE](examples/vae)

### Graph Neural Networks
⟶ [Semi-supervised node classification on Zachary's karate club](examples/graph)

## The Flax Module abstraction in a nutshell

The core of Flax is the Module abstraction. Modules allow you to write parameterized functions just as if you were writing a normal numpy function with JAX. The Module API allows you to declare parameters and use them directly with the JAX APIs.

Modules are the one part of Flax with "magic" -- the magic is constrained, and enables a very ergonomic model construction style, where modules are defined in a single function with minimal boilerplate.

A few things to know about Modules:

1. Create a new module by subclassing `flax.nn.Module` and implementing the `apply` method.

2. Within `apply`, call `self.param(name, shape, init_func)` to register a new parameter and returns its initial value.

3. Apply submodules with `MySubModule(name=..., ...)` within `MyModule.apply`. Parameters of `MySubModule` are stored
as a dictionary under the parameters `MyModule` and accessible via `self.get_param(name=...)`. This applies `MySubmodule` once --
to re-use parameters, use [`Module.shared`](https://flax.readthedocs.io/en/latest/notebooks/flax_intro.html#Parameter-sharing)

4. `MyModule.init(rng, ...)` is a pure function that calls `apply` in "init mode" and returns a nested Python dict of initialized parameter values

5. `MyModule.call(params, ...)` is a pure function that calls `apply` in "call mode" and returns the output of the module.

For example you can define a learned linear transformation as follows:

```py
from flax import nn
import jax.numpy as jnp

class Linear(nn.Module):
  def apply(self, x, num_features, kernel_init_fn):
    input_features = x.shape[-1]
    W = self.param('W', (input_features, num_features), kernel_init_fn)
    return jnp.dot(x, W)
```

You can also use `nn.module` as a function decorator to create a new module, as
long as you don't need access to `self` for creating parameters directly:

```py
@nn.module
def DenseLayer(x, features):
  x = flax.nn.Dense(x, features)
  x = flax.nn.relu(x)
  return x
```

**⟶ Read more about Modules in the [Flax Guide](https://flax.readthedocs.io/en/latest/notebooks/flax_guided_tour.html#Simplifying-Neural-Networks-in-JAX:-Flax-Modules)**

## A full ResNet implementation

(from [examples/imagenet/resnet_v1.py](examples/imagenet/resnet_v1.py))

```py
class ResidualBlock(nn.Module):
  def apply(self, x, filters, strides=(1, 1), train=True, dtype=jnp.float32):
    needs_projection = x.shape[-1] != filters * 4 or strides != (1, 1)
    batch_norm = nn.BatchNorm.partial(
        use_running_average=not train, momentum=0.9, epsilon=1e-5, dtype=dtype)
    conv = nn.Conv.partial(bias=False, dtype=dtype)

    residual = x
    if needs_projection:
      residual = conv(residual, filters * 4, (1, 1), strides, name='proj_conv')
      residual = batch_norm(residual, name='proj_bn')

    y = conv(x, filters, (1, 1), name='conv1')
    y = batch_norm(y, name='bn1')
    y = nn.relu(y)
    y = conv(y, filters, (3, 3), strides, name='conv2')
    y = batch_norm(y, name='bn2')
    y = nn.relu(y)
    y = conv(y, filters * 4, (1, 1), name='conv3')

    y = batch_norm(y, name='bn3', scale_init=nn.initializers.zeros)
    y = nn.relu(residual + y)
    return y


class ResNet(nn.Module):
  def apply(self, x, num_classes, num_filters=64, num_layers=50,
            train=True, dtype=jnp.float32):
    if num_layers not in _block_size_options:
      raise ValueError('Please provide a valid number of layers')
    block_sizes = _block_size_options[num_layers]
    x = nn.Conv(
        x, num_filters, (7, 7), (2, 2), padding=[(3, 3), (3, 3)],
        bias=False, dtype=dtype, name='init_conv')
    x = nn.BatchNorm(
        x, use_running_average=not train, momentum=0.9,
        epsilon=1e-5, dtype=dtype, name='init_bn')
    x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
    for i, block_size in enumerate(block_sizes):
      for j in range(block_size):
        strides = (2, 2) if i > 0 and j == 0 else (1, 1)
        x = ResidualBlock(
            x, num_filters * 2 ** i, strides=strides,
            train=train, dtype=dtype)
    x = jnp.mean(x, axis=(1, 2))
    x = nn.Dense(x, num_classes)
    x = nn.log_softmax(x)
    return x
```

## Installation

You will need Python 3.6 or later.

For GPU support, first install `jaxlib`; please follow the
instructions in the [JAX
readme](https://github.com/google/jax/blob/master/README.md).  If they
are not already installed, you will need to install
[CUDA](https://developer.nvidia.com/cuda-downloads) and
[CuDNN](https://developer.nvidia.com/cudnn) runtimes.

Then install `flax` from PyPi:

```
> pip install flax
```

## TPU support

We currently have a [LM1b/Wikitext-2 language model with a Transformer architecture](https://colab.research.google.com/github/google/flax/blob/master/examples/lm1b/Colab_Language_Model.ipynb)
that's been tuned. You can run it directly via Colab.

At present, Cloud TPUs are network-attached, and Flax users typically feed in data from one or more additional VMs

When working with large-scale input data, it is important to create large enough VMs with sufficient network bandwidth to avoid having the TPUs bottlenecked waiting for input

TODO: Add an example for running on Google Cloud.

## Getting involved
We welcome pull requests, in particular for those issues [marked as PR-ready](https://github.com/google/flax/issues?q=is%3Aopen+is%3Aissue+label%3A%22Status%3A+pull+requests+welcome%22). For other proposals, we ask that you first open an Issue to discuss your planned contribution.

## Note

This is not an official Google product.


