#!/usr/bin/env python3
#
# basically:
#
#     curl -X POST \
#       -H 'Authorization: Bearer '$NERSC_TOKEN \
#       -H 'Content-Type: application/x-www-form-urlencoded' \
#       -d 'isPath=false' \
#       -d 'job@{{base}}/scripts/job' \
#       https://api.nersc.gov/api/v1.2/compute/jobs/perlmutter
#
# except that the NERSC token is gathered from oidc.nersc.gov/c2id/token
# psik hot-start and the jobndx is added to the script,
# and this script polls the returned task to extract the jobid.

from typing import Optional, Tuple, List, Any
import sys
import os
import asyncio
from contextlib import asynccontextmanager
import json
from pathlib import Path
from base64 import b64decode

import aiohttp

from authlib.integrations.requests_client import OAuth2Session
from authlib.oauth2.rfc7523 import PrivateKeyJWT
from authlib.integrations.base_client.errors import OAuthError

def get_token(nersc_name: Optional[str] = None) -> str:
    """ Parse ~/.superfacility/key.pem and grab a token.

    Raises KeyError if no file found, or OAuthError if
    a token is not granted.
    """
    nersc_dir = Path(os.environ.get("HOME", "/")) / ".superfacility"
    if nersc_name is None:
        if (nersc_dir/"key.pem").is_file():
            nersc_name = "key.pem"
    if nersc_name is None:
        for f in nersc_dir.iterdir():
            if f.is_file() and f.name[-4:] in [".key", ".pem"]:
                nersc_name = f.name
                break
    if nersc_name is None:
        raise KeyError(f"No <name>.pem file in {nersc_dir}")

    nersc_info = nersc_dir/nersc_name
    nersc = Path(nersc_info)
    with open(nersc_info, "r", encoding="ascii") as f:
        client_id = f.readline().strip()
        private_key = f.read()
    token_url = "https://oidc.nersc.gov/c2id/token"
    #client_id = (nersc/"client_id").read_text().strip()
    #private_key = (nersc/"privkey.pem").read_text() #.strip()

    session = OAuth2Session(
            client_id,
            private_key,
            PrivateKeyJWT(token_url),
            grant_type="client_credentials",
            token_endpoint=token_url
        )

    try:
        token = session.fetch_token()
    except OAuthError:
        print("Auth. error.")
        raise
    # {'access_token': 'A.B.C', 'scope': 'color:red https://api.nersc.gov ip:192.31.96.218_128.219.141.16/30 un:dmroge', 'token_type': 'Bearer', 'expires_in': 600, 'expires_at': 1734463094}
    #print(f"fetched token with scope: {token['scope']}")
    return token['access_token']

nersc_jobspec = {
    "isPath": False,
    "job": "pwd",
    "args": "a,b" # comma-separated args
}

def nersc_path(path: str) -> str:
    return "/api/v1.2/" + path

@asynccontextmanager
async def nersc_client(token: str):
    base_url = "https://api.nersc.gov"
    headers = { "Authorization": 'Bearer '+token,
                "Content-Type": "application/x-www-form-urlencoded",
                "accept": "application/json"
              }
    
    async with aiohttp.ClientSession(base_url=base_url,
                                     headers=headers) as session:
        yield session

async def get_json(session, path: str, params={}):
    async with session.get(nersc_path(path), params=params) as resp:
        if resp.status//100 != 2:
            err = await resp.text()
            raise ValueError(f"Server returned error {resp.status}: {err}")
        return await resp.json()

async def post_json(session, path: str, data):
    async with session.post(nersc_path(path), data=data) as resp:
        if resp.status//100 != 2:
            err = await resp.text()
            raise ValueError(f"Server returned error {resp.status}: {err}")
        try:
            ans = await resp.json()
        except ValueError:
            raise ValueError("Invalid response from server.")
    return ans

async def post_job(session, machine: str, script: str) -> str:
    # {"task_id":"2170946","status":"OK","error":null}
    ans = await post_json(session, f"compute/jobs/{machine}",
                          {"isPath":"false",
                           "job": script})
    try:
        task_id = int(ans.get("task_id", ""))
    except ValueError:
        raise ValueError("Invalid task_id from server.")
    return await poll_task(session, task_id)

async def poll_task(session, task_id: int) -> str:
    dt = 0.1
    for i in range(10):
        ans = await get_json(session, f"tasks/{task_id}")
        stat = ans.get("status", "failed")
        if stat != "new":
            break
        await asyncio.sleep(dt)
        dt *= 2.0
    else:
        raise ValueError(f"Task {task_id} timed out.")

    ret = ans.get("result", "")
    if stat != "completed":
        raise ValueError(f"Task returned {stat}: {ret}")
    return ret

async def main(argv):
    with open("{{base}}/spec.json", "r", encoding="utf-8") as f:
        jobspec = json.dumps(json.load(f),separators=(',', ':'))

    jobndx = int(argv[1])
    with open("{{base}}/scripts/job", "r", encoding="utf-8") as f:
        script = f.read() % {
                'jobspec': jobspec.replace("'", "''"),
                'jobndx': jobndx }

    token = get_token()
    # Note: beware the dreaded {"detail":"Not authenticated"}
    async with nersc_client(token) as session:
        ans = await post_job(session, "perlmutter", script)
    print(ans)

if __name__=="__main__":
    asyncio.run(main(sys.argv))
