#!/usr/bin/env python
# coding=utf-8
# Copyright (C) LIGO Scientific Collaboration (2015-)
#
# This file is part of the GW DetChar python package.
#
# GW DetChar is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# GW DetChar is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with GW DetChar.  If not, see <http://www.gnu.org/licenses/>.

from __future__ import (division, print_function)

import os
import re
import multiprocessing
import sys

import numpy

from matplotlib import (use, rcParams)
use('agg')  # nopep8
from matplotlib import cm as mcm

from sklearn import linear_model
from sklearn.preprocessing import scale

from gwpy.table import Table
from pandas import set_option

from gwpy.plot import Plot
from gwpy.time import Time
from gwpy.detector import ChannelList
from gwpy.io import nds2 as io_nds2

from gwdetchar import (cli, lasso as gwlasso)
from gwdetchar.lasso import plot as gwplot
from gwdetchar.io import html as htmlio
from gwdetchar.io.datafind import get_data
from gwdetchar.plot import get_gwpy_tex_settings


# default frametypes
DEFAULT_FRAMETYPE = {
    'GDS-CALIB_STRAIN': '{ifo}_HOFT_C00',
    'DMT-SNSH_EFFECTIVE_RANGE_MPC.mean': 'SenseMonitor_hoft_{ifo}_M',
    'DMT-SNSW_EFFECTIVE_RANGE_MPC.mean': 'SenseMonitor_CAL_{ifo}_M',
}


# -- parse command line -------------------------------------------------------

parser = cli.create_parser(
    description=__doc__,
    formatter_class=cli.argparse.ArgumentDefaultsHelpFormatter)
cli.add_gps_start_stop_arguments(parser)
cli.add_ifo_option(parser)
cli.add_nproc_option(parser, default=1)
parser.add_argument('-J', '--nproc-plot', type=int, default=None,
                    help='number of processes to use for plotting')
parser.add_argument('-o', '--output-dir', default=os.curdir,
                    help='output directory for plots')
parser.add_argument('-f', '--channel-file', type=os.path.abspath,
                    help='path for channel file')
parser.add_argument('-T', '--trend-type', default='minute',
                    choices=['second', 'minute'],
                    help='type of trend for correlation')
parser.add_argument('-p', '--primary-channel',
                    default='{ifo}:DMT-SNSH_EFFECTIVE_RANGE_MPC.mean',
                    help='name of primary channel to use')
parser.add_argument('-P', '--primary-frametype', default=None,
                    help='frametype for --primary-channel, default: guess '
                         'by channel name')
parser.add_argument('-O', '--remove-outliers', type=float, default=None,
                    help='Std. dev. limit for removing outliers')
parser.add_argument('-t', '--threshold', type=float, default=0.0001,
                    help='threshold for making a plot')

psig = parser.add_argument_group('Signal processing options')
psig.add_argument('-b', '--band-pass', type=float, nargs=2, default=None,
                  metavar="FLOW FHIGH",
                  help='lower and upper frequencies for bandpass on h(t)')
psig.add_argument('-x', '--filter-padding', type=float, default=3.,
                  help='amount of time (seconds) to pad data for filtering')

lsig = parser.add_argument_group('Lasso options')
lsig.add_argument('-a', '--alpha', default=None, type=float,
                  help='alpha parameter for lasso fit')
lsig.add_argument('-C', '--no-cluster', action='store_true', default=False,
                  help='do not generate clustered channel plots')
lsig.add_argument('-c', '--cluster-coefficient', default=.85, type=float,
                  help='correlation coefficient threshold for clustering')
lsig.add_argument('-L', '--line-size-primary', default=1, type=float,
                  help='line width of primary channel')
lsig.add_argument('-l', '--line-size-aux', default=0.75, type=float,
                  help='line width of auxilary channel')

args = parser.parse_args()

# set up logger
logger = cli.logger(name=os.path.basename(__file__))

# set TeX settings
tex_settings = get_gwpy_tex_settings()
rcParams.update(tex_settings)

start = int(args.gpsstart)
end = int(args.gpsend)
pad = args.filter_padding
auto_xlabel = ('Time [hours] from '
               + re.sub(r'\.0+', '',
                        Time(start, format='gps', scale='utc').iso)
               + ' UTC (%d)' % start)

# let's go
logger.info('{} Lasso Correlations {}-{}'.format(args.ifo, start, end))

# get primary channel frametype
primary = args.primary_channel.format(ifo=args.ifo)
range_is_primary = 'EFFECTIVE_RANGE_MPC' in args.primary_channel
if args.primary_frametype is None:
    try:
        args.primary_frametype = DEFAULT_FRAMETYPE[
            args.primary_channel.split(':')[1]].format(ifo=args.ifo)
    except KeyError as exc:
        raise type(exc)("Could not determine primary channel's frametype, "
                        "please specify with --primary-frametype")


if not os.path.isdir(args.output_dir):
    os.makedirs(args.output_dir)
os.chdir(args.output_dir)
nprocplot = args.nproc_plot or args.nproc

if args.band_pass:
    try:
        flower, fupper = args.band_pass
    except TypeError:
        flower, fupper = None

    logger.info("-- Loading primary channel data")
    bandts = get_data(
        primary, start-pad, end+pad, verbose='Reading primary:'.rjust(30),
        frametype=args.primary_frametype, nproc=args.nproc)
    if flower < 0 or fupper >= float((bandts.sample_rate/2.).value):
        raise ValueError("bandpass frequency is out of range for this "
                         "channel, band (Hz): {0}, sample rate: {1}".format(
                             args.band_pass, bandts.sample_rate))

    # get darm BLRMS
    logger.debug("-- Filtering data")
    if args.trend_type == 'minute':
        stride = 60
    else:
        stride = 1
    if flower:
        darmblrms = (
            bandts.highpass(flower/2., fstop=flower/4.,
                            filtfilt=False, ftype='butter')
            .notch(60, filtfilt=False)
            .bandpass(flower, fupper, fstop=[flower/2., fupper*1.5],
                      filtfilt=False, ftype='butter')
            .crop(start, end).rms(stride))
        darmblrms.name = '%s %s-%s Hz BLRMS' % (primary, flower, fupper)
    else:
        darmblrms = bandts.notch(60).crop(start, end).rms(stride)
        darmblrms.name = '%s RMS' % primary

    primaryts = darmblrms

else:
    # load primary channel data
    logger.info("-- Loading primary channel data")
    primaryts = get_data(primary, start, end, frametype=args.primary_frametype,
                         verbose='Reading:'.rjust(30),
                         nproc=args.nproc).crop(start, end)

if args.remove_outliers:
    logger.debug("-- Removing outliers above %f sigma" % args.remove_outliers)
    gwlasso.remove_outliers(primaryts, args.remove_outliers)

primary_mean = numpy.mean(primaryts.value)
primary_std = numpy.std(primaryts.value)


def descaler(l, *coef):
    if coef:
        return [((x * primary_std * coef[0]) + primary_mean) for x in l]
    else:
        return [((x * primary_std) + primary_mean) for x in l]


# get aux data
logger.info("-- Loading auxiliary channel data")
if args.channel_file is None:
    host, port = io_nds2.host_resolution_order(args.ifo)[0]
    channels = ChannelList.query_nds2('*.mean', host=host, port=port,
                                      type='m-trend')
else:
    with open(args.channel_file, 'r') as f:
        channels = [name.rstrip('\n') for name in f]
nchan = len(channels)
logger.debug("Identified %d channels" % nchan)

if args.trend_type == 'minute':
    frametype = '%s_M' % args.ifo  # for minute trends
else:
    frametype = '%s_T' % args.ifo  # for second trends

auxdata = get_data(
    channels, start, end, verbose='Reading:'.rjust(30),
    frametype=frametype, nproc=args.nproc, pad=0).crop(start, end)

# -- removes flat data to be re-introdused later

logger.info('-- Pre-processing auxiliary channel data')

auxdata = gwlasso.remove_flat(auxdata)
flattab = Table(data=(list(set(channels) - set(auxdata.keys())),),
                names=('Channels',))
logger.debug('Removed {0} channels with flat data'.format(len(flattab)))
logger.debug('{0} channels remaining'.format(len(auxdata)))

# -- remove bad data

logger.info("Removing any channels with bad data...")
nbefore = len(auxdata)
auxdata = gwlasso.remove_bad(auxdata)
nafter = len(auxdata)
logger.debug('Removed {0} channels with bad data'.format(nbefore - nafter))
logger.debug('{0} channels remaining'.format(nafter))
data = numpy.array([scale(ts.value) for ts in auxdata.values()]).T

# -- perform lasso regression -------------------------------------------------

# create model
logger.info('-- Fitting data to target')
target = scale(primaryts.value)
model = gwlasso.fit(data, target, alpha=args.alpha)
logger.info('Alpha: {}'.format(model.alpha))

# restructure results for convenience
allresults = Table(
    data=(list(auxdata.keys()), model.coef_, numpy.abs(model.coef_)),
    names=('Channel', 'Lasso coefficient', 'rank'))
allresults.sort('rank')
allresults.reverse()
useful = allresults['rank'] > 0
allresults.remove_column('rank')
results = allresults[useful]  # non-zero coefficient
zeroed = allresults[numpy.invert(useful)]  # zero coefficient

# extract data for useful channels
nonzerodata = {name: auxdata[name] for name in results['Channel']}
nonzerocoef = {name: coeff for name, coeff in results.as_array()}

# print results
logger.info('Found {} channels with |Lasso coefficient| >= {}:\n\n'.format(
            len(results), args.threshold))
print(results)
print('\n\n')

# convert to pandas
set_option('max_colwidth', 200)
df = results.to_pandas()
df.index += 1

# write results to files
gpsstub = '%d-%d' % (start, end-start)
resultsfile = '%s-LASSO_RESULTS-%s.txt' % (args.ifo, gpsstub)
results.write(resultsfile, format='ascii', overwrite=True)
zerofile = '%s-ZERO_COEFFICIENT_CHANNELS-%s.txt' % (args.ifo, gpsstub)
zeroed.write(zerofile, format='ascii', overwrite=True)
flatfile = '%s-FLAT_CHANNELS-%s.txt' % (args.ifo, gpsstub)
flattab.write(flatfile, format='ascii', overwrite=True)

# -- generate lasso plots

modelFit = model.predict(data)

re_delim = re.compile(r'[:_-]')
form = '%%.%dd' % len(str(nchan))
p1 = (.1, .15, .9, .9)  # global plot defaults for plot1, lasso model

times = primaryts.times.value
xlim = primaryts.span
cmap = mcm.get_cmap('tab20')
colors = [cmap(i) for i in numpy.linspace(0, 1, len(nonzerodata)+1)]

plot = Plot(figsize=(12, 4))
plot.subplots_adjust(*p1)
ax = plot.gca(xscale='auto-gps', epoch=start, xlim=xlim)
ax.plot(times, descaler(target), label=primary.replace('_', r'\_'),
        color='black', linewidth=args.line_size_primary)
ax.plot(times, descaler(modelFit), label='Lasso model',
        linewidth=args.line_size_aux)
if range_is_primary:
    ax.set_ylabel('Sensitive range [Mpc]')
    ax.set_title('Lasso Model of Range')
else:
    ax.set_ylabel('Primary Channel Units')
    ax.set_title('Lasso Model of Primary Channel')
ax.legend(loc='best')
plot1 = gwplot.save_figure(
     plot, '%s-LASSO_MODEL-%s.png' % (args.ifo, gpsstub),
     bbox_inches='tight')

# summed contributions
plot = Plot(figsize=(12, 4))
plot.subplots_adjust(*p1)
ax = plot.gca(xscale='auto-gps', epoch=start, xlim=xlim)
ax.plot(times, descaler(target), label=primary.replace('_', r'\_'),  # primary
        color='black', linewidth=args.line_size_primary)
summed = 0
for i, name in enumerate(results['Channel']):
    summed += scale(nonzerodata[name].value) * nonzerocoef[name]
    if i:
        label = 'Channels 1-{0}'.format(i+1)
    else:
        label = 'Channel 1'
    ax.plot(times, descaler(summed), label=label, color=colors[i],
            linewidth=args.line_size_aux)
if range_is_primary:
    ax.set_ylabel('Sensitive range [Mpc]')
else:
    ax.set_ylabel('Primary Channel Units')
ax.set_title('Summations of Channel Contributions to Model')
ax.legend(loc='center left', bbox_to_anchor=(1.05, 0.5))
plot2 = gwplot.save_figure(
    plot, '%s-LASSO_CHANNEL_SUMMATION-%s.png' % (args.ifo, gpsstub),
    bbox_inches='tight')


# individual contributions
plot = Plot(figsize=(12, 4))
plot.subplots_adjust(*p1)
ax = plot.gca(xscale='auto-gps', epoch=start, xlim=xlim)
ax.plot(times, descaler(target), label=primary.replace('_', r'\_'),  # primary
        color='black', linewidth=args.line_size_primary)
for i, name in enumerate(results['Channel']):
    this = descaler(scale(nonzerodata[name].value) * nonzerocoef[name])
    if i:
        label = 'Channels 1-{0}'.format(i+1)
    else:
        label = 'Channel 1'
    ax.plot(times, this, label=name.replace('_', r'\_'), color=colors[i],
            linewidth=args.line_size_aux)
if range_is_primary:
    ax.set_ylabel('Sensitive range [Mpc]')
else:
    ax.set_ylabel('Primary Channel Units')
ax.set_title('Individual Channel Contributions to Model')
ax.legend(loc='center left', bbox_to_anchor=(1.05, 0.5))
plot3 = gwplot.save_figure(
    plot, '%s-LASSO_CHANNEL_CONTRIBUTIONS-%s.png' % (args.ifo, gpsstub),
    bbox_inches='tight')

# -- process aux channels, making plots

logger.info("-- Processing channels")
counter = multiprocessing.Value('i', 0)
p4 = (.1, .1, .9, .95)  # global plot defaults for plot4, timeseries subplots


def process_channel(input_,):
    gwplot.configure_mpl()
    chan = input_[1][0]
    ts = input_[1][1]
    lassocoef = nonzerocoef[chan]
    zeroed = lassocoef == 0

    if zeroed:
        plot4 = None
        plot5 = None
        plot6 = None
        pcorr = None
    else:
        plot4 = None
        plot5 = None
        plot6 = None
        if args.trend_type == 'minute':
            pcorr = numpy.corrcoef(ts.value, primaryts.value)[0, 1]
        else:
            pcorr = 0.0
        if abs(lassocoef) < args.threshold:
            with counter.get_lock():
                counter.value += 1
            pc = 100 * counter.value / len(nonzerodata)
            logger.info("Completed [%d/%d] %3d%% %-50s"
                        % (counter.value, len(nonzerodata), pc,
                           '(%s)' % str(chan)))
            sys.stdout.flush()
            return chan, lassocoef, plot4, plot5, plot6, ts

        # create time series subplots
        fig = Plot(figsize=(12, 8))
        fig.subplots_adjust(*p4)
        ax1 = fig.add_subplot(2, 1, 1, xscale='auto-gps', epoch=start)
        ax1.plot(primaryts, label=primary.replace('_', r'\_'),
                 color='black', linewidth=args.line_size_primary)
        ax1.set_xlabel(None)
        ax2 = fig.add_subplot(2, 1, 2, sharex=ax1, xlim=xlim)
        ax2.plot(ts, label=chan.replace('_', r'\_'),
                 linewidth=args.line_size_aux)
        if range_is_primary:
            ax1.set_ylabel('Sensitive range [Mpc]')
        else:
            ax1.set_ylabel('Primary channel units')
        ax2.set_ylabel('Channel units')
        for ax in fig.axes:
            ax.legend(loc='best')
        channelstub = re_delim.sub('_', str(chan)).replace('_', '-', 1)
        plot4 = gwplot.save_figure(
            fig, '%s_TRENDS-%s.png' % (channelstub, gpsstub),
            bbox_inches='tight')

        # create scaled, sign-corrected, and overlayed timeseries
        tsscaled = scale(ts.value)
        if lassocoef < 0:
            tsscaled = numpy.negative(tsscaled)
        fig = Plot(figsize=(12, 4))
        fig.subplots_adjust(*p1)
        ax = fig.gca(xscale='auto-gps', epoch=start, xlim=xlim)
        ax.plot(times, descaler(target), label=primary.replace('_', r'\_'),
                color='black', linewidth=args.line_size_primary)
        ax.plot(times, descaler(tsscaled), label=chan.replace('_', r'\_'),
                linewidth=args.line_size_aux)
        if range_is_primary:
            ax.set_ylabel('Sensitive range [Mpc]')
        else:
            ax.set_ylabel('Primary Channel Units')
        ax.legend(loc='best')
        plot5 = gwplot.save_figure(
            fig, '%s_COMPARISON-%s.png' % (channelstub, gpsstub),
            bbox_inches='tight')

        # scatter plot
        tsCopy = ts.value.reshape(-1, 1)
        primarytsCopy = primaryts.value.reshape(-1, 1)
        primaryReg = linear_model.LinearRegression()
        primaryReg.fit(tsCopy, primarytsCopy)
        primaryFit = primaryReg.predict(tsCopy)
        fig = Plot(figsize=(12, 4))
        fig.subplots_adjust(*p1)
        ax = fig.gca()
        ax.set_xlabel(chan.replace('_', r'\_') + ' [Channel units]')
        if range_is_primary:
            ax.set_ylabel('Sensitive range [Mpc]')
        else:
            ax.set_ylabel('Primary channel units')
        ax.text(.9, .1, 'r = ' + str('{0:.2}'.format(pcorr)),
                verticalalignment='bottom', horizontalalignment='right',
                transform=ax.transAxes, color='black', size=20,
                bbox=dict(boxstyle='square', facecolor='white', alpha=.75,
                          edgecolor='black'))
        ax.scatter(ts.value, primaryts.value, color='red')
        ax.plot(ts.value, primaryFit, color='black')
        ax.autoscale_view(tight=False, scalex=True, scaley=True)
        plot6 = gwplot.save_figure(
            fig, '%s_SCATTER-%s.png' % (channelstub, gpsstub),
            bbox_inches='tight')

    # increment counter and print status
    with counter.get_lock():
        counter.value += 1
        pc = 100 * counter.value / len(nonzerodata)
        logger.info("Completed [%d/%d] %3d%% %-50s"
                    % (counter.value, len(nonzerodata), pc,
                       '(%s)' % str(chan)))
        sys.stdout.flush()
    return chan, lassocoef, plot4, plot5, plot6, ts


# process channels
pool = multiprocessing.Pool(nprocplot)
results = pool.map(process_channel, enumerate(list(nonzerodata.items())))
results = sorted(results, key=lambda x: abs(x[1]), reverse=True)

#  generate clustered time series plots
counter = multiprocessing.Value('i', 0)
p7 = (.135, .15, .95, .9)  # global plot defaults for plot7, clusters
max_correlated_channels = 20


def generate_cluster(input_,):
    gwplot.configure_mpl()
    currentchan = input_[1][0]
    currentts = input_[1][5]
    current = input_[0]
    cluster_threshold = args.cluster_coefficient
    plot7 = None
    plot7_list = None

    if current < len(nonzerodata):
        cluster = []
        for i, otheritem in enumerate(list(auxdata.items())):
            chan_, ts_ = otheritem
            if chan_ != currentchan:
                pcorr = numpy.corrcoef(currentts.value, ts_.value)[0, 1]
                if abs(pcorr) >= cluster_threshold:
                    stub = re_delim.sub('_', chan_).replace('_', '-', 1)
                    cluster.append([i, ts_, pcorr, chan_, stub])

        if cluster:
            # write cluster table to file
            cluster = sorted(cluster, key=lambda x: abs(x[2]), reverse=True)
            clustertab = Table(data=list(zip(*cluster))[2:4],
                               names=('Channel', 'Pearson Coefficient'))
            plot7_list = '%s_CLUSTER_LIST-%s.txt' % (
                re_delim.sub('_', str(currentchan)).replace('_', '-', 1),
                gpsstub)
            clustertab.write(plot7_list, format='ascii', overwrite=True)

            ncluster = min(len(cluster), max_correlated_channels)
            colors2 = [cmap(i) for i in numpy.linspace(0, 1, ncluster+1)]

            # plot
            fig = Plot(figsize=(12, 4))
            fig.subplots_adjust(*p7)
            ax = fig.gca(xscale='auto-gps')
            ax.plot(
                times, scale(currentts.value)*numpy.sign(input_[1][1]),
                label=currentchan.replace('_', r'\_'),
                linewidth=args.line_size_aux, color=colors[0])

            for i in range(0, ncluster):
                this = cluster[i]
                ax.plot(
                    times,
                    scale(this[1].value) * numpy.sign(input_[1][1]) *
                    numpy.sign(this[2]),
                    color=colors2[i+1],
                    linewidth=args.line_size_aux,
                    label=('{0}, r = {1:.2}'.format(
                        cluster[i][3].replace('_', r'\_'), cluster[i][2])),
                )

            ax.margins(x=0)
            ax.set_ylabel('Scaled amplitude [arbitrary units]')
            ax.set_title('Highly Correlated Channels')
            ax.legend(loc='center left', bbox_to_anchor=(1.05, 0.5))
            plot7 = gwplot.save_figure(fig, '%s_CLUSTER-%s.png' % (
                re_delim.sub('_', str(currentchan))
                        .replace('_', '-', 1),
                gpsstub), bbox_inches='tight')

    with counter.get_lock():
        counter.value += 1
        pc = 100 * counter.value / len(nonzerodata)
        logger.info("Completed [%d/%d] %3d%% %-50s"
                    % (counter.value, len(nonzerodata), pc,
                       '(%s)' % str(currentchan)))
        sys.stdout.flush()
    return plot7, plot7_list


if args.no_cluster is False:
    logger.info("-- Generating clusters")
    pool = multiprocessing.Pool(nprocplot)
    clusters = pool.map(generate_cluster, enumerate(results))

channelsfile = '%s-CHANNELS-%s.txt' % (args.ifo, gpsstub)
numpy.savetxt(channelsfile, channels, fmt='%s')


# format results table
def format_indices(value):
    return "<div style='text-align:center;'>{}</div>".format(value)


def format_channels(value):
    return ("<div style='text-align:left;"
            "font-family:monospace,monospace;'>{}</div>".format(value))


def format_coefficients(value):
    return "<div style='text-align:left;'>{}</div>".format(value)


def style_table(html_table):
    html_table = (
        html_table.replace(
            '<table',
            ("<table style="
             "'border-collapse:collapse;"
             "border-style:hidden;"
             "width:85%; '"))
        .replace(
            '<tr>',
            ('''<tr style='''
             '''"border-left:0 !important;'''
             '''border-right:0 !important;"'''
             '''onmouseover="this.style.backgroundColor='f2f2f2';"'''
             '''onmouseout="this.style.backgroundColor='';">'''))
        .replace(
            '<th>',
            ("<th style="
             "'border-left:0 !important;"
             "border-right:0 !important;"
             "padding-left:10px;"
             "padding-right:10px;"
             "padding-top:5px; "
             "padding-bottom:5px;'>"))
        .replace(
            '<td>',
            ("<td style = "
             "'border-left:0 !important;"
             "border-right:0 !important;'>")))
    return html_table


# write html
title = '%s Lasso Slow Correlations: %d-%d' % (args.ifo, start, end)
links = [(s, '#%s' % s.lower()) for s in ['Parameters', 'Model', 'Results']]
(brand, class_) = htmlio.get_brand(args.ifo, 'Lasso', start)
navbar = htmlio.navbar(links, class_=class_, brand=brand)
page = htmlio.new_bootstrap_page(title=title, navbar=navbar)

page.div(class_='page-header')
page.h1(title)
page.div.close()  # page-header

content = [
    ('Primary channel',
     '{} ({})'.format(primary, args.primary_frametype or '-')),
    ('Channels searched',
     '{} ({})'.format(nchan, "<a href= %s target='_blank'>channel list</a>"
                             % channelsfile)),
    ('Number of flat channels',
     '{} ({})'.format(len(flattab),
                      "<a href= %s target='_blank'>flat channel list</a>"
                      % (flatfile))),
    ('Lasso coefficient threshold', '{}'.format(args.threshold)),
    ('Sigma for outlier removal', '{}'.format(args.remove_outliers)),
    ('Cluster coefficient threshold', '{}'.format(args.cluster_coefficient))]
if args.band_pass:
    content.append(
        ('Filter settings',
         'Flow = {} Hz, Fhigh = {} Hz '.format(flower, fupper)))
page.add(htmlio.write_arguments(content, start=start, end=end))

page.h2('Model Information', id_='model')

page.div(class_='model')
page.div(class_='model-body')

page.div(class_='model-info')
page.add(htmlio.write_param('Model', 'Lasso'))
page.add(htmlio.write_param(
    'Non-zero coefficients', '%d' % numpy.count_nonzero(model.coef_)))
page.add(htmlio.write_param('Alpha', '%g' % model.alpha))
page.add(htmlio.write_param(
    'Zero coefficients',
    '%d (%s)' % (len(zeroed),
                 "<a href= %s target='_blank'>zeroed channel list</a>"
                 % (zerofile))))
page.div.close()  # model-info

page.div(class_='scaffold well')

page.div(class_='row')
page.div(class_='results-table', align='center')
page.add(style_table(df.to_html(
    index=True,
    formatters={
        'Lasso coefficient': format_coefficients,
        'Channel': format_channels,
        '__index__': format_indices},
    escape=False)))
page.p.close()
page.div.close()  # results-table
page.div.close()  # row

page.div(class_='row', id_='primary-lasso')
page.div(class_='col-md-8 col-md-offset-2')
img1 = htmlio.FancyPlot(plot1)
page.add(htmlio.fancybox_img(img1))  # primary lasso plot
page.div.close()  # col-md-8 col-md-offset-2
page.div.close()  # primary-lasso
page.add('<hr class="row-divider">')

page.div(class_='row', id_='channel-summation')
img2 = htmlio.FancyPlot(plot2)
page.div(class_='col-md-8 col-md-offset-2')
page.add(htmlio.fancybox_img(img2))
page.div.close()  # col-md-8 col-md-offset-2
page.div.close()  # channel-summation

page.div(class_='row', id_='channels-and-primary')
img3 = htmlio.FancyPlot(plot3)
page.div(class_='col-md-8 col-md-offset-2')
page.add(htmlio.fancybox_img(img3))
page.div.close()  # col-md-8 col-md-offset-2
page.div.close()  # channels-and-primary

page.div.close()  # scaffold well
page.div.close()  # model-body
page.div.close()  # model

# results
page.h2('Top Channels', id_='results')

page.div(class_='panel-group', id_='results')

# for each auxiliary channel create information container and put plots in it
for i, (ch, lassocoef, plot4, plot5, plot6, ts) in enumerate(results):
    # set container color/context based on lasso coefficient
    if lassocoef == 0:
        break
    elif abs(lassocoef) < args.threshold:
        h = '%s [lasso coefficient = %.4f] (Below threshold)' % (ch, lassocoef)
    else:
        h = '%s [lasso coefficient = %.4f]' % (ch, lassocoef)
    if ((lassocoef is None) or (lassocoef == 0)
            or (abs(lassocoef) < args.threshold)):
        context = 'panel-default'
    elif abs(lassocoef) >= .5:
        context = 'panel-danger'
    elif abs(lassocoef) >= .2:
        context = 'panel-warning'
    else:
        context = 'panel-info'
    page.div(class_='panel %s' % context)

    # heading
    page.div(class_='panel-heading')
    page.a(h, class_='panel-title', href='#channel%d' % i,
           **{'data-toggle': 'collapse', 'data-parent': '#results'})
    page.div.close()  # panel-heading
    # body
    page.div(id_='channel%d' % i, class_='panel-collapse collapse')
    page.div(class_='panel-body')
    if lassocoef is None:
        page.p('The amplitude data for this channel is flat (does not change)'
               ' for the chosen time period.')
    elif abs(lassocoef) < args.threshold:
        page.p('Lasso coefficient below the threshold of %g.'
               % (args.threshold))
    else:
        for image in [plot4, plot5, plot6]:
            img = htmlio.FancyPlot(image)
            page.div(class_='row')
            page.div(class_='col-md-8 col-md-offset-2')
            page.add(htmlio.fancybox_img(img))
            page.div.close()  # col-md-8 col-md-offset-2
            page.div.close()  # row
            page.add('<hr class="row-divider">')
        if args.no_cluster is False:
            if clusters[i][0] is None:
                page.p("<font size='3'><br />No channels were highly"
                       " correlated with this channel.</font>")
            else:
                page.div(class_='row', id_='clusters')
                page.div(class_='col-md-12')
                cimg = htmlio.FancyPlot(clusters[i][0])
                page.add(htmlio.fancybox_img(cimg))
                page.div.close()  # col-md-12
                page.div.close()  # clusters
                if clusters[i][1] is not None:
                    page.p("<font size='3'><br />Only the first %d highly"
                           " correlated channels are reported in this"
                           " figure. <a href='%s' target='_blank'>Here</a>"
                           " is the full list.</font>"
                           % (max_correlated_channels, clusters[i][1]))
    page.div.close()  # panel-body
    page.div.close()  # panel-collapse
    page.div.close()  # panel
page.div.close()  # panel-group
htmlio.close_page(page, 'index.html')  # save and close
logger.info("-- Process Completed")
