#!python
# coding=utf-8
# Copyright (C) LIGO Scientific Collaboration (2015-)
#
# This file is part of the GW DetChar python package.
#
# GW DetChar is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# GW DetChar is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with GW DetChar.  If not, see <http://www.gnu.org/licenses/>.

"""Find channels clipping their software saturation limits
"""

from __future__ import print_function

import os
import re
import subprocess
import sys
from multiprocessing import (Pool, cpu_count)
try:
    from itertools import zip_longest
    from io import StringIO
except ImportError:
    from itertools import izip_longest as zip_longest
    from StringIO import StringIO


from matplotlib import (use, rcParams)
use('agg')

import gwdatafind

from glue import markup

from gwpy.io.cache import sieve as sieve_cache
from gwpy.io.gwf import get_channel_names
from gwpy.segments import (Segment, SegmentList,
                           DataQualityFlag, DataQualityDict)
from gwpy.time import to_gps
from gwpy.timeseries import (TimeSeries, TimeSeriesDict, StateTimeSeries)

from gwdetchar import (__version__, cli, const)
from gwdetchar.io import html as htmlio
from gwdetchar.saturation import find_saturations
from gwdetchar.misc import get_gwpy_tex_settings

try:
    from LDAStools import frameCPP
except ImportError:
    HAS_FRAMECPP = False
else:
    HAS_FRAMECPP = True

__author__ = 'Dan Hoak <daniel.hoak@ligo.org>'
__credits__ = 'Duncan Macleod <duncan.macleod@ligo.org>'

DEFAULT_NPROC = min(8, cpu_count())

re_limit = re.compile(r'_LIMIT\Z')
re_limen = re.compile(r'_LIMEN\Z')
re_swstat = re.compile(r'_SWSTAT\Z')
re_software = re.compile(
    r'(%s)' % '|'.join([re_limit.pattern, re_limen.pattern, re_swstat.pattern]))


# -- utilities ----------------------------------------------------------------

def grouper(iterable, n, fillvalue=None):
    """Separate an iterable into sub-sets of `n` elements
    """
    args = [iter(iterable)] * n
    return zip_longest(fillvalue=fillvalue, *args)


def write_flag_html(flag, id=0, parent='accordion', context='warning',
                    title=None, plotdir=None):
    page = markup.page()
    page.div(class_='panel panel-%s' % context)
    page.div(class_='panel-heading')
    if title is None:
        title = flag.name
    page.a(title, class_="panel-title", href='#flag%s' % id,
           **{'data-toggle': 'collapse', 'data-parent': '#%s' % parent})
    page.div.close()
    page.div(id_='flag%s' % id, class_='panel-collapse collapse')
    page.div(class_='panel-body')
    segs = StringIO()
    try:
        flag.active.write(segs, format='segwizard',
                          coltype=type(flag.active[0][0]))
    except IndexError:
        page.p("No segments were found.")
    else:
        page.pre(segs.getvalue())
    page.div.close()
    if plotdir is not None:
        flagr = flag.name.replace('-', '_').replace(':', '-', 1)
        png = os.path.join(
            plotdir, '%s-%d-%d.png' % (flagr, span[0], abs(span)))
        plot = plot_saturations(flag, span)
        plot.save(png)
        plot.close()
        page.a(href=png, target='_blank')
        page.img(style="width: 100%;", src=png)
        page.a.close()
    page.div.close()
    page.div.close()
    return page


def plot_saturations(flag, span, facecolor='red', edgecolor='darkred',
                     known={'alpha': 0.2, 'facecolor': 'lightgray',
                            'edgecolor': 'gray'}):
    """Plot the saturation segments contained within this flag
    """
    name = flag.texname if rcParams["text.usetex"] else flag.name
    plot = flag.plot(
        figsize=[12, 2], facecolor=facecolor, edgecolor=edgecolor,
        known=known, label=' ',
        xlim=span, epoch=span[0],
        title="{} software saturations".format(name),
    )
    plot.subplots_adjust(bottom=0.4, top=0.8)
    return plot


def find_limit_channels(channels, skip=None):
    """Find all 'LIMIT' channels that have a matching 'LIMEN' or 'SWSTAT'

    Parameters
    ----------
    channels : `list` of `str`
        the list of channel names to search

    Returns
    -------
    limits : `list` or `str`
        the list of channels whose name ends in '_LIMIT' for whom a matching
        channel ending in '_LIMEN' or '_SWSTAT' was found
    """
    # find relevant channels and sort them
    if skip:
        re_skip = re.compile('(%s)' % '|'.join(skip))
        useful = sorted(x for x in channels if re_software.search(x) and
                        not re_skip.search(x))
    else:
        useful = sorted(x for x in channels if re_software.search(x))

    # map limits to limen or swstat
    limits = sorted(x[:-6] for x in useful if re_limit.search(x))
    limens = sorted(x[:-6] for x in useful if re_limen.search(x)
                    and x[:-6] in limits)
    swstats = sorted(x[:-7] for x in useful if re_swstat.search(x)
                     and x[:-7] in limits)
    return limens, swstats


def _find_saturations(data):
    out = find_saturations(data[0], data[1], precision=.99, segments=True)
    out.name = out.name[:-7]
    return out


def is_saturated(channel, cache, start=None, end=None,
                 indicator='LIMEN', nproc=DEFAULT_NPROC, segments=True):
    """Check whether a channel has saturated its software limit

    Parameters
    ----------
    channel : `str`, or `list` of `str`
        either a single channel name, or a list of channel names
    cache : `~glue.lal.Cache`
        a `~glue.lal.Cache` of file paths, the cache must be contiguous
    start : `~gwpy.time.LIGOTimeGPS`, `int`
        the GPS start time of the check
    end : `~gwpy.time.LIGOTimeGPS`, `int`
        the GPS end time of the check
    indicator : `str`
        the suffix of the indicator channel, either `'LIMEN'` or `'SWSTAT'`
    nproc : `int`
        the number of parallel processes to use for frame reading
    segments : `bool`, default `True`
        if `True` return the actual saturation segments, otherwise just
        return `True` is the channel saturates at least once

    Returns
    -------
    saturated : `bool`, `None`, or `DataQualityFlag`, or `list` of the same
        one of the following given the conditions

        - `None` : if the channel doesn't have a software limit
        - `False` : if the channel didn't saturate
        - `True` : if the channel did saturate and `segments=False` was given
        - `~gwpy.segments.DataQualityFlag` : otherwise

        OR, a `list` of the above if a `list` of channels was given in the
        first place
    """
    if isinstance(channel, (list, tuple)):
        channels = channel
    else:
        channels = [channel]
    # parse prefix
    for i, c in enumerate(channels):
        if c.endswith('_LIMIT'):
            channels[i] = c[:-6]
    # check limit if set
    indicators = ['%s_%s' % (c, indicator) for c in channels]
    if HAS_FRAMECPP:
        iokwargs = {'type': 'adc', 'format': 'gwf.framecpp'}
    else:
        iokwargs = {'format': 'gwf'}
    data = TimeSeriesDict.read(
        cache[0], indicators, start=start, end=start+1, **iokwargs)
    if indicator.upper() == 'LIMEN':
        active = dict((c, data[indicators[i]].value[0]) for
                      i, c in enumerate(channels))
    elif indicator.upper() == 'SWSTAT':
        active = dict(
            (c, data[indicators[i]].astype('uint32').value[0] >> 13 & 1) for
            i, c in enumerate(channels))
    else:
        raise ValueError("Don't know how to determine if limit is set for "
                         "indicator %r" % indicator)
    # get output/limit data for all with active limits
    activechans = [c for c in channels if active[c]]
    datachans = ['%s_%s' % (c, s) for c in activechans for
                 s in ('LIMIT', 'OUTPUT')]
    data = TimeSeriesDict.read(cache, datachans, start=start, end=end,
                               nproc=nproc, **iokwargs)

    # find saturations of the limit for each channel
    dataiter = ((data['%s_OUTPUT' % c], data['%s_LIMIT' % c])
                for c in activechans)
    if nproc > 1:
        pool = Pool(processes=nproc)
        saturations = list(pool.map(_find_saturations, dataiter))
        pool.close()
    else:
        saturations = list(map(_find_saturations, dataiter))

    # return many or one (based on input)
    if isinstance(channel, (list, tuple)):
        return saturations
    else:
        return saturations[0]


# -----------------------------------------------------------------------------
#
# Execution starts here
#
# -----------------------------------------------------------------------------

# -- parse command line -------------------------------------------------------

parser = cli.create_parser(description=__doc__)
cli.add_gps_start_stop_arguments(parser)
cli.add_ifo_option(parser)
cli.add_frametype_option(parser, required=const.IFO is None,
                         default=const.IFO is not None and '%s_R' % const.IFO)
cli.add_nproc_option(parser)
parser.add_argument('-c', '--channels',
                    help="file containing columnar list of channels to "
                         "process, default is to find all relevant channels "
                         "from frames")
parser.add_argument('-s', '--skip', nargs='*', default=[],
                    help='skip channels matching this string')
parser.add_argument('-g', '--group-size', default=1024, type=int,
                    help="number of channels to process in a single batch, "
                         "default: %(default)s")
parser.add_argument('-a', '--state-flag', metavar='FLAG',
                    help='restrict search to times when FLAG was active')
parser.add_argument('-p', '--pad-state-end', metavar='PAD', default=0,
                    type=float, help='pad state segments inwards from the end '
                                     'by PAD segments, default: %(default)s')
parser.add_argument('-m', '--html', help='path to write html output')
parser.add_argument('-v', '--plot', action='store_true', default=False,
                    help='make plots of all saturations, defaul: %(default)s')

args = parser.parse_args()

ifo = args.ifo.upper()
site = ifo[0]
frametype = args.frametype or '%s_R' % ifo

# set TeX settings
tex_settings = get_gwpy_tex_settings()
rcParams.update(tex_settings)

# get segments
span = Segment(args.gpsstart, args.gpsend)
if args.state_flag:
    state = DataQualityFlag.query(args.state_flag, int(args.gpsstart),
                                  int(args.gpsend),
                                  url=const.DEFAULT_SEGMENT_SERVER)
    for i, seg in enumerate(state.active):
        state.active[i] = type(seg)(seg[0], seg[1]-args.pad_state_end)
    segs = state.active.coalesce()
    print("Recovered %d seconds of time for %s"
          % (abs(segs), args.state_flag))
else:
    segs = SegmentList([Segment(args.gpsstart, args.gpsend)])

# find frames
cache = gwdatafind.find_urls(site, frametype, int(args.gpsstart), int(args.gpsend))

# find channels
if not os.getenv('LIGO_DATAFIND_SERVER'):
    raise RuntimeError("No LIGO_DATAFIND_SERVER variable set, don't know "
                       "how to discover channels")
else:
    print("Finding channels in frames...")
    if len(cache) == 0:
        raise RuntimeError("No frames recovered for %s in interval [%s, %s)" %
                           (frametype, int(args.gpsstart),
                            int(args.gpsend)))
    allchannels = get_channel_names(cache[0])
    print("   Found %d channels in frame" % len(allchannels))
    sys.stdout.flush()
    channels = find_limit_channels(allchannels, skip=args.skip)
    print("   Parsed %d channels with '_LIMIT' and '_LIMEN' or '_SWSTAT'"
          % sum(map(len, channels)))


# -- read channels and check limits -------------------------------------------

saturations = DataQualityDict()
bad = set()

# TODO: use multiprocessing to separate channel list into discrete chunks
#       should give a factor of X for X processes

# check limens
for suffix, clist in zip(['LIMEN', 'SWSTAT'], channels):
    nchans = len(clist)
    # group channels in sets for batch processing
    #     min of <number of channels>, user group size (sensible number), and
    #     512 Mb of RAM for single-precision EPICS
    try:
        dur = max([float(abs(s)) for s in segs])
    except ValueError:
        ngroup = args.group_size
    else:
        ngroup = int(
            min(nchans, args.group_size, 2 * 1024**3 / 4. / 16. / dur))
    print('Processing %s channels in groups of %d:' % (suffix, ngroup))
    sys.stdout.flush()
    sets = grouper(clist, ngroup)
    for i, cset in enumerate(sets):
        # remove empty entries use to pad the list to 8 elements
        cset = list(cset)
        while cset[-1] is None:
            cset.pop(-1)
        for seg in segs:
            cache2 = sieve_cache(cache, segment=seg)
            if not len(cache2):
                continue
            saturated = is_saturated(cset, cache2, seg[0], seg[1],
                                     indicator=suffix, nproc=args.nproc,
                                     segments=True)
            for new in saturated:
                try:
                    saturations[new.name] += new
                except KeyError:
                    saturations[new.name] = new
        for j, c in enumerate(cset):
            try:
                sat = saturations[c]
            except KeyError:
                print('%40s:      SKIP      [%d/%d]'
                      % (c, i*ngroup + j + 1, nchans), end='\r')
            else:
                if abs(sat.active):
                    print('%40s: ---- FAIL ---- [%d/%d]'
                          % (c, i*ngroup + j + 1, nchans))
                    for seg in sat.active:
                        print(" " * 42 + str(seg))
                    bad.add(c)
                else:
                    print('%40s:      PASS      [%d/%d]'
                          % (c, i*ngroup + j + 1, nchans))
            sys.stdout.flush()

# -- print results and exit ---------------------------------------------------

print("\n----------------------------------------------------------------"
      "-----")
if len(bad):
    print("Saturations were found for all of the following:")
else:
    print("No software saturations were found in any channels")
print("---------------------------------------------------------------------")
for c in bad:
    print(c)
if len(bad):
    print("-------------------------------------------------------------------"
          "--")

# write LIGO_LW XML
outfile = ('%s-SOFTWARE_SATURATIONS-%d-%d.xml.gz'
           % (ifo, int(args.gpsstart),
              int(args.gpsend) - int(args.gpsstart)))
saturations.write(outfile, overwrite=True)
print("Saturation segments written to \n%s" % outfile)

if args.html:
    if args.plot:
        args.plot = os.path.dirname(args.html)
    xmlfile = os.path.relpath(outfile, os.path.dirname(args.html))
    if os.path.basename(args.html) == 'index.html':
        page = htmlio.new_bootstrap_page()
    else:
        page = markup.page()
    page.div(class_='container')
    # -- header
    page.div(class_='page-header')
    page.h1('%s software saturations: %d-%d'
            % (ifo, int(args.gpsstart), int(args.gpsend)))
    page.p("This analysis searched %d channels for times during which their "
           "OUTPUT value matched the LIMIT value set in software. Only "
           "those signals achieving saturation are printed in the Segments "
           "section below, while all channels for which the LIMIT was active "
           "are included in the Results summary and the XML record."
           % sum(map(len, channels)))
    page.div.close()
    # -- segments
    page.h2('Segments')
    # link XML file
    page.p()
    page.add('The full output segments are recorded in '
             'LIGO_LW-format XML here:')
    page.a(os.path.basename(xmlfile), href=xmlfile, target='_blank')
    page.p.close()
    # print state segments
    if args.state_flag:
        page.p('This analysis was executed over the following segments:')
        page.div(class_='panel-group', id_='accordion1')
        page.add(str(write_flag_html(state, 'state', parent='accordion1',
                                     context='success')))
        page.div.close()
    # print saturation segments
    if len(bad):
        page.p('The following channels indicated a software saturation:')
        page.div(class_='panel-group', id_='accordion2')
        for i, (c, flag) in enumerate(saturations.items()):
            if abs(flag.active) > 0:
                title = '%s [%d]' % (flag.name, len(flag.active))
                page.add(str(write_flag_html(flag, i, parent='accordion2',
                                             title=title, plotdir=args.plot)))
        page.div.close()
    else:
        page.div(class_='alert alert-info')
        page.p('No software saturations were found')
        page.div.close()
    # -- results table
    page.h2('Results summary')
    page.table(class_='table table-striped table-hover')
    # write table header
    page.thead()
    page.tr()
    for header in ['Channel', 'Result', 'Num. saturations']:
        page.th(header)
    page.thead.close()
    # write body
    page.tbody()
    for c, seglist in saturations.items():
        passed = abs(seglist.active) == 0
        page.tr(class_=passed and 'default' or 'warning')
        page.td(c)
        page.td(passed and 'Pass' or 'Fail')
        page.td(len(seglist.active))
        page.tr.close()
    page.tbody.close()
    page.table.close()
    # -- paramters
    page.h2('Parameters')
    page.p('This analysis used the following parameters:')
    def write_param(param, value):
        page.p()
        page.strong('%s: ' % param)
        page.add(str(value))
        page.p.close()
    write_param('Start time', args.gpsstart)
    write_param('End time', args.gpsend)
    write_param('State flag', args.state_flag)
    write_param('State end padding', args.pad_state_end)
    write_param('Skip', ', '.join(map(repr, args.skip)))
    # close and write
    page.div.close()
    with open(args.html, 'w') as fp:
        print(str(page), file=fp)
