#!/usr/bin/env python

"""
.. module:: radical.pilot-profiler
   :platform: Unix
   :synopsis: A simple runtime profiler for RADICAL-Pilot.

.. moduleauthor:: Ole Weidner <ole.weidner@rutgers.edu>
"""

__copyright__ = "Copyright 2013-2014, http://radical.rutgers.edu"
__license__   = "MIT"

import sys
import numpy
import pprint
import logging
import optparse
import radical.pilot
import datetime, time
import matplotlib.pyplot as plt

N_PILOTS = 0

#-----------------------------------------------------------------------------

def get_data(logger, session_id, database_url, database_name):
    """Populates pmgr_pilot_data and umgr_task_data lists
    Currently only supports execution pattern where mapping of 
    Unit Manager to Pilot Manager is one-to-one and only one
    pilot is launched per Pilot Manager  
    """
    pmgr_pilot_data = []
    umgr_task_data = []

    try:
        session = radical.pilot.Session(
            session_uid   = session_id,
            database_url  = database_url,
            database_name = database_name
        )
        logger.info("Connected to session %s" % session)
        
        #----------------------------------------------------------
        # get pilot data
        for pmgr in session.get_pilot_managers():
            for pilot in pmgr.get_pilots():
              # pprint.pprint (pilot.as_dict ())
              # pprint.pprint (pilot.description)
                pilot_data = []
                pilot_data.append( pilot.uid )
                pilot_data.append( pilot.description['cores'] ) 
                pilot_data.append( pilot.submission_time )
                pilot_data.append( pilot.start_time )
                pilot_data.append( pilot.stop_time )
                pmgr_pilot_data.append(pilot_data)
                global N_PILOTS
                N_PILOTS += 1

        #---------------------------------------------------------
        # get task data
        for umgr in session.get_unit_managers():
            pilot_ids = umgr.list_pilots()
            tasks = umgr.get_units()
            
            for task in tasks:
               task_data = []
               task_data.append( pilot_ids[0] )
               task_data.append( task.submission_time )
               task_data.append( task.start_time )
               task_data.append( task.stop_time )
               umgr_task_data.append( task_data )

    except radical.pilot.PilotException, ex:
            logger.error(ex)
            sys.exit(255)

    return pmgr_pilot_data, umgr_task_data
    
#-----------------------------------------------------------------------------

def calculate_pilot_runtimes(pmgr_pilot_data, umgr_task_data):
    """Populates pilot_runtimes list. Pilot runtime is calculated as
    difference between last CU's stop time and first CU's start time   
    """
    pilot_runtimes = []

    for pilot in pmgr_pilot_data:
        pilot_rtime = []
        pid = pilot[0]
        cu_start_times = []
        cu_stop_times = []
        for unit in umgr_task_data:
            if ( pid == unit[0] ):
                start_t = unit[2]
                stop_t = unit[3]
                cu_start_times.append( (start_t - datetime.datetime(1970,1,1)).total_seconds() )
                cu_stop_times.append( (stop_t - datetime.datetime(1970,1,1)).total_seconds() )

        np_start_times = numpy.array( cu_start_times )
        np_stop_times = numpy.array( cu_stop_times )
        runtime = np_stop_times.max() - np_start_times.min()
        pilot_rtime.append( pid )
        pilot_rtime.append( runtime )
        pilot_runtimes.append( pilot_rtime )

    return pilot_runtimes

#-----------------------------------------------------------------------------

def ten_round(x, base=10):
    return int(base * round(float(x)/base))

#-----------------------------------------------------------------------------

def plot_avg_cu_runtime(session_id, pmgr_pilot_data, umgr_task_data):

    plt_title = "Average CU execution times. Session ID: " + session_id
    figure_name = "avg-cu-runtimes-" + session_id + ".png"

    cu_runtimes = []
    pilot_sizes = []
    
    # populating pilot_sizes and cu_runtimes
    #-------------------------------------------- 
    for pilot in pmgr_pilot_data:
        pilot_sizes.append( pilot[1] )
        cu_times = []
        for unit in umgr_task_data:
            if ( pilot[0] == unit[0] ):
                cu_times.append( (unit[3] - unit[2]).total_seconds() )
        cu_runtimes.append( numpy.mean(cu_times ) )
    #--------------------------------------------
    t_max = max(cu_runtimes)

    # calculating y tickts
    #--------------------------------------------
    y_interval = ten_round( t_max / N_PILOTS )
    if (y_interval < 1):
        y_interval = int(t_max)
    y_ticks = []
    for i in range(0, N_PILOTS):
        y_ticks.append( i*y_interval )
    #--------------------------------------------

    #define plot size in inches (width, height) & resolution(DPI)
    fig = plt.figure(figsize=(12, 8), dpi=100)

    #define font size
    plt.rc("font", size=10)

    # plot data
    x = numpy.arange(1,(N_PILOTS+1),1)
    y = cu_runtimes

    plt.plot(x, y, linestyle="solid", marker="o", color="red", label = "Average CU execution times")
    plt.xticks(range(len(pilot_sizes)), pilot_sizes, size='small')

    # configure x axes
    plt.xlim(0.5, ( N_PILOTS + 0.5 ))
    plt.xticks(numpy.arange(1,(N_PILOTS+1),1))

    # configure y axes
    plt.ylim(0.0, (int(t_max) + y_interval))
    plt.yticks(y_ticks)

    plt.xlabel("Pilot size")
    plt.ylabel("Time in secs", size=10)
    plt.title(plt_title, size=10)

    # plot y values
    for xpoint, ypoint in zip(x, y):
        plt.annotate(round(ypoint,2), ((xpoint-0.1),(ypoint-y_interval/4)), ha='right', va='bottom', bbox=dict(fc='none', ec='none'))

    plt.legend()
    plt.savefig( figure_name )

#-----------------------------------------------------------------------------

def plot_cu_throughput(session_id, pmgr_pilot_data, umgr_task_data,  pilot_runtimes):

    plt_title = "Throughput of CUs. Session ID: " + session_id
    figure_name = "cu-throughput-" + session_id + ".png"

    cu_throughput = []
    pilot_sizes = []
    pilot_cu_count = []

    # calculating number of CU for each pilot
    #--------------------------------------------
    for pilot in pmgr_pilot_data:
        count = 0
        for unit in umgr_task_data:
            if ( pilot[0] == unit[0] ):
                count += 1
        pilot_cu_count.append( count )
            
    # populating pilot_sizes and cu_throughput
    #--------------------------------------------    
    for pilot in pmgr_pilot_data:
        pilot_sizes.append( pilot[1] )
        count = 0
        for unit in umgr_task_data:
            if ( pilot[0] == unit[0] ):
                count += 1
        for time in pilot_runtimes:
            if ( pilot[0] == time[0] ):
                cu_throughput.append( count / time[1] )
    #--------------------------------------------
    t_max = max(cu_throughput)

    # calculating y tickts
    #--------------------------------------------
    y_interval = ten_round( t_max / N_PILOTS )
    if (y_interval < 1):
        y_interval = int(t_max)
    y_ticks = []
    for i in range(0, N_PILOTS):
        y_ticks.append( i*y_interval )
    #--------------------------------------------

    #define plot size in inches (width, height) & resolution(DPI)
    fig = plt.figure(figsize=(12, 8), dpi=100)

    #define font size
    plt.rc("font", size=10)

    # plot data
    x = numpy.arange(1,(N_PILOTS+1),1)
    y = cu_throughput

    # plot data
    plt.plot(x, y, linestyle="solid", marker="o", color="red", label = "CU throughput")
    plt.xticks(range(len(pilot_sizes)), pilot_sizes, size='small')

    # configure x axes
    plt.xlim(0.5, ( N_PILOTS + 0.5 ))
    plt.xticks(numpy.arange(1,(N_PILOTS+1),1))

    # configure y axes
    plt.ylim(0.0, (int(t_max) + y_interval))
    plt.yticks(y_ticks)

    plt.xlabel("Pilot size")
    plt.ylabel("Time in secs", size=10)
    plt.title(plt_title, size=10)

    # plot y values
    for xpoint, ypoint in zip(x, y):
        plt.annotate(round(ypoint,2), ((xpoint-0.1),(ypoint-y_interval/4)), ha='right', va='bottom', bbox=dict(fc='none', ec='none'))

    plt.legend()
    plt.savefig( figure_name )

#-----------------------------------------------------------------------------

def plot_pilot_queue_time(session_id, pmgr_pilot_data, umgr_task_data):

    plt_title = "Pilot times from submission to start. Session ID: " + session_id
    figure_name = "pilot-queue-times-" + session_id + ".png"

    pilot_qtimes = []
    pilot_sizes = []

    # populating pilot_sizes and pilot_qtimes
    #--------------------------------------------    
    for pilot in pmgr_pilot_data:
        pilot_sizes.append( pilot[1] )
        pilot_qtimes.append((pilot[3] - pilot[2]).total_seconds())
    #--------------------------------------------
    t_max = max(pilot_qtimes)

    # calculating y tickts
    #--------------------------------------------
    y_interval = ten_round( t_max / N_PILOTS )
    if (y_interval < 1):
        y_interval = int(t_max)
    y_ticks = []
    for i in range(0, N_PILOTS):
        y_ticks.append( i*y_interval )
    #--------------------------------------------

    #define plot size in inches (width, height) & resolution(DPI)
    fig = plt.figure(figsize=(12, 8), dpi=100)

    #define font size
    plt.rc("font", size=10)

    # plot data
    x = numpy.arange(1,(N_PILOTS+1),1)
    y = pilot_qtimes

    plt.plot(x, y, linestyle="solid", marker="o", color="red", label = "pilot queue times")
    plt.xticks(range(len(pilot_sizes)), pilot_sizes, size='small')

    # configure x axes
    plt.xlim(0.5, ( N_PILOTS + 0.5 ))
    plt.xticks(numpy.arange(1,(N_PILOTS+1),1))

    # configure y axes
    plt.ylim(0.0, (int(t_max) + y_interval))
    plt.yticks(y_ticks)

    plt.xlabel("Pilot size")
    plt.ylabel("Time in secs", size=10)
    plt.title(plt_title, size=10)

    # plot y values
    for xpoint, ypoint in zip(x, y):
        plt.annotate(round(ypoint,2), ((xpoint-0.1),(ypoint-y_interval/4)), ha='right', va='bottom', bbox=dict(fc='none', ec='none'))

    plt.legend()
    plt.savefig( figure_name )

#-----------------------------------------------------------------------------

def plot_pilot_runtime(session_id, pmgr_pilot_data, umgr_task_data, pilot_runtimes):

    plt_title = "Pilot runtimes. Session ID: " + session_id
    figure_name = "pilot-runtimes-" + session_id + ".png"

    pilot_rtimes = []
    pilot_sizes = []

    # populating pilot_sizes and pilot_rtimes
    #--------------------------------------------    
    for pilot in pmgr_pilot_data:
        pilot_sizes.append( pilot[1] )
        for time in pilot_runtimes:
            if ( pilot[0] == time[0] ):
                pilot_rtimes.append( time[1] )
    #--------------------------------------------
    t_max = max(pilot_rtimes)

    # calculating y tickts
    #--------------------------------------------
    y_interval = ten_round( t_max / N_PILOTS )
    if (y_interval < 1):
        y_interval = int(t_max)
    y_ticks = []
    for i in range(0, N_PILOTS):
        y_ticks.append( i*y_interval )
    #--------------------------------------------

    #define plot size in inches (width, height) & resolution(DPI)
    fig = plt.figure(figsize=(12, 8), dpi=100)

    #define font size
    plt.rc("font", size=10)

    # plot data
    x = numpy.arange(1,(N_PILOTS+1),1)
    y = pilot_rtimes

    plt.plot(x, y, linestyle="solid", marker="o", color="red", label = "pilot runtimes")
    plt.xticks(range(len(pilot_sizes)), pilot_sizes, size='small')

    # configure x axes
    plt.xlim(0.5, ( N_PILOTS + 0.5 ))
    plt.xticks(numpy.arange(1,(N_PILOTS+1),1))

    # configure y axes
    plt.ylim(0.0, (int(t_max) + y_interval))
    plt.yticks(y_ticks)

    plt.xlabel("Pilot size")
    plt.ylabel("Time in secs", size=10)
    plt.title(plt_title, size=10)

    # plot y values
    for xpoint, ypoint in zip(x, y):
        plt.annotate(round(ypoint,2), ((xpoint-0.1),(ypoint-y_interval/4)), ha='right', va='bottom', bbox=dict(fc='none', ec='none'))

    plt.legend()
    plt.savefig( figure_name )

#-----------------------------------------------------------------------------

def setup_logger():
    """Configures the logging facility.

    Since this is a command line tool, we simply log to the console. 
    """
    logger = logging.getLogger('radical.pilot.profiler')

    logger.setLevel(logging.INFO)
    ch = logging.StreamHandler()
    ch.setLevel(logging.INFO)

    format_string = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
    formatter = logging.Formatter(format_string)
    
    ch.setFormatter(formatter)
    logger.addHandler(ch)

    return logger

#-----------------------------------------------------------------------------

def parse_commandline():

    usage = "usage: %prog -d -s [-n]"
    parser = optparse.OptionParser(usage=usage)

    parser.add_option('-d', '--mongodb-url',
                      metavar='URL',
                      dest='database_url',
                      help='specifies the url of the MongoDB database.')

    parser.add_option('-n', '--database-name',
                      metavar='URL',
                      dest='database_name',
                      default='radicalpilot',
                      help='specifies the name of the database [default: %default].')

    parser.add_option('-s', '--session-id',
                      metavar='SID',
                      dest='session_id',
                      help='specifies the id of the session you want to inspect.')

    # parse the whole shebang
    (options, args) = parser.parse_args()

    if options.database_url is None:
        parser.error("You must define MongoDB URL (-d/--mongodb-url). Try --help for help.")
    elif options.database_name is None:
        parser.error("You must define a database name (-n/--database-name). Try --help for help.")
    elif options.session_id is None:
        parser.error("You must define a session id (-s/--session-id). Try --help for help.")

    return options

#-----------------------------------------------------------------------------
#
if __name__ == "__main__":

    options = parse_commandline()
    logger  = setup_logger()

    pmgr_pilot_data, umgr_task_data = get_data(logger        = logger, 
                                               session_id    = options.session_id, 
                                               database_url  = options.database_url, 
                                               database_name = options.database_name
    )

    pilot_runtimes = calculate_pilot_runtimes(pmgr_pilot_data, umgr_task_data)

    # generating plots
    plot_pilot_runtime(options.session_id, pmgr_pilot_data, umgr_task_data, pilot_runtimes)
    plot_pilot_queue_time(options.session_id, pmgr_pilot_data, umgr_task_data)
    plot_cu_throughput(options.session_id, pmgr_pilot_data, umgr_task_data, pilot_runtimes)
    plot_avg_cu_runtime(options.session_id, pmgr_pilot_data, umgr_task_data)

    sys.exit(0)


