#!/usr/bin/env python
"""
Start and stop JupyterLab server for testing the ai-jup extension.
"""
import os
import sys
import json
import time
import signal
import subprocess
from pathlib import Path

def describe():
    print("""name: jupyter-server
description: Start or stop JupyterLab server for testing ai-jup. Manages server lifecycle with a fixed token for easy API testing.
action: string (one of: start, stop, status, restart) - what to do. Default is 'status'.
port: number (optional) - Jupyter server port. Default is 8888.
token: string (optional) - Authentication token. Default is 'debug-token'.""")


def get_project_root():
    """Get the project root directory."""
    return os.environ.get("WORKSPACE_ROOT", os.getcwd())


def get_pid_file(port: int) -> Path:
    """Get the PID file path for a given port."""
    return Path(get_project_root()) / f".jupyter-server-{port}.pid"


def find_running_servers():
    """Find running Jupyter servers."""
    try:
        result = subprocess.run(
            ["jupyter", "server", "list", "--json"],
            capture_output=True, text=True, timeout=10
        )
        servers = []
        if result.returncode == 0:
            for line in result.stdout.strip().split("\n"):
                if line:
                    try:
                        servers.append(json.loads(line))
                    except json.JSONDecodeError:
                        pass
        return servers
    except Exception as e:
        return []


def is_port_in_use(port: int) -> bool:
    """Check if a port is in use."""
    import socket
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        return s.connect_ex(('localhost', port)) == 0


def wait_for_server(port: int, timeout: int = 30) -> bool:
    """Wait for server to be ready on port."""
    for i in range(timeout):
        if is_port_in_use(port):
            return True
        time.sleep(1)
    return False


def start_server(port: int = 8888, token: str = "debug-token"):
    """Start JupyterLab server in background."""
    print(f"🚀 Starting JupyterLab on port {port}")
    print("-" * 50)
    
    # Check if already running
    if is_port_in_use(port):
        print(f"⚠️  Port {port} is already in use")
        servers = find_running_servers()
        for s in servers:
            if str(port) in s.get("url", ""):
                print(f"   Server running: {s['url']}")
                print(f"   Token: {s.get('token', 'unknown')}")
                return True
        print("   Unknown process using port")
        return False
    
    project_root = get_project_root()
    pid_file = get_pid_file(port)
    log_file = Path(project_root) / f".jupyter-server-{port}.log"
    
    # Build command
    cmd = [
        "jupyter", "lab",
        f"--port={port}",
        "--no-browser",
        f"--notebook-dir={project_root}",
        f"--IdentityProvider.token={token}",
        "--ServerApp.allow_origin=*",
        "--ServerApp.disable_check_xsrf=True",
    ]
    
    print(f"   Command: {' '.join(cmd[:4])}...")
    print(f"   Log file: {log_file}")
    
    # Start server in background
    with open(log_file, "w") as log:
        process = subprocess.Popen(
            cmd,
            cwd=project_root,
            stdout=log,
            stderr=subprocess.STDOUT,
            start_new_session=True,
            env={**os.environ, "JUPYTER_TOKEN": token}
        )
    
    # Save PID
    with open(pid_file, "w") as f:
        f.write(str(process.pid))
    
    print(f"   PID: {process.pid}")
    print("   Waiting for server to start...")
    
    # Wait for server
    if wait_for_server(port, timeout=30):
        print(f"\n✅ JupyterLab started successfully!")
        print(f"   URL: http://localhost:{port}/lab?token={token}")
        print(f"   API: http://localhost:{port}/ai-jup/")
        return True
    else:
        print(f"\n❌ Server failed to start within 30 seconds")
        print(f"   Check log file: {log_file}")
        # Show last few lines of log
        try:
            with open(log_file, "r") as f:
                lines = f.readlines()
                if lines:
                    print("\n   Last log lines:")
                    for line in lines[-10:]:
                        print(f"   {line.rstrip()}")
        except:
            pass
        return False


def stop_server(port: int = 8888):
    """Stop JupyterLab server."""
    print(f"🛑 Stopping JupyterLab on port {port}")
    print("-" * 50)
    
    pid_file = get_pid_file(port)
    stopped = False
    
    # Try to stop using PID file
    if pid_file.exists():
        try:
            with open(pid_file, "r") as f:
                pid = int(f.read().strip())
            print(f"   Found PID file: {pid}")
            os.kill(pid, signal.SIGTERM)
            print(f"   Sent SIGTERM to {pid}")
            stopped = True
        except ProcessLookupError:
            print(f"   Process {pid} not found (already stopped?)")
            stopped = True
        except Exception as e:
            print(f"   Error killing process: {e}")
        
        # Remove PID file
        try:
            pid_file.unlink()
        except:
            pass
    
    # Also try jupyter server stop
    try:
        result = subprocess.run(
            ["jupyter", "server", "stop", str(port)],
            capture_output=True, text=True, timeout=10
        )
        if result.returncode == 0:
            print(f"   jupyter server stop succeeded")
            stopped = True
    except Exception as e:
        pass
    
    # Wait for port to be free
    for i in range(10):
        if not is_port_in_use(port):
            print(f"\n✅ Server stopped (port {port} is free)")
            return True
        time.sleep(0.5)
    
    if stopped:
        print(f"\n⚠️  Stop signal sent but port {port} still in use")
        print("   May need to wait or kill manually")
    else:
        print(f"\n❌ Could not stop server on port {port}")
    
    return stopped


def server_status(port: int = 8888):
    """Check server status."""
    print(f"📊 JupyterLab Server Status (port {port})")
    print("-" * 50)
    
    pid_file = get_pid_file(port)
    log_file = Path(get_project_root()) / f".jupyter-server-{port}.log"
    
    # Check PID file
    if pid_file.exists():
        with open(pid_file, "r") as f:
            pid = f.read().strip()
        print(f"   PID file: {pid}")
    else:
        print(f"   PID file: not found")
    
    # Check port
    port_in_use = is_port_in_use(port)
    print(f"   Port {port}: {'in use' if port_in_use else 'free'}")
    
    # Check jupyter server list
    servers = find_running_servers()
    matching = [s for s in servers if str(port) in s.get("url", "")]
    
    if matching:
        s = matching[0]
        print(f"\n✅ Server is running:")
        print(f"   URL: {s.get('url', 'unknown')}")
        print(f"   Token: {s.get('token', 'unknown')}")
        print(f"   Notebook dir: {s.get('notebook_dir', 'unknown')}")
        return True
    elif port_in_use:
        print(f"\n⚠️  Port {port} in use but not by Jupyter")
        return False
    else:
        print(f"\n❌ Server is not running")
        return False


def restart_server(port: int = 8888, token: str = "debug-token"):
    """Restart JupyterLab server."""
    print("🔄 Restarting JupyterLab")
    print("=" * 50)
    stop_server(port)
    time.sleep(2)  # Give it time to fully stop
    return start_server(port, token)


def execute():
    """Execute the tool based on input."""
    input_data = sys.stdin.read().strip()
    action = "status"
    port = 8888
    token = "debug-token"
    
    if input_data:
        for line in input_data.split("\n"):
            if line.startswith("action:"):
                action = line.split(":", 1)[1].strip()
            elif line.startswith("port:"):
                try:
                    port = int(line.split(":", 1)[1].strip())
                except ValueError:
                    pass
            elif line.startswith("token:"):
                token = line.split(":", 1)[1].strip()
    
    if action == "start":
        start_server(port, token)
    elif action == "stop":
        stop_server(port)
    elif action == "status":
        server_status(port)
    elif action == "restart":
        restart_server(port, token)
    else:
        print(f"Unknown action: {action}")
        print("Valid actions: start, stop, status, restart")


if __name__ == "__main__":
    action = os.environ.get("TOOLBOX_ACTION")
    if action == "describe":
        describe()
    elif action == "execute":
        execute()
    else:
        # Direct execution - show status
        server_status()
