#!/usr/bin/env python
# coding: utf-8

# Force Python version
from payu import reversion
reversion.repython('2.7.5', __file__)

# Standard library
import argparse
import os
import re
import sys
import shlex
import socket
import subprocess as sp

# External dependencies
import yaml

# Local
from payu.modelindex import index as model_index

# TODO: Move to some sort of config file
default_config_fname = 'config.yaml'

#---
def payu_parse():

    #------------------
    # Parser generation

    parser = argparse.ArgumentParser()

    parser.add_argument('-m', '--model',
                        action='store',
                        dest='model_name',
                        default=None,
                        help='Select model type')

    parser.add_argument('-c', '--config',
                        action='store',
                        dest='config_path',
                        default=None,
                        help='Configuration path')

    subcmd = parser.add_subparsers()

    #---
    sweep_cmd = subcmd.add_parser('sweep')
    sweep_cmd.set_defaults(cmd=call_sweep)

    sweep_cmd.add_argument('--hard',
                           action='store_true',
                           dest='hard_sweep')

    #---
    list_cmd = subcmd.add_parser('list')
    list_cmd.set_defaults(cmd=call_list)

    #---
    list_cmd = subcmd.add_parser('init')
    list_cmd.set_defaults(cmd=call_init)

    #---
    run_cmd = subcmd.add_parser('run')
    run_cmd.set_defaults(cmd=call_run)

    run_cmd.add_argument('--initial', '-i',
                         action='store',
                         dest='init_run')

    run_cmd.add_argument('--nruns', '-n',
                         action='store',
                         dest='n_runs')

    #---
    archive_cmd = subcmd.add_parser('archive')
    archive_cmd.set_defaults(cmd=call_archive)

    #---
    collate_cmd = subcmd.add_parser('collate')
    collate_cmd.set_defaults(cmd=call_collate)

    collate_cmd.add_argument('--initial', '-i',
                         action='store',
                         dest='init_run')

    collate_cmd.add_argument('--nruns', '-n',
                         action='store',
                         dest='n_runs')

    #---
    # Display help if no arguments are provided
    if len(sys.argv) == 1:
        parser.print_help()
        sys.exit()

    args = parser.parse_args()

    # Skip validation when listing supported models
    if args.cmd == call_list:
        call_list()
        sys.exit()

    config_path = args.config_path
    model_name = args.model_name

    #--------------------
    # Payu tag validation

    # Validate the configuration file path
    if config_path and not os.path.isfile(config_path):
        sys.exit('payu: error: Configuration file {} does not exist.'
                 ''.format(config_path))

    # Assign the default config path if it exists
    if not config_path and os.path.isfile(default_config_fname):
        config_path = os.path.join(os.curdir, default_config_fname)

    # If no model name is specified, then check the config path
    if not model_name and config_path and os.path.isfile(config_path):
        with open(config_path, 'r') as config_file:
            config = yaml.load(config_file)
        model_name = config.get('model')

    # If there is still no defined model_name, try the parent directory
    if not model_name:
        model_name = os.path.basename(os.path.abspath(os.pardir))
        if not model_name in model_index.keys():
            sys.exit('payu: error: Unknown model {}.'.format(model_name))
        else:
            print('payu: warning: Assuming model is {model} based on '
                  'parent directory name.'.format(model=model_name))

    #---
    # Counters
    if args.cmd in (call_run, call_collate):
        # TODO: Check non-integer case
        if args.init_run:
            init_run = int(args.init_run)
            assert init_run >= 0
        else:
            init_run = None

        if args.n_runs:
            n_runs = int(args.n_runs)
            assert n_runs > 0
        else:
            n_runs = None

        # Parse config arguments
        with open(config_path, 'r') as config_file:
            pbs_config = yaml.load(config_file)

        # Determine payu module name and parent path
        payu_modulename = re.search('[^:]*payu[^:]*',
                                    os.environ['LOADEDMODULES']).group()

        payu_modulepath = re.search('[^:]*(?=/{0})'.format(payu_modulename),
                                    os.environ['_LMFILES_']).group()

        # Set environment variables
        pbs_vars = {'PYTHONPATH': os.environ['PYTHONPATH'],
                    'PAYU_MODULENAME': payu_modulename,
                    'PAYU_MODULEPATH': payu_modulepath,
                    }
        if init_run is not None:
            pbs_vars['PAYU_CURRENT_RUN'] = init_run
        if n_runs is not None:
            pbs_vars['PAYU_N_RUNS'] = n_runs

        args.cmd(pbs_config, pbs_vars, init_run, n_runs)

    elif args.cmd == call_sweep:
        call_sweep(config_path, args.hard_sweep)

    elif args.cmd == call_init:
        call_init(model_name, config_path)


#---
def call_run(pbs_config, pbs_vars, current_run=None, n_runs=None):

    # Set the queue
    # NOTE: Maybe force all jobs on the normal queue
    if not 'queue' in pbs_config:
        pbs_config['queue'] = 'normal'

    # TODO: Find a better way to get this, PBS Pro qstat is just too slow
    #qstat_cmd = 'qstat -Qf normal-node'
    #qstat_out = sp.check_output(shlex.split(qstat_cmd))
    #qstat_npernode = re.search('\n *resources_max.ncpus *= *[\d]+', qstat_out)
    #
    #if qstat_npernode:
    #    max_cpus_per_node = int(qstat_npernode.group().split('=')[1])
    #else:
    #    # NOTE: If this didn't work, then assume we are on vayu
    #    max_cpus_per_node = 8
    max_cpus_per_node = 16

    # Increase the cpu request to match a complete node
    n_cpus = pbs_config.get('ncpus', 1)
    n_cpus_per_node = pbs_config.get('npernode', max_cpus_per_node)

    assert n_cpus_per_node <= max_cpus_per_node

    node_misalignment = n_cpus % max_cpus_per_node != 0
    node_increase = n_cpus_per_node < max_cpus_per_node

    # Increase the CPUs to accomodate the cpu-per-node request
    if n_cpus > max_cpus_per_node and (node_increase or node_misalignment):

        # Number of requested nodes
        n_nodes = 1 + (n_cpus - 1) // n_cpus_per_node
        n_cpu_request = max_cpus_per_node * n_nodes
        n_inert_cpus = n_cpu_request - n_cpus

        print('payu: warning: Job request includes {} unused CPUs.'
              ''.format(n_inert_cpus))

        # Increase CPU request to match the effective node request
        n_cpus = max_cpus_per_node * n_nodes

        # Update the ncpus field in the config
        if n_cpus and n_cpus != pbs_config['ncpus']:
            print('payu: warning: CPU request increased from {} to {}'
                  ''.format(pbs_config['ncpus'], n_cpus))
            pbs_config['ncpus'] = n_cpus

    # Set memory to use the complete node if unspeficied
    # TODO: Move RAM per node as variable
    pbs_mem = pbs_config.get('mem')
    if not pbs_mem and n_cpus > max_cpus_per_node:
        pbs_config['mem'] = '{}GB'.format((n_cpus // max_cpus_per_node) * 31)

    # NOTE: Using ``__file__`` is slightly volatile
    run_script = os.path.join(os.path.dirname(__file__), 'payu-run')
    submit_job(run_script, pbs_config, pbs_vars)


#---
def call_collate(pbs_config, pbs_vars, current_run=None, n_runs=None):
    # TODO: Add run counter support

    collate_queue = pbs_config.get('collate_queue', 'copyq')
    pbs_config['queue'] = collate_queue

    # Collation jobs are (currently) serial
    pbs_config['ncpus'] = 1

    # Modify jobname
    pbs_config['jobname'] = pbs_config['jobname'][:13] + '_c'

    # Replace (or remove) walltime
    collate_walltime = pbs_config.get('collate_walltime')
    if collate_walltime:
        pbs_config['walltime'] = collate_walltime
    else:
        # Remove the model walltime if set
        try:
            pbs_config.pop('walltime')
        except KeyError:
            pass

    # Replace (or remove) memory request
    collate_mem = pbs_config.get('collate_mem')
    if collate_mem:
        pbs_config['mem'] = collate_mem
    else:
        # Remove the model memory request if set
        try:
            pbs_config.pop('mem')
        except KeyError:
            pass

    # NOTE: Using ``__file__`` is slightly volatile
    collate_script = os.path.join(os.path.dirname(__file__), 'payu-collate')
    submit_job(collate_script, pbs_config, pbs_vars)


#---
def call_sweep(config_path, hard_sweep=False):

    cmd = 'payu-sweep -c {}'.format(config_path)
    if hard_sweep:
        cmd = ' '.join([cmd, '--hard'])

    cmd = shlex.split(cmd)
    rc = sp.call(cmd)
    assert rc == 0


#---
def call_init(model_name, config_path):

    cmd = 'payu-init -m {} -c {}'.format(model_name, config_path)
    cmd = shlex.split(cmd)
    rc = sp.call(cmd)
    assert rc == 0


#---
def call_archive(model_name, config_path):

    cmd = 'payu-archive -m {} -c {}'.format(model_name, config_path)
    cmd = shlex.split(cmd)
    rc = sp.call(cmd)
    assert rc == 0


#---
def call_list():

    print('Supported models: {0}'.format(' '.join(model_index.keys())))


#---
def submit_job(pbs_script, pbs_config, pbs_vars=None):

    hostname = re.sub('\d+$', '', socket.gethostname())

    pbs_qsub = 'qsub'
    pbs_flags = []

    pbs_queue = pbs_config.get('queue', 'normal')
    pbs_flags.append('-q {}'.format(pbs_queue))

    # Raijin doesn't read $PROJECT, which is required at login
    pbs_project = pbs_config.get('project', os.environ['PROJECT'])
    pbs_flags.append('-P {}'.format(pbs_project))

    pbs_walltime = pbs_config.get('walltime')
    if pbs_walltime:
        pbs_flags.append('-l walltime={}'.format(pbs_walltime))

    pbs_ncpus = pbs_config.get('ncpus')
    if pbs_ncpus:
        pbs_flags.append('-l ncpus={}'.format(pbs_ncpus))

    pbs_mem = pbs_config.get('mem')
    if pbs_mem:
        mem_rname = 'vmem' if hostname == 'vayu' else 'mem'
        pbs_flags.append('-l {}={}'.format(mem_rname, pbs_mem))

    pbs_jobname = pbs_config.get('jobname')
    if pbs_jobname:
        # TODO: Only truncate when using PBSPro
        pbs_jobname = pbs_jobname[:15]
        pbs_flags.append('-N {}'.format(pbs_jobname))

    pbs_priority = pbs_config.get('priority')
    if pbs_priority:
        pbs_flags.append('-p {}'.format(pbs_priority))

    pbs_wd = '-wd' if hostname == 'vayu' else '-l wd'
    pbs_flags.append(pbs_wd)

    # TODO: Make this optional
    pbs_flags.append('-j oe')

    if pbs_vars:
        pbs_vstring = ','.join('{}={}'.format(k, v)
                               for k, v in pbs_vars.iteritems())
        pbs_flags.append('-v ' + pbs_vstring)

    # Append any additional qsub flags here
    pbs_flags_extend = pbs_config.get('qsub_flags')
    if pbs_flags_extend:
        pbs_flags.append(pbs_flags_extend)

    # Collect flags
    pbs_flags = ' '.join(pbs_flags)

    # Construct full command
    cmd = '{} {} {}'.format(pbs_qsub, pbs_flags, pbs_script)

    cmd = shlex.split(cmd)
    rc = sp.call(cmd)
    assert rc == 0


#-------------------------
if __name__ == '__main__':

    payu_parse()
