Metadata-Version: 2.4
Name: pygedai
Version: 1.0.0
Summary: PyGEDAI: Generalized Eigenvalue De-Artifacting Instrument in Python
Home-page: https://github.com/JoelKessler/PyGEDAI
Author: Joel Kessler
License: PolyForm Noncommercial License 1.0.0
Project-URL: Source, https://github.com/JoelKessler/PyGEDAI
Project-URL: License, https://github.com/JoelKessler/PyGEDAI/blob/main/LICENSE
Project-URL: Bug Tracker, https://github.com/JoelKessler/PyGEDAI/issues
Project-URL: Documentation, https://github.com/JoelKessler/PyGEDAI#readme
Classifier: License :: Other/Proprietary License
Classifier: Intended Audience :: Science/Research
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: torch>=2.2.0
Provides-Extra: torch
Requires-Dist: torch>=2.2.0; extra == "torch"
Dynamic: author
Dynamic: classifier
Dynamic: description
Dynamic: description-content-type
Dynamic: home-page
Dynamic: license
Dynamic: license-file
Dynamic: project-url
Dynamic: provides-extra
Dynamic: requires-dist
Dynamic: summary

# PyGEDAI Usage Guide

This library implements the Generalized Eigenvalue De-Artifacting Instrument (GEDAI) for EEG cleaning. The core API mirrors the original MATLAB tooling while embracing PyTorch tensors for efficient numerical work. This document provides a concise reference for integrating `gedai()` and `batch_gedai()` into your projects while staying faithful to the algorithmic description in Ros et al. (2025).

---

## What GEDAI Does

GEDAI (Generalized Eigenvalue De-Artifacting Instrument) is an unsupervised, theoretically grounded denoiser for heavily contaminated EEG. It contrasts each epoch’s covariance with a physics-based forward model (leadfield), retaining only components that behave like genuine neural activity, no clean calibration data or manual supervision required.

**Core mechanism:** A generalized eigenvalue decomposition (GEVD) compares the data covariance (`dataCOV`) with a leadfield-derived reference covariance (`refCOV`). Components aligned with the brain subspace are kept; orthogonal components are treated as artifacts. The Signal & Noise Subspace Alignment Index (SENSAI) automatically selects the optimal rejection threshold. See [Ros et al., 2025](https://doi.org/10.1101/2025.10.04.680449) for full details.

---

## Background and References

- Ros, T., Férat, V., Huang, Y., Colangelo, C., Kia, S. M., Wolfers, T., Vulliemoz, S., & Michela, A. (2025). *Return of the GEDAI: Unsupervised EEG Denoising based on Leadfield Filtering*. bioRxiv. https://doi.org/10.1101/2025.10.04.680449
- Original MATLAB/EEGLAB plugin: https://github.com/neurotuning/GEDAI-master (this Python port follows the architecture and processing stages documented there).


---

## Quick Start

Install from PyPI for a quick start:

```bash
pip install pygedai
```

```python
import torch
from pygedai import gedai, batch_gedai

eeg = torch.load("eeg.pt") # (channels, samples)
leadfield = torch.load("leadfield.pt") # (channels, channels)

# Clean a single recording
result = gedai(eeg, sfreq=100.0, leadfield=leadfield)
cleaned = result["cleaned"]

# Clean a mini-batch of trials
batch = eeg.unsqueeze(0) # add batch dimension
cleaned_batch = batch_gedai(batch, sfreq=100.0, leadfield=leadfield)

# result contains cleaned EEG, per-band thresholds, and SENSAI quality metrics
print(f"SENSAI quality score: {result['sensai_score']:.3f}")
```

The notebook `testing/Test.ipynb` covers an end-to-end example, including plots that compare raw and cleaned signals.

---

## Real Time Streaming

`GEDAIStream` keeps a rolling buffer of incoming EEG chunks, periodically recomputes artifact thresholds, and applies them to each new chunk without reinitializing the optimizer. Use the `gedai_stream()` factory to create a configured stream when you need continuous denoising or want to embed GEDAI inside acquisition software.

- `sfreq`: Sampling frequency in Hz.
- `leadfield`: Square reference covariance (`channels × channels`) that anchors the GEVD.
- `threshold_update_interval_sec`: How often (seconds of streamed data) to refresh artifact thresholds once the initial pass has completed.
- `initial_threshold_delay_sec`: Minimum buffered duration before the first threshold computation, letting the optimizer see a representative window.
- `buffer_max_sec`: Size cap (seconds) on the rolling buffer used for threshold estimation to keep memory usage bounded.
- `processing_window_sec`: Optional batching window (seconds) for cleaning. When set, incoming chunks are concatenated until the window length is reached, then processed as a single batch; synchronous calls return `None` until enough data is accumulated.
- `moving_window_chunk_sec`: Size of the raw-history tail (seconds) that is prepended to every chunk before cleaning. GEDAI uses this overlapping context to avoid boundary artifacts; when both options are set the moving-window duration must exceed `processing_window_sec` so enough historical context is available beyond the active chunk.
- `denoising_strength`: Same semantics as `gedai()` (`"auto"`, `"auto-"`, `"auto+"`, or numeric); governs artifact rejection aggressiveness.
- `epoch_size_in_cycles`, `lowcut_frequency`, `wavelet_levels`, `matlab_levels`: Wavelet configuration forwarded to `gedai()`, controlling frequency resolution and band selection.
- `device`, `dtype`: Target torch device/dtype for buffering and computation.
- `TolX`, `maxiter`: Convergence tolerance and iteration cap for SENSAI's golden-section search during threshold discovery.
- `max_concurrent_chunks`: Upper bound on how many chunks are allowed to be in-flight at once when using callbacks. Set to `-1` to disable back-pressure entirely and let submissions proceed without waiting.
- `num_workers`: Thread pool size used for asynchronous cleaning; pass an explicit value to cap executor threads even when `max_concurrent_chunks=-1`. When omitted it mirrors `max_concurrent_chunks`, or falls back to Python's default when back-pressure is disabled.

```python
import torch
from pygedai import gedai_stream

leadfield = torch.load("leadfield_61ch.pt")
stream = gedai_stream(sfreq=250.0, leadfield=leadfield, device="cpu", threshold_update_interval_sec=10.0, initial_threshold_delay_sec=5.0)

with stream:
  for chunk in acquire_eeg_chunks():
    cleaned_chunk = stream.next(chunk)
    handle_cleaned_eeg(cleaned_chunk)
```

When you need non-blocking streaming, pass a `callback` to `stream.next`, control back-pressure with `max_concurrent_chunks`, and adjust `num_workers` to set the thread count. Setting `max_concurrent_chunks=-1` disables back-pressure entirely so the main loop never waits, while `num_workers` caps the executor's parallelism when provided.

```python
cleaned_chunks: dict[int, torch.Tensor] = {}

def handle_cleaned_chunk(cleaned_chunk: torch.Tensor, chunk_index: int, raw_chunk: torch.Tensor) -> None:
  # Cache results while the main loop keeps submitting new chunks.
  cleaned_chunks[chunk_index] = cleaned_chunk.detach().cpu()

stream = gedai_stream(
  sfreq=fs,
  leadfield=leadfield_cov,
  initial_threshold_delay_sec=10.0,
  threshold_update_interval_sec=10.0,
  max_concurrent_chunks=-1,
  num_workers=4,
)

with stream:
  for idx, chunk in enumerate(eeg_chunks):
    stream.next(chunk, callback=handle_cleaned_chunk)
    # Perform logging, plotting, or acquisition work while GEDAI runs on a worker thread.

# Cleaned tensors are available in cleaned_chunks once the callbacks have fired.
```

If you set `processing_window_sec`, the stream buffers consecutive chunks until that many seconds of data are collected. Each window is cleaned (and delivered via callback, if provided) as a single block so downstream consumers always see window-aligned segments.
When `callback` is omitted, `next()` returns `None` until a full window has been accumulated and then yields the cleaned window-sized tensor.

Configure `moving_window_chunk_sec` when you want each chunk (or processing window) to include a short tail of the immediately preceding raw data. This overlapping context keeps GEDAI’s wavelet bands from seeing abrupt edges, reduces artifacts near chunk boundaries, and determines how much history is stored in the stream’s `state`. Leave it unset to run on non-overlapping chunks, or choose a value strictly greater than `processing_window_sec` to keep extra historical context beyond the active window.

Threshold updates run on the main streaming thread. When it's time to refresh, the stream waits for all currently running cleaning jobs to finish, then recomputes the thresholds, and only after that lets new chunks be processed. Jobs that were already running use the old thresholds. All chunks after the update use the new ones.

The notebook `testing/RealTimeEEG.ipynb` contains a full example that pairs this pattern with a `queue.Queue` to plot real-time updates while GEDAI runs concurrently.

Call `stream.reset()` to clear thresholds while keeping the leadfield or `stream.close()` when shutting down the pipeline. The `state` property surfaces the current buffer and thresholds so you can checkpoint progress between sessions.

---

## Deployment & Local Testing

### Build source and wheel distributions

Use the existing `setup.py` helper to create both source (`sdist`) and wheel (`bdist_wheel`) artifacts before publishing or sharing locally:

```bash
python setup.py sdist bdist_wheel
```

The resulting archives land in `dist/` and are ready for installation in any compatible Python environment.

As a final step to publish the built artifacts to PyPI, upload them with Twine:

```bash
twine upload dist/*
```

### Install into a fresh environment (example with conda)

```bash
conda create -n pygedai python=3.12 -y
conda activate pygedai
pip install mne
pip install "torch==2.2.2"
pip install "numpy==1.26.4"
pip install dist/pygedai-1.0.0-py3-none-any.whl --force-reinstall
```

Adjust the Python version and dependency pins as needed for your platform (the above works well on Intel macOS).

### Verify dependency versions

```bash
python - <<'PY'
import torch, numpy as np
print("torch:", torch.__version__, "numpy:", np.__version__)
PY
```

No errors should be emitted; the versions printed should match your expectations.

### Install from PyPI instead

Once a release is published, install the public package (with optional Torch extras) directly:

```bash
pip install mne
pip install "numpy==1.26.4"
pip install "pygedai[torch]"
```

### Smoke-test the pipeline locally

Run this script to ensure PyGEDAI processes bundled sample data end to end:

```bash
python - <<'PY'
import pathlib
import torch
import mne
import numpy as np
from pygedai import gedai

root = pathlib.Path.cwd()
raw_filepath = root / "testing" / "samples" / "with_artifacts" / "artifact_jumps.set"
print(raw_filepath)
raw = mne.io.read_raw_eeglab(str(raw_filepath), preload=True)
raw.set_eeg_reference(ref_channels="average", projection=False, verbose=False)

device = "cuda" if torch.cuda.is_available() else "cpu"

eeg = torch.from_numpy(raw.get_data(picks="eeg")).to(device=device, dtype=torch.float32)

leadfield_filepath = root / "testing" / "leadfield_calibrated" / "leadfield4GEDAI_eeg_61ch.npy"
leadfield = torch.from_numpy(np.load(str(leadfield_filepath))).to(device=device, dtype=torch.float32)

result = gedai(
  eeg,
  sfreq=raw.info["sfreq"],
  denoising_strength="auto",
  leadfield=leadfield,
  device=device,
)

cleaned = result["cleaned"].detach().cpu().numpy()

print("cleaned shape:", cleaned.shape)
print("SENSAI score:", float(result["sensai_score"]))
PY
```

Successful execution prints the cleaned array shape and a SENSAI quality score.

---

## `gedai()`

`gedai(eeg, sfreq, denoising_strength="auto", leadfield=None, *, epoch_size_in_cycles=12.0, lowcut_frequency=0.5, wavelet_levels=9, matlab_levels=None, chanlabels=None, device="cpu", dtype=torch.float32, skip_checks_and_return_cleaned_only=False, batched=False, verbose_timing=False, TolX=1e-1, maxiter=500)`

### Purpose

Execute the GEDAI pipeline on a single EEG recording shaped `(channels, samples)` by applying rank-safe referencing, broadband denoising, multi-resolution wavelet cleaning, and artifact scoring.

### Required Parameters

- `eeg`: PyTorch tensor or array-like with shape `(channels, samples)`. For best performance convert to a torch tensor before calling.
- `sfreq`: Sampling frequency in Hertz. Guides epoch sizing and band selection.
- `leadfield`: `(channels, channels)` reference covariance matrix derived from your EEG forward model (leadfield) that defines the theoretical brain signal subspace. Accepts a filepath, numpy array, or torch tensor. The row and column order must match the EEG channel ordering because tensors carry no channel labels.

### Key Optional Parameters

- `denoising_strength`: Controls artifact suppression aggressiveness.
  - `"auto"` (default): SENSAI-optimized threshold (noise multiplier = 3.0).
  - `"auto-"`: More aggressive filtering (noise multiplier = 6.0).
  - `"auto+"`: More conservative filtering (noise multiplier = 1.0).
  - Numeric value (`0.0–12.0` typical): Manual threshold passed directly to the optimizer.
  Internally this value is forwarded to `artifact_threshold_type` in `gedai_per_band()`.
- `epoch_size_in_cycles`: Number of wave cycles to cover when determining per-band epoch lengths (default `12.0`). Lower values shorten high-frequency epochs; higher values lengthen low-frequency epochs.
- `lowcut_frequency`: Exclude wavelet bands whose upper frequency bound is at or below this threshold (Hz). Defaults to `0.5` to remove slow drifts.
- `wavelet_levels`: Number of Haar MODWT levels when `matlab_levels` is `None`. Typical values fall between `7` and `9`.
- `matlab_levels`: Alternative to `wavelet_levels`, recreating MATLAB level numbering with `2**matlab_levels + 1` bands. Leave `None` unless porting MATLAB scripts directly.
- `chanlabels`: Placeholder for channel label remapping. Currently not implemented and raises an error when supplied.
- `device`: Target torch device such as `"cpu"` or `"cuda"`. EEG data, leadfield, and internal buffers move to this device.
- `dtype`: Torch dtype used during computation, defaulting to `torch.float32` for a balanced memory and compute footprint; set `torch.float64` when maximum numerical accuracy is required and resources permit.
- `skip_checks_and_return_cleaned_only`: When `True`, bypass validation and return only the cleaned tensor to reduce overhead.
- `batched`: Internal flag used by `batch_gedai()`. Leave `False` in user-facing calls.
- `verbose_timing`: Enables profiling markers emitted by `profiling.py`, useful for benchmarking.
- `TolX`: Convergence tolerance for the golden-section search used during automatic thresholding (default `1e-1`).
- `maxiter`: Maximum iterations allowed for the threshold optimizer (default `500`).

The broadband stage follows MATLAB by using a 1 s epoch (rounded to an even number of samples) before the wavelet decomposition step.

### Returns

By default returns a dictionary with:

- `cleaned`: Denoised EEG `(channels, samples)`.
- `artifacts`: Removed components (`input_referenced - cleaned`).
- `sensai_score`: Overall quality metric (higher means better alignment with the brain subspace).
- `sensai_score_per_band`: Per-band SENSAI scores (length = number of wavelet bands plus the broadband pass).
- `artifact_threshold_per_band`: Thresholds applied to each wavelet band.
- `artifact_threshold_broadband`: Threshold used during the initial broadband pass.
- `epoch_size_used`: Actual epoch duration in seconds after enforcing an even sample count.
- `refCOV`: Reference covariance matrix used for GEVD.
- `epoch_sizes_per_band`: Per-band epoch durations (seconds) derived from `epoch_size_in_cycles`.
- `lowcut_frequency_used`: Effective low-cut frequency after adjusting for data length constraints.

When `skip_checks_and_return_cleaned_only=True`, the function returns only the `cleaned` tensor.

### Typical Workflow

1. Load or calculate a `(channels, channels)` leadfield covariance from `leadfield_calibrated/` or your own pipeline.
2. Convert raw EEG to a torch tensor and send it to the intended device and dtype.
3. Call `gedai(...)` and capture the cleaned signal along with diagnostics.
4. Visualize the results using utilities such as `plot_eeg` in `testing/Test.ipynb`.

---

## `batch_gedai()`

`batch_gedai(eeg_batch, sfreq, denoising_strength="auto", leadfield=None, *, epoch_size_in_cycles=12.0, lowcut_frequency=0.5, wavelet_levels=9, matlab_levels=None, chanlabels=None, device="cpu", dtype=torch.float32, parallel=True, max_workers=None, verbose_timing=False, TolX=1e-1, maxiter=500)`

### Purpose

Vectorize the GEDAI pipeline across a batch dimension. Input tensors must be shaped `(batch, channels, samples)`. Each sample is processed independently, optionally in parallel via a thread pool.

### Required Parameters

- `eeg_batch`: PyTorch tensor containing EEG recordings arranged as `(batch, channels, samples)`.
- `sfreq`: Sampling frequency shared across the batch.
- `leadfield`: `(channels, channels)` reference covariance matrix reused for every batch element. Ensure its row and column order mirrors the channel order in `eeg_batch`.

### Key Optional Parameters

- `denoising_strength`, `epoch_size_in_cycles`, `lowcut_frequency`, `wavelet_levels`, `matlab_levels`, `chanlabels`, `device`, `dtype`, `TolX`, `maxiter`: Match the semantics of the corresponding arguments on `gedai()` and are forwarded per sample.
- `parallel`: When `True`, executes each batch element in a `ThreadPoolExecutor`. Set to `False` for serial execution or debugging.
- `max_workers`: Overrides the number of worker threads when `parallel` is enabled. Defaults to Python's heuristic based on CPU count.
- `verbose_timing`: Aggregates profiling information across the batch to assist throughput measurements.

### Returns

PyTorch tensor shaped `(batch, channels, samples)` containing the cleaned EEG for each input sample. Internally the function gathers the `cleaned` value from each `gedai()` call and stacks the results.

### Usage Example

```python
from pathlib import Path
import torch
from pygedai import batch_gedai

project_root = Path.cwd()
eeg_trial = torch.load(project_root / "testing" / "samples" / "with_artifacts" / "artifact_jumps_tensor.pt")
leadfield = torch.from_numpy(np.load(project_root / "testing" / "leadfield_calibrated" / "leadfield4GEDAI_eeg_61ch.npy")).to(device)
batch = eeg_trial.unsqueeze(0)
cleaned = batch_gedai(batch, sfreq=125.0, leadfield=leadfield, verbose_timing=True)
```

This mirrors the workflow shown in `testing/Test.ipynb`, where the cleaned batch is plotted against the raw recording.

---

## Tips and Troubleshooting

- Ensure the leadfield reference covariance shape matches the EEG channel count. The functions raise a `ValueError` when dimensions disagree.
- **Channel order is critical:** Double-check that the row and column order of the reference covariance matches your EEG channel order (e.g., channel index 0 in both tensors corresponds to C1); misalignment silently degrades cleaning quality.
- The `leadfield` parameter should be a reference covariance computed from a forward model of your montage. Precomputed examples are available in `leadfield_calibrated/`.
- **Average referencing**: GEDAI applies a non-rank-deficient average reference (dividing by `n_channels + 1`) to prevent ICA ghost components (see [Kim et al., 2023](https://doi.org/10.3389/frsip.2023.1064138)). Do not apply standard average referencing before calling GEDAI.
- The default threshold search uses golden-section (`"parabolic"`). An optional debug mode (`"grid"`) exhaustively evaluates thresholds from 0.0 to 12.0 in 0.1 steps and is roughly 100× slower.
- GEDAI typically processes ~1 s of 64-channel EEG in 0.5–2 s on CPU, depending on `wavelet_levels` and `denoising_strength`. CPU execution is usually faster than GPU because the `sensai_fminbnd` minimization dominates runtime and benefits little from GPU acceleration.
- Use `batch_gedai()` for multiple independent trials or subjects; with `parallel=True` and adequate CPU cores, throughput scales nearly linearly with batch size.
- If thread-pool contention or hangs arise when running `batch_gedai()` in parallel mode, set single-threaded math libraries before importing torch:
  ```python
  import os
  os.environ["OMP_NUM_THREADS"] = "1"
  os.environ["MKL_NUM_THREADS"] = "1"
  os.environ["OPENBLAS_NUM_THREADS"] = "1"
  os.environ["NUMEXPR_NUM_THREADS"] = "1"
  os.environ["VECLIB_MAXIMUM_THREADS"] = "1"
  os.environ["BLIS_NUM_THREADS"] = "1"

  import torch
  try:
      torch.set_num_threads(1) # intra-op
      torch.set_num_interop_threads(1) # inter-op
  except RuntimeError:
      pass # torch was already initialised
  ```
  The `set_num_threads` calls must run before PyTorch initialises; future Python 3.14+ releases are expected to reduce the need for this workaround.
- The pipeline enforces even epoch lengths. Incomplete epochs are padded via reflection rather than cropped, and the padding is trimmed after denoising.
- Adjust `epoch_size_in_cycles` or `lowcut_frequency` when targeting specific bandwidths: higher cycle counts improve low-frequency stability while higher low-cut values skip slow drifts and reduce required epoch durations.
- When running on GPU, move both EEG data and leadfield tensors to the target device prior to calling the API.
- Enable `verbose_timing=True` during development to gather profiling markers such as `start_batch`, `modwt_analysis`, and `batch_done`.
- If you only require cleaned signals, set `skip_checks_and_return_cleaned_only=True` to avoid collecting diagnostic metadata.
- Automatic threshold selection relies on `sensai_fminbnd` (golden-section minimization) and may run up to `maxiter` iterations; supplying fixed thresholds dramatically reduces runtime.


## License

This port follows the PolyForm Noncommercial License 1.0.0, identical to the original GEDAI plugin. The core algorithms are patent pending; commercial use requires obtaining the appropriate license from the patent holders. See [LICENSE](https://github.com/JoelKessler/PyGEDAI/blob/main/LICENSE) for full terms and contact information.

---

## Further Resources

- `GEDAI.py`: Core implementation with inline comments describing the Haar MODWT pipeline, SENSAI scoring, and artifact reconstruction.
- `auxiliaries/`: Helper modules including `GEDAI_per_band.py`, `SENSAI_basic.py`, and `clean_EEG.py`, which provide per-band denoisers and optimization routines.
- `testing/Test.ipynb`: Practical notebook demonstrating data loading, covariance handling, calls to `batch_gedai()`, and visualization of results.

For issues or feature requests, please open a GitHub issue in this repository.
