#!/usr/bin/env python

import os
import sys
import subprocess
import argparse
import numpy as np
from argparse import RawTextHelpFormatter  # to go next line in the help text
from manage_crystal.file_parser import parse_from_filepath, parse_dcd_snapshot
from manage_crystal.file_parser import parse_dcd_header
from manage_crystal.file_writer import write_to_filepath

####################################################################### Argparse
parser = argparse.ArgumentParser(description='Tool to split, append and convert trajectories. (by D.Ongari)',
                                 formatter_class=RawTextHelpFormatter)

parser.add_argument(type=str, dest="inputfile", help="path to the input file to read\n" + "IMPLEMENTED: pdb, ddec")

parser.add_argument("-o",
                    "--output",
                    required=True,
                    type=str,
                    dest="output",
                    default=None,
                    help="Output filename.extension or just the extension\n" + "IMPLEMENTED: xyz, axsf, dcd")

parser.add_argument("-a",
                    "--append",
                    type=str,
                    nargs=3,
                    dest="append",
                    default=None,
                    help="Trajectories (of the same format) to append to the input one.\n" +
                    "More than one file can be specified and appended.")

parser.add_argument("-sss",
                    "-savesnapshots",
                    action="store_true",
                    dest="sss",
                    default=False,
                    help="Save also the single snapshots of the trajectory")

parser.add_argument("-dt",
                    "-dcdtypes",
                    type=str,
                    dest="dcdtypesfile",
                    default=None,
                    help="Specify another file containing the atomic types for dcd\n" +
                    "Any format readable by manage_crystal is valid.")

args = parser.parse_args()

################################################## Manage input/output filenames
# Open input file
if not os.path.isfile(args.inputfile):
    sys.exit("ERROR: The file %s doesn't exist!" % args.inputfile)
inpfilename = os.path.splitext(args.inputfile)[0]
inpformat = os.path.splitext(args.inputfile)[1][1:]
inpfile = "{}.{}".format(inpfilename, inpformat)
file = open(inpfile, 'r')

# Parse output filename
if len(args.output.split(".")) > 1:  #output defined as name.format
    outfilename = os.path.splitext(args.output)[0]
    outformat = os.path.splitext(args.output)[1][1:]
else:  #output defined as format
    outfilename = inpfilename
    outformat = args.output
outfile = "{}.{}".format(outfilename, outformat)

################################ Split and convert snapshots from the trajectory
if inpformat == 'pdb':
    ss = 0  #snapshot count
    line = None
    # Split snapshots
    while True:
        line = file.readline()
        data = line.split()
        if line == "":
            print("Extracted %d snapshots" % ss)
            break
        elif len(data) > 0 and data[0] == "MODEL":
            ssfile = "{0}-ss{1:05d}.pdb".format(outfilename, ss)
            ssfilew = open(ssfile, 'w')
        elif len(data) > 0 and data[0] == "ENDMDL":
            print("END", end="", file=ssfilew)
            ssfilew.close()
            ss += 1
        else:
            print(line.strip(), end="\n", file=ssfilew)
    file.close()
    # Convert snapshots
    print("Converting %s to %s " % (inpformat, outformat), end="")
    for iss in range(ss):
        print(".", end="")
        sys.stdout.flush()
        ssinp = "{0}-ss{1:05d}.pdb".format(outfilename, iss)
        ssout = "{0}-ss{1:05d}.{2}".format(outfilename, iss, outformat)
        crys = parse_from_filepath(ssinp, 0)
        crys.afterparse_basic()
        write_to_filepath(crys, ssout, 0, None, None, None)
    print(" all done!")

if inpformat == 'dcd':
    if args.dcdtypesfile == None:
        sys.exit("WARNING: to read .dcd trajectories, you need to specify " +
                 "another file that contains the atomic types. Use -dt. EXIT.")
    if not os.path.isfile(args.inputfile):
        sys.exit("ERROR: The file %s doesn't exist!" % args.inputfile)
    # Read the atom types and number
    crys = parse_from_filepath(args.dcdtypesfile, 0)
    crys.compute_atom_count()
    # Read the dcd trajectory
    inpfileread = open(inpfile, 'rb')
    junkheader = parse_dcd_header(inpfileread)
    ss = 0
    EOF = False
    while True:
        EOF = parse_dcd_snapshot(inpfileread, crys)
        if EOF:
            break
        else:
            crys.afterparse_basic()
            ssout = "{}-ss{:05d}.{}".format(outfilename, ss, outformat)
            write_to_filepath(crys, ssout, 0, None, None, None, None)
            ss += 1
    print("Extracted %d snapshots" % ss)

############################################ Finally combine the temporary files

# For xyz files simply append ss
if outformat == "xyz":
    with open(outfile, 'w') as outfilew:
        for iss in range(ss):
            ssfileout = "{0}-ss{1:05d}.{2}".format(outfilename, iss, outformat)
            with open(ssfileout) as ssfiler:
                outfilew.write(ssfiler.read())

# For axsf files write the header and append the cutted ss
if outformat == "axsf":
    with open(outfile, 'w') as outfile:
        print("ANIMSTEPS %s" % ss, file=outfile)
        print("CRYSTAL %s" % ss, file=outfile)
        for iss in range(ss):
            print("PRIMVEC %s" % (iss + 1), file=outfile)
            fname = "{0}-ss{1:05d}.{2}".format(outfilename, iss, outformat)
            with open(fname) as file:
                for i, line in enumerate(file):  #skip header
                    if i > 2 and i < 6:
                        print(line.strip(), file=outfile)
            print("PRIMCOORD %s" % (iss + 1), file=outfile)
            with open(fname) as file:
                for i, line in enumerate(file):  #skip header
                    if i > 6:
                        print(line.strip(), file=outfile)

############################################################### Delete snapshots
if not args.sss:
    print("Deleting snaphsots (use -sss to save them)")
    for iss in range(ss):
        ssinp = "{0}-ss{1:05d}.{2}".format(outfilename, iss, inpformat)
        ssout = "{0}-ss{1:05d}.{2}".format(outfilename, iss, outformat)
        if os.path.exists(ssinp):
            os.remove(ssinp)
        if os.path.exists(ssout):
            os.remove(ssout)
else:
    print("Saved snaphsots")
