#!/usr/bin/env python3

import argparse
import illustris_python as il
from pNbody import *
import h5py
import numpy as np
from tqdm import tqdm

####################################################################
# option parser
####################################################################

description="extract stars inside an illustris subhalo"
epilog     ="""
Examples:
--------
il_getStars -ID 513845 -o TNG50_513845_stars.hdf5
il_getStars -ID 513845 --addSubhaloes -o TNG50_513845_stars.hdf5

"""


parser = argparse.ArgumentParser(description=description,epilog=epilog,formatter_class=argparse.RawDescriptionHelpFormatter)


parser.add_argument("--ID",
                    action="store", 
                    dest="ID",
                    metavar='INT',
                    type=int, 
                    default=0,
                    help='subhalo ID')  

parser.add_argument("--basePath",
                    action="store", 
                    dest="basePath",
                    metavar='STR',
                    type=str, 
                    help='base path') 

parser.add_argument("--snapNum",
                    action="store", 
                    dest="snapNum",
                    metavar='INT',
                    type=int, 
                    help='snapshot number') 

parser.add_argument("-o",
                    action="store",
                    type=str,
                    dest="outputfilename",
                    default=None,
                    help="Name of the output file")  

parser.add_argument("--addSubhaloes",
                    action="store_true", 
                    dest="addSubhaloes",
                    default=False,
                    help='base path') 

####################################################################
# main
####################################################################

opt = parser.parse_args()


basePath = opt.basePath
snapNum  = opt.snapNum
haloID   = opt.ID


# get snapshot path and read header
with h5py.File(il.snapshot.snapPath(basePath, snapNum), 'r') as f:
  header = dict(f['Header'].attrs.items())




if opt.addSubhaloes:
  # get all haloes of the same group
  print("reading subhalo catalog...",)
  subhalos = il.groupcat.loadSubhalos(basePath,snapNum)

  # get group number
  GrNr = subhalos["SubhaloGrNr"][haloID]

  # get the IDs of haloes in the same group
  c = subhalos["SubhaloGrNr"]==GrNr
  IDs = np.compress(c,np.arange(subhalos['count']))

else:
  IDs = [haloID]


for ID in tqdm(IDs):  
  
  # load a subfind Halo from its ID
  stars = il.snapshot.loadSubhalo(basePath,snapNum,ID,'stars')
  

  if "Coordinates" not in stars:
    continue
    

  pos         = stars["Coordinates"]
  mass        = stars["Masses"]


  nb = Nbody(pos=pos,mass=mass,status='new',ftype="arepo")


  nb.set_tpe(4)

  nb.minit                = stars["GFM_InitialMass"]
  nb.tstar                = stars["GFM_StellarFormationTime"]
  nb.metals               = stars["GFM_Metals"]
  nb.metallicity          = stars["GFM_Metallicity"]

  nb.atime                = header["Time"]                         # 0.9999999999999998
  nb.redshift             = header["Redshift"]                     # 2.220446049250313e-16

  nb.UnitLength_in_cm     = header["UnitLength_in_cm"]             # 3.085678e+21
  nb.UnitMass_in_g        = header["UnitMass_in_g"]                # 1.989e+43
  nb.UnitVelocity_in_cgs  = header["UnitVelocity_in_cm_per_s"]     # 100000.0
  nb.Unit_time_in_cgs     = nb.UnitLength_in_cm/nb.UnitVelocity_in_cgs

  nb.hubbleparam          = header["HubbleParam"]                  # 0.6774
  nb.omegabaryon          = header["OmegaBaryon"]                  # 0.0486
  nb.omega0               = header["Omega0"]                       # 0.3089
  nb.omegalambda          = header["OmegaLambda"]                  # 0.6911

  nb.cosmorun             = 1
  nb.boxsize              = header["BoxSize"]

  nb.git_commit           = header["Git_commit"] # just to tell its an arepo format

  
  if ID==IDs[0]:
    nbtot = nb
  
  else:
    nbtot = nbtot + nb  
    
   
   


if opt.outputfilename is not None:
  nbtot.rename(opt.outputfilename)
  nbtot.write()


