#!/home/travis/miniconda/bin/python

import numpy as np
from phonopy.phonon.tetrahedron_mesh import TetrahedronMesh
from phonopy.harmonic.force_constants import similarity_transformation
from phono3py.phonon3.triplets import (get_ir_grid_points,
                                       get_grid_points_by_rotations)

epsilon = 1.0e-8

def fracval(frac):
    if frac.find('/') == -1:
        return float(frac)
    else:
        x = frac.split('/')
        return float(x[0]) / float(x[1])

class KappaDOS(object):
    def __init__(self,
                 mode_kappa,
                 cell,
                 frequencies,
                 mesh,
                 grid_address,
                 grid_mapping_table,
                 ir_grid_points,
                 grid_order=None,
                 num_sampling_points=100):
        self._mode_kappa = mode_kappa
        self._tetrahedron_mesh = TetrahedronMesh(
            cell,
            frequencies,
            mesh,
            grid_address,
            grid_mapping_table,
            ir_grid_points)

        min_freq = min(frequencies.ravel())
        max_freq = max(frequencies.ravel()) + epsilon
        self._frequency_points = np.linspace(min_freq,
                                             max_freq,
                                             num_sampling_points)
        self._kdos = np.zeros(
            (len(mode_kappa), len(self._frequency_points), 2, 6),
            dtype='double')
        self._run_tetrahedron_method()

    def get_kdos(self):
        return self._frequency_points, self._kdos
        
    def _run_tetrahedron_method(self):
        num_freqs = len(self._frequency_points)
        thm = self._tetrahedron_mesh
        for j, value in enumerate(('J', 'I')):
            thm.set(value=value, frequency_points=self._frequency_points)
            for i, iw in enumerate(thm):
                # kdos[temp, freq_points, IJ, tensor_elem]
                # iw[freq_points, band]
                # mode_kappa[temp, ir_gp, band, tensor_elem]
                self._kdos[:, :, j] += np.transpose(
                    np.dot(iw, self._mode_kappa[:, i]), axes=(1, 0, 2))
        self._kdos *= np.prod(mesh)

def get_gv_by_gv(gv,
                 symmetry,
                 primitive,
                 mesh,
                 grid_points,
                 grid_address):
    point_operations = symmetry.get_reciprocal_operations()
    rec_lat = np.linalg.inv(primitive.get_cell())
    rotations_cartesian = np.array(
        [similarity_transformation(rec_lat, r)
         for r in point_operations], dtype='double')

    num_band = gv.shape[1]
    gv_sum2 = np.zeros((gv.shape[0], num_band, 6), dtype='double')
    for i, gp in enumerate(grid_points):
        rotation_map = get_grid_points_by_rotations(
            grid_address[gp],
            point_operations,
            mesh)
        gv_by_gv = np.zeros((num_band, 3, 3), dtype='double')
        for r in rotations_cartesian:
            gvs_rot = np.dot(gv[i], r.T)
            gv_by_gv += [np.outer(r_gv, r_gv) for r_gv in gvs_rot]
        gv_by_gv /= len(rotation_map) // len(np.unique(rotation_map))
        for j, vxv in enumerate(
                ([0, 0], [1, 1], [2, 2], [1, 2], [0, 2], [0, 1])):
            gv_sum2[i, :, j] = gv_by_gv[:, vxv[0], vxv[1]]

    return gv_sum2


if __name__ == '__main__':
    """Incremental kappa with respect to frequency and the derivative"""

    import h5py
    import sys
    from phonopy.interface.vasp import read_vasp
    from phonopy.structure.cells import get_primitive
    from phonopy.structure.symmetry import Symmetry
    from phonopy.structure.grid_points import GridPoints
    import argparse

    parser = argparse.ArgumentParser(description="Show unit cell volume")
    parser.add_argument("--pa", dest="primitive_axis",
                        default="1 0 0 0 1 0 0 0 1", help="Primitive matrix")
    parser.add_argument("--mesh", dest="mesh", default="1 1 1",
                         help="Mesh numbers")
    parser.add_argument('--gv', action='store_true',
                        help='Calculate gv_x_gv instead of kappa')
    parser.add_argument('--mfp', action='store_true',
                        help='Mean free path is used instead of frequency')
    parser.add_argument('--temperature', type=float, dest='temperature',
                        help='Temperature to output data at')
    parser.add_argument('--nsp', '--num_sampling_points', type=int,
                        dest='num_sampling_points',
                        default=100,
                        help='Number of sampling points in MFP axis')
    parser.add_argument('--average', action='store_true',
                        help=("Output the traces of the tensors divided by 3 "
                              "rather than the unique elements"))
    parser.add_argument('--trace', action='store_true',
                        help=("Output the traces of the tensors "
                              "rather than the unique elements"))
    parser.add_argument('filenames', nargs='*')
    args = parser.parse_args()
    
    cell = read_vasp(args.filenames[0])
    primitive = get_primitive(cell, np.reshape(
        [fracval(x) for x in args.primitive_axis.split()], (3, 3)))

    f = h5py.File(args.filenames[1])
    if 'mesh' in f:
        mesh = np.array(f['mesh'][:], dtype='intc')
    else:
        mesh = np.array([int(x) for x in args.mesh.split()], dtype='intc')

    temperatures = f['temperature'][:]
    weights = f['weight'][:]

    symmetry = Symmetry(primitive)
    rotations = symmetry.get_pointgroup_operations()
    (ir_grid_points,
     weights_for_check,
     grid_address,
     grid_mapping_table) = get_ir_grid_points(mesh, rotations)

    if (weights != weights_for_check).any():
        print("*** kaccum exited ***")
        print("Something wrong in crystal symmetry. "
              "Please check --pa or --mesh")
        sys.exit(1)

    if args.gv:
        if 'gv_by_gv' in f:
            gv_sum2 = f['gv_by_gv'][:]
        else: # For backward compatibility. This will be removed someday.
            gv = f['group_velocity'][:]
            gv_sum2 = get_gv_by_gv(gv,
                                   symmetry,
                                   primitive,
                                   mesh,
                                   ir_grid_points,
                                   grid_address)
        unit_conversion = primitive.get_volume() * np.prod(mesh)
        mode_kappa = gv_sum2.reshape((1,) + gv_sum2.shape) / unit_conversion
    else:
        mode_kappa = f['mode_kappa'][:]

    if args.mfp:
        g = f['gamma'][:]
        g = np.where(g > 0, g, -1)
        gv_norm = np.sqrt((f['group_velocity'][:] ** 2).sum(axis=2))
        # mean free path though the variable name is frequency...
        frequencies = np.where(g > 0, gv_norm / (2 * 2 * np.pi * g), 0)
    else:
        frequencies = f['frequency'][:]

    if args.temperature is not None and mode_kappa.shape[0] > 1:
        for i, t in enumerate(temperatures):
            if np.abs(t - args.temperature) < epsilon:
                temperatures = temperatures[i:i+1]
                mode_kappa = mode_kappa[i:i+1,:,:]
                if len(frequencies.shape) == 3: # For mean free path
                    frequencies = frequencies[i:i+1]

    if args.mfp:
        kdos = []
        sampling_points = []
        for i, mfp in enumerate(frequencies):
            kappa_dos = KappaDOS(mode_kappa[i:i+1,:,:],
                                 primitive,
                                 mfp,
                                 mesh,
                                 grid_address,
                                 grid_mapping_table,
                                 ir_grid_points,
                                 num_sampling_points=args.num_sampling_points)
            sampling_points_at_T, kdos_at_T = kappa_dos.get_kdos()
            kdos.append(kdos_at_T[0])
            sampling_points.append(sampling_points_at_T)
        kdos = np.array(kdos)
        sampling_points = np.array(sampling_points)
    else:
        kappa_dos = KappaDOS(mode_kappa,
                             primitive,
                             frequencies,
                             mesh,
                             grid_address,
                             grid_mapping_table,
                             ir_grid_points,
                             num_sampling_points=args.nsp)
        freq_points, kdos = kappa_dos.get_kdos()
        sampling_points = np.tile(freq_points, (len(kdos), 1))

    for i, kdos_t in enumerate(kdos):
        if not args.gv:
            print("# %d K" % temperatures[i])

        for f, k in zip(sampling_points[i], kdos_t): # show kappa_xx
            if args.average:
                print(("%13.5f " * 3) %
                      (f, k[0][:3].sum() / 3, k[1][:3].sum() / 3))
            elif args.trace:
                print(("%13.5f " * 3) % (f, k[0][:3].sum(), k[1][:3].sum()))
            else:
                print(("%f " * 13) % ((f,) + tuple(k[0]) + tuple(k[1])))

        print('')
        print('')
