#!python

import os
import sys
import datetime
import yaml
import calendar
import json
import time
import requests
import math
import pwd
import importlib
import decimal
from pushover import Client
from monzo_utils.lib.monzo_api import MonzoAPI
from monzo_utils.lib.config import Config
from monzo_utils.lib.db import DB
from monzo_utils.lib.monzo_sync import MonzoSync
from monzo_utils.model.provider import Provider
from monzo_utils.model.account import Account
from monzo_utils.model.pot import Pot
from monzo_utils.model.transaction import Transaction
from monzo_utils.model.flex_summary import FlexSummary

PROVIDER = 'Monzo'

class PaymentFundsCheck:
    def __init__(self, account_name, output_json, abbreviate=False):
        self.account_name = account_name
        self.json = output_json
        self.abbreviate = abbreviate
        self.output = []

        self.load_config()
        self.validate_config()

        try:
            self.db = DB(None, self.config_path)
        except Exception as e:
            print(f"failed to connect to the database: {str(e)}")
            sys.exit(1)

        self.seen = []
        self.exchange_rates = {}

        self.provider = Provider().find_by_name(PROVIDER)
        self.account = Account().find_by_provider_id_and_name(self.provider.id, self.account_name)

        if not self.account:
            sys.stderr.write(f"account {self.account_name} not found in the database\n")
            sys.exit(1)

        homedir = pwd.getpwuid(os.getuid()).pw_dir
        self.credit_tracker = f"{homedir}/{self.config['account']}.credit"
        self.shortfall_tracker = f"{homedir}/{self.config['account']}.shortfall"


    def load_config(self):
        if os.path.exists(self.account_name):
            account_config_file = self.account_name
            self.config_path = os.path.dirname(account_config_file)
        else:
            homedir = pwd.getpwuid(os.getuid()).pw_dir
            self.config_path = f"{homedir}/.monzo"

            if not os.path.exists(self.config_path):
                os.mkdir(self.config_path, 0o755)

            account_config_file = f"{self.config_path}/{self.account_name}.yaml"

        if not os.path.exists(account_config_file):
            sys.stderr.write(f"Cannot find account config file: {account_config_file}\n")
            sys.exit(1)

        try:
            self.config = yaml.safe_load(open(account_config_file).read())
            self.account_name = self.config['account']
        except Exception as e:
            sys.stderr.write(f"Cannot read or parse account config file {account_config_file}: {str(e)}\n")
            sys.exit(1)


    def validate_config(self):
        for required in ['payments','salary_description','salary_payment_day']:
            if required not in self.config or not self.config[required]:
                sys.stderr.write(f"Missing config key: {required}\n")
                sys.exit(1)

        if ('notify_shortfall' in self.config and self.config['notify_shortfall']) or ('notify_credit' in self.config and self.config['notify_credit']):
            for required in ['pushover_key','pushover_app']:
                if required not in self.config or not self.config[required]:
                    sys.stderr.write(f"Push is enabled but push config key is missing: {required}\n")
                    sys.exit(1)


    def main(self):
        self.last_salary_date = self.get_last_salary_date()
        self.next_salary_date = self.get_next_salary_date(self.last_salary_date)
        self.following_salary_date = self.get_next_salary_date(self.next_salary_date)

        self.due = 0
        self.total_this_month = 0
        self.next_month = 0
        self.next_month_bills_pot = 0

        if self.json is False:
            self.display_columns()

        for payment_list in self.config['payments']:
            if not payment_list['payments']:
                continue

            self.display_payment_list(payment_list['type'], payment_list['payments'])

        if 'refunds_due' in self.config and self.config['refunds_due']:
            self.display_payment_list('Refund', self.config['refunds_due'])

        if 'pot' in self.config:
            pot = Pot().find_by_account_id_and_name_and_deleted(self.account.id, self.config['pot'], 0)
        else:
            pot = self.account

        shortfall = (self.due - (round(pot.balance * 100))) / 100

        if self.json:
            data = {
                'balance': float(pot.balance),
                'due': float(self.due) / 100,
                'total_this_month': self.total_this_month / 100,
                'total_next_month': self.next_month / 100,
                'payments': self.sanitise(self.output)
            }

            if shortfall * 100 <0:
                data['credit'] = (round(pot.balance * 100) - self.due) / 100
                data['shortfall'] = 0
            else:
                data['shortfall'] = shortfall
                data['credit'] = 0

            print(json.dumps(data,indent=4))
            return

        sys.stdout.write("\n")

        sys.stdout.write(" " * 25)
        sys.stdout.write("TOTAL THIS MONTH:".ljust(31))
        print("£%.2f" % (self.total_this_month / 100))

        if 'exclude_yearly_from_bills' in self.config and self.config['exclude_yearly_from_bills']:
            sys.stdout.write("\n")

        sys.stdout.write(" " * 25)
        sys.stdout.write("TOTAL NEXT MONTH:".ljust(31))
        print("£%.2f" % (self.next_month / 100))

        credit = (round(pot.balance * 100) - self.due) / 100

        if round(shortfall * 100) <0:
            sys.stdout.write(" " * 25)
            sys.stdout.write("LESS CREDIT BALANCE:".ljust(31))
            print("£%.2f" % ((self.next_month / 100) - credit))

        if 'exclude_yearly_from_bills' in self.config and self.config['exclude_yearly_from_bills']:
            sys.stdout.write(" " * 25)
            sys.stdout.write("Bills pot payment:".ljust(31))
            print("£%.2f" % (self.next_month_bills_pot / 100))

        sync_required = False

        if round(shortfall * 100) >0:
            print("      due: £%.2f" % (self.due / 100))
            print("  balance: £%.2f" % (pot.balance))
            print("shortfall: £%.2f" % (shortfall))

            sync_required = self.handle_shortfall(pot, shortfall)

        else:

            if os.path.exists(self.shortfall_tracker):
                os.remove(self.shortfall_tracker)

            print("    due: £%.2f" % (self.due / 100))
            print("balance: £%.2f" % (pot.balance))

            if round(credit * 100) == 0:
                credit = 0

                if os.path.exists(self.credit_tracker):
                    os.remove(self.credit_tracker)

            else:
                print(" credit: £%.2f" % (credit))

                sync_required = self.handle_credit(pot, credit)

        today = datetime.datetime.now()

        if today.strftime('%Y%m%d') == self.last_salary_date.strftime('%Y%m%d'):
            if self.check_pot_payments():
                sync_required = True

        if sync_required:
            ms = MonzoSync()
            ms.sync()


    def display_columns(self):
        print("%s %s %s %s %s %s %s %s" % (
            'Status'.rjust(8),
            'Type'.ljust(15),
            'Name'.ljust(25),
            ''.ljust(4),
            'Amount'.ljust(8),
            'Left'.ljust(8),
            'Last date'.ljust(12),
            'Due date'.ljust(10)
        ))

        print("-" * 97)


    def prompt_action(self, prompt):
        while 1:
            sys.stdout.write(prompt)
            sys.stdout.flush()

            i = sys.stdin.readline().rstrip().lower()

            if i in ['y','n']:
                break

        return i == 'y'

 
    def display_payment_list(self, payment_type, payment_list):
        payments = self.get_payments(payment_type, payment_list)

        summary = None

        if payment_type == 'Flex':
            total_this_month = 0
            total_next_month = 0
            today = datetime.datetime.now()
            today = datetime.date(today.year, today.month, today.day)

            for payment in payments:
                total_this_month += payment.amount_for_period(self.last_salary_date, self.next_salary_date)
                total_next_month += payment.amount_for_period(self.next_salary_date, self.following_salary_date)

            flex_remaining = 0

            for payment in payments:
                flex_remaining += payment.remaining

            summary = FlexSummary(self.config, total_this_month, total_next_month, flex_remaining, self.last_salary_date)

            if self.json:
                self.output.append(summary.data(self.abbreviate))
            else:
                summary.display()

        for payment in payments:
            if summary:
                payment.last_flex_payment = summary.last_payment

            if payment.status in ['DUE','PAID']:
                self.total_this_month += payment.display_amount * 100

            if payment.status == 'DUE':
                self.due += payment.display_amount * 100

            if payment.due_next_month:
                if payment.status == 'SKIPPED' or payment_type != 'Refund':
                    self.next_month += payment.next_month_amount * 100

                    if 'exclude_yearly_from_bills' not in self.config or self.config['exclude_yearly_from_bills'] is False or 'yearly_month' not in payment_config:
                        self.next_month_bills_pot += payment.next_month_amount * 100

            if self.json:
                self.output.append(payment.data(self.abbreviate))
            else:
                payment.display()


    def get_payments(self, payment_type, payment_list):
        payments = []

        for payment_config in payment_list:
            payment_list_type_library = payment_type.lower().replace(' ','_')
            payment_list_type = payment_type.title().replace(' ','')

            payment = getattr(importlib.import_module(f"monzo_utils.model.{payment_list_type_library}"), payment_list_type)(
                self.config,
                payment_list,
                payment_config,
                self.last_salary_date,
                self.next_salary_date,
                self.following_salary_date
            )

            payments.append(payment)

        return payments


    def notify(self, event, message):
        pushover = Client(self.config['pushover_key'], api_token=self.config['pushover_app'])
        pushover.send_message(message, title=event)


    def abbreviate_string(self, string):
        abbreviated = ''
        
        for i in range(0, len(string)):
            if string[i].isupper():
                abbreviated += string[i]

        return abbreviated


    def get_last_salary_date(self):
        if 'salary_account' in self.config and self.config['salary_account'] != self.account_name:
            account = Account().find_by_provider_id_and_name(self.provider.id, self.config['salary_account'])
        else:
            account = self.account

        last_salary_transaction = account.last_salary_transaction(
            description=self.config['salary_description'],
            payment_day=self.config['salary_payment_day'],
            salary_minimum=self.config['salary_minimum'] if 'salary_minimum' in self.config else 1000
        )

        if not last_salary_transaction:
            sys.stderr.write("failed to find last salary transaction.\n")
            sys.stderr.write(f"SQL: {sql}\n")
            sys.stderr.write(f"params: {json.dumps(params,indent=4)}\n")
            sys.exit(1)

        last_salary_date = last_salary_transaction['date']

        return last_salary_date


    def get_next_salary_date(self, last_salary_date):
        while last_salary_date.day != 15:
            try:
                last_salary_date = datetime.date(last_salary_date.year, last_salary_date.month, last_salary_date.day+1)
            except:
                last_salary_date = datetime.date(last_salary_date.year, last_salary_date.month+1, 1)

        next_salary_date = datetime.date(last_salary_date.year, last_salary_date.month, last_salary_date.day+1)

        while next_salary_date.day != 15:
            try:
                next_salary_date = datetime.date(next_salary_date.year, next_salary_date.month, next_salary_date.day+1)
            except:
                if next_salary_date.month == 12:
                    next_salary_date = datetime.date(next_salary_date.year+1, 1, 1)
                else:
                    next_salary_date = datetime.date(next_salary_date.year, next_salary_date.month+1, 1)

        while next_salary_date.weekday() in [5,6]:
            next_salary_date = datetime.date(next_salary_date.year, next_salary_date.month, next_salary_date.day-1)

        return next_salary_date


    def handle_shortfall(self, pot, shortfall):
        deposit = False
        notify = False

        if 'pot' in self.config and 'auto_deposit' in self.config and self.config['auto_deposit'] and not sys.stdout.isatty() and not self.abbreviate:
            if 'auto_deposit_delay_mins' in self.config:
                withdraw = self.handle_deposit_delay(credit)
            else:
                withdraw = True

        elif not sys.stdout.isatty():
            if 'notify_shortfall' in self.config and self.config['notify_shortfall'] and not self.abbreviate:
                notify = True
        else:
            if 'pot' in self.config and 'prompt_deposit' in self.config and self.config['prompt_deposit']:
                deposit = self.prompt_action("\ndeposit shortfall? [y/N] ")

            elif 'notify_shortfall' in self.config and self.config['notify_shortfall']:
                notify = True

        if deposit:
            m = MonzoAPI()

            if not m.deposit_to_pot(self.account.account_id, pot, shortfall):
                sys.stderr.write("ERROR: failed to deposit funds\n")

                if 'notify_deposit' in self.config and self.config['notify_deposit']:
                    self.notify(
                        '%s - pot deposit failed' % (self.account.name),
                        "£%.2f\n£%.2f due, £%.2f available" % (
                            credit,
                            self.due / 100,
                            pot.balance
                        )
                    )
            else:
                if os.path.exists(self.shortfall_tracker):
                    os.remove(shortfall_tracker)

                if 'notify_deposit' in self.config and self.config['notify_deposit']:
                    self.notify(
                        '%s - pot topped up' % (self.account.name),
                        "£%.2f\n£%.2f due, £%.2f available" % (
                            shortfall,
                            self.due / 100,
                            pot.balance
                        )
                    )

                return True

        elif notify:
            self.notify(
                '%s - shortfall' % (self.account.name),
                "£%.2f\n£%.2f due, £%.2f available" % (
                    shortfall,
                    self.due / 100,
                    pot.balance
                )
            )

        return False


    def handle_credit(self, pot, credit):
        withdraw = False
        notify = False

        if 'pot' in self.config and 'auto_withdraw' in self.config and self.config['auto_withdraw'] and not sys.stdout.isatty() and not self.abbreviate:
            if 'auto_withdraw_delay_mins' in self.config:
                withdraw = self.handle_withdrawal_delay(credit)
            else:
                withdraw = True
        elif not sys.stdout.isatty():
            if 'notify_credit' in self.config and self.config['notify_credit'] and not self.abbreviate:
                notify = True
        else:
            if 'pot' in self.config and 'prompt_withdraw' in self.config and self.config['prompt_withdraw']:
                withdraw = self.prompt_action("\nwithdraw credit? [y/N] ")

            elif 'notify_credit' in self.config and self.config['notify_credit']:
                notify = True

        if withdraw:
            m = MonzoAPI()

            if not m.withdraw_credit(self.account.account_id, pot, credit):
                sys.stderr.write("ERROR: failed to withdraw credit\n")

                if 'notify_withdraw' in self.config and self.config['notify_withdraw']:
                    self.notify(
                        '%s - pot credit withdraw failed' % (self.account.name),
                        "£%.2f\n£%.2f due, £%.2f available" % (
                            credit,
                            self.due / 100,
                            pot.balance
                        )
                    )
            else:
                if os.path.exists(self.credit_tracker):
                    os.remove(credit_tracker)

                if 'notify_withdraw' in self.config and self.config['notify_withdraw']:
                    self.notify(
                        '%s - pot credit withdrawn' % (self.account.name),
                        "£%.2f\n£%.2f due, £%.2f available" % (
                            credit,
                            self.due / 100,
                            pot.balance
                        )
                    )

                return True

        elif notify:
            self.notify(
                '%s - credit' % (self.account.name),
                "£%.2f\n£%.2f due, £%.2f available" % (
                    credit,
                    self.due / 100,
                    pot.balance
                )
            )

        return False


    def handle_withdrawal_delay(self, credit):
        credit_s = str(int(credit * 100))

        if not os.path.exists(self.credit_tracker) or open(self.credit_tracker).read().rstrip() != credit_s:
            with open(self.credit_tracker,'w') as f:
                f.write(credit_s)

            return False

        elapsed = time.time() - os.stat(self.credit_tracker).st_mtime

        return elapsed >= self.config['auto_withdraw_delay_mins']


    def handle_deposit_delay(self, shortfall):
        homedir = pwd.getpwuid(os.getuid()).pw_dir

        shortfall_s = str(int(shortfall * 100))

        if not os.path.exists(self.shortfall_tracker) or open(self.shortfall_tracker).read().rstrip() != shortfall_s:
            with open(self.shortfall_tracker,'w') as f:
                f.write(shortfall_s)

            return False

        elapsed = time.time() - os.stat(self.shortfall_tracker).st_mtime

        return elapsed >= self.config['auto_deposit_delay_mins']


    def sanitise(self, output):
        new_output = []

        for item in output:
            obj = {}
            
            for key in item:
                if type(item[key]) == datetime.date:
                    obj[key] = item[key].strftime('%Y-%m-%d')
                elif type(item[key]) == datetime.datetime:
                    obj[key] = item[key].strftime('%Y-%m-%d %H:%M:%S')
                elif type(item[key]) == decimal.Decimal:
                    obj[key] = float(item[key])
                else:
                    obj[key] = item[key]

            new_output.append(obj)

        return new_output


    def check_pot_payments(self):
        m = None
        sync_required = False

        if 'payments_to_pots' in self.config and type(self.config['payments_to_pots']) == list:
            for payment in self.config['payments_to_pots']:
                pot = Pot().find_by_name_and_deleted(payment['name'], 0)

                if pot.last_monthly_transfer_date != self.last_salary_date:
                    amount_to_transfer = self.get_transfer_amount(pot, payment)

                    if amount_to_transfer == 0:
                        continue

                    if not m:
                        m = MonzoAPI()
                        sys.stdout.write("\n")

                    sys.stdout.write("Transferring £%.2f to pot: %s ... " % (amount_to_transfer / 100, pot.name))
                    sys.stdout.flush()

                    if m.deposit_to_pot(self.account.account_id, pot, amount_to_transfer / 100):
                        sys.stdout.write("OK\n")

                        pot.last_monthly_transfer_date = self.last_salary_date
                        pot.save()

                        sync_required = True
                    else:
                        sys.stdout.write("FAILED\n\n")

                        sys.stderr.write("ERROR: failed to transfer £%.2f to pot: %s\n" % (amount_to_transfer / 100, pot.name))

        return sync_required


    def get_transfer_amount(self, pot, payment):
        amount = round(payment['amount'] * 100)

        if 'topup' in payment and payment['topup']:
            pot_balance = round(pot.balance * 100)
            to_transfer = max([0, amount - pot_balance])
        else:
            to_transfer = amount

        return to_transfer


if len(sys.argv) <2:
    print("usage: %s [-o json] <account_name>" % (sys.argv[0].split('/')[-1]))
    sys.exit(1)

output_json = False
account = None
abbreviate = False

for i in range(1, len(sys.argv)):
    if sys.argv[i] == '-o' and i+1 < len(sys.argv) and sys.argv[i+1] == 'json':
        output_json = True
        continue
    else:
        if sys.argv[i] == '-a':
            abbreviate = True
            continue

        if sys.argv[i-1] == '-o':
            continue

        account = sys.argv[i]

if account is None:
    print("usage: %s [-o json] [-a] <account_name>" % (sys.argv[0].split('/')[-1]))
    sys.exit(1)

p = PaymentFundsCheck(account, output_json, abbreviate)
p.main()
