import numpy as np
import matplotlib.pyplot as plt
# import cvxpy as cvx
from scipy.signal import remez, minimum_phase, freqz, group_delay
from ellalgo.cutting_plane import Options, cutting_plane_optim
from ellalgo.ell import Ell
from ellalgo.oracles.spectral_fact import spectral_fact
# from problem import Problem
from ellalgo.oracles.lowpass_oracle import create_lowpass_case
# from ellalgo.tests.test_lowpass import create_lowpass_case


# def create_csdlowpass_case(ndim=48, nnz=8):
#     """[summary]

#     Keyword Arguments:
#         ndim (int): [description] (default: {48})
#         nnz (int): [description] (default: {8})

#     Returns:
#         [type]: [description]
#     """
#     P, Spsq = create_lowpass_case(ndim)
#     Pcsd = csdlowpass_oracle(nnz, P)
#     return Pcsd, Spsq


def plot_lowpass_result(r, Spsq_new):
    # *********************************************************************
    # plotting routines
    # *********************************************************************
    # frequency response of the designed filter, where j = sqrt(-1)
    h_sp = spectral_fact(r)  # from CVX distribution, Examples subdirectory
    # I'm not sure how accurate this function performs!!!
    h = h_sp
    print("h = ", h)
    # compute the min attenuation in the stopband (convert to original vars)
    Ustop = 20 * np.log10(np.sqrt(Spsq_new))

    print('Min attenuation in the stopband is ', Ustop, ' dB.')

    freq = [0, 0.12, 0.2, 1.0]
    desired = [1, 0]
    h_linear = remez(151, freq, desired, Hz=2.)
    h_min_hom = minimum_phase(h_linear, method='homomorphic')

    # fig, axs = plt.subplots(4, figsize=(4, 8))
    fig = plt.figure()
    ax1 = fig.add_subplot(221)
    ax2 = fig.add_subplot(222)
    ax3 = fig.add_subplot(223)
    ax4 = fig.add_subplot(224)
    axs = (ax1, ax2, ax3, ax4)
    for h, style, color in zip((h_min_hom, h_sp), ('-', '-'), ('r', 'k')):
        # if feasible:
        w, H = freqz(h)
        w, gd = group_delay((h, 1))
        w /= np.pi
        axs[0].plot(h, color=color, linestyle=style)
        axs[1].plot(w, np.abs(H), color=color, linestyle=style)
        axs[2].plot(w,
                    20 * np.log10(np.abs(H)),
                    color=color,
                    linestyle=style)
        axs[3].plot(w, gd, color=color, linestyle=style)

    for ax in axs:
        ax.grid(True, color='0.5')
        ax.fill_between(freq[1:3], *ax.get_ylim(), color='#ffeeaa', zorder=1)

    axs[0].set(xlim=[0, len(h_linear) - 1],
               ylabel='Amplitude',
               xlabel='Samples')
    axs[1].legend(['Min-Hom', 'Our'], title='Phase')
    for ax, ylim in zip(axs[1:], ([0, 1.1], [-80, 10], [-60, 60])):
        ax.set(xlim=[0, 1], ylim=ylim, xlabel='Frequency')
    axs[1].set(ylabel='Magnitude')
    axs[2].set(ylabel='Magnitude (dB)')
    axs[3].set(ylabel='Group delay')
    plt.tight_layout()
    plt.show()


# Modified from CVX code by Almir Mutapcic in 2006.
# Adapted in 2010 for impulse response peak-minimization by convex iteration
# by Christine Law.
#
# "FIR Filter Design via Spectral Factorization and Convex Optimization"
# by S.-P. Wu, S. Boyd, and L. Vandenberghe
#
# Designs an FIR lowpass filter using spectral factorization method with
# constraint on maximum passband ripple and stopband attenuation:
#
#   minimize   max |H(w)|                      for w in stopband
#       s.t.   1/delta <= |H(w)| <= delta      for w in passband
#
# We change variables via spectral factorization method and get:
#
#   minimize   max R(w)                          for w in stopband
#       s.t.   (1/delta)**2 <= R(w) <= delta**2  for w in passband
#              R(w) >= 0                         for all w
#
# where R(w) is squared magnitude frequency response
# (and Fourier transform of autocorrelation coefficients r).
# Variables are coeffients r and G = hh' where h is impulse response.
# delta is allowed passband ripple.
# This is a convex problem (can be formulated as an SDP after sampling).

# rand('twister',sum(100*clock))
# randn('state',sum(100*clock))

# def test_lowpass():
if __name__ == "__main__":
    # tic = time.time()
    ndim = 48
    r0 = np.zeros(ndim)  # initial x0
    r0[0] = 0

    # ********************************************************************
    # optimization
    # ********************************************************************
    # Ae = diag(ones(ndim,1)) # initial ellipsoid (sphere)
    E = Ell(40., r0)
    # E.use_parallel_cut = False
    P, Spsq = create_lowpass_case(ndim)
    options = Options()
    # options.max_it = 50000
    # options.tolerance = 1e-11
    r, Spsq_new, _ = cutting_plane_optim(P, E, Spsq, options)
    assert r is not None
    print(r)
    plot_lowpass_result(r, Spsq_new)

    # toc = time.time()

    # print(num_iters)

    # x = r
    # m = length(x)
    # u = x(m:-1:1)'
    # u(m) = 0.5*x(1)
    # d = roots(u)
    # figure(3)
    # plot(1./d,'x')
    # axis('square')
    # grid on
    # hold on
    # elplot([1 0 0 1], [0 0])

    # E = ell(1,r0)
    # P = FIR_oracle2(Ap, As, Anr, Lpsq, Upsq)
    # [r, Spsq_new, iter, feasible, status] ...
    #  = ellipsoid_dc(@P.evaluate, E, Spsq, 100000, 1e-4)
    # toc
    # iter
