#!/usr/bin/env python3

import argparse
import os
import pandas as pd
from clubcpg.Imputation import Imputation

# Input params
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument("-a", "--input_bam_file",
                        help="Input bam file, coordinate sorted with index present", default=None)
arg_parser.add_argument("-c", "--coverage",
                        help="output file from clubcpg-coverage, filtered for at least 1 read and 2 cpgs", default=None)
arg_parser.add_argument("-o", "--output", help="folder to save generated model files", default=None)
arg_parser.add_argument("-n", help="number of cpu cores to use")
arg_parser.add_argument("-l", "--limit_samples", help="Limit the number of samples used to train the model, "
                                                      "this will speed up training. default=10000", default=10000)
arg_parser.add_argument("--read1_5", help="integer, read1 5' m-bias ignore bp, default=0", default=0)
arg_parser.add_argument("--read1_3", help="integer, read1 3' m-bias ignore bp, default=0", default=0)
arg_parser.add_argument("--read2_5", help="integer, read2 5' m-bias ignore bp, default=0", default=0)
arg_parser.add_argument("--read2_3", help="integer, read2 3' m-bias ignore bp, default=0", default=0)

if __name__ == "__main__":

    # Extract arguments from command line and set as correct types
    args = arg_parser.parse_args()

    # Get the mbias inputs and adjust to work correctly, 0s should be converted to None
    mbias_read1_5 = int(args.read1_5)
    mbias_read1_3 = int(args.read1_3)
    mbias_read2_5 = int(args.read2_5)
    mbias_read2_3 = int(args.read2_3)
    processes = int(args.n)
    sample_limit = int(args.limit_samples)
    
    # Set output dir
    if not args.output:
        output_folder = os.path.dirname(args.input_bam_file)
    else:
        output_folder = args.output

    try:
        os.mkdir(output_folder)
    except FileExistsError:
        print("Output folder already exists... no need to create it...")

    # Read in coverage file
    coverage_data = pd.read_csv(args.coverage, header=None)
    coverage_data.columns = ['bin', 'reads', 'cpgs']

    # Train models
    for i in range(2,6):
        print("Starting training cpg density: {}".format(i))
        trainer = Imputation(i, args.input_bam_file, mbias_read1_5, mbias_read1_3, mbias_read2_5, mbias_read2_3, processes)
        matrices = trainer.extract_matrices(coverage_data, sample_limit=sample_limit)
        model = trainer.train_model(output_folder, matrices)

    print("done")







    


