#!/usr/bin/env python
# tool to automate various AWS commands
import datetime as dt
import os
import shlex
import subprocess
import sys
import time

import pytz

import ncluster

from ncluster import aws_util as u
from ncluster import util
from ncluster.aws_backend import INSTANCE_INFO
from boto3_type_annotations.ec2 import Volume
from boto3_type_annotations.ec2 import Image

VERBOSE = False


def _run_shell(user_cmd):
  """Runs shell command, returns list of outputted lines
with newlines stripped"""
  #  print(cmd)
  p = subprocess.Popen(user_cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
  (stdout, _) = p.communicate()
  stdout = stdout.decode('ascii')
  lines = stdout.split('\n')
  stripped_lines = []
  for l in lines:
    stripped_line = l.strip()
    if l:
      stripped_lines.append(stripped_line)
  return stripped_lines


def _check_instance_found(instances, fragment, states=()):
  if not instances:
    print(f"Couldn't find instances in state {states} matching '{fragment}' for key {u.get_keypair_name()}")
    return False
  return True


def vprint(*args):
  if VERBOSE:
    print(*args)


def toseconds(dt_):
  """Converts datetime object to seconds."""
  return time.mktime(dt_.utctimetuple())


def ls(fragment=''):
  """List running instances"""
  print(f"https://console.aws.amazon.com/ec2/v2/home?region={u.get_region()}")

  stopped_instances = u.lookup_instances(fragment, valid_states=['stopped'])
  stopped_names = list(u.get_name(i) for i in stopped_instances)
  if stopped_names:
    print("ignored stopped instances: ", ", ".join(stopped_names))

  instances = u.lookup_instances(fragment)
  print('-' * 80)
  print(
    f"{'name':18s} {'hours_live':>10s} {'cost_in_$':>10s} {'instance_type':>15s} {'public_ip':>15s} "
    f"{'key/owner':>15s} {'private ip':>15s}")
  print('-' * 80)
  for instance in instances[::-1]:
    # current time in UTC zone (default AWS)
    now_time = dt.datetime.utcnow().replace(tzinfo=pytz.utc)
    launch_time = instance.launch_time
    elapsed_sec = toseconds(now_time) - toseconds(launch_time)
    elapsed_hours = elapsed_sec / 3600
    instance_type = instance.instance_type
    if instance_type in INSTANCE_INFO:
      cost = INSTANCE_INFO[instance_type]['cost'] * elapsed_hours
    else:
      cost = -1
    key_name = str(instance.key_name)  # could be None
    print(f"{u.get_name(instance):18s} {elapsed_sec / 3600:10.1f} {cost:10.0f} {instance_type[:5]:>15s} "
          f"{instance.public_ip_address:>15s} {key_name[9:]:>15s} {instance.private_ip_address:>15s} "
          f"{instance.placement_group.name} ")

  # list spot requests, ignore active ones since they show up already
  client = u.get_ec2_client()
  spot_requests = []
  for request in client.describe_spot_instance_requests()['SpotInstanceRequests']:
    state = request['State']
    # TODO(y) also ignore state == 'fulfilled'?
    if state == 'cancelled' or state == 'closed' or state == 'active':
      continue

    launch_spec = request['LaunchSpecification']
    spot_requests.append(launch_spec['InstanceType'])
  if spot_requests:
    print(f"Pending spot instances: {','.join(spot_requests)}")
  #   client.cancel_spot_instance_requests(SpotInstanceRequestIds=[request['SpotInstanceRequestId']])


def etchosts(_):
  """Copy/pastable /etc/hosts file"""
  instances = u.lookup_instances()
  instance_tuples = [(u.get_name(i), i.public_ip_address) for i in instances]
  print('-' * 80)
  print("paste following into your /etc/hosts")
  print('-' * 80)
  for name, ip in sorted(instance_tuples):
    print(f"{ip} {name}")

  print("""\n127.0.0.1    localhost
255.255.255.255    broadcasthost
::1             localhost""")

def _user_keypair_check(instance):
  launching_user = instance.key_name[len(u.get_prefix()) + 1:]
  current_user = os.environ['USER']
  assert launching_user == current_user, f"Set USER={launching_user} to connect to this machine, and make sure their " \
                                         f".pem file is in your ~/.ncluster"


def ssh(fragment=''):
  """SSH into the instace with the given prefix."""
  instances = u.lookup_instances(fragment)
  if not _check_instance_found(instances, fragment):
    return
  instance = instances[0]
  if len(instances) > 1:
    print(f"Found {len(instances)} instances matching {fragment}, connecting to most recent  {u.get_name(instance)} "
          f"launched by {instance.key_name}")
  else:
    print(f"Connecting to  {u.get_name(instance)} "
          f"launched by {instance.key_name}")

  _user_keypair_check(instance)
  user_cmd = f"ssh -t -i {u.get_keypair_fn()} -o StrictHostKeyChecking=no -o ServerAliveCountMax=1 " \
             f"-o ServerAliveInterval=60 " \
             f"{u.get_aws_username(instance)}@{instance.public_ip_address} "
  print(user_cmd)
  os.system(user_cmd)


def reboot(fragment=''):
  """reboots given instance."""
  instances = u.lookup_instances(fragment)
  if not _check_instance_found(instances, fragment):
    return
  instance = instances[0]
  if len(instances) > 1:
    print(f"Found {len(instances)} instances matching {fragment}, connecting to most recent  {u.get_name(instance)} "
          f"launched by {instance.key_name}")
  else:
    print(f"Rebooting to  {u.get_name(instance)}  ({instance.id})"
          f"launched by {instance.key_name}")

  _user_keypair_check(instance)
  instance.reboot()


def old_ssh(fragment=''):
  """SSH into the instace with the given prefix. Works on dumb terminals."""
  instances = u.lookup_instances(fragment)
  if not _check_instance_found(instances, fragment):
    return
  instance = instances[0]
  if len(instances) > 1:
    print(f"Found {len(instances)} instances matching {fragment}, connecting to most recent  {u.get_name(instance)} "
          f"launched by {instance.key_name}")
  else:
    print(f"Connecting to  {u.get_name(instance)} "
          f"launched by {instance.key_name}")

  _user_keypair_check(instance)
  user_cmd = f"ssh -i {u.get_keypair_fn()} -o StrictHostKeyChecking=no -o ConnectTimeout=10  " \
             f"-o ServerAliveCountMax=1 " \
             f"-o ServerAliveInterval=60 " \
             f"{u.get_aws_username(instance)}@{instance.public_ip_address}"
  print(user_cmd)
  os.system(user_cmd)


def connect(fragment=''):
  """SSH into the instance using authorized keys mechanism."""
  instances = u.lookup_instances(fragment)
  if not _check_instance_found(instances, fragment):
    return
  instance = instances[0]
  if len(instances) > 1:
    print(f"Found {len(instances)} instances matching {fragment}, connecting to most recent  {u.get_name(instance)} "
          f"launched by {instance.key_name}")
  else:
    print(f"Connecting to  {u.get_name(instance)} "
          f"launched by {instance.key_name}")

  ssh_cmd = f"ssh -t -o StrictHostKeyChecking=no -o ConnectTimeout=10  " \
            f"-o ServerAliveCountMax=1 " \
            f"-o ServerAliveInterval=60 " \
            f"{u.get_aws_username(instance)}@{instance.public_ip_address} "
  if 'INSIDE_EMACS' in os.environ:
    print("detected Emacs, skipping tmux attach")
  pass
  if 'NO_TMUX' in os.environ:
    print("detected NO_TMUX, skipping tmux attach")
    pass

  else:
    connect_cmd = ssh_cmd + " tmux a"
  print(connect_cmd)
  exit_code = os.system(connect_cmd)
  if exit_code != 0:
    print(f"Creating ssh tmux a returned {exit_code}, recreating tmux")
    fix_cmd = ssh_cmd + " tmux new"
    exit_code = os.system(fix_cmd)
    if exit_code != 0:
      print(f"command {fix_cmd} failed with code '{exit_code}'")


def connectm(fragment=''):
  """Like connect, but uses mosh"""
  instances = u.lookup_instances(fragment)
  if not _check_instance_found(instances, fragment):
    return
  instance = instances[0]
  if len(instances) > 1:
    print(f"Found {len(instances)} instances matching {fragment}, connecting to most recent  {u.get_name(instance)} "
          f"launched by {instance.key_name}")
  else:
    print(f"Connecting to  {u.get_name(instance)} "
          f"launched by {instance.key_name}")

  user_cmd = f"mosh --ssh='ssh -o StrictHostKeyChecking=no -o ConnectTimeout=10  " \
             f"-o ServerAliveCountMax=1 " \
             f"-o ServerAliveInterval=60' " \
             f"{u.get_aws_username(instance)}@{instance.public_ip_address}"
  print(user_cmd)
  os.system(user_cmd)


def kill(fragment: str = '', stop_instead_of_kill: bool = False):
  """

  Args:
    fragment:
    stop_instead_of_kill: use stop_instances instead of terminate_instances
  """

  if stop_instead_of_kill:
    valid_states = ['running']
  else:
    valid_states = ['running', 'stopped']
  instances = u.lookup_instances(fragment, valid_states=valid_states, limit_to_current_user=False)
  instances_to_kill = []
  instances_to_skip = []
  users_to_skip = set()
  instances_to_kill_formatted = []
  for i in instances:
    state = i.state['Name']

    if not LIMIT_TO_CURRENT_USER or i.key_name == u.get_keypair_name():
      instances_to_kill_formatted.append(("  ", u.get_name(i), i.instance_type, i.key_name, state if state == 'stopped' else ''))
      instances_to_kill.append(i)
    else:
      instances_to_skip.append(u.get_name(i))
      users_to_skip.add(i.key_name[9:])

  if stop_instead_of_kill:
    action = 'stopping'
    override_action = 'reallystop'
  else:
    action = 'terminating'
    override_action = 'reallykill'

  if instances_to_skip:
    print(f"Skipping {','.join(instances_to_skip)} launched by ({', '.join(users_to_skip)}), override with {override_action}")
  if not _check_instance_found(instances_to_kill, fragment, valid_states):
    return

  print(f"{action}:")
  for line in instances_to_kill_formatted:
    print(*line)

  ec2_client = u.get_ec2_client()
  # don't ask for confirmation when stopping, erronous stopping has milder consequences
  num_instances = len(instances_to_kill)
  #  if util.is_set("NCLUSTER_SKIP_CONFIRMATION") or stop_instead_of_kill:
  #  print("NCLUSTER_SKIP_CONFIRMATION is set or stop_instead_of_kill, skipping confirmation")
  answer = input(f"{num_instances} instances found, {action} in {u.get_region()}? (y/N) ")

  if answer.lower() == "y":
    instance_ids = [i.id for i in instances_to_kill]

    if stop_instead_of_kill:
      response = ec2_client.stop_instances(InstanceIds=instance_ids)
    else:
      response = ec2_client.terminate_instances(InstanceIds=instance_ids)

    assert u.is_good_response(response), response
    print(f"{action} {num_instances} instances: success")
  else:
    print("Didn't get y, doing nothing")


def stop(fragment=''):
  kill(fragment, stop_instead_of_kill=True)


LIMIT_TO_CURRENT_USER = True


def reallykill(*args, **kwargs):
  """Kill instances, including ones launched by other users."""
  global LIMIT_TO_CURRENT_USER
  LIMIT_TO_CURRENT_USER = False
  kill(*args, **kwargs)
  LIMIT_TO_CURRENT_USER = True


def reallystop(*args, **kwargs):
  """Stop instances, including ones launched by other users."""
  global LIMIT_TO_CURRENT_USER
  LIMIT_TO_CURRENT_USER = False
  stop(*args, **kwargs)
  LIMIT_TO_CURRENT_USER = True


def start(fragment=''):
  instances = u.lookup_instances(fragment, valid_states=['stopped'])
  for i in instances:
    print(u.get_name(i), i.instance_type, i.key_name)

  if not instances:
    print("no stopped instances found, quitting")
    return

  #  answer = input(f"{len(instances)} instances found, start in {u.get_region()}? (y/N) ")
  answer = 'y'

  if answer.lower() == "y":
    for i in instances:
      print(f"starting {u.get_name(i)}")
      i.start()
  else:
    print("Didn't get y, doing nothing")
    return

  print("Warning, need to manually mount efs on instance: ")
  print_efs_mount_command()


def mosh(fragment=''):
  instances = u.lookup_instances(fragment)
  if not _check_instance_found(instances, fragment):
    return
  instance = instances[0]
  print(f"Found {len(instances)} instances matching {fragment}, connecting to most recent  {u.get_name(instance)}")
  _user_keypair_check(instance)

  user_cmd = f"mosh --ssh='ssh -i {u.get_keypair_fn()} -o StrictHostKeyChecking=no' " \
             f"{u.get_aws_username(instance)}@{instance.public_ip_address} tmux attach"
  print(user_cmd)
  os.system(user_cmd)


def print_efs_mount_command():
  print(u.get_efs_mount_command())


def efs(_):
  print("EFS information. To upload to remote EFS use 'ncluster efs_sync'")
  print_efs_mount_command()
  print()
  print()

  efs_client = u.get_efs_client()
  response = efs_client.describe_file_systems()
  assert u.is_good_response(response), response

  for efs_response in response['FileSystems']:
    #  {'CreationTime': datetime.datetime(2017, 12, 19, 10, 3, 44, tzinfo=tzlocal()),
    # 'CreationToken': '1513706624330134',
    # 'Encrypted': False,
    # 'FileSystemId': 'fs-0f95ab46',
    # 'LifeCycleState': 'available',
    # 'Name': 'nexus01',
    # 'NumberOfMountTargets': 0,
    # 'OwnerId': '316880547378',
    # 'PerformanceMode': 'generalPurpose',
    # 'SizeInBytes': {'Value': 6144}},
    efs_id = efs_response['FileSystemId']
    tags_response = efs_client.describe_tags(FileSystemId=efs_id)
    assert u.is_good_response(tags_response)
    key = u.get_name(tags_response.get('Tags', ''))
    print("%-16s %-16s" % (efs_id, key))
    print('-' * 40)

    # list mount points
    response = efs_client.describe_mount_targets(FileSystemId=efs_id)
    ec2 = u.get_ec2_resource()
    if not response['MountTargets']:
      print("<no mount targets>")
    else:
      for mount_response in response['MountTargets']:
        subnet = ec2.Subnet(mount_response['SubnetId'])
        zone = subnet.availability_zone
        state = mount_response['LifeCycleState']
        id_ = mount_response['MountTargetId']
        ip = mount_response['IpAddress']
        print('%-16s %-16s %-16s %-16s' % (zone, ip, id_, state,))


def terminate_tmux(_):
  """Script to clean-up tmux sessions."""

  for line in _run_shell('tmux ls'):
    session_name = line.split(':', 1)[0]

    if session_name == 'tensorboard' or session_name == 'jupyter' or session_name == 'dropbox':
      print("Skipping " + session_name)
      continue
    print("Killing " + session_name)
    _run_shell('tmux kill-session -t ' + session_name)


def nano(*_unused_args):
  """Bring up t2.nano instance."""
  ncluster.make_task(name='shell',
                     instance_type='t2.nano')


def cmd(user_cmd):
  """Finds most recent instance launched by user, runs commands there, pipes output to stdout"""

  instances = u.lookup_instances(limit_to_current_user=True)
  assert instances, f"{u.get_username()} doesn't have an instances to connect to. Use 'ncluster nano'" \
                    f" to bring up a small instance."
  instance = instances[0]
  user_cmd = f"ssh -t -i {u.get_keypair_fn()} -o StrictHostKeyChecking=no " \
             f"{u.get_aws_username(instance)}@{instance.public_ip_address} {user_cmd}"
  os.system(user_cmd)


def cat(user_cmd): cmd('cat ' + user_cmd)


def ls_(user_cmd): cmd('ls ' + user_cmd)


def cleanup_placement_groups(*_args):
  print("Deleting all placement groups")
  # TODO(y): don't delete groups that have currently stopped instances
  client = u.get_ec2_client()
  for group in client.describe_placement_groups().get('PlacementGroups', []):
    name = group['GroupName']
    sys.stdout.write(f"Deleting {name} ... ")
    sys.stdout.flush()
    try: 
      client.delete_placement_group(GroupName=name)
      print("success")
    except Exception as _:
      print("failed")


# ncluster launch --image_name=dlami23-efa --instance_type=c5.large --name=test
def launch(args_str):
  import argparse
  parser = argparse.ArgumentParser()
  parser.add_argument('--name', type=str, default='ncluster_launch', help="instance name")
  # parser.add_argument('--image_name', type=str, default='')  # default small image
  parser.add_argument('--image_name', type=str, default='Deep Learning AMI (Ubuntu) Version 23.0')
  # can also use --image_name='Deep Learning AMI (Amazon Linux) Version 23.0' # cybertronai01
  parser.add_argument('--instance_type', type=str, default='c5.large', help="type of instance")
  parser.add_argument('--disk_size', type=int, default=0, help="size of disk in GBs. If 0, use default size for the image")
  args = parser.parse_args(shlex.split(args_str))

  return ncluster.make_task(**vars(args))


def fix_default_security_group(_):
  """Allows ncluster and ncluster_nd security groups to exchange traffic with each other."""

  def peer(current, other):
    """allow current group to accept all traffic from other group"""

    groups = u.get_security_group_dict()
    current_group = groups[current]
    other_group = groups[other]
    response = {}
    for protocol in ['icmp']:
      try:
        rule = {'FromPort': -1,
                'IpProtocol': protocol,
                'IpRanges': [],
                'PrefixListIds': [],
                'ToPort': -1,
                'UserIdGroupPairs': [{'GroupId': other_group.id}]}
        response = current_group.authorize_ingress(IpPermissions=[rule])

      except Exception as e:
        if response['Error']['Code'] == 'InvalidPermission.Duplicate':
          print("Warning, got " + str(e))
        else:
          assert False, "Failed while authorizing ingress with " + str(e)

    for protocol in ['tcp', 'udp']:
      try:
        rule = {'FromPort': 0,
                'IpProtocol': protocol,
                'IpRanges': [],
                'PrefixListIds': [],
                'ToPort': 65535,
                'UserIdGroupPairs': [{'GroupId': other_group.id}]}
        response = current_group.authorize_ingress(IpPermissions=[rule])
      except Exception as e:
        if response['Error']['Code'] == 'InvalidPermission.Duplicate':
          print("Warning, got " + str(e))
        else:
          assert False, "Failed while authorizing ingress with " + str(e)

  group1 = u.get_security_group_name()
  group2 = u.get_security_group_nd_name()

  peer(group1, group2)
  peer(group2, group1)


def keys(_):
  """runs ssh-keygen if necessary, prints public key."""
  key = util.get_public_key()
  print("Your public key is below. Append all of your your team-members  public keys to NCLUSTER_AUTHORIZED_KEYS env var separated by ; ie \nNCLUSTER_AUTHORIZED_KEYS=<key1>;<key2>;<key3>\n")
  print(key)


def lookup_image(image_id: str) -> Image:
  """Looks up image from image id like 'ami-0cc96feef8c6bbff3', prints image.name"""
  # could use ec2.images.filter(ImageIds=['ami-0cc96feef8c6bbff3']
  assert image_id.startswith('ami-')
  ec2 = u.get_ec2_resource()
  images = list(ec2.images.filter(ImageIds=[image_id]))
  assert images, f"No images found with id={image_id}"
  assert len(images) == 1, f"Multiple images found with id={image_id}: {','.join(i.name for i in images)}"
  image = [im for im in images if im.id == image_id][0]
  print(image.name)
  return image


def grow_disks(fragment: str, target_size_gb=500):
    """Grows main disk for given machine to 500GB"""

    instance = u.lookup_instance(fragment)
    client = u.get_ec2_client()

    volumes = list(instance.volumes.all())
    for vol in volumes:
      volume: Volume = vol
      print("Growing %s to %s" % (volume.id, target_size_gb))
      response = client.modify_volume(VolumeId=volume.id, Size=target_size_gb)
      assert u.is_good_response(response)


def disks(fragment):
  """Print disk information for instance."""
  instances = u.lookup_instances(fragment, valid_states=('running', 'stopped'))
  for instance in instances:
    print()
    print(f"Disks on instance '{u.get_name(instance)}'")
    for volume in instance.volumes.all():
      device = volume.attachments[0]['Device']
      print(f"{device:>10s} {volume.size:>10d} {volume.volume_type:>10s} {volume.id:>30s} ")


def fixkeys(_):
  key_name = u.get_keypair_name()
  pairs = u.get_keypair_dict()
  if key_name not in pairs:
    print(f"Default keypair {key_name} does not exist, returning")
    return
  keypair = pairs[key_name]

  print(f"Deleting current user keypair {key_name}")
  ec2 = u.get_ec2_resource()
  instance_list = []
  for instance in ec2.instances.all():
    if instance.key_name == key_name:
      if instance.state == 'terminated':
        continue
      instance_list.append(instance)
  if instance_list:
    print("Warning, after deleting keypair, the following instances will be no longer accessible:")
    for i in instance_list:
      print(u.get_name(i), i.id)
    answer = input("Proceed? (y/N) ")
  else:
    answer = "y"
  if answer.lower() == 'y':
    keypair_fn = u.get_keypair_fn()
    if os.path.exists(keypair_fn):
      print(f"Deleting local .pem file '{keypair_fn}'")
      os.system(f'sudo rm -f {keypair_fn}')
    print(f"Deleting AWS keypair '{keypair.name}'")
    keypair.delete()


def efs_sync(_):
  """Starts a daemon to sync local /ncluster/sync with remote ncluster sync."""

  print("Syncing local /ncluster/sync to remote /ncluster/sync")
  if not os.path.exists('/ncluster/sync'):
    print("Local /ncluster/sync doesn't exist, creating")
    if not os.path.exists('/ncluster'):
      os.system('sudo mkdir /ncluster')
      os.system('chown `whoami` /ncluster')
    os.system('mkdir /ncluster/sync')

  instances = u.lookup_instances(limit_to_current_user=True)
  if instances:
    instance = instances[0]
    print(f"Found {len(instances)} instances owned by {u.get_username()}, using {u.get_name(instance)} for syncing")
  else:
    print(f"Found no instances by {u.get_username()}, Launching t2.nano instance to do the sync.")
    task = ncluster.make_task(name='shell', instance_type='t2.nano')
    instance = task.instance

  os.system(f'cd /ncluster/sync && nsync -m {u.get_name(instance)} -d /ncluster/sync')


COMMANDS = {
  'ls': ls,
  'ssh': ssh,
  'ssh_': old_ssh,
  'mosh': mosh,
  'kill': kill,
  'reallykill': reallykill,
  'stop': stop,
  'reallystop': reallystop,
  'start': start,
  'efs': efs,
  'cat': cat,
  'ls_': ls_,
  'nano': nano,
  'cmd': cmd,
  '/etc/hosts': etchosts,
  'hosts': etchosts,
  'terminate_tmux': terminate_tmux,
  'cleanup_placement_groups': cleanup_placement_groups,
  'lookup_image': lookup_image,
  'reboot': reboot,
  'launch': launch,
  'fix_default_security_group': fix_default_security_group,
  'keys': keys,
  'connect': connect,
  'connectm': connectm,
  'grow_disks': grow_disks,
  'disks': disks,
  'fixkeys': fixkeys,
  'efs_sync': efs_sync,
}


def main():
  print(f"Region ({u.get_region()}) $USER ({u.get_username()}) account ({u.get_account_number()}:{u.get_account_name()})")
  if len(sys.argv) < 2:
    mode = 'ls'
  else:
    mode = sys.argv[1]
  if mode == 'help':
    for k, v in COMMANDS.items():
      if v.__doc__:
        print(f'{k}\t{v.__doc__}')
      else:
        print(k)
    return
  if mode not in COMMANDS:
    assert False, f"unknown command '{mode}', available commands are {', '.join(str(a) for a in COMMANDS.keys())}"
  COMMANDS[mode](' '.join([shlex.quote(arg) for arg in sys.argv[2:]]))


if __name__ == '__main__':
  main()
