#!python
"""
S3 download plugin for CRDS.

Exit status:
0 - success
1 - Invalid arg or missing environment vars
10 - failed file size verification
11 - failed checksum verification
other codes - aws-cli failure.  See https://docs.aws.amazon.com/cli/latest/topic/return-codes.html

As of 2021-04-23, copies of this script are maintained in the crds
and caldp repositories.  Please ensure that any bug fixes make it into
both!
"""
import argparse
import os
import random
import subprocess
import time
import sys

from crds.core.utils import checksum
from crds.client.api import get_default_observatory
from crds.core import config 

BAD_SIZE_STATUS = 10
BAD_CHECKSUM_STATUS = 11

try:
    import boto3
    import awscli
except ImportError:
    boto3 = None
    awscli = None

def check_aws_imports():
    return boto3 is not None

def create_s3_path(s3_prefix : str, kind : str, fname : str, obs: str) -> str:
    """Creates an s3 object uri string starting with s3_prefix and following standard crds cache directory structure.
    """
    kind = f"/{kind}" if s3_prefix else kind
    return f"{s3_prefix.rstrip('/')}{kind}/{obs}/{fname}"


def get_file_attrs(fname):
    pfx = fname.split("_")[0]
    if fname.endswith("map"):
        kind = "mappings"
        obs = config.mapping_to_observatory(fname)
        instr = ""
    elif fname.split(".")[-1] in ["asdf", "fits"]:
        kind = "references" # assumes only refs and mappings, not config files
        obs = pfx if pfx in ["roman", "jwst"] else "hst"
        instr = str(os.path.basename(os.path.dirname(config.locate_reference(fname, obs))))
    else:
        print(f"Unrecognized file type for `source`: {fname} - valid types are mappings (.pmap, .imap, .rmap) and references (.asdf, .fits)")
        sys.exit(1)
    return obs, instr, kind


def format_uris(**kwargs):
    source, destination = kwargs.pop("source"), kwargs.pop("destination")
    fname = source.split("/")[-1]
    obs, instr, kind = get_file_attrs(fname)
    
    if not source.startswith("s3"):
        s3_bucket = os.environ.get("CRDS_S3_BUCKET")
        if not s3_bucket:
            print("Please set the CRDS_S3_BUCKET environment variable or pass the full S3 URI into `source` arg")
            sys.exit(1)
        
        s3_pfx = "/roman/crds" if obs == "roman" else ""
        s3_key = create_s3_path(s3_pfx, kind, fname, obs)
        s3_uri = f"s3://{s3_bucket}/{s3_key.lstrip('/')}"
    else:
        s3_uri = source

    crds_path = os.environ.get("CRDS_PATH", config.get_crds_path())
    if os.path.isdir(destination):
        if instr and config.get_crds_ref_subdir_mode(obs) == "instrument":
            instr += "/"
        else:
            instr = ""
        dest = f"{crds_path}/{kind}/{obs}/{instr}{fname}" if destination == crds_path else f"{destination.rstrip('/')}/{fname}"
    else:
        dest = destination
    
    os.makedirs(os.path.dirname(dest), exist_ok=True)
    return s3_uri, dest


def parse_args():
    parser = argparse.ArgumentParser("crds_s3_get", description="S3 download plugin for CRDS")
    parser.add_argument("source", help="filename to download or full S3 URI to the file")
    parser.add_argument("-d", "--destination", help="Destination path on local filesystem", default=os.environ.get("CRDS_PATH", None))
    parser.add_argument("-s", "--file-size", help="Expected file size in bytes", type=int, default=None)
    parser.add_argument("-c", "--file-sha1sum", help="Expected file SHA-1 checksum", default=None)
    parser.add_argument("-r", "--max-retries", help="Maximum number of retries on download failure", type=int, default=3)
    return parser.parse_args()


def main():
    if not check_aws_imports():
        raise ImportError(
            "You must install awscli and boto3 for the crds_s3_get script to work. "
            "AWS dependencies for CRDS can be installed via `pip install crds[aws]`"
        )
    args = parse_args()
    if args.destination is None:
        print("Destination path defaults to CRDS_PATH but no value was set. \n" \
        "Please set the CRDS_PATH variable e.g. `export CRDS_PATH=path/to/local/cache` \n" \
        "or pass an absolute path on local disk where you want the file to be downloaded: \n" \
        "`crds_s3_get myfile -d abs/path/to/download")
        sys.exit(1)
    kwargs = {**vars(args)}
    src, dest = format_uris(**kwargs)
    

    result = subprocess.run([
        "aws", "s3", "cp", "--no-progress",
        src,
        dest,
    ], stdout=subprocess.PIPE, stderr=subprocess.STDOUT, encoding="utf-8")

    if result.returncode != 0:
        output = "\\n".join(result.stdout.strip().splitlines())
        print(f"crds_s3_get: Failed to download '{args.source}' with return code {result.returncode}: {output}", file=sys.stderr)
        sys.exit(result.returncode)

    if args.file_size is not None:
        downloaded_size = os.path.getsize(args.destination)
        if downloaded_size != args.file_size:
            print(f"crds_s3_get: '{args.source}' failed file size check.  Expected: {args.file_size} Received: {downloaded_size}")
            os.unlink(args.destination)
            sys.exit(BAD_SIZE_STATUS)

    if args.file_sha1sum is not None:
        downloaded_sha1sum = checksum(args.destination)
        if downloaded_sha1sum != args.file_sha1sum:
            print(f"crds_s3_get: '{args.source}' failed checksum.  Expected: {args.file_sha1sum} Received: {downloaded_sha1sum}")
            os.unlink(args.destination)
            sys.exit(BAD_CHECKSUM_STATUS)


if __name__ == "__main__":
    main()
