#!python
# tool to automate various AWS commands
import datetime as dt
import os
import subprocess
import sys
import time

import pytz

from ncluster import aws_util as u
from ncluster.aws_backend import INSTANCE_INFO

VERBOSE = False


def _run_shell(user_cmd):
  """Runs shell command, returns list of outputted lines
with newlines stripped"""
  #  print(cmd)
  p = subprocess.Popen(user_cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
  (stdout, _) = p.communicate()
  stdout = stdout.decode('ascii')
  lines = stdout.split('\n')
  stripped_lines = []
  for l in lines:
    stripped_line = l.strip()
    if l:
      stripped_lines.append(stripped_line)
  return stripped_lines


def _check_instance_found(instances, fragment):
    if not instances:
        print(f"Couldn't find instances matching {fragment}")
        return False
    return True


def vprint(*args):
    if VERBOSE:
        print(*args)


def toseconds(dt_):
    """Converts datetime object to seconds."""
    return time.mktime(dt_.utctimetuple())


def ls(fragment=''):
    """List running instances"""
    print(f"https://console.aws.amazon.com/ec2/v2/home?region={u.get_region()}")

    stopped_instances = u.lookup_instances(fragment, valid_states=['stopped'])
    stopped_names = list(u.get_name(i) for i in stopped_instances)
    if stopped_names:
        print("ignored stopped instances: ", ", ".join(stopped_names))

    instances = u.lookup_instances(fragment)
    print('-' * 80)
    print(
        f"{'name':15s} {'hours_live':>10s} {'cost_in_$':>10s} {'instance_type':>15s} {'ip_address':>15s} "
        f"{'key/owner':>15s}")
    print('-' * 80)
    for instance in instances[::-1]:
        # current time in UTC zone (default AWS)
        now_time = dt.datetime.utcnow().replace(tzinfo=pytz.utc)
        launch_time = instance.launch_time
        elapsed_sec = toseconds(now_time) - toseconds(launch_time)
        elapsed_hours = elapsed_sec / 3600
        instance_type = instance.instance_type
        if instance_type in INSTANCE_INFO:
            cost = INSTANCE_INFO[instance_type]['cost'] * elapsed_hours
        else:
            cost = -1
        print(f"{u.get_name(instance):15s} {elapsed_sec / 3600:10.1f} {cost:10.0f} {instance_type[:5]:>15s} "
              f"{instance.public_ip_address:>15s} {instance.key_name[9:]:>15s}")

    # list spot requests, ignore active ones since they show up already
    client = u.get_ec2_client()
    spot_requests = []
    for request in client.describe_spot_instance_requests()['SpotInstanceRequests']:
        state = request['State']
        # TODO(y) also ignore state == 'fulfilled'?
        if state == 'cancelled' or state == 'closed' or state == 'active':
            continue

        launch_spec = request['LaunchSpecification']
        spot_requests.append(launch_spec['InstanceType'])
    if spot_requests:
        print(f"Pending spot instances: {','.join(spot_requests)}")
    #   client.cancel_spot_instance_requests(SpotInstanceRequestIds=[request['SpotInstanceRequestId']])


def etchosts(_):
    """Copy/pastable /etc/hosts file"""
    instances = u.lookup_instances()
    instance_tuples = [(u.get_name(i), i.public_ip_address) for i in instances]
    print('-' * 80)
    print("paste following into your /etc/hosts")
    print('-' * 80)
    for name, ip in sorted(instance_tuples):
        print(f"{ip} {name}")


def _user_keypair_check(instance):
    launching_user = instance.key_name[len(u.get_prefix()) + 1:]
    current_user = os.environ['USER']
    assert launching_user == current_user, f"Set USER={launching_user} to connect to this machine"


def ssh(fragment=''):
    """SSH into the instace with the given prefix."""
    instances = u.lookup_instances(fragment)
    if not _check_instance_found(instances, fragment):
        return
    instance = instances[0]
    print(f"Found {len(instances)} instances matching {fragment}, connecting to most recent  {u.get_name(instance)} "
          f"launched by {instance.key_name}")

    _user_keypair_check(instance)
    user_cmd = f"ssh -t -i {u.get_keypair_fn()} -o StrictHostKeyChecking=no ubuntu@{instance.public_ip_address} " \
               f"tmux attach"
    print(user_cmd)
    os.system(user_cmd)


def old_ssh(fragment=''):
    """SSH into the instace with the given prefix. Works on dumb terminals."""
    instances = u.lookup_instances(fragment)
    if not _check_instance_found(instances, fragment):
        return
    instance = instances[0]
    print(f"Found {len(instances)} instances matching {fragment}, connecting to most recent  {u.get_name(instance)} "
          f"launched by {instance.key_name}")

    _user_keypair_check(instance)
    user_cmd = f"ssh -i {u.get_keypair_fn()} -o StrictHostKeyChecking=no ubuntu@{instance.public_ip_address}"
    print(user_cmd)
    os.system(user_cmd)


LIMIT_TO_CURRENT_USER = True


def kill(fragment=''):
    instances = u.lookup_instances(fragment, valid_states=['running', 'stopped'])
    instances_to_kill = []
    for i in instances:
        state = i.state['Name']
        if LIMIT_TO_CURRENT_USER and i.key_name != u.get_keypair_name():
            print(f"Skipping instance launched with key {i.key_name}, use reallykill to kill")
            continue
        print(u.get_name(i), i.instance_type, i.key_name,
              state if state == 'stopped' else '')
        instances_to_kill.append(i)

    action = 'terminating'
    if not _check_instance_found(instances, fragment):
        return

    answer = input(f"{len(instances_to_kill)} instances found, {action} in {u.get_region()}? (y/N) ")

    ec2_client = u.get_ec2_client()
    if answer.lower() == "y":
        instance_ids = [i.id for i in instances_to_kill]
        response = ec2_client.terminate_instances(InstanceIds=instance_ids)

        assert u.is_good_response(response), response
        print(f"{action}: success")
    else:
        print("Didn't get y, doing nothing")


def reallykill(*args, **kwargs):
    global LIMIT_TO_CURRENT_USER
    LIMIT_TO_CURRENT_USER = False
    kill(*args, **kwargs)
    LIMIT_TO_CURRENT_USER = True


def stop(fragment=''):
    instances = u.lookup_instances(fragment, valid_states=['running'])
    for i in instances:
        state = i.state['Name']
        print(u.get_name(i), i.instance_type, i.key_name,
              state if state == 'stopped' else '')

    action = 'stopping'
    if not _check_instance_found(instances, fragment):
        return

    answer = input(f"{len(instances)} instances found, {action} in {u.get_region()}? (y/N) ")

    ec2_client = u.get_ec2_client()
    if answer.lower() == "y":
        instance_ids = [i.id for i in instances]
        response = ec2_client.stop_instances(InstanceIds=instance_ids)

        assert u.is_good_response(response), response
        print(f"{action}: success")
    else:
        print("Didn't get y, doing nothing")


def start(fragment=''):
    instances = u.lookup_instances(fragment, valid_states=['stopped'])
    for i in instances:
        print(u.get_name(i), i.instance_type, i.key_name)

    if not instances:
        print("no stopped instances found, quitting")
        return

    answer = input(f"{len(instances)} instances found, start in {u.get_region()}? (y/N) ")

    if answer.lower() == "y":
        for i in instances:
            print(f"starting {u.get_name(i)}")
            i.start()
    else:
        print("Didn't get y, doing nothing")
        return

    print("Warning, need to manually mount efs on instance: ")
    print_efs_mount_command()


def mosh(fragment=''):
    instances = u.lookup_instances(fragment)
    if not _check_instance_found(instances, fragment):
        return
    instance = instances[0]
    _user_keypair_check(instance)
    print(f"Found {len(instances)} instances matching {fragment}, connecting to most recent  {u.get_name(instance)}")

    user_cmd = f"mosh --ssh='ssh -i {u.get_keypair_fn()} -o StrictHostKeyChecking=no' " \
               f"ubuntu@{instance.public_ip_address} tmux attach"
    print(user_cmd)
    os.system(user_cmd)


def print_efs_mount_command():
    region = u.get_region()
    efs_id = u.get_efs_dict()[u.get_prefix()]
    dns = f"{efs_id}.efs.{region}.amazonaws.com"
    print('sudo mkdir -p /ncluster')
    print(f"sudo mount -t nfs -o nfsvers=4.1,rsize=1048576,wsize=1048576,hard,timeo=600,retrans=2 {dns}:/ /ncluster")


def efs(_):
    print_efs_mount_command()
    print()
    print()

    efs_client = u.get_efs_client()
    response = efs_client.describe_file_systems()
    assert u.is_good_response(response), response

    for efs_response in response['FileSystems']:
        #  {'CreationTime': datetime.datetime(2017, 12, 19, 10, 3, 44, tzinfo=tzlocal()),
        # 'CreationToken': '1513706624330134',
        # 'Encrypted': False,
        # 'FileSystemId': 'fs-0f95ab46',
        # 'LifeCycleState': 'available',
        # 'Name': 'nexus01',
        # 'NumberOfMountTargets': 0,
        # 'OwnerId': '316880547378',
        # 'PerformanceMode': 'generalPurpose',
        # 'SizeInBytes': {'Value': 6144}},
        efs_id = efs_response['FileSystemId']
        tags_response = efs_client.describe_tags(FileSystemId=efs_id)
        assert u.is_good_response(tags_response)
        key = u.get_name(tags_response.get('Tags', ''))
        print("%-16s %-16s" % (efs_id, key))
        print('-' * 40)

        # list mount points
        response = efs_client.describe_mount_targets(FileSystemId=efs_id)
        ec2 = u.get_ec2_resource()
        if not response['MountTargets']:
            print("<no mount targets>")
        else:
            for mount_response in response['MountTargets']:
                subnet = ec2.Subnet(mount_response['SubnetId'])
                zone = subnet.availability_zone
                state = mount_response['LifeCycleState']
                id_ = mount_response['MountTargetId']
                ip = mount_response['IpAddress']
                print('%-16s %-16s %-16s %-16s' % (zone, ip, id_, state,))


def terminate_tmux(_):
    """Script to clean-up tmux sessions."""

    for line in _run_shell('tmux ls'):
        session_name = line.split(':', 1)[0]

        if session_name == 'tensorboard' or session_name == 'jupyter' or session_name == 'dropbox':
            print("Skipping " + session_name)
            continue
        print("Killing " + session_name)
        _run_shell('tmux kill-session -t ' + session_name)


def cmd(user_cmd):
    """Finds most recent instance launched by user, runs commands there, pipes output to stdout"""

    instances = u.lookup_instances(limit_to_current_user=True)
    assert instances, f"u.get_username() doesn't have an instances to connect to"
    instance = instances[0]
    user_cmd = f"ssh -t -i {u.get_keypair_fn()} -o StrictHostKeyChecking=no ubuntu@{instance.public_ip_address} " \
               f"{user_cmd}"
    os.system(user_cmd)


def cat(user_cmd): cmd('cat ' + user_cmd)


def ls_(user_cmd): cmd('ls ' + user_cmd)


MODES = {
    'ls': ls,
    'ssh': ssh,
    'ssh_': old_ssh,
    'mosh': mosh,
    'kill': kill,
    'reallykill': reallykill,
    'stop': stop,
    'start': start,
    'efs': efs,
    'cat': cat,
    'ls_': ls_,
    'cmd': cmd,
    '/etc/hosts': etchosts,
    'hosts': etchosts,
    'terminate_tmux': terminate_tmux,
}


def main():
    print(f"Region ({u.get_region()}) $USER ({u.get_username()}) account ({u.get_account_number()})")
    if len(sys.argv) < 2:
        mode = 'ls'
    else:
        mode = sys.argv[1]

    if len(sys.argv) < 3:
        fragment = ''
    else:
        fragment = sys.argv[2]

    if mode == 'help':
        for k, v in MODES.items():
            if v.__doc__:
                print(f'{k}\t{v.__doc__}')
            else:
                print(k)
        return
    MODES[mode](fragment)


if __name__ == '__main__':
    main()
