#!/usr/bin/env python3

"""
Command line utility to invoke the functionality of the smb* jobs.

Relies on the lava connector subsystem.

"""

import argparse
import logging
import os
import sys
from shutil import rmtree
from tempfile import mkdtemp

import boto3

from lava.connection import get_smb_connection
from lava.lavacore import LOGNAME, LavaError, get_realm_info
from lava.lib.aws import s3_download, s3_split, s3_upload
from lava.lib.logging import setup_logging
from lava.lib.smb import SMBOperationError, smb_mkdirs
from lava.version import __version__

__author__ = 'Murray Andrews'

PROG = os.path.splitext(os.path.basename(sys.argv[0]))[0]
LOG = logging.getLogger(LOGNAME)
LOGLEVEL = 'info'


# ------------------------------------------------------------------------------
def process_cli_args() -> argparse.Namespace:
    """
    Process the command line arguments.

    :return:    The args namespace.
    """

    argp = argparse.ArgumentParser(prog=PROG, description='Operate on SMB file shares.')

    argp.add_argument('--profile', action='store', help='As for AWS CLI.')

    argp.add_argument(
        '-c', '--conn-id', dest='conn_id', required=True, action='store', help='Lava connection ID.'
    )

    argp.add_argument(
        '-r',
        '--realm',
        action='store',
        help=(
            'Lava realm name. If not specified, the environment variable LAVA_REALM must be set.'
        ),
    )

    argp.add_argument(
        '--tmpdir',
        action='store',
        help=(
            'Place temporary files in the specified directory.'
            ' This directory must not exist and will be deleted on exit.'
        ),
    )

    argp.add_argument('-v', '--version', action='version', version=__version__)

    # ------------------------------
    # Logging options

    logp = argp.add_argument_group('logging arguments')
    logp.add_argument(
        '--no-colour',
        '--no-color',
        dest='colour',
        action='store_false',
        default=True,
        help='Don\'t use colour in information messages.',
    )

    logp.add_argument(
        '-l',
        '--level',
        metavar='LEVEL',
        default=LOGLEVEL,
        help=(
            'Print messages of a given severity level or above. The standard logging level names'
            ' are available but debug, info, warning and error are most useful.'
            f' The Default is {LOGLEVEL}.'
        ),
    )

    logp.add_argument(
        '--log',
        action='store',
        help='Log to the specified target. This can be either a file'
        ' name or a syslog facility with an @ prefix (e.g. @local0).',
    )

    logp.add_argument(
        '--tag',
        action='store',
        default=PROG,
        help=f'Tag log entries with the specified value. The default is {PROG}.',
    )

    # ----------------------------------------
    # Sub parsers

    subp = argp.add_subparsers()

    # . . . . . . . . . . . . . . . . . . . .
    # "put" command

    c_put = subp.add_parser('put', help='Copy a file to an SMB file share.')
    c_put.set_defaults(func=cmd_put)

    c_put.add_argument(
        '-m',
        '--mkdir',
        action='store_true',
        help='Create the target directory if it doesn\'t exist',
    )

    c_put.add_argument(
        'file',
        action='store',
        help='Source file. Values starting with s3:// will be copied from S3.'
        ' This will be jinja rendered.',
    )

    c_put.add_argument(
        'smb_path',
        action='store',
        metavar='SMB-path',
        help='Target location. Must be in the form share-name:path.'
        ' This will be jinja rendered.',
    )

    # . . . . . . . . . . . . . . . . . . . .
    # "get" command

    c_get = subp.add_parser('get', help='Copy a file from an SMB file share.')
    c_get.set_defaults(func=cmd_get)

    c_get.add_argument(
        '-k',
        '--kms-key-id',
        dest='kms_key_id',
        action='store',
        help='AWS KMS key to use for uploading data to S3.',
    )

    c_get.add_argument(
        'smb_path',
        action='store',
        metavar='SMB-path',
        help='Source location. Must be in the form share-name:path.'
        ' This will be jinja rendered.',
    )

    c_get.add_argument(
        'file',
        action='store',
        help='Target file. Values starting with s3:// will be copied to S3.'
        ' This will be jinja rendered.',
    )

    args = argp.parse_args()
    if not args.realm:
        try:
            args.realm = os.environ['LAVA_REALM']
        except KeyError:
            argp.error(
                'Lava realm must be specified via -r, --realm or LAVA_REALM environment variable.'
            )

    return args


# ---------------------------------------------------------------------------------------
def cmd_get(args: argparse.Namespace) -> None:
    """
    Get a file from an SMB share.

    The namespace argument must contain the following:

        - realm
            Realm name.

        - conn_id
            A connection ID for the source SMB server.

        - smb_path
            Source path within the remote file share. It must be in the form
            `share:path`.

        - file
            Target file. If it starts with `s3://` it is assumed to be an object in
            S3, otherwise a local file. If the local file is not absolute, it is
            assumed to be relative to the lava temporary directory. This will be
            jinja rendered.

        - kms_key_id*
            AWS KMS key to use for uploading data to S3.


    :param args:      The argparse arguments namespace.

    :rtype:         None
    """

    filedir = args.tmpdir or args.basedir

    try:
        share_name, path = args.smb_path.split(':')
    except ValueError:
        raise Exception(f'{args.smb_path}: Bad path - must be in the form share:path')

    aws_session = boto3.Session(profile_name=args.profile)
    realm_table = aws_session.resource('dynamodb').Table('lava.realms')

    # ----------------------------------------
    # Get realm info

    realm_info = get_realm_info(args.realm, realm_table)
    LOG.debug(f'Realm info: {realm_info}')

    # ----------------------------------------
    # Prepare destination file

    dst_file = args.file

    if dst_file.startswith('s3://'):
        dst_bucket, dst_key = s3_split(dst_file)
        dst_local = os.path.join(filedir, os.path.basename(dst_key))
    elif os.path.isabs(dst_file):
        # Local absolute path
        dst_bucket, dst_key = None, None
        dst_local = dst_file
    else:
        # Local relative path.
        dst_bucket, dst_key = None, None
        dst_local = os.path.abspath(dst_file)

    # ----------------------------------------
    # Get file from SMB file share.

    LOG.debug(f'Getting SMB connection for conn_id {args.conn_id}')
    conn = get_smb_connection(args.conn_id, args.realm, aws_session=aws_session)

    try:
        LOG.debug(f'Checking {share_name}:{path} exists and is normal file')

        if conn.get_attributes(share_name, path).is_directory:
            raise LavaError(f'{share_name}:{path}: is a directory')

        LOG.debug(f'Downloading {path} to {dst_local}')
        with open(dst_local, 'wb') as fp:
            conn.retrieve_file(share_name, path, fp)

    except SMBOperationError as e:
        raise LavaError(e, data={'NTSTATUS': e.ntstatus}) from e

    finally:
        try:
            LOG.debug(f'Closing SMB connection for conn_id {args.conn_id}')
            conn.close()
        except Exception as e:
            raise LavaError(f'Could not close SMB connection: {e}')

    # ----------------------------------------
    # Push it back to S3
    if dst_bucket:
        if not aws_session:
            aws_session = boto3.Session()

        LOG.debug(f'Uploading {dst_local} to s3://{dst_bucket}/{dst_key}')
        s3_upload(
            bucket=dst_bucket,
            key=dst_key,
            filename=dst_local,
            kms_key=args.kms_key_id,
            s3_client=aws_session.client('s3'),
        )


# ---------------------------------------------------------------------------------------
def cmd_put(args: argparse.Namespace) -> None:
    """
    Put a file into an SMB share.

    The namespace argument must contain the following:

    - realm
        Realm name.

    - conn_id
        A connection ID for the target SMB server.

    - smb_path
        Target path within the remote file share. It must be in the form `share:path`.

    - file
        Source file. If it starts with `s3://` it is assumed to be an object in
        S3, otherwise a local file. If the local file is not absolute, it is
        assumed to be relative to the lava temporary directory. This will be
        jinja rendered.

    - mkdir
        If True, the target directory will be created if it doesn't exist.
        Default is False.

    :param args:        The argparse arguments namespace.

    """

    # PyCharm is a bit clueless at times.
    # noinspection PyTestUnpassedFixture
    filedir = args.tmpdir or args.basedir

    try:
        share_name, path = args.smb_path.split(':')
    except ValueError:
        raise Exception(f'{args.smb_path}: Bad path - must be in the form share:path')

    aws_session = boto3.Session(profile_name=args.profile)
    realm_table = aws_session.resource('dynamodb').Table('lava.realms')

    # ----------------------------------------
    # Get realm info

    realm_info = get_realm_info(args.realm, realm_table)
    LOG.debug(f'Realm info: {realm_info}')

    # ----------------------------------------
    # Prepare source file

    src_file = args.file

    if src_file.startswith('s3://'):
        if not aws_session:
            aws_session = boto3.Session()
        src_bucket, src_key = s3_split(src_file)
        src_local = os.path.join(filedir, os.path.basename(src_key))
        s3_download(src_bucket, src_key, src_local, aws_session.client('s3'))
    elif os.path.isabs(src_file):
        src_local = src_file
    else:
        src_local = os.path.abspath(src_file)

    # ----------------------------------------
    # Copy to SMB file share.

    LOG.debug(f'Getting SMB connection for conn_id {args.conn_id}')
    conn = get_smb_connection(args.conn_id, args.realm, aws_session=aws_session)

    try:
        if args.mkdir:
            smb_dir = os.path.split(path)[0]
            if smb_dir:
                LOG.debug(f'Creating {share_name}:{smb_dir}')
                smb_mkdirs(conn, share_name, smb_dir)

        with open(src_local, 'rb') as fp:
            LOG.debug(f'Uploading {src_local} to {share_name}:{path}')
            conn.store_file(share_name, path, fp)

    except SMBOperationError as e:
        raise LavaError(e, data={'NTSTATUS': e.ntstatus}) from e

    finally:
        try:
            LOG.debug(f'Closing SMB connection for conn_id {args.conn_id}')
            conn.close()
        except Exception as e:
            LOG.error(f'Could not close SMB connection: {e}')


# ---------------------------------------------------------------------------------------
def main() -> int:
    """
    Do the business.

    :return:    status
    """

    setup_logging(LOGLEVEL, name=LOGNAME, prefix=PROG)
    args = process_cli_args()

    setup_logging(args.level, name=LOGNAME, target=args.log, colour=args.colour, prefix=args.tag)

    if args.tmpdir:
        os.makedirs(args.tmpdir)
    else:
        args.tmpdir = mkdtemp(prefix=f'{PROG}.')

    try:
        args.func(args)
    finally:
        try:
            rmtree(args.tmpdir, ignore_errors=True)
        except Exception as e:
            LOG.warning(f'Cannot remove {args.tmpdir}: {e}')
        else:
            LOG.debug(f'Removed {args.tmpdir}')

    return 0


# ------------------------------------------------------------------------------
if __name__ == '__main__':
    # Uncomment for debugging
    # exit(main())  # noqa: ERA001
    try:
        exit(main())
    except Exception as ex:
        print(f'{PROG}: {ex}', file=sys.stderr)
        exit(1)
    except KeyboardInterrupt:
        print('Interrupt', file=sys.stderr)
        exit(2)
