FGCZ banner

Project: RCTD Vectorization Accuracy Check Datasets: Spacexr Slide-Seq Vignette (19 types, 100 pixels) + Xenium Region 3 (45 types, 66k pixels) Objective: Compare numerical parity, runtimes, and spatial mapping of RCTD implementations against the original R spacexr. Rendered: 2026-03-01 18:06:53 CET

Equivalence Metrics Against Gold Standard (R CPU)

This section evaluates how well the diverse backend implementations correlate with the original spacexr CPU implementation.

Table 1: Execution Time Profiling

The total doublet deconvolution execution time of each platform over 2,000 spatial cells.

Implementation Framework Device Evaluated Execution Time (s) Relative Speedup
spacexr (Original R) CPU Standard 14.61 s 1.00x
spacexr (Reticulate) GPU (PyTorch) 2.58 s 5.66x
pyrctd (Native Python) CPU (16 Cores) 3.05 s 4.8x
pyrctd (Native Python) GPU (PyTorch) 4.71 s 3.1x
rctd-py (Native JAX) CPU (16 Cores) 8.77 s 1.67x
rctd-py (Native JAX) GPU (XLA) 14.52 s 1.01x

Note: rctd-py execution time includes heavy ~60s XLA static compilation penalties.

Table 2: Convergence Parity (vs Spacexr CPU)

Measurements verifying identical mathematical constraints directly against the original R CPU Doublet mode calculations.

Implementation Weight Correlation spot_class Concordance first_type Concordance
Target Metric > 0.9900 > 95.0% > 97.0%
spacexr GPU 1 100.0% 100.0%
pyrctd CPU -0.0297 60.6% 6.2%
pyrctd GPU -0.0021 60.6% 7.8%
rctd-py JAX CPU 0.8933 85.9% 96.9%
rctd-py JAX GPU 0.8933 85.9% 96.9%

Visual Profiling

library(tidyr)

# Parse Time DF
methods <- c("spacexr (CPU)", "spacexr (GPU)", "pyrctd (CPU)", "pyrctd (GPU)", "rctd-py (JAX CPU)", "rctd-py (JAX GPU)")
times <- c(py$times_r['spacexr_cpu'], py$times_r['spacexr_gpu'], py$times_py['pytorch_cpu'], py$times_py['pytorch_gpu'], py$times_py['jax_cpu'], py$times_py['jax_gpu'])
time_df <- data.frame(Method = factor(methods, levels = methods), Time = times)

p1 <- ggplot(time_df, aes(x = Method, y = Time, fill = Method)) + 
  geom_bar(stat="identity") + 
  theme_classic() + 
  theme(axis.text.x = element_text(angle = 45, hjust = 1)) +
  scale_fill_brewer(palette = "Set2") +
  labs(title="Execution Times (s)", y="Seconds") +
  guides(fill="none")

# Parse Parity DF
models <- c("spacexr GPU", "pyrctd CPU", "pyrctd GPU", "JAX CPU", "JAX GPU")
w_corr <- c(py$m_sx_gpu[[3]], py$m_pt_cpu[[3]], py$m_pt_gpu[[3]], py$m_jax_cpu[[3]], py$m_jax[[3]])
spot_conc <- c(py$m_sx_gpu[[1]]/100, py$m_pt_cpu[[1]]/100, py$m_pt_gpu[[1]]/100, py$m_jax_cpu[[1]]/100, py$m_jax[[1]]/100)
type_conc <- c(py$m_sx_gpu[[2]]/100, py$m_pt_cpu[[2]]/100, py$m_pt_gpu[[2]]/100, py$m_jax_cpu[[2]]/100, py$m_jax[[2]]/100)

parity_df <- data.frame(Method=models, Weight=w_corr, Spot=spot_conc, Type=type_conc) %>%
  pivot_longer(cols = c("Weight", "Spot", "Type"), names_to="Metric", values_to="Score")
parity_df$Method <- factor(parity_df$Method, levels=models)

p2 <- ggplot(parity_df, aes(x = Method, y = Score, fill = Metric)) +
  geom_bar(stat="identity", position="dodge") +
  theme_classic() +
  theme(axis.text.x = element_text(angle = 45, hjust = 1)) +
  scale_y_continuous(labels = scales::percent, limits=c(0, 1.0)) +
  scale_fill_manual(values=c("Weight"="#1b9e77", "Spot"="#d95f02", "Type"="#7570b3")) +
  labs(title="Mathematical Parity vs Spacexr CPU", y="Concordance / Correlation")

library(gridExtra)
grid.arrange(p1, p2, ncol=2)

Correlation Heatmap

library(reshape2)
corr_melt <- melt(as.matrix(py$corr_df))
colnames(corr_melt) <- c("Method1", "Method2", "Correlation")

ggplot(corr_melt, aes(x=Method1, y=Method2, fill=Correlation)) +
  geom_tile(color="white") +
  scale_fill_gradient2(low="#d73027", mid="white", high="#1b9e77",
                       midpoint=0, limits=c(-1, 1), name="Pearson\nCorrelation") +
  theme_minimal() + 
  theme(axis.text.x = element_text(angle = 45, vjust = 1, hjust = 1)) +
  geom_text(aes(label = round(Correlation, 3)), color = "black", size = 3.5) +
  coord_fixed() +
  labs(title="Pairwise Weight Correlation Heatmap", x="", y="")

Spatial Map Check

df_base <- py$df_sx_cpu
df_base$Method <- "Spacexr (R CPU Baseline)"

df_pt <- py$df_pt_gpu
df_pt$Method <- "Pyrctd (PyTorch GPU)"

df_jax <- py$df_jax
df_jax$Method <- "rctd-py (JAX GPU)"

df_combined <- rbind(
  df_base[, c("x_centroid", "y_centroid", "spot_class", "Method")],
  df_pt[, c("x_centroid", "y_centroid", "spot_class", "Method")],
  df_jax[, c("x_centroid", "y_centroid", "spot_class", "Method")]
)
df_combined <- df_combined[!is.na(df_combined$x_centroid), ]
df_combined$Method <- factor(df_combined$Method, levels = c("Spacexr (R CPU Baseline)", "Pyrctd (PyTorch GPU)", "rctd-py (JAX GPU)"))

ggplot(df_combined, aes(x = x_centroid, y = y_centroid, color = spot_class)) +
  geom_point(size = 1.6, alpha = 0.8) +
  theme_void() +
  theme(aspect.ratio = 1, legend.position = "right") +
  scale_color_brewer(palette = "Set1") +
  facet_wrap(~Method, ncol=3) +
  ggtitle("Pixel Classification Map Correlated Cross-Implementations")

Confident Cell Types

top_types <- df_pt %>% 
  filter(spot_class %in% c("singlet", "doublet_certain", "doublet_uncertain")) %>%
  count(first_type, sort = TRUE) %>%
  head(12) %>%
  pull(first_type)

df_subset <- df_pt %>% filter(first_type %in% top_types & spot_class %in% c("singlet", "doublet_certain", "doublet_uncertain"))

if (nrow(df_subset) > 0) {
  p <- ggplot(df_subset, aes(x = x_centroid, y = y_centroid, color = as.factor(first_type))) +
    geom_point(size = 0.8, alpha = 0.7) +
    theme_void() +
    theme(aspect.ratio = 1) +
    facet_wrap(~as.factor(first_type), ncol = 4) +
    scale_color_discrete(guide = "none") +
    ggtitle("Spatial Cell Type Distribution")
  print(p)
} else {
  cat("No confident cell types found for subsetting.")
}

Xenium Region 3 Validation

Real-world validation on Xenium duodenum Region 3 (45 cell types, ~66k spatial pixels) comparing rctd-py (JAX GPU) against R spacexr CPU.

xm <- py$xm

cat("### Dataset Overview\n\n")

Dataset Overview

cat(sprintf("| Property | Value |\n"))
Property | Value |
cat(sprintf("|----------|-------|\n"))

|———-|——-|

cat(sprintf("| Spatial pixels (total) | %s |\n", format(xm$n_spatial, big.mark=",")))
Spatial pixels (total) | 66,611 |
cat(sprintf("| Pixels after UMI filter | %s |\n", format(xm$n_filtered, big.mark=",")))
Pixels after UMI filter | 58,191 |
cat(sprintf("| R spacexr pixels | %s |\n", format(xm$r_n_cells, big.mark=",")))
R spacexr pixels | 62,313 |
cat(sprintf("| Common pixels compared | %s |\n", format(xm$n_common, big.mark=",")))
Common pixels compared | 58,188 |
cat(sprintf("| Cell types | %d |\n", xm$n_types))
Cell types | 45 |
cat(sprintf("| rctd-py elapsed (GPU) | %.1f s |\n\n", xm$elapsed_s))
rctd-py elapsed (GPU) | 3477.7 s |
cat("### Concordance Metrics (rctd-py JAX GPU vs spacexr CPU)\n\n")

Concordance Metrics (rctd-py JAX GPU vs spacexr CPU)

cat("| Metric | Value |\n")
Metric | Value |
cat("|--------|-------|\n")

|——–|——-|

cat(sprintf("| Dominant type agreement | **%.1f%%** |\n", xm$dominant_type_agreement * 100))
Dominant type agreement | 99.7% |
cat(sprintf("| Per-pixel weight correlation (median) | **%.4f** |\n", xm$pixel_corr_median))
Per-pixel weight correlation (median) | 1.0000 |
cat(sprintf("| Per-pixel weight correlation (mean) | %.4f |\n", xm$pixel_corr_mean))
Per-pixel weight correlation (mean) | 0.9998 |
cat(sprintf("| Pixels with correlation > 0.8 | %.1f%% |\n", xm[["pixel_corr_gt_0.8"]] * 100))
Pixels with correlation > 0.8 | 100.0% |
cat(sprintf("| Pixels with correlation > 0.5 | %.1f%% |\n\n", xm[["pixel_corr_gt_0.5"]] * 100))
Pixels with correlation > 0.5 | 100.0% |
cat("### Spot Class Distribution (rctd-py)\n\n")

Spot Class Distribution (rctd-py)

cat("| Class | Count | Fraction |\n")
Class | Count | Fraction |
cat("|-------|-------|----------|\n")

|——-|——-|———-|

spot <- xm$spot_class_distribution
total <- Reduce("+", spot)
for (nm in names(spot)) {
  cat(sprintf("| %s | %s | %.1f%% |\n", nm, format(spot[[nm]], big.mark=","), spot[[nm]] / total * 100))
}
reject | 8,476 | 14.6% |
singlet | 33,106 | 56.9% |
doublet_certain | 8,343 | 14.3% |
doublet_uncertain | 8,266 | 14.2% |
cat("\n")

Elapsed Time Comparison

xm <- py$xm

# Build timing data
r_total   <- xm$r_elapsed_s
py_total  <- xm$elapsed_s
b200_total <- if (!is.null(xm$b200_elapsed_s)) xm$b200_elapsed_s else NA
py_sigma  <- if (!is.null(xm$sigma_elapsed_s))  xm$sigma_elapsed_s  else NA
py_deconv <- if (!is.null(xm$deconv_elapsed_s)) xm$deconv_elapsed_s else NA

has_r_timing   <- !is.null(r_total)    && !is.na(r_total)
has_b200       <- !is.null(b200_total) && !is.na(b200_total)
has_breakdown  <- !is.na(py_sigma) && !is.na(py_deconv)

# ── Total time barplot ──
impls  <- "rctd-py (JAX GPU, L40S)"
totals <- py_total
if (has_r_timing) { impls  <- c("spacexr (R CPU, 8 cores)", impls);  totals <- c(r_total,    totals) }
if (has_b200)     { impls  <- c(impls, "rctd-py (JAX GPU, Blackwell B200)\u2020"); totals <- c(totals, b200_total) }

timing_df <- data.frame(
  Implementation = factor(impls, levels = rev(impls)),
  Minutes = totals / 60,
  stringsAsFactors = FALSE
)

colors <- c(
  "spacexr (R CPU, 8 cores)"                         = "#d95f02",
  "rctd-py (JAX GPU, L40S)"                          = "#7570b3",
  "rctd-py (JAX GPU, Blackwell B200)\u2020"           = "#1b9e77"
)

p_total <- ggplot(timing_df, aes(x = Implementation, y = Minutes, fill = Implementation)) +
  geom_bar(stat = "identity", width = 0.6) +
  geom_text(aes(label = sprintf("%.1f min", Minutes)), hjust = -0.1, size = 4) +
  coord_flip() +
  scale_fill_manual(values = colors) +
  theme_classic() +
  theme(legend.position = "none") +
  labs(
    title = "End-to-End Elapsed Time: Xenium Region 3 (58k pixels, doublet mode)",
    caption = "\u2020 Blackwell B200: sigma 173s (measured) + doublet 36s (measured separately); see text.",
    x = "", y = "Minutes"
  ) +
  ylim(0, max(timing_df$Minutes) * 1.2)
print(p_total)

if (has_breakdown) {
  breakdown_df <- data.frame(
    Phase = c("Sigma estimation", "Deconvolution (GPU)"),
    Seconds = c(py_sigma, py_deconv),
    stringsAsFactors = FALSE
  )
  breakdown_df$Phase <- factor(breakdown_df$Phase,
                                levels = rev(breakdown_df$Phase))
  breakdown_df$Minutes <- breakdown_df$Seconds / 60

  p_breakdown <- ggplot(breakdown_df, aes(x = Phase, y = Minutes, fill = Phase)) +
    geom_bar(stat = "identity", width = 0.6) +
    geom_text(aes(label = sprintf("%.1f min (%.0f%%)", Minutes,
                                   Seconds / py_total * 100)),
              hjust = -0.1, size = 4) +
    coord_flip() +
    scale_fill_manual(values = c("Sigma estimation" = "#e7298a",
                                  "Deconvolution (GPU)" = "#1b9e77")) +
    theme_classic() +
    theme(legend.position = "none") +
    labs(title = "rctd-py Time Breakdown: Sigma Estimation vs GPU Deconvolution",
         x = "", y = "Minutes") +
    ylim(0, max(breakdown_df$Minutes) * 1.3)
  print(p_breakdown)
}

Key finding: rctd-py completed in 58.0 min for 58k pixels in doublet mode.

Sigma optimisation (applied in this run): The sigma estimation step — previously the dominant bottleneck at ~66 min — was reduced to ~3 min via three targeted changes: (1) caching the 437×437 tridiagonal matrix inverse that depends only on the fixed λ-grid, eliminating ~144 redundant O(n³) inversions; (2) precomputing all 126 spline coefficient matrices once at startup; and (3) replacing 85 sequential calc_log_likelihood calls per epoch with a single jax.vmap/jax.jit fused kernel. The sigma value (σ = 0.72) and all accuracy metrics are numerically unchanged.

Note on GPU hardware: The doublet-mode deconvolution runtime varies strongly with GPU memory bandwidth. On HPC-class GPUs (Blackwell B200, ~8 TB/s HBM3e) the IRWLS loop completes in ~36s for 58k pixels; on the L40S (~864 GB/s GDDR6, a rendering/inference card) the same step takes ~55 min due to the memory-bound nature of the fancy-indexed Q-matrix lookups. Total end-to-end time on a B200 is ~3.5 min vs ~51 min for R spacexr.

Per-Type Correlation (Top 20)

tc <- py$x_type_corr
tc_df <- data.frame(CellType = names(tc), Correlation = as.numeric(tc))
tc_df <- tc_df[order(-tc_df$Correlation), ]
tc_df <- head(tc_df, 20)
tc_df$CellType <- factor(tc_df$CellType, levels = rev(tc_df$CellType))

ggplot(tc_df, aes(x = CellType, y = Correlation, fill = Correlation)) +
  geom_bar(stat = "identity") +
  coord_flip() +
  scale_fill_gradient2(low = "#d73027", mid = "#ffffbf", high = "#1b9e77", midpoint = 0.5, limits = c(0, 1)) +
  theme_classic() +
  labs(title = "Per-Type Weight Correlation (rctd-py vs spacexr)", x = "", y = "Pearson Correlation") +
  guides(fill = "none")

Dominant Type Distribution Comparison

td <- py$x_type_dist
td <- td[order(-pmax(td$python_count, td$r_count)), ]
td <- head(td, 20)

td_long <- tidyr::pivot_longer(td, cols = c("python_count", "r_count"),
                                names_to = "Method", values_to = "Count")
td_long$Method <- ifelse(td_long$Method == "python_count", "rctd-py (JAX GPU)", "spacexr (R CPU)")
td_long$cell_type <- factor(td_long$cell_type, levels = rev(td$cell_type))

ggplot(td_long, aes(x = cell_type, y = Count, fill = Method)) +
  geom_bar(stat = "identity", position = "dodge") +
  coord_flip() +
  scale_fill_manual(values = c("rctd-py (JAX GPU)" = "#1b9e77", "spacexr (R CPU)" = "#d95f02")) +
  theme_classic() +
  labs(title = "Dominant Cell Type Counts: rctd-py vs spacexr (Top 20)", x = "", y = "Pixel Count")

Spatial Cell Type Map (rctd-py JAX GPU)

df_sp <- py$xenium_spatial

spot_colors <- c(
  "singlet"           = "#1b9e77",
  "doublet_certain"   = "#d95f02",
  "doublet_uncertain" = "#7570b3",
  "reject"            = "#cccccc"
)

ggplot(df_sp, aes(x = x_centroid, y = y_centroid, color = spot_class)) +
  geom_point(size = 0.25, alpha = 0.6) +
  scale_color_manual(values = spot_colors, name = "Spot class") +
  theme_void() +
  theme(aspect.ratio = 1, legend.position = "right",
        plot.title = element_text(hjust = 0.5, size = 13)) +
  guides(color = guide_legend(override.aes = list(size = 3))) +
  labs(title = "Pixel Classification Map — rctd-py (JAX GPU, Xenium Region 3)",
       caption = sprintf("Showing %s of %s filtered pixels (random subsample)",
                         format(nrow(df_sp), big.mark=","),
                         format(py$xm$n_filtered, big.mark=",")))

df_sp <- py$xenium_spatial

# Top 12 most frequent first types among singlets
top_types <- df_sp |>
  dplyr::filter(spot_class %in% c("singlet", "doublet_certain", "doublet_uncertain")) |>
  dplyr::count(first_type, sort = TRUE) |>
  head(12) |>
  dplyr::pull(first_type)

df_top <- df_sp |> dplyr::filter(first_type %in% top_types)

type_colors <- unname(pals::polychrome(length(top_types)))
names(type_colors) <- top_types

ggplot(df_top, aes(x = x_centroid, y = y_centroid, color = first_type)) +
  geom_point(size = 0.4, alpha = 0.7) +
  scale_color_manual(values = type_colors, name = "Cell type") +
  theme_void() +
  theme(aspect.ratio = 1, legend.position = "right",
        plot.title = element_text(hjust = 0.5, size = 13)) +
  guides(color = guide_legend(override.aes = list(size = 3))) +
  labs(title = "Dominant Cell Type Map — Top 12 Types (rctd-py JAX GPU)",
       caption = "Only pixels confidently assigned to a single dominant cell type are shown.")

Doublet Mode Concordance

xdm <- py$xdm

cat("| Metric | Value |\n")
Metric | Value |
cat("|--------|-------|\n")

|——–|——-|

cat(sprintf("| Common pixels compared | %s |\n", format(xdm$n_common_doublet, big.mark=",")))
Common pixels compared | 58,187 |
cat(sprintf("| **Spot class agreement** | **%.1f%%** |\n", xdm$spot_class_agreement * 100))
Spot class agreement | 100.0% |
cat(sprintf("| **First type agreement** (non-reject) | **%.1f%%** |\n", xdm$first_type_agreement * 100))
First type agreement (non-reject) | 98.9% |
if (!is.null(xdm$second_type_agreement)) {
  cat(sprintf("| Second type agreement (both doublet) | %.1f%% |\n", xdm$second_type_agreement * 100))
}
Second type agreement (both doublet) | 100.0% |
cat("\n")
cat("### Spot Class Distribution Comparison\n\n")

Spot Class Distribution Comparison

cat("| Class | rctd-py | spacexr | Difference |\n")
Class | rctd-py | spacexr | Difference |
cat("|-------|---------|---------|------------|\n")

|——-|———|———|————|

py_dist <- xdm$py_spot_distribution
r_dist <- xdm$r_spot_distribution
for (cls in c("reject", "singlet", "doublet_certain", "doublet_uncertain")) {
  py_n <- if (!is.null(py_dist[[cls]])) py_dist[[cls]] else 0
  r_n <- if (!is.null(r_dist[[cls]])) r_dist[[cls]] else 0
  diff_n <- py_n - r_n
  cat(sprintf("| %s | %s | %s | %+d |\n", cls, format(py_n, big.mark=","), format(r_n, big.mark=","), diff_n))
}
reject | 8,474 | 8,475 | -1 |
singlet | 33,104 | 33,104 | +0 |
doublet_certain | 8,343 | 8,345 | -2 |
doublet_uncertain | 8,266 | 8,263 | +3 |
cat("\n")
cat("**Doublet comparison data not yet generated.** Run `scripts/extract_r_doublet_full.R` then re-run `scripts/generate_xenium_report_data.py`.\n")
cat("**Xenium data not yet generated.** Run `scripts/generate_xenium_report_data.py` on a GPU node first.\n")

Session Info

sessionInfo()
## R version 4.5.0 (2025-04-11)
## Platform: x86_64-pc-linux-gnu
## Running under: Debian GNU/Linux 12 (bookworm)
## 
## Matrix products: default
## BLAS:   /usr/lib/x86_64-linux-gnu/blas/libblas.so.3.11.0 
## LAPACK: /usr/lib/x86_64-linux-gnu/lapack/liblapack.so.3.11.0  LAPACK version 3.11.0
## 
## locale:
##  [1] LC_CTYPE=en_US.UTF-8       LC_NUMERIC=C              
##  [3] LC_TIME=en_US.UTF-8        LC_COLLATE=en_US.UTF-8    
##  [5] LC_MONETARY=en_US.UTF-8    LC_MESSAGES=en_US.UTF-8   
##  [7] LC_PAPER=en_US.UTF-8       LC_NAME=C                 
##  [9] LC_ADDRESS=C               LC_TELEPHONE=C            
## [11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C       
## 
## time zone: Europe/Zurich
## tzcode source: system (glibc)
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
## [1] reshape2_1.4.5    gridExtra_2.3     tidyr_1.3.2       reticulate_1.44.1
## [5] dplyr_1.2.0       ggplot2_4.0.2    
## 
## loaded via a namespace (and not attached):
##  [1] Matrix_1.7-3       gtable_0.3.6       jsonlite_2.0.0     compiler_4.5.0    
##  [5] maps_3.4.3         tidyselect_1.2.1   Rcpp_1.1.1         stringr_1.6.0     
##  [9] pals_1.10          dichromat_2.0-0.1  jquerylib_0.1.4    scales_1.4.0      
## [13] png_0.1-8          yaml_2.3.12        fastmap_1.2.0      lattice_0.22-7    
## [17] plyr_1.8.9         R6_2.6.1           labeling_0.4.3     generics_0.1.4    
## [21] mapproj_1.2.12     knitr_1.51         tibble_3.3.1       bslib_0.10.0      
## [25] pillar_1.11.1      RColorBrewer_1.1-3 rlang_1.1.7        stringi_1.8.7     
## [29] cachem_1.1.0       xfun_0.56          sass_0.4.10        S7_0.2.1          
## [33] otel_0.2.0         cli_3.6.5          withr_3.0.2        magrittr_2.0.4    
## [37] digest_0.6.39      grid_4.5.0         lifecycle_1.0.5    vctrs_0.7.1       
## [41] evaluate_1.0.5     glue_1.8.0         farver_2.1.2       colorspace_2.1-2  
## [45] purrr_1.2.1        rmarkdown_2.30     tools_4.5.0        pkgconfig_2.0.3   
## [49] htmltools_0.5.9
reticulate::py_config()
## python:         /home/pgueguen/git/rctd-py/.venv/bin/python
## libpython:      /misc/ngseq12/miniforge3/lib/libpython3.10.so
## pythonhome:     /home/pgueguen/git/rctd-py/.venv:/home/pgueguen/git/rctd-py/.venv
## version:        3.10.14 | packaged by conda-forge | (main, Mar 20 2024, 12:45:18) [GCC 12.3.0]
## numpy:          /home/pgueguen/git/rctd-py/.venv/lib/python3.10/site-packages/numpy
## numpy_version:  2.2.6
## 
## NOTE: Python version was forced by RETICULATE_PYTHON