Metadata-Version: 2.1
Name: surecr
Version: 0.1.2
Summary: This package enables tractable evaluation of Stein's Unbiased Risk Estimate on convexly regularized estimators.
Home-page: https://github.com/cvxgrp/SURE-CR
Author: Parth Nobel
Author-email: ptnobel@stanford.edu
License: UNKNOWN
Platform: UNKNOWN
Classifier: Programming Language :: Python :: 3
Description-Content-Type: text/markdown
Requires-Dist: cvxpy
Requires-Dist: cvxpylayers
Requires-Dist: numpy (>=1.17.5)
Requires-Dist: scipy
Requires-Dist: torch
Requires-Dist: torch-linops

# SURE-CR

This package enables tractable evaluation of Stein's Unbiased Risk Estimate on
convexly regularized estimators.

For an estimator of the mean of a normally-distributed random vector $y$ with known covariance matrix $\sigma^2 I$ given by

$$
    \hat\mu(y) = \mathcal A ~ \text{argmin} \frac{1}{2} \lVert\mathcal A b - y\rVert_2^2 + r(b)
$$

where $r: \mathbb R^p \to \mathbb R$  is a convex function and
$\mathcal A: \mathbb R^p \to \mathbb R^d$ is a linear operator, this package
provides methods to compute Stein's Unbiased  Risk Estimate of $\hat\mu$:

$$
    SURE(\hat\mu, y) = -n \sigma^2 + \lVert\hat\mu(y) - y\rVert_2^2 + 2 \sigma^2 \nabla \cdot \hat\mu(y).
$$

$SURE(\hat\mu, y)$ is a good estimate of the $\ell_2$ risk of $\hat\mu$, especially
for high dimensional problems.

## Installation

SURE-CR can be installed with pip and conda. By default, the conda instillation does not support using the `CVXPYSolver`, and the `cvxpylayers` library will have to be installed seperately to use it.

To install with pip:

```bash
$ pip install surecr
```

To install with conda:

```bash
$ conda install -c stanfordcvxgrp -c conda-forge sure-cr
```

## Examples

The easiest way to start using SURE-CR may be to read the examples:

 * [LASSO with `FISTASolver`](https://github.com/cvxgrp/SURE-CR/blob/main/examples/lasso/lasso.py)
    - Shows how to define proximal operator.
 * [LASSO with `CVXPYSolver`](https://github.com/cvxgrp/SURE-CR/blob/main/examples/lasso/lasso_cvxpy.py)
 * [Matrix completion with `ADMMSolver`](https://github.com/cvxgrp/SURE-CR/blob/main/examples/mat_compl/mat_compl.py)
    - Includes defining custom linear operator.
 * [Robust PCA with `ADMMSolver`](https://github.com/cvxgrp/SURE-CR/blob/main/examples/rpca/rpca.py)
    - Example of separable regularizers.


## Usage

There are three key things in this package:

 * The `SURE` class
 * The `Solver` class and its subclasses `CVXPYSolver`, `FISTASolver`, and `ADMMSolver`
 * The `prox_lib` helper library

### The `SURE` Class

The `SURE` class has the following API:
```python
class SURE:
    def __init__(self, variance: float, solver: Solver): ...

    def compute(self, y: torch.Tensor, divergence_parameters={}) -> float:
        """
        Computes and returns SURE for the estimator computed by the solver
        at the point y.

        Currently, divergence_parameters can contain the key "m" to indicate
        how many samples to use during the divergence estimation (which
        dominates the runtime at high dimensions). The default is for m to be
        102.

        In the future we may switch to A-Hutch++ and may change what options
        the divergence_parameters specifies.
        """

    @property
    def solution(self) -> torch.Tensor:
        """
        Returns solver.solve(y) from the last compute call.
        """

    def runtimes(self) -> TypedDict('Runtimes', solver=float, divergence=float):
        """
        Returns how long it took for the solver to run and how long it took
        the divergence estimator to run during the last compute call.
        """
```


### The `Solver` class

Most uses of the library should use one of the existing `Solver` subclasses.
They have the following APIs:

The three notable `Solver` instances provided by this library have the following
constructors:
```python
class FISTASolver(Solver):
    def __init__(self, A: linops.LinearOperator,
                       prox_R: Callable[[torch.Tensor, float | torch.Tensor], torch.Tensor],
                       x0: torch.Tensor,
                       device=None,
                       lipschitz_iterations=20,
                       lipschitz_vec=None,
                       *, max_iters=5000, eps=1e-3):
        """
        This solver solves problems of the form with a variant on FISTA:
              min. 1/2 ||A b - y||_2^2 + r(b)
        and estimates the mean of y with A b^* where b^* is the optimal b.

        A is a linear operator defined using <https://github.com/cvxgrp/torch_linops>

        prox_R is a differentiable-with-respect-to-its-first-argument function to
            find the optimal point b for a (v, t) pair of
              min. t r(b) + 1/2 ||b - v||_2^2

        x0 is the point where we begin iterations, it must be chosen
            indepentently of y.

        lipschitz_iterations is how many iterations of the power method to use
        to approximate the largest eigenvalue of A^T A

        lipschitz_vec is the vector to start the power method. By default, a
        vector of all 1s is used. If this vector is orthogonal to the largest
        eigenvector of A^T A, this argument is mandatory.

        max_iters, eps control when iterations stop.

        """

class ADMMSolver(Solver):
    def __init__(self, A: linops.LinearOperator,
                       prox_R: Callable[[torch.Tensor, float | torch.Tensor], torch.Tensor],
                       x0: torch.Tensor,
                       device=None,
                       *, max_iters=1000, eps_rel=1e-3, eps_abs=1e-6):
        """
        This solver solves problems of the form with a variant on ADMM:
              min. 1/2 ||A b - y||_2^2 + r(b)
        and estimates the mean of y with A b^* where b^* is the optimal b.

        A is a linear operator defined using <https://github.com/cvxgrp/torch_linops>

        prox_R is a differentiable-with-respect-to-its-first-argument function to
            find the optimal point b for a (v, t) pair of
              min. t r(b) + 1/2 ||b - v||_2^2

        x0 is the point where we begin iterations, it must be chosen
            indepentently of y.

        max_iters, eps_rel, eps_abs control when iterations stop.
        """

class CVXPYSolver(Solver):
    def __init__(self, problem: cp.Problem,
                       y_parameter: cp.Parameter, 
                       variables: list[cp.Variable], 
                       estimate: Callable[[list[torch.Tensor]], torch.Tensor]):
        """
        problem must be a CVXPY problem with a single paremeter, y_parameter,
            and variables y_variable.

        estimate must be function which takes tensors with values for each variable
            and returns the estimate.

        WARNING: This solver has poor performance on large problems, and can
        have undetected poor accuracy on some moderately-sized problems.
        """
```

If you wish to implement, `Solver`, it has has the following API, where `T` is
any type of the implementation's choice:
```python
class Solver:

    def solve(self, y: torch.Tensor) -> T:
        """
        Returns intermediate value used to estimate the mean of the distribution
        y is sampled from.
        """

    def estimate(self, beta: T) -> torch.Tensor: ...
        """
        Given the output of a solve call, returns the estimate of the mean of the
        distribution y was sampled from.
        """
```

Note that for a given instance `s` of a solver class, `s.estimate(s.solve(y))` must
be differentiable via torch's backpropagation.


### The `prox_lib` library
Since `FISTASolver` and `ADMMSolver` both require a proximal operator for the
regularizer we provide some methods here to help construct proximal operators:

There are also many helper methods in `surecr.prox_lib`.

 * `prox_l1_norm(v, t)`: the $\ell_1$ norm's proximal operator.
 * `prox_l2_norm(v, t)`: the $\ell_2$ norm's proximal operator.
 * `make_scaled_prox_nuc_norm(shape: tuple[int, int], t_scale: float)`: generates the proximal operator
    $\text{prox}_{r}: \mathbb R^{\mathtt{shape}} \to \mathbb R^{\mathtt{shape}}$
    of 
    $b \mapsto \mathtt{t_scale} \sum_i \sigma_i(b)$
 * `combine_proxs(shape: list[int], proxs: list)`: if there are two regularizers
    $r_1$, $r_2$ such that the regularizer for the problem is given by
    $r(b, b') = r_1(b) + r_2(b')$, then this function should be called with
    `([dim(b), dim(b')], [prox_r_1, prox_r_2])`.
 * `scale_prox(prox, t_scale)`: takes a proximal operator of $r$, and returns the
    proximal operator of $\mathtt{t_scale} r$.


# Citing
If you use this code in a research project, please cite the associated paper. 
```
@article{nobel2022tractable,
    title={Tractable evalutaion of {S}tein's {U}nbiased {R}isk {E}stimate with convex regularizers},
    author={Parth Nobel \and Emmanuel Cand\`es \and Stephen Boyd},
    publisher = {arXiv},
    year = {2022},
    note = {arXiv:2211.05947 [math.ST]},
    url = {https://arxiv.org/abs/2211.05947},
}
```


