#!/usr/bin/env python3

import argparse
import os

import torch
import numpy as np
from pytorch_lightning import Trainer
from deepblast.trainer import LightningAligner
from deepblast.dataset import FastaDataset
from deepblast.dataset.utils import collate_fasta_f
from deepblast.dataset.utils import unpack_sequences
from torch.utils.data import DataLoader
from tqdm import tqdm


def fasta_dataloader(args):
    dataset = FastaDataset(args.query_fasta, args.db_fasta)
    dataloader = DataLoader(
        dataset, args.batch_size, shuffle=False,
        collate_fn=collate_fasta_f, num_workers=args.num_workers,
        pin_memory=True)
    return dataloader


def main(args):
    torch.cuda.set_device(0)
    print('args', args)
    model = LightningAligner.load_from_checkpoint(
        args.load_from_checkpoint)
    dataloader = fasta_dataloader(args)
    if args.gpu:
        model = model.cuda()
    with open(args.output_file, 'w') as out_handle:
        for batch in dataloader:
            qids, dids, qtoks, dbtoks = batch            
            A = model.score(qtoks, dbtoks)
            A = A.detach().cpu().numpy()
            for i in range(A.shape[0]):
                ql, dl = len(qids[i]), len(dids[i])
                aln_score = A[i]
                norm_score = aln_score / (float(ql) * float(dl))
                q_id, d_id = qids[i], dids[i]
                res = [q_id, d_id, np.round(aln_score, decimals=4),
                       np.round(norm_score, decimals=4)]
                res = list(map(str, res))
                print(q_id, d_id)
                out_handle.write('\t'.join(res) + '\n')


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--query-fasta', type=str, required=True)
    parser.add_argument('--db-fasta', type=str, required=True)
    parser.add_argument('--load-from-checkpoint', type=str, required=True)
    parser.add_argument('--output-file', type=str, required=True)
    parser.add_argument('--gpu', type=bool, default=None)
    parser.add_argument('--num-workers', type=int, default=1)
    parser.add_argument('--batch-size', type=int, default=10)

    hparams = parser.parse_args()
    main(hparams)
