#!/usr/bin/env python
import argparse
import configparser
import os
import logging
import subprocess
import json
import dateutil.parser
import datetime
import pytz


CONFIG_DIR = os.path.join(os.path.expanduser('~'), '.aws-mfa-auth')
CONFIG_FILE = os.path.join(CONFIG_DIR, 'config.cnf')
SESSION_FILE = os.path.join(CONFIG_DIR, 'session.json')
MFA_ARN_OPTION = 'mfa_device_arn'


def _args():
    """ Parse arguments."""
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--configure',
        help='Configure the tool.',
        default=False,
        action='store_true',
    )
    parser.add_argument(
        '--profile',
        help='Name of AWS profile to authenticate against.',
        default=os.environ.get('AWS_PROFILE', 'default'),
    )
    parser.add_argument(
        '--token',
        help='AWS MFA Auth Token',
    )

    return parser.parse_args()


def _update_config(args, conf_path):
    """ Update the config."""
    pass


def _get_conf(path, section, option):
    """ Retrieve a specific config option from a file."""
    config = configparser.ConfigParser()
    config.read(path)
    try:
        mfa_arn = config.get(section, option)
    except configparser.NoSectionError:
        return None
    except configparser.NoOptionError:
        return None

    return mfa_arn


def _set_conf(path, section, option, value):
    """ Set a specific config item."""
    config = configparser.ConfigParser()
    config.read(path)

    if section not in config:
        logging.info('Section [{}] does not exist. Adding it.'.format(section))
        config.add_section(section)

    config.set(section, option, value)

    with open(path, 'w') as config_file:
        config.write(config_file)


def _configure(args):
    """ Configure."""
    profile = input(
        'Profile to Configure ({}): '.format(
            args.profile
        )
    ) or args.profile

    current_mfa_arn = _get_conf(
        CONFIG_FILE, profile, MFA_ARN_OPTION)
    mfa_arn = input(
        'MFA Device ARN ({}): '.format(current_mfa_arn)) or current_mfa_arn

    if not mfa_arn:
        print('Please provide a valid ARN for your MFA device.')
        return

    _set_conf(CONFIG_FILE, profile, MFA_ARN_OPTION, mfa_arn)

    print('Configuration updated.')


def _set_auth_envs(session):
    """ Print credentials."""
    key_id = session.get('AccessKeyId')
    secret = session.get('SecretAccessKey')
    token = session.get('SessionToken')

    print(
        'export AWS_ACCESS_KEY_ID={key_id} && '
        'export AWS_SECRET_ACCESS_KEY={secret} &&'
        'export AWS_SECURITY_TOKEN={token}'.format(
            key_id=key_id,
            secret=secret,
            token=token,
        )
    )


def _handle_mfa_auth(args):
    """ Handle MFA auth."""
    mfa_arn = _get_conf(CONFIG_FILE, args.profile, MFA_ARN_OPTION)
    session = None

    if args.token:
        try:
            out = subprocess.check_output([
                'aws', 'sts', 'get-session-token',
                '--serial-number', mfa_arn,
                '--token-code', args.token,
                '--profile', args.profile
            ])
        except subprocess.CalledProcessError:
            logging.exception('Unable to authenticate.')
            return

        response = json.loads(out)
        session = response.get('Credentials')
        with open(SESSION_FILE, 'w') as session_file:
            json.dump(session, session_file)

    if session:
        _set_auth_envs(session)
        return

    try:
        with open(SESSION_FILE, 'r') as session_file:
            session = json.load(session_file)
            session_exp = session.get('Expiration')
            exp_date = dateutil.parser.parse(session_exp)
            if exp_date > datetime.datetime.now(pytz.UTC):
                _set_auth_envs(session)
                return
    except FileNotFoundError:
        pass


    print('Session expired, run again with --token argument.')


def main():
    """ Script entrypoint."""
    args = _args()
    if args.configure:
        _configure(args)
        return
    _handle_mfa_auth(args)


if __name__ == '__main__':
    main()
