#!/usr/bin/env python

# Code: runFixedPop.py
# Version: 1
# Version changes: SAMPLE FIXED POPULATION OF BINARIES AND EVOLVE WITH BSE;
#                  COMPUTE RATES AND NUMBERS FOR EACH POPULATION ACCORDING
#                  TO FLAGS SET BY USER
#
# Edited on:  8 SEP 2015


##############################################################################
#  IMPORT ALL NECESSARY PYTHON PACKAGES
##############################################################################
from collections import OrderedDict
import warnings
import argparse
from configparser import ConfigParser

import math
import random
import time
from time import sleep
import string
import os.path

import numpy as np
import scipy.special as ss
import pandas as pd


from cosmic.sample.initialbinarytable import InitialBinaryTable
from cosmic import Match
from cosmic.evolve import Evolve
from cosmic.utils import mass_min_max_select
from cosmic.utils import param_transform

###############################################################################
# DEFINE COMMANDLINE ARGUMENTS
###############################################################################
def parse_commandline():
    """Parse the arguments given on the command-line.
    """
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--inifile",
                        help="Name of ini file of params",
                        required=True)
    parser.add_argument("--final_kstar1",
                        help="Specify the final condition of kstar1 "
                        ", you want systems to end at for your samples",
                        required=True, type=int, nargs='+')
    parser.add_argument("--final_kstar2",
                        help="Specify the final condition of kstar2, you want "
                        "systems to end at for your samples",
                        required=True, type=int, nargs='+')
    parser.add_argument("--convergence-params",
                        help="A space separated list of parameters you would "
                        "like to varify you have simulated enough binaries for"
                        , nargs='+',
                        default=['mass_1', 'mass_2', 'porb', 'ecc'])
    parser.add_argument("--initial_samp",
                        help="Specify if independent binary initial "
                        "conditions: independent, or following "
                        "Moe & Di Stefano (2017): multidim",
                        default="multidim")
    parser.add_argument("--galaxy_component",
                        help="Galaxy Components. Options include "
                        "Bulge ThinDisk and ThickDisk", required=True)
    parser.add_argument("--metallicity", help="Metallicity of the population; "
                        "default: 0.02 (solar)", default=0.02, type=float)
    parser.add_argument("--porb_cut",
                        help="Specify an orbital period cut "
                        "in log10(sec); default: 12",
                        type=float, default=12.0)
    parser.add_argument("--Niter",
                        help="Number of iterations of binaries "
                        "to try, will check ever Nstep for convergence",
                        type=int, default=10000000)
    parser.add_argument("--Nstep",
                        help="Number of binaries to try before checking for "
                        "convergence, it will check ever Nstep binaries until "
                        "it reach Niter binaries", type=int, default=10000)
    parser.add_argument("-n", "--nproc",
                        help="number of processors", type=int, default=1)
    parser.add_argument("--verbose", action="store_true", default=False,
                        help="Run in Verbose Mode")

    args = parser.parse_args()

    if len(args.final_kstar1) > 2 or len(args.final_kstar2) > 2:
        raise parser.error('final kstar1 and final kstar2 '
                           'must be either a single value or '
                           'a range between two values.')

    if (len(args.final_kstar1) == 2):
        if (args.final_kstar1[0] >= args.final_kstar1[1]):
            raise parser.error('Range provided for kstar1 invalid')

    if (len(args.final_kstar2) == 2):
        if (args.final_kstar2[0] >= args.final_kstar2[1]):
            raise parser.error('Range provided for kstar2 invalid')

    if args.initial_samp not in ['independent', 'multidim']:
        raise parser.error('Initial sample must either be '
                           'independent or multidim')

    return args

###############################################################################
# BEGIN MAIN FUNCTION
###############################################################################
if __name__ == '__main__':

    # READ COMMANDLINE ARGUMENTS
    ###########################################################################
    args = parse_commandline()



    # SET TIME TO TRACK COMPUTATION TIME
    ###########################################################################
    start_time = time.time()

    # CONSTANTS
    ###########################################################################
    G = 6.67384*math.pow(10, -11.0)
    c = 2.99792458*math.pow(10, 8.0)
    parsec = 3.08567758*math.pow(10, 16)
    Rsun = 6.955*math.pow(10, 8)
    Msun = 1.9891*math.pow(10,30)
    day = 86400.0
    rsun_in_au = 215.0954
    day_in_year = 365.242
    sec_in_day = 86400.0
    sec_in_hour = 3600.0
    hrs_in_day = 24.0
    sec_in_year = 3.15569*10**7.0
    Tobs = 3.15569*10**7.0
    geo_mass = G/c**2

    # ---- Create configuration-file-parser object and read parameters file.
    cp = ConfigParser()
    cp.optionxform = str
    cp.read(args.inifile)

    # ---- Read needed variables from [bse] section.
    # -- bse
    neta = cp.getfloat('bse', 'neta')
    bwind = cp.getfloat('bse', 'bwind')
    hewind = cp.getfloat('bse', 'hewind')
    alpha1 = cp.getfloat('bse', 'alpha1')
    lambdaf = cp.getfloat('bse', 'lambdaf')
    ceflag = cp.getint('bse', 'ceflag')
    tflag = cp.getint('bse', 'tflag')
    ifflag = cp.getint('bse', 'ifflag')
    wdflag = cp.getint('bse', 'wdflag')
    bhflag = cp.getint('bse', 'bhflag')
    nsflag = cp.getint('bse', 'nsflag')
    mxns = cp.getfloat('bse', 'mxns')
    pts1 = cp.getfloat('bse', 'pts1')
    pts2 = cp.getfloat('bse', 'pts2')
    pts3 = cp.getfloat('bse', 'pts3')
    sigma = cp.getfloat('bse', 'sigma')
    beta = cp.getfloat('bse', 'beta')
    xi = cp.getfloat('bse', 'xi')
    acc2 = cp.getfloat('bse', 'acc2')
    epsnov = cp.getfloat('bse', 'epsnov')
    eddfac = cp.getfloat('bse', 'eddfac')
    gamma = cp.getfloat('bse', 'gamma')
    bconst = cp.getint('bse', 'bconst')
    CK = cp.getint('bse', 'CK')
    merger = cp.getint('bse', 'merger')
    windflag = cp.getint('bse', 'windflag')

    # ---- Read needed variables from [rand_seed] section.
    # -- rand_seed
    seed_int = cp.getint('rand_seed', 'seed')
    if seed_int != 0:
        np.random.seed(seed_int)
    else:
        np.random.seed()


    # Make dictionary of all BSE parameters for readablility reasons
    BSEDict = {}
    BSEDict['neta'] = neta
    BSEDict['bwind'] = bwind
    BSEDict['hewind'] = hewind
    BSEDict['alpha1'] = alpha1
    BSEDict['lambdaf'] = lambdaf
    BSEDict['ceflag'] = ceflag
    BSEDict['tflag'] = tflag
    BSEDict['ifflag'] = ifflag
    BSEDict['wdflag'] = wdflag
    BSEDict['bhflag'] = bhflag        
    BSEDict['nsflag'] = nsflag
    BSEDict['mxns'] = mxns
    BSEDict['pts1'] = pts1
    BSEDict['pts2'] = pts2
    BSEDict['pts3'] = pts3
    BSEDict['sigma'] = sigma
    BSEDict['beta'] = beta
    BSEDict['xi'] = xi
    BSEDict['acc2'] = acc2
    BSEDict['epsnov'] = epsnov
    BSEDict['eddfac'] = eddfac
    BSEDict['gamma'] = gamma
    BSEDict['bconst'] = bconst
    BSEDict['CK'] = CK
    BSEDict['merger'] = merger
    BSEDict['windflag'] = windflag

    # Based on the final_kstar1 and final_kstar2, select primary and secondary
    # mass ranges to evolve    
    primary_min, primary_max, secondary_min, secondary_max = mass_min_max_select(args.final_kstar1, args.final_kstar2)   

    # Set up final_kstar1 and final_kstar2 strings for saved data files
    if len(args.final_kstar1) == 2: 
        kstar1_range = np.arange(args.final_kstar1[0], args.final_kstar1[1]+1)
        kstar1_range_string = str(int(args.final_kstar1[0]))+'_'+str(int(args.final_kstar1[1]))
    else:
        kstar1_range = args.final_kstar1
        kstar1_range_string = str(int(args.final_kstar1[0]))

    if len(args.final_kstar2) == 2:
        kstar2_range = np.arange(args.final_kstar2[0], args.final_kstar2[1]+1)
        kstar2_range_string = str(int(args.final_kstar2[0]))+'_'+str(int(args.final_kstar2[1]))
    else:
        kstar2_range = args.final_kstar2
        kstar2_range_string = str(int(args.final_kstar2[0]))

    # Open the hdf5 file to store the fixed population data
    try:
        dat_store = pd.HDFStore('dat_'+args.galaxy_component+'_'+kstar1_range_string+'_'+kstar2_range_string+'.h5')
        bcm_save = pd.read_hdf(dat_store, 'bcm')
        log_file = open('log_'+args.galaxy_component+'_'+kstar1_range_string+'_'+kstar2_range_string+'.txt', 'a')
        log_file.write('There are already: '+str(bcm_save.shape[0])+' '+kstar1_range_string+'_'+kstar2_range_string+' binaries evolved\n')
        bpp_save = pd.read_hdf(dat_store, 'bpp')
        total_mass = pd.read_hdf(dat_store, 'totalMass')[0][0]
        total_sampled_mass = np.max(total_mass)
        idx = np.max(pd.read_hdf(dat_store, 'idx'))[0]
    except:
        bcm_save = pd.DataFrame()
        bpp_save = pd.DataFrame()
        dat_store = pd.HDFStore('dat_'+args.galaxy_component+'_'+kstar1_range_string+'_'+kstar2_range_string+'.h5')
        total_sampled_mass = 0.0
        idx = 0
        log_file = open('log_'+args.galaxy_component+'_'+kstar1_range_string+'_'+kstar2_range_string+'.txt', 'w')

    bcm_len = 0
    match_all = np.zeros(len(args.convergence_params))
    Nstep = args.Nstep

    match_list = []

    # Simulate the fixed population 
    # This process is illustrated in Fig 1 of Breivik & Larson (2018)
    while (Nstep < args.Niter) & (np.min(np.array(match_all)) < 0.999999) or (bcm_len < 10000):
        # Select the Galactic component from user input
        if args.galaxy_component == 'ThinDisk':
            SFH_model='const'
            component_age=10000.0
        elif args.galaxy_component == 'Bulge':
            SFH_model='burst'
            component_age=10000.0
        elif args.galaxy_component == 'ThickDisk':
            SFH_model='burst'
            component_age=11000.0
        elif args.galaxy_component == 'DeltaBurst':
            SFH_model='delta_burst'
            component_age=13700.0
        
        # Select the initial binary sample method from user input
        if args.initial_samp == 'independent':
            IBT, sampled_mass, n_samp = InitialBinaryTable.sampler(args.initial_samp, kstar1_range, kstar2_range, 'kroupa93', 'thermal', SFH_model, component_age, args.metallicity, args.Nstep)

        if args.initial_samp == 'multidim':
            IBT, sampled_mass, n_samp = InitialBinaryTable.sampler(args.initial_samp, kstar1_range, kstar2_range, np.random.randint(0,1e6,1), args.nproc, SFH_model, component_age, args.metallicity, args.Nstep)


        # Log the total sampled mass from the initial binary sample
        # for future Galactic occurence rate calculation
        total_sampled_mass += sampled_mass
        log_file.write("The total mass sampled so far is: {0}\n".format(total_sampled_mass))
        
        # save the total_sampled_mass so far
        dat_store.append('totalMass', pd.DataFrame([total_sampled_mass]))

        # Now that we have all these initial conditions 
        # let's create an Evolve class and evolve these systems
        bpp, bcm, initCond = Evolve.evolve(initialbinarytable=IBT, BSEDict=BSEDict, nproc=args.nproc, idx=idx)

        # Convert the orbital period from years to seconds
        bcm['porb'] = np.log10(bcm['porb']*sec_in_year)
        
        # Keep track of the index
        idx = bcm.bin_num.max()+1
        dat_store.append('idx', pd.DataFrame([idx]))

        # Filter out any binaries that have has mass transfer from a WD onto a compact object
        idx_MT = bpp.loc[(bpp.kstar_1.isin([10,11,12,13,14])) &(bpp.kstar_2.isin([10,11,12])) &
                         (bpp.evol_type == 3.0)].bin_num
        bcm = bcm.loc[~bcm.bin_num.isin(idx_MT)]
        
        # Select the state of the binary today and filter out 
        # any disrupted binaries
        bcm_filtered = bcm.loc[bcm.tphys > 1.0]
        bcm_filtered = bcm_filtered.loc[bcm_filtered.sep > 0.0]
        bcm_filtered = bcm_filtered.loc[bcm_filtered.porb < args.porb_cut]
        # Filter out long period binaries and systems that don't match final kstars supplied
        bcm_save_filtered = bcm_filtered.loc[(bcm_filtered.kstar_1.isin(kstar1_range)) & (bcm_filtered.kstar_2.isin(kstar2_range))]
        
        # Run the match on short period binaries, such that we get a better sampling
        # for the highest signal systems
        if bcm_save_filtered.shape[0] > 0:
            # Save the bcm dataframe
            dat_store.append('bcm', bcm_save_filtered)
            bcm_save = bcm_save.append(bcm_save_filtered)
            bcm_len = bcm_save.shape[0]
            if idx <= args.Nstep:
                bcm_save_test_conv1 = bcm_save.iloc[0:int(bcm_save.shape[0]/2.0)]
            else:
                bcm_save_test_conv1 = bcm_save.iloc[:bcm_save.shape[0]-bcm_save_filtered.shape[0]]
        
            idxSave = bcm_save.bin_num
            bpp_save_filtered = bpp[bpp.bin_num.isin(idxSave)]
            
            # Save the bpp dataframe
            dat_store.append('bpp', bpp_save_filtered)
            bpp_save = bpp_save.append(bpp_save_filtered)
            
            # Now let's select out the initial binaries that produce our data
            initC = initCond.loc[initCond.bin_num.isin(idxSave)]

            # Save the initial binaries 
            dat_store.append('initCond', initC)

            # Perform the Match calculations for all interested parameters
            # supplied by user in conv_params
            if bcm_save_test_conv1.shape[0] > 3:
                match_all = []
                for i_convergence_parameter in args.convergence_params:
                    if bcm_save[i_convergence_parameter].all() == 0.0:
                        match_all.append(np.array([1.0]))
                    else:
                        match, binwidth = Match.match([param_transform(bcm_save[i_convergence_parameter]).tolist(),\
                                                      param_transform(bcm_save_test_conv1[i_convergence_parameter]).tolist()], 2)
                        match_all.append(match)
                match_all = np.array(match_all)
                log_file.write('matches: {0}\n'.format(np.array(match_all)))
                log_file.write('Length of bcm array: {0}\n'.format(len(bcm_save)))
                match_save = pd.DataFrame(np.array(match_all).T, columns = args.convergence_params)
                dat_store.append('match', match_save)
            # save the total_sampled_mass so far
            dat_store.append('totalMass', pd.DataFrame([total_sampled_mass]))

        Nstep += args.Nstep
        log_file.flush()

    # Close the data storage file 
    dat_store.close()

    log_file.write('All done friend!')
    log_file.close()

