#!/usr/bin/env python
# coding=utf-8
# Copyright (C) LIGO Scientific Collaboration (2015-)
#
# This file is part of the GW DetChar python package.
#
# GW DetChar is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# GW DetChar is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with GW DetChar.  If not, see <http://www.gnu.org/licenses/>.

"""Compute an omega scan for a list of channels around a given GPS time

This utility can be used to process an arbitrary list of detector channels
with minimal effort in finding data. The input should be an INI-formatted
configuration file that lists processing options and channels in contextual
blocks. For more information, see gwdetchar.omega.config.
"""

from __future__ import division

import os
import sys
import warnings
import numpy

from gwpy.table import Table
from gwpy.time import to_gps

from matplotlib import use
use('agg')  # nopep8

from gwdetchar import (cli, omega)
from gwdetchar.omega import (config, html, plot)
from gwdetchar.io.datafind import (check_flag, get_data)

__author__ = 'Alex Urban <alexander.urban@ligo.org>'
__credits__ = 'Duncan Macleod <duncan.macleod@ligo.org>'


# -- parse command line -------------------------------------------------------

parser = cli.create_parser(description=__doc__)
parser.add_argument('gpstime', type=to_gps,
                    help='GPS time or datestring to scan')
cli.add_ifo_option(parser, required=False)
parser.add_argument('-o', '--output-directory',
                    help='output directory for the omega scan, '
                         'default: ~/public_html/wdq/{IFO}_{gpstime}')
parser.add_argument('-f', '--config-file', action='append', default=None,
                    help='path to configuration file to use, can be given '
                         'multiple times (files read in order), default: '
                         'choose a standard one based on IFO and GPS time')
parser.add_argument('-d', '--disable-correlation', action='store_true',
                    default=False, help='disable cross-correlation of aux '
                                        'channels, default: False')
parser.add_argument('-D', '--disable-checkpoint', action='store_true',
                    default=False, help='disable checkpointing from previous '
                                        'runs, default: False')
parser.add_argument('-s', '--ignore-state-flags', action='store_true',
                    default=False, help='ignore state flag definitions in '
                                        'the configuration, default: False')
parser.add_argument('-t', '--far-threshold', type=float, default=3.171e-8,
                    help='white noise false alarm rate threshold (Hz) for '
                         'processing channels, default: %(default)s')
parser.add_argument('-c', '--colormap', default='viridis',
                    help='name of colormap to use, default: %(default)s')
cli.add_nproc_option(parser)

args = parser.parse_args()

# set up logger
logger = cli.logger(name=os.path.basename(__file__))

# get run parameters
if args.ifo:
    ifo = args.ifo
else:
    ifo = 'Network'
gps = numpy.around(float(args.gpstime), 2)

logger.info("{} Omega Scan {}".format(ifo, gps))

# get default configuration
if args.config_file is None:
    args.config_file = config.get_default_configuration(ifo, gps)

# parse configuration files
args.config_file = [os.path.abspath(f) for f in args.config_file]
logger.debug('Parsing the following configuration files:')
for fname in args.config_file:
    logger.debug(''.join([' -- ', fname]))
cp = config.OmegaConfigParser(ifo=ifo)
cp.read(args.config_file)

# parse primary channel
if not args.disable_correlation:
    try:
        primary = config.OmegaChannelList(
            'primary', **dict(cp.items('primary')))
    except config.configparser.NoSectionError:
        logger.warning(
            'No primary configured, continuing without cross-correlation')
        args.disable_correlation = True
cp.remove_section('primary')

# get contextual channel blocks
blocks = cp.get_channel_blocks()

# set up analyzed channel dict
if sys.version_info >= (3, 7):  # python 3.7+
    analyzed = {}
else:
    from collections import OrderedDict
    analyzed = OrderedDict()

# prepare html variables
htmlv = {
    'title': '{} Qscan | {}'.format(ifo, gps),
    'config': args.config_file,
    'refresh': True,
}

# set output directory
outdir = args.output_directory
if outdir is None:
    outdir = os.path.expanduser('~/public_html/wdq/{ifo}_{gps}'.format(
        ifo=ifo, gps=gps))
outdir = os.path.abspath(outdir)
if not os.path.isdir(outdir):
    os.makedirs(outdir)
os.chdir(outdir)
logger.debug('Output directory created as {}'.format(outdir))


# -- Compute Qscan ------------------------------------------------------------

# make subdirectories
plotdir = 'plots'
aboutdir = 'about'
datadir = 'data'
for d in [plotdir, aboutdir, datadir]:
    if not os.path.isdir(d):
        os.makedirs(d)

# determine checkpoints
summary = os.path.join(datadir, 'summary.csv')
if os.path.exists(summary) and not args.disable_checkpoint:
    logger.debug('Checkpointing from {}'.format(
        os.path.abspath(summary)))
    record = Table.read(summary)
    completed = list(record['Channel'])
    if not args.disable_correlation and ('Standard Deviation'
                                         not in record.colnames):
        raise KeyError(
            'Cross-correlation is not available from this record, '
            'consider running without correlation or starting from '
            'scratch with --disable-checkpoint')
else:
    record = []
    completed = []

# set up html output
logger.debug('Setting up HTML at {}/index.html'.format(outdir))
html.write_qscan_page(ifo, gps, analyzed, **htmlv)

# launch omega scans
logger.info('Launching omega scans')

# construct a matched-filter from primary channel
if not args.disable_correlation:
    logger.debug('Processing primary channel')
    duration = primary.duration
    fftlength = primary.fftlength
    # process `duration` seconds of data centered on gps
    name = primary.channel.name
    start = gps - duration/2. - 1
    end = gps + duration/2. + 1
    correlate = get_data(
        name, start, end, frametype=primary.frametype, source=primary.source,
        nproc=args.nproc, verbose='Reading primary:'.rjust(30))
    correlate = omega.primary(
        gps, primary.length, correlate, fftlength, resample=primary.resample,
        f_low=primary.flow)
    plot.timeseries_plot(correlate, gps, primary.length, name,
                         'plots/primary.png', ylabel='Whitened Amplitude')
    # prepare HTML output
    htmlv['correlated'] = True
    htmlv['primary'] = name
else:
    correlate = None

# range over channel blocks
for block in blocks.values():
    logger.debug('Processing block {}'.format(block.key))
    chans = [c.name for c in block.channels]
    # get configuration
    duration = block.duration
    fftlength = block.fftlength
    # check that analysis flag is active for all of `duration`
    if block.flag and (not args.ignore_state_flags):
        logger.info(' -- Querying state flag {}'.format(block.flag))
        if not check_flag(block.flag, gps, duration, pad=1):
            logger.info(
                ' -- {} not active, skipping block'.format(block.flag))
            continue
    # read `duration` seconds of data
    if not (set(chans) <= set(completed)):
        start = gps - duration/2. - 1
        end = gps + duration/2. + 1
        data = get_data(
            chans, start, end, frametype=block.frametype, source=block.source,
            nproc=args.nproc, verbose='Reading block:'.rjust(30))

    # process individual channels
    for channel in block.channels:
        if channel.name in completed:  # load checkpoint
            logger.info(' -- Checkpointing {} from a previous '
                        'run'.format(channel.name))
            cindex = completed.index(channel.name)
            channel.load_loudest_tile_features(
                record[cindex], correlated=correlate is not None)
            analyzed = html.update_toc(analyzed, channel,
                                       name=blocks[channel.section].name)
            htmlv['toc'] = analyzed
            html.write_qscan_page(ifo, gps, analyzed, **htmlv)
            continue

        try:  # scan the channel
            logger.info(' -- Scanning channel {}'.format(channel.name))
            series = omega.scan(
                gps, channel, data[channel.name].astype('float64'), fftlength,
                resample=block.resample, fthresh=args.far_threshold,
                search=block.search)
        except (ValueError, KeyError) as exc:
            warnings.warn("Skipping {}: [{}] {}".format(
                channel.name, type(exc), str(exc)), UserWarning)
            continue

        if series is None:
            logger.warning(
                ' -- Channel not significant at white noise false alarm '
                'rate {} Hz'.format(args.far_threshold))
            continue  # channel is insignificant, skip it

        logger.info(
            ' -- Plotting omega scan products for {}'.format(channel.name))
        plot.write_qscan_plots(gps, channel, series, colormap=args.colormap)

        if correlate is not None:
            logger.info(' -- Cross-correlating {}'.format(channel.name))
            correlation = omega.cross_correlate(series[2], correlate)
            channel.save_loudest_tile_features(
                series[3], correlation, gps=gps, dt=block.dt)
        else:
            channel.save_loudest_tile_features(series[3])

        analyzed = html.update_toc(analyzed, channel,
                                   name=blocks[channel.section].name)
        htmlv['toc'] = analyzed
        html.write_qscan_page(ifo, gps, analyzed, **htmlv)


# -- Prepare HTML -------------------------------------------------------------

# write HTML page and finish
logger.debug('Finalizing HTML at {}/index.html'.format(outdir))
htmlv['refresh'] = False  # turn off auto-refresh
if analyzed:
    html.write_qscan_page(ifo, gps, analyzed, **htmlv)
else:
    reason = 'No significant channels found during active analysis segments'
    html.write_null_page(ifo, gps, reason, **htmlv)
logger.info("-- index.html written, all done --")
