#!/usr/bin/env python
# coding=utf-8
# Copyright (C) LIGO Scientific Collaboration (2015-)
#
# This file is part of the GW DetChar python package.
#
# GW DetChar is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# GW DetChar is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with GW DetChar.  If not, see <http://www.gnu.org/licenses/>.

"""Search for evidence of beam scattering based on optic velocity

This tool identifies time segments when evidence for scattering is strong,
to compare projected fringes against spectrogram measurements for a specific
time please use the command-line module:
`python -m gwdetchar.scattering --help`
"""

from __future__ import division

import os.path
import re
import random
import warnings

from six.moves import StringIO

import numpy

from matplotlib import (use, rcParams)
use('agg')  # noqa

import gwtrigfind

from gwpy.plot import Plot
from gwpy.plot.tex import label_to_latex
from gwpy.table import (Table, EventTable)
from gwpy.table.filters import in_segmentlist
from gwpy.segments import (DataQualityFlag, DataQualityDict,
                           Segment, SegmentList)

from gwdetchar import (cli, const, scattering)
from gwdetchar.io import html as htmlio
from gwdetchar.io.datafind import get_data
from gwdetchar.omega import batch
from gwdetchar.plot import get_gwpy_tex_settings

__author__ = 'Duncan Macleod <duncan.macleod@ligo.org>'
__credits__ = 'Siddharth Soni <siddharth.soni@ligo.org>' \
              'Alex Urban <alexander.urban@ligo.org>'

# TeX settings
tex_settings = get_gwpy_tex_settings()
rcParams.update(tex_settings)
rcParams.update({
    'axes.labelsize': 20,
    'figure.subplot.bottom': 0.17,
    'figure.subplot.left': 0.1,
    'figure.subplot.right': 0.9,
    'figure.subplot.top': 0.90,
    'grid.color': 'gray',
    'image.cmap': 'viridis',
    'svg.fonttype': 'none',
})

# -- parse command line -------------------------------------------------------

parser = cli.create_parser(description=__doc__)
cli.add_gps_start_stop_arguments(parser)
cli.add_ifo_option(parser)
parser.add_argument('-a', '--state-flag', metavar='FLAG', default=None,
                    help='restrict search to times when FLAG was active')
parser.add_argument('-z', '--fmin', metavar='FMIN', type=float, default=10.,
                    help='minimum frequency (Hz) of triggers, '
                         'default: %(default)s')
parser.add_argument('-t', '--frequency-threshold', type=float, default=15.,
                    help='critical fringe frequency threshold (Hz), '
                         'default: %(default)s')
parser.add_argument('-x', '--multiplier-for-threshold', type=int,
                    default=4, choices=scattering.FREQUENCY_MULTIPLIERS,
                    help='fringe frequency multiplier to use when applying '
                         '--frequency-threshold, default: %(default)s')
parser.add_argument('-m', '--optic', action='append',
                    help='optic to search for scattering signal, can be given '
                         'multiple times, default: %s'
                         % scattering.OPTIC_MOTION_CHANNELS.keys())
parser.add_argument('-p', '--segment-padding', type=float, default=.05,
                    help='time with which to pad scattering segments on '
                         'either side, default: %(default)s')
parser.add_argument('-s', '--segment-start-pad', type=float, default=30.0,
                    help='amount of time to remove from the start of each '
                         'analysis segment')
parser.add_argument('-e', '--segment-end-pad', type=float, default=30.0,
                    help='amount of time to remove from the end of each '
                         'analysis segment')
parser.add_argument('-c', '--main-channel',
                    default='{IFO}:GDS-CALIB_STRAIN',
                    help='name of main (h(t)) channel, default: %(default)s')
cli.add_frametype_option(parser, default='%s_R' % const.IFO)
parser.add_argument('-o', '--output-dir', type=os.path.abspath,
                    default=os.curdir,
                    help='output directory for analysis, default: %(default)s')
parser.add_argument('-w', '--omega-scans', type=int, metavar='NSCAN',
                    help='generate a workflow of omega scans for active '
                         'scattering, this will launch automatically to '
                         'condor, default: None')
parser.add_argument('-v', '--verbose', action='store_true', default=False,
                    help='log verbose output, default: %(default)s')
cli.add_nproc_option(parser)

args = parser.parse_args()

# set up logger
logger = cli.logger(
    name=os.path.basename(__file__),
    level='DEBUG' if args.verbose else 'INFO',
)

multiplier = args.multiplier_for_threshold
if args.frequency_threshold.is_integer():
    args.frequency_threshold = int(args.frequency_threshold)

tstr = str(args.frequency_threshold).replace('.', '_')
gpsstr = '%s-%s' % (int(args.gpsstart), int(args.gpsend-args.gpsstart))

if args.optic is None:
    args.optic = scattering.OPTIC_MOTION_CHANNELS.keys()

if not os.path.isdir(args.output_dir):
    os.makedirs(args.output_dir)
os.chdir(args.output_dir)

segfile = '%s-SCATTERING_SEGMENTS_%s_HZ-%s.h5' % (args.ifo, tstr, gpsstr)
summfile = '{}-SCATTERING_SUMMARY-{}.csv'.format(args.ifo, gpsstr)

logger.info('{} Scattering {}-{}'.format(
    args.ifo, int(args.gpsstart), int(args.gpsend)))

# -- get state segments -------------------------------------------------------

span = Segment(args.gpsstart, args.gpsend)

# get segments
if args.state_flag is not None:
    state = DataQualityFlag.query(args.state_flag, int(args.gpsstart),
                                  int(args.gpsend),
                                  url=const.DEFAULT_SEGMENT_SERVER)
    padding = args.segment_start_pad + args.segment_end_pad
    for i, seg in enumerate(state.active):
        if abs(seg) > padding:
            state.active[i] = type(seg)(
                seg[0]+args.segment_start_pad, seg[1]-args.segment_end_pad)
        else:
            logger.debug(
                "Segment length {} shorter than padding length {}, "
                "skipping segment {}-{}".format(abs(seg), padding, *seg))
    state.coalesce()
    statea = state.active
    livetime = float(abs(statea))
    logger.debug("Downloaded %d segments for %s [%.2fs livetime]"
                 % (len(statea), args.state_flag, livetime))
else:
    statea = SegmentList([span])

# -- load h(t) ----------------------------------------------------------------

args.main_channel = args.main_channel.format(IFO=args.ifo)
logger.debug("Loading Omicron triggers for %s" % args.main_channel)

if args.gpsstart >= 1230336018:  # Jan 1 2019
    ext = "h5"
    names = ["time", "frequency", "snr"]
    read_kw = {
        "columns": names,
        "selection": [
            "{0} < frequency < {1}".format(
                args.fmin, multiplier * args.frequency_threshold),
            ("time", in_segmentlist, statea),
        ],
        "format": "hdf5",
        "path": "triggers",
    }
else:
    try:
        from glue.ligolw import lsctables  # noqa: F401
    except ImportError as exc:
        exc.args = (
            "{!s}, `lscsoft-glue` is required to read LIGO_LW XML files",
        )
        raise
    ext = "xml.gz"
    names = ['peak', 'peak_frequency', 'snr']
    read_kw = {
        "columns": names,
        "selection": [
            "{0} < peak_frequency < {1}".format(
                args.fmin, multiplier * args.frequency_threshold),
            ('peak', in_segmentlist, statea),
        ],
        "format": 'ligolw',
        "tablename": "sngl_burst",
    }

fullcache = []
for seg in statea:
    cache = gwtrigfind.find_trigger_files(
        args.main_channel, 'omicron', seg[0], seg[1], ext="h5",
    )
    if len(cache) == 0:
        warnings.warn("No Omicron triggers found for %s in segment [%d .. %d)"
                      % (args.main_channel, seg[0], seg[1]))
        continue
    fullcache.extend(cache)

# read triggers
if fullcache:
    trigs = EventTable.read(fullcache, nproc=args.nproc, **read_kw)
else:  # no files (no livetime?)
    trigs = EventTable(names=names)

highsnrtrigs = trigs[trigs['snr'] >= 8]
logger.debug("%d read" % len(trigs))

# -- prepare HTML -------------------------------------------------------------

links = [(s, '#%s' % s.lower()) for s in ['Parameters', 'Segments']]
if args.omega_scans:
    links.append(('Omega Scans', '#omega-scans'))
(brand, class_) = htmlio.get_brand(args.ifo, 'Scattering', args.gpsstart)
navbar = htmlio.navbar(links, class_=class_, brand=brand)
page = htmlio.new_bootstrap_page(
    title='%s Scattering' % args.ifo,
    navbar=navbar)
page.div(class_='page-header')
page.h1('%s Scattering: %d-%d'
        % (args.ifo, int(args.gpsstart), int(args.gpsend)))
page.p("This analysis searched for evidence of beam scattering between {} and "
       "{} Hz based on the velocity of optic motion. The fringe frequency is "
       "projected using equation (3) of <a href=\"http://iopscience.iop.org/"
       "article/10.1088/0264-9381/27/19/194011\" target=\"_blank\">Accadia "
       "et al. (2010)</a>.".format(args.fmin,
                                   multiplier * args.frequency_threshold))
page.div.close()
page.add(htmlio.write_arguments(
    [], start=int(args.gpsstart), end=int(args.gpsend), flag=args.state_flag))

# link segments file
page.h2('Segments', id_='segments')
page.p()
page.add('Full output segments are recorded in HDF5 format here:')
page.a(os.path.basename(segfile), href=segfile, target='_blank')
page.p.close()
# link summary file
page.p()
page.add('A summary of triggers that occur during active scattering '
         'segments is recorded in CSV format here:')
page.a(os.path.basename(summfile), href=summfile, target='_blank')
page.p.close()

# record state segments
if args.state_flag is not None:
    page.p('This analysis was done over the following data-quality flag:')
    page.div(class_='panel-group', id_='accordion1')
    page.add(htmlio.write_flag_html(
        state, span, 'state', parent='accordion1', context='success',
        plotdir='', facecolor=(0.2, 0.8, 0.2), edgecolor='darkgreen',
        known={'facecolor': 'red', 'edgecolor': 'darkred', 'height': 0.4}))
    page.div.close()

if statea:
    page.p("The following channels were searched for evidence of scattering "
           "(yellow = weak evidence, red = strong evidence):")
    page.div(class_='panel-group', id_='accordion1')
else:
    page.div(class_='alert alert-info')
    page.p('No active analysis segments were found')
    page.div.close()

# -- find scattering evidence -------------------------------------------------

allchannels = ['%s:%s' % (args.ifo, c) for optic in args.optic for
               c in scattering.OPTIC_MOTION_CHANNELS[optic]]

logger.info("Reading all data")
alldata = []
n = len(statea)
for i, seg in enumerate(statea):
    msg = "{0}/{1} {2}:".rjust(30).format(
        str(i).rjust(len(str(n))),
        n,
        str(seg),
    ) if args.verbose else False
    alldata.append(
        get_data(allchannels, seg[0], seg[1], frametype=args.frametype,
                 verbose=msg, nproc=args.nproc).resample(128))
try:
    channels = alldata[0].keys()
except IndexError:
    channels = []

scatter_segments = DataQualityDict()
actives = SegmentList()

for i, channel in enumerate(sorted(channels)):
    logger.info("-- Processing %s --" % channel)
    chanstr = re.sub('[:-]', '_', channel).replace('_', '-', 1)
    optic = channel.split('-')[1].split('_')[0]
    flag = '%s:DCH-%s_SCATTERING_GE_%s_HZ:1' % (args.ifo, optic, tstr)
    scatter_segments[channel] = DataQualityFlag(
        flag,
        isgood=False,
        description="Evidence for scattering above {0} Hz from {1} in "
                    "{2}".format(args.frequency_threshold, optic, channel),
    )
    # set up plot(s)
    plot = Plot(figsize=[12, 12])
    axes = {}
    axes['position'] = plot.add_subplot(411, xscale='auto-gps', xlabel='')
    axes['fringef'] = plot.add_subplot(412, sharex=axes['position'], xlabel='')
    axes['triggers'] = plot.add_subplot(413, sharex=axes['position'],
                                        xlabel='')
    axes['segments'] = plot.add_subplot(414, projection='segments',
                                        sharex=axes['position'])
    plot.subplots_adjust(bottom=.07, top=.95)
    histdata = dict((i, numpy.ndarray((0,))) for
                    i in scattering.FREQUENCY_MULTIPLIERS)
    linecolor = None
    fringecolors = [None] * len(scattering.FREQUENCY_MULTIPLIERS)
    # loop over state segments and find scattering fringes
    for j, seg in enumerate(statea):
        logger.debug("Processing segment [%d .. %d)" % seg)
        ts = alldata[j][channel]
        # get raw data and plot
        line = axes['position'].plot(ts, color=linecolor)[0]
        linecolor = line.get_color()
        # get fringe frequency and plot
        fringef = scattering.get_fringe_frequency(ts, multiplier=1)
        for k, m in list(enumerate(scattering.FREQUENCY_MULTIPLIERS))[::-1]:
            fm = fringef * m
            line = axes['fringef'].plot(
                fm, color=fringecolors[k],
                label=(j == 0 and r'$f\times%d$' % m or None))[0]
            fringecolors[k] = line.get_color()
            histdata[m].resize((histdata[m].size + fm.size,))
            histdata[m][-fm.size:] = fm.value
        # get segments and plot
        if ((fringef * multiplier).value.max() <
                args.frequency_threshold):
            scatter = DataQualityFlag(flag, known=[ts.span])
        else:
            scatter = (
                fringef * multiplier >=
                args.frequency_threshold * fringef.unit).to_dqflag(name=flag)
            scatter.active = scatter.active.protract(args.segment_padding)
            scatter.coalesce()
        axes['segments'].plot(scatter, facecolor='red', edgecolor='darkred',
                              known={'alpha': 0.6, 'facecolor': 'lightgray',
                                     'edgecolor': 'gray', 'height': 0.4},
                              height=0.8, y=0, label=' ')
        scatter_segments[channel] += scatter
        logger.debug(
            "    Found %d scattering segments" % (len(scatter.active)))
    logger.debug("Completed channel, found %d segments in total."
                 % len(scatter_segments[channel].active))

    # calculate efficiency and deadtime of veto
    deadtime = abs(scatter_segments[channel].active)
    try:
        deadtimepc = deadtime / livetime * 100
    except ZeroDivisionError:
        deadtimepc = 0.
    logger.info("Deadtime: %.2f%% (%.2f/%ds)"
                % (deadtimepc, deadtime, livetime))
    efficiency = in_segmentlist(highsnrtrigs[names[0]],
                                scatter_segments[channel].active).sum()
    try:
        efficiencypc = efficiency / len(highsnrtrigs) * 100
    except ZeroDivisionError:
        efficiencypc = 0.
    logger.info("Efficiency (SNR>=8): %.2f%% (%d/%d)"
                % (efficiencypc, efficiency, len(highsnrtrigs)))
    if deadtimepc == 0.:
        effdt = 0
    else:
        effdt = efficiencypc/deadtimepc
    logger.info("Efficiency/Deadtime: %.2f" % effdt)

    if abs(scatter_segments[channel].active):
        actives.extend(scatter_segments[channel].active)

    # finalize plot
    logger.debug("Plotting")
    name = label_to_latex(channel) if rcParams["text.usetex"] else channel
    axes['position'].set_title("Scattering evidence in %s" % name)
    axes['position'].set_xlabel('')
    axes['position'].set_ylabel(r'Position [\textmu m]')
    axes['position'].text(
        0.01, 0.95, 'Optic position',
        transform=axes['position'].transAxes, va='top', ha='left',
        bbox={'edgecolor': 'none', 'facecolor': 'white', 'alpha': .5})
    axes['fringef'].plot(
        span, [args.frequency_threshold, args.frequency_threshold], 'k--')
    axes['fringef'].set_xlabel('')
    axes['fringef'].set_ylabel(r'Frequency [Hz]')
    axes['fringef'].yaxis.tick_right()
    axes['fringef'].yaxis.set_label_position("right")
    axes['fringef'].set_ylim(0, multiplier * args.frequency_threshold)
    axes['fringef'].text(
        0.01, 0.95, 'Calculated fringe frequency',
        transform=axes['fringef'].transAxes, va='top', ha='left',
        bbox={'edgecolor': 'none', 'facecolor': 'white', 'alpha': .5})
    handles, labels = axes['fringef'].get_legend_handles_labels()
    axes['fringef'].legend(handles[::-1], labels[::-1], loc='upper right',
                           borderaxespad=0, bbox_to_anchor=(-0.01, 1.),
                           handlelength=1)

    axes['triggers'].scatter(
        trigs[names[0]],
        trigs[names[1]],
        c=trigs[names[2]],
        edgecolor='none',
    )
    name = args.main_channel
    if rcParams["text.usetex"]:
        name = label_to_latex(name)
    axes['triggers'].text(
        0.01, 0.95,
        '%s event triggers (Omicron)' % name,
        transform=axes['triggers'].transAxes, va='top', ha='left',
        bbox={'edgecolor': 'none', 'facecolor': 'white', 'alpha': .5})
    axes['triggers'].set_ylabel('Frequency [Hz]')
    axes['triggers'].set_ylim(args.fmin, multiplier * args.frequency_threshold)
    axes['triggers'].colorbar(cmap='YlGnBu', clim=(3, 100), norm='log',
                              label='Signal-to-noise ratio')
    axes['segments'].set_ylim(-.55, .55)
    axes['segments'].text(
        0.01, 0.95,
        r'Time segments with $f\times%d > %.2f$ Hz' % (
            multiplier, args.frequency_threshold),
        transform=axes['segments'].transAxes, va='top', ha='left',
        bbox={'edgecolor': 'none', 'facecolor': 'white', 'alpha': .5})
    for ax in axes.values():
        ax.set_epoch(int(args.gpsstart))
        ax.set_xlim(*span)
    png = '%s_SCATTERING_%s_HZ-%s.png' % (chanstr, tstr, gpsstr)
    try:
        plot.save(png)
    except OverflowError as e:
        warnings.warn(str(e))
        plot.axes[1].set_ylim(0, multiplier * args.frequency_threshold)
        plot.refresh()
        plot.save(png)
    plot.close()
    logger.debug("%s written." % png)

    # make histogram
    histogram = Plot(figsize=[12, 6])
    ax = histogram.gca()
    hrange = (0, multiplier * args.frequency_threshold)
    for m, color in list(zip(histdata, fringecolors))[::-1]:
        if histdata[m].size:
            ax.hist(
                histdata[m], facecolor=color, alpha=.6, range=hrange,
                bins=50, histtype='stepfilled', label=r'$f\times%d$' % m,
                cumulative=-1, weights=ts.dx.value, bottom=1e-100, log=True)
        else:
            ax.plot(histdata[m], color=color, label=r'$f\times%d$' % m)
            ax.set_yscale('log')
    ax.set_ylim(.01, float(livetime))
    ax.set_ylabel('Time with fringe above frequency [s]')
    ax.set_xlim(*hrange)
    ax.set_xlabel('Frequency [Hz]')
    ax.set_title(axes['position'].get_title())
    handles, labels = ax.get_legend_handles_labels()
    ax.legend(handles[::-1], labels[::-1], loc='upper right')
    hpng = '%s_SCATTERING_HISTOGRAM-%s.png' % (chanstr, gpsstr)
    histogram.save(hpng)
    histogram.close()
    logger.debug("%s written." % hpng)

    # write HTML
    if deadtime != 0 and effdt > 2:
        context = 'danger'
    elif ((deadtime != 0 and effdt < 2) or
          (histdata[multiplier].size and
           histdata[multiplier].max() >=
              args.frequency_threshold/2.)):
        context = 'warning'
    else:
        continue
    page.div(class_='panel panel-%s' % context)
    page.div(class_='panel-heading')
    page.a(channel, class_="panel-title", href='#flag%s' % i,
           **{'data-toggle': 'collapse', 'data-parent': '#accordion'})
    page.div.close()
    page.div(id_='flag%s' % i, class_='panel-collapse collapse')
    img = htmlio.FancyPlot(
        png, caption=scattering.SCATTER_CAPTION.format(CHANNEL=channel))
    page.add(htmlio.fancybox_img(img))
    himg = htmlio.FancyPlot(
        hpng, caption=scattering.HIST_CAPTION.format(CHANNEL=channel))
    page.add(htmlio.fancybox_img(himg))
    page.div(class_='panel-body')
    segs = StringIO()
    if deadtime:
        page.p("%d segments were found predicting a scattering fringe above "
               "%.2f Hz." % (len(scatter_segments[channel].active),
                             args.frequency_threshold))
        page.table(class_='table table-condensed table-hover')
        page.tbody()
        page.tr()
        page.th('Deadtime')
        page.td('%.2f/%d seconds' % (deadtime, livetime))
        page.td('%.2f%%' % deadtimepc)
        page.tr.close()
        page.tr()
        page.th('Efficiency<br><small>(SNR&ge;8 and '
                '%.2f Hz</sub>&ltf<sub>peak</sub>&lt;%.2f Hz)</small>'
                % (args.fmin, multiplier * args.frequency_threshold))
        page.td('%d/%d events' % (efficiency, len(highsnrtrigs)))
        page.td('%.2f%%' % efficiencypc)
        page.tr.close()
        page.tr()
        page.th('Efficiency/Deadtime')
        page.td()
        page.td('%.2f' % effdt)
        page.tr.close()
        page.tbody.close()
        page.table.close()
        scatter_segments[channel].active.write(segs, format='segwizard',
                                               coltype=float)
        page.pre(segs.getvalue())
    else:
        page.p("No segments were found with scattering above %.2f Hz."
               % args.frequency_threshold)
    page.div.close()
    page.div.close()
    page.div.close()

if statea:
    # close panel group
    page.div.close()

actives = actives.coalesce()  # merge contiguous segments
if not actives:
    page.div(class_='alert alert-info')
    page.p('No evidence of scattering found in the channels analyzed')
    page.div.close()

# identify triggers during active segments
logger.debug('Writing a summary CSV record')
ind = [i for i, trigtime in enumerate(highsnrtrigs[names[0]])
       if trigtime in actives]
gps = highsnrtrigs[names[0]][ind]
freq = highsnrtrigs[names[1]][ind]
snr = highsnrtrigs[names[2]][ind]
segs = [y for x in gps for y in actives if x in y]
table = Table(
    [gps, freq, snr, [seg[0] for seg in segs], [seg[1] for seg in segs]],
    names=('trigger_time', 'trigger_frequency', 'trigger_snr',
           'segment_start', 'segment_end'))
logger.info('The following {} triggers fell within active scattering '
            'segments:\n\n'.format(len(table)))
print(table)
print('\n\n')
table.write(summfile, overwrite=True)

# -- launch omega scans -------------------------------------------------------

nscans = min(args.omega_scans, len(table))
if nscans > 0:
    # launch scans
    scandir = 'scans'
    ind = random.sample(range(0, len(table)), nscans)
    omegatimes = [str(t) for t in table['trigger_time'][ind]]
    logger.debug('Collected {} event times to omega scan: {}'.format(
        nscans, ', '.join(omegatimes)))
    logger.info('Creating workflow for omega scans')
    flags = batch.get_command_line_flags(ifo=args.ifo, ignore_state_flags=True)
    condorcmds = batch.get_condor_arguments(timeout=4, gps=args.gpsstart)
    dagman = batch.generate_dag(omegatimes, flags=flags, submit=True,
                                outdir=scandir, condor_commands=condorcmds)
    logger.info('Launched {} omega scans to condor'.format(nscans))
    # render HTML
    page.h2('Omega Scans', id_='omega-scans')
    page.p('The following event times correspond to significant Omicron '
           'triggers that occur during the scattering segments found above. '
           'To compare these against fringe frequency projections, please '
           'use the scattering module:')
    page.pre('$ python -m gwdetchar.scattering --help')
    page.add(htmlio.scaffold_omega_scans(
        omegatimes, args.main_channel, scandir=scandir))
elif args.omega_scans:
    logger.info('No events found during active scattering segments')

# -- finalize -----------------------------------------------------------------

# write segments
scatter_segments.write(segfile, path="segments", overwrite=True)
logger.debug("%s written" % segfile)

# write HTML
htmlio.close_page(page, 'index.html')
logger.info("-- index.html written, all done --")
