#!/usr/bin/env python

"""
Trading strategy runner
"""

import multiprocessing as mp
import os
import sys
import time
import uuid
import random
from datetime import datetime
from math import ceil
from typing import List

import alpaca_trade_api as tradeapi
import pygit2
import toml
from pytz import timezone

from liualgotrader.common import config
from liualgotrader.common.market_data import get_historical_data_from_polygon
from liualgotrader.common.tlog import tlog
from liualgotrader.consumer import consumer_main
from liualgotrader.polygon_producer import polygon_producer_main
from liualgotrader.scanners_runner import main


def motd(filename: str, version: str, unique_id: str) -> None:
    """Display welcome message"""

    print("+=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=+")
    tlog(f"{filename} {version} starting")
    tlog(f"unique id: {unique_id}")
    print("+=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=+")
    tlog(f"DSN: {config.dsn}")
    tlog(f"MAX SYMBOLS: {config.total_tickers}")
    print("+=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=+")


def get_trading_windows(tz, api):
    """Get start and end time for trading"""
    tlog("checking market schedule")
    today = datetime.today().astimezone(tz)
    today_str = datetime.today().astimezone(tz).strftime("%Y-%m-%d")

    calendar = api.get_calendar(start=today_str, end=today_str)[0]

    tlog(f"next open date {calendar.date.date()}")

    if today.date() < calendar.date.date():
        tlog(f"which is not today {today}")
        return None, None
    market_open = today.replace(
        hour=calendar.open.hour,
        minute=calendar.open.minute,
        second=0,
        microsecond=0,
    )
    market_close = today.replace(
        hour=calendar.close.hour,
        minute=calendar.close.minute,
        second=0,
        microsecond=0,
    )
    return market_open, market_close


"""
process main
"""

def ready_to_start(trading_api: tradeapi) -> bool:
    nyc = timezone("America/New_York")

    config.market_open, config.market_close = get_trading_windows(nyc, trading_api)

    if config.market_open or config.bypass_market_schedule:

        if not config.bypass_market_schedule:
            tlog(
                f"markets open {config.market_open} market close {config.market_close}"
            )

        # Wait until just before we might want to trade
        current_dt = datetime.today().astimezone(nyc)
        tlog(f"current time {current_dt}")

        if config.bypass_market_schedule:
            tlog("bypassing market schedule, are we debugging something?")
            return True
        elif current_dt < config.market_close:
            to_market_open = config.market_open - current_dt
            if to_market_open.total_seconds() > 0:
                try:
                    tlog(
                        f"waiting for market open: {to_market_open} ({to_market_open.total_seconds()} seconds)"
                    )
                    time.sleep(to_market_open.total_seconds() + 1)
                except KeyboardInterrupt:
                    return False

            return True

    return False


def calc_num_consumer_processes() -> int:
    if config.num_consumers > 0:
        return config.num_consumers

    num_cpus = os.cpu_count()

    load_avg = sum(os.getloadavg()) / 3

    tlog(f"15-min load_avg:{load_avg}, num_cpus:{num_cpus}, proc_factor:{config.proc_factor}")
    if not load_avg:
        load_avg = 1.0

    num_processes = max(
        1,
        ceil(
            num_cpus / min(load_avg, 1.0) * config.proc_factor
            if num_cpus > 0
            else config.proc_factor / min(load_avg, 1.0)
        ),
    )


    return num_processes


"""
starting
"""


if __name__ == "__main__":
    config.filename = os.path.basename(__file__)
    mp.set_start_method("spawn")

    try:
        config.build_label = pygit2.Repository("../").describe(
            describe_strategy=pygit2.GIT_DESCRIBE_TAGS
        )
    except pygit2.GitError:
        import liualgotrader

        config.build_label = liualgotrader.__version__ if hasattr(liualgotrader, "__version__") else ""  # type: ignore

    uid = str(uuid.uuid4())
    motd(filename=config.filename, version=config.build_label, unique_id=uid)

    # load configuration
    tlog(
        f"loading configuration file from {os.getcwd()}/{config.configuration_filename}"
    )
    try:
        conf_dict = toml.load(config.configuration_filename)
    except FileNotFoundError:
        tlog(f"[ERROR] could not locate tradeplan file {config.configuration_filename}")
        sys.exit(0)


    # parse configuration
    config.bypass_market_schedule = conf_dict.get("bypass_market_schedule", False)

    # basic validation for scanners and strategies
    tlog(f"bypass_market_schedule = {config.bypass_market_schedule}")
    if "scanners" not in conf_dict or len(conf_dict["scanners"]) == 0:
        tlog("must have at least one scanner configured")
        exit(0)
    elif "strategies" not in conf_dict or len(conf_dict["strategies"]) == 0:
        tlog("must have at least one strategy configured")
        exit(0)

    scanners_conf = conf_dict["scanners"]
    for scanner in scanners_conf:
        tlog(f"- {scanner} scanner detected")


    data_api = tradeapi.REST(
        base_url=config.prod_base_url,
        key_id=config.prod_api_key_id,
        secret_key=config.prod_api_secret,
    )

    if ready_to_start(data_api):
        # add open positions
        symbols: List = []
        base_url = (
            config.prod_base_url if config.env == "PROD" else config.paper_base_url
        )
        api_key_id = (
            config.prod_api_key_id if config.env == "PROD" else config.paper_api_key_id
        )
        api_secret = (
            config.prod_api_secret if config.env == "PROD" else config.paper_api_secret

        )
        trading_api = tradeapi.REST(
            base_url=base_url, key_id=api_key_id, secret_key=api_secret
        )
        existing_positions = trading_api.list_positions()

        if len(existing_positions) == 0:
            tlog("no open positions")
        else:
            for position in existing_positions:
                if position.symbol not in symbols:
                    symbols.append(position.symbol)
                    tlog(f"added existing open position in {position.symbol}")

        minute_history = get_historical_data_from_polygon(
            api=data_api,
            symbols=symbols,
            max_tickers=min(config.total_tickers, len(symbols)),
        )
        symbols = list(minute_history.keys())

        # Consumers first
        num_consumer_processes = calc_num_consumer_processes()
        tlog(f"Starting {num_consumer_processes} consumer processes")

        queues: List[mp.Queue] = [mp.Queue() for i in range(num_consumer_processes)]
        q_id_hash = {}
        symbol_by_queue = {}
        for symbol in symbols:
            _index = random.SystemRandom().randint(0, num_consumer_processes-1)
            q_id_hash[symbol] = _index
            if _index not in symbol_by_queue:
                symbol_by_queue[_index] = [symbol]
            else:
                symbol_by_queue[_index].append(symbol)

        consumers = [
            mp.Process(
                target=consumer_main,
                args=(
                    queues[i],
                    symbol_by_queue[i] if i in symbol_by_queue else None,
                    minute_history,
                    uid,
                    conf_dict,
                ),
            )
            for i in range(num_consumer_processes)
        ]
        for p in consumers:
            # p.daemon = True
            p.start()

        scanner_queue: mp.Queue = mp.Queue()
        producer = mp.Process(
            target=polygon_producer_main,
            args=(
                uid,
                queues,
                symbols,
                q_id_hash,
                config.market_close,
                conf_dict,
                scanner_queue,
                num_consumer_processes,
            ),
        )
        producer.start()

        tlog("Starting scanners process")
        scanner = mp.Process(
            target=main,
            args=(
                conf_dict,
                config.market_open,
                config.market_close,
                scanner_queue,
            ),
        )
        scanner.start()

        # wait for completion and hope everyone plays nicely
        try:
            producer.join()
            scanner.join()
            for p in consumers:
                p.join()
        except KeyboardInterrupt:
            producer.terminate()
            scanner.terminate()
            for p in consumers:
                p.terminate()

    print("+=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=+")
    tlog(f"run {uid} completed")
    sys.exit(0)
