#!/usr/bin/env python
# coding=utf-8
# Copyright (C) Duncan Macleod (2014)
#
# This file is part of GWSumm.
#
# GWSumm is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# GWSumm is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with GWSumm.  If not, see <http://www.gnu.org/licenses/>.

"""Generate Omega scans of the loudest event triggers.

This tool can be used to find the loudest N triggers of any ETG that follows
the DetChar trigger conventions set out in LIGO-T1300468, sort them by
one of a large number of columnar/functional statistics, and generate a
HTCondor Directed Acyclic Graph (DAG) to generate an Omega scan report of
each.
"""

from __future__ import print_function

import os
import argparse

from glue.datafind import GWDataFindHTTPConnection
from glue.lal import Cache
from glue import pipeline

from gwpy.segments import (Segment, SegmentList, DataQualityFlag)
from gwpy.table.lsctables import SnglBurstTable
from gwpy.time import to_gps

import gwtrigfind

from gwsumm import __version__

__author__ = 'Duncan Macleod <duncan.macleod@ligo.org>'


# -----------------------------------------------------------------------------
# Utilities

class OmegaPipelineJob(pipeline.CondorDAGJob):
    """Job representing a configurable instance of omega scan.
    """
    logtag = '$(cluster)-$(process)'

    def __init__(self, universe, executable, command='scan', tag='omega_scan',
                 subdir=None, logdir=None, **cmds):
        pipeline.CondorDAGJob.__init__(self, universe, executable)
        if subdir:
            subdir = os.path.abspath(subdir)
            self.set_sub_file(os.path.join(subdir, '%s.sub' % (tag)))
        if logdir:
            logdir = os.path.abspath(logdir)
            self.set_log_file(os.path.join(
                logdir, '%s-%s.log' % (tag, self.logtag)))
            self.set_stderr_file(os.path.join(
                logdir, '%s-%s.err' % (tag, self.logtag)))
            self.set_stdout_file(os.path.join(
                logdir, '%s-%s.out' % (tag, self.logtag)))
        cmds.setdefault('getenv', 'True')
        for key, val in cmds.iteritems():
            self.add_condor_cmd(key, val)
        # add sub-command option
        self._command = command

    def add_opt(self, opt, value=''):
        pipeline.CondorDAGJob.add_opt(self, opt, str(value))
    add_opt.__doc__ = pipeline.CondorDAGJob.add_opt.__doc__

    def set_command(self, command):
        self._command = command

    def get_command(self):
        return self._command

    def write_sub_file(self):
        pipeline.CondorDAGJob.write_sub_file(self)
        if self.get_command():
            with open(self.get_sub_file(), 'r') as f:
                sub = f.read()
            sub = sub.replace('arguments = "', 'arguments = " %s'
                              % self.get_command())
            with open(self.get_sub_file(), 'w') as f:
                f.write(sub)


class OmegaPipelineDAGNode(pipeline.CondorDAGNode):
    def get_cmd_line(self):
        cmd = pipeline.CondorDAGNode.get_cmd_line(self)
        if self.job().get_command():
            return '%s %s' % (self.job().get_command(), cmd)
        else:
            return cmd


def which(program):
    """Find full path of executable program
    """
    def is_exe(fpath):
        return os.path.isfile(fpath) and os.access(fpath, os.X_OK)
    fpath, fname = os.path.split(program)
    if fpath:
        if is_exe(program):
            return program
    else:
        for path in os.environ["PATH"].split(os.pathsep):
            exe_file = os.path.join(path, program)
            if is_exe(exe_file):
                return exe_file


def lal_cache_to_wcache(cache):
    wcachedict = {}
    for e in cache:
        dir = os.path.split(e.path)[0]
        lfn = os.path.basename(e.path)
        if wcachedict.has_key(dir):
            l = wcachedict[dir]
            if l[2] > int(e.segment[0]):
                wcachedict[dir][2] = e.segment[0]
            if l[3] < int(e.segment[1]):
                wcachedict[dir][3] = e.segment[1]
        else:
            wcachedict[dir] = [e.observatory, e.description, int(e.segment[0]),
                               int(e.segment[1]), int(abs(e.segment)), dir]
    return wcachedict


# -----------------------------------------------------------------------------
# Parse command line

class GPSAction(argparse.Action):
    def __call__(self, parser, namespace, values, option_string=False):
        try:
            values = float(values)
        except (TypeError, ValueError):
            pass
        setattr(namespace, self.dest, to_gps(values))


parser = argparse.ArgumentParser()
parser.add_argument('-g', '--trigger-generator', action='store', type=str,
                    default='Omicron', help="name of event trigger generator, "
                                            "default: %(default)s")

topts = parser.add_argument_group('trigger options')
topts.add_argument('-r', '--rank-by', action='store', type=str, default='snr',
                    help="name of column by which to rank events, "
                         "default: %(default)s")
topts.add_argument('-m', '--minimum-rank', action='store', type=float,
                   default=8, help="exclude events below some rank")
topts.add_argument('-n', '--number', action='store', type=int, default=10,
                    help="number of scans to run, default: %(default)s")
topts.add_argument('-t', '--min-delta-t', action='store', type=float,
                    metavar='dT', default=5,
                    help="treat multiple events within dT as the same, "
                         "default: %(default)s")

sopts = parser.add_argument_group('segment options')
sopts.add_argument('-f', '--flag', action='store', type=str,
                   default='DMT-DC_READOUT_LOCKED:1',
                   help='data-quality flag with which to restrict triggers, '
                        'default: %(default)s')
sopts.add_argument('-U', '--segment-url', action='store', type=str,
                   default='https://segdb-er.ligo.caltech.edu',
                   help='URL of segment database, default: %(default)s')

oopts = parser.add_argument_group('Omega options')
oopts.add_argument('-c', '--config-file', action='store', type=str,
                   help='path to omega scan configuration file')

copts = parser.add_argument_group('Condor options')
copts.add_argument('-u', '--universe', action='store', type=str,
                   default='vanilla',
                   help="Condor universe to use, default: %(default)s")
copts.add_argument('-e', '--executable', action='store', type=str,
                   default=which('wpipeline'))
copts.add_argument('-a', '--file-tag', action='store', type=str,
                   default=os.path.basename(__file__),
                   help='file tag, default: %(default)s')
copts.add_argument('-o', '--output-dir', action='store', type=str,
                   default=os.curdir,
                   help='output directory, default: %(default)s')
copts.add_argument('-l', '--log-dir', action='store', type=str,
                   default=os.getenv('LOGDIR', None),
                   help='Condor log file directory, default: %(default)s')

parser.add_argument('channel', action='store', type=str,
                    help="name of channel to search")
parser.add_argument('gpsstart', action=GPSAction,
                    help="GPS start time of search")
parser.add_argument('gpsend', action=GPSAction,
                    help="GPS end time of search")

args = parser.parse_args()

# get directories
outdir = os.path.abspath(args.output_dir)
logdir = os.path.join(outdir, 'logs')
htclogdir = os.path.abspath(args.log_dir)
for d in [outdir, logdir, htclogdir]:
    if not os.path.isdir(d):
        os.makedirs(d)

# get channel information
ifo = args.channel[:2]
site = ifo[0]
span = Segment(args.gpsstart, args.gpsend)

# find wpipeline
args.executable = which(args.executable or 'wpipeline')

# find config file
if not args.config_file:
    if args.gpsstart < 1072569616:
        epoch = 'ER4'
    elif args.gpsstart < 1085616016:
        epoch = 'ER5'
    else:
        epoch = 'ER6'
    args.config_file = os.path.expanduser('~detchar/etc/omega/{0}/{1}-{1}1_R-'
                                          'selected.txt'.format(epoch, site))

# -----------------------------------------------------------------------------
# Get segments

if not args.flag.startswith('%s:' % ifo):
    args.flag = '%s:%s' % (ifo, args.flag)
print('Downloading segments for %s' % args.flag)
segments = DataQualityFlag.query(args.flag, args.gpsstart, args.gpsend,
                                 url=args.segment_url)
print('%d segments found, %ds livetime' % (len(segments.active),
                                           int(abs(segments.active))))

# -----------------------------------------------------------------------------
# Read triggers

cache = gwtrigfind.find_trigger_files(args.channel, args.trigger_generator,
                                      args.gpsstart, args.gpsend, verbose=True)

# check files
if len(cache) == 0:
    raise RuntimeError("No trigger files found")
print("%d files found" % len(cache))

# read triggers
filt = lambda row: args.gpsstart <= float(row.get_peak()) < args.gpsend
print('Reading triggers...', end=' ')
trigs = SnglBurstTable.read(cache, format='ligolw', filt=filt)
if len(trigs) == 0:
    raise RuntimeError("No triggers found")
trigs = trigs.vetoed(segments.active)
print("%d triggers found" % len(trigs))

args.number = min(args.number, len(trigs))

# -----------------------------------------------------------------------------
# Find triggers

trigs.sort(key=lambda row: float(getattr(row, args.rank_by.lower())),
           reverse=True)
itrigs = iter(trigs)
times = []
snrs = []
while len(times) < args.number:
    try:
        t = next(itrigs)
    except StopIteration:
        break
    rank = getattr(t, args.rank_by)
    if rank < args.minimum_rank:
        break
    if any([abs(float(t.get_peak()) - t2) < args.min_delta_t for t2 in times]):
        continue
    times.append(float(t.get_peak()))
    snrs.append(float(rank))

print('Found the following scan times:')
for t, snr in zip(times, snrs):
    print('    %.3f (%s = %.2f)' % (t, args.rank_by, snr))

# -----------------------------------------------------------------------------
# Build frame cache

print('Generating frame file cache...')

cache = Cache()

frametypes = set()
with open(args.config_file, 'rb') as f:
    while True:
        try:
            line = f.next()
        except StopIteration:
            break
        if 'frameType' in line:
            frametypes.add(eval(line.split(' ')[-1]))

conn = GWDataFindHTTPConnection()
for ft in frametypes:
    cache.extend(conn.find_frame_urls(ft[0], ft, args.gpsstart-43200,
                                      args.gpsend+43200,
                                      urltype='file'))
wcachedict = lal_cache_to_wcache(cache)

cachefile = os.path.join(outdir,
                         '%s-OMEGA_FRAME_CACHE-%d-%d.wcf'
                         % (ifo, args.gpsstart, args.gpsend-args.gpsstart))

with open(cachefile, 'w') as f:
    for item in wcachedict:
        f.write("%s %s %s %s %s %s\n" % tuple(wcachedict[item]))

print('Frame cache written to\n%s' % cachefile) 

# -----------------------------------------------------------------------------
# Build Omega scan DAG

dag = pipeline.CondorDAG(os.path.join(htclogdir, '%s.log' % args.file_tag))
dag.set_dag_file(os.path.join(outdir, args.file_tag))
job = OmegaPipelineJob(args.universe, args.executable, subdir=outdir,
                       logdir=logdir)

job.add_opt('configuration', args.config_file)
job.add_opt('report')
job.add_opt('framecache', cachefile)

for t in times:
    node = OmegaPipelineDAGNode(job)
    node.add_var_opt('outdir', os.path.join(args.output_dir, str(t)))
    node.add_var_arg(str(t))
    node.set_category('omega_scan')
    node.set_retry(1)
    dag.add_node(node)

dag.write_sub_files()
dag.write_dag()
dag.write_script()

print('DAG written to:\n%s' % os.path.abspath(dag.get_dag_file()))
