#!/usr/bin/env python3

import argparse
import json
import os
from pkg_resources import Requirement, resource_filename, DistributionNotFound
import time
import pickle

from firm import firm
from firm import pssm
from firm.miRvestigator import miRvestigator

PACKAGE_NAME = "isb-firm"
USER_P3UTR_PATH = 'firm/common/p3utrSeqs_Homo_sapiens.csv.gz'
USER_G2REFSEQ_PATH = 'firm/common/gene2refseq.gz'
USER_MATUREFA_PATH = 'firm/common/mature.fa.gz'
USER_FREQFILES_PATH = 'firm/FreqFiles'
USER_PREDDIR = 'firm/TargetPredictionDatabases'


def firm_findmotifs(p3utrseqs_path, g2refseq_path, freqfiles_path,
                    expdir, outdir, tmpdir, use_entrez):
    if not firm.check_weederlauncher():
        print("Could not find Weeder on this system, please install weeder-1.4.3.tar.gz from https://github.com/baliga-lab/firm2")
        return

    seqs = firm.read_sequences(p3utrseqs_path)
    refSeq2entrez = firm.make_refseq2entrez(g2refseq_path)

    # First stage: run Weeder on the input clusters and write out the
    # PSSMs to a JSON file
    t1 = time.time()
    fasta_files = firm.prepare_weeder_input(seqs, refSeq2entrez, use_entrez, expdir,
                                            tmpdir)
    weeder_pssms = firm.find_motifs(fasta_files, tmpdir, freqfiles_path)
    with open(os.path.join(outdir, 'pssms.json'), 'w') as outfile:
        json.dump(weeder_pssms, outfile)

    # Write the 3' UTR sequences to pass as as a filter for miRvestigator
    # we actually could use the original 3' UTR source files, but a line file is simpler
    with open(os.path.join(outdir, 'seqs.txt'), 'w') as outfile:
        for seq in seqs.values():
            outfile.write("%s\n" % seq)
    t2 = time.time()
    print("Completed weeder find motifs in {:.2f} minutes".format((t2-t1)/60.))


def mirvestigator(indir, outdir, maturefa_path):
    with open(os.path.join(indir, 'seqs.txt')) as infile:
        seqs = infile.readlines()

    pssms = []
    with open(os.path.join(indir, 'pssms.json')) as infile:
        pssms_json = json.load(infile)
        for pssm_json in pssms_json:
            pssms.append(pssm.pssm(name=pssm_json['name'],
                                   sites=pssm_json['sites'],
                                   evalue=pssm_json['evalue'],
                                   pssm=pssm_json['matrix']))

    t1 = time.time()
    m2m = miRvestigator(pssms, seqs, maturefa_path,
                        seedModel=[6,7,8], minor=True, p5=True,
                        p3=True, wobble=False, wobbleCut=0.25,
                        use_multiprocessing=True,
                        baseDir=outdir)
    with open(os.path.join(outdir, 'm2m_standalone.pkl'), 'wb') as outfile:
        pickle.dump(m2m, outfile)
    t2 = time.time()
    print("Completed miRvestigator in {:.2f} minutes".format((t2-t1)/60.))



def firm_results(g2refseq_path, preddir, maturefa_path,
                 expdir, outdir, tmpdir, use_entrez):
    refSeq2entrez = firm.make_refseq2entrez(g2refseq_path)
    firm.run_target_prediction_dbs(refSeq2entrez,
                                   exp_dir=expdir,
                                   outdir=outdir,
                                   tmpdir=tmpdir,
                                   pred_db_dir=preddir,
                                   use_entrez=use_entrez)

    firm.write_combined_report(os.path.join(outdir, 'miRNA', 'scores.csv'),
                               maturefa_path,
                               outdir=outdir)

DESCRIPTION = """firm - Run FIRM pipeline"""

if __name__ == '__main__':
    parser = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter,
                                     description=DESCRIPTION)
    parser.add_argument('-ue', '--use_entrez', action='store_true',
                        help="input file uses entrez IDs instead of RefSeq")
    parser.add_argument('-t', '--tmpdir', default='tmp',
                        help="temporary directory")
    parser.add_argument('expdir', help='expression input directory')
    parser.add_argument('outdir', help='output directory')
    args = parser.parse_args()

    if not os.path.exists(args.expdir):
        raise Exception("Input directory '%s' does not exist" % args.expdir)

    # create directories if needed
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)
    if not os.path.exists(args.tmpdir):
        os.makedirs(args.tmpdir)

    # File paths that are contained in the distribution now, in the future can be
    # overridden by command line options
    try:
        p3utr_path = resource_filename(Requirement.parse(PACKAGE_NAME), USER_P3UTR_PATH)
    except DistributionNotFound:
        p3utr_path = USER_P3UTR_PATH

    try:
        g2refseq_path = resource_filename(Requirement.parse(PACKAGE_NAME), USER_G2REFSEQ_PATH)
    except DistributionNotFound:
        g2refseq_path = USER_G2REFSEQ_PATH

    try:
        freqfiles_path = resource_filename(Requirement.parse(PACKAGE_NAME), USER_FREQFILES_PATH)
    except DistributionNotFound:
        freqfiles_path = USER_FREQFILES_PATH

    try:
        maturefa_path = resource_filename(Requirement.parse(PACKAGE_NAME), USER_MATUREFA_PATH)
    except DistributionNotFound:
        maturefa_path = USER_MATUREFA_PATH

    try:
        preddir = resource_filename(Requirement.parse(PACKAGE_NAME), USER_PREDDIR)
    except DistributionNotFound:
        preddir = USER_PREDDIR


    seqs_path = os.path.join(args.outdir, 'seqs.txt')
    pssms_path = os.path.join(args.outdir, 'pssms.json')

    if not os.path.exists(pssms_path) or not os.path.exists(seqs_path):
        firm_findmotifs(p3utr_path, g2refseq_path, freqfiles_path,
                        args.expdir, args.outdir,
                        args.tmpdir, args.use_entrez)
    else:
        print("findmotifs results found, continuing to mirvestigator")

    if not os.path.exists(os.path.join(args.outdir, 'm2m_standalone.pkl')):
        mirvestigator(args.outdir, args.outdir, maturefa_path)
    else:
        print('mirvestigator results found, continuing to target database matching')

    firm_results(g2refseq_path, preddir, maturefa_path,
                 args.expdir, args.outdir, args.tmpdir, args.use_entrez)
