#!/usr/bin/env python

# Code: runFixedPop.py
# Version: 1
# Version changes: SAMPLE FIXED POPULATION OF BINARIES AND EVOLVE WITH BSE;
#                  COMPUTE RATES AND NUMBERS FOR EACH POPULATION ACCORDING
#                  TO FLAGS SET BY USER
#
# Edited on:  8 SEP 2015


##############################################################################
#  IMPORT ALL NECESSARY PYTHON PACKAGES
##############################################################################
from collections import OrderedDict
import warnings
import argparse
from configparser import ConfigParser

import math
import random
import time
from time import sleep
import string
import os.path
import json

import numpy as np
import scipy.special as ss
import pandas as pd
import warnings

from cosmic.sample.initialbinarytable import InitialBinaryTable
from cosmic import Match
from cosmic.evolve import Evolve
from cosmic.utils import param_transform, filter_bpp_bcm, bcm_conv_select

###############################################################################
# DEFINE COMMANDLINE ARGUMENTS
###############################################################################
def parse_commandline():
    """Parse the arguments given on the command-line.
    """
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--inifile",
                        help="Name of ini file of params",
                        required=True)
    parser.add_argument("--final_kstar1",
                        help="Specify the final condition of kstar1 "
                        ", you want systems to end at for your samples",
                        required=True, type=int, nargs='+')
    parser.add_argument("--final_kstar2",
                        help="Specify the final condition of kstar2, you want "
                        "systems to end at for your samples",
                        required=True, type=int, nargs='+')
    parser.add_argument("--convergence-params",
                        help="A space separated list of parameters you would "
                        "like to verify have converged to a single distribution shape"
                        , nargs='+',
                        default=['mass_1', 'mass_2', 'porb', 'ecc'])
    parser.add_argument("--initial_samp",
                        help="Specify if independent binary initial "
                        "conditions: independent, or following "
                        "Moe & Di Stefano (2017): multidim",
                        default="multidim")
    parser.add_argument("--galaxy_component",
                        help="Galaxy Components. Options include "
                        "Bulge ThinDisk and ThickDisk", required=True)
    parser.add_argument("--metallicity", help="Metallicity of the population; "
                        "default: 0.02 (solar)", default=0.02, type=float)
    parser.add_argument("--Niter",
                        help="Number of iterations of binaries "
                        "to try, will check ever Nstep for convergence",
                        type=int, default=10000000)
    parser.add_argument("--Nstep",
                        help="Number of binaries to try before checking for "
                        "convergence, it will check ever Nstep binaries until "
                        "it reach Niter binaries", type=int, default=10000)
    parser.add_argument("--match",
                        help="Match value for convergence tests",
                        type=float, default=-5.0)
    parser.add_argument("-n", "--nproc",
                        help="number of processors", type=int, default=1)
    parser.add_argument("--verbose", action="store_true", default=False,
                        help="Run in Verbose Mode")

    args = parser.parse_args()

    if len(args.final_kstar1) > 2 or len(args.final_kstar2) > 2:
        raise parser.error('final kstar1 and final kstar2 '
                           'must be either a single value or '
                           'a range between two values.')

    if (len(args.final_kstar1) == 2):
        if (args.final_kstar1[0] >= args.final_kstar1[1]):
            raise parser.error('Range provided for kstar1 invalid')

    if (len(args.final_kstar2) == 2):
        if (args.final_kstar2[0] >= args.final_kstar2[1]):
            raise parser.error('Range provided for kstar2 invalid')

    if args.initial_samp not in ['independent', 'multidim']:
        raise parser.error('Initial sample must either be '
                           'independent or multidim')

    return args

###############################################################################
# BEGIN MAIN FUNCTION
###############################################################################
if __name__ == '__main__':

    # READ COMMANDLINE ARGUMENTS
    ###########################################################################
    args = parse_commandline()

    # SET TIME TO TRACK COMPUTATION TIME
    ###########################################################################
    start_time = time.time()

    # CONSTANTS
    ###########################################################################
    G = 6.67384*math.pow(10, -11.0)
    c = 2.99792458*math.pow(10, 8.0)
    parsec = 3.08567758*math.pow(10, 16)
    Rsun = 6.955*math.pow(10, 8)
    Msun = 1.9891*math.pow(10,30)
    day = 86400.0
    rsun_in_au = 215.0954
    day_in_year = 365.242
    sec_in_day = 86400.0
    sec_in_hour = 3600.0
    hrs_in_day = 24.0
    sec_in_year = 3.15569*10**7.0
    Tobs = 3.15569*10**7.0
    geo_mass = G/c**2

    # ---- Create configuration-file-parser object and read parameters file.
    cp = ConfigParser()
    cp.optionxform = str
    cp.read(args.inifile)

    # ---- Read needed variables from the inifile
    dictionary = {}
    for section in cp.sections():
        dictionary[section] = {}
        for option in cp.options(section):
            opt = cp.get(section, option)
            if opt == 'False':
                opt = False
            elif opt == 'True':
                opt = True
            try:
                dictionary[section][option] = json.loads(opt)
            except:
                dictionary[section][option] = opt

    BSEDict = dictionary['bse']
    seed_int = int(dictionary['rand_seed']['seed'])
    filters = dictionary['filters']
    convergence = dictionary['convergence']

    if seed_int != 0:
        np.random.seed(seed_int)
    else:
        np.random.seed(0)

    # ---- Check that values in the BSEDict are viable
    if BSEDict['polar_kick_angle']<0 or BSEDict['polar_kick_angle']>90:
        raise ValueError("{0:0.1f} is outside the allowed range of [0,90] for specifying polar kicks".format(BSEDict['polar_kick_angle']))

    # Set up final_kstar1 and final_kstar2 strings for saved data files
    if len(args.final_kstar1) == 2:
        kstar1_range = np.arange(args.final_kstar1[0], args.final_kstar1[1]+1)
        kstar1_range_string = str(int(args.final_kstar1[0]))+'_'+str(int(args.final_kstar1[1]))
    else:
        kstar1_range = args.final_kstar1
        kstar1_range_string = str(int(args.final_kstar1[0]))

    if len(args.final_kstar2) == 2:
        kstar2_range = np.arange(args.final_kstar2[0], args.final_kstar2[1]+1)
        kstar2_range_string = str(int(args.final_kstar2[0]))+'_'+str(int(args.final_kstar2[1]))
    else:
        kstar2_range = args.final_kstar2
        kstar2_range_string = str(int(args.final_kstar2[0]))

    # Open the hdf5 file to store the fixed population data
    try:
        dat_store = pd.HDFStore('dat_'+args.galaxy_component+'_'+kstar1_range_string+'_'+kstar2_range_string+'.h5')
        bcm_save = pd.read_hdf(dat_store, 'bcm')
        log_file = open('log_'+args.galaxy_component+'_'+kstar1_range_string+'_'+kstar2_range_string+'.txt', 'a')
        log_file.write('There are already: '+str(bcm_save.shape[0])+' '+kstar1_range_string+'_'+kstar2_range_string+' binaries evolved\n')
        log_file.write('\n')
        bpp_save = pd.read_hdf(dat_store, 'bpp')
        total_mass = pd.read_hdf(dat_store, 'totalMass')[0][0]
        total_sampled_mass = np.max(total_mass)
        idx = int(np.max(pd.read_hdf(dat_store, 'idx'))[0])
    except:
        bcm_save = pd.DataFrame()
        bpp_save = pd.DataFrame()
        dat_store = pd.HDFStore('dat_'+args.galaxy_component+'_'+kstar1_range_string+'_'+kstar2_range_string+'.h5')
        total_sampled_mass = 0.0
        idx = 0
        log_file = open('log_'+args.galaxy_component+'_'+kstar1_range_string+'_'+kstar2_range_string+'.txt', 'w')

    match_all = np.zeros(len(args.convergence_params))
    Nstep = 0

    match_list = []

    # Select the Galactic component from user input
    if args.galaxy_component == 'ThinDisk':
        SFH_model='const'
        component_age=10000.0
    elif args.galaxy_component == 'Bulge':
        SFH_model='burst'
        component_age=10000.0
    elif args.galaxy_component == 'ThickDisk':
        SFH_model='burst'
        component_age=11000.0
    elif args.galaxy_component == 'DeltaBurst':
        SFH_model='delta_burst'
        component_age=13700.0

    # Simulate the fixed population
    # This process is illustrated in Fig 1 of Breivik & Larson (2018)
    steps = 0
    bcm_filter = pd.DataFrame()
    initC_filter = pd.DataFrame()
    bpp_filter = pd.DataFrame()

    while (Nstep < args.Niter) & (np.max(np.array(match_all)) > args.match):
        # Select the initial binary sample method from user input
        if args.initial_samp == 'independent':
            IBT, sampled_mass, n_samp = InitialBinaryTable.sampler(args.initial_samp, kstar1_range, kstar2_range, 'kroupa93', 'thermal', SFH_model, component_age, args.metallicity, args.Nstep)

        if args.initial_samp == 'multidim':
            IBT, sampled_mass, n_samp = InitialBinaryTable.sampler(args.initial_samp, kstar1_range, kstar2_range, np.random.randint(0,1e6,1), args.nproc, SFH_model, component_age, args.metallicity, args.Nstep)

        # Log the total sampled mass from the initial binary sample
        # for future Galactic occurence rate calculation
        total_sampled_mass += sampled_mass

        # save the total_sampled_mass so far
        dat_store.append('totalMass', pd.DataFrame([total_sampled_mass]))
        log_file.write("The total mass sampled so far is: {0}\n".format(total_sampled_mass))

        # Now that we have all these initial conditions
        # let's create an Evolve class and evolve these systems
        # see if users specified a smaple rate for the bcm array if not set it
        # tphysf
        try:
            dtp = dictionary['bse']['dtp']
        except:
            dtp = IBT['tphysf'].values

        bpp, bcm, initCond = Evolve.evolve(initialbinarytable=IBT, BSEDict=BSEDict, nproc=args.nproc, idx=idx, dtp=dtp)

        # Convert the orbital period from years to log seconds
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore", message="divide by zero encountered in log10")
            bcm['porb'] = np.log10(bcm['porb']*sec_in_year)
            bpp['porb'] = np.log10(bpp['porb']*sec_in_year)

        met = []
        for num in bcm.bin_num.unique():
            n_met = len(bcm.loc[bcm.bin_num==num])
            met_val = np.array(initCond.loc[initCond.bin_num==num].metallicity)[0]
            met.extend(np.ones(n_met)*met_val)
        bcm['metallicity'] = met

        # Keep track of the index
        idx = int(bcm.bin_num.max()+1)
        dat_store.append('idx', pd.DataFrame([idx]))

        # Filter the bcm and bpp arrays according to user specified filters
        if len(bcm_filter) > 1:
            bcm_filter = bcm_filter.append(filter_bpp_bcm(bcm, bpp, filters, kstar1_range, kstar2_range))
            initC_filter = initC_filter.append(initCond.loc[initCond.bin_num.isin(bcm_filter.bin_num)])
            bpp_filter = bpp_filter.append(bpp.loc[bpp.bin_num.isin(bcm_filter.bin_num)])
        else:
            bcm_filter = filter_bpp_bcm(bcm, bpp, filters, kstar1_range, kstar2_range)
            initC_filter = initCond.loc[initCond.bin_num.isin(bcm_filter.bin_num)]
            bpp_filter = bpp.loc[bpp.bin_num.isin(bcm_filter.bin_num)]

        # Run the match on short period binaries, such that we get a better sampling
        # for the highest signal systems
        n_bin_state = []
        for bin_state in filters['binary_state']:
            n_bin_state.append(len(bcm_filter.loc[bcm_filter.bin_state == bin_state]))
        if all(i >= 50 for i in n_bin_state):
            # Save the bcm dataframe
            dat_store.append('bcm', bcm_filter)
            bcm_save = bcm_save.append(bcm_filter)

            # Save the bpp dataframe
            dat_store.append('bpp', bpp_filter)
            bpp_save = bpp_save.append(bpp_filter)

            # Save the initial binaries
            dat_store.append('initCond', initC_filter)

            # perform the convergence
            match = Match.perform_convergence(args.convergence_params, filters['binary_state'],\
                                                  convergence, bcm_save, bcm_filter, bpp_save,\
                                                  kstar1_range, kstar2_range, log_file)

            if (len(match) > 1):
                match_all = np.reshape(match, (-1,len(args.convergence_params)))
                match_save = pd.DataFrame(np.atleast_2d(match_all), columns = args.convergence_params)
                dat_store.append('match', match_save)

            # reset the bcm_filter DataFrame
            bcm_filter = pd.DataFrame()
        Nstep += args.Nstep
        log_file.flush()
        n_bin_state = []

    # Close the data storage file
    dat_store.close()

    log_file.write('All done friend!')
    log_file.close()

