import os
import json
import pandas as pd
import networkx as nx
import numpy as np
from typing import Optional,Union
from datetime import datetime
from pathlib import Path
import networkx as nx
import ray
import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.data import Data
from torch_geometric.utils import from_networkx
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from sklearn.ensemble import RandomForestRegressor
import tempfile
from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler
from .gnn_models import GCN, GAT, SAGE, GIN, process_dropout
from ..utils.logger import get_logger
from scipy.stats import skew
from torch_geometric.utils import add_self_loops
[docs]
class GNNEmbedding:
"""
GNNEmbedding Class for Generating Graph Neural Network (GNN) Based Embeddings.
Attributes:
adjacency_matrix : pd.DataFrame
omics_data : pd.DataFrame
phenotype_data : pd.DataFrame
clinical_data : Optional[pd.DataFrame]
phenotype_col : str
model_type : str
hidden_dim : int
layer_num : int
dropout: Union[bool, float] (if bool, True maps to 0.5, False to 0.0)
num_epochs : int
lr : float
weight_decay : float
gpu : bool
seed : Optional[int]
tune : Optional[bool]
"""
def __init__(
self,
adjacency_matrix: pd.DataFrame,
omics_data: pd.DataFrame,
phenotype_data: Union[pd.Series, pd.DataFrame],
clinical_data: Optional[pd.DataFrame] = None,
phenotype_col: str = "phenotype",
model_type: str = "GAT",
hidden_dim: int = 64,
layer_num: int = 4,
dropout: Union[bool, float] = True,
num_epochs: int = 100,
lr: float = 1e-3,
weight_decay: float = 1e-4,
gpu: bool = False,
activation: str = "relu",
seed: Optional[int] = None,
tune: Optional[bool] = False,
output_dir: Optional[str] = None,
):
"""
Initializes the GNNEmbedding instance.
"""
self.logger = get_logger(__name__)
# Input validation
if adjacency_matrix.empty:
raise ValueError("Adjacency matrix cannot be empty.")
if omics_data.empty:
raise ValueError("Omics data cannot be empty.")
if adjacency_matrix.shape[0] == omics_data.shape[0]:
raise ValueError("Adjacency matrix, omics data must have the same number of samples.")
if clinical_data is not None and clinical_data.empty:
raise ValueError("Clinical data was provided but is empty.")
if isinstance(phenotype_data, pd.Series):
self.phenotype_data = phenotype_data.copy(deep=True)
elif isinstance(phenotype_data, pd.DataFrame):
if phenotype_col and phenotype_col in phenotype_data.columns:
self.phenotype_data = phenotype_data[phenotype_col].copy(deep=True)
elif phenotype_data.shape[1] == 1:
self.phenotype_data = phenotype_data.iloc[:, 0].copy(deep=True)
else:
raise ValueError(
f"Cannot determine phenotype column. "
f"Either provide a single-column DataFrame or set 'phenotype_col' to a valid column name."
)
else:
raise ValueError("Phenotype data must be a Series or a DataFrame.")
if seed is not None:
torch.manual_seed(seed)
np.random.seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
self.seed = seed
self.adjacency_matrix = adjacency_matrix.copy(deep=True)
self.omics_data = omics_data.copy(deep=True)
self.clinical_data = clinical_data.copy(deep=True) if clinical_data is not None else None
self.phenotype_col = phenotype_col
self.model_type = model_type
self.hidden_dim = hidden_dim
self.layer_num = layer_num
self.dropout = process_dropout(dropout)
self.num_epochs = num_epochs
self.lr = lr
self.weight_decay = weight_decay
self.activation = activation
self.gpu = gpu
self.device = torch.device("cuda" if self.gpu and torch.cuda.is_available() else "cpu")
self.logger.info(f"Initialized GNNEmbedding. device={self.device}")
self.model = None
self.data = None
self.embeddings = None
self.tune = tune
if output_dir is None:
self.temp_dir_obj = tempfile.TemporaryDirectory()
self.output_dir = self.temp_dir_obj.name
self.logger.info(f"No output_dir provided; using temporary directory: {self.output_dir}")
else:
self.output_dir = Path(output_dir)
self.logger.info(f"Output directory set to: {self.output_dir}")
self.output_dir.mkdir(parents=True, exist_ok=True)
[docs]
def fit(self) -> None:
"""
Trains the GNN model using the provided data.
"""
self.logger.info("Starting training process.")
try:
node_features = self._prepare_node_features()
node_labels = self._prepare_node_labels()
self.data = self._build_pyg_data(node_features, node_labels)
self.model = self._initialize_gnn_model().to(self.device)
self._train_gnn(self.model, self.data)
self.logger.info("Training completed successfully.")
except Exception as e:
self.logger.error(f"Error during training: {e}")
raise
[docs]
def embed(self, as_df: bool = False) -> Union[torch.Tensor, pd.DataFrame]:
"""
Generates node embeddings.
If tuning is enabled, runs hyperparameter tuning and uses the best configuration.
"""
self.logger.info("Generating node embeddings.")
if not self.tune and (self.model is None or self.data is None):
raise ValueError("Model not trained. Call fit() first.")
if self.tune:
self.logger.info("Tuning is enabled. Running hyperparameter tuning.")
best_config = self.run_gnn_embedding_tuning()
self.logger.info(f"Best tuning config: {best_config}")
self.model_type = best_config["model_type"]
self.hidden_dim = best_config["hidden_dim"]
self.layer_num = best_config["layer_num"]
self.dropout = best_config["dropout"]
self.num_epochs = best_config["num_epochs"]
self.lr = best_config["lr"]
self.weight_decay = best_config["weight_decay"]
self.activation = best_config["activation"]
self.tune = False
self.logger.info(f"Retraining with best config: {best_config}")
self.fit()
self.logger.info("Model retrained with best hyperparameters.")
try:
self.embeddings = self._generate_embeddings(self.model, self.data)
self.logger.info("Node embeddings generated.")
if as_df:
embeddings_df = self._tensor_to_df(self.embeddings, self.adjacency_matrix)
return embeddings_df
else:
return self.embeddings
except Exception as e:
self.logger.error(f"Error during embedding generation: {e}")
raise
def _tensor_to_df(self, embeddings_tensor: torch.Tensor, network: pd.DataFrame) -> pd.DataFrame:
"""
Convert embeddings tensor to a DataFrame with node (feature) names as the index,
and embedding dimension labels as columns.
"""
try:
self.logger.info("Converting embeddings tensor to DataFrame.")
if embeddings_tensor is None:
raise ValueError("Embeddings tensor is empty (None).")
if network is None:
raise ValueError("Network (adjacency matrix) is empty (None).")
if embeddings_tensor.shape[0] != len(network.index):
raise ValueError(
f"Mismatch: embeddings tensor has {embeddings_tensor.shape[0]} rows, "
f"but network index has {len(network.index)} rows."
)
self.logger.debug(f"Embeddings tensor shape: {embeddings_tensor.shape}")
embeddings_df = pd.DataFrame(
embeddings_tensor.numpy(),
index=network.index,
columns=[f"Embed_{i+1}" for i in range(embeddings_tensor.shape[1])]
)
return embeddings_df
except Exception as e:
self.logger.error(f"Error during conversion: {e}")
raise
def _prepare_node_features(self) -> pd.DataFrame:
"""
1. Align network & omics nodes.
2. Compute graph-centralities (pagerank, eigenvector, katz).
3. If clinical_data exists:
- compute Pearson correlations vs. each clinical var.
Else:
- compute mean, log-skew, median-abs-dev of omics per node.
4. Rank-scale every feature to [-1,1].
5. Save and return the final features DataFrame.
"""
self.logger.info("Preparing node features.")
network_features = self.adjacency_matrix.columns
nodes = sorted(network_features.intersection(self.omics_data.columns))
if len(nodes) == 0:
raise ValueError("No common features found between the network and omics data.")
if len(nodes) != len(network_features):
missing = set(network_features) - set(nodes)
self.logger.warning(f"Length of common features: {len(nodes)}")
self.logger.warning(f"Length of network features: {len(network_features)}")
self.logger.warning(f"Missing features: {missing}")
raise ValueError("Mismatch between network features and omics data columns.")
self.logger.info(f"Found {len(nodes)} common features between network and omics data.")
omics_filtered = self.omics_data[nodes]
network_filtered = self.adjacency_matrix.loc[nodes, nodes]
G = nx.from_pandas_adjacency(network_filtered)
pagerank = nx.pagerank(G, alpha=0.85, weight="weight", max_iter=1000)
katz = nx.katz_centrality_numpy(G, alpha=0.005, beta=1.0, weight="weight")
eigenvector = {}
for comp in nx.connected_components(G):
sub = G.subgraph(comp)
try:
ec = nx.eigenvector_centrality(sub, max_iter=1000, tol=1e-6, weight="weight")
except nx.PowerIterationFailedConvergence:
ec = {}
self.logger.warning(
f"Eigenvector centrality failed for component size {len(sub)}; defaulting to 0."
)
for node in sub.nodes():
ec[node] = 0.0
eigenvector.update(ec)
nodes = list(network_filtered.index)
centralities_df = pd.DataFrame({
"pagerank": pagerank,
"eigenvector": eigenvector,
"katz": katz
}).reindex(nodes)
if self.clinical_data is not None and not self.clinical_data.empty:
clinical_cols = list(self.clinical_data.columns)
common_index = self.clinical_data.index.intersection(omics_filtered.index)
if common_index.empty:
raise ValueError("No common indices between omics and clinical data.")
node_features_dict = {}
for node in nodes:
vec = pd.to_numeric(omics_filtered[node].loc[common_index], errors="coerce")
corr_vector = {}
for cvar in clinical_cols:
clinical_series = self.clinical_data[cvar].loc[common_index]
corr_val = vec.corr(clinical_series)
corr_vector[cvar] = corr_val if not pd.isna(corr_val) else 0.0
full_feature_vec = {
"pagerank": pagerank[node],
"eigenvector": eigenvector[node],
"katz": katz[node],
}
full_feature_vec.update(corr_vector)
node_features_dict[node] = full_feature_vec
node_features_df = pd.DataFrame.from_dict(node_features_dict, orient="index")
self.logger.info(f"Built feature matrix with clinical correlations shape: {node_features_df.shape}")
else:
self.logger.warning("No clinical data found. Using centrality measures and statistical features.")
if self.phenotype_data is None or self.phenotype_data.empty:
raise ValueError("No phenotype data available for statistical features.")
pheno = self.phenotype_data.loc[omics_filtered.index].dropna()
stat_features = {}
for node in nodes:
vec = omics_filtered[node].loc[pheno.index].dropna()
if vec.empty:
mean_val = np.nan
skew_val = np.nan
mad_val = np.nan
else:
mean_val = vec.mean()
skew_val = skew(vec)
log_skew_val = np.sign(skew_val) * np.log1p(abs(skew_val)) if not np.isnan(skew_val) else 0.0
mad_val = np.median(np.abs(vec - np.median(vec)))
stat_features[node] = {"mean": mean_val, "log_skew": log_skew_val, "mad": mad_val}
stat_df = pd.DataFrame.from_dict(stat_features, orient="index")
node_features_df = stat_df.join(centralities_df)
self.logger.info(f"Built statistical feature matrix shape: {node_features_df.shape}")
ranks = node_features_df.rank(method="average")
scale_den = (ranks.max() - ranks.min()).replace(0, 1)
scaled_ranks = 2 * (ranks - ranks.min()) / scale_den - 1
node_features_df = scaled_ranks
timestamp = datetime.now().strftime("%m%d_%H_%M_%S")
labels_file = self.output_dir / f"features_{network_filtered.shape[0]}_{timestamp}.txt"
with open(labels_file, "w") as f:
f.write(node_features_df.to_string())
self.logger.info(f"Node features prepared successfully and saved to {labels_file}.")
return node_features_df
def _prepare_node_labels(self) -> pd.Series:
"""
Build node labels using either Pearson correlation OR mutual information
between each omics feature and the specified phenotype column.
"""
self.logger.info("Preparing node labels.")
nodes = sorted(self.adjacency_matrix.index.intersection(self.omics_data.columns))
samples = self.omics_data.index.intersection(self.phenotype_data.index)
omics_data = self.omics_data.loc[samples, nodes]
pheno = self.phenotype_data.loc[samples]
if len(samples)==0:
raise ValueError("No overlapping samples between omics and phenotype.")
if len(nodes)==0:
raise ValueError("No overlapping nodes between adjacency and omics.")
labels_dict = {}
for node in nodes:
vec = pd.to_numeric(omics_data[node], errors="coerce")
val = vec.corr(pheno)
labels_dict[node] = 0.0 if pd.isna(val) else val
labels_series = pd.Series(labels_dict, index=nodes)
ranks = labels_series.rank(method="average")
den = (ranks.max() - ranks.min()) or 1
scaled = 2*(ranks - ranks.min())/den - 1
timestamp = datetime.now().strftime("%m%d_%H_%M_%S")
labels_file = self.output_dir / f"labels_{self.adjacency_matrix.shape[0]}_{timestamp}.txt"
with open(labels_file, "w") as f:
f.write(scaled.to_string())
self.logger.info(f"Node labels prepared successfully and saved to {labels_file}.")
return scaled
def _build_pyg_data(self, node_features: pd.DataFrame, node_labels: pd.Series) -> Data:
self.logger.info("Constructing PyTorch Geometric Data object.")
if not node_labels.index.equals(node_features.index):
raise ValueError("`node_labels` must have the same index and order as `node_features`.")
nodes = node_features.index
adj = self.adjacency_matrix.loc[nodes, nodes]
G = nx.from_pandas_adjacency(adj)
node_mapping = {name: i for i, name in enumerate(nodes)}
G = nx.relabel_nodes(G, node_mapping)
data = from_networkx(G)
data.num_nodes = len(nodes)
edge_attr = getattr(data, "edge_attr", None)
if edge_attr is not None:
data.edge_weight = edge_attr.view(-1)
del data.edge_attr
else:
# no edge_attr
data.edge_weight = torch.ones(data.edge_index.size(1))
# data.edge_index, data.edge_weight = add_self_loops(
# data.edge_index, data.edge_weight, fill_value=1.0, num_nodes=len(nodes)
# )
# if isinstance(conv, (SAGEConv, GINConv)):
# data.edge_index, data.edge_weight = add_self_loops(
# data.edge_index, data.edge_weight, fill_value=1.0, num_nodes=data.num_nodes)
data.x = torch.tensor(node_features.loc[nodes].values, dtype=torch.float)
data.y = torch.tensor(node_labels.loc[nodes].values, dtype=torch.float)
self.logger.info("PyTorch Geometric Data object constructed successfully.")
return data
def _initialize_gnn_model(self) -> nn.Module:
"""
Initialize the GNN model based on the specified type.
Returns:
nn.Module
"""
self.logger.info(f"Initializing GNN model of type '{self.model_type}' with hidden_dim={self.hidden_dim} and layer_num={self.layer_num}.")
if self.data is None or not hasattr(self.data, "x") or self.data.x is None:
raise ValueError("Data is not initialized or is missing the 'x' attribute.")
input_dim = self.data.x.shape[1]
if self.model_type.upper() == "GCN":
return GCN(input_dim=input_dim, hidden_dim=self.hidden_dim, layer_num=self.layer_num, dropout=self.dropout,activation=self.activation, seed = self.seed)
elif self.model_type.upper() == "GAT":
return GAT(input_dim=input_dim, hidden_dim=self.hidden_dim, layer_num=self.layer_num, dropout=self.dropout,activation=self.activation, seed = self.seed)
elif self.model_type.upper() == "SAGE":
return SAGE(input_dim=input_dim, hidden_dim=self.hidden_dim, layer_num=self.layer_num, dropout=self.dropout,activation=self.activation, seed = self.seed)
elif self.model_type.upper() == "GIN":
return GIN(input_dim=input_dim, hidden_dim=self.hidden_dim, layer_num=self.layer_num, dropout=self.dropout, activation=self.activation, seed = self.seed)
else:
self.logger.error(f"Unsupported model_type: {self.model_type}")
raise ValueError(f"Unsupported model_type: {self.model_type}")
def _train_gnn(self, model: nn.Module, data: Data) -> None:
"""
Train the GNN model using MSE loss to predict node labels with early stopping.
"""
self.logger.info("Starting GNN training with early stopping.")
data = data.to(self.device)
model = model.to(self.device)
model.train()
optimizer = optim.Adam(model.parameters(), lr=self.lr, weight_decay=self.weight_decay)
loss_fn = nn.MSELoss()
best_loss = float("inf")
patience = 100
counter = 0
for epoch in range(1, self.num_epochs + 1):
optimizer.zero_grad()
output = model(data)
output = output.view(-1)
target = data.y.to(self.device)
loss = loss_fn(output, target)
if torch.isnan(loss):
self.logger.error("NaN loss encountered. Stopping early.")
break
loss.backward()
optimizer.step()
if loss.item() < best_loss - 1e-6:
best_loss = loss.item()
counter = 0
else:
counter += 1
if epoch % 50 == 0 or epoch == 1 or epoch == self.num_epochs:
self.logger.info(f"Epoch [{epoch}/{self.num_epochs}] - Loss: {loss.item():.4f} - EarlyStop: {counter}/{patience}")
if counter >= patience:
self.logger.warning(f"Early stopping triggered at epoch {epoch}. Best loss: {best_loss:.4f}")
break
self.logger.info("GNN training finished.")
def _generate_embeddings(self, model: nn.Module, data: Data) -> torch.Tensor:
"""
Retrieve node embeddings from the penultimate layer of the trained GNN model.
Returns:
torch.Tensor
"""
self.logger.info("Generating node embeddings from the trained GNN model.")
model.eval()
data = data.to(self.device)
with torch.no_grad():
embeddings = model.get_embeddings(data)
return embeddings.cpu()
def _tune_helper(self, config):
"""
The function that each Ray Tune trial calls.
"""
try:
tuned_instance = GNNEmbedding(
adjacency_matrix=self.adjacency_matrix,
omics_data=self.omics_data,
phenotype_data=self.phenotype_data,
clinical_data=self.clinical_data,
phenotype_col=self.phenotype_col,
model_type=config.get("model_type", self.model_type),
hidden_dim=config["hidden_dim"],
layer_num=config["layer_num"],
dropout=config["dropout"],
num_epochs=config["num_epochs"],
lr=config["lr"],
weight_decay=config["weight_decay"],
gpu=self.device.type,
seed=self.seed,
tune=False,
activation=self.activation,
output_dir=self.output_dir,
)
tuned_instance.fit()
node_embeddings = tuned_instance.embed()
X = node_embeddings.detach().cpu().numpy()
dim_stds = np.std(X, axis=0)
keep_dims = dim_stds >= 1e-4
num_dims_kept = np.sum(keep_dims)
if num_dims_kept == 0:
self.logger.warning(
"All embedding dimensions are nearly constant. Discarding trial."
)
tune.report({
"mse": 1e6,
"composite_score": 1e6,
"mean_dim_std": 0.0
})
return
X = X[:, keep_dims]
new_dim_stds = dim_stds[keep_dims]
mean_dim_std = np.mean(new_dim_stds)
y = tuned_instance._prepare_node_labels().values
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.3, random_state=self.seed
)
reg = RandomForestRegressor(
n_estimators=150,
max_depth=None,
n_jobs=-1,
random_state=self.seed
)
reg.fit(X_train, y_train)
y_pred = reg.predict(X_test)
mse = mean_squared_error(y_test, y_pred)
composite_score = mse / (mean_dim_std + 1e-6)
tune.report({
"mse": mse,
"composite_score": composite_score,
"mean_dim_std": mean_dim_std,
"dims_original": len(dim_stds),
"dims_dropped": int(len(dim_stds) - num_dims_kept)
})
except Exception as e:
self.logger.error(f"[Tuning Trial Error] config={config}")
self.logger.error(f"[Tuning Trial Error] Exception: {e}")
import traceback
traceback.print_exc()
tune.report({"mse": 1e8, "composite_score": 1e8, "mean_dim_std": 0.0})
[docs]
def run_gnn_embedding_tuning(self, num_samples=15):
"""
Run hyperparameter tuning with Ray Tune.
"""
num_nodes = self.adjacency_matrix.shape[0]
config = {
"model_type": tune.choice(["GAT","GCN","SAGE","GIN"]),
"hidden_dim": tune.choice([16, 32, 64, 128, 256, 512]),
"layer_num": tune.choice([2, 3, 4, 5, 6]),
"dropout": tune.choice([0.0, 0.1, 0.2, 0.3, 0.4, 0.5]),
"num_epochs": tune.choice([128, 256, 512, 1024, 2048]),
"lr": tune.loguniform(1e-6, 1e-3),
"weight_decay": tune.choice([0.0, 1e-6, 1e-5, 1e-4, 1e-3]),
"activation": tune.choice(["relu", "elu", "leaky_relu"]),
}
scheduler = ASHAScheduler(metric="composite_score", mode="min", grace_period=1, reduction_factor=2)
reporter = CLIReporter(metric_columns=["mse", "training_iteration"])
def short_dirname_creator(trial):
return f"_{trial.trial_id}"
resources = {"cpu": 1, "gpu": 1} if self.device.type == "cuda" else {"cpu": 1, "gpu": 0}
result = tune.run(
tune.with_parameters(self._tune_helper),
config=config,
num_samples=num_samples,
scheduler=scheduler,
verbose=1,
progress_reporter=reporter,
storage_path=os.path.expanduser("~/gnn"),
trial_dirname_creator=short_dirname_creator,
resources_per_trial=resources,
name="e",
)
timestamp = datetime.now().strftime("%m%d_%H_%M_%S")
save_dir = Path(self.output_dir)/"tuning_results"
os.makedirs(save_dir, exist_ok=True)
best_trial = result.get_best_trial("composite_score", "min", "last")
best_config_json = json.dumps(best_trial.config, indent=4)
try:
df = result.get_dataframe()
except AttributeError:
df = result.dataframe(metric="composite_score", mode="min")
summary_file = save_dir / f"summary_{num_nodes}_{timestamp}.txt"
with open(summary_file, "w") as f:
f.write(f"Best trial\n")
f.write(best_config_json)
f.write("\n\n")
f.write(df.to_string(index=False))
self.logger.info(f"Full trial summary saved to {summary_file}")
# best trial results
self.logger.info(f"Best trial config: {best_trial.config}")
self.logger.info(f"Best trial final MSE: {best_trial.last_result['mse']}")
# best config as a JSON file
timestamp = datetime.now().strftime("%m%d_%H_%M_%S")
save_dir.mkdir(exist_ok=True)
best_params_file = save_dir / f"emb_tuned_{num_nodes}_{timestamp}.json"
with open(best_params_file, "w") as f:
json.dump(best_trial.config, f, indent=4)
self.logger.info(f"Best embedding parameters saved to {best_params_file}")
return best_trial.config