#!/usr/bin/python

# Copyright (C) 2013 Michael Coughlin
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""Earthquake xml file generator.

This script generates earthquake xml files using notices from the
internet and USGS PDL client.

Comments should be e-mailed to michael.coughlin@ligo.org.

"""

import os, sys, glob, optparse, warnings, time, json
import numpy as np
import subprocess
from subprocess import Popen
from lxml import etree

import lal.gpstime

from seismon import (eqmon, utils)

import matplotlib
matplotlib.use("AGG")
matplotlib.rcParams.update({'font.size': 18})
from matplotlib import pyplot as plt
from matplotlib import cm

import pymultinest

__author__ = "Michael Coughlin <michael.coughlin@ligo.org>"
__version__ = 1.0
__date__    = "9/22/2013"

# =============================================================================
#
#                               DEFINITIONS
#
# =============================================================================

def parse_commandline():
    """@parse the options given on the command-line.
    """
    parser = optparse.OptionParser(usage=__doc__,version=__version__)

    parser.add_option("-i", "--inputFiles", help="Seismon files.")
    parser.add_option("-o", "--outputDirectory", help="output file.")

    parser.add_option("-v", "--verbose", action="store_true", default=False,
                      help="Run verbosely. (Default: False)")

    opts, args = parser.parse_args()

    # show parameters
    if opts.verbose:
        print >> sys.stderr, ""
        print >> sys.stderr, "running network_eqmon..."
        print >> sys.stderr, "version: %s"%__version__
        print >> sys.stderr, ""
        print >> sys.stderr, "***************** PARAMETERS ********************"
        for o in opts.__dict__.items():
          print >> sys.stderr, o[0]+":"
          print >> sys.stderr, o[1]
        print >> sys.stderr, ""

    return opts

def params_struct(opts):
    """@create params structure

    @param opts
        command line options
    """

    params = {}
    params["inputFiles"] = opts.inputFiles
    params["outputDirectory"] = opts.outputDirectory

    return params

def ampRf(M,r,h,Rf0,Rfs,cd,rs):
    #def ampRf(M,r,h,Rf0,Rfs,Q0,Qs,cd,ch,rs):
    # Rf amplitude estimate
    # M = magnitude
    # r = distance [km]
    # h = depth [km]

    # Rf0 = Rf amplitude parameter
    # Rfs = exponent of power law for f-dependent Rf amplitude
    # Q0 = Q-value of Earth for Rf waves at 1Hz
    # Qs = exponent of power law for f-dependent Q
    # cd = speed parameter for surface coupling  [km/s]
    # ch = speed parameter for horizontal propagation  [km/s]
    # rs

    # exp(-2*pi*h.*fc./cd), coupling of source to surface waves
    # exp(-2*pi*r.*fc./ch./Q), dissipation

    fc = 10**(2.3-M/2)
    #Q = Q0/(fc**Qs)
    Af = Rf0/(fc**Rfs)

    #Rf = 1e-3 * M*Af*np.exp(-2*np.pi*h*fc/cd)*np.exp(-2*np.pi*r*(fc/ch)/Q)/(r**(rs))
    Rf = 1e-3 * M*Af*np.exp(-2*np.pi*h*fc/cd)/(r**(rs))

    return Rf

def myprior(cube, ndim, nparams):
    cube[0] = cube[0]*1.0
    cube[1] = cube[1]*2.0
    cube[2] = cube[2]*5000.0
    cube[3] = cube[3]*2.0
    return cube  

def myloglike(cube, ndim, nparams):

    Rf0,Rfs,cd,rs = [cube[i] for i in range(nparams)]
    pred = ampRf(M,r,h,Rf0,Rfs,cd,rs)

    ratio1 = amp/pred
    ratio2 = pred/amp

    ratio = np.max(np.vstack((ratio1,ratio2)).T,axis=1)
    prob = float(len(np.where(ratio < 5.0)[0])) / float(len(ratio))

    if not np.isfinite(prob):
        prob = 0.0
    return prob

def opt_predict(plotDirectory,amp,M,r,h):

    rm_command = "rm %s/2-*"%plotDirectory
    os.system(rm_command)

    # number of dimensions our problem has
    parameters = ["Rf0","Rfs","cd","rs"]
    n_params = len(parameters)

    # we want to see some output while it is running
    progress = pymultinest.ProgressPlotter(n_params = n_params, outputfiles_basename='%s/2-'%plotDirectory); progress.start()
    # run MultiNest
    pymultinest.run(myloglike, myprior, n_params, importance_nested_sampling = False, resume = True, verbose = True, sampling_efficiency = 'model', n_live_points = 1000, outputfiles_basename='%s/2-'%plotDirectory)
    # ok, done. Stop our progress watcher
    progress.stop()

    livefile = os.path.join(plotDirectory,"2-phys_live.points")
    livepoints = np.loadtxt(livefile)
    sens = livepoints[:,4]
    index = np.argmax(sens)
    livepoint = livepoints[index,:]
   
    Rf0 = livepoint[0]
    Rfs = livepoint[1]
    cd = livepoint[2]
    rs = livepoint[3]
 
    return Rf0,Rfs,cd,rs

def predict_events():
    """@write pdl events

    @param params
        seismon params structure
    """

    warnings.filterwarnings("ignore")

    # Parse command line
    opts = parse_commandline()
    params = params_struct(opts)

    if not os.path.isdir(params["outputDirectory"]):
        utils.mkdir(params["outputDirectory"])
    outputFile = os.path.join(params["outputDirectory"],"earthquakes.txt")

    inputFiles = params["inputFiles"].split(",")
    fid = open(outputFile,'w')
    for inputFile in inputFiles:
        events = np.loadtxt(inputFile)
        for ii,event in enumerate(events):
            fid.write("%.1f %.1f %.1f %.1f %.1f %.1f %.1f %.5e %d %d %.1f %.1f %e %.1f %.1f %.1f %.5e %.1f %.5e %.1f %.5e %.1f %d\n"%(event[0],event[1],event[2],event[3],event[4],event[5],event[6],event[7],event[8],event[9],event[10],event[11],event[12],event[13],event[14],event[15],event[16],event[17],event[18],event[19],event[20],event[21],event[22]))
    fid.close()
    events = np.loadtxt(outputFile)

    global M, r, h, amp
    M = events[:,1]
    r = events[:,12]/1000.0
    h = events[:,13]
    amp = events[:,16]
    Rf0 = 0.16
    Rfs = 1.31
    cd = 4672.83
    rs = 0.83
    pred = ampRf(M,r,h,Rf0,Rfs,cd,rs)
    ratio1 = amp/pred
    ratio2 = pred/amp
    ratio = np.max(np.vstack((ratio1,ratio2)).T,axis=1)
    perc = float(len(np.where(ratio > 5.0)[0])) / float(len(ratio))

    Rf0,Rfs,cd,rs = opt_predict(params["outputDirectory"],events[:,16],M,r,h)
    pred = ampRf(M,r,h,Rf0,Rfs,cd,rs)
    events[:,7] = pred
    events = events[events[:,7].argsort()]

    fid = open(outputFile,'w')
    for ii,event in enumerate(events):
        fid.write("%.1f %.1f %.1f %.1f %.1f %.1f %.1f %.5e %d %d %.1f %.1f %e %.1f %.1f %.1f %.5e %.1f %.5e %.1f %.5e %.1f %d\n"%(event[0],event[1],event[2],event[3],event[4],event[5],event[6],event[7],event[8],event[9],event[10],event[11],event[12],event[13],event[14],event[15],event[16],event[17],event[18],event[19],event[20],event[21],event[22]))
    fid.close()
    events = np.loadtxt(outputFile)

    ind = np.arange(len(events[:,7]))
    vmin = 5.0
    vmax = 7.0

    plt.figure()
    ax = plt.gca()
    sc = ax.scatter(ind,events[:,16],s=20,c=events[:,1],vmin=vmin, vmax=vmax,label='Measured')
    ax.set_yscale('log')
    #plt.semilogy(events[:,15],'kx',label='Measured')
    #plt.semilogy(events[:,7],'c*',label='Predicted')
    yerr = np.zeros((2,len(events)))
    yerr[0,:] = events[:,7]*4.0/5.0
    yerr[1,:] = events[:,7]*4.0
    plt.errorbar(ind,events[:,7],yerr=yerr,label='Predicted')
    plt.legend(loc='best',numpoints=1)
    plt.xlabel("Event number")
    plt.ylabel("Ground Velocity [m/s]")
    plt.xlim([-0.5,len(events[:,0])+0.5])
    plt.ylim([1e-8,1e-3])
    cbar = plt.colorbar(sc)
    cbar.set_label("Earthquake Magnitude")
    plt.show()
    plotName = os.path.join(params["outputDirectory"],'prediction.png')
    plt.savefig(plotName)
    plotName = os.path.join(params["outputDirectory"],'prediction.eps')
    plt.savefig(plotName)
    plotName = os.path.join(params["outputDirectory"],'prediction.pdf')
    plt.savefig(plotName)
    plt.close()

    vel = (events[:,12]/1000.0)/(events[:,15]-events[:,0])

    distances = np.linspace(0,100000,1000)
    times = distances / 3.5
    plt.figure()
    ax = plt.gca()
    ax.plot(events[:,7],vel,'kx')
    ax.plot([1e-7,1e-3],[3.5,3.5],'k--')
    #ax.set_yscale('log')
    ax.set_xscale('log')
    plt.xlabel("Ground velocity [m/s]")
    plt.ylabel("Earthquake velocity [m/s]")
    plt.xlim([1e-8,1e-3])
    plt.ylim([2,5])
    plt.show()
    plotName = os.path.join(params["outputDirectory"],'velocity.png')
    plt.savefig(plotName)
    plotName = os.path.join(params["outputDirectory"],'velocity.eps')
    plt.savefig(plotName)
    plotName = os.path.join(params["outputDirectory"],'velocity.pdf')
    plt.savefig(plotName)
    plt.close()

    vmin = -8
    vmax = -4

    plt.figure()
    ax = plt.gca()
    sc = ax.scatter(events[:,12],events[:,1],c=np.log10(events[:,16]),vmin=vmin, vmax=vmax)
    ax.set_xscale('log')
    plt.xlabel("Distance [m]")
    plt.ylabel("Earthquake Magnitude")
    cbar = plt.colorbar(sc)
    cbar.set_label("log10(Ground velocity [m/s])")
    #plt.xlim([1e-6,1e-3])
    #plt.ylim([5,9])
    plt.show()
    plotName = os.path.join(params["outputDirectory"],'mag_distance.png')
    plt.savefig(plotName)
    plotName = os.path.join(params["outputDirectory"],'mag_distance.eps')
    plt.savefig(plotName)
    plotName = os.path.join(params["outputDirectory"],'mag_distance.pdf')
    plt.savefig(plotName)
    plt.close()

    plt.figure()
    ax = plt.gca()
    sc = ax.scatter(events[:,12],events[:,14],c=np.log10(events[:,16]),vmin=vmin, vmax=vmax)
    ax.set_xscale('log')
    plt.xlabel("Distance [m]")
    plt.ylabel("Azimuth [deg]")
    cbar = plt.colorbar(sc)
    cbar.set_label("log10(Ground velocity [m/s])")
    #plt.xlim([1e-6,1e-3])
    #plt.ylim([5,9])
    plt.show()
    plotName = os.path.join(params["outputDirectory"],'azimuth_distance.png')
    plt.savefig(plotName)
    plotName = os.path.join(params["outputDirectory"],'azimuth_distance.eps')
    plt.savefig(plotName)
    plotName = os.path.join(params["outputDirectory"],'azimuth_distance.pdf')
    plt.savefig(plotName)
    plt.close()

    plt.figure()
    ax = plt.gca()
    sc = ax.scatter(events[:,14],events[:,1],c=np.log10(events[:,16]),vmin=vmin, vmax=vmax)
    plt.xlabel("Azimuth [deg]")
    plt.ylabel("Earthquake Magnitude")
    cbar = plt.colorbar(sc)
    cbar.set_label("log10(Ground velocity [m/s])")
    #plt.xlim([1e-6,1e-3])
    #plt.ylim([5,9])
    plt.show()
    plotName = os.path.join(params["outputDirectory"],'mag_azimuth.png')
    plt.savefig(plotName)
    plotName = os.path.join(params["outputDirectory"],'mag_azimuth.eps')
    plt.savefig(plotName)
    plotName = os.path.join(params["outputDirectory"],'mag_azimuth.pdf')
    plt.savefig(plotName)
    plt.close()

    plt.figure()
    ax = plt.gca()
    indexes = np.where(events[:,22] == 0)[0]
    sc = ax.scatter(events[indexes,12],events[indexes,16],s=20,c='b')
    indexes = np.where(events[:,22] == 1)[0]
    sc = ax.scatter(events[indexes,12],events[indexes,16],s=20,c='g',marker='+')
    indexes = np.where(events[:,22] == 2)[0]
    sc = ax.scatter(events[indexes,12],events[indexes,16],s=20,c='r',marker='x')
    ax.set_xscale('log')
    ax.set_yscale('log')
    plt.xlim([1e5,1e8])
    plt.ylim([8e-7,1e-3])
    plt.xlabel("Distance [m]")
    plt.ylabel("Peak ground motion [m/s]")
    plt.grid()
    plt.show()
    plotName = os.path.join(params["outputDirectory"],'lockloss_vel_distance.png')
    plt.savefig(plotName)
    plotName = os.path.join(params["outputDirectory"],'lockloss_vel_distance.eps')
    plt.savefig(plotName)
    plotName = os.path.join(params["outputDirectory"],'lockloss_vel_distance.pdf')
    plt.savefig(plotName)
    plt.close()


# =============================================================================
#
#                                    MAIN
#
# =============================================================================

if __name__=="__main__":

    predict_events()

