#!/usr/bin/env python
"""
Test the ai-jup backend API endpoints.
Helps diagnose issues with the server-side handlers.
"""
import os
import sys
import json
import urllib.request
import urllib.error
from http.cookies import SimpleCookie

def describe():
    print("""name: jupyter-api-test
description: Test the ai-jup backend API endpoints (/ai-jup/prompt, /ai-jup/models, /ai-jup/tool-execute). Use this to diagnose backend API issues.
base_url: string (optional) - JupyterLab base URL (default: http://localhost:8888)
token: string (optional) - Jupyter authentication token
action: string (one of: test-all, test-models, test-prompt, test-health) - which endpoint to test. Default is 'test-all'.""")

def get_jupyter_token():
    """Try to get Jupyter token from common locations."""
    # Check environment
    token = os.environ.get("JUPYTER_TOKEN", "")
    if token:
        return token
    
    # Try to read from jupyter config
    try:
        import subprocess
        result = subprocess.run(
            ["jupyter", "server", "list", "--json"],
            capture_output=True, text=True, timeout=10
        )
        if result.returncode == 0:
            for line in result.stdout.strip().split("\n"):
                if line:
                    try:
                        server = json.loads(line)
                        if server.get("token"):
                            return server["token"]
                    except json.JSONDecodeError:
                        pass
    except Exception:
        pass
    
    return ""

def get_running_servers():
    """Get list of running Jupyter servers."""
    try:
        import subprocess
        result = subprocess.run(
            ["jupyter", "server", "list", "--json"],
            capture_output=True, text=True, timeout=10
        )
        servers = []
        for line in result.stdout.strip().split("\n"):
            if line:
                try:
                    servers.append(json.loads(line))
                except json.JSONDecodeError:
                    pass
        return servers
    except Exception as e:
        return [{"error": str(e)}]

def make_request(url: str, method: str = "GET", data: dict = None, token: str = None):
    """Make an HTTP request to the Jupyter server."""
    headers = {
        "Content-Type": "application/json",
    }
    
    if token:
        headers["Authorization"] = f"token {token}"
    
    request_data = None
    if data:
        request_data = json.dumps(data).encode("utf-8")
    
    req = urllib.request.Request(url, data=request_data, headers=headers, method=method)
    
    try:
        with urllib.request.urlopen(req, timeout=30) as response:
            return {
                "status": response.status,
                "body": response.read().decode("utf-8"),
                "headers": dict(response.headers)
            }
    except urllib.error.HTTPError as e:
        return {
            "status": e.code,
            "error": str(e),
            "body": e.read().decode("utf-8") if e.fp else ""
        }
    except urllib.error.URLError as e:
        return {
            "status": 0,
            "error": str(e.reason)
        }
    except Exception as e:
        return {
            "status": 0,
            "error": str(e)
        }

def test_models_endpoint(base_url: str, token: str):
    """Test the /ai-jup/models endpoint."""
    url = f"{base_url}/ai-jup/models"
    print(f"\n📡 Testing GET {url}")
    
    result = make_request(url, "GET", token=token)
    
    if result.get("status") == 200:
        print(f"   ✅ Status: {result['status']}")
        try:
            body = json.loads(result["body"])
            print(f"   ✅ Response: {json.dumps(body, indent=4)}")
            return True
        except json.JSONDecodeError:
            print(f"   ⚠️  Invalid JSON: {result['body'][:200]}")
            return False
    else:
        print(f"   ❌ Status: {result.get('status', 'N/A')}")
        print(f"   ❌ Error: {result.get('error', result.get('body', 'Unknown'))}")
        return False

def test_prompt_endpoint(base_url: str, token: str):
    """Test the /ai-jup/prompt endpoint with a simple request."""
    url = f"{base_url}/ai-jup/prompt"
    print(f"\n📡 Testing POST {url}")
    
    # Simple test prompt
    data = {
        "prompt": "Say 'Hello from ai-jup test' in exactly those words.",
        "context": {
            "variables": {},
            "functions": {},
            "preceding_code": "# Test cell"
        },
        "model": "claude-sonnet-4-20250514"
    }
    
    print(f"   📤 Request body: {json.dumps(data, indent=4)[:500]}")
    
    result = make_request(url, "POST", data=data, token=token)
    
    if result.get("status") == 200:
        print(f"   ✅ Status: {result['status']}")
        print(f"   ✅ Got streaming response (first 500 chars):")
        print(f"      {result['body'][:500]}")
        
        # Check for errors in SSE stream
        if "error" in result["body"].lower():
            print(f"   ⚠️  Response may contain errors - check above")
            return False
        return True
    else:
        print(f"   ❌ Status: {result.get('status', 'N/A')}")
        print(f"   ❌ Error: {result.get('error', result.get('body', 'Unknown'))}")
        return False

def test_health(base_url: str, token: str):
    """Test basic connectivity to Jupyter server."""
    url = f"{base_url}/api/status"
    print(f"\n📡 Testing GET {url}")
    
    result = make_request(url, "GET", token=token)
    
    if result.get("status") == 200:
        print(f"   ✅ Jupyter server is running")
        return True
    else:
        print(f"   ❌ Cannot reach Jupyter server")
        print(f"   ❌ Error: {result.get('error', 'Unknown')}")
        return False

def run_all_tests(base_url: str, token: str):
    """Run all API tests."""
    print("=" * 60)
    print("AI-JUP API ENDPOINT TESTS")
    print("=" * 60)
    print(f"\nBase URL: {base_url}")
    print(f"Token: {'set' if token else 'not set'}")
    
    # List running servers
    print("\n🖥️  Running Jupyter Servers:")
    servers = get_running_servers()
    for s in servers:
        if "error" in s:
            print(f"   Error: {s['error']}")
        else:
            print(f"   - {s.get('url', 'unknown')} (token: {'yes' if s.get('token') else 'no'})")
    
    results = {}
    
    # Test health first
    results["health"] = test_health(base_url, token)
    
    if results["health"]:
        results["models"] = test_models_endpoint(base_url, token)
        results["prompt"] = test_prompt_endpoint(base_url, token)
    else:
        print("\n⚠️  Skipping other tests - server not reachable")
        results["models"] = False
        results["prompt"] = False
    
    # Summary
    print("\n" + "=" * 60)
    print("TEST SUMMARY")
    print("=" * 60)
    for test, passed in results.items():
        print(f"   {'✅' if passed else '❌'} {test}")
    
    if not all(results.values()):
        print("\n💡 Troubleshooting tips:")
        if not results.get("health"):
            print("   - Make sure JupyterLab is running: jupyter lab")
            print("   - Check the base URL is correct")
            print("   - Ensure you have the right token")
        if not results.get("models"):
            print("   - Check if the extension is installed: pip install -e .")
            print("   - Restart JupyterLab after installation")
        if not results.get("prompt"):
            print("   - Verify ANTHROPIC_API_KEY is set in the Jupyter server environment")
            print("   - Check Jupyter server logs for errors")

def execute():
    """Execute the tool based on input."""
    input_data = sys.stdin.read().strip()
    action = "test-all"
    base_url = "http://localhost:8888"
    token = get_jupyter_token()
    
    if input_data:
        for line in input_data.split("\n"):
            if line.startswith("action:"):
                action = line.split(":", 1)[1].strip()
            elif line.startswith("base_url:"):
                base_url = line.split(":", 1)[1].strip()
            elif line.startswith("token:"):
                token = line.split(":", 1)[1].strip()
    
    # Remove trailing slash from base_url
    base_url = base_url.rstrip("/")
    
    if action == "test-all":
        run_all_tests(base_url, token)
    elif action == "test-models":
        test_models_endpoint(base_url, token)
    elif action == "test-prompt":
        test_prompt_endpoint(base_url, token)
    elif action == "test-health":
        test_health(base_url, token)
    else:
        print(f"Unknown action: {action}")
        print("Valid actions: test-all, test-models, test-prompt, test-health")

if __name__ == "__main__":
    action = os.environ.get("TOOLBOX_ACTION")
    if action == "describe":
        describe()
    elif action == "execute":
        execute()
    else:
        # Direct execution for testing
        token = get_jupyter_token()
        run_all_tests("http://localhost:8888", token)
