#!/usr/bin/env python
"""
    Obzplan - an observation planning tool for MeerKAT
    Copyright (C) 2018-2019  Benjamin Hugo, South African Radio Astronomy Observatory

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <https://www.gnu.org/licenses/>.
"""
import obzplan
import ephem
import sys
import os
import numpy as np
from matplotlib import pyplot as plt
import argparse
from mpl_toolkits.mplot3d import Axes3D
import datetime
import matplotlib.dates as mdates
import logging

def create_logger():
    """ Create a console logger """
    log = logging.getLogger(obzplan.__name__)
    cfmt = logging.Formatter(('%(name)s - %(asctime)s %(levelname)s - %(message)s'))
    log.setLevel(logging.DEBUG)
    log.setLevel(logging.INFO)

    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    console.setFormatter(cfmt)

    log.addHandler(console)

    return log, console, cfmt

def remove_log_handler(hndl):
    log.removeHandler(hndl)

def add_log_handler(hndl):
    log.addHandler(hndl)

# Create the log object
log, log_console_handler, log_formatter = create_logger()


parser = argparse.ArgumentParser(description="Observation Planner (v {0:s})".format(obzplan.__version__))
parser.add_argument("--lat",
                    type=str,
                    default="-30:42:47.41",
                    help="Observer Latitude")
parser.add_argument("--long",
                    type=str,
                    default="21:26:38.0",
                    help="Observer Longitude")
parser.add_argument("--elev",
                    type=float,
                    default=1054,
                    help="Observer Elevation")
parser.add_argument("-s",
                    "--start",
                    type=str,
                    default="2017/3/21 08:00:00",
                    help="Observation start time local")
parser.add_argument("-e",
                    "--end",
                    type=str,
                    default="2017/3/21 20:00:00",
                    help="Observation end time local")
parser.add_argument("--elev-cutoff",
                    type=float,
                    default=15,
                    help="Elevation cutoff")
parser.add_argument("sources", metavar="src", type=str, nargs="+",
                    help="Sources to plot")
parser.add_argument("--add-to-catalog", type=str, nargs="+",
                    help="Add source to catalog (format e.g. name,f|J,00:00:00.00000,-90:00:00:00000,0.0)")
parser.add_argument("--solar-separation",
                    type=float,
                    default=15,
                    help="Minimum solar separation")

parser.add_argument("--lunar-separation",
                    type=float,
                    default=0.75,
                    help="Minimum lunar separation")

parser.add_argument("--plot-styles",
                    type=str,
                    default=["k*-", "k^-", "kv-", "kx-", "k<-", "k>-", "kD-",
                             "k|-", "k_-", "kp-"],
                    nargs="+",
                    help="Plot style")
parser.add_argument("--marker-size",
                    type=float,
                    default="6.0",
                    help="Marker size")
parser.add_argument("--satelite-separation",
                    type=float,
                    default=1.0,
                    help="Minimum satelite separation")

args = parser.parse_args()
if len(args.plot_styles) < len(args.sources):
    raise ValueError("Not enough plot styles defined for the number of sources to be plotted")
meerkat = ephem.Observer()

# Set the observer at the Karoo
meerkat.lat, meerkat.long, meerkat.elevation = args.lat, args.long, args.elev
log.info("Observer at (lat, long, alt): (%s, %s, %s)" % (meerkat.lat, meerkat.long, meerkat.elevation))
meerkat.epoch = ephem.J2000
 
scp = ephem.readdb("SCP,f|J,00:00:00.00000,-90:00:00:00000,0.0")

sources = {"Sun": ephem.Sun(),
           "Moon": ephem.Moon(),
          }

with open(os.path.join(os.path.dirname(obzplan.__file__), "data", "catalog.txt")) as f:
    for line in f:
        try:
            name, catline = line.split(":custom")
            sources[name] = ephem.readdb(name + catline)
            log.info("Adding source {0:s}".format(name))
        except:
            raise ValueError("Malformed database catalog string {0:s}. Check database file.".format(line))
if args.add_to_catalog is not None:
    for src in args.add_to_catalog:
        try:
            name = src.split(",")[0]
            sources[name] = ephem.readdb(src)
            log.info("Adding source {0:s}".format(name))
        except:
            raise ValueError("Malformed user input catalog string {0:s}".format(src))

satelites = {}

with open(os.path.join(os.path.dirname(obzplan.__file__), "data", "sat_tles.txt")) as f:
        i = 0
        name = l1 = l2 = ""
        for line in f:
            if i % 3 == 0:
                name = line.strip()
            elif i % 3 == 1:
                l1 = line
            else:
                l2 = line
                sources[name] = ephem.readtle(name, l1, l2)
                satelites[name] = ephem.readtle(name, l1, l2)
                log.info("Adding satelite {0:s} to RFI sources".format(name))
            i += 1

positions = {k: [] for k in sources}
pa = {k: [] for k in sources}

pi = np.pi
r2d = 180.0/pi
d2r = pi/180.0

r2h = 12.0/pi
h2r = pi/12.0

meerkat.date = args.start
st = meerkat.date
log.info("Observation start: UTC %s (LST %s = %f)" % (meerkat.date,
                                                      meerkat.sidereal_time(),
                                                      meerkat.sidereal_time()*r2h))
scp.compute(meerkat)
meerkat.date = args.end
et = meerkat.date
log.info("Observation end: UTC %s (LST %s = %f)" % (meerkat.date,
                                                    meerkat.sidereal_time(),
                                                     meerkat.sidereal_time()*r2h))

scppos = (scp.az, np.pi/2 - scp.alt)
time = np.linspace(st,et,1500)
times_above = {k: [] for k in sources}
times_interference = {k: [] for k in sources}
positions_interference = {k: [] for k in sources}
pa_interference = {k: [] for k in sources}
sidereals = {k: [] for k in sources}

for t in time:
    meerkat.date = t
    for k in args.sources:
        sources["Sun"].compute(meerkat)
        sun_coord = np.asarray([sources["Sun"].az, sources["Sun"].alt],
                               dtype=np.float64)
        sources["Moon"].compute(meerkat)
        moon_coord = np.asarray([sources["Moon"].az, sources["Moon"].alt],
                               dtype=np.float64)
        try:
            sources[k].compute(meerkat)
        except ValueError as e:
            continue
        
        src_coord = np.asarray([sources[k].az, sources[k].alt],
                               dtype=np.float64)
        def angle_between(theta_1, phi_1, theta_2, phi_2):
            return np.arccos(np.sin(theta_1) * np.sin(theta_2) + 
                             np.cos(theta_1) * np.cos(theta_2) * np.cos(phi_1 - phi_2))

        dot_src_sun = angle_between(sun_coord[0], sun_coord[1], 
                                    src_coord[0], src_coord[1]) + (10000 * (sources["Sun"].alt < 0.0))
        dot_src_moon = angle_between(moon_coord[0], moon_coord[1],
                                     src_coord[0], src_coord[1]) + (10000 * (sources["Moon"].alt < 0.0))

        min_sat_separation = 10000
        sat_name = ""
        for sat in satelites:
            try:
                satelites[sat].compute(meerkat)
                dot_sat_src =  angle_between(src_coord[0], src_coord[1],
                                             satelites[sat].az, satelites[sat].alt)
                if min_sat_separation > dot_sat_src:
                    min_sat_separation = dot_sat_src
                    sat_name = sat

            except ValueError as e:
                continue
            except RuntimeError as e:
                continue
                                   

        if sources[k].alt > np.deg2rad(args.elev_cutoff) and \
           ((min_sat_separation > np.deg2rad(args.satelite_separation) and \
             dot_src_sun > np.deg2rad(args.solar_separation) and \
             dot_src_moon > np.deg2rad(args.lunar_separation)) or (k in ["Sun", "Moon"]) or k in satelites):
            positions[k].append([float(sources[k].az),
                                 float(np.pi * 0.5 - sources[k].alt)])
            pa[k].append(sources[k].parallactic_angle())
            times_above[k].append(t)
            sidereals[k].append(meerkat.sidereal_time())

        else:
            if dot_src_sun <= np.deg2rad(args.solar_separation) and k in args.sources and not k == "Sun" and sources[k].alt > np.deg2rad(args.elev_cutoff):
                log.warn("{0:s} experiences solar interference at {1:s}".format(k, str(ephem.Date(t).datetime())))
                times_interference[k].append(t)
                pa_interference[k].append(sources[k].parallactic_angle())
                positions_interference[k].append([float(sources[k].az),
                                                  float(np.pi * 0.5 - sources[k].alt)])

            if dot_src_moon <= np.deg2rad(args.lunar_separation) and k in args.sources and not k == "Moon" and sources[k].alt > np.deg2rad(args.elev_cutoff):
                log.warn("{0:s} is behind the moon at {1:s}".format(k, str(ephem.Date(t).datetime())))
                times_interference[k].append(t)
                pa_interference[k].append(sources[k].parallactic_angle())
                positions_interference[k].append([float(sources[k].az),
                                                  float(np.pi * 0.5 - sources[k].alt)])
            if min_sat_separation <= np.deg2rad(args.satelite_separation) and \
                    k in args.sources and \
                    not k in ["Sun", "Moon"] and \
                    not k in satelites and \
                    sources[k].alt > np.deg2rad(args.elev_cutoff):
                log.warn("{0:s} (elevation {1:f}) is within {2:f} degrees from satelite '{3:s}' at {4:s}".format(k, 
                    np.rad2deg(sources[k].alt), np.rad2deg(min_sat_separation), sat_name, str(ephem.Date(t).datetime())))
                times_interference[k].append(t)
                pa_interference[k].append(sources[k].parallactic_angle())
                positions_interference[k].append([float(sources[k].az),
                                                  float(np.pi * 0.5 - sources[k].alt)])

            times_above[k].append(np.nan)
            positions[k].append([np.nan, np.nan])
            pa[k].append(np.nan)
            sidereals[k].append(np.nan)

for k in sources:
    positions[k] = np.rad2deg(np.array(positions[k]))
    positions_interference[k] = np.rad2deg(np.array(positions_interference[k]))

fig = plt.figure("Observation Planner")
ax = fig.add_subplot(111, projection="3d")
u, v = np.mgrid[0:2*np.pi:20j, 0:np.pi/2:10j]
x = np.cos(u)*np.sin(v)
y = np.sin(u)*np.sin(v)
z = np.cos(v)
ax.plot_wireframe(x, y, z, color="0.75")
ax.plot([np.sin(scppos[1])*np.cos(scppos[0])],
        [np.sin(scppos[1])*np.sin(scppos[0])],
        [np.cos(scppos[1])], "k*")

for si, s in enumerate(args.sources):
    elv_s = 90 - positions[s][0,1]
    az_s = positions[s][0,0]
    elv_e = 90 - positions[s][-1,1]
    az_e = positions[s][-1,0]

    if len(np.where(np.logical_not(np.isnan(positions[s][:,0])))[0]) == 0:
        log.info("{0:s} is never above elevation limit!".format(s))
        continue
    else:
        log.info("%s: Starting (elevation, azimuth): (%f,%f)" % (s, elv_s, az_s))
        log.info("%s: Ending (elevation, azimuth): (%f,%f)" % (s, elv_e, az_e))
        log.info("%s: Minimum (elevation, azimuth): (%f, %f)" % (s,
                                                              90 - np.nanmax(positions[s][:,1]),
                                                              np.nanmin(positions[s][:,0])))
        log.info("%s: Maximum (elevation, azimuth): (%f, %f)" % (s,
                                                              90 - np.nanmin(positions[s][:,1]),
                                                              np.nanmax(positions[s][:,0])))
        lst_start = np.nanmin(sidereals[s])
        lst_end = np.nanmax(sidereals[s])
        log.info("Observable times LST %02.0f:%02.0f to %02.0f:%02.0f" % (np.floor(lst_start),
                                                 np.floor((lst_start - np.floor(lst_start))*60),
                                                 np.floor(lst_end),
                                                 np.floor((lst_end - np.floor(lst_end))*60)))



    start_time = np.deg2rad(positions[s][np.where(np.logical_not(np.isnan(positions[s][:,0])))[0][0],:])
    end_time = np.deg2rad(positions[s][np.where(np.logical_not(np.isnan(positions[s][:,0])))[0][-1],:])
    ax.plot([np.sin(start_time[1])*np.cos(start_time[0])],
            [np.sin(start_time[1])*np.sin(start_time[0])],
            [np.cos(start_time[1])], "k^")
    ax.plot([np.sin(end_time[1])*np.cos(end_time[0])],
            [np.sin(end_time[1])*np.sin(end_time[0])],
            [np.cos(end_time[1])], "kv")

    ax.plot(np.sin(np.deg2rad(positions[s][:,1]))*np.cos(np.deg2rad(positions[s][:,0])),
            np.sin(np.deg2rad(positions[s][:,1]))*np.sin(np.deg2rad(positions[s][:,0])),
            np.cos(np.deg2rad(positions[s][:,1])), args.plot_styles[si],
            label=s, markersize=args.marker_size, markevery=25)
    if len(positions_interference[s]) > 0:
         ax.plot(np.sin(np.deg2rad(positions_interference[s][:,1]))*np.cos(np.deg2rad(positions_interference[s][:,0])),
                 np.sin(np.deg2rad(positions_interference[s][:,1]))*np.sin(np.deg2rad(positions_interference[s][:,0])),
                 np.cos(np.deg2rad(positions_interference[s][:,1])), "r.",
                 markersize=args.marker_size/6.0)

ax.legend()
ax.grid(False)
ax.set_xlim(-1,1)
ax.set_ylim(-1,1)
ax.set_zlim(-1,1)

fig = plt.figure("Elevation v Time")
ax = fig.add_subplot(111)
for si, s in enumerate(args.sources):
    elv_s = 90 - positions[s][0,1]
    az_s = positions[s][0,0]
    elv_e = 90 - positions[s][-1,1]
    az_e = positions[s][-1,0]

    time = [ephem.Date(t).datetime() if not np.isnan(t) else datetime.datetime.now() for t in
            times_above[s]]
    ax.plot(time, 90 - positions[s][:,1], args.plot_styles[si], label=s,
            markersize = args.marker_size, markevery=25)
    time = [ephem.Date(t).datetime() if not np.isnan(t) else datetime.datetime.now() for t in
            times_interference[s]]
    if len(positions_interference[s] > 0):
        ax.plot(time, 90 - positions_interference[s][:,1], "r.",
                markersize = args.marker_size/6.0)
hfmt = mdates.DateFormatter('%H:%M')

ax.xaxis.set_major_formatter(hfmt)
plt.xlabel("Time (%s) UTC" % str(args.start))
plt.ylabel("Elevation angle")
plt.grid(True)
plt.legend()

fig = plt.figure("PA v Time")
ax = fig.add_subplot(111)
for si, s in enumerate(args.sources):
    time = [ephem.Date(t).datetime() if not np.isnan(t) else datetime.datetime.now() for t in
            times_above[s]]
    ax.plot(time, pa[s][:], args.plot_styles[si], label=s,
            markersize=args.marker_size, markevery=25)
    time = [ephem.Date(t).datetime() if not np.isnan(t) else datetime.datetime.now() for t in
            times_interference[s]]
    if len(pa_interference[s][:]) > 0:
        ax.plot(time, pa_interference[s][:], "r.", markersize=args.marker_size/6.0)

hfmt = mdates.DateFormatter('%H:%M')

ax.xaxis.set_major_formatter(hfmt)
plt.xlabel("Time (%s) UTC" % str(args.start))
plt.ylabel("Parallactic angle")
plt.grid(True)
plt.legend()
plt.show()
