#!/usr/bin/env python
""" I hate this script and it should be rewritten, but
this should read data from a run3 convergence test and plot it.
"""
from matador.scrapers.castep_scrapers import castep2dict
from matador.plotting.plotting import plotting_function
from matador.utils.chem_utils import get_formula_from_stoich
from os import walk, chdir
from os.path import isdir
from decimal import Decimal, ROUND_UP
from collections import defaultdict
from traceback import print_exc
import numpy as np


def round(n, prec):
    """Replace default (bankers) rounding with "normal" rounding."""
    if prec is None:
        return n
    else:
        return float(Decimal(str(n)).quantize(Decimal("0.05"), rounding=ROUND_UP))


def get_files(path, only=None):
    """Find all CASTEP files in the directory."""
    chdir(path)
    structure_files = defaultdict(list)
    for root, dirs, files in walk(".", topdown=True, followlinks=True):
        for file in files:
            if only is None or only in file:
                if file.endswith(".castep"):
                    castep_dict, success = castep2dict(root + "/" + file, db=False)
                    if not success:
                        print("Failure to read castep")
                    else:
                        source = castep_dict["source"][0].split("/")[-1]
                        source = source.replace(".castep", "")
                        source = "".join(source.split("_")[:-1])
                        structure_files[source].append(castep_dict)
    chdir("..")
    return structure_files


def get_data(structure_files, conv_field="cut_off_energy"):
    """Parse cutoff energy/kpt spacing convergence calculations from list of files.

    Parameters:
        structure_files (list): list of filenames.

    Keyword arguments:
        conv_field (str): field for convergence parameter.

    """
    scraped_from_filename = None

    form_set = set(
        [
            get_formula_from_stoich(structure_files[structure][0]["stoichiometry"])
            for structure in structure_files
        ]
    )
    if len(form_set) == 1:
        chempot_mode = False
        single = True
        print("Working in single stoichiometry mode..")

    else:
        print("Searching for chemical potentials")
        chempot_mode = True
        single = False
        chempots_dict = defaultdict(dict)
        chempots = defaultdict(list)
        chempot_list = dict()

    if conv_field == "kpoints_mp_spacing":
        rounding = None
    else:
        rounding = None

    if chempot_mode:
        for key in structure_files:
            for doc in structure_files[key]:
                if conv_field == "kpoints_mp_spacing":
                    scraped_from_filename = float(
                        doc["source"][0].split("/")[-1].split("_")[-1].split("A")[0]
                    )
                try:
                    if not single and len(doc["stoichiometry"]) == 1:
                        doc["formation_energy_per_atom"] = 0
                        if scraped_from_filename is not None:
                            rounded_field = round(scraped_from_filename, rounding)
                        else:
                            rounded_field = round(doc[conv_field], rounding)
                        if "total_energy_per_atom" in doc:
                            chempots_dict[str(rounded_field)][
                                doc["atom_types"][0]
                            ] = doc["total_energy_per_atom"]
                            chempots[key].append(
                                [rounded_field, doc["total_energy_per_atom"]]
                            )
                        else:
                            chempots_dict[str(rounded_field)][
                                doc["atom_types"][0]
                            ] = doc["enthalpy_per_atom"]
                            chempots[key].append(
                                [rounded_field, doc["enthalpy_per_atom"]]
                            )
                        chempot_list[key] = doc["stoichiometry"][0][0]
                except:
                    print(
                        "Error with {} and {} = {}".format(
                            key, conv_field, rounded_field
                        )
                    )
                    print_exc()
                    exit()

    form = defaultdict(list)
    forces = defaultdict(list)
    stoich_list = dict()
    elems = set()

    if chempot_mode:
        for value in chempots_dict:
            for elem in chempots_dict[value]:
                elems.add(elem)

        for value in chempots_dict:
            for elem in elems:
                if elem not in chempots_dict[value]:
                    print(
                        "WARNING: {} chemical potential missing at {} = {} eV, skipping this value.".format(
                            elem, conv_field, value
                        )
                    )

    for key in structure_files:
        for doc in structure_files[key]:
            if conv_field == "kpoints_mp_spacing":
                scraped_from_filename = float(
                    doc["source"][0].split("/")[-1].split("_")[-1].split("A")[0]
                )
            if single or len(doc["stoichiometry"]) != 1:
                try:
                    if "total_energy_per_atom" in doc:
                        doc["formation_energy_per_atom"] = doc["total_energy_per_atom"]
                    else:
                        doc["formation_energy_per_atom"] = doc["enthalpy_per_atom"]
                    if scraped_from_filename is not None:
                        rounded_field = round(scraped_from_filename, rounding)
                    else:
                        rounded_field = round(doc[conv_field], rounding)
                    if chempot_mode:
                        for atom in doc["atom_types"]:
                            doc["formation_energy_per_atom"] -= chempots_dict[
                                str(rounded_field)
                            ][atom] / len(doc["atom_types"])
                    form[key].append([rounded_field, doc["formation_energy_per_atom"]])
                    stoich_list[key] = get_formula_from_stoich(
                        doc["stoichiometry"], tex=True
                    )
                    if "forces" in doc:
                        forces[key].append(
                            [
                                rounded_field,
                                np.sqrt(
                                    np.sum(np.asarray(doc["forces"]) ** 2, axis=-1)
                                ),
                            ]
                        )
                        assert len(forces[key][-1][-1]) == len(doc["atom_types"])
                except:
                    print_exc()
                    pass

    for key in form:
        if conv_field == "kpoints_mp_spacing":
            reverse = True
        else:
            reverse = False
        form[key] = sorted(form[key], key=lambda x: x[0], reverse=reverse)
        form[key] = np.asarray(form[key])

        forces[key] = sorted(forces[key], key=lambda x: x[0], reverse=reverse)

    if chempot_mode:
        for key in chempots:
            chempots[key] = sorted(form[key], key=lambda x: x[0])
            chempots[key] = np.asarray(chempots[key])

    data = dict()
    data["form"] = form
    data["forces"] = forces
    data["stoich_list"] = stoich_list
    if chempot_mode:
        data["chempots"] = chempots
        data["chempot_list"] = chempot_list

    return data


@plotting_function
def plot_both(
    plot_cutoff=False,
    plot_kpt=False,
    cutoff_data={},
    kpt_data={},
    log=True,
    show_chempots=False,
    **kwargs
):
    """Plot convergence of either cutoff/kpts or both."""
    import matplotlib.pyplot as plt

    num_structures = max(
        len(cutoff_data["stoich_list"] if "stoich_list" in cutoff_data else [1]),
        len(kpt_data["stoich_list"] if "stoich_list" in kpt_data else [1]),
    )
    print("Number of structures: {}".format(num_structures))
    if plot_cutoff and plot_kpt and "forces" in cutoff_data and "forces" in kpt_data:
        fig = plt.figure()
        forces = True
        ax_cutoff = fig.add_subplot(221)
        ax_cutoff_forces = fig.add_subplot(223, sharex=ax_cutoff)
        ax_kpt = fig.add_subplot(222)
        ax_kpt_forces = fig.add_subplot(224, sharex=ax_kpt)
    elif plot_cutoff and plot_kpt:
        fig = plt.figure()
        ax_cutoff = fig.add_subplot(211)
        ax_kpt = fig.add_subplot(212)
    elif plot_cutoff and not plot_kpt:
        fig, ax_cutoff = plt.subplots(1, 1)
    elif plot_kpt and not plot_cutoff:
        fig, ax_kpt = plt.subplots(1, 1)

    if plot_cutoff:
        xpoints = []
        ypoints = []
        lines = []
        labels = []
        cutoff_form = cutoff_data["form"]
        if forces:
            cutoff_forces = cutoff_data["forces"]
        if "chempots" in cutoff_data:
            cutoff_chempots = cutoff_data["chempots"]
        for ind, key in enumerate(cutoff_form):
            try:
                relative_energies = 1000 * np.abs(
                    cutoff_form[key][:, 1] - cutoff_form[key][-1, 1]
                )
            except:
                print("Issue with {}: {}".format(key, cutoff_form[key]))
                continue
            if log:
                x, y = -1 / cutoff_form[key][:, 0], np.log10(relative_energies)
            else:
                x, y = -1 / cutoff_form[key][:, 0], relative_energies

            (line,) = ax_cutoff.plot(
                x, y, "o", markersize=5, alpha=1, label=key, lw=0, zorder=1000
            )
            (point,) = ax_cutoff.plot(
                x, y, "-", alpha=0.2, label=key, lw=1, zorder=1000, c=line.get_color()
            )

            xpoints.append(x)
            ypoints.append(y)

            lines.append(line)
            labels.append(key)
            if show_chempots and "chempots" in cutoff_data:
                for key in cutoff_chempots:
                    ax_cutoff.plot(
                        -1 / cutoff_chempots[key][:, 0],
                        np.log10(
                            (cutoff_chempots[key][:, 1] - cutoff_chempots[key][-1, 1])
                            * 1000
                        ),
                        "o-",
                        markersize=5,
                        alpha=1,
                        label=key,
                        lw=1,
                        c="grey",
                    )
            if forces:
                try:
                    for ind, value in enumerate(cutoff_forces[key]):
                        relative_forces = np.abs(
                            np.asarray(cutoff_forces[key][ind][1])
                            - np.asarray(cutoff_forces[key][-1][1])
                        )
                        ax_cutoff_forces.plot(
                            len(cutoff_forces[key][ind][1]) * [-1 / value[0]],
                            relative_forces,
                            alpha=0.2,
                            c=lines[-1].get_color(),
                        )
                        ax_cutoff_forces.scatter(
                            -1 / value[0],
                            np.mean(relative_forces),
                            alpha=0.5,
                            c=lines[-1].get_color(),
                        )
                except:
                    print("Issue with {}: {}".format(key, cutoff_form[key]))

        if forces:
            ax_cutoff_forces.set_ylabel("Force eV/A")
            if kwargs["force_range"]:
                ax_cutoff_forces.set_ylim(0, kwargs["force_range"])
            if kwargs["force_target"]:
                ax_cutoff_forces.axhline(kwargs["force_target"], lw=0.5, c="r", ls="--")

        if log:
            ax_cutoff.set_ylabel("log(Relative energy difference (meV/atom))")
        else:
            ax_cutoff.set_ylabel("Relative energy difference (meV/atom)")
        ax_cutoff.set_xlabel("1 / plane wave cutoff (eV)")
        try:
            min_ = 1e10
            max_ = 0
            for key in cutoff_form:
                try:
                    min_cutoff = np.min(cutoff_form[key][:, 0])
                    if min_cutoff < min_:
                        min_ = min_cutoff
                    max_cutoff = np.max(cutoff_form[key][:, 0])
                    if max_cutoff > max_:
                        max_ = max_cutoff
                except:
                    pass
            xlabels = np.arange(min_, 2 * max_, step=100)
            xlabels_str = ["1/{:3.0f}".format(val[1]) for val in enumerate(xlabels)]
            for i in range(len(xlabels_str)):
                if i % 2 == 1 or xlabels[i] > max_:
                    xlabels_str[i] = ""
            ax_cutoff.set_xticks(-1 / xlabels)
            ax_cutoff.set_xticklabels(xlabels_str)
            ax_cutoff.set_xlim(-1.1 * 1 / min_, -0.9 * 1 / max_)
            if kwargs["energy_range"]:
                ax_cutoff.set_ylim(0, kwargs["energy_range"])
        except:
            print_exc()
            print("No cutoff.conv file found, axis labels may be ugly...")

        max_y = dict()
        for ind, points in enumerate(xpoints):
            for jnd, x in enumerate(points):
                if max_y.get(x) is None or ypoints[ind][jnd] > max_y[x]:
                    max_y[x] = ypoints[ind][jnd]

        max_x_vals = np.asarray(sorted(max_y.keys()))
        max_y_vals = np.asarray([max_y[x] for x in max_x_vals])
        if kwargs["energy_target"]:
            ax_cutoff.fill_between(
                max_x_vals[np.where(max_y_vals <= kwargs["energy_target"])],
                max_y_vals[np.where(max_y_vals <= kwargs["energy_target"])],
                alpha=0.2,
                color="green",
            )
            ax_cutoff.fill_between(
                max_x_vals[np.where(max_y_vals > kwargs["energy_target"])],
                max_y_vals[np.where(max_y_vals > kwargs["energy_target"])],
                alpha=0.2,
                color="red",
            )
        else:
            ax_cutoff.fill_between(max_x_vals, max_y_vals, alpha=0.2, color="grey")

    if plot_kpt:
        lines = []
        labels = []
        xpoints = []
        ypoints = []
        kpt_form = kpt_data["form"]
        if forces:
            kpt_forces = kpt_data["forces"]
        if "chempots" in kpt_data:
            kpt_chempots = kpt_data["chempots"]
        for ind, key in enumerate(kpt_form):
            try:
                relative_energies = 1000 * np.abs(
                    kpt_form[key][:, 1] - kpt_form[key][-1, 1]
                )
            except:
                print("Issue with {}: {}".format(key, kpt_form[key]))
                continue
            if log:
                x, y = kpt_form[key][:, 0], np.log10(relative_energies)
            else:
                x, y = kpt_form[key][:, 0], relative_energies

            xpoints.append(x)
            ypoints.append(y)

            (line,) = ax_kpt.plot(
                x, y, "o", markersize=5, alpha=1, label=key, lw=0, zorder=1000
            )
            (point,) = ax_kpt.plot(
                x, y, "-", alpha=0.2, label=key, c=line.get_color(), lw=1, zorder=1000
            )
            lines.append(line)
            labels.append(key)
            if show_chempots and "chempots" in kpt_data:
                for key in kpt_chempots:
                    ax_kpt.plot(
                        kpt_chempots[key][:, 0],
                        (kpt_chempots[key][:, 1] - kpt_chempots[key][0, 1]) * 1000,
                        "o-",
                        markersize=5,
                        alpha=1,
                        label=key,
                        lw=1,
                        c="grey",
                    )

            if forces:

                try:
                    for ind, value in enumerate(kpt_forces[key]):
                        relative_forces = np.abs(
                            np.asarray(kpt_forces[key][ind][1])
                            - np.asarray(kpt_forces[key][-1][1])
                        )
                        ax_kpt_forces.plot(
                            len(kpt_forces[key][ind][1]) * [value[0]],
                            relative_forces,
                            alpha=0.2,
                            c=lines[-1].get_color(),
                        )
                        ax_kpt_forces.scatter(
                            value[0],
                            np.mean(relative_forces),
                            alpha=0.5,
                            c=lines[-1].get_color(),
                        )
                except:
                    print("Issue with {}: {}".format(key, kpt_form[key]))

        max_y = dict()
        for ind, points in enumerate(xpoints):
            for jnd, x in enumerate(points):
                if max_y.get(x) is None or ypoints[ind][jnd] > max_y[x]:
                    max_y[x] = ypoints[ind][jnd]

        max_x_vals = np.asarray(sorted(max_y.keys()))
        max_y_vals = np.asarray([max_y[x] for x in max_x_vals])

        if kwargs["energy_target"]:
            ax_kpt.axhline(kwargs["energy_target"], lw=0.5, c="r", ls="--")
            ax_kpt.fill_between(
                max_x_vals[np.where(max_y_vals <= kwargs["energy_target"])],
                max_y_vals[np.where(max_y_vals <= kwargs["energy_target"])],
                alpha=0.2,
                color="green",
            )
            ax_kpt.fill_between(
                max_x_vals[np.where(max_y_vals > kwargs["energy_target"])],
                max_y_vals[np.where(max_y_vals > kwargs["energy_target"])],
                alpha=0.2,
                color="red",
            )
        else:
            ax_kpt.fill_between(max_x_vals, max_y_vals, alpha=0.2, color="grey")

        if forces:
            ax_kpt_forces.set_ylabel("Force eV/A")
            if kwargs["force_range"]:
                ax_kpt_forces.set_ylim(0, kwargs["force_range"])
            if kwargs["force_target"]:
                ax_kpt_forces.axhline(kwargs["force_target"], lw=0.5, c="r", ls="--")

        if plot_cutoff:
            ax_kpt.set_xlim(ax_kpt.get_xlim()[1], ax_kpt.get_xlim()[0])
            ax_kpt.yaxis.tick_right()
            ax_kpt.yaxis.set_label_position("right")
            if forces:
                ax_kpt_forces.yaxis.tick_right()
                ax_kpt_forces.yaxis.set_label_position("right")
            if log:
                ax_kpt.set_ylabel("log(Relative energy difference (meV/atom))")
            else:
                ax_kpt.set_ylabel("Relative energy difference (meV/atom)")
            if kwargs["energy_range"]:
                ax_kpt.set_ylim(0, kwargs["energy_range"])

    plt.figlegend(
        lines,
        labels,
        loc="upper center",
        fontsize=10,
        ncol=2,
        frameon=True,
        fancybox=True,
        shadow=True,
    )
    plt.savefig("conv.png", bbox_inches="tight")
    plt.show()


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(prog="plot_convergence")
    parser.add_argument("--log", action="store_true", help="Plot log energies")
    parser.add_argument(
        "--show_chempots", action="store_true", help="Include chempots in plot"
    )
    parser.add_argument(
        "--only", type=str, help="Show only convergence of this seedname"
    )
    parser.add_argument(
        "--energy-range",
        "--energy_range",
        type=float,
        help="Plot up to this energy value",
    )
    parser.add_argument(
        "--force-range", "--force_range", type=float, help="Plot up to this force value"
    )
    parser.add_argument(
        "--energy-target",
        "--energy_target",
        type=float,
        help="Mark this convergence target on the plot",
    )
    parser.add_argument(
        "--force-target",
        "--force_target",
        type=float,
        help="Mark this convergence target on the plot",
    )
    args = parser.parse_args()
    kwargs = vars(args)
    if kwargs.get("only") is not None:
        kwargs["only"] = kwargs["only"].split(".")[0]
    try:
        cutoff = True
        kpts = True
        if not isdir("completed_cutoff"):
            cutoff = False
            cutoff_data = {}
            print("Did not find completed_cutoff folder, skipping cutoffs...")
        if not isdir("completed_kpts"):
            kpts = False
            kpt_data = {}
            print("Did not find completed_kpts folder, skipping kpts...")
        if not cutoff and not kpts:
            exit("Could not find any completed_$x folders!")
        cutoff = False
        kpts = False
        if isdir("completed_cutoff"):
            print("Parsing cutoffs...")
            cutoff = True
            cutoff_structure_files = get_files(
                "completed_cutoff", only=kwargs.get("only")
            )
            cutoff_data = get_data(cutoff_structure_files, conv_field="cut_off_energy")
        if isdir("completed_kpts"):
            print("Parsing kpts...")
            kpts = True
            kpt_structure_files = get_files("completed_kpts", only=kwargs.get("only"))
            kpt_data = get_data(kpt_structure_files, conv_field="kpoints_mp_spacing")
        plot_both(
            plot_cutoff=cutoff,
            plot_kpt=kpts,
            cutoff_data=cutoff_data,
            kpt_data=kpt_data,
            **kwargs
        )
    except:
        print_exc()
        print(
            "This script is rubbish, please contact me388@cam.ac.uk and tell him to fix it."
        )
