#!/usr/bin/env python

"""
hdose
Harmonize the dosages between exposure and outcome alleles for Mendelian Randomization.
(Visit github.com/endrebak/harmonize-dosage for examples and help.)
Usage:
    hdose --exposure=EXP --outcome=OUT --dosage=DOS --output-dir=DIR
          [--tolerance=TOL] [--outcome-columns=OC] [--exposure-columns=EC]
    hdose --help
Arguments:
    -e EXP --exposure EXP     Exposure file
    -o OUT --outcome OUT      Outcome file
    -d DOS --dosage DOS       Dosage file
    -t TOL --tolerance TOL    Tolerance for considering AFs similar.  [default: 0.08]
    -D DIR --output-dir DIR   Dir to put output files in.
    -E EC --exposure-columns  Comma-separated names of the columns containing the
                              reference allele, alternate allele, rsid and
                              allele frequency in the outcome file.
                              [default: REF,ALT,rs,AF]
    -O OC --outcome-columns   Comma-separated names of the columns containing the
                              reference allele, alternate allele, rsid and
                              allele frequency in the outcome file.
                              [default: A1,A2,rsid,FreqA1]
Options:
    -h --help                 show this help message
"""

from subprocess import call

import pandas as pd

from docopt import docopt

from harmonize.create_exposure_outcome_table import create_exposure_outcome_table
from harmonize.harmonize import assign_table, sanity_check_table
from harmonize.flip_swap import flip_swap_remove, flip_swap_remove_exposure_alleles, swap_dosages
from harmonize.vcf_to_dosage import vcf_to_dosage


def main():
    """Harmonize outcome and exposure alleles."""

    args = docopt(__doc__)

    tolerance = float(args["--tolerance"][0])
    outpath = args["--output-dir"][0]

    if outpath:
        call("mkdir -p {}".format(outpath), shell=True)

    assert len(args["--exposure-columns"][0].split(",")) == 4,\
      "Exactly four columns are needed in --exposure-columns"
    assert len(args["--outcome-columns"][0].split(",")) == 4,\
        "Exactly four columns are needed in --outcome-columns"

    exposure_columns = {k: v for k, v in
                        zip(args["--exposure-columns"][0].split(","), "REF,ALT,rs,AF".split(","))}
    outcome_columns = {k: v for k, v in
                       zip(args["--outcome-columns"][0].split(","), "A1,A2,rsid,FreqA1".split(","))}

    exposure = pd.read_table(args["--exposure"][0], sep=r"\s+")
    outcome = pd.read_table(args["--outcome"][0], sep=r"\s+")

    original_exposure_names = list(exposure.columns)

    exposure.columns = [exposure_columns.get(c, c) for c in exposure.columns]
    outcome.columns = [outcome_columns.get(c, c) for c in outcome.columns]

    dosage_path = args["--dosage"][0]

    if dosage_path.endswith(".vcf") or dosage_path.endswith(".vcf.gz") or \
       dosage_path.endswith(".bcf"):
        dosage = vcf_to_dosage(dosage_path)
    else:
        dosage = pd.read_table(dosage_path, sep=r"\s+", index_col=0)

    exposure = exposure.set_index("rs")
    outcome = outcome.set_index("rsid")

    exposure_outcome = create_exposure_outcome_table(exposure, outcome)

    missing_data = exposure_outcome.loc[exposure_outcome.isnull().sum(axis=1).astype(bool)]
    exposure_outcome = exposure_outcome.dropna()

    sanity_check_table(exposure_outcome)

    snp_classifications = assign_table(exposure_outcome, tolerance)
    to_swap, to_flip, to_remove = flip_swap_remove(snp_classifications, tolerance)
    harmonized_exposure = flip_swap_remove_exposure_alleles(exposure, to_flip, to_swap, to_remove)
    harmonized_exposure = harmonized_exposure.reset_index()

    harmonized_dosage = swap_dosages(dosage, to_swap, to_remove)

    harmonized_dosage.to_csv(outpath + "/harmonized_dosage.csv", sep=" ")
    harmonized_exposure.to_csv(outpath + "/harmonized_exposure.csv", sep=" ",
                               header=original_exposure_names, index=False)
    snp_classifications.to_csv(outpath + "/snp_classifications.csv", sep=" ")
    missing_data.to_csv(outpath + "/missing_data.csv", sep=" ", na_rep="NA")

    return snp_classifications


if __name__ == "__main__":

    snps = main()

    print((snps.SNPType.value_counts() * 100 / len(snps)).to_csv(sep=" "))
