#!/usr/bin/env python
""" This script computes, plots and optionally saves PXRD patterns from cif input files. """
import argparse
import time
from traceback import print_exc
import matplotlib.pyplot as plt
from matador import script_epilog
from matador.scrapers import cif2dict
from matador.plotting.pxrd_plotting import plot_pxrd
from matador.crystal import Crystal
from matador.utils.print_utils import print_failure


def compute_pxrd(**kwargs):
    """Take res/cif files from command-line, calculate PXRD, then plot or save."""
    strucs = []
    seeds = kwargs.get("seeds")
    if isinstance(seeds, str):
        seeds = [seeds]
    broadening_width = kwargs.get("broadening_width")
    wavelength = kwargs.get("wavelength")
    theta_m = kwargs.get("theta_m")
    two_theta_range = kwargs.get("two_theta_range")

    for _file in seeds:
        start = time.time()
        struc, success = cif2dict(_file, fail_fast=False, verbosity=0)
        elapsed = time.time() - start
        if success:
            print(f"Loaded CIF {_file} in {elapsed:3f} s")
        else:
            print_failure(f"Error parsing {_file}")
            print(struc)
            exit()
        start = time.time()
        try:
            crystal = Crystal(struc)
            strucs.append(crystal)
        except Exception:
            print(f"Error loading {_file} as Crystal object:")
            print_exc()
            exit()
        elapsed = time.time() - start

    for ind, doc in enumerate(strucs):
        start = time.time()
        doc.calculate_pxrd(
            wavelength=wavelength,
            theta_m=theta_m,
            two_theta_bounds=two_theta_range,
            lorentzian_width=broadening_width,
            progress=doc.num_atoms > 100,  # show progress bar if cell is large
        )
        elapsed = time.time() - start
        print(f"Computed PXRD for {seeds[ind]} in {elapsed:3f} s")

    if kwargs.get("plot") or kwargs.get("savefig"):
        labels = seeds
        if kwargs.get("spg_labels"):
            labels = None
        plot_pxrd(
            [doc.pxrd for doc in strucs], rug=kwargs.get("rugplot"), labels=labels
        )
        if kwargs.get("savefig"):
            plt.savefig(kwargs.get("savefig"))
        else:
            plt.show()

    if kwargs.get("save_patterns"):
        for doc in strucs:
            doc.pxrd.save_pattern(doc.root_source + "_pxrd_pattern.dat")

    if kwargs.get("save_peaks"):
        for doc in strucs:
            doc.pxrd.save_peaks(doc.root_source + "_pxrd_peaks.dat")

    if kwargs.get("save_res"):
        from matador.export import doc2res

        for doc in strucs:
            doc2res(doc, doc.root_source + ".res", info=False)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Compute, plot and export PXRD patterns from CIF file inputs.",
        epilog=script_epilog,
    )
    parser.add_argument(
        "-l",
        "--wavelength",
        type=float,
        default=1.5406,
        help="the incident X-ray wavelength in Angstrom (DEFAULT: 1.5406 Å, i.e. CuKα)",
    )
    parser.add_argument(
        "-bw",
        "--broadening_width",
        type=float,
        default=0.03,
        help="the width of broadening to apply to each peak",
    )
    parser.add_argument(
        "-tm",
        "--theta_m",
        type=float,
        default=0.0,
        help="the 2θ monochromator angle in degrees (DEFAULT: 0 degrees)",
    )
    parser.add_argument(
        "--plot", action="store_true", help="show a plot of the PXRD patterns"
    )
    parser.add_argument(
        "--savefig", type=str, help='save a plot to this file, e.g. "pxrd.pdf"'
    )
    parser.add_argument(
        "-t",
        "--two_theta_range",
        nargs=2,
        type=float,
        help="the 2θ range to use for plotting/calculating the pattern (DEFAULT: 0 90)",
    )
    parser.add_argument(
        "--spg_labels",
        action="store_true",
        help="label with computed spacegroup-formula instead of filename",
    )
    parser.add_argument(
        "--save_res",
        action="store_true",
        help="save a res file with a closer interpretation of the structure used",
    )
    parser.add_argument(
        "--save_patterns",
        action="store_true",
        help="save a .dat file with the xy pattern for each structure",
    )
    parser.add_argument(
        "--save_peaks",
        action="store_true",
        help="save a .txt file per structure with a list of peaks",
    )
    parser.add_argument(
        "--rugplot",
        action="store_true",
        help="additionally plot peak positions as a rug plot",
    )
    parser.add_argument(
        "seeds", nargs="+", type=str, help="list of structures to compute"
    )

    parsed_kwargs = vars(parser.parse_args())
    compute_pxrd(**parsed_kwargs)
    print("Done!")
