#!/usr/bin/env python
import argparse
import hivmmer
import os
from multiprocessing import cpu_count, Pool
from subprocess import run

if __name__ == "__main__":

    ### ARGUMENTS ###

    parser = argparse.ArgumentParser(description=hivmmer.__doc__,
                                     formatter_class=argparse.RawDescriptionHelpFormatter)
    parser.add_argument("FASTQ1",
                        help="FASTQ file with forward Illumina reads")
    parser.add_argument("FASTQ2",
                        help="FASTQ file with reverse Illumina reads")
    parser.add_argument("-o", "--outdir",
                        default=os.getcwd(),
                        help="output directory [current directory]")
    parser.add_argument("-t", "--threads",
                        default=1,
                        metavar="N",
                        type=int,
                        help="number of threads [1]")
    parser.add_argument("-l", "--min-length",
                        default=75,
                        metavar="L",
                        type=int,
                        help="minimum read length to retain [75]")
    parser.add_argument("-q", "--min-quality",
                        default=25,
                        metavar="Q",
                        type=int,
                        help="minimum mean quality score to retain [25]")
    parser.add_argument("-v", "--version",
                        action="version",
                        version="hivmmer {}".format(hivmmer.__version__))
    args = parser.parse_args()
    fastq1 = os.path.abspath(args.FASTQ1)
    fastq2 = os.path.abspath(args.FASTQ2)

    if args.threads > cpu_count():
        print("WARNING: --threads {} is larger than cpu count {}".format(args.threads, cpu_count()))
    elif args.threads > 1:
        pool = Pool(processes=args.threads)

    ### PIPELINE ###

    print("Creating output directories in:")
    print(os.path.abspath(args.outdir))
    os.makedirs(args.outdir, exist_ok=True)
    os.chdir(args.outdir)
    os.makedirs("logs", exist_ok=True)
    os.makedirs("sequences", exist_ok=True)
    os.makedirs("references", exist_ok=True)
    os.makedirs("alignments", exist_ok=True)
    os.makedirs("report", exist_ok=True)

    print("Running PEAR on FASTQ inputs:")
    print(fastq1)
    print(fastq2)
    with open("logs/pear.log", "w") as log:
        status = run(["pear", "-y", "1G", "-f", fastq1, "-r", fastq1, "-o", "sequences/pear", "-k", "-j", str(args.threads)],
                     stdout=log,
                     stderr=log).returncode
    assert status == 0, "ERROR: PEAR exited with status {} - check pear.log".format(status)

    print("Filtering and deduplicating PEAR sequences")
    keep = {}
    for pearfile in ("assembled", "unassembled.forward", "unassembled.reverse"):
        hivmmer.filter.mean_filter("sequences/pear.{}.fastq".format(pearfile), args.min_length, args.min_quality, keep)
    with open("sequences/deduplicated.fa", "w") as f:
        hivmmer.filter.tofasta(keep, f)

    print("Translating deduplicated sequences to amino acid sequences")
    with open("sequences/translated.pfa", "w") as f, open("logs/translate.log", "w") as log:
        hivmmer.translate("sequences/deduplicated.fa", f, log)

    print("Copying pHMM references")
    hivmmer.copy_hmms("references")

    for gene in hivmmer.genes:
        print("Aligning {} with hmmsearch".format(gene))
        with open("logs/hmmsearch.{}.log".format(gene), "w") as log:
            status = run(["hmmsearch", "--max",
                          "--cpu", str(args.threads),
                          "-o", "alignments/{}.txt".format(gene),
                          "references/{}.hmm".format(gene),
                          "sequences/translated.pfa"],
                         stdout=log,
                         stderr=log).returncode
        assert status == 0, "ERROR: hmmsearch exited with status {} - check hmmsearch.{}.log".format(status, gene)

    print("Extracting codons from hmmsearch alignments")
    if args.threads > 1:
        alignments = [(os.path.abspath("sequences/deduplicated.fa"),
                       os.path.abspath("alignments/{}.txt".format(gene)),
                       gene)
                      for gene in hivmmer.genes]
        codons = pool.starmap(hivmmer.codons, alignments)
    else:
        codons = [hivmmer.codons("sequences/deduplicated.fa", 
                                 "alignments/{}.txt".format(gene),
                                 gene)
                  for gene in hivmmer.genes]

    print("Merging codons across genes")
    with open("codons.tsv", "w") as f:
        print("hxb2", "codon", "count", sep="\t", file=f)
        for lines in codons:
            print("\n".join(lines), file=f)

    print("Generating consensus sequences")
    hivmmer.consensus("codons.tsv", "consensus.fa")

    print("Generating AA table")
    hivmmer.aa_table("codons.tsv", "aa.xlsx")

    print("Identifying DRMs in AA table")
    hivmmer.drms("aa.xlsx", "drms.csv")

    print("Plotting WGS coverage")
    hivmmer.report.plot_coverage("codons.tsv", "report/coverage.pdf")

    print("Plotting PRRT coverage")
    hivmmer.report.plot_coverage_prrt("aa.xlsx", "report/coverage-prrt.pdf")

    print("Plotting DRMs/SDRMs")
    drms = hivmmer.report.plot_drms("aa.xlsx", "drms.csv", "Stanford", "report/drms.pdf")
    drmi = hivmmer.report.plot_drms("aa.xlsx", "drms.csv", "IAS", "report/drmi.pdf")
    sdrm = hivmmer.report.plot_drms("aa.xlsx", "drms.csv", "SDRM", "report/sdrm.pdf")

    print("Compiling PDF report")
    hivmmer.report.compile([fastq1, fastq2],
                           "coverage.pdf", "coverage-prrt.pdf",
                           "drms.pdf", "drmi.pdf", "sdrm.pdf",
                           drms, drmi, sdrm,
                           "report", "report.pdf")

    if args.threads > 1:
        pool.close()

    print("Finished.")

# vim: expandtab sw=4 ts=4
