#!/usr/bin/env python
"""
A Puppet ENC which assigns Nodes based on their AWS EC2 metadata.

For for information about developing a custom ENC, see:
https://puppet.com/docs/puppet/latest/nodes_external.html

The Puppet Master will need to following IAM Policy applied so that is can query
the EC2 API for Instance metadata:

    {
        "Version": "2012-10-17",
        "Statement": [
            {
                "Effect": "Allow",
                "Action": ["ec2:DescribeInstances"],
                "Resource": "*"
            }
        ]
    }

"""
from __future__ import print_function

import sys

__version__ = '1.0.0'

DEFAULT_AWS_REGION = 'us-east-1'
DEFAULT_ENVIRONMENT = 'production'
DEFAULT_ROLE = 'default'

ROLE_TAG = 'Role'
ENVIRONMENT_TAG = 'Environment'

def die(msg=None, status=1):
    """
    Exit with an error message and non-zero status code.
    """

    data = {'error': msg or 'An unknown error occurred.'}
    yaml_print(data, file=sys.stderr)
    sys.exit(status)

def yaml_print(data, file=sys.stdout):
    from yaml import SafeDumper, safe_dump

    # print '' instead of 'null' for None values
    SafeDumper.add_representer(
        type(None),
        lambda dumper, value: dumper.represent_scalar(u'tag:yaml.org,2002:null', '')
    )

    print('---', file=file)
    safe_dump(data,
              stream=file,
              default_flow_style=False,
              canonical=False)

def get_one_instance(value, key='instance-id'):
    """
    Searches EC2 for a single Instance matching the given key/value filter.
    """

    from os import environ
    from boto3 import client

    region_name = environ.get('AWS_DEFAULT_REGION') or DEFAULT_AWS_REGION
    ec2 = client('ec2', region_name=region_name)
    res = ec2.describe_instances(Filters=[
        {'Name': key, 'Values': [value]},
        {'Name': 'instance-state-name', 'Values': ['running']},
    ])
    count = 0
    instance = None
    for reservation in res['Reservations']:
        for instance in reservation['Instances']:
            count += 1
    if count > 1:
        die('More than one EC2 Instance found in %s where %s=%s' % (region_name, key, value))
    if not instance:
        die('No EC2 Instance found in %s where %s=%s' % (region_name, key, value))
    return instance

def get_tag(key, instance):
    """
    Returns the value of an EC2 Instance Tag or None if it does not exist.
    """

    for tag in instance['Tags']:
        if tag['Key'] == key:
            return tag['Value']
    return None

def tag_to_tsv(tag):
    """
    Convert an EC2 tag to a Puppet top scope variable.
    Inspired by https://github.com/BIAndrews/ec2tagfacts
    """

    from re import sub

    key = 'ec2_tag_%s' % tag['Key'].lower()
    key = sub(r'[^a-z0-9]+', '_', key)
    key = sub(r'_{2,}', '_', key)

    return {key: str(tag['Value'])}

def classify(instance):
    """
    Prints a Puppet Node classification for the given EC2 Instance.
    """

    env = get_tag(ENVIRONMENT_TAG, instance) or DEFAULT_ENVIRONMENT
    role = get_tag(ROLE_TAG, instance) or DEFAULT_ROLE
    classes = {'role::%s' % role: None}
    params = {'ec2_tags': {}}
    if 'Tags' in instance:
        for tag in instance['Tags']:
            params.update(tag_to_tsv(tag))
            params['ec2_tags'].update({tag['Key']: tag['Value']})
    data = {
        'environment': env,
        'classes': classes,
        'parameters': params,
    }
    yaml_print(data)

def usage(status=0):
    """
    Print usage information and exit.
    """

    f = sys.stdout if status == 0 else sys.stderr
    print('usage: %s instance-id' % sys.argv[0], file=f)
    sys.exit(status)

def main():
    """
    Prints a Puppet Node classification for the EC2 Instance given as the first
    argument on the command line.
    """

    if len(sys.argv) != 2:
        usage(1)
    value = sys.argv[1]
    key = 'tag:Name'
    if value.lower().startswith('i-'):
        key = 'instance-id'
    if value.lower().startswith('ip-'):
        key = 'private-dns-name'
    classify(get_one_instance(value, key=key))

if __name__ == '__main__':
    main()
