#!/usr/bin/python

import logging
import random
import os
import sys
import boto.sts
try:
  import ConfigParser
except:
  import configparser as ConfigParser
from os.path import expanduser

##########################################################################

profile = os.environ['AWS_PROFILE'] if 'AWS_PROFILE' in os.environ else ''

profile_section = 'profile ' + profile if profile != '' else 'default'

region = os.environ['AWS_REGION'] if 'AWS_REGION' in os.environ else 'eu-west-1'

aws_credentials_file = os.environ['AWS_SHARED_CREDENTIALS_FILE'] if 'AWS_SHARED_CREDENTIALS_FILE' in os.environ else '~/.aws/credentials'
aws_config_file = os.path.dirname(aws_credentials_file) + '/config'

# Uncomment to enable low level debugging
# logging.basicConfig(level=logging.DEBUG)

##########################################################################

config = ConfigParser.RawConfigParser()
config.read(expanduser(aws_config_file))

if not config.has_section(profile_section):
  print('No profile section defined in configuration file')
  sys.exit(1)

if config.has_option(profile_section, 'login_source_profile'):
  source_profile = config.get(profile_section, 'login_source_profile')
else:
  print('No source_profile specified for the selected profile.')
  sys.exit(1)

if config.has_option(profile_section, 'login_mfa_serial'):
  mfa_serial = config.get(profile_section, 'login_mfa_serial')
else:
  print('No mfa_serial specified for the selected profile.')
  sys.exit(1)

if config.has_option(profile_section, 'login_role_arn'):
  role_arn = config.get(profile_section, 'login_role_arn')
else:
  print('No role_arn specified for the selected profile.')
  sys.exit(1)

if config.has_option(profile_section, 'login_duration_seconds'):
  duration_seconds = config.get(profile_section, 'login_duration_seconds')
else:
  duration_seconds = 60 * 60

if config.has_option(profile_section, 'region'):
  region = config.get(profile_section, 'region')

print('Signing in using MFA device: ' + mfa_serial)

token = ''

try:
  input = raw_input
except NameError:
  pass
token = input("MFA token: ")

# Use the assertion to get an AWS STS token using Assume Role with SAML
conn = boto.sts.connect_to_region(region, profile_name=source_profile)
token = conn.assume_role(
  role_arn=role_arn,
  role_session_name=profile + '-' + str(random.randint(1, 9999999)),
  duration_seconds=duration_seconds,
  mfa_serial_number=mfa_serial,
  mfa_token=token,
)

# Write the AWS STS token into the AWS credential file

# Read in the existing config file
credentials_config_path = expanduser(aws_credentials_file)

credentials_config = ConfigParser.RawConfigParser()
credentials_config.read(credentials_config_path)

# Put the credentials into a saml specific section instead of clobbering
# the default credentials
if not credentials_config.has_section(profile):
    credentials_config.add_section(profile)

credentials_config.set(profile, 'aws_access_key_id', token.credentials.access_key)
credentials_config.set(profile, 'aws_secret_access_key', token.credentials.secret_key)
credentials_config.set(profile, 'aws_session_token', token.credentials.session_token)
credentials_config.set(profile, 'expiration', token.credentials.expiration)

# Write the updated config file
with open(credentials_config_path, 'w+') as config_file:
    credentials_config.write(config_file)

# Give the user some basic info as to what has just happened
print('\n\n--------------------------------------------------------------------------------------------------------')
print('Your are now signed in. The token will expire at {0}.'.format(token.credentials.expiration))
print('After this time, you may safely rerun this script to refresh your access key pair.')
print('--------------------------------------------------------------------------------------------------------\n\n')
