#!/home/travis/miniconda/bin/python

import math
import numpy as np
from phonopy.cui.settings import fracval
from phonopy.phonon.tetrahedron_mesh import TetrahedronMesh
from phonopy.phonon.dos import NormalDistribution
from anharmonic.phonon3.triplets import get_ir_grid_points

epsilon = 1.0e-8

class GammaDOS(object):
    def __init__(self,
                 gamma,
                 frequencies,
                 ir_grid_weights,
                 num_fpoints=200):
        self._gamma = gamma
        self._frequencies = frequencies
        self._ir_grid_weights = ir_grid_weights
        self._num_fpoints = num_fpoints
        self._set_frequency_points()
        self._gdos = np.zeros(
            (len(gamma), len(self._frequency_points), 2), dtype='double')

    def get_gdos(self):
        return self._frequency_points, self._gdos
    
    def _set_frequency_points(self):
        min_freq = np.min(self._frequencies)
        max_freq = np.max(self._frequencies) + epsilon
        self._frequency_points = np.linspace(min_freq,
                                             max_freq,
                                             self._num_fpoints)

class GammaDOSsmearing(GammaDOS):
    def __init__(self,
                 gamma,
                 frequencies,
                 ir_grid_weights,
                 sigma=None,
                 num_fpoints=200):
        GammaDOS.__init__(self,
                          gamma,
                          frequencies,
                          ir_grid_weights,
                          num_fpoints=num_fpoints)
        if sigma is None:
            self._sigma = (max(self._frequency_points) -
                           min(self._frequency_points)) / 100
        else:
            self._sigma = 0.1
        self._smearing_function = NormalDistribution(self._sigma)
        self._run_smearing_method()

    def _run_smearing_method(self):
        self._dos = []
        num_gp = np.sum(self._ir_grid_weights)
        for i, f in enumerate(self._frequency_points):
            dos = self._smearing_function.calc(self._frequencies - f)
            for j, g_t in enumerate(self._gamma):
                self._gdos[j, i, 1] = np.sum(np.dot(self._ir_grid_weights,
                                                    dos * g_t)) / num_gp

class GammaDOStetrahedron(GammaDOS):
    def __init__(self,
                 gamma,
                 cell,
                 frequencies,
                 mesh,
                 grid_address,
                 grid_mapping_table,
                 ir_grid_points,
                 ir_grid_weights,
                 num_fpoints=200):
        GammaDOS.__init__(self,
                          gamma,
                          frequencies,
                          ir_grid_weights,
                          num_fpoints=num_fpoints)
        self._cell = cell
        self._mesh = mesh
        self._grid_address = grid_address
        self._grid_mapping_table = grid_mapping_table
        self._ir_grid_points = ir_grid_points

        self._set_tetrahedron_method()
        self._run_tetrahedron_method()

    def _set_tetrahedron_method(self):
        self._tetrahedron_mesh = TetrahedronMesh(
            self._cell,
            self._frequencies,
            self._mesh,
            self._grid_address,
            self._grid_mapping_table,
            self._ir_grid_points)
        
    def _run_tetrahedron_method(self):
        num_freqs = len(self._frequency_points)
        thm = self._tetrahedron_mesh
        for j, value in enumerate(('J', 'I')):
            thm.set(value=value, frequency_points=self._frequency_points)
            for i, iw in enumerate(thm): 
                # gdos[temp, freq_points, IJ]
                # iw[freq_points, band]
                # gamma[temp, ir_gp, band]
                self._gdos[:, :, j] += np.dot(
                    self._gamma[:, i] * self._ir_grid_weights[i], iw.T)

if __name__ == '__main__':
    """Incremental kappa with respect to frequency and the derivative"""

    import h5py
    import sys
    from phonopy.interface.vasp import read_vasp
    from phonopy.structure.cells import get_primitive
    from phonopy.structure.symmetry import Symmetry
    from phonopy.structure.grid_points import GridPoints
    import argparse

    parser = argparse.ArgumentParser(description="Show unit cell volume")
    parser.add_argument("--pa", dest="primitive_axis",
                        default="1 0 0 0 1 0 0 0 1", help="Primitive matrix")
    parser.add_argument("--mesh", dest="mesh", default="1 1 1",
                         help="Mesh numbers")
    parser.add_argument('--pqj', action='store_true',
                        help='Calculate Pqj instead of Gamma')
    parser.add_argument('--cv', action='store_true',
                        help='Calculate Cv instead of Gamma')
    parser.add_argument('--tau', action='store_true',
                        help='Calculate lifetimes (tau) instead of Gamma')
    parser.add_argument('--gv_norm', action='store_true',
                        help='Calculate |g_v| instead of Gamma')
    parser.add_argument('--temperature', type=float, dest='temperature',
                        help='Temperature to output data at')
    parser.add_argument('--nsp', '--num_sampling_points', type=int,
                        dest='num_sampling_points',
                        default=100,
                        help='Number of sampling points in MFP axis')
    parser.add_argument('filenames', nargs='*')
    args = parser.parse_args()

    cell = read_vasp(args.filenames[0])
    primitive = get_primitive(cell, np.reshape(
        [fracval(x) for x in args.primitive_axis.split()], (3, 3)))

    f = h5py.File(args.filenames[1])
    if 'mesh' in f:
        mesh = np.array(f['mesh'][:], dtype='intc')
    else:
        mesh = np.array([int(x) for x in args.mesh.split()], dtype='intc')
    frequencies = f['frequency'][:]
    temperatures = f['temperature'][:]
    weights = f['weight'][:]

    symmetry = Symmetry(primitive)
    rotations = symmetry.get_pointgroup_operations()
    (ir_grid_points,
     weights_for_check,
     grid_address,
     grid_mapping_table) = get_ir_grid_points(mesh, rotations)

    if (weights != weights_for_check).any():
        print("*** gaccum exited ***")
        print("Something wrong in crystal symmetry. "
              "Please check --pa or --mesh")
        sys.exit(1)

    if args.pqj:
        gamma = f['ave_pp'][:].reshape((1,) + f['ave_pp'].shape)
    elif args.cv:
        gamma = f['heat_capacity'][:]
    elif args.tau:
        g = f['gamma'][:]
        g = np.where(g > 0, g, -1)
        gamma = np.where(g > 0, 1.0 / (2.0 * g), 0) # tau
    elif args.gv_norm:
        gamma = np.sqrt((f['group_velocity'][:, :, :] ** 2).sum(axis=2))
        gamma = gamma.reshape((1,) + gamma.shape)
    else:
        gamma = f['gamma'][:]

    if args.temperature is not None and gamma.shape[0] > 1:
        for i, t in enumerate(temperatures):
            if np.abs(t - args.temperature) < epsilon:
                temperatures = temperatures[i:i+1]
                gamma = gamma[i:i+1,:,:]

                break

    gamma_dos = GammaDOStetrahedron(gamma,
                                    primitive,
                                    frequencies,
                                    mesh,
                                    grid_address,
                                    grid_mapping_table,
                                    ir_grid_points,
                                    weights,
                                    num_fpoints=args.num_sampling_points)

    # gamma_dos = GammaDOSsmearing(gamma,
    #                              frequencies,
    #                              weights,
    #                              num_fpoints=200)
                         
    freq_points, gdos = gamma_dos.get_gdos()

    if args.pqj:
        for f, g in zip(freq_points, gdos[0]): # show kappa_xx
            print("%f %e %e" % (f, g[0], g[1]))
    else:
        for i, gdos_t in enumerate(gdos):
            total = np.dot(weights, gamma[i]).sum() / weights.sum()
            assert np.isclose(gdos_t[-1][0], total)
            print("# %d K" % temperatures[i])
            for f, g in zip(freq_points, gdos_t): # show kappa_xx
                print("%f %f %f" % (f, g[0], g[1]))
            print('')
            print('')
