#!/usr/bin/env python3
"""

Full research-grade satellite multibeam coverage simulator with hybrid TLE source:
 - Primary: CelesTrak (gp.php?GROUP=...)
 - Fallback: local TLE file (one TLE block per 3-line entry)

Features:
 - SGP4 propagation of TLEs to chosen epoch
 - TEME -> ECEF conversion (precise via astropy if installed; GMST fallback if not)
 - Nadir pointing attitude and beam direction generation
 - Four beam models:
     * gaussian_beam_gain
     * cosn_beam_gain
     * uniform_beam_gain (top-hat)
     * multi_beam_generator (hex/circular/random)
 - Per-beam frequency assignment; preferred bands including E-band
 - Chunked aggregation for large constellations (scales to 35k with tuning)
 - Optional CuPy GPU acceleration (auto-detected)
 - Simple link-budget placeholder for EIRP/path-loss (can be enabled)
 - Outputs: PNG heatmaps and CSVs
"""
# requests matplotlib sgp4 astropy cupy

import math, time, os, sys, traceback
from typing import List, Tuple, Dict, Optional, Any
import numpy as np
import requests
import datetime
import matplotlib.pyplot as plt

# SGP4
from sgp4.api import Satrec, jday

# Optional dependencies: astropy (for precise TEME->ECEF), cupy (GPU)
ASTROPY_AVAILABLE = False
try:
    import astropy.time as atime
    import astropy.coordinates as ac
    import astropy.units as u
    ASTROPY_AVAILABLE = True
except Exception:
    ASTROPY_AVAILABLE = False

CUPY_AVAILABLE = False
try:
    import cupy as cp
    CUPY_AVAILABLE = True
except Exception:
    cp = None
    CUPY_AVAILABLE = False

# -----------------------
# User-configurable parameters (tune these)
# -----------------------
# https://celestrak.org/NORAD/elements/
#
ALL_CELESTRAK_GROUPS = ["Intelsat", "SES", "Eutelsat", "Telesat", "Starlink", "OneWeb", "Qianfan", "Hulianwang Digui", "Kuiper", "Iridium NEXT", "Orbcomm", "Globalstar", "Amateur Radio", "SatNOGS" ]
#
# TLE sources & local fallback
CELESTRAK_GROUPS = ["starlink", "oneweb","Globalstar","Intelsat"]
LOCAL_TLE_FALLBACK = "local_tles.txt"   # file path for fallback if celestrak fails

# Simulation grid and scale
DEFAULT_N_TARGET  = 35000     # target number of TLEs to consider (set lower for testing)
DEFAULT_GRID_NLAT = 180       # lat resolution
DEFAULT_GRID_NLON = 360       # lon resolution

# Beam / antenna settings
DEFAULT_N_BEAMS_PER_SAT = 7
DEFAULT_BEAM_HALF_ANGLE_DEG = 0.8
DEFAULT_BEAM_PATTERN = "hex"   # hex / circular / random
DEFAULT_BEAM_MODEL = "gaussian"  # gaussian / cosn / uniform / multibeam
DEFAULT_BEAM_MAX_TILT_DEG = 10.0

# Gain & frequency settings
DEFAULT_GAIN_THRESHOLD = 0.25
PREFERRED_BANDS = {
    "E-uplink": (71e9, 76e9),
    "E-downlink": (81e9, 86e9),
    "Ku": (10.7e9, 14.5e9),
    "Ka": (17.7e9, 30e9)
}
DEFAULT_FREQUENCY_BAND = "E-band"  # used by multi-beam generator to assign frequencies

# Chunking and GPU
DEFAULT_CHUNK_SAT = 256
DEFAULT_CHUNK_GROUND = 20000
USE_GPU_IF_AVAILABLE = False  # change to True to auto-enable GPU when cupy is present

# Earth constants
WGS84_A_M = 6378137.0
WGS84_B_M = 6356752.314245
WGS84_E2 = 1 - (WGS84_B_M**2 / WGS84_A_M**2)
R_EARTH_KM = 6378.137
KM2M = 1000.0
M2KM = 1.0 / 1000.0
RAD = math.pi / 180.0

# -----------------------
# Utilities: TLE download / parsing
# -----------------------
def fetch_tle_group_celestrak(group: str, timeout: int = 30) -> str:
    url = f"https://celestrak.com/NORAD/elements/gp.php?GROUP={group}&FORMAT=tle"
    r = requests.get(url, timeout=timeout)
    r.raise_for_status()
    return r.text

def parse_tle_text(raw: str) -> List[Tuple[str,str,str]]:
    """
    Parse TLE text (3-line blocks: name, line1, line2).
    Returns list of (name, line1, line2).
    Robust to extra blank lines.
    """
    lines = [ln.rstrip() for ln in raw.splitlines() if ln.strip() != ""]
    tles = []
    i = 0
    while i + 2 < len(lines):
        name = lines[i].strip()
        l1 = lines[i+1].strip()
        l2 = lines[i+2].strip()
        if l1.startswith("1 ") and l2.startswith("2 "):
            tles.append((name, l1, l2))
            i += 3
        else:
            i += 1
    return tles

def load_local_tles(filepath: str) -> List[Tuple[str,str,str]]:
    if not os.path.exists(filepath):
        return []
    with open(filepath, "r", encoding="utf-8") as f:
        data = f.read()
    return parse_tle_text(data)

# -----------------------
# Propagation: SGP4 to TEME/ECI
# -----------------------
def propagate_tles_to_epoch(tles: List[Tuple[str,str,str]], epoch_dt: datetime.datetime):
    """
    Propagate list of TLEs to epoch. Returns:
      - names: list[str]
      - pos_teme_km: ndarray (N,3)
      - vel_teme_km_s: ndarray (N,3)
      - satrecs: list of Satrec objects
    """
    names = []
    pos_list = []
    vel_list = []
    satrecs = []
    jd, fr = jday(epoch_dt.year, epoch_dt.month, epoch_dt.day,
                  epoch_dt.hour, epoch_dt.minute, epoch_dt.second + epoch_dt.microsecond*1e-6)
    for (name, l1, l2) in tles:
        sat = Satrec.twoline2rv(l1, l2)
        e, r, v = sat.sgp4(jd, fr)
        if e != 0:
            # skip satellites that cannot be propagated at epoch
            continue
        names.append(name)
        pos_list.append(np.array(r, dtype=float))
        vel_list.append(np.array(v, dtype=float))
        satrecs.append(sat)
    if len(pos_list) == 0:
        return names, np.zeros((0,3)), np.zeros((0,3)), satrecs
    return names, np.vstack(pos_list), np.vstack(vel_list), satrecs

# -----------------------
# TEME -> ECEF conversion
# -----------------------
def teme_to_ecef_km(pos_teme_km: np.ndarray, epoch_dt: datetime.datetime):
    """
    Convert TEME positions (km) to ECEF positions (km).
    If astropy is available, use its built-in transforms (recommended).
    Otherwise fallback to a GMST rotation (approximate).
    pos_teme_km: (N,3) array
    """
    if pos_teme_km.shape[0] == 0:
        return pos_teme_km.copy()

    if ASTROPY_AVAILABLE:
        try:
            # Use astropy to transform from TEME (if supported) to ITRS (ECEF)
            # Astropy does not have direct 'TEME' in older versions; attempt using TEME via ITRS/other if available.
            # We construct SkyCoord with TEME representation if supported; else fall back.
            # Approach: build CartesianRepresentation in TEME using 'coord' frame if available
            t = atime.Time(epoch_dt, scale='utc')
            # Try direct TEME transform (some astropy versions implement TEME)
            try:
                teme_frame = ac.TEME(obstime=t)
                cart = ac.CartesianRepresentation(pos_teme_km * u.km)
                sc = ac.SkyCoord(cart, frame=teme_frame)
                itrs = sc.transform_to(ac.ITRS(obstime=t))
                x = itrs.x.to(u.km).value
                y = itrs.y.to(u.km).value
                z = itrs.z.to(u.km).value
                return np.column_stack([x,y,z])
            except Exception:
                # Fallback: convert TEME->GCRS->ITRS (if possible) via astropy transformations
                # Build a generic ITRS vector via built-in transformation pipeline
                # Note: This may not be available in all astropy versions; if it fails, fallback to GMST.
                try:
                    from astropy.coordinates import GCRS
                    # Create a GCRS coordinate approximately from TEME by applying frame rotation (not exact)
                    # Use a simple GMST rotation here as last fallback (below)
                    pass
                except Exception:
                    pass
        except Exception as e:
            # If any astropy error, fallback to GMST method
            print("Astropy TEME->ECEF conversion failed, falling back to GMST rotation:", str(e))

    # --- GMST rotation fallback (approximate) ---
    jd = datetime_to_julian_date(epoch_dt)
    gmst = julian_date_to_gmst_rad(jd)
    cosg = math.cos(gmst); sing = math.sin(gmst)
    R = np.array([[cosg, sing, 0.0], [-sing, cosg, 0.0], [0.0, 0.0, 1.0]])
    return (R @ pos_teme_km.T).T

# --- helpers for julian/gmst ---
def datetime_to_julian_date(dt: datetime.datetime) -> float:
    year = dt.year; month = dt.month; day = dt.day
    hour = dt.hour; minute = dt.minute; second = dt.second + dt.microsecond*1e-6
    if month <= 2:
        year -= 1; month += 12
    A = year // 100
    B = 2 - A + A // 4
    day_frac = (hour + minute/60.0 + second/3600.0) / 24.0
    jd = int(365.25*(year + 4716)) + int(30.6001*(month + 1)) + day + B - 1524.5 + day_frac
    return float(jd)

def julian_date_to_gmst_rad(jd: float) -> float:
    T = (jd - 2451545.0) / 36525.0
    gmst_seconds = 67310.54841 + (876600.0*3600.0 + 8640184.812866)*T + 0.093104*T*T - 6.2e-6*T*T*T
    gmst_seconds = gmst_seconds % 86400.0
    return (gmst_seconds / 86400.0) * 2.0 * math.pi

# -----------------------
# ECEF <-> geodetic (WGS84)
# -----------------------
def ecef_to_geodetic_wgs84_km(r_ecef_km: np.ndarray):
    """
    Vectorized ECEF (km) -> geodetic lat, lon, alt (lat, lon rad; alt in km)
    Iterative Bowring method.
    """
    if r_ecef_km.shape[0] == 0:
        return np.array([]), np.array([]), np.array([])
    x = r_ecef_km[:,0] * 1000.0
    y = r_ecef_km[:,1] * 1000.0
    z = r_ecef_km[:,2] * 1000.0
    a = WGS84_A_M; e2 = WGS84_E2
    lon = np.arctan2(y, x)
    p = np.sqrt(x*x + y*y)
    lat = np.arctan2(z, p * (1 - e2))
    for _ in range(10):
        sinl = np.sin(lat)
        N = a / np.sqrt(1 - e2 * sinl * sinl)
        alt = p / np.cos(lat) - N
        lat_new = np.arctan2(z, p * (1 - e2 * (N/(N+alt))))
        if np.max(np.abs(lat_new - lat)) < 1e-12:
            lat = lat_new; break
        lat = lat_new
    sinl = np.sin(lat)
    N = a / np.sqrt(1 - e2 * sinl * sinl)
    alt = p / np.cos(lat) - N
    return lat, lon, alt * M2KM

# -----------------------
# Beam gain models
# -----------------------
def gaussian_beam_gain(theta_rad: np.ndarray, half_angle_deg: float):
    bw = math.radians(half_angle_deg)
    sigma = bw / math.sqrt(2.0 * math.log(2.0))
    return np.exp(-0.5 * (theta_rad / sigma)**2)

def cosn_beam_gain(theta_rad: np.ndarray, half_angle_deg: float):
    bw = math.radians(half_angle_deg)
    denom = math.log(max(1e-12, math.cos(bw)))
    n = math.log(0.5) / denom if denom != 0 else 1.0
    g = np.cos(theta_rad)**n
    return np.clip(g, 0.0, 1.0)

def uniform_beam_gain(theta_rad: np.ndarray, half_angle_deg: float):
    bw = math.radians(half_angle_deg)
    return (theta_rad <= bw).astype(float)

# -----------------------
# Multi-beam generator
# -----------------------
def multi_beam_generator(n_beams: int,
                         beam_half_angle_deg: float,
                         pattern: str = "hex",
                         max_tilt_deg: float = 60.0,
                         frequency_band: str = "E-band",
                         rng: Optional[np.random.Generator] = None):
    """
    Generate beam center directions in satellite body frame (z = nadir),
    beam half-angles and frequencies (Hz).
    """
    if rng is None:
        rng = np.random.default_rng()

    # frequencies in Hz
    if frequency_band == "E-band":
        freqs = np.linspace(71e9, 86e9, n_beams)
    elif frequency_band == "Ku":
        freqs = np.linspace(10.7e9, 14.5e9, n_beams)
    elif frequency_band == "Ka":
        freqs = np.linspace(17.7e9, 30e9, n_beams)
    else:
        freqs = np.linspace(10e9, 86e9, n_beams)

    if pattern == "random":
        tilt = math.radians(max_tilt_deg)
        cos_theta = rng.uniform(math.cos(tilt), 1.0, n_beams)
        sin_theta = np.sqrt(1 - cos_theta**2)
        phi = rng.uniform(0.0, 2.0*math.pi, n_beams)
        dirs = np.column_stack([sin_theta*np.cos(phi), sin_theta*np.sin(phi), cos_theta])
    elif pattern == "circular":
        theta = math.radians(max_tilt_deg)
        phi = np.linspace(0.0, 2.0*math.pi, n_beams, endpoint=False)
        dirs = np.column_stack([np.sin(theta)*np.cos(phi), np.sin(theta)*np.sin(phi), np.cos(theta)*np.ones(n_beams)])
    else:
        # hex tiling
        dirs_list = []
        rings = int(np.ceil(np.sqrt(n_beams)))
        count = 0
        for r in range(rings):
            frac = 0.0 if rings <= 1 else (r/(rings-1))
            theta = math.radians(max_tilt_deg) * frac
            n_in_ring = max(1, 6 * max(1, r))
            for k in range(n_in_ring):
                if count >= n_beams:
                    break
                phi = 2.0 * math.pi * (k / n_in_ring)
                dirs_list.append([math.sin(theta)*math.cos(phi), math.sin(theta)*math.sin(phi), math.cos(theta)])
                count += 1
        dirs = np.array(dirs_list[:n_beams])

    dirs = dirs / np.linalg.norm(dirs, axis=1, keepdims=True)
    half_angles = np.ones(n_beams) * beam_half_angle_deg
    return dirs, half_angles, freqs

# -----------------------
# Dispatcher: generate_beam_pattern
# -----------------------
def generate_beam_pattern(model: str,
                          sat_pos_ecef_m: np.ndarray,
                          sat_vel_eci_km_s: np.ndarray,
                          n_beams: int = 1,
                          beam_half_angle_deg: float = 1.5,
                          pattern: str = "hex",
                          max_tilt_deg: float = 60.0,
                          frequency_band: str = "E-band",
                          rng: Optional[np.random.Generator] = None):
    """
    Return:
      - boresights_ecef_unit (nb,3) unit vectors in ECEF frame
      - half_angles_deg (nb,)
      - freqs_hz (nb,)
    sat_pos_ecef_m: satellite position in ECEF (meters)
    sat_vel_eci_km_s: optional velocity (km/s) used for attitude - but we use ECEF-based nadir frame
    """
    # Build nadir-pointing body->ECEF rotation
    r = sat_pos_ecef_m.astype(float)
    r_hat = r / np.linalg.norm(r)
    z_body = -r_hat
    # choose a cross-track axis - use Earth's Z to define approximate along-track if velocity missing
    north = np.array([0.0, 0.0, 1.0])
    y_body = np.cross(z_body, north)
    ynorm = np.linalg.norm(y_body)
    if ynorm < 1e-10:
        y_body = np.array([0.0, 1.0, 0.0])
    else:
        y_body = y_body / ynorm
    x_body = np.cross(y_body, z_body)
    x_body = x_body / np.linalg.norm(x_body)
    R_b2ecef = np.column_stack([x_body, y_body, z_body])  # 3x3

    if model in ("gaussian", "cosn", "uniform"):
        # single boresight is nadir
        dirs_body = np.array([[0.0, 0.0, 1.0]])
        half_angles = np.array([beam_half_angle_deg])
        freqs = np.array([_default_freq_for_band(frequency_band)])
    elif model == "multibeam":
        dirs_body, half_angles, freqs = multi_beam_generator(n_beams=n_beams,
                                                             beam_half_angle_deg=beam_half_angle_deg,
                                                             pattern=pattern,
                                                             max_tilt_deg=max_tilt_deg,
                                                             frequency_band=frequency_band,
                                                             rng=rng)
    else:
        raise ValueError("Unknown beam model: " + str(model))
    # convert body dirs to ECEF boresight vectors
    boresights_ecef = (dirs_body @ R_b2ecef.T)
    boresights_ecef = boresights_ecef / np.linalg.norm(boresights_ecef, axis=1, keepdims=True)
    return boresights_ecef, half_angles, freqs

def _default_freq_for_band(band_name: str):
    if band_name.lower().startswith("e"):
        return 83e9
    if band_name.lower().startswith("ku"):
        return 12e9
    if band_name.lower().startswith("ka"):
        return 20e9
    return 12e9

# -----------------------
# Link-budget helper (placeholder)
# -----------------------
def free_space_path_loss_db(freq_hz: float, distance_m: float):
    """
    FSPL (dB) = 20 log10(4π d / λ)
    """
    c = 299792458.0
    lam = c / freq_hz
    with np.errstate(divide='ignore'):
        fspl = 20.0 * np.log10(4.0 * math.pi * distance_m / lam)
    return fspl

def link_budget_received_db(eirp_dbw: float, freq_hz: float, distance_m: float, rx_gain_db: float = 0.0, losses_db: float = 0.0):
    """
    Very simple link budget: Pr_dBW = EIRP_dBW - FSPL_dB + Gr_dB - losses
    EIRP_dBW desired in dBW.
    """
    fspl = free_space_path_loss_db(freq_hz, distance_m)
    pr = eirp_dbw - fspl + rx_gain_db - losses_db
    return pr

# -----------------------
# Aggregation: beams -> ground (chunked; CPU & optional GPU support)
# -----------------------
def aggregate_beams_to_ground(
    sat_ecef_km: np.ndarray,
    sat_vel_eci_km_s: np.ndarray,
    sat_names: List[str],
    ground_lat_rad: np.ndarray,
    ground_lon_rad: np.ndarray,
    model: str = "multibeam",
    n_beams_per_sat: int = DEFAULT_N_BEAMS_PER_SAT,
    beam_half_angle_deg: float = DEFAULT_BEAM_HALF_ANGLE_DEG,
    beam_pattern: str = DEFAULT_BEAM_PATTERN,
    beam_max_tilt_deg: float = DEFAULT_BEAM_MAX_TILT_DEG,
    beam_gain_model: str = DEFAULT_BEAM_MODEL,
    gain_threshold: float = DEFAULT_GAIN_THRESHOLD,
    frequency_band: str = DEFAULT_FREQUENCY_BAND,
    preferred_bands: Dict[str, Tuple[float,float]] = PREFERRED_BANDS,
    chunk_sat: int = DEFAULT_CHUNK_SAT,
    chunk_ground: int = DEFAULT_CHUNK_GROUND,
    use_gpu: bool = False,
    compute_power_map: bool = False,
    eirp_dbw: float = 47.0  # example EIRP in dBW per beam (placeholder)
):
    """
    Core aggregation routine. Returns:
     - total_beams_per_ground (G,)
     - preferred_beams_per_ground (G,)
     - cofreq_map: dict freq_hz -> (G,) counts
     - (optional) received_power_map (G,) in dBW if compute_power_map True (sum of linear powers)
    Notes:
     - sat_ecef_km: (N,3) satellite ECEF positions in km
     - sat_vel_eci_km_s: (N,3) provided but used only for attitude if needed
     - ground_lat_rad & ground_lon_rad are flattened arrays (G,)
    """
    use_gpu = use_gpu and CUPY_AVAILABLE
    xp = cp if use_gpu else np

    N = sat_ecef_km.shape[0]
    G = ground_lat_rad.size

    # Build ground ECEF (meters) and normals
    alt0_m = np.zeros_like(ground_lat_rad)
    r_g_m  = geodetic_to_ecef_m(ground_lat_rad, ground_lon_rad, alt0_m)  # (G,3)
    n_g    = r_g_m / np.linalg.norm(r_g_m, axis=1, keepdims=True)

    # G = total number of ground points
    Nvis 		= np.zeros(G, dtype=int)	# needs reworking
    total_counts	= np.zeros(G, dtype=int)	# also uncorrected
    preferred_counts	= np.zeros(G, dtype=int)
    cofreq_map: Dict[float, np.ndarray] = {}
    if compute_power_map:
        # store linear power sums (watts) per ground point
        power_linear = np.zeros(G, dtype=float)
    else:
        power_linear = None

    # helper: is freq in any preferred band
    def is_preferred(freq_hz: float) -> bool:
        for (lo, hi) in preferred_bands.values():
            if freq_hz >= lo and freq_hz <= hi:
                return True
        return False

    # iterate sats in chunks
    for s0 in range(0, N, chunk_sat):
        s1 = min(N, s0 + chunk_sat)
        chunk_idxs = range(s0, s1)

        # pre-generate beams for chunk (list)
        beam_list = []  # entries: (sat_origin_m, boresight_unit_m (3,), half_angle_deg, freq_hz)
        for si in chunk_idxs:
            sat_ecef_m = sat_ecef_km[si] * KM2M
            sat_vel = sat_vel_eci_km_s[si] if sat_vel_eci_km_s is not None else None
            boresights_ecef, half_angles, freqs = generate_beam_pattern(
                model=model,
                sat_pos_ecef_m=sat_ecef_m,
                sat_vel_eci_km_s=sat_vel,
                n_beams=n_beams_per_sat,
                beam_half_angle_deg=beam_half_angle_deg,
                pattern=beam_pattern,
                max_tilt_deg=beam_max_tilt_deg,
                frequency_band=frequency_band
            )
            for b in range(boresights_ecef.shape[0]):
                beam_list.append((sat_ecef_m, boresights_ecef[b], float(half_angles[b]), float(freqs[b])))

        B = len(beam_list)
        if B == 0:
            continue

        # process ground in chunks
        for g0 in range(0, G, chunk_ground):
            g1 = min(G, g0 + chunk_ground)
            r_g_chunk = r_g_m[g0:g1]    # (gchunk, 3) meters
            n_g_chunk = n_g[g0:g1]
            counts_chunk = np.zeros(g1-g0, dtype=int)
            pref_chunk = np.zeros(g1-g0, dtype=int)
            freq_acc: Dict[float, np.ndarray] = {}
            pow_chunk_linear = np.zeros(g1-g0, dtype=float) if compute_power_map else None

            # loop beams (vectorizing across ground points)
            for (sat_origin_m, boresight_u, half_angle_deg, freq_hz) in beam_list:
                #
                v = r_g_chunk - sat_origin_m			# vektor satellit → mark
                dist = np.linalg.norm(v, axis=1)		# v_norm
                dist_safe = np.where(dist <= 0.0, 1e-6, dist)	# avoid divide-by-zero
                v_unit = v / dist_safe[:,None]			# v_unit = v / v_norm[:, None]
                #
                # beam-mask: θ mellan boresight och riktning från sat → mark
                cos_theta = np.einsum('ij,j->i', v_unit, boresight_u)	# boresight_u = satellitens boresight (nadir eller tiltad)
                cos_theta = np.clip(cos_theta, -1.0, 1.0)		# begränsa värdespann
                theta = np.arccos(cos_theta)
                mask_beam = theta <= math.radians(half_angle_deg)
                if not np.any(mask_beam):
                    continue
                #
                # elevation (grader); elevation/horizon check: dot of v with local normal > 0 for visibility
                dotvn = np.einsum('ij,ij->i', v, n_g_chunk)
                min_elev_angle = 0.0				# horisonten
                min_elev_func  = np.sin(min_elev_angle)
                #
                # synlighetsmask
                visible_mask = dotvn >= min_elev_func
                final_mask = mask_beam & visible_mask
                if not np.any(final_mask):
                    continue

                # Nvis: alla synliga beams, oavsett frekvens
                Nvis[g0:g1][visible_mask] += 1

                # evaluate gain depending on model
                if model == "gaussian" or model == "multibeam":
                    gain_vals = gaussian_beam_gain(theta[final_mask], half_angle_deg)
                elif model == "cosn":
                    gain_vals = cosn_beam_gain(theta[final_mask], half_angle_deg)
                elif model == "uniform":
                    gain_vals = uniform_beam_gain(theta[final_mask], half_angle_deg)
                else:
                    gain_vals = gaussian_beam_gain(theta[final_mask], half_angle_deg)

                mask_gain = gain_vals >= gain_threshold
                if not np.any(mask_gain):
                    continue
                idxs = np.nonzero(final_mask)[0][mask_gain]
                counts_chunk[idxs] += 1
                if is_preferred(freq_hz):
                    pref_chunk[idxs] += 1
                #
                # Nco: per frekvensbin co-frequency accumulation
                freq_key = round(freq_hz, 3)				# bin på kHz-nivå, byt till 6 för MHz
                if freq_key not in freq_acc:
                    freq_acc[freq_key] = np.zeros(g1-g0, dtype=int)	# freq_acc[freq_hz] hz <-> key
                freq_acc[freq_key][idxs] += 1				# freq_acc[freq_hz][idxs] += 1

                # optional received power accumulation (very simple FSPL model from sat to ground)
                if compute_power_map:
                    # convert linear gain (0..1) to dBi or relative factor? Here treat gain_vals as normalized main-lobe factor.
                    # compute free-space path loss using mean distance for idxs
                    d_sel = dist[idxs]
                    # convert freq to Hz
                    pr_db = link_budget_received_db(eirp_dbw=eirp_dbw, freq_hz=freq_hz, distance_m=np.mean(d_sel), rx_gain_db=0.0, losses_db=0.0)
                    pr_lin = 10**(pr_db/10.0)  # convert dBW to watts approx (assuming dBW scale)
                    # scale by gain fraction (normalized)
                    # We'll add same pr_lin to each idxs weighted by normalized gain
                    for k, gi in enumerate(idxs):
                        pow_chunk_linear[gi] += pr_lin * gain_vals[mask_gain][k]

            # merge chunk results into global
            total_counts[g0:g1] += counts_chunk
            preferred_counts[g0:g1] += pref_chunk
            if compute_power_map:
                power_linear[g0:g1] += pow_chunk_linear
            for f, arr in freq_acc.items():
                if f not in cofreq_map:
                    cofreq_map[f] = np.zeros(G, dtype=int)
                cofreq_map[f][g0:g1] += arr

    if compute_power_map:
        # convert total linear watts to dBW per ground point (handle zeros)
        with np.errstate(divide='ignore'):
            power_dbw = 10.0 * np.log10(power_linear, where=(power_linear>0.0))
        power_dbw = np.where(power_linear>0.0, power_dbw, -999.0)
    else:
        power_dbw = None

    return total_counts, preferred_counts, cofreq_map, power_dbw, Nvis

# -----------------------
# ECEF/Geodetic helpers
# -----------------------
def geodetic_to_ecef_m(lat_rad: np.ndarray, lon_rad: np.ndarray, alt_m: np.ndarray):
    a = WGS84_A_M; e2 = WGS84_E2
    sinl = np.sin(lat_rad); cosl = np.cos(lat_rad)
    N = a / np.sqrt(1 - e2*sinl*sinl)
    x = (N + alt_m) * cosl * np.cos(lon_rad)
    y = (N + alt_m) * cosl * np.sin(lon_rad)
    z = (N * (1 - e2) + alt_m) * sinl
    return np.stack([x,y,z], axis=-1)

# -----------------------
# plotting & saving helpers
# -----------------------
def plot_heatmap(grid2d: np.ndarray, lat_vals_rad: np.ndarray, lon_vals_rad: np.ndarray, filename: str, title: str = ""):
    plt.figure(figsize=(12,6))
    extent = [math.degrees(lon_vals_rad.min()), math.degrees(lon_vals_rad.max()),
              math.degrees(lat_vals_rad.min()), math.degrees(lat_vals_rad.max())]
    plt.imshow(grid2d, origin='lower', extent=extent, aspect='auto', cmap='inferno')
    plt.colorbar(label='count')
    plt.title(title)
    plt.xlabel('Longitude (deg)'); plt.ylabel('Latitude (deg)')
    plt.tight_layout()
    plt.savefig(filename, dpi=200)
    plt.close()

def save_flat_csv(flat_arr: np.ndarray, filename: str, header: str = "value"):
    np.savetxt(filename, flat_arr, delimiter=",", header=header, comments='')

# -----------------------
# Top-level pipeline
# -----------------------
def run_full_simulation(
    out_dir: str = "simulator_output",
    groups: List[str] = CELESTRAK_GROUPS,
    local_tle_file: str = None,
    N_target: int = DEFAULT_N_TARGET,
    grid_nlat: int = DEFAULT_GRID_NLAT,
    grid_nlon: int = DEFAULT_GRID_NLON,
    model: str = DEFAULT_BEAM_MODEL,
    n_beams_per_sat: int = DEFAULT_N_BEAMS_PER_SAT,
    beam_half_angle_deg: float = DEFAULT_BEAM_HALF_ANGLE_DEG,
    beam_pattern: str = DEFAULT_BEAM_PATTERN,
    beam_max_tilt_deg: float = DEFAULT_BEAM_MAX_TILT_DEG,
    beam_gain_model: str = DEFAULT_BEAM_MODEL,
    gain_threshold: float = DEFAULT_GAIN_THRESHOLD,
    frequency_band: str = DEFAULT_FREQUENCY_BAND,
    preferred_bands: Dict[str, Tuple[float,float]] = PREFERRED_BANDS,
    chunk_sat: int = DEFAULT_CHUNK_SAT,
    chunk_ground: int = DEFAULT_CHUNK_GROUND,
    use_gpu_if_available: bool = USE_GPU_IF_AVAILABLE,
    compute_power_map: bool = False,
    save_tles_to_disk: bool = False,
    do_random_sampling:bool = False,
):
    os.makedirs(out_dir, exist_ok=True)
    # 1) gather TLEs (CelesTrak primary, local fallback)
    tles = []
    if local_tle_file is None :
        for g in groups:
            try:
                print(f"Fetching TLEs for group '{g}' from CelesTrak...")
                raw = fetch_tle_group_celestrak(g)
                tles_group = parse_tle_text(raw)
                print(f"  parsed {len(tles_group)} TLEs from {g}")
                if save_tles_to_disk :
                    fo = open(f"{out_dir+'/'}{g}TLE.txt","w")
                    print ( raw , file=fo )
                    fo.close()
                tles.extend(tles_group)
            except Exception as e:
                print(f"  failed to fetch {g} from CelesTrak: {e}; continuing")

    if len(tles) == 0 :
        # local file
        print("No TLEs downloaded from CelesTrak; attempting to load local TLE file:", local_tle_file)
        try:
            tles = load_local_tles(local_tle_file)
            if len(tles) == 0:
                raise RuntimeError("No TLEs available: CelesTrak failed and local file not found/empty.")
        except Exception as e:
            print(f"  failed to obtain tle data from {local_tle_file} : {e}; continuing")
    # Trim to N_target
    if N_target is not None and len(tles) > N_target:
        if do_random_sampling :
            import random
            indices = random.sample( range(len(tles)) , N_target )
            tles = [ tles[ idx ] for idx in indices ]
        else :
            tles = tles[:N_target]
    print("Total TLEs to be used:", len(tles))

    # 2) propagate to epoch
    epoch = datetime.datetime.utcnow()
    print("Propagating TLEs to epoch (UTC):", epoch.isoformat())
    names, pos_teme_km, vel_teme_km_s, satrecs = propagate_tles_to_epoch(tles, epoch)
    print("  propagated:", pos_teme_km.shape[0], "satellites")

    # 3) TEME -> ECEF (km)
    print("Converting TEME -> ECEF (km) (astropy fallback if available)")
    pos_ecef_km = teme_to_ecef_km(pos_teme_km, epoch)

    # 4) geodetic sub-satellite points (lat/lon/alt)
    lat_s_rad, lon_s_rad, alt_s_km = ecef_to_geodetic_wgs84_km(pos_ecef_km)

    # 5) Build ground grid
    print("Building ground grid (lat/lon)...")
    lat_vals = np.linspace(-60*RAD, 60*RAD, grid_nlat)
    lon_vals = np.linspace(-180*RAD, 180*RAD, grid_nlon)
    lat2d, lon2d = np.meshgrid(lat_vals, lon_vals, indexing='ij')
    ground_lat_flat = lat2d.ravel()
    ground_lon_flat = lon2d.ravel()
    G = ground_lat_flat.size
    print(f"  ground grid: {grid_nlat} x {grid_nlon} = {G} points")

    # 6) decide on GPU usage
    use_gpu = use_gpu_if_available and CUPY_AVAILABLE
    if use_gpu_if_available and not CUPY_AVAILABLE:
        print("CuPy requested but not available; running on CPU (NumPy).")
    if use_gpu:
        print("CuPy detected and will be used for parts of computation (GPU).")

    # 7) aggregate beams to ground
    print("Aggregating beams to ground (this can be slow for large N; tune chunks)...")
    t0 = time.time()
    total_counts, pref_counts, cofreq_map, power_dbw, Nvis = aggregate_beams_to_ground(
        sat_ecef_km=pos_ecef_km,
        sat_vel_eci_km_s=vel_teme_km_s,
        sat_names=names,
        ground_lat_rad=ground_lat_flat,
        ground_lon_rad=ground_lon_flat,
        model=model,
        n_beams_per_sat=n_beams_per_sat,
        beam_half_angle_deg=beam_half_angle_deg,
        beam_pattern=beam_pattern,
        beam_max_tilt_deg=beam_max_tilt_deg,
        beam_gain_model=beam_gain_model,
        gain_threshold=gain_threshold,
        frequency_band=frequency_band,
        preferred_bands=preferred_bands,
        chunk_sat=chunk_sat,
        chunk_ground=chunk_ground,
        use_gpu=use_gpu,
        compute_power_map=compute_power_map
    )
    t1 = time.time()
    print(f"Aggregation complete in {t1-t0:.1f} s")

    # reshape to 2D for plotting
    total_grid = total_counts.reshape(lat2d.shape)
    pref_grid = pref_counts.reshape(lat2d.shape)
    combined_cofreq_flat = np.zeros_like(total_counts)
    for f, arr in cofreq_map.items():
        combined_cofreq_flat += arr
    combined_cofreq_grid = combined_cofreq_flat.reshape(lat2d.shape)

    # save outputs
    out_total_png = os.path.join(out_dir, "total_beams_heatmap.png")
    out_pref_png = os.path.join(out_dir, "preferred_beams_heatmap.png")
    out_cofreq_png = os.path.join(out_dir, "cofreq_heatmap.png")

    print("Saving heatmaps...")
    plot_heatmap(total_grid, lat_vals, lon_vals, out_total_png, title="Total beams")
    plot_heatmap(pref_grid, lat_vals, lon_vals, out_pref_png, title="Preferred-band beams")
    plot_heatmap(combined_cofreq_grid, lat_vals, lon_vals, out_cofreq_png, title="Co-frequency beams")

    # Save CSVs and grids
    out_nvis_csv = os.path.join(out_dir, "nvis_beams.csv")
    out_total_csv = os.path.join(out_dir, "total_beams.csv")
    out_pref_csv = os.path.join(out_dir, "preferred_beams.csv")
    out_cofreq_csv = os.path.join(out_dir, "cofreq_beams.csv")
    save_flat_csv(Nvis, out_nvis_csv, header="nvis_beams")
    save_flat_csv(total_counts, out_total_csv, header="total_beams")
    save_flat_csv(pref_counts, out_pref_csv, header="preferred_beams")
    save_flat_csv(combined_cofreq_flat, out_cofreq_csv, header="cofreq_beams")
    np.save(os.path.join(out_dir, "lat_grid.npy"), lat2d)
    np.save(os.path.join(out_dir, "lon_grid.npy"), lon2d)

    if compute_power_map and (power_dbw is not None):
        out_power_png = os.path.join(out_dir, "received_power_heatmap.png")
        power_grid = power_dbw.reshape(lat2d.shape)
        plot_heatmap(power_grid, lat_vals, lon_vals, out_power_png, title="Received power (dBW)")
        save_flat_csv(power_dbw, os.path.join(out_dir, "received_power.csv"), header="received_power_dBW")

    print("All outputs written to:", out_dir)
    return {
        "total_png"  : out_total_png  ,
        "pref_png"   : out_pref_png   ,
        "cofreq_png" : out_cofreq_png ,
        "total_csv"  : out_total_csv  ,
        "pref_csv"   : out_pref_csv   ,
        "cofreq_csv" : out_cofreq_csv ,
        "nvis_csv"   : out_nvis_csv   
    }

def download_tle_data (
    out_dir: str = "tle_downloads",
    groups: List[str] = ALL_CELESTRAK_GROUPS,
):
    os.makedirs(out_dir, exist_ok=True)
    # gather TLEs from CelesTrak primary
    tles  = []
    names = []

    for g in groups:
        try:
            print(f"Fetching TLEs for group '{g}' from CelesTrak...")
            raw = fetch_tle_group_celestrak(g)
            tles_group = parse_tle_text(raw)
            names.append( f"{out_dir+'/'}{g}TLE.txt" )
            print(f"  parsed {len(tles_group)} TLEs from {g}")
            fo = open(f"{out_dir+'/'}{g}TLE.txt","w")
            print ( raw , file=fo )
            fo.close()
            tles.extend(tles_group)
        except Exception as e:
            print(f"  failed to fetch {g} from CelesTrak: {e}; continuing")

def gather_tle_data (
    out_dir: str = "tle_downloads",
    groups: List[str] = ALL_CELESTRAK_GROUPS,
    local_tle_file: str = "tle_local.txt",
):
    os.makedirs(out_dir, exist_ok=True)
    # gather TLEs from CelesTrak primary
    tles  = []
    names = []

    for g in groups:
        try:
            print(f"Fetching TLEs for group '{g}' from CelesTrak...")
            raw = fetch_tle_group_celestrak(g)
            tles_group = parse_tle_text(raw)
            names.append( f"{out_dir+'/'}{g}TLE.txt" )
            print(f"  parsed {len(tles_group)} TLEs from {g}")
            fo = open(f"{out_dir+'/'}{g}TLE.txt","w")
            print ( raw , file=fo )
            fo.close()
            tles.extend(tles_group)
        except Exception as e:
            print(f"  failed to fetch {g} from CelesTrak: {e}; continuing")
    fo = open(local_tle_file,"w")
    fo .close()
    fo = open(local_tle_file,"a")
    for name in names :
        print ( name , ':', os.path.getsize(name) )
        with open(name,"r") as input :
            try:
                for line in input :
                    if len(line.replace(" ","").replace("\n","")) > 0 :
                        print ( line.replace( "\n" , "" ) , file=fo )
            except Exception as err:
                continue
    fo.close()

def collate_tle_data(
    out_dir: str        = "tle_downloads",
    groups: List[str]   = ALL_CELESTRAK_GROUPS,
    local_tle_file: str = "tle_local.txt",
):
    os.makedirs(out_dir, exist_ok=True)
    # gather TLEs from CelesTrak primary
    tles  = []
    names = []
    fo = open(local_tle_file,"w")
    fo .close()
    fo = open(local_tle_file,"a")
    for g in groups :
        name = f"{out_dir+'/'}{g}TLE.txt"
        print ( name , ':', os.path.getsize(name) )
        with open(name,"r") as input :
            try:
                for line in input :
                    if len(line.replace(" ","").replace("\n","")) > 0 :
                        print ( line.replace( "\n" , "" ) , file=fo )
            except Exception as err:
                continue
    fo.close()

""" PSUEDO CODE
dev++
start_time = now()
dt_seconds = 30
n_steps = int(24*3600 / dt_seconds)
times = [start_time + i*dt_seconds for i in range(n_steps)]

Nvis_acc = np.zeros(G)
Nco_acc = {}
total_acc = np.zeros(G)

for t in times:
    sats = propagate_all_satellites(t)
    tot, pref, co, pwr, Nvis = aggregate_beams_to_ground_gpu(...)
    
    Nvis_acc += Nvis
    total_acc += tot
    
    for fk in co:
        if fk not in Nco_acc:
            Nco_acc[fk] = co[fk].copy()
        else:
            Nco_acc[fk] += co[fk]

Nvis_avg = Nvis_acc / len(times)
total_avg = total_acc / len(times)
Nco_avg = {f: arr / len(times) for f, arr in Nco_acc.items()}
"""
#
# -----------------------
# Small sanity-run when executed directly
# -----------------------
if __name__ == "__main__":
    if False :
        download_tle_data()
        collate_tle_data()
        exit(1)

    try:
        out = run_full_simulation(
            out_dir="sim_20251212_dev",
            groups=ALL_CELESTRAK_GROUPS,	# CELESTRAK_GROUPS,
            local_tle_file="tle_local.txt", 	# LOCAL_TLE_FALLBACK,
            N_target=10000,               	# set to 35000 for full-scale runs (ensure resources)
            grid_nlat=120,
            grid_nlon=240,
            model="multibeam",
            n_beams_per_sat=7,
            beam_half_angle_deg=0.8,
            beam_pattern="hex",
            beam_max_tilt_deg=10.0,
            beam_gain_model="gaussian",
            gain_threshold=0.25,
            frequency_band="E-band",
            preferred_bands=PREFERRED_BANDS,
            chunk_sat=256,
            chunk_ground=20000,
            use_gpu_if_available=False,   # set True if you installed cupy
            compute_power_map = True,
            do_random_sampling = True,
        )
        print("Simulation finished. Outputs:", out)
    except Exception as err:
        print("Error during simulation:", err)
        traceback.print_exc()

    import pandas as pd
    tdf = pd.concat( (	pd.read_csv(out['total_csv']),	pd.read_csv(out['pref_csv']),
			pd.read_csv(out['cofreq_csv']),	pd.read_csv(out['nvis_csv'])) )
    print ( tdf .describe() )

