#!python

import torch

from rnaglib.learning import models, learn
from rnaglib.data_loading import rna_dataset, rna_loader
from rnaglib.representations import GraphRepresentation

"""
This script just shows a first very basic example : learn binding protein preferences 
from the nucleotide types and the graph structure

To do so, we choose our data, create a data loader around it, build a RGCN model and train it.
"""

if __name__ == "__main__":
    # Choose the data, features and targets to use and GET THE DATA GOING
    node_features = ['nt_code']
    node_target = ['binding_protein']
    graph_rep = GraphRepresentation(framework='dgl')
    supervised_dataset = rna_dataset.RNADataset(nt_features=node_features, nt_targets=node_target,
                                                representations=[graph_rep])
    train_loader, validation_loader, test_loader = rna_loader.get_loader(dataset=supervised_dataset)

    # Define a model, we first embed our data in 10 dimensions, and then add one classification
    input_dim, target_dim = supervised_dataset.input_dim, supervised_dataset.output_dim
    embedder_model = models.Embedder(dims=[10, 10], infeatures_dim=input_dim)
    classifier_model = models.Classifier(embedder=embedder_model, classif_dims=[target_dim])

    # Finally get the training going
    optimizer = torch.optim.Adam(classifier_model.parameters(), lr=0.001)
    learn.train_supervised(model=classifier_model,
                           optimizer=optimizer,
                           train_loader=train_loader)
