#!/usr/bin/env python

# TODO: https://cloud.google.com/storage/docs/json_api/v1/how-tos/batch

import os, sys, json, textwrap, readline, time, datetime, logging
from argparse import Namespace

import click, tweak, jwt, requests

from gs.util.printing import page_output, tabulate, GREEN, BLUE, BOLD
from gs.version import __version__
from gs.util.exceptions import NoServiceCredentials

@click.group()
@click.version_option(version=__version__)
def cli():
    """gs is a minimalistic CLI for Google Cloud Storage."""

@click.command()
def configure():
    """Set gs config options, including the API key."""
    msg = ("Please open " + BOLD("https://console.cloud.google.com/iam-admin/serviceaccounts") + ", create a service "
           "account and download its private key. The service account should have a role with Google Storage access. "
           "Drag & drop the key file into this terminal window, or paste the file location or JSON contents below.")
    print("\n".join(textwrap.wrap(msg, 120)))
    prompt = "Service account key file path or contents: "
    buf, filename = "", None
    while True:
        line = input(prompt).strip()
        if line == "":
            if buf == "":
                continue
            break
        if buf == "" and line != "{":
            filename = line
            break
        buf += line
        if line == "}":
            break
        prompt = ""
    if filename:
        with open(filename) as fh:
            key = json.load(fh)
    else:
        key = json.loads(buf)
    config.service_credentials = key
    config.save()
    print("Key configuration saved.")

cli.add_command(configure)

class GSClient:
    base_url = "https://www.googleapis.com/storage/v1/"
    scope = "https://www.googleapis.com/auth/cloud-platform"
    instance_metadata_url = "http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/token"

    def __init__(self, config, **session_kwargs):
        self.config = config
        self._service_jwt = None
        self._oauth2_token = None
        self._session = None
        self._session_kwargs = session_kwargs

    def get_session(self):
        if self._session is None:
            self._session = requests.Session(**self._session_kwargs)
            self._session.headers.update({"Authorization": "Bearer " + self.get_oauth2_token(),
                                          "User-Agent": self.__class__.__name__})
        return self._session

    def get_oauth2_token(self):
        # TODO: invalidate and refetch before expiration
        if self._oauth2_token is None:
            try:
                service_jwt = self.get_service_jwt()
                params = dict(grant_type="urn:ietf:params:oauth:grant-type:jwt-bearer", assertion=service_jwt)
                res = self.get_session().post("https://www.googleapis.com/oauth2/v4/token", data=params)
            except NoServiceCredentials:
                try:
                    res = requests.get(self.instance_metadata_url)
                except Exception:
                    sys.exit('API credentials not configured. Please run "gs configure" '
                             'or set GOOGLE_APPLICATION_CREDENTIALS.')
            self._oauth2_token = res.json()["access_token"]
        return self._oauth2_token

    def get_service_jwt(self):
        if self._service_jwt is None:
            if "service_credentials" not in self.config:
                if "GOOGLE_APPLICATION_CREDENTIALS" in os.environ:
                    logging.info("Using GOOGLE_APPLICATION_CREDENTIALS file %s",
                                 os.environ["GOOGLE_APPLICATION_CREDENTIALS"])
                    with open(os.environ["GOOGLE_APPLICATION_CREDENTIALS"]) as fh:
                        config.service_credentials = json.load(fh)
                else:
                    raise NoServiceCredentials()

            payload = {
                'iss': self.config.service_credentials["client_email"],
                'sub': self.config.service_credentials["client_email"],
                'scope': self.scope,
                'aud': "https://www.googleapis.com/oauth2/v4/token",
                'iat': datetime.datetime.utcnow(),
                'exp': datetime.datetime.utcnow() + datetime.timedelta(minutes=60)
            }
            additional_headers = {'kid': self.config.service_credentials["private_key_id"]}
            self._service_jwt = jwt.encode(payload,
                                           self.config.service_credentials["private_key"],
                                           headers=additional_headers,
                                           algorithm='RS256').decode()
        return self._service_jwt

    def request(self, method, resource, **kwargs):
        res = self.get_session().request(method=method, url=self.base_url + resource, **kwargs)
        res.raise_for_status()
        return res if kwargs.get("stream") is True or method == "delete" else res.json()

    def get(self, resource, **kwargs):
        return self.request(method="get", resource=resource, **kwargs)

    def post(self, resource, **kwargs):
        return self.request(method="post", resource=resource, **kwargs)

    def patch(self, resource, **kwargs):
        return self.request(method="patch", resource=resource, **kwargs)

    def put(self, resource, **kwargs):
        return self.request(method="put", resource=resource, **kwargs)

    def delete(self, resource, **kwargs):
        return self.request(method="delete", resource=resource, **kwargs)

    def get_project(self):
        if "GOOGLE_CLOUD_PROJECT" in os.environ:
            return os.environ["GOOGLE_CLOUD_PROJECT"]
        elif "service_credentials" in self.config:
            return self.config.service_credentials["project_id"]

class GSUploadClient(GSClient):
    base_url = "https://www.googleapis.com/upload/storage/v1/"

@click.command()
@click.argument('path', required=False)
def ls(path):
    """List buckets or objects in a bucket/prefix."""
    if path is None:
        res = client.get("b", params=dict(project=client.get_project()))
        columns = ["name", "timeCreated", "updated", "location", "storageClass"]
        page_output(tabulate(res["items"], args=Namespace(columns=columns, max_col_width=40)))
    else:
        if path.startswith("gs://"):
            path = path[len("gs://"):]
        if "/" not in path:
            bucket, prefix = path, ""
        else:
            bucket, prefix = path.split("/", 1)
        params = dict(delimiter="/")
        prefix = prefix.rstrip("*")
        if prefix:
            params["prefix"] = prefix
        res = client.get("b/{}/o".format(bucket), params=params)
        items = [dict(name=i) for i in res.get("prefixes", [])] + res.get("items", [])
        columns = ["name", "size", "timeCreated", "updated", "contentType", "storageClass"]
        page_output(tabulate(items, args=Namespace(columns=columns, max_col_width=40)))

cli.add_command(ls)

def read_file_chunks(filename, chunk_size=1024 * 1024):
    with open(filename, "rb") as fh:
        while True:
            chunk = fh.read(chunk_size)
            if len(chunk) == 0:
                break
            yield chunk

@click.command()
@click.argument('paths', nargs=-1, required=True)
def cp(paths):
    """Copy files to, from, or between buckets."""
    assert len(paths) >= 2
    api_method_template = "b/{source_bucket}/o/{source_key}/copyTo/b/{dest_bucket}/o/{dest_key}"
    if all(p.startswith("gs://") for p in paths):
        for path in paths[:-1]:
            _, _, source_bucket, source_key = path.split("/", 3)
            _, _, dest_bucket, dest_prefix = paths[-1].split("/", 3)
            dest_key = dest_prefix
            # TODO: check if dest_prefix is a prefix on the remote
            if dest_prefix.endswith("/") or len(paths) > 2:
                dest_key = os.path.join(dest_prefix, os.path.basename(source_key))
            api_args = dict(source_bucket=source_bucket,
                            source_key=source_key,
                            dest_bucket=dest_bucket,
                            dest_key=dest_key)
            print("Copying gs://{source_bucket}/{source_key} to gs://{dest_bucket}/{dest_key}".format(**api_args))
            escaped_args = {k: requests.compat.quote(v, safe="") for k, v in api_args.items()}
            client.post(api_method_template.format(**escaped_args))
    elif all(p.startswith("gs://") for p in paths[:-1]) and not paths[-1].startswith("gs://"):
        # TODO: support remote wildcards
        for path in paths[:-1]:
            _, _, source_bucket, source_key = path.split("/", 3)
            dest_filename = paths[-1]
            if os.path.isdir(dest_filename) or len(paths) > 2:
                dest_filename = os.path.join(dest_filename, os.path.basename(source_key))
            api_args = dict(bucket=source_bucket, key=source_key, dest_filename=dest_filename)
            print("Copying gs://{bucket}/{key} to {dest_filename}".format(**api_args))
            with open(dest_filename, "wb") as fh:
                escaped_args = {k: requests.compat.quote(v, safe="") for k, v in api_args.items()}
                res = client.get("b/{bucket}/o/{key}".format(**escaped_args), params=dict(alt="media"), stream=True)
                while True:
                    chunk = res.raw.read(1024 * 1024)
                    if len(chunk) == 0:
                        break
                    fh.write(chunk)
    elif paths[-1].startswith("gs://") and not any(p.startswith("gs://") for p in paths[0:-1]):
        upload_client = GSUploadClient(config=config)
        for path in paths[:-1]:
            _, _, dest_bucket, dest_prefix = paths[-1].split("/", 3)
            dest_key = dest_prefix
            # TODO: check if dest_prefix is a prefix on the remote
            if dest_prefix == "" or dest_prefix.endswith("/") or len(paths) > 2:
                dest_key = os.path.join(dest_prefix, os.path.basename(path))
            print("Copying {path} to gs://{bucket}/{key}".format(path=path, bucket=dest_bucket, key=dest_key))
            upload_path = "b/{bucket}/o".format(bucket=requests.compat.quote(dest_bucket))
            upload_client.post(upload_path, params=dict(uploadType="media", name=dest_key), data=read_file_chunks(path))
    else:
        raise click.BadParameter("paths")

cli.add_command(cp)

@click.command()
@click.argument('paths', nargs=-1, required=True)
def mv(paths):
    """Move files to, from, or between buckets."""
    cp.main(paths, standalone_mode=False)
    rm(paths[:-1])

cli.add_command(mv)

@click.command()
@click.argument('paths', nargs=-1, required=True)
def rm(paths):
    """Delete objects (files) from buckets."""
    if not all(p.startswith("gs://") for p in paths):
        raise click.BadParameter("paths")
    for path in paths:
        _, _, bucket, key = path.split("/", 3)
        print("Deleting gs://{bucket}/{key}".format(bucket=bucket, key=key))
        client.delete("b/{bucket}/o/{key}".format(bucket=requests.compat.quote(bucket),
                                                  key=requests.compat.quote(key, safe="")))

cli.add_command(rm)

@click.command()
@click.argument('paths', nargs=2, required=True)
def sync(paths):
    """Sync a directory of files with bucket/prefix."""
    raise NotImplementedError()

cli.add_command(sync)

@click.command()
@click.argument('bucket_name')
@click.option('--location')
@click.option('--storage-class')
def mb(bucket_name, storage_class=None, location=None):
    """Create a new bucket."""
    print("Creating new Google Storage bucket {}".format(bucket_name))
    api_params = dict(name=bucket_name)
    if location:
        api_params["location"] = location
    if storage_class:
        api_params["storageClass"] = storage_class
    res = client.post("b", params=dict(project=client.get_project()), json=api_params)
    print(json.dumps(res, indent=4))

cli.add_command(mb)

@click.command()
@click.argument('bucket_name')
def rb(bucket_name):
    """Permanently delete an empty bucket."""
    print("Deleting Google Storage bucket {}".format(bucket_name))
    client.delete("b/{}".format(requests.compat.quote(bucket_name)))

cli.add_command(rb)

config = tweak.Config("gs", save_on_exit=False)

client = GSClient(config=config)

if __name__ == '__main__':
    logging.basicConfig(level=logging.ERROR)
    cli()
