#!/usr/bin/python

import argparse

import numpy as np

import matplotlib.patches as patches
import matplotlib.pyplot as plt

from happyml import plot
from happyml.utils import get_f
from happyml.utils import central_difference


# Parse args.
parser = argparse.ArgumentParser(description="""
    Visualize the gradient of 1-dimensional or 2-dimensional functions.
""", formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument('expressions', nargs='+', type=str,
                    metavar='EXPRESSION',
                    help='Expression to plot. Use "x" and "y" as variables. You can use any '
                         'math functions in the math module (e.g. "sqrt(x)+log(y)").')
parser.add_argument('-l', '--limits', dest='limits', nargs=4, type=float, default=(-1, 1, -1, 1),
                    metavar=('XMIN', 'XMAX', 'YMIN', 'YMAX'),
                    help='Axis limits (default: -1 1 -1 1)')
parser.add_argument('-c', '--levels-curves', dest='levels_curves', nargs='+', type=int, default=(10,),
                    metavar='N_LEVELS_CURVES',
                    help='Number of level curves to plot. Can be a number or a list of numbers '
                         '(at least one per function).')
parser.add_argument('-s', '--start', dest='start', nargs=2, type=float, default=None,
                    metavar=('X', 'Y'),
                    help='Start point where the gradients will be ploted (default: 0 0).')
parser.add_argument('-f', '--field', dest='gradient_field',  action='store_true',
                    help='Plot the gradient field.')
parser.add_argument('-n', '--no-scaled', dest='no_scaled', action='store_true',
                    help='Avoid scaling axes. Useful when the axes scales are quite different. '
                         'If the axes are not scaled the gradient may seem no perpendicular to the '
                         'level curves. Default: Scaled axes.')
parser.add_argument('-g', '--grid', dest='grid', action='store_true',
                    help='Show grid.')
args = parser.parse_args()

# Prepare arguments.
if args.start is None:
    args.start = [(args.limits[0] + args.limits[1]) / 2,
                  (args.limits[2] + args.limits[3]) / 2]
args.start = np.array(args.start)
if len(args.levels_curves) == 1:
    args.levels_curves *= 10

# Plot all expressions.
fig = plt.figure()
ax = fig.gca()
n_expr = len(args.expressions)
gradients_arrows = [None] * n_expr
functions = [None] * n_expr

for i, expr in enumerate(args.expressions):
    # Skip empty expressions.
    if len(expr) == 0: continue

    color = plot.get_class_color(i)

    # Check if it is a expression with a constraint (one level curve).
    constraint = None
    if "=" in expr:
        lhs, rhs = expr.split("=")
        constraint = eval(rhs) if len(rhs) else 0
        expr = lhs

    # Compute grid.
    functions[i] = get_f(expr)
    X, Y, Z = plot.grid_function_slow(functions[i], bounds=args.limits)

    # Plot contours.
    if constraint is not None:
        # Plot expression with constraint, i.e. plot only one level curve.
        ax.contour(X, Y, Z, [constraint,], linewidths=3, colors=color, label=expr)
    else:
        # Plot all level curves (dashed lines for negative levels).
        ax.contour(X, Y, Z, args.levels_curves[i], linewidths=1.5, colors=color, label=expr)
        # Thicker curve at level 0.
        ax.contour(X, Y, Z, [0,], linewidths=3, colors=color, headwidth=10)

    # Plot gradients.
    start_gradient = central_difference(functions[i], args.start[0], args.start[1])
    gradients_arrows[i] = patches.FancyArrowPatch(args.start,
                                                  args.start + start_gradient,
                                                  fc=color, mutation_scale=20, zorder=10)
    ax.add_patch(gradients_arrows[i])

    # Plot gradient fields.
    if args.gradient_field:
        G = np.gradient(Z)  # first array for rows (Y), second for columns (X)
        ax.quiver(X[::2, ::2], Y[::2, ::2], G[1][::2, ::2], G[0][::2, ::2], zorder=10)
    if args.grid:
        ax.grid()
    if not args.no_scaled:
        ax.set_aspect('equal', adjustable='box')


def onmove(event):
    x, y = event.xdata, event.ydata
    # Do nothing if move mouse outside the axis.
    if x is None or y is None:  return
    # Update gradients positions.
    position = np.array([x, y])
    for i, a in enumerate(gradients_arrows):
        if a:
            gradient = central_difference(functions[i], x, y)
            a.set_positions(position, position + gradient)
    plt.draw()


cid = fig.canvas.mpl_connect('motion_notify_event', onmove)

plt.show()


def main():
    pass


if __name__ == "__main__":
    main()
