#!/usr/bin/env python

import os
from dartqc.DartUtils import stamp, CommandLine, Installer


def main():

    cmd_line = CommandLine()

    args = vars(cmd_line.args)

    if args["version"]:
        print("0.0.1")
        exit(0)

    if args["subparser"] == "install":
        Installer()
        exit(0)

    from dartqc.DartWriter import DartWriter
    from dartqc.DartPrepare import DartPreparator
    from dartqc.DartModules import SummaryModule

    os.makedirs(args["out_path"], exist_ok=True)

    # Workflows

    if args["subparser"] == "prepare":

        DartPreparator(file_path=args["file"], excel_sheet=args["sheet"], output_name=args["output_name"],
                       output_path=args["out_path"])

    if args["subparser"] == "process":

        # Import the called reads from the standard file in basic mode, that is import a pre-formatted
        # data matrix with columns (C): CloneID, AlleleID, Sequence, Replication Average and Calls

        data, attributes = _read_dart(args)

        data, attributes = _preprocess_dart(args, data, attributes)

    if args["subparser"] == "filter":

        data, attributes = _filter_dart(args)

        stamp("Initialising Summary Module...")
        # Summary module for writing a summary of SNP parameters:
        sm = SummaryModule(data=data, attributes=attributes)

        sm.write_snp_summary(summary_parameters=["maf", "call_rate", "rep_average", "hwe"])

        stamp("Initialising Writing Module...")
        dart_writer = DartWriter(data, attributes)

        dart_writer.write_json(args["project"] + "_filtered")
        dart_writer.write_plink(args["project"] + "_filtered", remove_space=True)


def _filter_dart(args):

    from dartqc.DartReader import DartReader

    # Reading the data from JSON after Preprocessing:

    if args["processed_path"] is not None:
        data_file = os.path.join(args["processed_path"], args["project"] + "_data.json")
        attr_file = os.path.join(args["processed_path"], args["project"] + "_attr.json")

        stamp("Reading data from pre-processed JSON at path", args["processed_path"])
        stamp("Data file:", data_file)
        stamp("Attribute file:", attr_file)

        dart_reader = DartReader()
        data, attributes = dart_reader.read_json(data_file=data_file, attribute_file=attr_file)
    else:
        data, attributes = _read_dart(args)

    # Insert summary module here for before filtering snapshot of data (including parameters)
    # otherwise find summary of modules in attributes (removed, retained)

    data, attributes = _filter_mind(args, data, attributes)

    data, attributes = _filter_monomorphic(args, data, attributes)

    data, attributes = _filter_snps(args, data, attributes)

    data, attributes = _filter_redundancy(args, data, attributes)

    return data, attributes


def _filter_monomorphic(args, data, attributes):

    from dartqc.DartModules import PopulationModule

    if args["pop_file"] is None:
        stamp("Cannot load Population Module, no population file specified.")
        return data, attributes

    if args["mono"] is not None:
        pm = PopulationModule(data=data, attributes=attributes)
        data, attributes = pm.get_data(args["mono"], comparison=args["mono_comp"])

    return data, attributes


def _filter_mind(args, data, attributes):

    from dartqc.DartModules import SampleModule

    if args["mind"][0] is not None:
        im = SampleModule(data, attributes)

        # Not recalculating as MIND before SNP Module
        data, attributes = im.filter_data(mind=args["mind"], recalculate=False)

    return data, attributes


def _filter_redundancy(args, data, attributes):

    from dartqc.DartModules import RedundancyModule

    stamp("Initialising Redundancy Module...")

    rm = RedundancyModule(data=data, attributes=attributes, tmp_remove=True)

    # Indexing duplicate and identity clusters:
    if args["remove_duplicates"]:
        stamp("Removing SNPs with duplicate clone IDs...")

        rm.remove_duplicates(selector_list=("maf", "call_rate", "rep_average"))
    if args["remove_clusters"]:
        stamp("Removing SNPs in which clone sequences cluster at nucleotide identity of",
              str(args["identity"] * 100) + "%")

        rm.remove_clusters(selector_list=("maf", "call_rate", "rep_average"), cdhit_path="cd-hit-est",
                           identity=args["identity"])

    # Export data with duplicates and clustered SNPs removed:
    data, attributes = rm.get_data(duplicates=args["remove_duplicates"], clusters=args["remove_clusters"], )

    if len(data) == 0:
        stamp("All data was filtered, cannot write data to file.")
        exit(1)

    return data, attributes


def _filter_snps(args, data, attributes):

    from dartqc.DartModules import SNPModule

    stamp("Initialising SNP Module...")

    maf = args["maf"][0]
    call_rate = args["call_rate"][0]
    rep = args["rep"][0]
    hwe = args["hwe"][0]

    snp_filters = (maf, call_rate, rep, hwe)

    if all(v is None for v in snp_filters):
        stamp("No filters specified for SNPs.")

        mm = SNPModule(data=data, attributes=attributes)
        data, attributes = mm.get_data(threshold=None, multiple=None)

    else:
        # Initializing MarkerModule, which handles filtering:

        stamp("MAF <=", maf)
        stamp("Call Rate <=", call_rate)
        stamp("Replication Average <=", rep)
        stamp("Hardy-Weinberg p-value <=", hwe)

        mm = SNPModule(data=data, attributes=attributes)

        # Indexing all filter values defined above
        # True/False for retaining the SNP across all filter values and SNPs

        stamp("Indexing filters...")

        mm.filter_data(args["maf"], parameter="maf", comparison="<=")  # <= minor allele frequency
        mm.filter_data(args["hwe"], parameter="hwe", comparison="<=")  # <= p-value
        mm.filter_data(args["rep"], parameter="rep_average", comparison="<=")  # <= replication average by DArT
        mm.filter_data(args["call_rate"], parameter="call_rate", comparison="<=")  # <= call rate of SNP

        # Deploying filter across pre-indexed values (values given here must be present in lists above)
        # Filtered data is exported from module for further use.
        # Single filter use only at the moment:

        data, attributes = mm.get_data(multiple=[("maf", maf),
                                                 ("call_rate", call_rate),
                                                 ("rep_average", rep),
                                                 ("hwe", hwe)])

        if len(data) == 0:
            stamp("All data was filtered, cannot process data.")
            exit(1)

    return data, attributes


def _read_dart(args):

    from dartqc.DartReader import DartReader

    stamp("Reading data from call file...")
    stamp("Project", args["project"])
    stamp("Output to:", args["out_path"])
    stamp("Call file:", args["call_file"])
    stamp("Scheme file:", os.path.basename(args["call_scheme"]))

    dart_reader = DartReader()

    # This is weird, check if it makes sense when file input of both files in process task:

    if args["subparser"] == "filter":

        if args["split_clones"]:
            split = True
        else:
            split = False

        dart_reader.set_options(project=args["project"], out_path=args["out_path"], scheme=args["call_scheme"],
                                split_clone=split, clone_split=args["split_clones"])
    else:
        dart_reader.set_options(project=args["project"], out_path=args["out_path"], scheme=args["call_scheme"])

    dart_reader.read_double_row(file=args["call_file"], basic=True)

    if args["pop_file"] is not None:
        stamp("Reading population file at:", args["pop_file"])
        dart_reader.read_pops(args["pop_file"], sep=",")

    data, attributes = dart_reader.get_data()

    stamp("SNPs before QC:", attributes["snps"])
    stamp("Number of samples before QC:", len(attributes["sample_names"]))

    return data, attributes


def _preprocess_dart(args, data, attributes):

    from dartqc.DartProcessor import Preprocessor
    from dartqc.DartWriter import DartWriter

    stamp("Pre-processing with raw read data...")
    stamp("Setting calls to missing, where raw read sum is <=", args["raw_read_threshold"])
    stamp("Project", args["project"])
    stamp("Output to:", args["out_path"])
    stamp("Call file:", args["call_file"])
    stamp("Raw file", args["raw_file"])
    stamp("Scheme file (raw):", os.path.basename(args["raw_scheme"]))
    stamp("Scheme file (call):", os.path.basename(args["call_scheme"]))

    # Read the read call data into the Preprocessor:
    pp = Preprocessor(call_data=data, call_attributes=attributes)

    # Setting options for raw read count file:
    pp.set_options(project=args["project"], scheme=args["raw_scheme"])

    # Reading the raw read counts:
    pp.read_count_data(args["raw_file"])

    # Set all calls to missing < threshold (sum of minor and major read counts for SNP)
    # COULD USE ADDITIONAL FILTERS

    pp.filter_read_counts(threshold=args["raw_read_threshold"])

    # Export data and attributes for further use in the filtering modules...
    data, attributes = pp.get_data()

    # Writing these data to JSON, as pre-processing can take a while...
    dart_writer = DartWriter(data, attributes)

    dart_writer.write_json(args["project"])

    return data, attributes

main()

