#!/usr/bin/env python3

###########################################################################################
#  package:   Gtools
#  file:      pplot
#  brief:     
#  copyright: GPLv3
#             Copyright (C) 2019 EPFL (Ecole Polytechnique Federale de Lausanne)
#             LASTRO - Laboratory of Astrophysics of EPFL
#  author:    Yves Revaz <yves.revaz@epfl.ch>
#
# This file is part of Gtools.
###########################################################################################


import argparse
from h5py import File
import shutil

def parse_options():
    
    description="""
    This script allows to convert Swift snapshot to initial conditions by applying the relevant changes 
    in the hdf5 file. 
    
    When to use this script? If for whatever reason you need to start with a Swift output snapshot, 
    you need to convert them to proper Swift ICs. Some convention changes between 
    output snapshots and ICs are explained in Swift documentation. 
    
    When to rename the fields? An example, if one generates ICs with MUSIC, there is no output 
    format for Swift. So, one sets the output format to gadget. Then, the gadget output must be 
    converted to Swift format with `gad2swift` (script in Gtools package). However, this conversion 
    leaves the hdf5 fields as if they were the output of Swift snapshots. To finally get a proper IC file, 
    one can convert to Swift ICs with the sw_snapshot2ICs script and the option `
    --rename_fields_only`. It will only rename fields such as `SmoothingLengths` to 
    `SmoothingLength` and does not apply the scale factor.
    
    An alternative script is the built-in Swift `convert_snapshot_to_ICs.py` in the `tools` directory of 
    Swift (check out this: https://gitlab.cosma.dur.ac.uk/swift/swiftsim/-/merge_requests/1855).

    """
    epilog = """
    Examples:
    --------
    python3 sw_snapshot2ICs  snapthot_01000.hdf -o ICs.hdf5
    python3 sw_snapshot2ICs  snapthot_01000.hdf -o ICs.hdf5 --rename_fields_only
    
    """
    
    parser = argparse.ArgumentParser(description=description, epilog=epilog)
    
    parser.add_argument("filename",
      action="store", 
      type=str,
      default = None,		    
      help="Name of the input file")  
      
    parser.add_argument("-o",
      action="store", 
      dest="output",
      type=str,
      default = None,		    
      help="output file name",       
      metavar=" STRING")  
    
    parser.add_argument("--rename_fields_only",
                                              action="store_true", 
                                              help="Rename the fields in the hdf5 file without applying the scale factor on them.")  
      
    return parser.parse_args()



########################################################################
# MAIN
########################################################################


if __name__ == '__main__':
    opt = parse_options()
    
    # we need to change
    # SmoothingLengths   -> SmoothingLength
    # InternalEnergies   -> InternalEnergy
    filename = opt.filename 

    # first copy the filename
    if opt.output is not None:
         shutil.copyfile(filename, opt.output)
    else:
         raise ValueError("No output files has been provided...")
    

    with File(opt.output, "a") as f:
        
       # check if it is a cosmo run
       if ("Cosmology") in f and (not opt.rename_fields_only):
         a = float(f['Cosmology'].attrs['Scale-factor'])
       else:
         a = 1. 

       # loop over particle types
       for i in range(6):   
          name = "PartType%i" % i
          if name not in f:
              continue
                 
          part = f[name]      
                 
          if "SmoothingLengths" in part:
            part["SmoothingLength"] = part["SmoothingLengths"]
            del part["SmoothingLengths"]
                     
          if "InternalEnergies" in part:
            #part["InternalEnergy"] = part["InternalEnergies"]
            data = part["InternalEnergies"][:]/a**2
            part["InternalEnergy"] = data
            
            del part["InternalEnergies"]    
    
       f.flush()
