bioneuralnet.network_embedding.gnn_embedding

Functions

get_logger(name)

Retrieves a global logger configured to write to 'bioneuralnet.log' at the project root.

process_dropout(dropout)

skew(a[, axis, bias, nan_policy, keepdims])

Compute the sample skewness of a data set.

Classes

ASHAScheduler

alias of AsyncHyperBandScheduler

CLIReporter(*[, metric_columns, ...])

Command-line reporter

GAT(*args, **kwargs)

GCN(*args, **kwargs)

GIN(*args, **kwargs)

GNNEmbedding(adjacency_matrix, omics_data, ...)

GNNEmbedding Class for Generating Graph Neural Network (GNN) Based Embeddings.

Path(*args, **kwargs)

PurePath subclass that can make system calls.

SAGE(*args, **kwargs)

datetime(year, month, day[, hour[, minute[, ...)

The year, month and day arguments are required.

class bioneuralnet.network_embedding.gnn_embedding.GNNEmbedding(adjacency_matrix: DataFrame, omics_data: DataFrame, phenotype_data: Series | DataFrame, clinical_data: DataFrame | None = None, phenotype_col: str = 'phenotype', model_type: str = 'GAT', hidden_dim: int = 64, layer_num: int = 4, dropout: bool | float = True, num_epochs: int = 100, lr: float = 0.001, weight_decay: float = 0.0001, gpu: bool = False, activation: str = 'relu', seed: int | None = None, tune: bool | None = False, output_dir: str | None = None)[source]

Bases: object

GNNEmbedding Class for Generating Graph Neural Network (GNN) Based Embeddings.

adjacency_matrix

pd.DataFrame

omics_data

pd.DataFrame

phenotype_data

pd.DataFrame

clinical_data

Optional[pd.DataFrame]

phenotype_col

str

model_type

str

hidden_dim

int

layer_num

int

dropout

Union[bool, float] (if bool, True maps to 0.5, False to 0.0)

num_epochs

int

lr

float

weight_decay

float

gpu

bool

seed

Optional[int]

tune

Optional[bool]

embed(as_df: bool = False) torch.Tensor | DataFrame[source]

Generates node embeddings. If tuning is enabled, runs hyperparameter tuning and uses the best configuration.

fit() None[source]

Trains the GNN model using the provided data.

run_gnn_embedding_tuning(num_samples=15)[source]

Run hyperparameter tuning with Ray Tune.