#!/usr/bin/env python3
#
# Plot residuals, massflow, loads, etc.

__DOC__ = 'Plot residuals, massflow, loads, etc.'

import os
import numpy as np
import argparse

from treelab import cgns
import mola.naming_conventions as names
from mola.logging import mola_logger, MolaException, MolaUserError, CYAN, GREEN, ENDC

import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from matplotlib.patches import Patch

def plot_residuals(signals):
    residuals_zone = signals.get(Name='Residuals', Type='Zone')
    if not residuals_zone:
        mola_logger.warning('No residuals found')
    else:
        residuals = residuals_zone.allFields(include_coordinates=False)

        varList = list(residuals)
        varList.remove('Iteration')
        plt.figure()
        for varname in varList:
            plt.plot(residuals['Iteration'], residuals[varname], label=varname)
        plt.yscale('log')
        plt.xlabel('Iterations')
        plt.ylabel('Residuals')
        plt.legend(loc='best')
        plt.grid()
        save_figure_if_required(f'{names.DIRECTORY_OUTPUT}/residuals.png')

def plot_massflow(signals):
    massflows = dict()

    for node in signals.group(Type='FlowSolution'):
        try:
            massflows[node.Parent.name()] = dict(
                Iteration = node.get(Name='Iteration').value(),
                MassFlow = np.absolute(node.get(Name='MassFlow').value()), 
                )
        except:
            continue

    if len(massflows) == 0:
        mola_logger.warning('No massflow results found')
    else:
        plt.figure()
        extractions = sorted(massflows.keys())
        for extraction in extractions:
            plt.plot(massflows[extraction]['Iteration'], massflows[extraction]['MassFlow'], label=extraction)
        plt.xlabel('Iterations')
        plt.ylabel('MassFlow (kg/s)')
        plt.legend(loc='best')
        plt.grid()
        save_figure_if_required(f'{names.DIRECTORY_OUTPUT}/massflow.png')

def plot_loads(signals):
    quantities = ['Thrust', 'Torque', 'Power']

    loads = dict()
    for integral_base in signals.group(Type='CGNSBase', Name='Integral', Depth=1):
        extraction = integral_base.name() 
        fs = integral_base.get(Type='FlowSolution')

        for quantity in quantities: 

            # first search for the base quantity
            if fs is not None and fs.get(Name=quantity):
                if quantity not in loads:
                    loads[quantity] = dict()
                loads[quantity][extraction] = dict()
                loads[quantity][extraction]['Iteration'] = fs.get(Name='Iteration').value()
                loads[quantity][extraction]['current'] = fs.get(Name=quantity).value()
            else:
                continue

            # Then search for statistics of that quantity
            for prefix in ['avg', 'std', 'rsd']:
                try:
                    loads[quantity][extraction][prefix] = fs.get(Name=f'{prefix}-{quantity}').value()
                except:
                    pass
                
    if len(loads) == 0:
        mola_logger.warning('No loads results found')
    else:
        for quantity in quantities: 

            extractions = sorted(loads[quantity].keys())

            plt.figure()
            legend_elements = []
            for i, extraction in enumerate(extractions):
                array = loads[quantity][extraction]
                current_line, = plt.plot(array['Iteration'], array['current'], '-', color=f'C{i}', label=extraction)
                legend_elements.append(current_line)
                if 'avg' in array:
                    plt.plot(array['Iteration'], array['avg'], '--', color=f'C{i}', label=None)
            
            if len(extractions) > 1:
                # Plot total on all extractions
                total_line, = plt.plot(array['Iteration'], np.sum([loads[quantity][ext]['current'] for ext in extractions], axis=0), 'k-', label='Total')
                legend_elements.append(total_line)

            plt.xlabel('Iterations')
            plt.ylabel(quantity)

            avg_line = Line2D([0], [0], linestyle='--', color='darkgrey', label='avg')
            legend_elements.append(avg_line)
            
            plt.legend(handles=legend_elements, loc='best')
            plt.grid()
            save_figure_if_required(f'{names.DIRECTORY_OUTPUT}/{quantity}.png')

            # if any(['rsd' in loads[quantity][extraction] for extraction in extractions]):
            #     prefix = 'rsd'
            # elif any(['std' in loads[quantity][extraction] for extraction in extractions]):
            #     prefix = 'std'
            # else:
            #     prefix = None

            # if prefix:
            #     plt.figure()
            #     for i, extraction in enumerate(extractions):
            #         array = loads[quantity][extraction]
            #         if prefix in array:
            #             plt.plot(array['Iteration'], array[prefix], '-', color=f'C{i}', label=extraction)
            #     plt.xlabel('Iterations')
            #     plt.ylabel(f'{prefix}-{quantity}')
            #     if prefix == 'rsd':
            #         plt.yscale('log')
            #     plt.legend(loc='best')
            #     plt.grid()
            #     save_figure_if_required(f'{names.DIRECTORY_OUTPUT}/{prefix}-{quantity}.png')

def plot_probes(signals, abscissa='Iteration'):

    def _get_variables(probes):
        variables = []
        for probe in probes.values():
            variables += list(probe)
        variables = list(set(variables))
        variables.remove('Iteration')
        variables.remove('Time')
        return variables
    
    # Get probes data
    probes = dict()
    probes_base = signals.get(Type='CGNSBase', Name='Probes', Depth=1)
    if probes_base is None:
        mola_logger.warning('No probes found')

    for probe_zone in probes_base.zones():
        probe = dict()
        fs = probe_zone.get(Type='FlowSolution', Depth=1)
        for data_node in fs.group(Type='DataArray'):
            probe[data_node.name()] = data_node.value()
        probes[probe_zone.name()] = probe

    if len(probes) == 0:
        mola_logger.warning('No probes found')
    else:      
        # Plot probes data  
        for variable in _get_variables(probes):
            # One plot per variable, with all probes containing that variable
            plt.figure()
            for name, probe in probes.items():
                if variable in probe:    
                    plt.plot(probe[abscissa], probe[variable], label=name)
            plt.xlabel(abscissa)
            plt.ylabel(variable)
            plt.legend(loc='best')
            plt.grid()
            save_figure_if_required(f'{names.DIRECTORY_OUTPUT}/probes_{variable}.png')

def plot_criteria(signals):

    def _read_criteria_from_cgns():
        try:
            ConvergenceCriteria_node = cgns.load_from_path(names.FILE_INPUT_SOLVER, 'WorkflowParameters/ConvergenceCriteria')
        except: 
            mola_logger.warning("No ConvergenceCriteria node found in the CGNS file")
            return []

        ConvergenceCriteria = []
        for criterion_node in ConvergenceCriteria_node.children():
            criterion = dict(
                ExtractionName = criterion_node.get(Name='ExtractionName').value(),
                Variable = criterion_node.get(Name='Variable').value(),
                Threshold = criterion_node.get(Name='Threshold').value(),
            )
            ConvergenceCriteria.append(criterion)

        return ConvergenceCriteria
    
    ConvergenceCriteria = _read_criteria_from_cgns()

    if len(ConvergenceCriteria) == 0:
        mola_logger.warning('No convergence criteria found')
    for criterion in ConvergenceCriteria:

        extraction_for_criterion = signals.get(Name=criterion['ExtractionName'])
        if extraction_for_criterion is None:
            continue
        criterion_variable = extraction_for_criterion.get(Name=criterion['Variable'], Type='DataArray').value()

        iterations = extraction_for_criterion.get(Name='Iteration', Type='DataArray').value()

        plt.figure()
        plt.title(f"Criterion on {criterion['Variable']} for {criterion['ExtractionName']}")
        plt.xlabel('Iterations')
        plt.ylabel(criterion['Variable'])
        plt.plot(iterations, criterion_variable)

        # If criterion is on a relative standard deviation, use a log scale
        if criterion['Variable'].startswith('rsd-'):
            plt.yscale('log')

        # Plot convergence threshold
        if np.isfinite(criterion['Threshold']):
            plt.plot(iterations, np.ones(iterations.shape) * criterion['Threshold'], '--r')
            plt.text(iterations[0], criterion['Threshold'], 'Convergence Threshold', color='red')
    
        plt.grid()
        save_figure_if_required(f"{names.DIRECTORY_OUTPUT}/criterion_{criterion['ExtractionName']}_{criterion['Variable']}.png")

def save_figure_if_required(figname):
    if figname is not None:
        print(f'Saving {CYAN}{figname}{ENDC} ...', end=' ')
        plt.savefig(figname, dpi=150, bbox_inches='tight')
        print(f'{GREEN}OK{ENDC}')

def main():

    available_operations = ['residuals', 'massflow', 'loads', 'probes', 'criteria']

    parser = argparse.ArgumentParser(usage=__DOC__)
    parser.add_argument('-f', '--file', type=str, help='path of the file with 1D results', 
                        default=os.path.join(names.DIRECTORY_OUTPUT, names.FILE_OUTPUT_1D))
    parser.add_argument('-o', '--output', type=str, help=f'Kinds of outputs to plot. Available choices are: {", ".join(available_operations)}', 
                        nargs='*', default=['residuals', 'massflow', 'loads', 'criteria'])
    parser.add_argument('--no-show', action='store_true', help='Do not show the figures on screen')
    args = parser.parse_args()

    if not os.path.isfile(args.file):
        raise MolaUserError(f'Cannot use mola_plot because file {args.file} does not exist.')
    
    signals = cgns.load(args.file)

    operations = dict(
        residuals = plot_residuals,
        massflow = plot_massflow,
        loads = plot_loads,
        probes = plot_probes,
        criteria = plot_criteria,
    )
    for name in args.output:
        if name not in operations:
            raise MolaUserError(f'Unknown kind of output "{name}". See documention of this function with --help')
        operations[name](signals)

    if not args.no_show:
        plt.show()

if __name__ == '__main__':
    main()
