#!/usr/bin/env python

# Code: gxRealization
# Version: 1
# Version changes: GENERATE KDE FROM FIXED POPULATION TO MONTE CARLO
#                  A GALACTIC REALIZATION OF THE POPULATION ACCORDING
#                  TO FLAGS SET BY USER
#
# Edited on:  13 FEB 2018


##############################################################################
#  IMPORT ALL NECESSARY PYTHON PACKAGES
##############################################################################
from collections import OrderedDict
import warnings
import argparse

import math
import random
import time
from time import sleep
import string
import os.path

import numpy as np
import scipy.special as ss
import scipy.stats as stats
import pandas as pd
import multiprocessing as mp

import cosmic.gxreal as gxreal
import cosmic.MC_samp as MC_sample
import cosmic.GW_calcs as GW_calcs
import cosmic.utils as utils
##################################################################################
# DEFINE COMMANDLINE ARGUMENTS
##################################################################################
def parse_commandline():
    """Parse the arguments given on the command-line.
    """
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--galaxy_component", help="Galaxy Components. Options include Bulge ThinDisk and ThickDisk", required=True)
    parser.add_argument("--dist_model", help="Galaxy distribution model. fiducial is McMillan (2010)", required=True, default='McMillan')
    parser.add_argument("--final_kstar1", help="Specify the final condition of kstar1, you want systems to end at for your samples", required=True, type=float, nargs='+')
    parser.add_argument("--final_kstar2", help="Specify the final condition of kstar2, you want systems to end at for your samples", required=True, type=float, nargs='+')
    parser.add_argument("--N_realizations", help="Number of Galactic realziations to sample", type=int, default=100)
    parser.add_argument("-n", "--nproc", help="number of processors", type=int, default=1)
    parser.add_argument("--gx_save", help="Save the galaxy realizations; MAY GENERATE LARGE DATA SETS", type=int, default=0)
    parser.add_argument("--HG_save", help="Save systems which undergo common envelope w/ HG star secondary", type=int, default=0)
    parser.add_argument("--LISA_calc", help="Save the PSD of the population in the LISA band", type=int, default=0)
    parser.add_argument("--verbose", action="store_true", default=False, help="Run in Verbose Mode")
    args = parser.parse_args()

    if len(args.final_kstar1) > 2 or len(args.final_kstar2) > 2:
        raise parser.error('final kstar1 and final kstar2 '
                           'must be either a single value or '
                           'a range between two values.')

    if (len(args.final_kstar1) == 2):
        if (args.final_kstar1[0] >= args.final_kstar1[1]):
            raise parser.error('Range provided for kstar1 invalid')

    if (len(args.final_kstar2) == 2):
        if (args.final_kstar2[0] >= args.final_kstar2[1]):
            raise parser.error('Range provided for kstar2 invalid')
    return args


##################################################################################
# BEGIN MAIN FUNCTION
##################################################################################
if __name__ == '__main__':

    # LOOP OVER GALAXIES IN SAMPLE
    ###########################################################################
    def _gx_realization(gx, fixed_pop, fixed_mass, dat_list, gx_component, gx_model, LISA_calc, gx_save, Tobs, kstar1_string, kstar2_string):
        np.random.seed(gx)
        gx_file = pd.HDFStore('gx_real_'+str(gx)+'_'+str(gx_component)+'_'+kstar1_string+'_'+kstar2_string+'.h5', format='table')

        gx_log_file = open('log_gx_real_'+str(gx)+'_'+str(gx_component)+'_'+kstar1_string+'_'+kstar2_string+'.txt', 'w')
        gx_real = gxreal.GxReal(fixed_pop, fixed_mass, gx_model, gx_component, dat_list)

        # compute the number of systems in the sample
        gx_real.n_samp = gx_real.compute_n_sample()
        gx_log_file.write('The number of systems in a Galactic realization are {0}\n'.format(gx_real.n_samp))

        if gx_real.n_samp >= 1e6:
            n_loop = 50
            gx_real.n_samp = int(gx_real.n_samp/n_loop)
            gx_log_file.write('Sampling full population in {0} chunks\n'.format(n_loop))
            for ii in range(n_loop):
                # Monte Carlo sample the population
                gx_real.realization = gx_real.sample_population()
                gx_file.append('gx_dat', gx_real.realization)
                gx_log_file.write('Chunk {0} complete\n'.format(ii))
                gx_log_file.flush()
        else:
            # Monte Carlo sample the population
            gx_real.realization = gx_real.sample_population()
            gx_file.append('gx_dat', gx_real.realization)
        gx_log_file.write('Galaxy sampled\n')

        if LISA_calc > 0:
            gx_real.PSD = gx_real.LISA_obs(Tobs)
            gx_file.append('PSD', gx_real.PSD)

            gx_log_file.write('LISA calcs finished\n')
        gx_file.close()
        gx_log_file.close()
        return

    # READ COMMANDLINE ARGUMENTS
    ##############################################################################
    args = parse_commandline()

    # CONSTANTS
    ##############################################################################
    G = 6.67384*math.pow(10, -11.0)
    c = 2.99792458*math.pow(10, 8.0)
    parsec = 3.08567758*math.pow(10, 16)
    Rsun = 6.955*math.pow(10, 8)
    Msun = 1.9891*math.pow(10,30)
    day = 86400.0
    rsun_in_au = 215.0954
    day_in_year = 365.242
    sec_in_day = 86400.0
    sec_in_hour = 3600.0
    hrs_in_day = 24.0
    sec_in_year = 3.15569*10**7.0
    Tobs = 4 * sec_in_year
    geo_mass = G/c**2

    # UTILS???
    # Handle the file structure for different final kstars
    ###########################################################################
    if len(args.final_kstar1) == 2:
        kstar1_range = np.arange(args.final_kstar1[0], args.final_kstar1[1])
        kstar1_range_string = str(int(args.final_kstar1[0]))+'_'+str(int(args.final_kstar1[1]))
    else:
        kstar1_range = args.final_kstar1
        kstar1_range_string = str(int(args.final_kstar1[0]))

    if len(args.final_kstar2) == 2:
        kstar2_range = np.arange(args.final_kstar2[0], args.final_kstar2[1])
        kstar2_range_string = str(int(args.final_kstar2[0]))+'_'+str(int(args.final_kstar2[1]))
    else:
        kstar2_range = args.final_kstar2
        kstar2_range_string = str(int(args.final_kstar2[0]))

    # READ IN DATA
    ###########################################################################
    dat_path = 'dat_'+args.galaxy_component+'_'+\
                kstar1_range_string+'_'+kstar2_range_string+'.h5'
    total_sampled_mass = np.max(pd.read_hdf(dat_path, key='totalMass'))
    fixed_pop = pd.read_hdf(dat_path, key='bcm')

    bpp = pd.read_hdf(dat_path, key='bpp')

    # Filter out systems which undergo Common Envelope w/ HG star secondary
    ###########################################################################
    if args.HG_save == 0:
        comenv_index, = np.where(bpp.evol_type==7.0)
        if len(comenv_index) > 0:
            bpp_pre_ce = bpp.iloc[comenv_index-1]
            bpp_ce_save = bpp_pre_ce.loc[bpp_pre_ce.kstar_2 != 2.0]
            bpp_ce_save = bpp_ce_save.loc[bpp_ce_save.kstar_2 != 8.0]
            index_save = bpp_ce_save.bin_num

            fixed_pop = fixed_pop.loc[fixed_pop.bin_num.isin(index_save)]

    # Filter out systems which undergo mass transfer as compact objects
    ###################################################################
    index_MT = bpp.loc[(bpp.kstar_1.isin([10,11,12,13,14])) &
               (bpp.kstar_2.isin([10,11,12])) &
               (bpp.evol_type == 3.0)].bin_num
    fixed_pop = fixed_pop.loc[~fixed_pop.bin_num.isin(index_MT)]

    if fixed_pop.ecc.all() <= 1e-6:
        dat_list = ['mass_1', 'mass_2', 'porb']
    else:
        dat_list = ['mass_1', 'mass_2', 'porb', 'ecc']

    # Sample the Milky Way realizations and perform observability calculations
    ###########################################################################
    if args.nproc > 1:
        n_loops = (args.N_realizations / args.nproc)
        for n in range(n_loops):
            output = mp.Queue()
            processes = [mp.Process(target = _gx_realization,\
                                    args = (x + n*args.nproc + 0, fixed_pop,
                                            dat_list, total_sampled_mass, args.galaxy_component,
                                            args.dist_model, args.LISA_calc, args.gx_save, Tobs,
                                            kstar1_range_string, kstar2_range_string))
                                    for x in range(args.nproc)]

            for p in processes:
                p.daemon = True
                p.start()

            for p in processes:
                p.join()
    else:
        n_loops = args.N_realizations
        for n in range(n_loops):
            _gx_realization(n, fixed_pop, total_sampled_mass, dat_list, args.galaxy_component, args.dist_model, args.LISA_calc, args.gx_save, Tobs, kstar1_range_string, kstar2_range_string)

