#!/usr/bin/env python3

import argparse
import os
import pandas as pd
import tempfile
import logging
import numpy as np
from clubcpg.Imputation import Imputation

def create_dictionary(bins, matrices):
    output = dict()
    for b, m in zip(bins, matrices):
        output[b] = m

    return output

### Get Input params ###
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument("-a", "--input_bam_file",
                        help="Input bam file, coordinate sorted with index present", default=None)
arg_parser.add_argument("-c", "--coverage", help="output file from clubcpg-coverage, "
                                                 "filtered for at least 1 read and 2 cpgs",
                        default=None)
arg_parser.add_argument("-m", "--models", help="Path to folder containing saved models", default=None)
arg_parser.add_argument("-o", "--output", help="folder to save imputed coverage data", default=None)
arg_parser.add_argument("-n", help="number of cpu cores to use")
arg_parser.add_argument("-chr", "--chromosome", help="Optional, perform only on one chromosome. "
                                      "Default=all chromosomes provided in -c. Example: 'chr7'",
                        default=None)
arg_parser.add_argument("--read1_5", help="integer, read1 5' m-bias ignore bp, default=0", default=0)
arg_parser.add_argument("--read1_3", help="integer, read1 3' m-bias ignore bp, default=0", default=0)
arg_parser.add_argument("--read2_5", help="integer, read2 5' m-bias ignore bp, default=0", default=0)
arg_parser.add_argument("--read2_3", help="integer, read2 3' m-bias ignore bp, default=0", default=0)

if __name__ == "__main__":
    # Extract arguments from command line and set as correct types
    args = arg_parser.parse_args()

    # Set output dir
    if not args.output:
        output_folder = os.path.dirname(args.input_bam_file)
    else:
        output_folder = args.output

    try:
        os.mkdir(output_folder)
    except FileExistsError:
        print("Output folder already exists... no need to create it...")

    # Get selected chromosome
    chr = args.chromosome

    # Set up logging
    log_file = os.path.join(output_folder, "clubcpg-impute-coverage.{}.{}log".format(args.input_bam_file, chr+"." if chr else "")) #todo fix this to inputbam base name
    logging.basicConfig(filename=log_file, level=logging.DEBUG)

    # log all passed command line arguments
    logging.info(args)

    # Get the mbias inputs and adjust to work correctly, 0s should be converted to None
    mbias_read1_5 = int(args.read1_5)
    mbias_read1_3 = int(args.read1_3)
    mbias_read2_5 = int(args.read2_5)
    mbias_read2_3 = int(args.read2_3)
    processes = int(args.n)
    models = args.models

    ### Read in coverage file ###
    coverage_data = pd.read_csv(args.coverage, header=None)
    coverage_data.columns = ['bin', 'reads', 'cpgs']

    if chr:
        print("Performing imputation only on chromosome: {}".format(chr))
        coverage_data = coverage_data[coverage_data['bin'].str.startswith(''.join([chr, "_"]))]

    ### Impute from models ###
    # temp file for storage
    tfile = tempfile.TemporaryFile(mode="w+t")
    for i in range(2,6):
        print("Starting with cpg density: {}...".format(i), flush=True)
        imputer = Imputation(i, args.input_bam_file, mbias_read1_5, mbias_read1_3, mbias_read2_5, mbias_read2_3, processes)
        # Get matrices with unknowns as -1
        print("Extracting cpg matrices from genome...", flush=True)
        bins, matrices = imputer.extract_matrices(coverage_data, return_bins=True)

        data_dict = create_dictionary(bins, matrices)

        print("Beginning imputation...", flush=True)
        imputed_matrices = imputer.impute_from_model(models_folder=models, matrices=list(data_dict.values()), postprocess=True)

        data_dict = create_dictionary(data_dict.keys(), imputed_matrices)

        for b, m in zip(data_dict.keys(), data_dict.values()):
            m2 = m[~np.isnan(m).any(axis=1)]
            line = "{},{},{}\n".format(b, m2.shape[0], m2.shape[1])
            tfile.write(line)
        
    ### Update Coverage File and Save New ###
    print("Updating original data with imputed data...", flush=True)
    before = coverage_data
    tfile.seek(0)
    after = pd.read_csv(tfile, header=None)
    after.columns = ['bin', 'reads', 'cpgs']

    # set indexes for update
    before.set_index(['bin'], inplace=True)
    after.set_index(['bin'], inplace=True)

    # update rows with imputed values
    before.update(after)

    # Set filename and Save
    if chr:
        outfile = os.path.join(output_folder, os.path.basename(args.coverage) + ".{}.IMPUTED.csv".format(chr))
    else:
        outfile = os.path.join(output_folder, os.path.basename(args.coverage) + ".IMPUTED.csv")

    print("Saving updated data to file...", flush=True)
    before.to_csv(outfile, header=False)

    print("done")
