#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#

# Copyright (C) 2011 Mattias Slabanja <slabanja@chalmers.se>
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
# 02110-1301, USA.

import os
import sys
import optparse
import logging
import numpy as np

from itertools import count, islice
from functools import partial

import dsf.filon as filon
from dsf.binner import fixed_bin_averager
from dsf.handythread import foreach
from dsf.index import section_index
from dsf.output import create_mfile, create_pfile
from dsf.reciprocal import reciprocal_isotropic, reciprocal_line
from dsf.trajectory import get_itraj, iwindow

from multiprocessing import cpu_count

try:
    from scipy.interpolate import interp1d
except ImportError:
    # Fallback if SciPy is not available
    def interp1d(xp, yp):
        if len(yp.shape) == 1:
            return lambda x: np.interp(x, xp, yp)
        elif len(yp.shape) == 2:
            return lambda x: np.array([np.interp(x, xp, ypi) for ypi in yp])
        else:
            raise RuntimeError("yp can only be 1d or 2d")


class averager:
    """Naive special purpose averager class used in dynasor

    It assists with keeping track on how many data samples
    have been added to each slot.

    Ex:
    av = averager(2)
    av[0] += 10
    av[0] += 2
    av[1] += 3
    av.get_av() ->
    [6, 3]
    """
    def __init__(self, N_slots, initial=np.zeros(1)):
        assert(N_slots >= 1)
        self._N = N_slots
        self._data = [np.array(initial) for n in range(N_slots)]
        self._samples = np.zeros(N_slots)

    def __getitem__(self, key):
        return self._data[key]

    def __setitem__(self, key, val):
        self._data[key] = val
        self._samples[key] += 1

    def add(self, array, slot):
        self[slot] += array

    def get_single_av(self, slot):
        f = 1.0 / self._samples[slot]
        return f * self._data[slot]

    def get_av(self):
        return np.array([self.get_single_av(i) for i in range(self._N)])


hbar = 6.58211928e-1  # eV fs
pi = np.pi
two_pi = 2.0 * pi

if __name__ == '__main__':

    parser = optparse.OptionParser()

    iogroup = optparse.OptionGroup(parser,
                                   'Input/output options',
                                   'Options controlling input and output, '
                                   'files and fileformats.')
    iogroup.add_option('-f', '--trajectory', metavar='TRAJECTORY_FILE',
                       help='Molecular dynamics TRAJECTORY_FILE to be '
                       'analyzed. '
                       'Supported formats depends on VMD\'s molfile plugin '
                       'or gmxlib. As a fallback, a lammps-trajectory parser '
                       'implemented in Python is also available.')
    iogroup.add_option('-n', '--index', metavar='INDEX_FILE',
                       help='Optional index file (think Gromacs INI-style) for'
                       ' specifying atom types. Atoms are indexed from 1 up to'
                       ' N (total number of atoms). '
                       ' It is possible to index only a sub set of all atoms,'
                       ' and atoms can be indexed in more than one group.'
                       ' If no INDEX_FILE is provided, all atoms will be'
                       ' considered identical.')
    iogroup.add_option('', '--om', metavar='FILE',
                       help='Write output to FILE as a Matlab style m-file.')
    iogroup.add_option('', '--op', metavar='FILE',
                       help='Write output to FILE as a Python pickled file.')
    parser.add_option_group(iogroup)

    kspace = optparse.OptionGroup(parser,
                                  'General k-space options',
                                  'Options controlling general aspects for'
                                  ' how k-space should be sampled and'
                                  ' collected.')
    kspace.add_option('', '--k-sampling',
                      metavar='STYLE', default='isotropic',
                      help='Possible values are "isotropic" (default) '
                      'for sampling isotropic systems (as liquids), '
                      '"line" to sample uniformely along a certain direction '
                      'in k-space, '
                      '"explicit" for sampling on an explicit set of k-points')
    kspace.add_option('', '--k-bins',
                      metavar='BINS', type='int', default=80,
                      help='Number of "radial" bins to use (between 0 and '
                      'largest |k|-value) when collecting resulting '
                      'average. Default value is 80.')
    parser.add_option_group(kspace)

    kiso = optparse.OptionGroup(parser,
                                'Isotropic k-space sampling')
    kiso.add_option('', '--max-k-points', metavar='KPOINTS', type='int',
                    default=20000,
                    help='Maximum number of points used to sample k-space. '
                    'Puts an (approximate) upper limit by randomly selecting. '
                    'points. Default value is 20000.')
    kiso.add_option('', '--k-max', metavar='KMAX', type='float', default=60,
                    help='Largest k-value to consider (in "2*pi*nm^-1"). '
                    'Default value for KMAX is 60. ')
    parser.add_option_group(kiso)

    kline = optparse.OptionGroup(parser,
                                 'Line-style k-space sampling')
    kline.add_option('', '--k-direction', metavar='KDIRECTION',
                     help='Point in k-space to consider. '
                     'KPOINTS points will be evenly placed between '
                     '0,0,0 and KDIRECTION. Given as three '
                     'comma separated values.')
    kline.add_option('', '--k-points', metavar='KPOINTS', type='int',
                     default=100,
                     help='Numbers of uniformly distributed points between '
                     '0,0,0 and KDIRECTION to consider. Default value is 100.')
    parser.add_option_group(kline)

    # kexpl = optparse.OptionGroup(parser,
    #                             'Explicit k-space sampling')
    # kexpl.add_option('','--k-points-file', metavar='KPOINTS-FILE',
    #                  help='KPOINTS-FILE should contain each kpoint to '
    #                  'consider. ')
    # parser.add_option_group(kexpl)

    tgroup = optparse.OptionGroup(parser, 'Time-related options',
                                  'Options controlling timestep,'
                                  ' length and shape of trajectory frame'
                                  ' window, etc.')
    tgroup.add_option('', '--nt', metavar='TIME_CORR_STEPS', type='int',
                      help='Determines the length of the trajectory frame '
                      'window to use for time correlation calculation. '
                      'TIME_CORR_STEPS is expressed in "number of frames" and '
                      'e.g. determines the smallest frequency resolved. '
                      'If no TIME_CORR_STEPS is provided, only static (t=0) '
                      'correlations will be calculated')
    tgroup.add_option('', '--max-frames', metavar='FRAMES', type='int',
                      default=100,
                      help='Limits the total number of trajectory frames read '
                      'to FRAMES. '
                      'The default value is set to 100.')
    tgroup.add_option('', '--step', metavar='STEP', type='int', default=1,
                      help='Only use every (STEP)th trajectory frame.'
                      ' Default STEP is 1, meaning every frame is processed.'
                      ' STEP affects dt and hence the smallest time resolved.')
    tgroup.add_option('', '--stride', metavar='STRIDE', type='int', default=1,
                      help='Stride STRIDE frames between consecutive'
                      ' trajectory windows. This does not affect dt.'
                      ' If e.g. STRIDE > TIME_CORR_STEPS, some frames will'
                      ' be completely unused.')
    tgroup.add_option('', '--dt', metavar='DELTATIME',
                      help='Explicitly sets the time difference between two '
                      'consecutively processed trajectory frames to DELTATIME'
                      ' (femtoseconds). Useful when no time step information'
                      ' can be extracted from the trajectory file (e.g. when'
                      ' using molfileplugin).')
    parser.add_option_group(tgroup)

    options = optparse.OptionGroup(parser, 'General processing options')
    options.add_option('', '--calculate-self',
                       action='store_true', default=False,
                       help='Calculate the self-part, F_s, ...')
    parser.add_option_group(options)

    parser.add_option('', '--threads', type='int', default=0,
                      help='Number of threads to use. '
                      'The default value is taken from OMP_NUM_THREADS if it'
                      ' is set,  otherwise it is set to the number of'
                      ' available cpus.')
    parser.add_option('-q', '--quiet', action='count', default=0,
                      help='Increase quietness (opposite of verbosity).')
    parser.add_option('-v', '--verbose', action='count', default=0,
                      help='Increase verbosity (opposite of quietness).')

    options, args = parser.parse_args()
    quietness = options.quiet - options.verbose

    logger = logging.getLogger('dynasor')
    handler = logging.StreamHandler()
    handler.setFormatter(logging.Formatter(r'%(levelname)s: %(message)s'))
    logger.addHandler(handler)
    if quietness < 0:
        logger.setLevel(logging.DEBUG)
    elif quietness == 0:
        logger.setLevel(logging.INFO)
    elif quietness == 1:
        logger.setLevel(logging.WARN)
    elif quietness == 2:
        logger.setLevel(logging.ERROR)
    else:
        logger.setLevel(logging.CRITICAL)

    if options.trajectory is None:
        logger.error('A trajectory must be specified. Use option -f.\n')
        parser.print_help()
        sys.exit(1)

    if not options.om and not options.op:
        logger.warn('No output file specified. Results will not be saved')

    if options.threads != 0:
        num_threads = options.threads
    elif 'OMP_NUM_THREADS' in os.environ:
        num_threads = int(os.environ['OMP_NUM_THREADS'])
    else:
        num_threads = cpu_count()
    if num_threads < 1:
        logger.error('Number of threads must be > 0')
        sys.exit(1)

    # Affects rho_j_k
    os.environ['OMP_NUM_THREADS'] = str(num_threads)

    # Read the two first frames to set up references values
    # box size, number of different particles, time step length, etc
    try:
        f0, f1 = islice(get_itraj(options.trajectory, step=options.step),
                        2)
    except ValueError:
        logger.error('Failed to read two consecutive frames to determine '
                     'delta t. Is the trajectory long enough?')
        sys.exit(1)

    delta_t = f1['time'] - f0['time']

    if options.dt:
        try:
            delta_t = float(options.dt)
        except ValueError:
            logger.error('Could not parse provided dt ({}) as float!'
                         .format(options.dt))
            sys.exit(1)
        logger.warn('-- Explicitly setting delta time to {} fs.'
                    .format(delta_t))

    elif delta_t == 0.0:
        logger.warn('Delta time is 0.0 (meaning, cannot read info'
                    ' from trajectory)!')
        logger.warn('-- Defaulting to 1 fs'
                    ' [this will affect the output scale]')
        delta_t = 1

    if options.nt:
        assert(options.nt >= 1)
        N_tc = options.nt + (options.nt + 1) % 2
    else:
        N_tc = 1

    if N_tc > 1:
        logger.info('-- delta_t found to be %f [fs], time window {} [fs]'
                    .format(delta_t, delta_t * (options.nt - 1)))
        dw = two_pi / (options.nt * delta_t)
        logger.info('-- delta_omega, max_omega = {}, {} [fs^-1]'
                    .format(dw, dw * options.nt))

    # Do we get any velocities from the trajectory reader?
    if 'v' in f0:
        calculate_current = True
    else:
        calculate_current = False

    # Should the particle self correlations be calculated?
    if options.calculate_self:
        calculate_self = True
    else:
        calculate_self = False

    index = section_index(options.index, f0['N'])

    a, b, c = reference_box = f0['box']
    reference_volume = abs(np.dot(np.cross(a, b), c))
    particle_types = index.get_section_names()
    particle_counts = map(len, index.get_section_indices())
    particle_densities = [n / reference_volume for n in particle_counts]

    logger.info('Trajectory file: %s' % options.trajectory)
    logger.info('-- With a total of %i particles, %i types.' % (
            f0['N'], len(particle_types)))
    for i, t in enumerate(particle_types):
        logger.info('-- %d. %d %s' % (i + 1, particle_counts[i], t))

    logger.info('Simulation box is\n%s' % str(reference_box))

    style = options.k_sampling
    if style not in ('isotropic', 'line'):
        logger.error('Unknown style %s' % style)
        sys.exit(1)

    if style == 'line':
        # Sample on points along a line in k-space
        if options.k_direction is None:
            logger.error('k-direction must be specified in line mode.')
            sys.exit(1)

        try:
            k_direction = np.array(map(float, options.k_direction.split(',')))
        except ValueError:
            logger.error('k-direction must be specified as a valid comma '
                         'separated three-tuple.')
            sys.exit(1)

        rec = reciprocal_line(points=options.k_points,
                              k_direction=k_direction)

    elif style == 'isotropic':
        # Sample k-space without preference to direction
        rec = reciprocal_isotropic(reference_box,
                                   max_points=options.max_k_points,

                                   max_k=options.k_max)

    if len(rec.k_distance) > 1:
        logger.info('N kpoints = %i' % len(rec.k_distance))
    else:
        logger.warning('N kpoints = %i' % len(rec.k_distance))
    logger.info('k_max = {} --> x_min = {} [nm]'
                .format(rec.max_k, two_pi / rec.max_k))

    assert options.stride > 0
    N_stride = options.stride

    # function to use to "calculate rho(k)"
    f2 = rec.get_frame_process_function()
    # function to split particles into different index groups (types)
    f1 = index.get_section_split_function()  # Prerequisite for f2
    # apply this to each frame considered
    element_processor = lambda frame: f2(f1(frame))

    # The trajectory window iterator
    itraj_window = iwindow(get_itraj(options.trajectory,
                                     step=options.step,
                                     max_frames=options.max_frames),
                           width=N_tc,
                           stride=options.stride,
                           element_processor=element_processor)

    # TODO....
    # * Assert box is not changed during consecutive frames

    Ntypes = len(particle_types)

    m = count(0)
    pair_list = [(m.next(), i, j)
                 for i in range(Ntypes) for j in range(i, Ntypes)]
    pair_types = [particle_types[i] + '-' + particle_types[j]
                  for _, i, j in pair_list]

    z = np.zeros(len(rec.q_distance))
    F_k_t_avs = [averager(N_tc, z) for _ in pair_list]
    if calculate_current:
        Cl_k_t_avs = [averager(N_tc, z) for _ in pair_list]
        Ct_k_t_avs = [averager(N_tc, z) for _ in pair_list]
    if calculate_self:
        F_s_k_t_avs = [averager(N_tc, z) for _ in particle_types]

    def calc_corr(window, time_i):
        # Calculate correlations between two frames in the window
        f0 = window[0]
        fi = window[time_i]
        for m, i, j in pair_list:
            F_k_t_avs[m][time_i] += np.real(f0['rho_ks'][i]
                                            * fi['rho_ks'][j].conjugate())

        if calculate_current:
            for m, i, j in pair_list:
                Cl_k_t_avs[m][time_i] += np.real(f0['jz_ks'][i]
                                                 * fi['jz_ks'][j].conjugate())
                Ct_k_t_avs[m][time_i] += 0.5 * \
                    np.real(np.sum(f0['jper_ks'][i]
                                   * fi['jper_ks'][j].conjugate(), axis=0))

        if calculate_self:
            for i, F_s in enumerate(
                rec.process_specific_xs(
                    [(xi - x0) for xi, x0 in zip(fi['xs'], f0['xs'])])):
                F_s_k_t_avs[i][time_i] += np.real(F_s)

    # This is the "main loop"
    for window in itraj_window:
        logger.debug("processing window step %i to %i" % (window[0]['index'],
                                                          window[-1]['index']))
        # Have num_threads threads concurrently process the window
        foreach(partial(calc_corr, window),
                range(len(window)), threads=num_threads)

    # Extract correlation (all k-point) averages
    # and calculate average per 'radial' bin
    k_bins = fixed_bin_averager(rec.max_k, options.k_bins, rec.k_distance)
    k_bin_averager = partial(k_bins.bin, axis=1)

    F_k_t = map(k_bin_averager, [F.get_av() for F in F_k_t_avs])

    if calculate_current:
        Cl_k_t = map(k_bin_averager, [C.get_av() for C in Cl_k_t_avs])
        Ct_k_t = map(k_bin_averager, [C.get_av() for C in Ct_k_t_avs])

    if calculate_self:
        F_s_k_t = map(k_bin_averager, [C.get_av() for C in F_s_k_t_avs])
        for i, N in enumerate(particle_counts):
            F_s_k_t[i] *= (1.0 / np.sqrt(N))

    t = delta_t * np.arange(N_tc)
    k = k_bins.x.copy()
    k_bin_count = k_bins.bin_count.copy()

    output = []
    output += [(k, 'k', 'k-values (technically, bin centers) [nm^1]'),
               (t, 't', 'time values [fs]'),
               (k_bin_count, 'k_bin_count', 'Number of k-points per bin')]
    output += [(F_k_t[m], 'F_k_t_%i_%i' % (i, j),
                'Partial intermediate scattering function [time, k] ({})'
                .format(pair_types[m]))
               for m, i, j in pair_list]

    if calculate_current:
        output += [(Cl_k_t[m], 'Cl_k_t_%i_%i' % (i, j),
                    'Longitudinal current correlation [time, k] ({})'
                    .format(pair_types[m]))
                   for m, i, j in pair_list]
        output += [(Ct_k_t[m], 'Ct_k_t_%i_%i' % (i, j),
                    'Transversal current correlation [time, k] ({})'
                    .format(pair_types[m]))
                   for m, i, j in pair_list]

    if calculate_self:
        output += [(F_s_k_t[i], 'F_s_k_t_%i' % i,
                    'Self part of intermediate scattring function [time, k]')
                   for i in range(index.N_sections())]

    if len(k) > 1:
        # Create an odd number of linearly spaced k-points, ranging from
        # the "distance" of the smallest non-empty bin and up.
        k_ = k_bins.x_linspace
        k_ = k_[k_ >= k[1]]
        k_ = k_[k_ <= k[-1]]
        if not len(k_) % 2:
            k_ = k_[:-1]

        dr = two_pi / k[-1]
        r = np.arange(5 * dr, pi / k[1], dr)

        def F_to_G(F, pair_index):
            _, i, j = pair_list[pair_index]
            f = 1 / (r * 2 * pi ** 2 * particle_densities[j])
            kF_ = k_ * interp1d(k, F - 1)(k_)
            return f * filon.sin_integral(kF_, k_[1] - k_[0], r,
                                          k_[0], axis=1) + 1
        G_r_t = [F_to_G(F, i) for i, F in enumerate(F_k_t)]

        output += [(r, 'r', 'r-values for calculated G(r,t) [nm]')]
        output += [(G_r_t[m], 'G_r_t_%i_%i' % (i, j),
                    'Calculated partial van Hove function [time, r] ({})'
                    .format(pair_types[m]))
                   for m, i, j in pair_list]

    if len(t) > 2:
        w, S_k_w = zip(*[filon.fourier_cos(F, delta_t) for F in F_k_t])
        w = w[0]
        output += [(w, 'w', 'omega [fs^-1]')]
        output += [(S_k_w[m], 'S_k_w_%i_%i' % (i, j),
                    'Partial dynamical structure factor [omega, k] ({})'
                    .format(pair_types[m]))
                   for m, i, j in pair_list]

        if calculate_current:
            _, Cl_k_w = zip(*[filon.fourier_cos(C, delta_t) for C in Cl_k_t])
            _, Ct_k_w = zip(*[filon.fourier_cos(C, delta_t) for C in Ct_k_t])
            output += [(Cl_k_w[m], 'Cl_k_w_%i_%i' % (i, j),
                        'Longitudinal partial current correlation'
                        ' [omega, k] ({})'.format(pair_types[m]))
                       for m, i, j in pair_list]
            output += [(Ct_k_w[m], 'Ct_k_w_%i_%i' % (i, j),
                        'Transversal partial current correlation'
                        ' [omega, k] (%s)'.format(pair_types[m]))
                       for m, i, j in pair_list]

        if calculate_self:
            _, S_s_k_w = zip(*[filon.fourier_cos(F, delta_t) for F in F_s_k_t])
            output += [(S_s_k_w[i], 'S_s_k_w_%i' % i,
                        'Self part of partial dynamical structure factor'
                        ' [omega, k]')
                       for i, _ in enumerate(particle_types)]

    comment = 'Command line: ' + ' '.join(sys.argv)
    for fn, writer in ((options.om, partial(create_mfile, comment=comment)),
                       (options.op, create_pfile)):
        fn and writer(fn, output)
