#!/usr/bin/env python
"""
A command-line program to fit a StarModel using the isochrones package

Input argument is name of a folder that contains a file
called ``star.ini``, which is a config file containing all
the observed properties of the star on which the model should
be conditioned.  Multiple folder names can also be passed.

"""
from __future__ import division, print_function

import matplotlib
matplotlib.use('agg')

import os, os.path, re, sys
import logging
import time

from configobj import ConfigObj
import argparse

from isochrones.starmodel import StarModel, BinaryStarModel, TripleStarModel

def initLogging(filename, logger):
    if logger == None:
        logger = logging.getLogger()
    else:  # wish there was a logger.close()
        for handler in logger.handlers[:]:  # make a copy of the list
            logger.removeHandler(handler)

    logger.setLevel(logging.INFO)
    formatter = logging.Formatter(fmt='%(asctime)s: %(message)s')

    fh = logging.FileHandler(filename)
    fh.setFormatter(formatter)
    logger.addHandler(fh)

    sh = logging.StreamHandler(sys.stdout)
    logger.addHandler(sh)
    return logger


if __name__=='__main__':

    parser = argparse.ArgumentParser(description='Fit physical properties of a star conditioned on observed quantities.')

    parser.add_argument('folders', nargs='*', default=['.'])
    parser.add_argument('--binary', action='store_true')
    parser.add_argument('--triple', action='store_true')
    parser.add_argument('--all', action='store_true')
    parser.add_argument('--models', default='dartmouth')
    parser.add_argument('--threads', type=int, default=1)
    parser.add_argument('--no_local_fehprior', action='store_true')
    
    args = parser.parse_args()

    if args.models=='dartmouth':
        from isochrones.dartmouth import Dartmouth_Isochrone
        ichrone = Dartmouth_Isochrone()
    elif args.models=='padova':
        from isochrones.padova import Padova_Isochrone
        ichrone = Padova_Isochrone()
    elif args.models=='basti':
        from isochrones.basti import Basti_Isochrone
        ichrone = Basti_Isochrone()
    else:
        raise ValueError('Unknown stellar models: {}'.format(args.models))

    if args.all:
        multiplicities = ['single', 'binary', 'triple']
    elif args.binary:
        multiplicities = ['binary']
    elif args.triple:
        multiplicities = ['triple']
    else:
        multiplicities = ['single']
    
    Models = {'single':StarModel,
              'binary':BinaryStarModel,
              'triple':TripleStarModel}

    logger = None #dummy
    
    for i,folder in enumerate(args.folders):
        for mult in multiplicities:
            Model = Models[mult]
            model_filename = '{}_starmodel_{}.h5'.format(args.models, mult)


            print('{} of {}: {} ({})'.format(i+1, len(args.folders), folder, mult))
            #initialize logger for folder
            logfile = os.path.join(folder, 'run.log')
            logger = initLogging(logfile, logger)

            try:
                start = time.time()
                ini_file = os.path.join(folder, 'star.ini')
                config = ConfigObj(ini_file)

                props = {}
                for kw in config.keys():
                    try:
                        props[kw] = float(config[kw])
                    except:
                        props[kw] = (float(config[kw][0]), float(config[kw][1]))


                use_local_fehprior = not args.no_local_fehprior
                mod = Model(ichrone, **props)
                mod.fit_mcmc(threads=args.threads, 
                             loglike_kwargs={'use_local_fehprior':use_local_fehprior})
                triangle_base = os.path.join(folder, '{}_triangle_{}'.format(args.models, mult))
                mod.triangle_plots(triangle_base)
                mod.save_hdf(os.path.join(folder, model_filename))

                end = time.time()
                logger.info('{} starfit successful for '.format(mult) +
                            '{} in {:.1f} minutes.'.format(folder, (end-start)/60))
            except KeyboardInterrupt:
                logger.error('{} starfit calculation interrupted for {}.'.format(mult,folder))
                raise
            except:
                logger.error('{} starfit calculation failed for {}.'.format(mult,folder),
                             exc_info=True)
