#!/usr/bin/env python

import os
import sys
import bson
import pprint
import datetime
import pymongo
import radical.utils       as ru
import radical.pilot       as rp
import radical.pilot.utils as rpu


_DEFAULT_DBURL = 'mongodb://user:password@localhost:27017/radicalpilot/'
_DEFAULT_DBURL = 'mongodb://user:password@ec2-184-72-89-141.compute-1.amazonaws.com:24242/radicalpilot/'

if  'RADICAL_PILOT_DBURL' in os.environ :
    _DEFAULT_DBURL = os.environ['RADICAL_PILOT_DBURL']

_DEFAULT_DBURL = ru.Url (_DEFAULT_DBURL)
if  not _DEFAULT_DBURL.path or '/' == _DEFAULT_DBURL.path :
    _DEFAULT_DBURL.path = 'radicalpilot'

_DEFAULT_DBURL = str(_DEFAULT_DBURL)

# ------------------------------------------------------------------------------
#
def usage (msg=None, noexit=False) :

    if  msg :
        print("\n      Error: %s" % msg)

    print("""
      usage       : %s [-d <dburl>] [-m <mode>] [-p <pid>]
      example     : %s -m sid -p 53e124edd1969c73e56f0eb2
                    find ID of the session which created that pilot

      modes :

        help      : show this message
        pid2sid   : find session ID for a pilot id (pid)
        uid2sid   : find session ID for a unit id  (uid)
        resources : list resource information

      options :
        -p <pid>  : apply mode to pilot   with given ID
        -u <pid>  : apply mode to unit    with given ID
        -s <sid>  : apply mode to session with given ID
        -d <url>  : use given database URL instead of default (%s).

      The default mode is 'help'.

""" % (sys.argv[0], sys.argv[0], _DEFAULT_DBURL))

    if  msg :
        sys.exit (1)

    if  not noexit :
        sys.exit (0)


# ------------------------------------------------------------------------------
#
def find_sid_by_pid (dbclient, dbname, pid) :

    sids = rpu.get_session_ids (dbclient[dbname])

    print("%d sessions" % len(sids))

    # IDs are globally time ordered.  We start looking from last session
    # (request is likely of recent run), but only look at session IDs which have
    # IDs older than the pilot ID

    for sid in sorted (sids, reverse=True) :
        if  sid > pid :
            sys.stdout.write ('_')
            sys.stdout.flush (   )
            continue

        try :
            docs = rpu.get_session_docs (dbclient[dbname], sid)
        except :
            sys.stdout.write ('?')
            sys.stdout.flush (   )
            continue

        for doc in docs['pilot'] :
            if  str(doc['_id']) == pid :
                sys.stdout.write ('!')
                sys.stdout.flush (   )
                print("\nfound pilot %s in session %s" % (pid, sid))
                return
        sys.stdout.write ('.')
        sys.stdout.flush (   )

    print("\npilot %s not found in any session")


# ------------------------------------------------------------------------------
#
def find_sid_by_uid (dbclient, dbname, uid) :

    sids = rpu.get_session_ids (dbclient[dbname])

    print("%d sessions" % len(sids))

    # IDs are globally time ordered.  We start looking from last session
    # (request is likely of recent run), but only look at session IDs which have
    # IDs older than the pilot ID

    for sid in sorted (sids, reverse=True) :
        if  sid > uid :
            sys.stdout.write ('_')
            sys.stdout.flush (   )
            continue

        try :
            docs = rpu.get_session_docs (dbclient[dbname], sid)
        except :
            sys.stdout.write ('?')
            sys.stdout.flush (   )
            continue

        for doc in docs['unit'] :
            if  str(doc['_id']) == uid :
                sys.stdout.write ('!')
                sys.stdout.flush (   )
                print("\nfound unit %s in session %s" % (uid, sid))
                return
        sys.stdout.write ('.')
        sys.stdout.flush (   )

    print("\nunit %s not found in any session")


# ------------------------------------------------------------------------------
#
def list_resources (verbose) :

    session = rp.Session ()

    for name in session._resource_configs :
        print(" -----------------------------------------------------------------------------")
        print(name)
        cfg = session._resource_configs[name]

        print("    %-25s : %s" % ('description', cfg['description'   ]))

        if verbose :
            print("    %-25s : %s" % ('notes', cfg['notes'         ]))

        if  'schemas' in cfg  and cfg['schemas'] :
            default_schema  = cfg['schemas'][0]
            job_manager_url = cfg[default_schema]['job_manager_endpoint']
        else :
            job_manager_url = cfg['job_manager_endpoint']

        print("    %-25s : %s" % ('host',  ru.Url(job_manager_url).host))
        print("    %-25s : %s" % ('rm',    cfg['resource_manager']))
        print("    %-25s : %s" % ('queue', cfg['default_queue' ]))

        if  verbose :
            print("    %-25s : %s" % ('spmd variation'         , cfg['spmd_variation']))
            print("    %-25s : %s" % ('python_interpreter'     , cfg['python_interpreter'     ]))
            print("    %-25s : %s" % ('default_remote_workdir' , cfg['default_remote_workdir' ]))
            print("    %-25s : %s" % ('valid_roots'            , cfg['valid_roots'            ]))
            print("    %-25s : %s" % ('pre_bootstrap_1'        , cfg['pre_bootstrap_1'        ]))
            print("    %-25s : %s" % ('pre_bootstrap_2'        , cfg['pre_bootstrap_2'        ]))
            print("    %-25s : %s" % ('bootstrapper'           , cfg['bootstrapper'           ]))
            print("    %-25s : %s" % ('task_launch_method'     , cfg['task_launch_method'     ]))
            print("    %-25s : %s" % ('mpi_launch_method'      , cfg['mpi_launch_method'      ]))
            print("    %-25s : %s" % ('agent_mongodb_endpoint' , cfg['agent_mongodb_endpoint' ]))
            print("    %-25s : %s" % ('forward_tunnel_endpoint', cfg['forward_tunnel_endpoint']))

        if  'schemas' in cfg  and cfg['schemas'] :
            print("    %-25s : %s" % ('schemas'  , cfg['schemas']))

            if  verbose :
                print()
                for schema in cfg['schemas'] :
                    print("    %s:" % schema)
                    for key in sorted (list(cfg[schema].keys ()), reverse=True) :
                        print("        %-21s : %s" % (key , cfg[schema][key]))

        else :
            if  verbose :
                print("    %-25s : %s" % ('job_manager_endpoint', cfg['job_manager_endpoint']))
                print("    %-25s : %s" % ('filesystem_endpoint' , cfg['filesystem_endpoint']))

        if  verbose :
            print(" -----------------------------------------------------------------------------")

        print()


# ------------------------------------------------------------------------------
#
def parse_commandline():

    return options


# ------------------------------------------------------------------------------
#
if __name__ == '__main__' :

    import optparse
    parser = optparse.OptionParser (add_help_option=False)

    parser.add_option('-d', '--dburl',   dest='url')
    parser.add_option('-m', '--mode',    dest='mode')
    parser.add_option('-p', '--pid',     dest='pid')
    parser.add_option('-u', '--uid',     dest='uid')
    parser.add_option('-v', '--verbose', dest='verbose', action='store_true')
    parser.add_option('-h', '--help',    dest='help',    action="store_true")

    options, args = parser.parse_args ()

    if  args :
        usage ("Too many arguments (%s)" % args)

    if  options.help :
        usage ()

    if  options.mode in ['help'] :
        usage ()

    if  not options.mode :
        usage ("No mode specified")

    if  not options.url :
        options.url = _DEFAULT_DBURL


    mode    = options.mode
    url     = options.url
    pid     = options.pid
    uid     = options.uid
    verbose = bool(options.verbose)

    host, port, dbname, cname, pname = ru.split_dburl (url, _DEFAULT_DBURL)[0:5]
    dbclient = pymongo.MongoClient (host=host, port=port)

    print("modes   : %s" % mode)
    print("db url  : %s" % url)

    if  pid :
        print("pid     : %s" % pid)

    if  uid :
        print("uid     : %s" % uid)


    for m in mode.split (',') :

        if  m not in ['pid2sid', 'uid2sid', 'resources', 'help'] :
            usage ("Unsupported mode '%s'" % m)

        if   m == 'pid2sid'   : find_sid_by_pid (dbclient, dbname, pid)
        if   m == 'uid2sid'   : find_sid_by_uid (dbclient, dbname, uid)
        elif m == 'resources' : list_resources  (verbose)
        elif m == 'help'      : usage (noexit=True)
        else                  : usage ("unknown mode '%s'" % mode)

    # ------------------------------------------------------------------------------------
    dbclient.close()

# ------------------------------------------------------------------------------

