#!/usr/bin/python

# Copyright (C) 2013 Michael Coughlin
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""Earthquake xml file generator.

This script generates earthquake xml files using notices from the
internet and USGS PDL client.

Comments should be e-mailed to michael.coughlin@ligo.org.

"""

import os, sys, glob, optparse, warnings, time, json
import numpy as np
import subprocess
from subprocess import Popen
from lxml import etree

import lal.gpstime

from seismon import (eqmon, utils)

import matplotlib
matplotlib.use("AGG")
matplotlib.rcParams.update({'font.size': 18})
from matplotlib import pyplot as plt
from matplotlib import cm

import pandas as pd
import statsmodels.api as sm

try:
    import gwpy.time, gwpy.timeseries, gwpy.plotter
    import gwpy.segments
    from gwpy.timeseries import StateVector
except:
    print "gwpy import fails... no plotting possible."

__author__ = "Michael Coughlin <michael.coughlin@ligo.org>"
__version__ = 1.0
__date__    = "9/22/2013"

# =============================================================================
#
#                               DEFINITIONS
#
# =============================================================================

def parse_commandline():
    """@parse the options given on the command-line.
    """
    parser = optparse.OptionParser(usage=__doc__,version=__version__)

    parser.add_option("-i", "--inputFileDirectory", help="Seismon files.",
                      default ="/home/mcoughlin/Seismon/Text_Files/Timeseries/H0_PEM-LVEA_SEISZ/64/")

    parser.add_option("-a", "--accelerationFileDirectory", help="Seismon acceleration files.",
                      default ="/home/mcoughlin/Seismon/Text_Files/Acceleration/H0_PEM-LVEA_SEISZ/64/")
    parser.add_option("-d", "--displacementFileDirectory", help="Seismon displacement files.")
    parser.add_option("-p", "--predictionFile", help="Seismon prediction file.",
                      default ="/home/mcoughlin/Seismon/H1/H1S6/931035615-971654415/earthquakes/earthquakes.txt")
    parser.add_option("-l", "--lockFile", help="Lock segment file.",
                      default ="/home/mcoughlin/seismon/RfPrediction/data/segs_Locked_H_1126569617_1136649617.txt")
    parser.add_option("-c", "--lockChannel", help="Lock channel.",
                      default ="L1:GRD-ISC_LOCK_STATE_N")
    parser.add_option("-o", "--outputDirectory", help="output file.",
                      default ="/home/mcoughlin/Seismon/Predictions/H1S6/")

    parser.add_option("-v", "--verbose", action="store_true", default=False,
                      help="Run verbosely. (Default: False)")

    opts, args = parser.parse_args()

    # show parameters
    if opts.verbose:
        print >> sys.stderr, ""
        print >> sys.stderr, "running network_eqmon..."
        print >> sys.stderr, "version: %s"%__version__
        print >> sys.stderr, ""
        print >> sys.stderr, "***************** PARAMETERS ********************"
        for o in opts.__dict__.items():
          print >> sys.stderr, o[0]+":"
          print >> sys.stderr, o[1]
        print >> sys.stderr, ""

    return opts

def params_struct(opts):
    """@create params structure

    @param opts
        command line options
    """

    params = {}
    params["inputFileDirectory"] = opts.inputFileDirectory
    params["accelerationFileDirectory"] = opts.accelerationFileDirectory
    if not opts.displacementFileDirectory == None:
        params["displacementFileDirectory"] = opts.displacementFileDirectory
    else:
        params["displacementFileDirectory"] = params["accelerationFileDirectory"].replace("Acceleration","Displacement")
    params["predictionFile"] = opts.predictionFile
    params["lockFile"] = opts.lockFile
    params["outputDirectory"] = opts.outputDirectory

    return params

def myprior(cube, ndim, nparams):
    cube[0] = cube[0]*1.0
    cube[1] = cube[1]*2.0
    cube[2] = cube[2]*5000.0
    cube[3] = cube[3]*2.0
    return cube  

def predict_events():
    """@write pdl events

    @param params
        seismon params structure
    """

    warnings.filterwarnings("ignore")

    # Parse command line
    opts = parse_commandline()
    params = params_struct(opts)

    thresh = 1.0e-7

    if not os.path.isdir(params["outputDirectory"]):
        utils.mkdir(params["outputDirectory"])
    outputFile = os.path.join(params["outputDirectory"],"earthquakes.txt")

    fid = open(outputFile,'w')
    events = np.loadtxt(params["predictionFile"])
    neqs, ncols = events.shape

    for event in events:
        gpsMin = event[9]
        gpsMax = event[10]
   
        filename = os.path.join(params["inputFileDirectory"],"%d-%d.txt"%(gpsMin,gpsMax))

        if not os.path.isfile(filename):
            continue
        data_out_vel = np.loadtxt(filename)

        filename = os.path.join(params["accelerationFileDirectory"],"%d-%d.txt"%(gpsMin,gpsMax))
        if not os.path.isfile(filename):
            data_out_acc = -1*np.ones(data_out_vel.shape)
        else:
            data_out_acc = np.loadtxt(filename)

        filename = os.path.join(params["displacementFileDirectory"],"%d-%d.txt"%(gpsMin,gpsMax))
        if not os.path.isfile(filename):
            data_out_disp = -1*np.ones(data_out_vel.shape)
        else:
            data_out_disp = np.loadtxt(filename)

        if ncols == 15:
            fid.write("%.1f %.1f %.1f %.1f %.1f %.1f %.1f %.5e %d %d %.1f %.1f %e %.1f %.1f nan nan nan nan nan nan nan nan nan %.1f %.5e %.1f %.5e %.1f %.5e\n"%(event[0],event[1],event[2],event[3],event[4],event[5],event[6],event[7],event[8],event[9],event[10],event[11],event[12],event[13],event[14],data_out_vel[1,0],data_out_vel[1,1],data_out_acc[1,0],data_out_acc[1,1],data_out_disp[1,0],data_out_disp[1,1]))
        elif ncols == 27:
            fid.write("%.1f %.1f %.1f %.1f %.1f %.1f %.1f %.5e %d %d %.1f %.1f %e %.1f %.1f %.1f %.1f %.1f %.1f %.1f %.1f %.1f %.1f %.1f %.1f %.5e %.1f %.5e %.1f %.5e\n"%(event[0],event[1],event[2],event[3],event[4],event[5],event[6],event[7],event[8],event[9],event[10],event[11],event[12],event[13],event[14],event[15],event[16],event[17],event[18],event[19],event[20],event[21],event[22],event[23],data_out_vel[1,0],data_out_vel[1,1],data_out_acc[1,0],data_out_acc[1,1],data_out_disp[1,0],data_out_disp[1,1]))
        else:
            print("Incorrect number of columns")
            exit(0)

        print gpsMin, gpsMax

    fid.close()

    events = np.loadtxt(outputFile)
    idx = np.where(events[:,25] >= thresh)[0]
    events = events[idx,:]
    events = events[events[:,7].argsort()[::-1]]

    indexes = []
    for ii, event in enumerate(events):

        eqStart = event[2] 
        eqEnd = event[4]
        eqPeakAmp = event[-1]

        over = 0
        for jj, index in enumerate(indexes):
            event2 = events[index,:]
            eqStart2 = event2[2] 
            eqEnd2 = event2[4]

            range1 = np.arange(np.floor(eqStart),np.ceil(eqEnd))
            range2 = np.arange(np.floor(eqStart2),np.ceil(eqEnd2))

            if np.sum(np.intersect1d(range1,range2)) > 0:
                over = 1

        if over == 0:
            indexes.append(ii) 

    events = events[indexes,:]
    events = events[events[:,7].argsort()]

    total_locks = 0
    total_time = 0
    flags = []
    locklosstimes = []
    flagsls = []
    locklosstimesls = []

    segments = np.loadtxt(params["lockFile"])
    for ii, event in enumerate(events):

        eqStart = event[2]
        eqEnd = event[4]
        eqPeakAmp = event[-1]

        #if not eqStart == 1131726420.1: continue

        indexes = []
        for jj, segment in enumerate(segments):
            segStart = segment[0]
            segEnd = segment[1]

            range1 = np.arange(np.floor(eqStart),np.ceil(eqEnd))
            range2 = np.arange(np.floor(segStart),np.ceil(segEnd))

            if np.sum(np.intersect1d(range1,range2)) > 0:
                indexes.append(jj)

        locklosstime = -1
        if len(indexes) == 0:
            flag = 0
        else:
            segs = segments[indexes,:]
            checkloss = np.where(segs[:,1] <= eqEnd)[0]
            if len(checkloss) == 0:
                flag = 1
            else:
                flag = 2
                locklosstime = segs[checkloss[0],1]

        if flag == 2:
            total_time = total_time + segments[indexes[checkloss[0]]+1,1] - segments[indexes[checkloss[0]],1]
            total_locks = total_locks + 1
   
        locklosstimes.append(locklosstime)
        flags.append(flag)
 
    locklosstimes = np.array(locklosstimes)
    flags = np.array(flags)

    locklosstimes = locklosstimes[events[:,7].argsort()]
    flags = flags[events[:,7].argsort()]
    events = events[events[:,7].argsort()]
    
    fid = open(outputFile,'w')
    for ii,event in enumerate(events):
        fid.write("%.1f %.1f %.1f %.1f %.1f %.1f %.1f %.5e %d %d %.1f %.1f %e %.1f %.1f %.1f %.1f %.1f %.1f %.1f %.1f %.1f %.1f %.1f %.1f %.5e %.1f %.5e %.1f %.5e %.1f %d\n"%(event[0],event[1],event[2],event[3],event[4],event[5],event[6],event[7],event[8],event[9],event[10],event[11],event[12],event[13],event[14],event[15],event[16],event[17],event[18],event[19],event[20],event[21],event[22],event[23],event[24],event[25],event[26],event[27],event[28],event[29],locklosstimes[ii],flags[ii]))

    fid.close()
    events = np.loadtxt(outputFile)

    thresh = 1.0e-6
    idx = np.where(events[:,25] > thresh)[0]
    indexes1 = np.where(events[idx,31] == 0)[0]
    indexes2 = np.where(events[idx,31] == 1)[0]
    indexes3 = np.where(events[idx,31] == 2)[0]

    nolock = float(len(indexes1))
    nolockloss = float(len(indexes2))
    lockloss = float(len(indexes3))

    toteqs = nolockloss + lockloss + nolock
    
    print "Number of EQs above M4.5 and ground motion >= 1 micron/s: %d"%toteqs
    print "Lockloss: %.3f"%(lockloss/toteqs)
    print "No lockloss: %.3f"%(nolockloss/toteqs)
    print "Not locked: %.3f"%(nolock/toteqs)

    ind = np.arange(len(events[:,7]))
    vmin = 5.0
    vmax = 7.0

    plt.figure()
    ax = plt.gca()
    sc = ax.scatter(ind,events[:,25],s=20,c=events[:,1],vmin=vmin, vmax=vmax,label='Measured')
    ax.set_yscale('log')
    yerr = np.zeros((2,len(events)))
    yerr[0,:] = events[:,7]*4.0/5.0
    yerr[1,:] = events[:,7]*4.0
    plt.errorbar(ind,events[:,7],yerr=yerr,label='Predicted')
    plt.legend(loc='best',numpoints=1)
    plt.xlabel("Event number")
    plt.ylabel("Ground Velocity [m/s]")
    plt.xlim([-0.5,len(events[:,0])+0.5])
    plt.ylim([1e-8,1e-3])
    cbar = plt.colorbar(sc)
    cbar.set_label("Earthquake Magnitude")
    plt.show()
    plotName = os.path.join(params["outputDirectory"],'prediction.png')
    plt.savefig(plotName)
    plotName = os.path.join(params["outputDirectory"],'prediction.eps')
    plt.savefig(plotName)
    plotName = os.path.join(params["outputDirectory"],'prediction.pdf')
    plt.savefig(plotName)
    plt.close()

    vel = (events[:,12]/1000.0)/(events[:,24]-events[:,0])

    distances = np.linspace(0,100000,1000)
    times = distances / 3.5
    plt.figure()
    ax = plt.gca()
    ax.plot(events[:,7],vel,'kx')
    ax.plot([1e-7,1e-3],[3.5,3.5],'k--')
    #ax.set_yscale('log')
    ax.set_xscale('log')
    plt.xlabel("Ground velocity [m/s]")
    plt.ylabel("Earthquake velocity [m/s]")
    plt.xlim([1e-8,1e-3])
    plt.ylim([2,5])
    plt.show()
    plotName = os.path.join(params["outputDirectory"],'velocity.png')
    plt.savefig(plotName)
    plotName = os.path.join(params["outputDirectory"],'velocity.eps')
    plt.savefig(plotName)
    plotName = os.path.join(params["outputDirectory"],'velocity.pdf')
    plt.savefig(plotName)
    plt.close()

    vmin = -8
    vmax = -4

    plt.figure()
    ax = plt.gca()
    sc = ax.scatter(events[:,12],events[:,1],c=np.log10(events[:,25]),vmin=vmin, vmax=vmax)
    ax.set_xscale('log')
    plt.xlabel("Distance [m]")
    plt.ylabel("Earthquake Magnitude")
    cbar = plt.colorbar(sc)
    cbar.set_label("log10(Ground velocity [m/s])")
    #plt.xlim([1e-6,1e-3])
    #plt.ylim([5,9])
    plt.show()
    plotName = os.path.join(params["outputDirectory"],'mag_distance.png')
    plt.savefig(plotName)
    plotName = os.path.join(params["outputDirectory"],'mag_distance.eps')
    plt.savefig(plotName)
    plotName = os.path.join(params["outputDirectory"],'mag_distance.pdf')
    plt.savefig(plotName)
    plt.close()

    plt.figure()
    ax = plt.gca()
    sc = ax.scatter(events[:,12],events[:,14],c=np.log10(events[:,25]),vmin=vmin, vmax=vmax)
    ax.set_xscale('log')
    plt.xlabel("Distance [m]")
    plt.ylabel("Azimuth [deg]")
    cbar = plt.colorbar(sc)
    cbar.set_label("log10(Ground velocity [m/s])")
    #plt.xlim([1e-6,1e-3])
    #plt.ylim([5,9])
    plt.show()
    plotName = os.path.join(params["outputDirectory"],'azimuth_distance.png')
    plt.savefig(plotName)
    plotName = os.path.join(params["outputDirectory"],'azimuth_distance.eps')
    plt.savefig(plotName)
    plotName = os.path.join(params["outputDirectory"],'azimuth_distance.pdf')
    plt.savefig(plotName)
    plt.close()

    plt.figure()
    ax = plt.gca()
    sc = ax.scatter(events[:,14],events[:,1],c=np.log10(events[:,25]),vmin=vmin, vmax=vmax)
    plt.xlabel("Azimuth [deg]")
    plt.ylabel("Earthquake Magnitude")
    cbar = plt.colorbar(sc)
    cbar.set_label("log10(Ground velocity [m/s])")
    #plt.xlim([1e-6,1e-3])
    #plt.ylim([5,9])
    plt.show()
    plotName = os.path.join(params["outputDirectory"],'mag_azimuth.png')
    plt.savefig(plotName)
    plotName = os.path.join(params["outputDirectory"],'mag_azimuth.eps')
    plt.savefig(plotName)
    plotName = os.path.join(params["outputDirectory"],'mag_azimuth.pdf')
    plt.savefig(plotName)
    plt.close()

    plt.figure()
    ax = plt.gca()
    indexes = np.where(events[:,31] == 0)[0]
    sc = ax.scatter(events[indexes,12],events[indexes,25],s=20,c='b')
    indexes = np.where(events[:,31] == 1)[0]
    sc = ax.scatter(events[indexes,12],events[indexes,25],s=20,c='g',marker='+')
    indexes = np.where(events[:,31] == 2)[0]
    sc = ax.scatter(events[indexes,12],events[indexes,25],s=20,c='r',marker='x')
    ax.set_xscale('log')
    ax.set_yscale('log')
    plt.xlim([1e5,1e8])
    plt.ylim([7e-7,1e-3])
    plt.xlabel("Distance [m]")
    plt.ylabel("Peak ground motion [m/s]")
    plt.grid()
    plt.show()
    plotName = os.path.join(params["outputDirectory"],'lockloss_vel_distance.png')
    plt.savefig(plotName)
    plotName = os.path.join(params["outputDirectory"],'lockloss_vel_distance.eps')
    plt.savefig(plotName)
    plotName = os.path.join(params["outputDirectory"],'lockloss_vel_distance.pdf')
    plt.savefig(plotName)
    plt.close()


# =============================================================================
#
#                                    MAIN
#
# =============================================================================

if __name__=="__main__":

    predict_events()

