#!python
import mpi4py.MPI

import argparse
import sys

import numpy

from expectra.exafs import exafs_first_shell, exafs_multiple_scattering
from expectra.io import read_xdatcar, read_con

COMM_WORLD = mpi4py.MPI.COMM_WORLD

def mpiexcepthook(type, value, traceback):
    sys.__excepthook__(type, value, traceback)
    sys.stderr.write("exception occured on rank %i\n" % COMM_WORLD.rank)
    COMM_WORLD.Abort()
sys.excepthook = mpiexcepthook

def main():

    parser = argparse.ArgumentParser()

    parser.add_argument('--first-shell', action='store_true', default=True,
            help='a single scattering calculation that uses an ' + \
            'automatically calculated reference path ' + \
            '(default: %(default)s)')

    parser.add_argument('--neighbor-cutoff', type=float, metavar='DISTANCE',
            help='1st neighbor cutoff distance (default: %(default)s)', 
            default=3.4)

    parser.add_argument('--multiple-scattering', action='store_true')
    parser.add_argument('--rmax', type=float, metavar='DISTANCE', 
            default=6.0, help='maximum scattering half-path length')

    parser.add_argument('--S02', type=float, metavar='FACTOR', 
            default=1.0, help='amplitude reduction factor')
    parser.add_argument('--energy-shift', type=float, metavar='ENERGY', 
            default=0.0, help='energy shift to apply in eV')
    parser.add_argument('--absorber', type=str, metavar='ELEMENT', 
            help='atomic symbol of the xray absorber')
    parser.add_argument('--ignore-elements', type=str, metavar='ELEMENTS',
            help='comma delimited list of elements to ignore in the ' + \
            'scattering calculation')
    parser.add_argument('--edge', type=str, help='one of K, L1, L2, L3',
                        default='K')

    parser.add_argument('--skip', type=int, default=0,
            help='number of frames to skip at the beginning')
    parser.add_argument('--every', type=int, default=1,
            help='number of frames to between each step')
    parser.add_argument('trajectories', metavar='TRAJ', nargs='+',
            help='trajectory file (POSCAR, con, xyz)')

    args = parser.parse_args()

    if args.ignore_elements:
        args.ignore_elements = args.ignore_elements.split(',')

    if args.multiple_scattering == True:
        args.first_shell = False

    trajectory = []
    if COMM_WORLD.rank == 0:
        for filename in args.trajectories:
            print('reading', filename)
            if filename[-3:] == 'con':
                trajectory_part = read_con(filename)[args.skip::args.every]
            else:
                trajectory_part = read_xdatcar(filename, args.skip, args.every)
            n = len(trajectory_part)
            print('read %4i configurations from %s' % (n, filename))
            trajectory += trajectory_part
    trajectory = COMM_WORLD.bcast(trajectory)

    args.absorber = get_default_absorber(trajectory[0], args)

    k, chi = exafs_trajectory(args, trajectory)
    save_result(k, chi)

def save_result(k, chi):
    if COMM_WORLD.rank != 0: return
    print('saving result to chi.dat')
    f = open('chi.dat', 'w')
    for i in range(len(k)):
        f.write("%6.3f %16.8e\n" % (k[i], chi[i]))
    f.close()

def get_default_absorber(atoms, args):
    symbols = set(atoms.get_chemical_symbols())
    if args.absorber:
        if args.absorber not in symbols:
            print('ERROR: --absorber %s is not in the system' % args.absorber)
            sys.exit(2)
        else:
            return args.absorber
    if args.ignore_elements:
        symbols -= set(args.ignore_elements)
    if len(symbols) == 1:
        return list(symbols)[0]
    else:
        print('ERROR: must specify --absorber if more than one chemical specie')
        sys.exit(2)

def exafs_trajectory(args, trajectory):
    if args.multiple_scattering:
        k, chi = exafs_multiple_scattering(args.S02, args.energy_shift, 
                args.absorber, args.ignore_elements, args.edge, args.rmax, 
                trajectory)
    elif args.first_shell:
        k, chi = exafs_first_shell(args.S02, args.energy_shift, 
                args.absorber, args.ignore_elements, args.edge, 
                args.neighbor_cutoff, trajectory)

    return k, chi

if __name__ == '__main__':
    main()
