#! /usr/bin/env python

from pymlip.lrlambda import cos_sweep
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider, Button, RadioButtons

axis_color = 'lightgoldenrodyellow'

fig = plt.figure()
ax = fig.add_subplot(111)

# Adjust the subplots region to leave some space for the sliders and buttons
fig.subplots_adjust(bottom=0.30)


lr_lambda = cos_sweep(
        {
            "lambda_lr_warm_length": 10,
            "lambda_lr_gamma": 0.97,
            "lambda_lr_length": 100,
        }
    )
[line] = lr_lambda.plot_ax(ax, 1000)

# Add two sliders for tweaking the parameters
# Define an axes area and draw a slider in it
warmup_slider_ax  = fig.add_axes([0.25, 0.15, 0.65, 0.03], facecolor=axis_color)
warmup_slider = Slider(warmup_slider_ax, 'Warmup', 1, 25, valinit=10, valfmt='%0.0f')

# Draw another slider
length_slider_ax = fig.add_axes([0.25, 0.1, 0.65, 0.03], facecolor=axis_color)
length_slider = Slider(length_slider_ax, 'Freq', 1, 200.0, valinit=100, valfmt='%0.0f')

# Draw another slider
gamma_slider_ax = fig.add_axes([0.25, 0.05, 0.65, 0.03], facecolor=axis_color)
gamma_slider = Slider(gamma_slider_ax, 'Gamma', 0.95, 1, valinit=0.9)

# Define an action for modifying the line when any slider's value changes
def sliders_on_changed(val):
    lr_lambda.warmup_length = int(warmup_slider.val)
    lr_lambda.length = int(length_slider.val)
    lr_lambda.gamma = gamma_slider.val
    epochs, lrs = lr_lambda.get_lrs(1000)
    line.set_ydata(lrs)
    fig.canvas.draw_idle()
warmup_slider.on_changed(sliders_on_changed)
length_slider.on_changed(sliders_on_changed)
gamma_slider.on_changed(sliders_on_changed)


plt.show()
