#!/usr/bin/env python

import logging
import pathlib
import configparser
import shutil
import os
import argparse
from Bio import SeqIO
import pandas as pd

import modules.genome_simulator as gsim
import modules.genome_quality_filter as gqf
import modules.msa as msa_filter
import modules.probe_generation as probeg
import modules.primer_generation as primerg
import modules.validator as validator
import modules.sensitivity_check as psc
import modules.visualizer as viz

import warnings
from Bio import BiopythonWarning
warnings.simplefilter('ignore', BiopythonWarning)

VERBOSITY = logging.INFO
SCRIPT_DIR = pathlib.Path(__file__).parent.absolute()

# TODO use biopython?
def combine_fasta_files(files: list, output_name: str):
    output_str = ""
    for f in files:
        output_str += open(f).read()
        if output_str[-1] != '\n':
            output_str += '\n'

    with open(output_name, 'w') as output_file:
        output_file.write(output_str)

def combine_pdcsv_files(files: list, output_name: str):
    pd.concat([pd.read_csv(f) for f in files]).to_csv(output_name, index=False)

def parse_args():
    parser = argparse.ArgumentParser(description="Probe finder")
    parser.add_argument(
        "reference",
        type=str,
        help="Reference genome in fasta format")
    parser.add_argument(
        "genomes",
        nargs="+",
        help="List of all input genomes")
    parser.add_argument(
        "-o",
        "--output",
        type=str,
        default="output",
        help="Prefix of output directory.")
    parser.add_argument(
        "--id",
        type=str,
        default="results",
        help="ID of run. Results will be stored at output/[ID]")
    parser.add_argument(
        "-c",
        "--config",
        type=str,
        default=os.path.join(SCRIPT_DIR, "data", "default_config.ini"),
        help="Path to configuration file")
    parser.add_argument(
        "-t",
        "--threads",
        type=int,
        default=3,
        help="Default number of threads to use")
    parser.add_argument(
        "--probes-only",
        action='store_true',
        help="Do not generate primers")

    # Simulator args
    simulator_filter_parser = parser.add_argument_group(description="Simulator/Filter arguments")
    simulator_filter_parser.add_argument(
        "--csv",
        type=str,
        help="Metadata in csv format")
    simulator_filter_parser.add_argument(
        "-r",
        "--reads",
        type=str,
        help="Directory containing read datasets")
    simulator_filter_parser.add_argument(
        "-s",
        "--simulate",
        action="store_true",
        help="Simulate genomes from read data and use in pipeline")
    simulator_filter_parser.add_argument(
        "--simulate-use-existing",
        type=str,
        help="Use simulated genomes from this directory rather than simulate new ones.")
    simulator_filter_parser.add_argument(
        "--coverage-threshold",
        type=float,
        default=0.5,
        help="input coverage threshold, mappings above the threshold would be \
            taken into consideration for generating simulated genome")
    simulator_filter_parser.add_argument(
        "--min-base-quality",
        type=int,
        default=0,
        help="input min base quality, bases with quality above this threshold \
        would be takin into consideration during mapping quality check")
    simulator_filter_parser.add_argument(
        "--min-af",
        type=float,
        default=0.01,
        help="Minimum AF threshold")
    simulator_filter_parser.add_argument(
        "--contiguousn",
        type=int,
        default=20, \
        help="contiguous N threshold, sequences with number of \
            contiguous N above this threshold are filtered out")
    simulator_filter_parser.add_argument(
        "--totalnthreshold",
        type=int,
        default=30, \
        help="Total N threshold, sequences with number of ambiguous base over \
            this threshold are filtered out")
    simulator_filter_parser.add_argument(
        "--lenthreshold",
        type=int,
        default=145, \
        help="Genome length threshold, sequences with length not in range \
            of plus/minus threshold bp compare to reference are filtered out")
    simulator_filter_parser.add_argument(
        "--endbases",
        type=int,
        default=200, \
        help="Ignore bases at both end while filtering genomes")
    simulator_filter_parser.add_argument(
        "--keep-insertion",
        action='store_true',
        help="While filtering, keep genomes with insertions")
    simulator_filter_parser.add_argument(
        "--keep-ambiguous",
        action='store_true',
        help="While filtering, keep genomes with ambiguous base not on both ends")

    # Probe/Primer generation args
    generator_parser = parser.add_argument_group(description="Probe/Primer generation arguments")
    generator_parser.add_argument(
        "--negative-control",
        type=str,
        default=os.path.join(SCRIPT_DIR, "data", "negative_control.fasta"),
        help="Path to negative control fasta file")
    generator_parser.add_argument(
        "--kmer-min",
        type=int,
        default=None, \
        help="Minimum length of desired k-mer probes")
    generator_parser.add_argument(
        "--kmer-max",
        type=int,
        default=None, \
        help="Maximum length of desired k-mer probes")
    generator_parser.add_argument(
        "--primer-min",
        type=int,
        default=None, \
        help="Minimum primer length")
    generator_parser.add_argument(
        "--primer-max",
        type=int,
        default=None, \
        help="Maximum primer length")

    # Validation args
    validator_parser = parser.add_argument_group(description="Validator/Sensitivity arguments")
    validator_parser.add_argument(
        "--run-all",
        action='store_true',
        help="Run sensitivity check on all oligos even if they fail validation")
    validator_parser.add_argument(
        "--blastdb",
        type=str,
        help="Directory of blast database, otherwise will use blast remote, \
            which have limitation of input sequence number")
    validator_parser.add_argument(
        "--email",
        type=str,
        help="Email to use for Entrez request")
    validator_parser.add_argument(
        "--api",
        type=str,
        help="API key to use for Entrez request")
    validator_parser.add_argument(
        "--cross-check",
        action="store_true",
        help="Check for dimers across separate probes/primers")

    debug_parser = parser.add_argument_group(description="Flags for skipping modules")
    debug_parser.add_argument(
        "--sw-generation",
        action='store_true',
        help="Skip quality filter and MSA. Useful for rerunning w/ different config for primer/probes")

    return parser, parser.parse_args()

def validate(args, parser, logger):
    if not (args.email or args.api or args.blastdb):
        logger.error("Need email, api key, or blastdb for validation step")
        parser.exit(2)

def main():
    parser, args = parse_args()
    logger = logging.getLogger("Olivar")
    logger.setLevel(VERBOSITY)

    ch = logging.StreamHandler()
    ch.setLevel(VERBOSITY)
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    ch.setFormatter(formatter)
    logger.addHandler(ch)

    validate(args, parser, logger)

    reference_fasta = args.reference
    reference_obj = SeqIO.read(reference_fasta, "fasta")
    genome_list = args.genomes
    input_metadata = args.csv
    output_dir = args.output
    output_w_id = os.path.join(output_dir, args.id)
    read_dir = args.reads
    config = configparser.ConfigParser()
    config.optionxform = str
    config.read(args.config)
    num_threads = args.threads
    probes_only = args.probes_only

    coverage_threshold = args.coverage_threshold
    min_base_quality = args.min_base_quality
    min_af_threshold = args.min_af
    contiguous_n_threshold = args.contiguousn
    total_n_threshold = args.totalnthreshold
    length_threshold = args.lenthreshold
    endbases = args.endbases
    if args.simulate_use_existing:
        simulated_genome_dir = args.simulate_use_existing
    elif args.simulate:
        simulated_genome_dir = os.path.join(output_dir, "simulate", "simulated_genomes")
    else:
        simulated_genome_dir = None


    kmer_length_min = config.getint("Probe", "Length_min")
    kmer_length_max = config.getint("Probe", "Length_max")
    primer_length_min = config.getint("Primer", "Length_min")
    primer_length_max = config.getint("Primer", "Length_max")
    negative_control = args.negative_control
    if args.kmer_min:
        kmer_length_min = args.kmer_min
    if args.kmer_max:
        kmer_length_max = args.kmer_max
    if args.primer_min:
        primer_length_min = args.primer_min
    if args.primer_max:
        primer_length_max = args.primer_max

    run_all = args.run_all
    blastdb_path = args.blastdb
    email = args.email
    api_key = args.api
    cross_check = args.cross_check

    if not os.path.isdir(output_w_id):
        pathlib.Path(output_w_id).mkdir(parents=True, exist_ok=True)

    quality_filter_output_dir = os.path.join(output_w_id, "quality_filter")
    if not os.path.exists(quality_filter_output_dir):
        os.mkdir(quality_filter_output_dir)
    probe_generation_output_dir = os.path.join(output_w_id, "probe_generation")
    if not os.path.exists(probe_generation_output_dir):
        os.mkdir(probe_generation_output_dir)
    primer_generation_output_dir = os.path.join(output_w_id, "primer_generation")
    if not os.path.exists(primer_generation_output_dir):
        os.mkdir(primer_generation_output_dir)
    final_oligos_unfiltered_path = os.path.join(output_w_id, "final_report")
    if not os.path.exists(final_oligos_unfiltered_path):
        os.mkdir(final_oligos_unfiltered_path)


    # Simulate genomes
    if args.simulate and not args.simulate_use_existing and not args.sw_generation:
        simulation_output_dir = os.path.join(output_dir, "simulate")
        if not os.path.isdir(simulation_output_dir):
            os.mkdir(simulation_output_dir)

        # Run simulation
        gsim.simulate(
            input_metadata,
            reference_fasta,
            read_dir,
            simulation_output_dir,
            coverage_threshold,
            min_base_quality,
            min_af_threshold,
            num_threads)

    # Filter genomes (quality)
    if not args.sw_generation:
        logger.info("Running quality filter...")
        gqf.filter_genome(
            genome_list,
            reference_obj,
            quality_filter_output_dir,
            simulated_genome_dir,
            contiguous_n_threshold,
            total_n_threshold,
            length_threshold,
            endbases,
            num_threads)

    if args.simulate or args.simulate_use_existing:
        filtered_genomes_fasta = os.path.join(
            quality_filter_output_dir,
            "filtered_genomes_wsim.fasta")
    else:
        filtered_genomes_fasta = os.path.join(
            quality_filter_output_dir,
            "filtered_genomes.fasta")

    # Filter genomes (align)
    if not args.sw_generation:
        logger.info("Running msa...")
        unfiltered_msa_path = os.path.join(quality_filter_output_dir, "unfiltered_msa.fasta")
        msa_ret = msa_filter.run(
            filtered_genomes_fasta,
            quality_filter_output_dir,
            num_threads)
        if msa_ret != 0:
            logger.critical("MAFFT alignment failed with exit code {}".format(msa_ret))
            return
    filtered_genomes_fasta = os.path.join(quality_filter_output_dir, "filtered_genomes.fasta")
    filtered_msa_fasta = os.path.join(quality_filter_output_dir, "filtered_msa.fasta")

    # Generate probes
    logger.info("Generating probes...")
    probeg.run(
        os.path.join(quality_filter_output_dir, "filtered_genomes"),
        negative_control,
        probe_generation_output_dir,
        kmer_length_min,
        kmer_length_max,
        num_threads)
    probes_fasta = os.path.join(
        probe_generation_output_dir,
        "probes",
        "final_kmers.fasta")

    # Validate probes
    logger.info("Validating probes...")
    validator.run(probes_fasta, probe_generation_output_dir, config, cross_check, num_threads)
    probes_validator = os.path.join(probe_generation_output_dir, "validator-report.csv")

    # Sensitivity for probes
    logger.info("Running sensitivity check on probes...")
    sens_ret = psc.run(
        probes_validator,
        probe_generation_output_dir,
        filtered_genomes_fasta,
        blastdb_path,
        num_threads,
        is_validated=True,
        config=config,
        run_all=run_all,
        email=email,
        api_key=api_key)
    if sens_ret != 0:
        logger.critical("Sensitivity check failed with return value {}".format(sens_ret))
        return sens_ret
    probes_sensitivity = os.path.join(probe_generation_output_dir, "sensitivity-report.csv")
    probes_fasta = os.path.join(probe_generation_output_dir, "sensitivity-passed.fasta")

    if not probes_only:
        logger.info("Generating primers...")
        # Generate primers
        primerg.generate_primers(
            probes_fasta,
            primer_generation_output_dir,
            filtered_msa_fasta,
            primer_length_min,
            primer_length_max,
            config,
            num_threads)
        primers_fasta = os.path.join(primer_generation_output_dir, "primers.fasta")

        # Validate primers
        validator.run(primers_fasta, primer_generation_output_dir, config, cross_check, num_threads)
        primers_validator = os.path.join(primer_generation_output_dir, "validator-report.csv")

        # Sensitivity for primers
        logger.info("Running sensitivity check on primers...")
        sens_ret = psc.run(
            primers_validator,
            primer_generation_output_dir,
            filtered_genomes_fasta,
            blastdb_path,
            num_threads,
            is_validated=True,
            config=config,
            run_all=run_all,
            email=email,
            api_key=api_key)
        if sens_ret != 0:
            logger.critical("Sensitivity check failed with return value {}".format(sens_ret))
            return sens_ret
        primers_sensitivity = os.path.join(primer_generation_output_dir, "sensitivity-report.csv")
        primers_fasta = os.path.join(primer_generation_output_dir, "sensitivity-passed.fasta")

    # Visualize
    logger.info("Finalizing output...")
    oligos_fasta = os.path.join(final_oligos_unfiltered_path, "oligos-passed.fasta")
    oligos_validator = os.path.join(final_oligos_unfiltered_path, "oligos-report.csv")
    oligos_sensitivity = os.path.join(final_oligos_unfiltered_path, "oligos-sensitivity.csv")
    if probes_only:
        shutil.copy(probes_fasta, oligos_fasta)
        shutil.copy(probes_sensitivity, oligos_sensitivity)
        shutil.copy(probes_validator, oligos_validator)
    else:
        combine_fasta_files([primers_fasta, probes_fasta], oligos_fasta)
        combine_pdcsv_files([primers_sensitivity, probes_sensitivity], oligos_sensitivity)
        combine_pdcsv_files([primers_validator, probes_validator], oligos_validator)
    template_dir = "data"
    template_file = "visualization-template.html"
    viz.visualize(oligos_validator, oligos_sensitivity, final_oligos_unfiltered_path, template_dir, template_file)


if __name__ == "__main__": 
    main()
