bioneuralnet.network_embedding

Classes

GAT(*args, **kwargs)

GCN(*args, **kwargs)

GIN(*args, **kwargs)

GNNEmbedding(adjacency_matrix, omics_data, ...)

GNNEmbedding Class for Generating Graph Neural Network (GNN) Based Embeddings.

SAGE(*args, **kwargs)

class bioneuralnet.network_embedding.GAT(*args: Any, **kwargs: Any)[source]

Bases: Module

forward(data)[source]
get_embeddings(data)[source]
class bioneuralnet.network_embedding.GCN(*args: Any, **kwargs: Any)[source]

Bases: Module

forward(data)[source]
get_embeddings(data)[source]
class bioneuralnet.network_embedding.GIN(*args: Any, **kwargs: Any)[source]

Bases: Module

forward(data)[source]
get_embeddings(data)[source]
class bioneuralnet.network_embedding.GNNEmbedding(adjacency_matrix: DataFrame, omics_data: DataFrame, phenotype_data: Series | DataFrame, clinical_data: DataFrame | None = None, phenotype_col: str = 'phenotype', model_type: str = 'GAT', hidden_dim: int = 64, layer_num: int = 4, dropout: bool | float = True, num_epochs: int = 100, lr: float = 0.001, weight_decay: float = 0.0001, gpu: bool = False, activation: str = 'relu', seed: int | None = None, tune: bool | None = False, output_dir: str | None = None)[source]

Bases: object

GNNEmbedding Class for Generating Graph Neural Network (GNN) Based Embeddings.

adjacency_matrix

pd.DataFrame

omics_data

pd.DataFrame

phenotype_data

pd.DataFrame

clinical_data

Optional[pd.DataFrame]

phenotype_col

str

model_type

str

hidden_dim

int

layer_num

int

dropout

Union[bool, float] (if bool, True maps to 0.5, False to 0.0)

num_epochs

int

lr

float

weight_decay

float

gpu

bool

seed

Optional[int]

tune

Optional[bool]

embed(as_df: bool = False) torch.Tensor | DataFrame[source]

Generates node embeddings. If tuning is enabled, runs hyperparameter tuning and uses the best configuration.

fit() None[source]

Trains the GNN model using the provided data.

run_gnn_embedding_tuning(num_samples=15)[source]

Run hyperparameter tuning with Ray Tune.

class bioneuralnet.network_embedding.SAGE(*args: Any, **kwargs: Any)[source]

Bases: Module

forward(data)[source]
get_embeddings(data)[source]

Modules

gnn_embedding

gnn_models