#!/usr/bin/env python

import string, sys
import numpy as np
from numpy.linalg import inv
import math
import subprocess
import argparse
from argparse import RawTextHelpFormatter  #needed to go next line in the help text
import os
import re  #re.split(r'(\d+)',"O23") = ['O', '23', '']
from manage_crystal import Crys
from manage_crystal.periodic_table import ptab_atnum_inv, ptab_vdw_uff
from manage_crystal.file_parser import parse_from_filepath
from manage_crystal.file_writer import write_to_filepath
from manage_crystal.utils import is_number
#from pprint import pprint #pprint(vars(your_object))

parser = argparse.ArgumentParser(description='Tool to read, extract info and convert crystal files. (by D.Ongari)',
                                 formatter_class=RawTextHelpFormatter)

# File input / output
parser.add_argument(
    type=str,
    dest="inputfile",
    help="path to the input file to read\n" + "IMPLEMENTED: xyz(w/CELL),pdb,cssr,pwi,pwo,cif,xsf,axsf,subsys(CP2K),\n" +
    "             restart(CP2K),inp(CP2K),cube,POSCAR(VASP),car\n" + "             [NEXT: gaussian, dcd+atoms]")

parser.add_argument("-o",
                    "--output",
                    type=str,
                    dest="output",
                    default=None,
                    help="Output filename.extension or just the extension\n" +
                    "IMPLEMENTED: cif,pdb,cssr,xyz(w/CELL),pwi,subsys(CP2K),axsf,POSCAR")

parser.add_argument("-resp",
                    type=str,
                    dest="resp",
                    default=None,
                    help="Read the charges from a cp2k RESP file\n" + "(also checking if the atoms are the same)\n" +
                    "BC1: it read the first set of charges\n" + "BC2: Also a cp2k output file with charges is fine!\n")

parser.add_argument("-readcharge",
                    type=str,
                    dest="readcharge",
                    default=None,
                    help="Read the charges from a simple list")

parser.add_argument("-readrepeatcharge",
                    type=str,
                    dest="readrepeatcharge",
                    default=None,
                    help="Read the charges from REPEAT output of QE")

# Read settings
parser.add_argument("-pseudopw", type=str, dest="pseudopw", default="pbe", help="Pseudo for the .pwi output")

parser.add_argument("-bscp2k",
                    "--basisset-cp2k",
                    type=str,
                    dest="bscp2k",
                    default="DZVP-MOLOPT-SR-GTH",
                    help="Gaussian Basis Set for CP2K")

parser.add_argument("-potcp2k",
                    "--potential-cp2k",
                    type=str,
                    dest="potcp2k",
                    default="GTH-PBE",
                    help="Pseudo potential for CP2K")

parser.add_argument("-fract",
                    "--fractional-coordinates",
                    action="store_true",
                    dest="fract",
                    default=False,
                    help="Force the writer to print fractional coordinates (if possible)")

parser.add_argument("-chargenull",
                    action="store_true",
                    dest="chargenull",
                    default=False,
                    help="Delete the charge of the atoms")

# Verbosity section
parser.add_argument("-silent", action="store_true", dest="silent", default=False, help="No output info on the screen")

parser.add_argument("-show",
                    action="store_true",
                    dest="show",
                    default=False,
                    help="Show all the info\n" + "[skip -silent]")

parser.add_argument("-showonly",
                    choices=[None, 'elements', 'cell', 'CELL', 'xyz', 'fract', 'charge'],
                    dest="showonly",
                    default=None,
                    help="Show only the required info\n" + "[skip -silent]")

# Find/compute stuff section
parser.add_argument("-cutoff",
                    type=float,
                    dest="cutoff",
                    default=None,
                    help="Automatically extend the UC so that the cutoff is respected\n" +
                    "(TIP: use -cutoff 0 to just know the perpendicular widths!)")

parser.add_argument("-cupw", action="store_true", dest="cupw", default=False, help="Look for a Copper PaddleWheel")

parser.add_argument("-ovlp",
                    action="store_true",
                    dest="ovlp",
                    default=False,
                    help="Look for an overlap and remove the overlapping atom")

parser.add_argument(
    "-ovlpthr",
    type=float,
    dest="ovlpthr",
    default=0.7,  #bond lengths: H-H = 0.74, C-H = 1.09, NH = 1.0, O-H = 0.96
    help="Threshold (Angstrom) for checking atomic overlaps.")

mod_coord_cell = parser.add_argument_group("Options to modify coord./cell")
mod_coord_cell.add_argument("-x",
                            type=int,
                            dest="multipl_x",
                            default=1,
                            help="Extend in the x direction, by the specified times")

mod_coord_cell.add_argument("-y",
                            type=int,
                            dest="multipl_y",
                            default=1,
                            help="Extend in the y direction, by the specified times")

mod_coord_cell.add_argument("-z",
                            type=int,
                            dest="multipl_z",
                            default=1,
                            help="Extend in the z direction, by the specified times")

mod_coord_cell.add_argument("-transl",
                            type=float,
                            nargs=3,
                            dest="transl",
                            default=None,
                            help="x y z translation in Angs")

mod_coord_cell.add_argument("-rotaxis",
                            choices=[None, "up", "down"],
                            dest="rotaxis",
                            default=None,
                            help="Rotate the axis. 'up': xyz>zxy, 'down': xyz>yzx")

mod_coord_cell.add_argument("-randomize",
                            type=float,
                            dest="randomize",
                            default=None,
                            help="Randomize the geometry by a gaussian\n" + "with the specified delta (angs)")

parser.add_argument("-chkmetalcharge",
                    action="store_true",
                    dest="chkmetalcharge",
                    default=False,
                    help="Check if the charge on a metal (see list) is neg.\n" + "[skip -silent]")

parser.add_argument("-normalizecharges",
                    action="store_true",
                    dest="normalizecharges",
                    default=False,
                    help="Normalize the charges to have a null total charge.")

parser.add_argument("-bv",
                    action="store_true",
                    dest="bondvalence",
                    default=False,
                    help="Apply the bond valence method to compute oxidation states.")

parser.add_argument(
    "-tm "
    "--tailormade",
    type=int,
    dest="tm",
    default=0,
    help="Tailor made settings for parsing and writing:\n" + " 1 - Deprecated (parse DDEC-CoRE-MOF .cif)\n"
    " 2 - Deprecated (write .cif for EQeq)\n"
    " 3 - Parse .xyz for B. Wells Qeq program\n"
    " 4 - Write .xyz for B. Wells Qeq program\n"
    " 5 - Deprecated (.xyz for B. Wells Qeq program, w/ zero formal charges)\n"
    " 6 - Deprecated (parse .cif from EGULP)\n")

args = parser.parse_args()

########################################################################## INPUT
# reading input file: name and format (notice that if there is a path it becomes
# part of the name, to have the output in the same place)
inputfilename = os.path.splitext(args.inputfile)[0]
crys = parse_from_filepath(args.inputfile, args.tm)
######################################################################## Charges

if not args.resp == None:
    if not all(x == 0 for x in crys.atom_charge):
        if not args.silent:
            print(" ... THERE WERE ALREADY CHARGES BUT I'M OVERWRITING THEM!")
    with open(args.resp, "r") as file:
        i = 0
        for line in file:
            data = line.split()
            if len(data) > 3 and data[0] == 'RESP' \
             and data[2] == crys.atom_type[i]:
                crys.atom_charge[i] = float(data[3])
                i += 1
            if i == crys.natom:
                break

if not args.readcharge == None:
    if not all(x == 0 for x in crys.atom_charge):
        if not args.silent:
            print(" ... THERE WERE ALREADY CHARGES BUT I'M OVERWRITING THEM!")
    with open(args.readcharge) as file:
        i = 0
        for line in file:
            crys.atom_charge[i] = float(line.split()[0])
            i += 1

if not args.readrepeatcharge == None:
    if not all(x == 0 for x in crys.atom_charge):
        if not args.silent:
            print(" ... THERE WERE ALREADY CHARGES BUT I'M OVERWRITING THEM!")
    with open(args.readrepeatcharge) as file:
        if not args.silent:
            print("*** CHARGES from QE>REPEAT.out: multiplying by -0.5")
        i = 0
        for line in file:
            data = line.split()
            if (i - 17) >= 0 and (i - 17) < natoms:  # Header of REPEAT.out
                crys.atom_charge[i - 17] = float(data[6]) * (-0.5)
            i += 1

if args.chargenull:
    if not args.silent:
        print("*** chargenull: DELETING ALL THE CHARGES! ***")
    crys.atom_charge = [0] * crys.natom

if args.normalizecharges:
    if not args.silent:
        print("")
    if not args.silent:
        print("*** NORMALIZING CHARGES ***")
    pos_charge = 0
    neg_charge = 0
    for i in range(crys.atom_natoms):
        if crys.atom_charge[i] > 0:
            pos_charge += crys.atom_charge[i]
        else:
            neg_charge += crys.atom_charge[i]

    tot_charge = pos_charge + neg_charge
    tot_abs = pos_charge - neg_charge
    pos_fract = pos_charge / tot_abs

    if not args.silent:
        print("total charge: %f" % tot_charge)
    if not args.silent:
        print("positive charges: %f" % pos_charge)
    if not args.silent:
        print("negative charges: %f" % neg_charge)
    if not args.silent:
        print("total absolute ch.: %f" % tot_abs)

    for i in range(crys.atom_natoms):
        if crys.atom_charge[i] > 0:
            crys.atom_charge[i] = crys.atom_charge[i] - tot_charge * pos_fract * crys.atom_charge[i] / pos_charge
        else:
            crys.atom_charge[i] = crys.atom_charge[i] - tot_charge * (1 - pos_fract) * crys.atom_charge[i] / neg_charge

########################################################### Coordinates and cell
if crys.inp_matrix:
    if not args.silent:
        print("\n Cell parsed as matrix.")
    crys.compute_la_from_matrix()
elif crys.inp_lengths_angles:
    if not args.silent:
        print("\n Cell parsed as lengths and angles.")
    crys.compute_matrix_from_la()

if crys.inp_xyz:
    if not args.silent:
        print(" \n Atomic coordinates parsed as cartersian.")
    crys.compute_fract_from_xyz()
elif crys.inp_fract:
    if not args.silent:
        print(" \n Atomic coordinates parsed as fractional.")
    crys.compute_xyz_from_fract()

if not crys.matrix_alligned:
    if not args.silent:
        print(" \n Matrix not alligned to cartesian axis: fixed.")
    crys.fix_cell_notalligned()

################################################## APPLY TRANSLATION / RANDOMIZE
if args.transl != None:
    if not args.silent:
        print("\n*** TRANSLATING coordinates by %f %f %f Angs" % (args.transl[0], args.transl[1], args.transl[2]))
    crys.transl_coord(args.transl)

if args.randomize != None:
    if not args.silent:
        print("\n*** RANDOMIZING XYZ coordinates by a normal distribution " + "with delta=%f Angs" % args.randomize)
    crys.randomize_coord(delta=args.randomize)

if args.rotaxis != None:
    if not args.silent:
        print("\n*** ROTATING axis, direction: %r" % args.rotaxis)
    if not args.silent:
        print("*** Warning: the cell matrix is not rotated")
    crys.rotate_axis(up=(args.rotaxis == "up"))

################################################### CUTOFF TEST & CELL EXPANSION
multipl_length = [0] * 3
multipl_length[0] = args.multipl_x
multipl_length[1] = args.multipl_y
multipl_length[2] = args.multipl_z

if not args.cutoff == None:  #copied from raspa/framework.c/CellProperties(line:6184)
    if any(x > 1 for x in multipl_length):
        sys.exit("WARNING: Why did you ask for both -cutoff" + "and -x -y -z ? EXIT.")

    crys.compute_perp_width()
    print("\nCUTOFF_TEST | Cutoff: %.1f" % (args.cutoff))
    print("CUTOFF_TEST | Cell perpendicular widths: %.3f %.3f %.3f" %
          (crys.perp_width[0], crys.perp_width[1], crys.perp_width[2]))

    # compute how big the cell must be
    for k in range(3):
        multipl_length[k] = int(math.ceil(2 * args.cutoff / crys.perp_width[k]))

    if any(x > 1 for x in multipl_length):
        print("CUTOFF_TEST | Expansion_done: %d %d %d for %s" %
              (multipl_length[0], multipl_length[1], multipl_length[2], args.inputfile))
    else:
        print("CUTOFF_TEST | Expansion_unnecesary: 1 1 1 for %s" % (args.inputfile))

for k in range(3):
    if multipl_length[k] > 1:
        crys.expand_k_dir(k, multipl_length[k])

################################################################### COMPUTE INFO
if not args.silent:
    print()
for element in crys.element_list:
    if not args.silent:
        print("{0:>5} {1:3} atoms".format(crys.element_count[element], element))
if not args.silent:
    print(" ---- --- ----- ")
if not args.silent:
    print('{0:>5} {1:3} atoms'.format(crys.natom, 'tot'))
if not args.silent:
    print("\nVolume: %.3f (Angtrom^3/u.c.)" % crys.compute_volume_from_la())
weight, rho_kgm3 = crys.compute_weight_density()
if not args.silent:
    print("Density: %.5f (kg/m3), %.5f (g/cm3), %.5f (g/molUC)" % (rho_kgm3, rho_kgm3 / 1000, weight))
if not args.silent:
    print("Conversion: 1 molec./u.c. = %.5f (mol/kg)" % (1000 / weight))
if all(x == 0 for x in crys.atom_charge):
    if not args.silent:
        print("\nNET_CHARGE: all the charges are zero or not provided.")
elif abs(sum(crys.atom_charge)) < +0.001:
    if not args.silent:
        print("\nNET_CHARGE: negligible (|sum|<0.001).")
else:
    if not args.silent:
        print("\nNET_CHARGE: nonzero (%.3f). ***WARNING***" % sum(crys.atom_charge))

#check negative charge on metals [skip -silent]
if args.chkmetalcharge:
    metal_list = [
        "Li", "Be", "Na", "Mg", "Al", "K", "Ca", "Sc", "Ti", "V", "Cr", "Mn", "Fe", "Co", "Ni", "Cu", "Zn", "Ga", "Rb",
        "Sr", "Y", "Zr", "Nb", "Mo", "Tc", "Ru", "Rh", "Pd", "Ag", "Cd", "In", "Sn", "Cs", "Ba", "La", "Hf", "Ta", "W"
    ]
    found_met = False
    found_met_neg = False
    found_met_nonzero = False
    found_met_notnumber = False
    for i, element in enumerate(crys.atom_element):
        if (element in metal_list):
            found_met = True
            if math.isnan(crys.atom_charge[i]):
                print("CHK_METAL_CHARGE: not_number >>> %s=%s" % (element, crys.atom_charge[i]))
                found_met_notnumber = True
            elif crys.atom_charge[i] < 0:
                print("CHK_METAL_CHARGE: found_negative >>> %s=%.3f" % (element, crys.atom_charge[i]))
                found_met_neg = True
            elif crys.atom_charge[i] > 0:
                found_met_nonzero = True
        if found_met_notnumber or found_met_neg:
            break
    if not found_met:
        print("CHK_METAL_CHARGE: no_metals")
    if found_met and not found_met_neg and not found_met_notnumber and not found_met_nonzero:
        print("CHK_METAL_CHARGE: all_zero")
    if found_met and not found_met_neg and not found_met_notnumber and found_met_nonzero:
        print("CHK_METAL_CHARGE: ok_positive")
if not args.silent:
    print("Tot. electrons: %d" % crys.compute_nelectron())
if args.showonly == 'elements':
    print(*crys.element_list, sep=', ')

if args.ovlp:
    overlaps = 0
    ovlp_list = []
    for i in range(crys.natom):
        if not i in ovlp_list:  # exclude atoms that I already know to overlap
            for j in range(i + 1, crys.natom):
                dist = crys.dist_ij(i, j)
                if (dist < args.ovlpthr):
                    print("Overlap found (dist=%.2f<%.2f) between:" % (dist, args.ovlpthr))
                    print("%4d %3s %9.5f %9.5f %9.5f " %
                          (i, crys.atom_element[i], crys.atom_xyz[i][0], crys.atom_xyz[i][1], crys.atom_xyz[i][2]))
                    print("%4d %3s %9.5f %9.5f %9.5f " %
                          (j, crys.atom_element[j], crys.atom_xyz[j][0], crys.atom_xyz[j][1], crys.atom_xyz[j][2]))
                    overlaps += 1
                    ovlp_list.append(j)
    crys.remove_atoms(ovlp_list)

    print("OVERLAPS FOUND: %d" % overlaps)

if args.bondvalence:
    crys.print_bondvalence()

################################################################ SHOW & SHOWONLY
if args.show:
    print("\nCell matrix ----------------------------------------\n")
if args.show or args.showonly == "cell":
    for k in range(3):
        print("     %10.5f %10.5f %10.5f" % (crys.matrix[k][0], crys.matrix[k][1], crys.matrix[k][2]))
if args.show:
    print("\nCell lengths and angles ----------------------------\n")
if args.show or args.showonly == "CELL":
    print(" %10.5f  %10.5f  %10.5f" % (crys.length[0], crys.length[1], crys.length[2]))
    print(" %10.5f  %10.5f  %10.5f" % (crys.angle_deg[0], crys.angle_deg[1], crys.angle_deg[2]))
if args.show:
    print("\nCartesian coordinates ------------------------------\n")
if args.show or args.showonly == "xyz":
    for i in range(crys.natom):
        print("%3s %10.5f %10.5f %10.5f " %
              (crys.atom_element[i], crys.atom_xyz[i][0], crys.atom_xyz[i][1], crys.atom_xyz[i][2]))
if args.show:
    print("\nFractional coordinates -----------------------------\n")
if args.show or args.showonly == "fract":
    for i in range(crys.natom):
        print("%3s %10.5f %10.5f %10.5f " %
              (crys.atom_element[i], crys.atom_fract[i][0], crys.atom_fract[i][1], crys.atom_fract[i][2]))
if args.show:
    print("\nCharges --------------------------------------------\n")
if args.show or args.showonly == "charge":
    for i in range(crys.natom):
        print("%3s %10.5f" % (crys.atom_element[i], crys.atom_charge[i]))

################################################################### WRITE OUTPUT
if args.output == None:
    outputfile = 'NOTHING'
else:
    if len(args.output.split(".")) > 1:  #output defined as name.format
        outputfilename = os.path.splitext(args.output)[0]
        outputformat = os.path.splitext(args.output)[1][1:]
        outputfile = outputfilename + "." + outputformat
    else:  #output defined as format
        outputfilename = inputfilename
        outputformat = args.output
        outputfile = outputfilename + "." + outputformat

    if outputformat == "pwi":
        print("QE input using the pseudo: %s" % (args.pseudopw))
    if outputformat == "subsys":
        print("CP2K input using the BASIS_SET: %s" % (args.bscp2k))
        print("CP2K input using the POTENTIAL: %s" % (args.potcp2k))

    write_to_filepath(crys, outputfile, args.tm, args.pseudopw, args.bscp2k, args.potcp2k, args.fract)

if not args.silent:
    print("\n Convert %s to %s\n" % (args.inputfile, outputfile))
