#!python
import torch

from rnaglib.kernels import node_sim
from rnaglib.data_loading import rna_dataset, rna_loader
from rnaglib.representations import GraphRepresentation, RingRepresentation
from rnaglib.learning import models, learning_utils, learn

"""
This script shows a second more complicated example : learn binding protein preferences as well as
small molecules binding from the nucleotide types and the graph structure
We also add a pretraining phase based on the R_graphlets kernel
"""

if __name__ == "__main__":
    # Choose the data, features and targets to use
    node_features = ['nt_code']
    node_target = ['binding_protein']

    ###### Unsupervised phase : ######
    # Choose the data and kernel to use for pretraining
    print('Starting to pretrain the network')
    node_simfunc = node_sim.SimFunctionNode(method='R_graphlets', depth=2)
    graph_representation = GraphRepresentation(framework='dgl')
    ring_representation = RingRepresentation(node_simfunc=node_simfunc, max_size_kernel=50)
    unsupervised_dataset = rna_dataset.RNADataset(nt_features=node_features,
                                                  representations=[ring_representation, graph_representation])
    train_loader = rna_loader.get_loader(dataset=unsupervised_dataset, split=False, num_workers=4)

    # Then choose the embedder model and pre_train it, we dump a version of this pretrained model
    embedder_model = models.Embedder(infeatures_dim=unsupervised_dataset.input_dim,
                                     dims=[64, 64])
    optimizer = torch.optim.Adam(embedder_model.parameters())
    learn.pretrain_unsupervised(model=embedder_model,
                                optimizer=optimizer,
                                train_loader=train_loader,
                                learning_routine=learning_utils.LearningRoutine(num_epochs=10),
                                rec_params={"similarity": True, "normalize": False, "use_graph": True, "hops": 2})
    # torch.save(embedder_model.state_dict(), 'pretrained_model.pth')
    print()

    ###### Now the supervised phase : ######
    print('We have finished pretraining the network, let us fine tune it')
    # GET THE DATA GOING, we want to use precise data splits to be able to use the benchmark.
    supervised_train_dataset = rna_dataset.RNADataset(nt_features=node_features,
                                                      nt_targets=node_target,
                                                      representations=[graph_representation])
    train_loader, _, test_loader = rna_loader.get_loader(dataset=supervised_train_dataset,
                                                         split_train=0.8, split_valid=0.8,
                                                         num_workers=10)

    # Define a model and train it :
    # We first embed our data in 64 dimensions, using the pretrained embedder and then add one classification
    # Then get the training going
    classifier_model = models.Classifier(embedder=embedder_model, classif_dims=[supervised_train_dataset.output_dim])
    optimizer = torch.optim.Adam(classifier_model.parameters(), lr=0.001)
    learn.train_supervised(model=classifier_model,
                           optimizer=optimizer,
                           train_loader=train_loader,
                           learning_routine=learning_utils.LearningRoutine(num_epochs=10))

    # Get a benchmark performance on the official uncontaminated test set :
    metric = learning_utils.evaluate_model_supervised(model=classifier_model, loader=test_loader)
    print('We get a performance of :', metric)
    print()
