#!/usr/bin/env python

# Code: cosmic-pop.py
# Version: 1
# Version changes: SAMPLE FIXED POPULATION OF BINARIES AND EVOLVE WITH BSE;
#                  COMPUTE RATES AND NUMBERS FOR EACH POPULATION ACCORDING
#                  TO FLAGS SET BY USER
#
# Edited on:  8 SEP 2015


##############################################################################
#  IMPORT ALL NECESSARY PYTHON PACKAGES
##############################################################################
from collections import OrderedDict
import warnings
import argparse

import math
import random
import time
from time import sleep
import string
import os.path

import numpy as np
import scipy.special as ss
import pandas as pd
import warnings

from cosmic.sample.initialbinarytable import InitialBinaryTable
from cosmic import Match, utils
from cosmic.evolve import Evolve

###############################################################################
# DEFINE COMMANDLINE ARGUMENTS
###############################################################################
def parse_commandline():
    """Parse the arguments given on the command-line.
    """
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--inifile",
                        help="Name of ini file of params",
                        required=True)
    parser.add_argument("--final-kstar1",
                        help="Specify the final condition of kstar1 "
                        ", you want systems to end at for your samples",
                        required=True, type=int, nargs='+')
    parser.add_argument("--final-kstar2",
                        help="Specify the final condition of kstar2, you want "
                        "systems to end at for your samples",
                        required=True, type=int, nargs='+')
    parser.add_argument("--Niter",
                        help="Number of iterations of binaries "
                        "to try, will check ever Nstep for convergence",
                        type=int, default=10000000)
    parser.add_argument("--Nstep",
                        help="Number of binaries to try before checking for "
                        "convergence, it will check ever Nstep binaries until "
                        "it reach Niter binaries", type=int, default=10000)
    parser.add_argument("-n", "--nproc",
                        help="number of processors", type=int, default=1)
    parser.add_argument("--verbose", action="store_true", default=False,
                        help="Run in Verbose Mode")

    args = parser.parse_args()

    if len(args.final_kstar1) > 2 or len(args.final_kstar2) > 2:
        raise parser.error('final kstar1 and final kstar2 '
                           'must be either a single value or '
                           'a range between two values.')

    if (len(args.final_kstar1) == 2):
        if (args.final_kstar1[0] >= args.final_kstar1[1]):
            raise parser.error('Range provided for final-kstar1 invalid')

    if (len(args.final_kstar2) == 2):
        if (args.final_kstar2[0] >= args.final_kstar2[1]):
            raise parser.error('Range provided for final-kstar2 invalid')

    if (len(args.final_kstar2) == 1) and (len(args.final_kstar1) == 1):
        if (args.final_kstar2 > args.final_kstar1):
            raise parser.error('final-kstar1 must be greater than or equal to '
                               'final-kstar2.')

    return args

###############################################################################
# BEGIN MAIN FUNCTION
###############################################################################
if __name__ == '__main__':

    # READ COMMANDLINE ARGUMENTS
    ###########################################################################
    args = parse_commandline()

    # SET TIME TO TRACK COMPUTATION TIME
    ###########################################################################
    start_time = time.time()

    # READ AND PARSE INIFILE
    ###########################################################################
    BSEDict, seed_int, filters, convergence, sampling = utils.parse_inifile(args.inifile)

    # Check that the values in BSEDict, filters, and convergence are valid
    utils.error_check(BSEDict, filters, convergence, sampling)

    if seed_int != 0:
        np.random.seed(seed_int)
    else:
        np.random.seed(0)

    # Set up final_kstar1 and final_kstar2 strings for saved data files
    if len(args.final_kstar1) == 2:
        kstar1_range = np.arange(args.final_kstar1[0], args.final_kstar1[1]+1)
        kstar1_range_string = str(int(args.final_kstar1[0]))+'_'+str(int(args.final_kstar1[1]))
    else:
        kstar1_range = args.final_kstar1
        kstar1_range_string = str(int(args.final_kstar1[0]))

    if len(args.final_kstar2) == 2:
        kstar2_range = np.arange(args.final_kstar2[0], args.final_kstar2[1]+1)
        kstar2_range_string = str(int(args.final_kstar2[0]))+'_'+str(int(args.final_kstar2[1]))
    else:
        kstar2_range = args.final_kstar2
        kstar2_range_string = str(int(args.final_kstar2[0]))

    # Open the hdf5 file to store the fixed population data
    try:
        dat_store = pd.HDFStore('dat_'+sampling['galaxy_component']+'_'+kstar1_range_string+'_'+kstar2_range_string+'.h5')
        conv_save = pd.read_hdf(dat_store, 'conv')
        log_file = open('log_'+sampling['galaxy_component']+'_'+kstar1_range_string+'_'+kstar2_range_string+'.txt', 'a')
        log_file.write('There are already: '+str(conv_save.shape[0])+' '+kstar1_range_string+'_'+kstar2_range_string+' binaries evolved\n')
        log_file.write('\n')
        total_mass_singles = np.max(pd.read_hdf(dat_store, 'mass_singles'))[0]
        total_mass_binaries = np.max(pd.read_hdf(dat_store, 'mass_binaries'))[0]
        total_mass_stars = np.max(pd.read_hdf(dat_store, 'mass_stars'))[0]
        total_n_singles = np.max(pd.read_hdf(dat_store, 'n_singles'))[0]
        total_n_binaries = np.max(pd.read_hdf(dat_store, 'n_binaries'))[0]
        total_n_stars = np.max(pd.read_hdf(dat_store, 'n_stars'))[0]
        idx = int(np.max(pd.read_hdf(dat_store, 'idx'))[0])
    except:
        conv_save = pd.DataFrame()
        dat_store = pd.HDFStore('dat_'+sampling['galaxy_component']+'_'+kstar1_range_string+'_'+kstar2_range_string+'.h5')
        total_mass_singles = 0  
        total_mass_binaries = 0
        total_mass_stars = 0
        total_n_singles = 0
        total_n_binaries = 0
        total_n_stars = 0
        idx = 0
        log_file = open('log_'+sampling['galaxy_component']+'_'+kstar1_range_string+'_'+kstar2_range_string+'.txt', 'w')

    # Initialize the step counter and convergence array/list
    Nstep = 0
    match = np.zeros(len(convergence['convergence_params']))

    # Select the Galactic component from user input
    if sampling['galaxy_component'] == 'ThinDisk':
        SFH_model='const'
        component_age=10000.0
    elif sampling['galaxy_component'] == 'Bulge':
        SFH_model='burst'
        component_age=10000.0
    elif sampling['galaxy_component'] == 'ThickDisk':
        SFH_model='burst'
        component_age=11000.0
    elif sampling['galaxy_component'] == 'DeltaBurst':
        SFH_model='delta_burst'
        component_age=13700.0

    # Simulate the fixed population
    # This process is illustrated in Fig 1 of Breivik & Larson (2018)
    steps = 0
    bcm_filter_match = []
    bpp_filter_match = []
    initC_filter_match = []

    while (Nstep < args.Niter) & (np.max(match) > convergence['match']):
        # Set random seed such that each iteration gets a unique, determinable seed
        rand_seed = seed_int + Nstep
        np.random.seed(rand_seed)

        # Select the initial binary sample method from user input
        if sampling['sampling_method'] == 'independent':
            init_samp_list = InitialBinaryTable.sampler(sampling['sampling_method'],
                                                        final_kstar1 = kstar1_range,
                                                        final_kstar2 = kstar2_range,
                                                        binfrac_model = 0.5,
                                                        primary_model = 'kroupa93',
                                                        ecc_model = 'thermal',
                                                        SFH_model = SFH_model,
                                                        component_age = component_age,
                                                        met = sampling['metallicity'],
                                                        size = args.Nstep)
            IBT, mass_singles, mass_binaries, n_singles, n_binaries = init_samp_list

        if sampling['sampling_method'] == 'multidim':
            init_samp_list = InitialBinaryTable.sampler(sampling['sampling_method'],
                                                        final_kstar1 = kstar1_range,
                                                        final_kstar2 = kstar2_range,
                                                        rand_seed = rand_seed,
                                                        nproc = args.nproc,
                                                        SFH_model = SFH_model,
                                                        component_age = component_age,
                                                        met = sampling['metallicity'],
                                                        size = args.Nstep)
            IBT, mass_singles, mass_binaries, n_singles, n_binaries = init_samp_list

        # Log the total sampled mass from the initial binary sample
        # for future Galactic occurence rate calculation
        total_mass_singles += mass_singles
        total_mass_binaries += mass_binaries
        total_mass_stars += mass_singles + mass_binaries
        total_n_singles += n_singles
        total_n_binaries += n_binaries
        total_n_stars += n_singles + 2*n_binaries

        # Now that we have all these initial conditions
        # let's create an Evolve class and evolve these systems
        # see if users specified a smaple rate for the bcm array if not set it
        # tphysf
        try:
            dtp = dictionary['bse']['dtp']
        except:
            dtp = IBT['tphysf'].values

        bpp, bcm, initCond = Evolve.evolve(initialbinarytable=IBT, BSEDict=BSEDict, nproc=args.nproc, idx=idx, dtp=dtp)
        # Keep track of the index
        idx = int(bcm.bin_num.max()+1)

        bcm_filter, bin_state_nums = utils.filter_bpp_bcm(bcm, bpp, filters, kstar1_range, kstar2_range)
        if bcm_filter.empty:
            warnings.warn("After filtering the bcm array for desired systems there were no systems matching your request. It is possible you should up to the number of binaries provessed in each iteration, i.e. Nstep")
            log_file.write("After filtering the bcm array for desired systems there were no systems matching your request. It is possible you should up to the number of binaries provessed in each iteration, i.e. Nstep")
            continue

        initC_filter = initCond.loc[initCond.bin_num.isin(bcm_filter.bin_num)]
        bpp_filter = bpp.loc[bpp.bin_num.isin(bcm_filter.bin_num)]
        # Now get the converging population
        conv_filter = utils.conv_select(bcm_filter, bpp_filter,
                                        kstar1_range, kstar2_range,
                                        convergence['convergence_filter'],
                                        convergence['convergence_limits'])
        if conv_filter.empty:
            warnings.warn("After filtering for desired convegence systems there were no systems matching your request. It is possible you are suggesting incompatible bin_state choices and convergence_filters, e.g. bin_state=[0,1], convergence_filter='disruption'")
            log_file.write("After filtering for desired convegence systems there were no systems matching your request. It is possible you are suggesting incompatible bin_state choices and convergence_filters, e.g. bin_state=[0,1], convergence_filter='disruption'")
            continue

        if convergence['bcm_bpp_initCond_filter']:
            bcm_filter = bcm_filter.loc[bcm_filter.bin_num.isin(conv_filter.bin_num)]
            bpp_filter = bpp_filter.loc[bpp_filter.bin_num.isin(conv_filter.bin_num)]
            initC_filter = initC_filter.loc[initC_filter.bin_num.isin(conv_filter.bin_num)]


        # Filter the bcm and bpp arrays according to user specified filters
        if len(bcm_filter_match) == 0:
            bcm_filter_match = bcm_filter.copy()
            bpp_filter_match = bpp_filter.copy()
            initC_filter_match = initC_filter.copy()
            conv_filter_match = conv_filter.copy()
        else:
            bcm_filter_match = bcm_filter_match.append(bcm_filter)
            bpp_filter_match = bpp_filter_match.append(bpp_filter)
            initC_filter_match = initC_filter_match.append(initC_filter)
            conv_filter_match = conv_filter_match.append(conv_filter)

        if len(conv_filter_match) >= np.min([50, args.Niter]):
            conv_save = conv_save.append(conv_filter_match)

            # perform the convergence
            if len(conv_save) == len(conv_filter_match):
                match = Match.perform_convergence(convergence['convergence_params'],
                                                  conv_save,
                                                  conv_filter_match,
                                                  log_file)
            else:
                match = Match.perform_convergence(convergence['convergence_params'],
                                                  conv_save,
                                                  conv_save.loc[~conv_save.bin_num.isin(conv_filter_match.bin_num)],
                                                  log_file)

            match_save = pd.DataFrame(np.atleast_2d(match), columns = convergence['convergence_params'])

            # write the data and the logs!
            mass_list = [total_mass_singles, total_mass_binaries, total_mass_stars]
            n_list = [total_n_singles, total_n_binaries, total_n_stars]
            utils.pop_write(dat_store, log_file, mass_list, n_list, bcm_filter_match,
                            bpp_filter_match, initC_filter_match, conv_filter_match,
                            bin_state_nums, match_save, idx)

            # reset the bcm_filter DataFrame
            bcm_filter_match = []
            bpp_filter_match = [] 
            initC_filter_match = []
            conv_filter_match = []
            log_file.write('\n')
        Nstep += args.Nstep
        log_file.flush()
    # Close the data storage file
    dat_store.close()

    log_file.write('All done friend!')
    log_file.close()

