#!/usr/bin/env python3

import os
import sys
import pprint

import radical.utils       as ru
import radical.pilot       as rp
import radical.pilot.utils as rpu


# for graphing events, we assign numerical pseudo values to each event.
# Note: make sure that those are translated back into event tags via
# 'set [xy]tics'
#
#   set xtics("lbl1" 1, "lbl2" 2, "lbl3" 3, "lbl4" 4)
#
_EVENT_ENCODING = {
        'session': {
            'created'                  :  1
        },
        'pmgr': {
        },
        'pilot' : {
            rp.PMGR_LAUNCHING_PENDING  :  1,
            rp.PMGR_LAUNCHING          :  2,
            rp.PMGR_ACTIVE_PENDING     :  3,
            rp.PMGR_ACTIVE             :  4,
            rp.DONE                    :  5,
            rp.CANCELED                :  6,
            rp.FAILED                  :  7
        },
        'tmgr': {
            rp.NEW                          :  0,
            rp.TMGR_SCHEDULING_PENDING      :  1,
            rp.TMGR_SCHEDULING              :  2,
            rp.TMGR_STAGING_INPUT_PENDING   :  3,
            rp.TMGR_STAGING_INPUT           :  4,
            rp.TMGR_STAGING_OUTPUT_PENDING  :  5,
            rp.TMGR_STAGING_OUTPUT          :  6,
            rp.DONE                         :  7,
            rp.FAILED                       :  8,
            rp.CANCELED                     :  9
        },
        'task': {
            rp.NEW                          :  1,
            rp.TMGR_SCHEDULING_PENDING      :  2,
            rp.TMGR_SCHEDULING              :  3,
            rp.TMGR_STAGING_INPUT_PENDING   :  4,
            rp.TMGR_STAGING_INPUT           :  5,
            rp.AGENT_STAGING_INPUT_PENDING  :  6,
            rp.AGENT_STAGING_INPUT          :  7,
            rp.AGENT_SCHEDULING_PENDING     :  8,
            rp.AGENT_SCHEDULING             :  9,
            rp.AGENT_EXECUTING_PENDING      : 10,
            rp.AGENT_EXECUTING              : 11,
            rp.AGENT_STAGING_OUTPUT_PENDING : 12,
            rp.AGENT_STAGING_OUTPUT         : 13,
            rp.TMGR_STAGING_OUTPUT_PENDING  : 14,
            rp.TMGR_STAGING_OUTPUT          : 15,
            rp.DONE                         : 16,
            rp.CANCELED                     : 17,
            rp.FAILED                       : 18
        }
    }

# # common states
# DONE
# CANCELED
# FAILED
#
# # pilot states
# PMGR_LAUNCHING_PENDING
# PMGR_LAUNCHING
# PMGR_ACTIVE_PENDING
# PMGR_ACTIVE
#
# # Task States
# NEW
# PENDING_INPUT_TRANSFER
# TRANSFERRING_INPUT
#
# # TMGR states
# TMGR_SCHEDULING
#
# # PMGR states
# PMGR_LAUNCHING
# PMGR_ACTIVE
#
# PENDING_OUTPUT_TRANSFER
# TRANSFERRING_OUTPUT


# ------------------------------------------------------------------------------
#
def usage(msg=None, noexit=False):

    if msg:
        print("\n      Error: %s" % msg)

    print("""
      usage     : %s -m <mode>[,<mode>] [-d <dburl>] [-s <session>[,<session>]]

      examples  : %s -m stat -d mongodb://user:password@localhost/radicalpilot -s 536afe101d41c83696ea0135
                  %s -m list -d mongodb://user:password@localhost/radicalpilot
                  %s -m tree -s 536afe101d41c83696ea0135,536afe101d41c83696ea0136

      mode(s):

        help    : show this message
        list    : show  a  list   of sessions in the database
        tree    : show  a  tree   of session objects
        dump    : show  a  tree   of session objects, with full details
        sort    : show  a  list   of session objects, sorted by type
        stat    : show statistics of session


      arguments:

        -d      : database URL
        -s      : session id(s)
        -c      : cachedir where <sid>.json caches are kept
        -p      : profile directory where <sid-pid>.prof files are kept
        -f      : filter for listings
        -t      : terminal type for plotting (pdf and/or png, default is both)


      filters:

        The following filters are supported at the moment:

        'pilots = n' : only list sessions with exactly    'n' pilots
        'pilots > n' : only list sessions with more  than 'n' pilots
        'pilots < n' : only list sessions with fewer than 'n' pilots

      Notes:

        The default mode is 'list'.
        Multiple modes can be specified as <mode>,<mode>,...
        For mode 'prof', multiple session ids can be specified as <session>,<session>,...

""" % (sys.argv[0], sys.argv[0], sys.argv[0], sys.argv[0]))

    if msg:
        sys.exit(1)

    if not noexit:
        sys.exit(0)


# ------------------------------------------------------------------------------
#
def dump_session(db, dbname, session):

    print("session : %s" % session)
    handle_session(db, 'dump', dbname, session, None)


# ------------------------------------------------------------------------------
#
def tree_session(db, dbname, session):

    handle_session(db, 'tree', dbname, session, None)


# ------------------------------------------------------------------------------
#
def list_sessions(db, cachedir, filters):

    sids = rpu.get_session_ids(db)

    if filters:

        if filters.startswith('pilots'):

            filters = filters.replace(' ', '')
            op      =     filters[6 ]
            n       = int(filters[7:])
            f_sids  = list()

            for sid in sids:

                docs     = rpu.get_session_docs(sid, cachedir=cachedir)
                n_pilots = len(docs['pilot'])

                if op == '>' and n_pilots >  n : f_sids.append(sid)
                if op == '=' and n_pilots == n : f_sids.append(sid)
                if op == '<' and n_pilots <  n : f_sids.append(sid)

                print("%s: %3s (%s)" % (sid, n_pilots, len(f_sids)))

            sids = f_sids


    if not sids:
        print('no matching sessions recorded in database at %s' % url)

    else:
        print("Session IDs:")
        for sid in sids:
            print("  %s" % sid)


# ------------------------------------------------------------------------------
def sort_session(session, cachedir):

    docs = rpu.get_session_docs(session, cachedir=cachedir)

    print("pilot managers :")
    for doc in docs['pmgr']:
        print("  %s" %  doc['_id'])

    print("pilots :")
    for doc in docs['pilot']:
        print("  %s" %  doc['_id'])

    print("task manager")
    for doc in docs['tmgr']:
        print("  %s" %  doc['_id'])

    print("tasks")
    for doc in docs['task']:
        print("  %s" %  doc['_id'])


# ------------------------------------------------------------------------------
def get_stats(session):

    n_tasks               = 0
    n_pilots              = 0
    pilot_stats           = dict()
    tasks                 = dict()

    pilot_stats['pilots'] = dict()

    # Trying to find module on sys.path
    import radical.analytics as ra

    s_session = ra.Session(session, 'radical.pilot')
    s_pilots  = s_session.filter(etype='pilot', inplace=False)
    s_tasks   = s_session.filter(etype='task' , inplace=False)

    pilot_stats['s_pilot'] = s_pilots
    pilot_stats['s_tasks'] = s_tasks
    pilot_stats['session_created'] = s_session.timestamps(state='NEW')[0]

    pdd = ra.utils.tabulate_durations(rp.utils.PILOT_DURATIONS_DEBUG)
    tdd  = ra.utils.tabulate_durations(rp.utils.TASK_DURATIONS_DEBUG)

    for task in s_tasks.get():
        tasks[task.uid] = task

    for pilot in s_pilots.get():

        n_pilots   += 1
        pid         = pilot.uid
        pilot_info  = dict()

        pilot_info['resource']     = pilot.resources
        pilot_info['cores']        = pilot.description['cores']
        pilot_info['n_tasks']      = 0
        pilot_info['task_states']  = dict()
        pilot_info['pilot_states'] = list()

        for state in pilot.states:
            try:
                for s in pdd:
                    if state == s['Start Timestamp']:
                        dur = pilot.duration(event=rp.utils.PILOT_DURATIONS_DEBUG[s['Duration Name']])
                        pilot_info['pilot_states'].append({'state'   : state,
                                                           'duration': dur})
            except ValueError:
                pass

        for task in s_tasks.get():
            uid                    = str(task.uid)
            n_tasks               += 1
            pilot_info['n_tasks'] += 1

            if uid not in tasks:
                print('unknonwn task %s' % uid)
                sys.exit()

            for state in task.states:
                try:
                    for s in tdd:
                        if s['Start Timestamp'] == state:
                            pilot_info['task_states'][state] = dict()
                            pilot_info['task_states'][state]['dur'] = list()
                            dur = task.duration(event=rp.utils.TASK_DURATIONS_DEBUG[s['Duration Name']])
                            pilot_info['task_states'][state]['dur'].append(dur)
                except ValueError:
                    pass

        started  = pilot.timestamps(state='NEW')[0]
        finished = pilot.timestamps(state=['DONE', 'FAILED', 'CANCELED'])[0]
        cores    = pilot.description['cores']

        if not started or not finished or not cores:
            pilot_runtime = 1
        else:
            pilot_runtime = (finished - started) * cores

        pilot_busy = pilot.duration(event=rp.utils.PILOT_DURATIONS_DEBUG['p_agent_runtime'])

        pilot_info['started']     = started
        pilot_info['finished']    = finished
        pilot_info['cpu_burned']  = pilot_runtime
        pilot_info['cpu_used']    = pilot_busy
        pilot_info['utilization'] = pilot_busy * 100 / pilot_runtime

        for s in pilot_info['task_states']:

            import numpy
            array = numpy.array(pilot_info['task_states'][s]['dur'])
            pilot_info['task_states'][s]['num' ] = len       (array)
            pilot_info['task_states'][s]['mean'] = numpy.mean(array)
            pilot_info['task_states'][s]['std' ] = numpy.std (array)
            pilot_info['task_states'][s]['dur' ] = list()


        pilot_stats['pilots'][pid] = pilot_info

    pilot_stats['n_pilots'] = n_pilots

    return pilot_stats


# ------------------------------------------------------------------------------
def stat_session(session, cachedir):

    docs       = rpu.get_session_docs(session, cachedir=cachedir)
    stats      = get_stats(session)

    t_x_start  = None    # first task started executing
    t_y_start  = None    # last  task started executing
    t_x_stop   = None    # last task finished executing

    task_cores = dict()  # sum of core counts from all tasks per pilot

    for task in stats['s_tasks'].get():
        started  = sorted(task.timestamps(state="AGENT_EXECUTING"))[0]
        finished = sorted(task.timestamps(state=["DONE", "FAILED", "CANCELED"]))[0]

        if not t_x_start : t_x_start =     started
        else             : t_x_start = min(started,  t_x_start)
        if not t_y_start : t_y_start =     started
        else             : t_y_start = max(started,  t_y_start)
        if not t_x_stop  : t_x_stop  =     finished
        else             : t_x_stop  = max(finished, t_x_stop )

        pid = task.cfg.get('pilot', None)
        if not pid in task_cores : task_cores[pid]  = task.description.get('cores', 1)
        else                     : task_cores[pid] += task.description.get('cores', 1)

    t_0  = stats['session_created']       # session started (absolute time)
    t_h  = t_x_start - t_0                # session start 'til first execution
    t_x  = t_x_stop  - t_x_start          # first start 'til last stop  of task execution
    t_xy = t_y_start - t_x_start          # first start 'til last start of task execution

    session_name = docs['session'][0]['_id']

    print("Session Statistics")
    print("------------------")
    print("")
    print("  session: %s"    % session_name)
    print("  pilots : %s"    % stats['n_pilots'])
    print("  t_0    : %12.1fs" % t_0)
    print("  t_h    : %12.1fs" % t_h)
    print("  t_x    : %12.1fs" % t_x)
    print("  t_xy   : %12.1fs" % t_xy)

    for pid in stats['pilots']:

        pilot = stats['pilots'][pid]

        print("  pilot [%s] [%s]"       % (pid, pilot['resource']))
        print("      cores       : %6d" % pilot['cores'])
        print("      tasks       : %6d" % pilot['n_tasks'])
        print("      tasks*cores : %6s" % task_cores.get(pid, '-'))

        urate = pilot['n_tasks'] / t_xy
        print("      task-rate   : %7.2f/sec" % urate)

        print("      utilization : %7.2f%%" % pilot['utilization'])
        print("      pilot states:")
        for ps in pilot['pilot_states']:
            print("        state %-30s : %10.2fs" % (ps['state'], ps['duration']))
        print("      task states :")
        for us  in pilot['task_states']:
            data = pilot['task_states'][us]
            print("        state %-30s : %10.2fs  +/- %8.2fs" % (us, data['mean'], data['std']))


# ------------------------------------------------------------------------------
def handle_session(db, mode, dbname, session, pname):
    """
    For the given db, traverse collections
    """

    # FIXME: cache(dir) is not used

    print(" +-- db   %s" % dbname)

    cnames = list()
    cnames.append("%s"    % session)
    cnames.append("%s.pm" % session)
    cnames.append("%s.p"  % session)
    cnames.append("%s.um" % session)
    cnames.append("%s.t"  % session)

    for name in cnames:

        if mode == 'list' and not cname:
            print(" | +-- coll %s" % name)

        elif mode == 'remove' and not pname:
            try:
                db.drop_collection(name)
                print("  removed collection %s" % name)
            except:
                pass  # ignore errors

        else:
            handle_coll(db, mode, name, pname)


# ------------------------------------------------------------------------------
def handle_coll(db, mode, cname, pname):
    """
    For a given collection, traverse all documents
    """

    if 'indexes' in cname:
        return

    collection = db[cname]
    print(" | +-- coll %s" % cname)

    docs = collection.find()

    for doc in docs:

        name = doc['_id']

        if mode == 'list' and not pname:
            print(" | | +-- doc  %s" % name)

        elif mode == 'remove':
            if (not pname) or (str(name) == str(pname)):
                try:
                    collection.remove(name)
                    print("  removed document %s" % name)
                except:
                    pass  # ignore errors

        else:
            if (not pname) or (str(name) == str(pname)):
                handle_doc(mode, doc)


# ------------------------------------------------------------------------------
def handle_doc(mode, doc):
    """
    And, surprise, for a given document, show it according to 'mode'
    """

    name = doc['_id']

    if mode == 'list':

        for key in doc:
            print(" | | | +-- %s" % (key))

    elif mode == 'tree':
        print(" | | +-- doc  %s" % (name))
        for key in doc:
            print(" | | | +-- %s" % (key))

    elif mode == 'dump':
        print(" | | +-- doc  %s" % (name))
        for key in doc:
            txt_in  = pprint.pformat(doc[key])
            txt_out = ""
            lnum    = 1
            for line in txt_in.split('\n'):
                if lnum != 1:
                    txt_out += ' | | | |                '
                txt_out += line
                txt_out += '\n'
                lnum    += 1

            print(" | | | +-- %-10s : %s" % (key, txt_out[:-1]))  # remove last \n


# ------------------------------------------------------------------------------
#
def parse_commandline():

    return args


# ------------------------------------------------------------------------------
#
if __name__ == '__main__':

    import argparse
    parser = argparse.ArgumentParser(description='')

    parser.add_argument('-s', '--session',   dest='session')
    parser.add_argument('-d', '--dburl',     dest='url')
    parser.add_argument('-m', '--mode',      dest='mode')
    parser.add_argument('-c', '--cachedir',  dest='cachedir')
    parser.add_argument('-p', '--profdir',   dest='profdir')
    parser.add_argument('-f', '--filter',    dest='filters')
    parser.add_argument('-t', '--terminal',  dest='term')

    args = parser.parse_args()

    if args.mode in ['help']:
        usage()

    if not args.mode:
        usage("No mode specified")

    mode     = args.mode
    url      = args.url
    session  = args.session
    filters  = args.filters
    term     = args.term
    cachedir = args.cachedir
    profdir  = args.profdir

    if not url:
        url = os.environ.get('RADICAL_PILOT_DBURL')

    if not url:
        raise Exception('RADICAL_PILOT_DBURL is not set')

    if not term:
        term = "pdf,png"

    if not cachedir:
        cachedir = os.getcwd()

    if not os.path.isdir(cachedir):
        usage("%s is no valid cachedir" % cachedir)

    print("modes   : %s" % mode)
    print("db url  : %s" % url)
    print("cachedir: %s" % cachedir)
    mongo, db, dbname, cname, pname = ru.mongodb_connect(str(url))

    if not session and mode != 'list':
        usage("mode %s needs a session id specified" % mode)
    elif session and mode == 'list':
        usage("invalid session parameter '%s' on mode 'list'" % session)
    else:
        print("session : %s" % session)

    for m in mode.split(','):

        if m not in ['list', 'dump', 'tree', 'sort', 'stat', 'help']:
            usage("Unsupported mode '%s'" % m)

        if   m == 'list' : list_sessions(db, cachedir, filters)
        elif m == 'tree' : tree_session (db, dbname, session)
        elif m == 'dump' : dump_session (db, dbname, session)
        elif m == 'sort' : sort_session (session, cachedir)
        elif m == 'stat' : stat_session (session, cachedir)
        elif m == 'help' : usage(noexit=True)
        else             : usage("unknown mode '%s'" % mode)

    # ------------------------------------------------------------------------------------
    mongo.close()

# ------------------------------------------------------------------------------

