#!python

import boto3
from botocore.exceptions import ClientError
import argparse
import sys
import shlex
import re
import os
import signal
import configparser

###############################################################################
#
# Constants
###############################################################################
__version__ = '0.0.3'

DEFAULT_MFA_CONFIG_FILE_DIR = "~/.aws/"
DEFAULT_MFA_CONFIG_FILE_NAME = "mfa-config"
DEFAULT_MFA_CONFIG_FILE_PATH = os.path.expanduser(os.path.join(DEFAULT_MFA_CONFIG_FILE_DIR,
                                DEFAULT_MFA_CONFIG_FILE_NAME))
DEFAULT_MFA_DURATION = 36000
AWS_CREDENTIALS_FILE_PATH = os.path.expanduser("~/.aws/credentials")


################################################################################
#
# Functions
################################################################################

# Handle signal
def signal_handler(signal, frame):
    print("Error: SIGINT received.")
    sys.exit(1)

# Input Validators
def arg_config_file(arg_string):
    # Validate arg_string below
    config_file_path = os.path.expandvars(arg_string)
    config_file_path = os.path.expanduser(config_file_path)
    if not os.path.isfile(config_file_path):
        raise argparse.ArgumentTypeError("Config file not found: {0}".format(config_file_path))
    return arg_string

def arg_profile(arg_string):
    not_valid = False
    # Validate arg_string below
    if not_valid:
        raise argparse.ArgumentTypeError(
                "Not a valid profile name '{0}'.".format(arg_string))
    return arg_string

def validate_token_code(arg_string):
    # Validate arg_string below
    if not re.match(r"\d{6}", arg_string):
        raise ValueError(
                "token_code must be 6 digits: '{0}'.".format(arg_string))
    return arg_string


def parse_args():
    parser = argparse.ArgumentParser(usage="%(prog)s [options]",
                                     description="updates aws credentials file with temporary sts credentials obtained with mfa",
                                     formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument('-d', '--debug', action='store_true', dest='debug', default=False, help='Enable debug')
    parser.add_argument('-c', '--config-file', dest='config_file', action="store",
                        help='config file to load mfa details [~/.aws/mfa-config]', default=DEFAULT_MFA_CONFIG_FILE_PATH,
                        type=arg_config_file)
    parser.add_argument('-p', '--profile', dest='profile', action="store",
                        help='profile to be loaded from the config file [default]', default="default",
                        type=arg_profile)
    return parser.parse_args()


def read_config(config_file_path):
    # If any errors with the config file, raise them here
    # Otherwise return the config as a dict
    parsed_config = {}
    config_file_path = os.path.expandvars(config_file_path)
    config_file_path = os.path.expanduser(config_file_path)
    if not os.path.isfile(config_file_path):
        raise ValueError("Config file not found: {0}".format(config_file_path))
    config = configparser.ConfigParser()
    config.read(config_file_path)
    if len(config.sections()) == 0:
        raise ValueError("Error parsing config file: {0}".format(config_file_path))
    for section in config.sections():
        if section.startswith("profile"):
            if not config[section]["mfa_serial"]:
               continue
               # raise ValueError("Not mfa_serial defined in section: {0}".format(section))
            if not config[section]["dest_profile"]:
               continue
               # raise ValueError("No dest_profile defined in section: {0}".format(section))
            try:
                split_section = shlex.split(section)
            except ValueError:
                continue
            if len(split_section) == 2:
                parsed_config[split_section[1]] = config[section]
        elif section == "default":
            # default is a special section in boto3 that doesn't require
            # profile as a prefix. why change here?
            # https://github.com/boto/botocore/blob/c6ebb3be3acc946e3b706333294320dc7e304dd7/botocore/configloader.py#L265
            parsed_config[section] = config[section]
    return parsed_config


def read_credentials(aws_credentials_file):
    cred_config = configparser.ConfigParser()
    cred_config.read(aws_credentials_file)
    return cred_config


def update_config(aws_credentials_file, dest_profile, credential_dict):
    credentials_config = read_credentials(aws_credentials_file)
    for section in credentials_config.sections():
        if section == dest_profile:
            credentials_config[section]["aws_access_key_id"] = credential_dict["AccessKeyId"]
            credentials_config[section]["aws_secret_access_key"] = credential_dict["SecretAccessKey"]
            credentials_config[section]["aws_session_token"] = credential_dict["SessionToken"]
            with open(aws_credentials_file, 'w') as credfile:
                credentials_config.write(credfile)
            return
    credentials_config[dest_profile] = {}
    credentials_config[dest_profile]["aws_access_key_id"] = credential_dict["AccessKeyId"]
    credentials_config[dest_profile]["aws_secret_access_key"] = credential_dict["SecretAccessKey"]
    credentials_config[dest_profile]["aws_session_token"] = credential_dict["SessionToken"]
    with open(aws_credentials_file, 'w') as credfile:
        credentials_config.write(credfile)
    return


def get_credential_dict(source_profile, mfa_serial, region, token_code, mfa_duration):
    conn_args = {}
    if source_profile:
        conn_args["profile_name"] = source_profile
    if region:
        conn_args["region_name"] = region
    if not mfa_duration:
        mfa_duration = DEFAULT_MFA_DURATION
    session = boto3.session.Session(**conn_args)
    sts_client = session.client("sts")
    get_session_token_query = {
        "DurationSeconds": mfa_duration,
        "SerialNumber": mfa_serial,
        "TokenCode": token_code
    }
    get_session_token_response = None
    try:
        get_session_token_response = sts_client.get_session_token(**get_session_token_query)
    except ClientError as err:
        if err.response['Error']['Code'] == 'AccessDenied':
            print("Invalid Toke Code. Exiting")
            sys.exit(1)
    except Exception as err:
        raise ValueError("{0}".format(err))
    return get_session_token_response.get('Credentials')


def get_token_code_from_user(mfa_serial):
    token_code = validate_token_code(input("token code for {0}: ".format(mfa_serial)))
    return token_code


################################################################################
#
# Main
################################################################################

def main():
    # Parse Args and read config
    options = parse_args()
    config = read_config(options.config_file)
    token_code = get_token_code_from_user(config[options.profile]["mfa_serial"])
    credential_dict = get_credential_dict(source_profile=config[options.profile].get("source_profile", options.profile),
                              mfa_serial=config[options.profile]["mfa_serial"],
                              mfa_duration=config[options.profile].get("mfa_duration", DEFAULT_MFA_DURATION),
                              region=config[options.profile].get("region", None), token_code=token_code)
    update_config(dest_profile=config[options.profile]["dest_profile"],
                        credential_dict=credential_dict, aws_credentials_file=AWS_CREDENTIALS_FILE_PATH)

if __name__ == '__main__':
    signal.signal(signal.SIGINT, signal_handler)
    main()
    sys.exit(0)

