#!/usr/bin/python

import logging
import traceback
import string
import click
import pkg_resources
import boto3
import botocore
from botocore.exceptions import ClientError

PROG_NAME = 'aws-default-cleaner'
VERSION = pkg_resources.require(PROG_NAME)[0].version

class AwsDefaultCleaner(object):

    def __init__(self, assume_roles, regions, dry_run=True):
        self.dry_run = dry_run
        self.assume_roles = assume_roles
        self.regions = regions
        self._setup_logging()

    def _setup_logging(self):
        self._log = logging.getLogger(PROG_NAME)
        self._log.setLevel(logging.DEBUG)
        ch = logging.StreamHandler()
        formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
        ch.setFormatter(formatter)
        self._log.addHandler(ch)

    def _aws_session(self, role_arn=None, session_name='default'):
        try:
            if role_arn:
                client = boto3.client('sts')
                response = client.assume_role(RoleArn=role_arn, RoleSessionName=session_name)
                session = boto3.Session(
                    aws_access_key_id=response['Credentials']['AccessKeyId'],
                    aws_secret_access_key=response['Credentials']['SecretAccessKey'],
                    aws_session_token=response['Credentials']['SessionToken'])
                return session
            else:
                return boto3.Session()
        except ClientError as e:
            self._log.error('Failed to establish AWS session: %s' % e.response['Error']['Message'])
            raise

    def _process_default_igw(self, client, default_vpc_filter_igw, region):
        entities = client.describe_internet_gateways(Filters=[default_vpc_filter_igw])
        
        for entity in entities["InternetGateways"]:
            message = '[%s] Default Internet Gateway %s has been found' % (region, entity["InternetGatewayId"])
            if self.dry_run:
                self._log.info(message)
            else:
                self._log.info('%s ...deleting' % message)
                try:                    
                    for attachment in entity["Attachments"]:                    
                        client.detach_internet_gateway(InternetGatewayId=entity["InternetGatewayId"], VpcId=attachment["VpcId"])
                    client.delete_internet_gateway(InternetGatewayId=entity["InternetGatewayId"])
                except ClientError as e:
                    self._log.error(e.response['Error']['Message'])

    def _process_default_subnets(self, client, default_vpc_filter, region):
        entities = client.describe_subnets(Filters=[default_vpc_filter])
        
        for entity in entities["Subnets"]:
            message = '[%s] Default Subnet %s has been found' % (region, entity["SubnetId"])
            if self.dry_run:
                self._log.info(message)
            else:
                self._log.info('%s ...deleting' % message)
                try:                    
                    client.delete_subnet(SubnetId=entity["SubnetId"])
                except ClientError as e:
                    self._log.error(e.response['Error']['Message'])

    def _process_default_sgs(self, client, default_vpc_filter, region):
        entities = client.describe_security_groups(Filters=[default_vpc_filter])
        
        for entity in entities["SecurityGroups"]:
            message = '[%s] Default Security Group %s has been found' % (region, entity["GroupId"])
            if self.dry_run:
                self._log.info(message)
            else:
                self._log.info('%s ...deleting' % message)
                try:
                    # default security group will be deleted along with default VPC
                    if entity['GroupName'] == 'default': 
                        continue       
                    client.delete_security_group(GroupId=entity["GroupId"])
                except ClientError as e:
                    self._log.error(e.response['Error']['Message'])

    def _process_default_route_tables(self, client, default_vpc_filter, region):
        entities = client.describe_route_tables(Filters=[default_vpc_filter])
        
        for entity in entities["RouteTables"]:
            message = '[%s] Default Route Table %s has been found' % (region, entity["RouteTableId"])
            if self.dry_run:
                self._log.info(message)
            else:
                self._log.info('%s ...deleting' % message)
                try:
                    # main route table will be deleted along with default VPC
                    if any(association['Main'] for association in entity['Associations']): 
                        continue          
                    client.delete_route_table(RouteTableId=entity["RouteTableId"])
                except ClientError as e:
                    self._log.error(e.response['Error']['Message'])

    def _process_default_nacls(self, client, default_vpc_filter, region):
        entities = client.describe_network_acls(Filters=[default_vpc_filter])
        
        for entity in entities["NetworkAcls"]:
            message = '[%s] Default Network ACL %s has been found' % (region, entity["NetworkAclId"])
            if self.dry_run:
                self._log.info(message)
            else:
                self._log.info('%s ...deleting' % message)
                try:
                    # default network ACL will be deleted along with default VPC
                    if entity['IsDefault']: 
                        continue             
                    client.delete_network_acl(NetworkAclId=entity["NetworkAclId"])
                except ClientError as e:
                    self._log.error(e.response['Error']['Message'])

    def _process_default_vpc(self, client, region):
        try:
            vpcs = client.describe_vpcs(Filters=[{ 'Name': 'isDefault', 'Values': ['true'] }])
            if not vpcs.get('Vpcs'):
                self._log.info('[%s] No default VPCs has been found' % region)
            else:
                default_vpc_id = vpcs['Vpcs'][0]['VpcId']
                default_vpc_filter = { "Name": "vpc-id", "Values": [default_vpc_id] }
                default_vpc_filter_igw = { "Name": "attachment.vpc-id", "Values": [default_vpc_id] }
                self._process_default_igw(client, default_vpc_filter_igw, region)
                self._process_default_subnets(client, default_vpc_filter, region)
                self._process_default_route_tables(client, default_vpc_filter, region)
                self._process_default_nacls(client, default_vpc_filter, region)
                self._process_default_sgs(client, default_vpc_filter, region)

                message = '[%s] Default VPC %s has been found' % (region, default_vpc_id)
                if self.dry_run:
                    self._log.info(message)
                else:
                    self._log.info('%s ...deleting' % message)                    
                    try:
                        client.delete_vpc(VpcId=default_vpc_id)
                    except ClientError as e:
                        self._log.error(e.response['Error']['Message'])
            
        except ClientError as e:
            self._log.error("Unable to describe VPCs in region %s: %s" % (region, e.response['Error']['Message']))

    def _process_regions(self):
        if self.regions == ():
            self._log.info('Checking default VPCs in all regions...')
            regions = self._session.get_available_regions('ec2')
        else:
            self._log.info('Checking default VPCs in supplied regions...')
            regions = self.regions

        for region in regions:
            client = self._session.client('ec2', region_name=region)
            self._process_default_vpc(client, region)

    def run(self):
        try:
            if self.assume_roles == ():
                self._log.info("No roles specified, working with current AWS credentials")
                self._session = self._aws_session()
                self._process_regions()
            else:
                for role in self.assume_roles:
                    self._log.info('Assuming role "%s"' % role)
                    self._session = self._aws_session(role)
                    self._process_regions()

            self._log.info('Operation finished successfully!')
        except Exception as e:
            self._log.error('Operation failed!')
            self._log.error(traceback.format_exc())

@click.group()
@click.version_option(version=VERSION, prog_name=PROG_NAME)
def aws_default_cleaner():
    pass

@click.command()
@click.argument('discover', required=False)
@click.option('--assume', '-a', help='List of roles to assume', multiple=True)
@click.option('--region', '-r', help='List of regions to process', multiple=True)
def discover(discover, assume, region):
    """Discovers default resources that can be deleted"""
    AwsDefaultCleaner(assume, region, True).run()

@click.command()
@click.argument('delete', required=False)
@click.option('--assume', '-a', help='List of roles to assume', multiple=True)
@click.option('--region', '-r', help='List of regions to process', multiple=True)
def delete(delete, assume, region):
    """Deletes default resources"""
    AwsDefaultCleaner(assume, region, False).run()

aws_default_cleaner.add_command(discover)
aws_default_cleaner.add_command(delete)

if __name__ == '__main__':
    aws_default_cleaner()
