#!/usr/bin/env python

import subprocess, sys, os, argparse, boto3
from objects.exceptions import CommandError

#sys.tracebacklimit=0

def get_aws_session(profile_name=None):
	if profile_name:
		return boto3.Session(profile_name=profile_name, region_name='eu-west-1')
	else:
		return boto3.Session(region_name='eu-west-1')

def get_aws_subaccount_credentials(args, conn_client):
	conn_organizations = get_aws_session().client('organizations')
	# The maximun allowed is 20 results per request, so here we are using the AWS paginator API to get the results of the next pages and consolidate all accounts in one single dictionary:
	paginator = conn_organizations.get_paginator('list_accounts')
	response_iterator = paginator.paginate(
		PaginationConfig={
		'PageSize': 20
		}
	)

	# Iterate over the pages and create a dictionary with subaccount_name and subbaccount_id
	accounts_dict = {}

	for list_all_accounts in response_iterator:

		# Iterate over the results of each page
		for accounts in list_all_accounts['Accounts']:

			# Add all accounts to the dictionary
			if (accounts['Status'] == 'ACTIVE'):
				accounts_dict[accounts['Name']] = accounts['Id']

	# This will be the result:
	# accounts_dict = {'devops-production': '229424175607', 'datascienceplatform-production': '789006245671', ...}

	# Now that we have a list of all accounts and its IDs, let's assume the Developers role in the desired account based on args.environment_name to get the list of instances

	# Split the environment_name (e.g., abtesti-production-web) to account_name (e.g., abtesti-production)
	desired_account_name = args.environment_name.split("-")[0]+"-"+args.environment_name.split("-")[1]
	conn = None

	for account in accounts_dict:
		if account == desired_account_name:
			conn_sts = get_aws_session().client('sts')
			assumed_role_object = conn_sts.assume_role(
				RoleArn="arn:aws:iam::" + accounts_dict[account] + ":role/Developers",
				RoleSessionName="AssumeRoleSession1",
				DurationSeconds=43100
			)
			# From the response that contains the assumed role, let's get the temporary
			# credentials that will be used to make subsequent API calls in the sub-accounts.
			credentials = assumed_role_object['Credentials']

			conn = boto3.client(
				conn_client,
				aws_access_key_id=credentials['AccessKeyId'],
				aws_secret_access_key=credentials['SecretAccessKey'],
				aws_session_token=credentials['SessionToken'],
				region_name='eu-west-1'
			)
			break

	if conn is None:
		print("ERROR: wrong environment_name '" + args.environment_name + "'")
		exit()

	return conn

def get_instances_list(args):

	if args.profile:
		conn_eb = get_aws_session(args.profile).client('elasticbeanstalk')
	else:
		conn_eb = get_aws_subaccount_credentials(args, 'elasticbeanstalk')

	try:
		instances_list = [ instance['Id'] for instance in conn_eb.describe_environment_resources(EnvironmentName=args.environment_name)['EnvironmentResources']['Instances'] ]
	except:
		print("ERROR: wrong environment_name '" + args.environment_name + "'")
		exit()

	return instances_list

def get_input(output, default):

	result = str(input(output + ': ')).strip() or default
	return result


def prompt_for_instance_in_list(instances_list, default=1):
	for x in range(0, len(instances_list)):
		print(str(x + 1) + ')', instances_list[x])

	while True:
		try:
			choice = int(get_input('(default is ' + str(default)+')', default))
			if not (0 < choice <= len(instances_list)):
				raise ValueError  # Also thrown by non int numbers
			else:
				break
		except ValueError:
			print('Sorry, that is not a valid choice. Please choose a number between 1 and ' + str(len(instances_list)) + '.')

	return choice - 1

def get_ssh_key_from_s3(args, instance):
	print("INFO: Downloading s3://ssh-key-"+instance['KeyName']+"/"+instance['KeyName'])

	if args.profile:
		conn_s3 = get_aws_session(args.profile).client('s3')
	else:
		conn_s3 = get_aws_subaccount_credentials(args, 's3')

	conn_s3.download_file("ssh-key-"+instance['KeyName'], instance['KeyName'], os.path.expanduser("~") + "/.ssh/" + instance['KeyName'])
	s = '400'
	os.chmod(os.path.expanduser("~") + "/.ssh/" + instance['KeyName'], int(s, base=8))

def ssh_into_instance(args, instance_id, custom_ssh=None, command=None):

	if args.profile:
		conn_ec2 = get_aws_session(args.profile).client('ec2')
	else:
		conn_ec2 = get_aws_subaccount_credentials(args, 'ec2')

	instance = conn_ec2.describe_instances(InstanceIds=[instance_id])['Reservations'][0]['Instances'][0]

	keypair_file = "~/.ssh/"+instance['KeyName']

	# Download the ssh key from s3 in case it is not present at "~/.ssh/...."
	if not os.path.isfile(os.path.expanduser("~") + "/.ssh/" + instance['KeyName']):
		get_ssh_key_from_s3(args, instance)

	try:
		ip = instance['PrivateIpAddress']
	except KeyError:
		# Now allows access to private subnet
		if 'PrivateIpAddress' in instance and 'PrivateDnsName' in instance:
			ip = instance['PrivateDnsName']
		else:
			raise

	user = 'ec2-user'

	# do ssh
	try:
		ssh_command = ['ssh', '-i', keypair_file, user + '@' + ip]

		print('INFO: Running ' + ' '.join(ssh_command))
		returncode = subprocess.call(ssh_command)
		if returncode != 0:
			raise CommandError('An error occurred while running: ' + ssh_command[0] + '.')
	except OSError:
		raise

def main(args):

	# Get the instances list
	instances_list = get_instances_list(args)

	# Prompt the user to chose an instance
	instance_chosen = prompt_for_instance_in_list(instances_list)

	# Open the ssh connection using the internal IP
	ssh_into_instance(args, instances_list[instance_chosen])


### Init ###

if 'ssh' in [arguments for arguments in sys.argv] or len(sys.argv) == 1:

	parser = argparse.ArgumentParser(add_help=False)

	parser.add_argument('ssh')

	# group_role_profile = parser.add_mutually_exclusive_group(required=True)
	# group_role_profile.add_argument('--profile',help='--profile <application-[production|staging]>')
	# group_role_profile.add_argument('--role',help='--role [Administrators|Developers]>')
	parser.add_argument('--profile',help='--profile <application-[production|staging]>', required=False)
	parser.add_argument('environment_name', help='<application-[production|staging]-[web|activejobs|cronjobs]>')

	args = parser.parse_args()

	main(args)

else:
	print('Forwarding command to the official awsebcli..')

	aws_eb_cli = ['eb']

	for eb_arguments in sys.argv[1:]:
		aws_eb_cli.append(eb_arguments)

	subprocess.call(aws_eb_cli)




