#!/usr/bin/env python3

import sys
import logging
import os
from clubcpg.ClusterReads import ClusterReadsWithImputation
import argparse
import datetime


def str2bool(v):
    if v.lower() == 'true':
        return True
    elif v.lower() == 'false':
        return False
    else:
        raise argparse.ArgumentTypeError("True or False value expected.")


# Set command line arguments
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument("-a", "--input_bam_A",
                        help="First Input bam file, coordinate sorted with index present, REQUIRED")
arg_parser.add_argument("-b", "--input_bam_B",
                        help="Second Input bam file, coordinate sorted with index present, OPTIONAL", default=None)
arg_parser.add_argument("--bins",
                        help="File with each line being one bin to extract and analyze, "
                             "generated by clubcpg-coverage, REQUIRED")
arg_parser.add_argument("-o", "--output_dir",
                        help="Output directory to save figures, defaults to bam file location")
arg_parser.add_argument("--bin_size", help="Size of bins to extract and analyze, default=100", default=100)
arg_parser.add_argument("-m", "--cluster_member_minimum",
                        help="Minimum number of members a cluster should have for it to be considered, default=4",
                        default=4)
arg_parser.add_argument("-r", "--read_depth",
                        help="Minium number of reads covering all CpGs that the bins should have to analyze, "
                             "default=10",
                        default=10)
arg_parser.add_argument("-n", "--num_processors",
                        help="Number of processors to use for analysis, default=1",
                        default=1)
arg_parser.add_argument("--read1_5", help="integer, read1 5' m-bias ignore bp, default=0", default=0)
arg_parser.add_argument("--read1_3", help="integer, read1 3' m-bias ignore bp, default=0", default=0)
arg_parser.add_argument("--read2_5", help="integer, read2 5' m-bias ignore bp, default=0", default=0)
arg_parser.add_argument("--read2_3", help="integer, read2 3' m-bias ignore bp, default=0", default=0)
arg_parser.add_argument("--no_overlap", help="bool, remove any overlap between paired reads and stitch"
                                             " reads together when possible, default=True",
                        type=str2bool, const=True, default='True', nargs='?')
arg_parser.add_argument("--remove_noise", help="bool, Discard the cluster containing noise points (-1)"
                                               " after clustering, default=True",
                        type=str2bool, const=True, default='True', nargs='?')

arg_parser.add_argument("--suffix",
                        help="Any additional info to include in the output file name, chromosome for example",
                        default=None)
arg_parser.add_argument("--models_A", help="Models to impute for input_bam_A, OPTIONAL", default=None)
arg_parser.add_argument("--models_B", help="Models to impute for imput_bam_B, OPTIONAL", default=None)
arg_parser.add_argument("--chunksize",
                        help="How large of chunks to split bins into during imputation. Higher will go faster, but uses more memory",
                        default=10000)

if __name__ == "__main__":


    args = arg_parser.parse_args()

    # Assign arg parser vars to new variables, not necessary, but I like it
    input_bam_a = args.input_bam_A
    input_bam_b = args.input_bam_B
    bins_file = args.bins
    bin_size = int(args.bin_size)
    cluster_min = int(args.cluster_member_minimum)
    read_depth_req = int(args.read_depth)
    num_processors = int(args.num_processors)
    no_overlap = args.no_overlap
    mbias_read1_5 = int(args.read1_5)
    mbias_read1_3 = int(args.read1_3)
    mbias_read2_5 = int(args.read2_5)
    mbias_read2_3 = int(args.read2_3)
    remove_noise = args.remove_noise
    models_A = args.models_A
    models_B = args.models_B
    chunksize = args.chunksize
    if args.suffix:
        suffix = str(args.suffix)
        suffix = "." + suffix
    else:
        suffix = ""


    # TODO write input params to log_file for record keeping

    # Check all inputs are supplied
    if not input_bam_a and bins_file:
        print("You must supply input_bam_a and a bins file. Exiting")
        sys.exit(1)

    if not input_bam_b:
        print("Only one input bam detected. Running in single-file mode")
        single_file_mode = True
    else:
        single_file_mode = False

    # Get or assign output directory
    if args.output_dir:
        output_dir = args.output_dir
    else:
        output_dir = os.path.dirname(input_bam_a)

    # Create output dir if it doesnt exist
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # Set up logging
    start_time = datetime.datetime.now().strftime("%y-%m-%d")
    log_file = os.path.join(output_dir, "Clustering.{}{}.{}.log".format(os.path.basename(input_bam_a), suffix, start_time))
    logging.basicConfig(filename=log_file, level=logging.DEBUG)

    logging.info(args)

    logging.info("Input files supplied:\n"
                 "A: {}\n"
                 "B: {}\n"
                 "Bins: {}"
                 .format(input_bam_a, input_bam_b, bins_file))


    logging.info("Starting workers pool, using {} processors".format(num_processors))
    logging.info("M bias inputs received, ignoring the following:\nread 1 5': {}bp\n"
                 "read1 3': {}bp\nread2 5: {}bp\nread2 3': {}bp".format(mbias_read1_5, mbias_read1_3, mbias_read2_5, mbias_read2_3))

    # Instantiate The clustering class
    cluster_reads = ClusterReadsWithImputation(
        bam_a=input_bam_a,
        bam_b=input_bam_b,
        bin_size=bin_size,
        bins_file=bins_file,
        output_directory=output_dir,
        num_processors=num_processors,
        cluster_member_min=cluster_min,
        read_depth_req=read_depth_req,
        remove_noise=remove_noise,
        mbias_read1_5=mbias_read1_5,
        mbias_read1_3=mbias_read1_3,
        mbias_read2_5=mbias_read2_5,
        mbias_read2_3=mbias_read2_3,
        suffix=suffix,
        no_overlap=no_overlap,
        models_A=models_A,
        models_B=models_B,
        chunksize=chunksize
    )

    logging.debug(args)

    # Perform clustering
    cluster_reads.execute()


    logging.info("Done")
    print("Complete")
