#!python
# coding=utf-8
# Copyright (C) LIGO Scientific Collaboration (2015-)
#
# This file is part of the GW DetChar python package.
#
# GW DetChar is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# GW DetChar is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with GW DetChar.  If not, see <http://www.gnu.org/licenses/>.

"""Find times when an input signal crosses a particular threshold
"""
from __future__ import print_function

import numpy
import os
import sys

from matplotlib import (use, rcParams)
use('agg')

import gwdatafind

from glue.ligolw.utils import write_fileobj

from gwpy.segments import (DataQualityFlag, DataQualityDict,
                           Segment, SegmentList)

from gwpy.io.cache import (cache_segments, sieve as sieve_cache)
from gwpy.utils import gprint

from gwdetchar import (const, cli, cds, __version__)
from gwdetchar.io import (ligolw, html as htmlio)
from gwdetchar.io.datafind import get_data
from gwdetchar.daq import find_crossings
from gwdetchar.plot import get_gwpy_tex_settings

__author__ = 'TJ Massinger <thomas.massinger@ligo.org>'
__credits__ = 'Duncan Macleod <duncan.macleod@ligo.org>'

parser = cli.create_parser(description=__doc__)
cli.add_gps_start_stop_arguments(parser)
cli.add_ifo_option(parser)
cli.add_nproc_option(parser)
cli.add_frametype_option(parser, required=const.IFO is None,
                         default=const.IFO is not None and '%s_R' % const.IFO)
parser.add_argument('-a', '--state-flag', metavar='FLAG',
                    help='restrict search to times when FLAG was active')
parser.add_argument('-o', '--output-path',
                    help='path to output xml file, name will be '
                         'automatically generated based on IFO and GPS times')
parser.add_argument('-t', '--threshold', type=float, nargs='+',
                    default=[0., 2.**16, -2.**16],
                    help='threshold for marking input data crossings')
parser.add_argument('-c', '--channel', required=True, type=str,
                    help='channel to read for input data')
parser.add_argument('-r', '--rate-thresh', default=16., type=float,
                    help='if the trigger rate (Hz) is above this value '
                         'an XML file will not be written')
args = parser.parse_args()

span = Segment(args.gpsstart, args.gpsend)

gprint('Processing channel %s over span %d - %d'
       % (args.channel, args.gpsstart, args.gpsend))

if args.state_flag:
    state = DataQualityFlag.query(args.state_flag, int(args.gpsstart),
                                  int(args.gpsend),
                                  url=const.DEFAULT_SEGMENT_SERVER)
    statea = state.active
else:
    statea = SegmentList([span])

duration = abs(span)

# set TeX settings
tex_settings = get_gwpy_tex_settings()
rcParams.update(tex_settings)

# initialize output files for each threshold and store them in a dict
outfiles = {}
for thresh in args.threshold:
    outfiles[str(thresh)] = (
        os.path.join(
            args.output_path,
            '%s_%s_DAC-%d-%d.xml.gz'
            % (args.channel.replace('-', '_').replace(':', '-'),
               str(int(thresh)).replace('-', 'n'),
               int(args.gpsstart), duration)))

# get frame cache
cache = gwdatafind.find_urls(args.ifo[0], args.frametype,
                             int(args.gpsstart), int(args.gpsend))

cachesegs = statea & cache_segments(cache)

# initialize a ligolw table for each threshold and store them in a dict
tables = {}
if not os.path.exists(args.output_path):
    os.makedirs(args.output_path)
for thresh in args.threshold:
    tables[str(thresh)] = ligolw.new_table(
        'sngl_burst',
        columns=['peak_time', 'peak_time_ns', 'peak_frequency', 'snr'])

# for each science segment, read in the data from frames, check for threshold
# crossings, and if the rate of crossings is less than rate_thresh, write to a
# sngl_burst table
for seg in cachesegs:
    gprint('Processing segment %d - %d' % (seg[0], seg[1]))
    c = sieve_cache(cache, segment=seg)
    data = get_data(args.channel, seg[0], seg[1], nproc=args.nproc, source=c)
    for thresh in args.threshold:
        times = find_crossings(data, thresh)
        gprint('Found %d crossings for threshold %d' % (len(times), thresh))
        gprint('Rate of crossings: %.2f Hz' % (float(len(times))/abs(seg)))
        if len(times) and (len(times)/abs(seg) < args.rate_thresh):
            tables[str(thresh)].extend(ligolw.sngl_burst_from_times(
                times, snr=10, peak_frequency=100,
                search=os.path.basename(__file__)))

for thresh in args.threshold:
    gprint('Writing output XML: %s' % outfiles[str(thresh)])
    file = open(outfiles[str(thresh)], 'w')
    gz = True
    write_fileobj(ligolw.table_to_document(tables[str(thresh)]), file, gz=gz)
