#!/usr/bin/env python3

import json
import os
import subprocess
from argparse import ArgumentParser
from collections import ChainMap
from configparser import ConfigParser, SectionProxy
from pathlib import Path


def main(vault_addr, vault_login_kwargs, extra_vl_flags,
         login_data_json,
         consul_template_hcl, consul_template_flag, extra_ct_flags,
         ):
    if vault_addr:
        os.environ['VAULT_ADDR'] = vault_addr
    if not is_vault_logged_in():
        vault_login(extra_vl_flags=extra_vl_flags, **vault_login_kwargs)
    os.environ['VAULTAWS_LOGINDATA_JSON'] = login_data_json
    os.execlp(
        'consul-template',
        'consul-template',
        '-config', consul_template_hcl,
        '-template', consul_template_flag,
        *extra_ct_flags
    )


def gen_login_data(
        profile, config,
        login_template_section='login_template',
        login_override_section='login_override {}'):
    login_template = config[login_template_section]

    login_ids = profile['login_ids'].split(',')
    for login_id in login_ids:
        curr_vars = dict(login_id=login_id)
        curr_sect = login_override_section.format(login_id)
        override_vars = ChainMap(curr_vars, config[curr_sect]) \
            if curr_sect in config \
            else curr_vars

        aws_profile_name = login_template.get('aws_profile_name',
                                              vars=override_vars,
                                              fallback=login_id)
        vault_secret_path = login_template.get('vault_secret_path',
                                               vars=override_vars)
        yield {
            'aws_profile_name': aws_profile_name,
            'vault_secret_path': vault_secret_path}


def is_vault_logged_in():
    return_code = subprocess.call(
        ['vault', 'token', 'lookup'],
        stdout=subprocess.DEVNULL
    )
    return return_code == 0


def vault_login(method=None, extra_vl_flags=(), **kwargs):
    subprocess.check_call([
        'vault', 'login',
        *(f'-method={value}'
          for value in (method,) if value),
        *extra_vl_flags,
        *(f'{key}={value}'
          for key, value in kwargs.items())
    ])


def gen_config(cli_profile_section, cli_config, cli_login,
               login_template_section='login_template'):
    config = ConfigParser(dict(
        credentials_file=Path('~/.aws/credentials').expanduser().absolute(),
        consul_template_hcl=Path('~/.vault-aws/credentials.hcl').expanduser().absolute(),
        template_source=Path('~/.vault-aws/credentials.tpl').expanduser().absolute(),
        template_dest=Path('~/.aws/vaultprovided-credentials').expanduser().absolute(),
        template_cmd='%(aws_credentials_merge_path)s '
                     '-o %(credentials_file)s '
                     '-i %(credentials_file)s %(template_dest)s',
        aws_credentials_merge_path='aws_credentials_merge',))
    config.read(cli_config)

    if login_template_section not in config:
        config.add_section(login_template_section)
    if cli_profile_section not in config:
        config.add_section(cli_profile_section)
    if cli_login:
        config.set(cli_profile_section, 'login_ids', value=cli_login)

    return config


def gen_vault_login_kwargs(profile, key_prefix='vault_login'):
    offset = len(key_prefix) + 1 # offset for trimming key prefix
    return {
        key[offset:]: value
        for key, value
        in profile.items()
        if key.startswith(key_prefix)}


def gen_consul_template_flag(profile: SectionProxy):
    template_flag = ':'.join(filter(None, (
        profile['template_source'],
        profile['template_dest'],
        profile.get('template_cmd', fallback=None))))
    return template_flag


def parse_context():
    parser = ArgumentParser()
    parser.add_argument('-p', '--profile', default='default')
    parser.add_argument('-c', '--config', default='~/.vault-aws/config')
    parser.add_argument('-l', '--login', default=None)
    args = parser.parse_args()

    cli_profile = args.profile
    cli_config = Path(args.config).expanduser().absolute()
    cli_login = args.login
    config = gen_config(cli_profile, cli_config, cli_login)
    profile = config[cli_profile]

    vault_addr = profile.get('vault_addr', fallback=None)
    vault_login_kwargs = gen_vault_login_kwargs(profile)
    consul_template_hcl = profile['consul_template_hcl']
    consul_template_flag = gen_consul_template_flag(profile)

    login_data_json = json.dumps(list(gen_login_data(profile, config)))
    extra_ct_flags = json.loads(profile.get('extra_ct_flags', fallback='[]'))
    extra_vl_flags = json.loads(profile.get('extra_vl_flags', fallback='[]'))

    return dict(
        vault_addr=vault_addr,
        vault_login_kwargs=vault_login_kwargs,
        consul_template_hcl=consul_template_hcl,
        consul_template_flag=consul_template_flag,
        login_data_json=login_data_json,
        extra_ct_flags=extra_ct_flags,
        extra_vl_flags=extra_vl_flags,
    )


if __name__ == '__main__':
    main(**(parse_context()))
