#!python
# coding=utf-8
# Copyright (C) LIGO Scientific Collaboration (2015-)
#
# This file is part of the GW DetChar python package.
#
# GW DetChar is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# GW DetChar is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with GW DetChar.  If not, see <http://www.gnu.org/licenses/>.

"""Find times when an input signal crosses a particular threshold
"""
from __future__ import print_function

import numpy
import os
import sys

from matplotlib import (use, rcParams)
use('agg')

import gwdatafind

from glue.ligolw.utils import write_fileobj

from gwpy.segments import (DataQualityFlag, DataQualityDict,
                           Segment, SegmentList)

from gwpy.timeseries import (TimeSeries, TimeSeriesDict)
from gwpy.io.cache import (cache_segments, sieve as sieve_cache)
from gwpy.utils import gprint

from gwdetchar import (const, cli, cds, __version__)
from gwdetchar.io import (ligolw, html as htmlio)
from gwdetchar.daq import find_crossings
from gwdetchar.misc import get_gwpy_tex_settings

__author__ = 'TJ Massinger <thomas.massinger@ligo.org>'
__credits__ = 'Duncan Macleod <duncan.macleod@ligo.org>'

def table_from_times(times):
    return ligolw.sngl_burst_from_times(
        times, snr=10, peak_frequency=100,
        search=os.path.basename(__file__))

parser = cli.create_parser(description=__doc__)
cli.add_gps_start_stop_arguments(parser)
cli.add_ifo_option(parser)
cli.add_nproc_option(parser)
cli.add_frametype_option(parser, required=const.IFO is None,
                         default=const.IFO is not None and '%s_R' % const.IFO)
parser.add_argument('-a', '--state-flag', metavar='FLAG',
                    help='restrict search to times when FLAG was active')
parser.add_argument('-o', '--output-path',
                    help='path to output xml file, name will be '
                         'automatically generated based on IFO and GPS times')
parser.add_argument('-t', '--threshold', nargs='+', default=[0.,2.**16,-2.**16],
                    type=float,help='threshold for marking input data crossings')
parser.add_argument('-c', '--channel', required=True, type=str,
                    help='channel to read for input data')
parser.add_argument('-r', '--rate-thresh', default=16., type=float,
                    help='if the trigger rate (Hz) is above this value '
                         'an XML file will not be written')
args = parser.parse_args()

span = Segment(args.gpsstart, args.gpsend)

gprint('Processing channel %s over span %d - %d' % (args.channel,args.gpsstart,args.gpsend))

if args.state_flag:
    state = DataQualityFlag.query(args.state_flag, int(args.gpsstart),
                                 int(args.gpsend),
                                 url=const.DEFAULT_SEGMENT_SERVER)
    statea = state.active
else:
    statea = SegmentList([span])

duration = abs(span)

# set TeX settings
tex_settings = get_gwpy_tex_settings()
rcParams.update(tex_settings)

# initialize output files for each threshold and store them in a dict
outfiles = {}
for thresh in args.threshold:
    outfiles[str(thresh)] = (args.output_path + '%s_%s_DAC-%d-%d.xml.gz'
        % (args.channel.replace('-','_').replace(':','-'), str(int(thresh)).replace('-','n'), 
           int(args.gpsstart), duration))

# get frame cache
cache = gwdatafind.find_urls(args.ifo[0], args.frametype,
                             int(args.gpsstart), int(args.gpsend))

cachesegs = statea & cache_segments(cache)

# initialize a ligolw table for each threshold and store them in a dict
tables = {}
if not os.path.exists(args.output_path):
    os.makedirs(args.output_path)
for thresh in args.threshold:
    tables[str(thresh)] = ligolw.new_table('sngl_burst',
        columns=['peak_time', 'peak_time_ns', 'peak_frequency','snr'])

# for each science segment, read in the data from frames, check for threshold
# crossings, and if the rate of crossings is less than rate_thresh, write to a
# sngl_burst table
for seg in cachesegs:
    gprint('Processing segment %d - %d' % (seg[0],seg[1]))
    c = sieve_cache(cache, segment=seg)
    data = TimeSeries.read(c, args.channel, nproc=args.nproc,
                           start=seg[0], end=seg[1])
    for thresh in args.threshold:
        times = find_crossings(data,thresh)
        gprint('Found %d crossings for threshold %d' % (len(times),thresh))
        gprint('Rate of crossings: %.2f Hz' % (float(len(times))/abs(seg)))
        if len(times) and (len(times)/abs(seg) < args.rate_thresh):
            tables[str(thresh)].extend(table_from_times(times))

for thresh in args.threshold:
    gprint('Writing output XML: %s' % outfiles[str(thresh)])
    file = open(outfiles[str(thresh)],'w')
    gz = True
    write_fileobj(ligolw.table_to_document(tables[str(thresh)]),file,gz=gz)

