#!/usr/bin/env python3

###########################################################################################
#  package:   Gtools
#  file:      sw_gad2swift
#  brief:     convert gadget binary IC files to swift hdf5 IC file.
#  copyright: GPLv3
#             Copyright (C) 2019 EPFL (Ecole Polytechnique Federale de Lausanne)
#             LASTRO - Laboratory of Astrophysics of EPFL
#  author:    Loic Hausammann, Mladen Ivkovic <mladen.ivkovic@hotmail.com>
#
# This file is part of Gtools.
###########################################################################################
"""
This script converts a gadget2 (non-hdf5) type initial condition file
(e.g. made with MUSIC) to a swift type IC file.


Based on scripts written by Loic Hausammann.
Put together in this form by Mladen Ivkovic, Dec 2018
"""


from pNbody import *
from h5py import File
from sys import argv
import numpy as np
from os import path
import argparse


# number of particle type
N_type = 6
# debug = False   # helpful debugging flag. Prints more info at times, and skips repetitive questions.




description = """
This script converts a gadget2 (non-hdf5) type initial condition file
(e.g. made with MUSIC) to a swift type IC file.
Needs pNbody to be installed.

In this current form, it will only convert gas, dm and bndry particles.
Stars are not considered.

To guarantee that the file with be used as IC,
the internal energy and SPH smoothing length will be stored as:
InternalEnergy
SmoothingLength

"""


epilog = """
Examples:
--------
sw_gad2swift snap.dat -o snap_swift.hdf5
sw_gad2swift snap.dat
sw_gad2swift snap.dat -o snap_swift.hdf5 --disable_interactive_mode
sw_gad2swift snap.dat -o snap_swift.hdf5 --disable_interactive_mode --verbose 10
"""


def parse_options():

    parser = argparse.ArgumentParser(description=description, epilog=epilog,formatter_class=argparse.RawDescriptionHelpFormatter)

    parser.add_argument("filename",
                        action="store",
                        type=str,
                        default=None,
                        help="Name of the input file")

    parser.add_argument("-o",
                        action="store",
                        dest="output",
                        type=str,
                        default=None,
                        help="output file name",
                        metavar=" STRING")

    parser.add_argument("--disable_interactive_mode",
                        action="store_false",
                        help="Disable interactive mode")

    parser.add_argument("--verbose",
                        action="store",
                        type=int,
                        default=0,
                        help="Verbose mode")


    return parser.parse_args()



# ====================================
def fix_particle_types(file_out, ptype_translation, interactive=True, verbose=0):
# ====================================
    """
    Change particle types in order to match the implemented types.
    Changes the particle types in the hdf5 file itself and overwrites
    it.
    """

    # --------------------------------------------------------------
    def groupName(part_type):
        return "PartType%i" % part_type
    # --------------------------------------------------------------

    # --------------------------------------------------------------

    def changeType(f, old, nedebug=False):
        """
        Change particle types in the file f.
        """

        # check if directory exists
        old_group = groupName(old)

        if old_group not in f:
            while True:
                print("Changing particle types - cannot find group '%s'" % old)
                print(
                    " This is fine if your gadget file didn't contain any particles of that type.")
                ans = input("Should I continue? [y/n] ")
                if ans == 'y' or ans == 'Y':
                    return
                elif ans == 'n' or ans == 'N':
                    raise IOError("Cannot find group '%s'" % old)

        while True:
            print(
                "Changing particle types - To which particle type do you want to convert '%s' ?" % old)
            ans = input("Enter a particle type number? [0, 1, 2, 3, 4, 5] ")
            if ans == '0' or ans == '1' or ans == '2' or ans == '3' or ans == '4' or ans == '5':
                new = int(ans)
                break
            else:
                raise IOError("Swift does not have particle type '%s'" % ans)

        old = f[old_group]

        new_group = groupName(new)

        if new_group not in f:
            f.create_group(new_group)

        new = f[new_group]

        print('check')
        for name in old:
            if debug:
                print("Moving '%s' from '%s' to '%s'"
                      % (name, old_group, new_group))

            tmp = old[name][:]
            del old[name]
            if name in new:
                new_tmp = new[name][:]
                if debug:
                    print("Found previous data:", tmp.shape, new_tmp.shape)
                tmp = np.append(tmp, new_tmp, axis=0)
                del new[name]

            if debug:
                print("With new shape:", tmp.shape)

            new.create_dataset(name, tmp.shape)
            new[name][:] = tmp

        del f[old_group]
    # --------------------------------------------------------------


    # --------------------------------------------------------------

    def changeTypeToType(f, ptype, old, new, verbose=0):
        """
        Change particle types in the file f.
        """
        
        if verbose>0:
          verbose=True

        # check if directory exists
        old_group = groupName(old)

        if old_group not in f:
            if verbose: print("Changing particle types - cannot find group '%s'. Skipping." % old)
            return
  
  
        print("%6s : %d -> %d"%(ptype,old,new))
                
        old = f[old_group]

        new_group = groupName(new)

        if new_group not in f:
            f.create_group(new_group)

        new = f[new_group]
        

        for name in old:
            if verbose: print("Moving '%s' from '%s' to '%s'"% (name, old_group, new_group))

            tmp = old[name][:]
            del old[name]
            if name in new:
                new_tmp = new[name][:]
                if verbose: print("Found previous data:", tmp.shape, new_tmp.shape)
                tmp = np.append(tmp, new_tmp, axis=0)
                del new[name]

            if verbose: print("With new shape:", tmp.shape)

            new.create_dataset(name, tmp.shape)
            new[name][:] = tmp

        del f[old_group]
    # --------------------------------------------------------------


    # --------------------------------------------------------------

    def countPart(f):
        """
        Count particles again.
        """

        npart = []

        for i in range(N_type):
            name = groupName(i)
            if name in f:
                grp = f[groupName(i)]
                N = grp["Masses"].shape[0]
            else:
                N = 0

            npart.append(N)

        f["Header"].attrs["NumPart_ThisFile"] = npart
        f["Header"].attrs["NumPart_Total"] = npart
        f["Header"].attrs["NumPart_Total_HighWord"] = [0]*N_type

        return
    # --------------------------------------------------------------

    # Bug fix: set the opening mode to "a". Do not leave it empty o
    # therwise it is read-only
    
    print("Change particle type in %s"%file_out)
    
    f = File(file_out, "a")

    
    for ptype in ptype_translation.keys(): 
      
      old = ptype_translation[ptype][0]
      new = ptype_translation[ptype][1]
      
      if old ==new:
        continue
        
      changeTypeToType(f,ptype,old,new,verbose=verbose)
     
    #changeType(f, 2, debug)
    #changeType(f, 3, debug)
    #changeType(f, 4, debug)
    #changeType(f, 5, debug)

    # Re-count particles properly
    countPart(f)

    f.close()
    print("Finished changing SWIFT-type file appropriately for use.")
    return


# ===================================================
def convert_gadget_to_swift(file_in, file_out, interactive=True, verbose=0):
# ===================================================
    """
    Converts a non-hdf5 type gadget2 IC file to the SWIFT
    format.
    (it should also work with hdf5-type gadget files,
    pNbody should figure out the initial file type by itself)
    """

    nb = Nbody(file_in, ftype="gadget",verbose=verbose)
    gadget_part_matchingDic = nb.getParticleMatchingDict()
    

    # change ftype
    nb = nb.set_ftype("swift")
    swift_part_matchingDic = nb.getParticleMatchingDict()
    
    
    
    # setup the ptype translator
    ptypes = ["gas","halo","stars","bndry"]
    
    # define the translation explicitly
    ptype_translation = {} 
    ptype_translation["gas"]  = (0,0)
    ptype_translation["halo"]  = (1,1)
    ptype_translation["bndry"] = (5,2)
    #ptype["stars"] = (0,0)            


    #ptype_translation = {}    
    #for ptype in ptypes:      
    #  ptype_translation[ptype] = (gadget_part_matchingDic[ptype],swift_part_matchingDic[ptype])
      


    ############################################
    # Set units if necessary
    ############################################

    units = ["UnitLength_in_cm", "UnitVelocity_in_cm_per_s", "UnitMass_in_g"]
    unitd = nb.unitsparameters.get_dic()


    print("This script assumes gadget default units, which are:")
    for u in units:
        print("{0:30}{1:12.4E}".format(u, unitd[u]))

    if interactive:

        while True:
            ans = input("Do you wish to change them manually? [y/n] ")
            if ans == 'y' or ans == 'Y':
                i = 0
                while i < len(units):
                    u = units[i]
                    inp = input("Enter a value for "+u +
                                ": [leave empty to keep] ")
                    try:
                        val = float(inp)
                    except ValueError:
                        if (inp == ""):
                            i += 1
                            continue
                        else:
                            print("Didn't understand input. Try again.")
                            continue
                    nb.unitsparameters.set(u, val)
                    i += 1
                print("Units are now:")
                for u in units:
                    unitd = nb.unitsparameters.get_dic()
                    print("{0:30}{1:12.4E}".format(u, unitd[u]))

                break
            elif ans == 'n' or ans == 'N':
                break

    
    ############################################
    # set boxsize
    ############################################
    
    boxsizeguess = nb.boxsize

    print("This script made a guess for the boxsize, which is %s (in code units)"%boxsizeguess)
    
    if interactive:
    
        print("If that is not correct, the computed density parameters Omega ")
        print("will be different from the ones in your MUSIC config file, and SWIFT")
        print("might not run if you specified the wrong density parameters in your SWIFT parameter file.")

        while True:
            ans = input("Do you wish to change them manually? [y/n] ")
            if ans == 'y' or ans == 'Y':
                inp = input("Enter a value for the boxsize: ")
                try:
                    val = float(inp)
                except ValueError:
                    print("Didn't understand input. Try again.")
                    continue
                nb.boxsize = val
                
                print("The boxsize is %s (in code units)"%val)
                
                break

            elif ans == 'n' or ans == 'N':
                nb.boxsize = boxsizeguess
                break

    nb.periodic = 0
    nb.flag_entropy_ics = 0
    
    ############################################
    # set SPH smoothing length
    ############################################
    

    
    if interactive:
      print("SWIFT needs an initial guess for the particle smoothing lengths.")
      print("They don't need to be exact, SWIFT will fix them up, but they shouldn't")
      print("be zero. You can either do a proper estimate of the smoothing length")
      print("using the tree construction from pNbody, but that might fail for 'large' (>1GB)")
      print("IC file sizes. Or you can do a crude estimate over mean interparticle distances.")
    
      while True:
          ans = input("Should I try the tree code? (Might crash) [y/n]")
          if ans == 'y' or ans == 'Y':
              do_tree = True
              break
      
          elif ans == 'n' or ans == 'N':
              do_tree = False
              break
    else:
      do_tree = False
    
    if do_tree:
        hsml = nb.get_rsp_approximation()
    else:
        
        print("Determine the smoothing length with a crude estimate over mean interparticle distances.")
      
        # average number of particles in 1 direction
        partx = (nb.nbody_tot)**(1./3)
        hest = 2*nb.boxsize/partx  # factor 2: just guessing here. Shouldn't matter
        hsml = np.zeros(nb.nbody_tot)
        # now assign the particles that need a smoothing length the guess
        start = 0
        for i, npart in enumerate(nb.npart):
            if i == 2 or i == 3:
                continue
            stop = start + npart
            hsml[start:stop] = hest
            start = stop
    

    # get ic
    nb.rsp_init = hsml
    nb.u_init   = nb.u
    
    del nb.rsp
    del nb.u

    # write new file
    nb.rename(file_out)

    nb.write()
    print("Written SWIFT-type file ", file_out)

    return ptype_translation


# ==================================
if __name__ == "__main__":
# ==================================
    opt = parse_options()
    file_in = opt.filename
    interactive = opt.disable_interactive_mode
    verbose     = opt.verbose
    
    
    # Set output file
    if opt.output is not None:
        file_out = opt.output
    else:
        # If no second file was given, generate file_out filename
        # assume suffix is after last period in filename
        cut = 0
        append = '-SWIFT.hdf5'
        for i in range(len(file_in)):
            if file_in[-i-1] == '.':
                cut = len(file_in)-i-1
                break
        if cut > 0:
            file_out = file_in[:cut]+append
        else:
            file_out = file_in+append

    # Convert to Swift now
    ptype_translation = convert_gadget_to_swift(file_in, file_out, interactive, verbose)
    fix_particle_types(file_out, ptype_translation, interactive, verbose)
