#!python

import os
import sys
import datetime
import yaml
import calendar
import json
import time
import requests
import math
import pwd
import importlib
import decimal
from pushover import Client
from monzo_utils.lib.monzo_api import MonzoAPI
from monzo_utils.lib.config import Config
from monzo_utils.lib.db import DB
from monzo_utils.lib.monzo_sync import MonzoSync
from monzo_utils.model.provider import Provider
from monzo_utils.model.account import Account
from monzo_utils.model.pot import Pot
from monzo_utils.model.transaction import Transaction
from monzo_utils.model.flex_summary import FlexSummary

PROVIDER = 'Monzo'

class PaymentFundsCheck:
    def __init__(self, account_name, output_json, abbreviate=False):
        self.account_name = account_name
        self.json = output_json
        self.abbreviate = abbreviate
        self.output = []

        self.load_config()
        self.validate_config()

        self.db = DB(None, self.config_path)

        self.seen = []
        self.exchange_rates = {}

        self.provider = Provider().find_by_name(PROVIDER)
        self.account = Account().find_by_provider_id_and_name(self.provider.id, self.account_name)

        if not self.account:
            sys.stderr.write(f"account {self.account_name} not found in the database\n")
            sys.exit(1)


    def load_config(self):
        if os.path.exists(self.account_name):
            account_config_file = self.account_name
            self.config_path = os.path.dirname(account_config_file)
        else:
            homedir = pwd.getpwuid(os.getuid()).pw_dir
            self.config_path = f"{homedir}/.monzo"

            if not os.path.exists(self.config_path):
                os.mkdir(self.config_path, 0o755)

            account_config_file = f"{self.config_path}/{self.account_name}.yaml"

        if not os.path.exists(account_config_file):
            sys.stderr.write(f"Cannot find account config file: {account_config_file}\n")
            sys.exit(1)

        try:
            self.config = yaml.safe_load(open(account_config_file).read())
            self.account_name = self.config['account']
        except Exception as e:
            sys.stderr.write(f"Cannot read or parse account config file {account_config_file}: {str(e)}\n")
            sys.exit(1)


    def validate_config(self):
        for required in ['payments','salary_description','salary_payment_day']:
            if required not in self.config or not self.config[required]:
                sys.stderr.write(f"Missing config key: {required}\n")
                sys.exit(1)

        if ('notify_shortfall' in self.config and self.config['notify_shortfall']) or ('notify_credit' in self.config and self.config['notify_credit']):
            for required in ['pushover_key','pushover_app']:
                if required not in self.config or not self.config[required]:
                    sys.stderr.write(f"Push is enabled but push config key is missing: {required}\n")
                    sys.exit(1)


    def main(self):
        self.last_salary_date = self.get_last_salary_date()
        self.next_salary_date = self.get_next_salary_date(self.last_salary_date)
        self.following_salary_date = self.get_next_salary_date(self.next_salary_date)

        self.due = 0
        self.total_this_month = 0
        self.next_month = 0
        self.next_month_bills_pot = 0

        if self.json is False:
            self.display_columns()

        for payment_list in self.config['payments']:
            if not payment_list['payments']:
                continue

            self.display_payment_list(payment_list['type'], payment_list['payments'])

        if 'refunds_due' in self.config and self.config['refunds_due']:
            self.display_payment_list('Refund', self.config['refunds_due'])

        if 'pot' in self.config:
            pot = Pot().find_by_account_id_and_name(self.account.id, self.config['pot'])
        else:
            pot = self.account

        shortfall = (self.due - (round(pot.balance * 100))) / 100

        if self.json:
            data = {
                'balance': float(pot.balance),
                'due': float(self.due) / 100,
                'total_this_month': self.total_this_month / 100,
                'total_next_month': self.next_month / 100,
                'payments': self.sanitise(self.output)
            }

            if shortfall * 100 <0:
                data['credit'] = (round(pot.balance * 100) - self.due) / 100
                data['shortfall'] = 0
            else:
                data['shortfall'] = shortfall
                data['credit'] = 0

            print(json.dumps(data,indent=4))
            return

        sys.stdout.write("\n")

        sys.stdout.write(" " * 25)
        sys.stdout.write("TOTAL THIS MONTH:".ljust(31))
        print("£%.2f" % (self.total_this_month / 100))

        if 'exclude_yearly_from_bills' in self.config and self.config['exclude_yearly_from_bills']:
            sys.stdout.write("\n")

        sys.stdout.write(" " * 25)
        sys.stdout.write("TOTAL NEXT MONTH:".ljust(31))
        print("£%.2f" % (self.next_month / 100))

        credit = (round(pot.balance * 100) - self.due) / 100

        if round(shortfall * 100) <0:
            sys.stdout.write(" " * 25)
            sys.stdout.write("LESS CREDIT BALANCE:".ljust(31))
            print("£%.2f" % ((self.next_month / 100) - credit))

        if 'exclude_yearly_from_bills' in self.config and self.config['exclude_yearly_from_bills']:
            sys.stdout.write(" " * 25)
            sys.stdout.write("Bills pot payment:".ljust(31))
            print("£%.2f" % (self.next_month_bills_pot / 100))

        if round(shortfall * 100) >0:
            print("      due: £%.2f" % (self.due / 100))
            print("  balance: £%.2f" % (pot.balance))
            print("shortfall: £%.2f" % (shortfall))

            self.handle_shortfall(pot, shortfall)

        else:

            print("    due: £%.2f" % (self.due / 100))
            print("balance: £%.2f" % (pot.balance))

            if round(credit * 100) == 0:
                credit = 0
            else:
                print(" credit: £%.2f" % (credit))

                self.handle_credit(pot, credit)


    def display_columns(self):
        print("%s %s %s %s %s %s %s %s" % (
            'Status'.rjust(8),
            'Type'.ljust(15),
            'Name'.ljust(25),
            ''.ljust(4),
            'Amount'.ljust(8),
            'Left'.ljust(8),
            'Last date'.ljust(12),
            'Due date'.ljust(10)
        ))

        print("-" * 97)


    def prompt_action(self, prompt):
        while 1:
            sys.stdout.write(prompt)
            sys.stdout.flush()

            i = sys.stdin.readline().rstrip().lower()

            if i in ['y','n']:
                break

        return i == 'y'

 
    def display_payment_list(self, payment_type, payment_list):
        payments = self.get_payments(payment_type, payment_list)

        if payment_type == 'Flex':
            total_this_month = 0
            total_next_month = 0
            today = datetime.datetime.now()
            today = datetime.date(today.year, today.month, today.day)

            for payment in payments:
                total_this_month += payment.amount_for_date(self.last_salary_date)
                total_next_month += payment.amount_for_date(self.next_salary_date)

            date = self.last_salary_date

            while date.day != self.config['flex_payment_date']:
                date += datetime.timedelta(days=1)

            if date > today:
                flex_status = 'DUE'
            else:
                flex_status = 'PAID'

            flex_remaining = 0

            for payment in payments:
                flex_remaining += payment.remaining

            summary = FlexSummary(self.config, flex_status, total_this_month, total_next_month, flex_remaining)

            if self.json:
                self.output.append(summary.data(self.abbreviate))
            else:
                summary.display()

        for payment in payments:
            if payment.status in ['DUE','PAID']:
                self.total_this_month += payment.display_amount * 100

            if payment.status == 'DUE':
                self.due += payment.display_amount * 100

            if payment.due_next_month:
                if payment.status == 'SKIPPED' or payment_type != 'Refund':
                    self.next_month += payment.display_amount * 100

                    if 'exclude_yearly_from_bills' not in self.config or self.config['exclude_yearly_from_bills'] is False or 'yearly_month' not in payment_config:
                        self.next_month_bills_pot += payment.display_amount * 100

            if self.json:
                self.output.append(payment.data(self.abbreviate))
            else:
                payment.display()


    def get_payments(self, payment_type, payment_list):
        payments = []

        for payment_config in payment_list:
            payment_list_type_library = payment_type.lower().replace(' ','_')
            payment_list_type = payment_type.title().replace(' ','')

            payment = getattr(importlib.import_module(f"monzo_utils.model.{payment_list_type_library}"), payment_list_type)(
                self.config,
                payment_list,
                payment_config,
                self.last_salary_date,
                self.next_salary_date,
                self.following_salary_date
            )

            payments.append(payment)

        return payments


    def notify(self, event, message):
        pushover = Client(self.config['pushover_key'], api_token=self.config['pushover_app'])
        pushover.send_message(message, title=event)


    def abbreviate_string(self, string):
        abbreviated = ''
        
        for i in range(0, len(string)):
            if string[i].isupper():
                abbreviated += string[i]

        return abbreviated


    def get_last_salary_date(self):
        if 'salary_account' in self.config and self.config['salary_account'] != self.account_name:
            account = Account().find_by_provider_id_and_name(self.provider.id, self.config['salary_account'])
        else:
            account = self.account

        last_salary_transaction = account.last_salary_transaction(
            description=self.config['salary_description'],
            payment_day=self.config['salary_payment_day'],
            salary_minimum=self.config['salary_minimum'] if 'salary_minimum' in self.config else 1000
        )

        if not last_salary_transaction:
            sys.stderr.write("failed to find last salary transaction.\n")
            sys.stderr.write(f"SQL: {sql}\n")
            sys.stderr.write(f"params: {json.dumps(params,indent=4)}\n")
            sys.exit(1)

        last_salary_date = last_salary_transaction['date']

        return last_salary_date


    def get_next_salary_date(self, last_salary_date):
        while last_salary_date.day != 15:
            try:
                last_salary_date = datetime.date(last_salary_date.year, last_salary_date.month, last_salary_date.day+1)
            except:
                last_salary_date = datetime.date(last_salary_date.year, last_salary_date.month+1, 1)

        next_salary_date = datetime.date(last_salary_date.year, last_salary_date.month, last_salary_date.day+1)

        while next_salary_date.day != 15:
            try:
                next_salary_date = datetime.date(next_salary_date.year, next_salary_date.month, next_salary_date.day+1)
            except:
                if next_salary_date.month == 12:
                    next_salary_date = datetime.date(next_salary_date.year+1, 1, 1)
                else:
                    next_salary_date = datetime.date(next_salary_date.year, next_salary_date.month+1, 1)

        while next_salary_date.weekday() in [5,6]:
            next_salary_date = datetime.date(next_salary_date.year, next_salary_date.month, next_salary_date.day-1)

        return next_salary_date


    def handle_shortfall(self, pot, shortfall):
        deposit = False
        notify = False

        if 'pot' in self.config and 'auto_deposit' in self.config and self.config['auto_deposit'] and not sys.stdout.isatty() and not self.abbreviate:
            deposit = True
        elif not sys.stdout.isatty():
            if 'notify_shortfall' in self.config and self.config['notify_shortfall'] and not self.abbreviate:
                notify = True
        else:
            if 'pot' in self.config and 'prompt_deposit' in self.config and self.config['prompt_deposit']:
                deposit = self.prompt_action("\ndeposit shortfall? [y/N] ")

            elif 'notify_shortfall' in self.config and self.config['notify_shortfall']:
                notify = True

        if deposit:
            m = MonzoAPI()

            if not m.deposit_to_pot(self.account.account_id, pot, shortfall):
                sys.stderr.write("ERROR: failed to deposit funds\n")

                if 'notify_deposit' in self.config and self.config['notify_deposit']:
                    self.notify(
                        '%s - pot deposit failed' % (self.account.name),
                        "£%.2f\n£%.2f due, £%.2f available" % (
                            credit,
                            self.due / 100,
                            pot.balance
                        )
                    )
            else:
                ms = MonzoSync()
                ms.sync()

                if 'notify_deposit' in self.config and self.config['notify_deposit']:
                    self.notify(
                        '%s - pot topped up' % (self.account.name),
                        "£%.2f\n£%.2f due, £%.2f available" % (
                            shortfall,
                            self.due / 100,
                            pot.balance
                        )
                    )

        elif notify:
            self.notify(
                '%s - shortfall' % (self.account.name),
                "£%.2f\n£%.2f due, £%.2f available" % (
                    shortfall,
                    self.due / 100,
                    pot.balance
                )
            )


    def handle_credit(self, pot, credit):
        withdraw = False
        notify = False

        if 'pot' in self.config and 'auto_withdraw' in self.config and self.config['auto_withdraw'] and not sys.stdout.isatty() and not self.abbreviate:
            withdraw = True
        elif not sys.stdout.isatty():
            if 'notify_credit' in self.config and self.config['notify_credit'] and not self.abbreviate:
                notify = True
        else:
            if 'pot' in self.config and 'prompt_withdraw' in self.config and self.config['prompt_withdraw']:
                withdraw = self.prompt_action("\nwithdraw credit? [y/N] ")

            elif 'notify_credit' in self.config and self.config['notify_credit']:
                notify = True

        if withdraw:
            m = MonzoAPI()

            if not m.withdraw_credit(self.account.account_id, pot, credit):
                sys.stderr.write("ERROR: failed to withdraw credit\n")

                if 'notify_withdraw' in self.config and self.config['notify_withdraw']:
                    self.notify(
                        '%s - pot credit withdraw failed' % (self.account.name),
                        "£%.2f\n£%.2f due, £%.2f available" % (
                            credit,
                            self.due / 100,
                            pot.balance
                        )
                    )
            else:
                ms = MonzoSync()
                ms.sync()

                if 'notify_withdraw' in self.config and self.config['notify_withdraw']:
                    self.notify(
                        '%s - pot credit withdrawn' % (self.account.name),
                        "£%.2f\n£%.2f due, £%.2f available" % (
                            credit,
                            self.due / 100,
                            pot.balance
                        )
                    )

        elif notify:
            self.notify(
                '%s - credit' % (self.account.name),
                "£%.2f\n£%.2f due, £%.2f available" % (
                    credit,
                    self.due / 100,
                    pot.balance
                )
            )


    def sanitise(self, output):
        new_output = []

        for item in output:
            obj = {}
            
            for key in item:
                if type(item[key]) == datetime.date:
                    obj[key] = item[key].strftime('%Y-%m-%d')
                elif type(item[key]) == datetime.datetime:
                    obj[key] = item[key].strftime('%Y-%m-%d %H:%M:%S')
                elif type(item[key]) == decimal.Decimal:
                    obj[key] = float(item[key])
                else:
                    obj[key] = item[key]

            new_output.append(obj)

        return new_output


if len(sys.argv) <2:
    print("usage: %s [-o json] <account_name>" % (sys.argv[0].split('/')[-1]))
    sys.exit(1)

output_json = False
account = None
abbreviate = False

for i in range(1, len(sys.argv)):
    if sys.argv[i] == '-o' and i+1 < len(sys.argv) and sys.argv[i+1] == 'json':
        output_json = True
        continue
    else:
        if sys.argv[i] == '-a':
            abbreviate = True
            continue

        if sys.argv[i-1] == '-o':
            continue

        account = sys.argv[i]

if account is None:
    print("usage: %s [-o json] [-a] <account_name>" % (sys.argv[0].split('/')[-1]))
    sys.exit(1)

p = PaymentFundsCheck(account, output_json, abbreviate)
p.main()
