#! /usr/bin/python3
# -*- coding: utf-8 -*-

__author__ = "Alan Loh"
__copyright__ = "Copyright 2023, nenupy"
__credits__ = ["Alan Loh"]
__maintainer__ = "Alan"
__email__ = "alan.loh@obspm.fr"
__status__ = "Production"

from tkinter import Tk, Button, Checkbutton, Label, Entry, Frame, StringVar, IntVar, DoubleVar
from tkcalendar import DateEntry
from datetime import datetime, timedelta
import numpy as np
from matplotlib.backend_bases import key_press_handler
from matplotlib.backends.backend_tkagg import (
    FigureCanvasTkAgg
)
from matplotlib.figure import Figure
from matplotlib.collections import LineCollection
import matplotlib.pyplot as plt
from matplotlib.projections.polar import PolarAxes
import matplotlib.dates as mdates
from matplotlib import colormaps

from typing import Union
import argparse
from abc import ABC

import astropy.units as u
from astropy.time import Time, TimeDelta
from astropy.coordinates import solar_system_ephemeris, SkyCoord, Angle

from nenupy.astro.target import Target, FixedTarget, SolarSystemTarget
from nenupy.astro.astro_tools import hour_angle

PLOT_FRAME_WIDTH = 550
PLOT_FRAME_LENGTH = 800
INPUT_FRAME_WIDTH = 310
ENTRY_WIDTH = 10
BUTTON_WIDTH = 5

TIME_STEPS = 100
FONTSIZE = 9
PAD = 20
FRAME_RELIEF = "flat"
FRAME_BORDER_WIDTH = 5

CMAP = "copper" # "RdYlGn"
BACKGROUND_COLOR = "0.9"

ELEVATION_MIN = 1
ELEVATION_MID = 40

DATE_HELP = (
    "Select a day.\n"
    "Click on the figure.\n"
    "<- or ->: +/- 1 day.\n"
    "shift+(<- or ->): +/- 30 days."
)

# ============================================================= #
# ---------------------- is_solar_system ---------------------- #
def is_solar_system(src_name: str) -> bool:
    """ Check if a source name belongs to the Solar System moving targets. """
    return src_name.lower() in solar_system_ephemeris.bodies

# ============================================================= #
# --------------------- target_from_name ---------------------- #
def target_from_name(src_name: str, time: Time) -> Target:
    """ Instanciate a Target object from name. """
    if is_solar_system(src_name):
        return SolarSystemTarget.from_name(src_name, time=time)
    else:
        return FixedTarget.from_name(src_name, time=time)

# ============================================================= #
# ------------------- target_from_position -------------------- #
def target_from_position(ra: Union[str, float], dec: Union[str, float], time: Time) -> Target:
    """ Instanciate a Target object from a ra dec position (can be in str, hexagedecimal, float). """
    if isinstance(ra, str) and isinstance(dec, str):
        try:
            # Degrees in string
            ra, dec = float(ra), float(dec)
            skycoord = SkyCoord(ra, dec, unit=u.deg)
        except ValueError:
            # Sexagedecimal format
            skycoord = SkyCoord(ra, dec, unit=(u.hourangle, u.deg))
    else:
        # They are in degrees
        skycoord = SkyCoord(ra, dec, unit=u.deg)
    return FixedTarget(coordinates=skycoord, time=time)

# ============================================================= #
# -------------------- generate_time_ramp --------------------- #
def generate_time_ramp(start_time: Time, n_times: int) -> Time:
    """ Compute a 24h time ramp """
    dt = TimeDelta(24*3600, format="sec") / n_times
    return start_time + np.arange(n_times)*dt    

# ============================================================= #
# ------------------ display_time_elevation ------------------- #
def display_time_elevation(ax, time, elevation, color) -> None:
    """ """
    if time.size == 0: return

    elevation_pad = 7
    if elevation > 40:
        ytext = elevation - elevation_pad
        ymax = elevation
        ymin = ytext
        va = "top"
    else:
        ytext = elevation + elevation_pad
        ymax = ytext
        ymin = elevation
        va = "bottom"
    ax.vlines(time.datetime, ymin=ymin, ymax=ymax, color=color, linestyle="-.")
    ax.text(time.datetime, ytext, time.iso,
        in_layout=False,
        rotation="vertical",
        va=va, ha="center",
        color="black",
        clip_on=True,
        bbox=dict(facecolor="white", edgecolor=color, boxstyle="round")
    )

# ============================================================= #
# --------------- plot_source_visibility_linear --------------- #
def plot_source_visibility_linear(ax, source, sun) -> None:
    """ """

    # Get the coordinates to plot
    elevation = source.horizontal_coordinates.alt.deg
    # theta = np.arange(hor_coord.size)

    # Plot the colorized line segments
    time_plt = mdates.date2num(source.time.datetime)
    points = np.array([time_plt, elevation]).T.reshape(-1, 1, 2)
    segments = np.concatenate([points[:-1], points[1:]], axis=1)
    norm = plt.Normalize(0, 90)
    lc = LineCollection(segments, cmap=CMAP, norm=norm, zorder=10)
    lc.set_array(elevation)
    lc.set_linewidth(3)
    _ = ax.add_collection(lc)

    # Set time ticks
    ax.set_xlabel("Time (UTC)")
    ax.xaxis.set_minor_locator(mdates.MinuteLocator(byminute=[15, 30, 45]))
    ax.xaxis.set_major_locator(mdates.HourLocator())
    hourFmt = mdates.DateFormatter("%H", usetex=True)
    ax.xaxis.set_major_formatter(hourFmt)
    #ax.autoscale_view()
    #ax.set_xlim(time_plt[0], time_plt[-1])
    ax.set_xlim(
        mdates.date2num((source.time[0] - TimeDelta(1800, format="sec")).datetime),
        mdates.date2num((source.time[-1] + TimeDelta(1800, format="sec")).datetime),
    )

    # Mask the night time
    if not (sun is None):
        midday = sun.next_meridian_transit(sun.time[0])
        dawn_time = sun.previous_rise_time(midday)
        twilight_time = sun.next_set_time(midday)
        t_min, t_max = ax.get_xlim()
        ax.axvspan(t_min, mdates.date2num(dawn_time.datetime), color="0.4", zorder=-1)
        ax.axvspan(mdates.date2num(twilight_time.datetime), t_max, color="0.4", zorder=-1)

    # Display the transits
    transits = source.meridian_transit(t_min=source.time[0], precision=TimeDelta(1, format="sec"))
    for transit in transits:
        cmap = colormaps[CMAP]
        transit_rgba = cmap(norm(max(elevation)))
        transit.precision = 0
        display_time_elevation(ax, transit, max(elevation), transit_rgba)

        # Display the visibility zones
        for elevation_threshold in [ELEVATION_MIN, ELEVATION_MID]:
            threshold_color = cmap(norm(elevation_threshold))
            ax.axhline(elevation_threshold, color=threshold_color, linestyle=":")
            display_time_elevation(ax,
                time=source.previous_rise_time(time=transit, elevation=elevation_threshold*u.deg),
                elevation=elevation_threshold,
                color=threshold_color
            )
            display_time_elevation(ax,
                time=source.next_rise_time(time=transit, elevation=elevation_threshold*u.deg),
                elevation=elevation_threshold,
                color=threshold_color
            )
            display_time_elevation(ax,
                time=source.previous_set_time(time=transit, elevation=elevation_threshold*u.deg),
                elevation=elevation_threshold,
                color=threshold_color
            )
            display_time_elevation(ax,
                time=source.next_set_time(time=transit, elevation=elevation_threshold*u.deg),
                elevation=elevation_threshold,
                color=threshold_color
            )

    # Set elevation ticks
    ax.set_ylim(0, 90)
    ax.set_ylabel("Elevation (deg)")

    # Set other properties
    ax.grid(True, color="0.6", linewidth=0.3)
    ax.set_facecolor(BACKGROUND_COLOR)

# ============================================================= #
# --------------- plot_source_visibility_polar ---------------- #
def plot_source_visibility_polar(ax: PolarAxes, source: Target, sun: SolarSystemTarget = None, fixed_solar_midday: bool = False) -> None:
    """ """
    
    # Get the coordinates to plot
    hor_coord = source.horizontal_coordinates
    theta, r = hour_angle(source.coordinates, source.time).rad, hor_coord.alt.deg

    # Plot the colorized line segments
    points = np.array([theta, r]).T.reshape(-1, 1, 2)
    segments = np.concatenate([points[:-1], points[1:]], axis=1)
    norm = plt.Normalize(0, 90)
    lc = LineCollection(segments, cmap=CMAP, norm=norm, zorder=10)
    lc.set_array(hor_coord.alt.deg)
    lc.set_linewidth(3)
    _ = ax.add_collection(lc)

    # Mask the night time
    if not (sun is None):
        sun_alt = sun.horizontal_coordinates.alt.deg
        night = sun_alt < 0
        def find_first_night_index(night_mask) -> int:
            night_mask = (~night_mask).astype(np.int8)
            iszero = np.concatenate(([0], np.equal(night_mask, 0).view(np.int8), [0]))
            absdiff = np.abs(np.diff(iszero))
            windows = np.where(absdiff == 1)[0].reshape(-1, 2)
            for window in windows:
                if window[0] == 0: continue
                return window[0]
        night_start_idx = find_first_night_index(night)
        dtheta = theta[1] - theta[0]
        night_angles = Angle(theta[night_start_idx] + np.arange(night.sum())*dtheta, unit="rad").wrap_at(Angle(180, unit="deg"))
        ax.fill_between(
            night_angles.rad,
            0,
            90,
            color="0.4"
        )

    # Set radial ticks
    ax.set_rmin(-10)
    ax.set_rmax(90)
    ax.set_rlabel_position(0)
    ax.set_yticks(np.arange(0, 100, 10))
    ax.set_yticklabels(
        Angle(np.arange(0, 100, 10), unit="deg").to_string(
            unit=u.deg, precision=0, fields=1, format='latex'
        ),
        rotation=45,
        zorder=100,
        fontsize=FONTSIZE
    )

    # Set azimuthal ticks
    x_ticks = np.radians(np.linspace(0, 360, 24, endpoint=False))
    # Interpolate the JD at each xtick
    if np.any(np.diff(theta) < -np.pi):
        theta[np.arange(theta.size) > np.argmin(np.diff(theta))] += 2*np.pi  # Loop over for values that go back to >0 rad after reaching 2pi
        x_ticks_increasing = x_ticks.copy() # To keep the time consistent, make the invert loop for theta
        x_ticks_increasing[x_ticks < theta[0]] += 2*np.pi
    else:
        x_ticks_increasing = x_ticks
    xtick_jday = np.interp(x_ticks_increasing, theta, source.time.jd)
    ax.set_xticks(x_ticks)
    ax.set_xticklabels(
        [
            "\n".join([f"$\mathrm{{HA}}=${ha}"]+ utc.split()) for ha, utc in zip(
                Angle(x_ticks, unit="rad").wrap_at(Angle(180, unit="deg")).to_string(unit=u.hourangle, precision=0, fields=1, format='latex'),
                Time(xtick_jday, format="jd", precision=0).iso
            )
        ],
        fontsize=FONTSIZE
    )
    ax.tick_params(pad=PAD)
    if fixed_solar_midday:
        # Rotate the plot by an offset equivalent to the difference between the start time of the plot and the sun culmination
        delta_deg = np.degrees(
            (sun.next_meridian_transit(sun.time[0]).jd - xtick_jday[0])*2*np.pi
        )
        ax.set_theta_zero_location("N",
            offset=delta_deg
        )
    else:
        ax.set_theta_zero_location("N")
    ax.set_theta_direction(-1)

    # Set other properties
    ax.grid(True, color="0.6", linewidth=0.3)
    ax.set_facecolor(BACKGROUND_COLOR)

# ============================================================= #
# ---------------------- BaseInputFrame ----------------------- #
class BaseInputFrame(Frame, ABC):

    def __init__(self, master, **kwargs):
        super().__init__(master=master, **kwargs)

        date_explanation_label = Label(self, text=DATE_HELP, justify="left", wraplength=self["width"])
        
        src_name_label = Label(self, text="Name:", anchor="e")
        self.src_name = StringVar(self)
        src_name_entry = Entry(self, textvariable=self.src_name, width=ENTRY_WIDTH*2)
        src_name_entry.bind("<Return>", self.plot_after_name)
        
        ra_label = Label(self, text="RA:", anchor="e")
        self.ra = StringVar(self)
        ra_entry = Entry(self, textvariable=self.ra, width=ENTRY_WIDTH)

        dec_label = Label(self, text="Dec:", anchor="e")
        self.dec = StringVar(self)
        dec_entry = Entry(self, textvariable=self.dec, width=ENTRY_WIDTH)

        # The calendar entry bugs on Mac OS devices
        # A double-click is required to display the dropdown menu.
        # Related issue: https://github.com/j4321/tkcalendar/issues/41
        self.date_entry = DateEntry(self, selectmode="day",
            year=datetime.now().year,
            month=datetime.now().month,
            day=datetime.now().day,
            date_pattern="yyyy-mm-dd",
            width=ENTRY_WIDTH,
            #font="Arial 14",
            justify="center"
        )
        self.date_entry.bind("<<DateEntrySelected>>", master.update)
        self.date_entry.bind("<Return>", master.update)

        coord_plot_button = Button(self, text="Go", command=self.plot_after_coord, width=BUTTON_WIDTH)
        # name_plot_button = Button(self, text="Go", command=self.plot_after_name, width=BUTTON_WIDTH)

        # Date row
        date_explanation_label.grid(column=0, row=0, columnspan=2, pady=(PAD, 0))
        self.date_entry.grid(column=2, row=0, pady=(PAD, 0))

        # Source name row
        src_name_label.grid(column=0, row=1, pady=(PAD, 0))
        src_name_entry.grid(column=1, row=1, columnspan=2, pady=(PAD, 0))
        # name_plot_button.grid(column=2, row=1, pady=(PAD, 0))

        # Coordinates rows
        ra_label.grid(column=0, row=2, pady=(PAD, 0))
        ra_entry.grid(column=1, row=2, pady=(PAD, 0))
        dec_label.grid(column=0, row=3)
        dec_entry.grid(column=1, row=3)
        coord_plot_button.grid(column=2, row=2, rowspan=2, pady=(PAD, 0))

        self.grid_propagate(False) 

    def plot_after_name(self, event):
        self.reset_coordinates()
        self.master.update()
    
    def plot_after_coord(self):
        self.reset_srcname()
        self.master.update()

    def reset_coordinates(self):
        self.ra.set("")
        self.dec.set("")
    
    def reset_srcname(self):
        self.src_name.set("")

# ============================================================= #
# ---------------------- PolarInputFrame ---------------------- #
class PolarInputFrame(BaseInputFrame):

    def __init__(self, master, **kwargs):
        super().__init__(master=master, **kwargs)

        # Check box
        self.set_midday_top = IntVar()

        midday_button = Checkbutton(self, text="Midday on top", variable=self.set_midday_top, command=master.update)
        midday_button.grid(column=0, row=4, pady=(PAD, 0), columnspan=3)

# ============================================================= #
# --------------------- LinearInputFrame ---------------------- #
class LinearInputFrame(BaseInputFrame):

    def __init__(self, master, **kwargs):
        super().__init__(master=master, **kwargs)

        self.el_1_val = DoubleVar(value=ELEVATION_MIN)
        el_1_label = Label(self, text="1st elevation threshold:", anchor="e")
        el_1_entry = Entry(self, textvariable=self.el_1_val, width=ENTRY_WIDTH)
       
        el_1_entry.bind("<Return>", self.update_elevation_min)
        el_1_label.grid(column=0, row=4, pady=(PAD, 0), columnspan=2)
        el_1_entry.grid(column=2, row=4, pady=(PAD, 0), columnspan=1)

        self.el_2_val = DoubleVar(value=ELEVATION_MID)
        el_2_label = Label(self, text="2nd elevation threshold:", anchor="e")
        el_2_entry = Entry(self, textvariable=self.el_2_val, width=ENTRY_WIDTH)
       
        el_2_entry.bind("<Return>", self.update_elevation_mid)
        el_2_label.grid(column=0, row=5, pady=(PAD, 0), columnspan=2)
        el_2_entry.grid(column=2, row=5, pady=(PAD, 0), columnspan=1)

    def update_elevation_min(self, event):
        global ELEVATION_MIN
        ELEVATION_MIN = self.el_1_val.get()
        self.master.update()

    def update_elevation_mid(self, event):
        global ELEVATION_MID
        ELEVATION_MID = self.el_2_val.get()
        self.master.update()

# ============================================================= #
# ------------------------ InputFrame ------------------------- #
class ResultFrame(Frame):

    def __init__(self, master, **kwargs):
        super().__init__(master=master, **kwargs)

        self.output_message = StringVar()
        output_label = Label(self, textvariable=self.output_message, justify="left")

        output_label.pack()
        self.pack_propagate(False) 

    
    def update(self, source: Target) -> None:
        """ """
        meridian_transit = source.meridian_transit(t_min=source.time[0])
        meridian_transit.precision = 0
        rise_1 = source.previous_rise_time(time=meridian_transit, elevation=1*u.deg)
        set_1 = source.next_set_time(time=meridian_transit, elevation=1*u.deg)
        rise_1.precision = 0
        set_1.precision = 0
        rise_40 = source.previous_rise_time(time=meridian_transit, elevation=40*u.deg)
        set_40 = source.next_set_time(time=meridian_transit, elevation=40*u.deg)
        rise_40.precision = 0
        set_40.precision = 0

        outputs = {
            "Meridian Transit": meridian_transit.isot,
            "Elevation > 1": (rise_1.isot, set_1.isot),
            "Elevation > 40": (rise_40.isot, set_40.isot),
        }

        message = ""
        for key, value in outputs.items():
            message += f"{key}: {value}\n"
        self.output_message.set(message)

# ============================================================= #
# ----------------------- BasePlotFrame ----------------------- #
class BasePlotFrame(Frame, ABC):

    def __init__(self, master, **kwargs):
        super().__init__(master=master, **kwargs)

        self.fig = Figure(
            #figsize=(5.9, 5.9),
            dpi=100,
            facecolor=BACKGROUND_COLOR,
            tight_layout=True
        )

    def set_figure_canvas(self, ax, fig) -> None:
        """ """
        self.canvas = FigureCanvasTkAgg(self.fig, master=self)
        self.canvas.get_tk_widget().pack(fill="both", expand=True)

        # Use left / right arrow to navigate throughout days
        def key_press_handler(event):
            current_date = self.master.input_frame.date_entry.get_date()
            if event.key == "right":
                self.master.input_frame.date_entry.set_date(current_date + timedelta(days=1))
                self.master.update()
            elif event.key == "left":
                self.master.input_frame.date_entry.set_date(current_date - timedelta(days=1))
                self.master.update()
            elif event.key == "shift+right":
                self.master.input_frame.date_entry.set_date(current_date + timedelta(days=30))
                self.master.update()
            elif event.key == "shift+left":
                self.master.input_frame.date_entry.set_date(current_date - timedelta(days=30))
                self.master.update()
            else:
                pass
        self.canvas.mpl_connect("key_press_event", key_press_handler)

        self.pack_propagate(False) 

# ============================================================= #
# ---------------------- PolarPlotFrame ----------------------- #
class PolarPlotFrame(BasePlotFrame):

    def __init__(self, master, **kwargs):
        super().__init__(master=master, **kwargs)

        self.ax = self.fig.add_subplot(projection="polar")

        # Initialize empty plot
        self.ax.set_facecolor(BACKGROUND_COLOR)
        self.ax.set_rlim(-10, 90)
        self.ax.text(0, -10, "Select a source", ha="center", bbox=dict(facecolor="white", edgecolor="black", boxstyle="round"))

        self.set_figure_canvas(self.fig, self.ax)

    def update(self, source: Target, sun: SolarSystemTarget) -> None:
        """ """
        self.ax.clear()
        plot_source_visibility_polar(
            ax=self.ax,
            source=source,
            sun=sun,
            fixed_solar_midday=self.master.input_frame.set_midday_top.get()
        )
        self.canvas.draw()

# ============================================================= #
# ---------------------- LinearPlotFrame ---------------------- #
class LinearPlotFrame(BasePlotFrame):

    def __init__(self, master, **kwargs):
        super().__init__(master=master, **kwargs)

        self.ax = self.fig.add_subplot(projection="rectilinear")

        # Initialize empty plot
        self.ax.set_facecolor(BACKGROUND_COLOR)
        self.ax.set_ylim(0, 90)
        self.ax.set_ylabel("Elevation (deg)")
        self.ax.set_xlabel("Time (UTC)")
        self.ax.text(0.5, 45, "Select a source", ha="center", bbox=dict(facecolor="white", edgecolor="black", boxstyle="round"))

        self.set_figure_canvas(self.fig, self.ax)

    def update(self, source: Target, sun: SolarSystemTarget) -> None:
        """ """
        self.ax.clear()
        plot_source_visibility_linear(
            ax=self.ax,
            source=source,
            sun=sun
        )
        self.canvas.draw()

# ============================================================= #
# ----------------------- VisibilityApp ----------------------- #
class VisibilityApp(Tk):

    def __init__(self, plot_type: str):
        super().__init__()

        # self.plot_frame = PlotFrame(self, plot_type=plot_type, borderwidth=5, relief="ridge",
        #     width=np.min((TOTAL_WIDTH, TOTAL_HEIGHT)), height=TOTAL_HEIGHT)
        # self.input_frame = InputFrame(self, borderwidth=5, relief="ridge",
        #     width=TOTAL_WIDTH - self.plot_frame["width"], height=TOTAL_HEIGHT/2)
        # # self.result_frame = ResultFrame(self, borderwidth=5, relief="ridge",
        # #     width=TOTAL_WIDTH - self.plot_frame["width"], height=TOTAL_HEIGHT/2)

        # self.geometry(f"{TOTAL_WIDTH}x{TOTAL_HEIGHT}")

        if plot_type.lower() == "linear":
            self.plot_frame = LinearPlotFrame(self, borderwidth=FRAME_BORDER_WIDTH, relief=FRAME_RELIEF,
                width=PLOT_FRAME_LENGTH,
                height=PLOT_FRAME_WIDTH)
            self.input_frame = LinearInputFrame(self, borderwidth=FRAME_BORDER_WIDTH, relief=FRAME_RELIEF,
                width=INPUT_FRAME_WIDTH,
                height=PLOT_FRAME_WIDTH)
            self.geometry(f"{PLOT_FRAME_LENGTH + INPUT_FRAME_WIDTH}x{PLOT_FRAME_WIDTH}")
        elif plot_type.lower() == "polar":
            self.plot_frame = PolarPlotFrame(self, borderwidth=FRAME_BORDER_WIDTH, relief=FRAME_RELIEF,
                width=PLOT_FRAME_LENGTH,
                height=PLOT_FRAME_LENGTH)
            self.input_frame = PolarInputFrame(self, borderwidth=FRAME_BORDER_WIDTH, relief=FRAME_RELIEF,
                width=INPUT_FRAME_WIDTH,
                height=PLOT_FRAME_LENGTH)
            self.geometry(f"{PLOT_FRAME_LENGTH + INPUT_FRAME_WIDTH}x{PLOT_FRAME_LENGTH}")
        else:
            raise ValueError(f"Unrecognized plot_type '{plot_type}'!")

        # Allows the plot Frame to expand and shrink
        self.rowconfigure(0, minsize=200, weight=1)
        self.columnconfigure(0, minsize=200, weight=1)

        self.plot_frame.grid(column=0, row=0, rowspan=1, sticky="news")
        self.input_frame.grid(column=1, row=0)
        # self.result_frame.grid(column=1, row=1)

        self.eval("tk::PlaceWindow . center")
        self.title("nenupy - NenuFAR source visibility")

    def update(self, *args) -> None:
        
        # Gather the inputs
        src_name = self.input_frame.src_name.get()
        ra = self.input_frame.ra.get()
        dec = self.input_frame.dec.get()
        times = generate_time_ramp(
            start_time=Time(self.input_frame.date_entry.get_date().isoformat(), format="iso"),
            n_times=TIME_STEPS
        )

        # Instanciate Target objects
        if (src_name != ""):
            self.input_frame.reset_coordinates()
            src = target_from_name(src_name=src_name, time=times)
            # Auto update RA, Dec
            self.input_frame.ra.set(f"{np.mean(src.coordinates.ra.deg)}")
            self.input_frame.dec.set(f"{np.mean(src.coordinates.dec.deg)}")
        elif (ra != "") and (dec != ""):
            self.input_frame.reset_srcname()
            src = target_from_position(ra=ra, dec=dec, time=times)
        else:
            # Nothing has been entered
            return
        
        # Show the output
        # self.result_frame.update(source=src)

        # Plot
        sun = SolarSystemTarget.from_name("Sun", time=times)
        self.plot_frame.update(src, sun)


# ============================================================= #
# -------------------- run_visibility_app --------------------- #
def run_visibility_app(plot_type: str = "linear") -> None:
    app = VisibilityApp(plot_type=plot_type)
    app.mainloop()

if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-p",
        "--plot_type",
        type=str,
        help="plot type (either 'linear' or 'polar')",
        default="linear",
        required=False
    )
    args = parser.parse_args()

    run_visibility_app(plot_type=args.plot_type)
