#!python

## Required modules
import argparse, sys, logging, eaglec, os

currentVersion = eaglec.__version__

def getargs():
    """Parse command-line arguments using argparse."""
    parser = argparse.ArgumentParser(
        description="Re-evaluate Structural Variation (SV) probabilities based on existing coordinates using a contact map.",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    
    # --- Input/Output Arguments ---
    parser.add_argument('-v', '--version', action='version',
                        version=' '.join(['%(prog)s',currentVersion]),
                        help='Print version number and exit.')
    parser.add_argument('-i', '--sv-file', required=True,
                        help="Input file with Structural Variation (SV) coordinates, format: 'chr1 pos1 chr2 pos2'.")
    parser.add_argument('-m', '--mcool', required=True,
                        help="Path to the input .mcool file.")
    parser.add_argument('-o', '--output-file', required=True,
                        help="Path to the output .pkl file for storing the final results.")
    parser.add_argument('--model-path', required=True, default='./EagleC2-models/EagleC2-models.TF',
                        help="Path to the folder containing pre-trained EagleC2 TensorFlow models.")

    # --- Analysis Arguments ---
    parser.add_argument('--resolutions', default='5000,10000,25000',
                        help="Resolutions at which the contact matrices will be used for SV evaluation.")
    parser.add_argument('--balance-type', default='ICE', choices=['ICE', 'CNV', 'Raw'],
                        help = "Normalization method.")
    parser.add_argument('--search-buffer', type=int, default=3,
                        help="Buffer size around SV coordinates to search for the optimal position (in multiples of resolution).")

    # --- Performance Arguments ---
    parser.add_argument('-p', '--nproc', type=int, default=8,
                        help="Maximum number of processes for parallel computation.")
    parser.add_argument('--batch-size', type=int, default=256,
                        help="Batch size for model prediction.")
    
    return parser.parse_args()



def main():

    args = getargs()
        
    import cooler
    import glob
    import joblib
    import numpy as np
    import tensorflow as tf
    from collections import defaultdict

    # Assume these are valid utility functions from your project
    # If these are part of your project, ensure they are importable
    from eaglec.utilities import distance_normaize_core, image_normalize, calculate_expected
    from eaglec.extractMatrix import check_sparsity

        
    def parse_SVs(fil):

        coords = {}
        with open(fil, 'r') as source:
            for line in source:
                c1, c2,p1, p2 = line.rstrip().split()
                p1, p2 = int(p1), int(p2)
                if c1 > c2:
                    c1, c2 = c2, c1
                    p1, p2 = p2, p1
                key = (c1, c2)
                if not key in coords:
                    coords[key] = []
                coords[key].append((p1, p2))
        
        return coords

    def load_models(root_folder):

        model_paths = glob.glob(os.path.join(root_folder, 'EagleC2-model*'))
        models = []
        for f in model_paths:
            model = tf.saved_model.load(f)
            inference_func = model.signatures['serving_default']
            models.append(inference_func)
        
        return models

    def transform_dataset(image):

        input_1 = tf.expand_dims(image[:,5:-5,5:-5], axis=-1)
        input_2 = tf.expand_dims(image, axis=-1)

        return input_1, input_2

    def convert2TF(images, batch_size=256):

        images = images.astype(np.float32)
        images = tf.data.Dataset.from_tensor_slices(images)
        images = images.batch(batch_size)
        images = images.map(transform_dataset)
        images = images.cache()

        return images

    def collect_images_core(clr, coords, balance, exp_dict, w=15, buff=20000):

        res = clr.binsize
        images = []
        sv_coords = []
        window_coords = []
        for c1, c2 in coords:
            Matrix = clr.matrix(balance=balance, sparse=True).fetch(c1, c2).tocsr()
            for p1, p2 in coords[(c1, c2)]:
                sv = (c1, c2, p1, p2)
                for x in range(p1-buff, p1+buff+res, res):
                    for y in range(p2-buff, p2+buff+res, res):
                        xi = x // res
                        yi = y // res
                        if (xi - w >= 0) and (xi + w + 1 <= Matrix.shape[0]) and \
                        (yi - w >= 0) and (yi + w + 1 <= Matrix.shape[1]):
                            window = Matrix[xi-w:xi+w+1, yi-w:yi+w+1].toarray()
                            window = window.astype(exp_dict[c1].dtype)
                            window[np.isnan(window)] = 0

                            if not check_sparsity(window):
                                # print(f'sparsity {c1} {c2} {xi} {yi}')  
                                continue

                            if c1 == c2:
                                window = distance_normaize_core(window, exp_dict[c1], xi, yi, w)
                            
                            window = image_normalize(window)
                            images.append(window)
                            sv_coords.append(sv)
                            window_coords.append((c1, c2, xi*res, yi*res))
                        elif (xi - w < 0) or (xi + w + 1 > Matrix.shape[0]):
                            # print(f'out of range: {c1} {xi}')
                            continue
                        elif (yi - w < 0) or (yi + w + 1 > Matrix.shape[1]):
                            # print(f'out of range: {c2} {yi}')
                            continue
        
        images = np.r_[images]
        
        return images, sv_coords, window_coords

    def predict_core(concrete_func, dataset):

        # sort input names according to the expected input shape
        input_signature = concrete_func.structured_input_signature[1]
        input_names = []
        for name, tensor_spec in input_signature.items():
            input_names.append((tensor_spec.shape[1], name))
        input_names.sort()

        all_probs = []
        for batch in dataset:
            input_dict = {input_names[0][1]: batch[0],
                        input_names[1][1]: batch[1]}
            preds = concrete_func(**input_dict)
            probs = preds['output_0']
            all_probs.append(probs)
        
        all_probs_tensor = tf.concat(all_probs, axis=0)
        all_probs_arr = all_probs_tensor.numpy()

        return all_probs_arr


    def predict(eval_funcs, images, sv_coords, batch_size=256,window_coords=None):
        images = convert2TF(images, batch_size)
        prob_pool = np.stack([predict_core(eval_func, images) for eval_func in eval_funcs])
        prob_mean = prob_pool.mean(axis=0)[:,:6]
        sv_probs = defaultdict(list)
        for i in range(prob_mean.shape[0]):
            sv = sv_coords[i]
            prob = prob_mean[i]
            sv_probs[sv].append((prob, window_coords[i]))
        
        return sv_probs
    
    def find_max_array_element(data_list):
        SV_labels = ['++', '+-', '-+', '--', '++/--', '+-/-+']
        max_tmp = -np.inf
        max_element = None
        for array, position in data_list:
            current_max = np.max(array)
            current_label = SV_labels[np.argmax(array)]
            if current_max > max_tmp:
                max_tmp = current_max
                max_element = (array, position, current_label)
        return max_element

    
    def find_best_prediction(sv_results):
        sv_evaluated = []
        for sv, predictions in sv_results.items():
            c1, c2, p1, p2 = sv
            max_element = find_max_array_element(predictions)
            if max_element is not None:
                array, position, current_label = max_element
                max_prob = np.max(array)
                hic_p1 = position[2]
                hic_p2 = position[3]
                sv_evaluated.append((c1, c2, p1, p2, current_label, max_prob, hic_p1, hic_p2))
        return sv_evaluated

    """Main execution function."""
    # --- 0. Setup Logging (directly in main) ---
    log_file_path = os.path.splitext(args.output_file)[0] + '.log'
    
    # Get root logger
    logger = logging.getLogger()
    logger.setLevel(logging.DEBUG) # Set lowest level to DEBUG (10)

    # Create and configure File Handler
    file_handler = logging.FileHandler(log_file_path, mode='w')
    file_handler.setLevel(logging.INFO) # File records INFO and above

    # Create and configure Console Handler
    console_handler = logging.StreamHandler(sys.stdout)
    console_handler.setLevel(logging.INFO) # Console also shows INFO and above

    # Create Formatter to match the first script's style
    formatter = logging.Formatter(
        fmt='%(name)-25s %(levelname)-7s @ %(asctime)s: %(message)s',
        datefmt='%m/%d/%y %H:%M:%S'
    )
    
    # Set formatter for handlers
    file_handler.setFormatter(formatter)
    console_handler.setFormatter(formatter)
    
    # Add handlers to the root logger
    logger.addHandler(file_handler)
    logger.addHandler(console_handler)

    # --- Start Logging ---
    # Log all arguments in the style of the first script
    arg_list_str = ['# ARGUMENT LIST:']
    for key, value in sorted(vars(args).items()):
        arg_key_formatted = key.replace('_', ' ').capitalize()
        arg_list_str.append(f'# {arg_key_formatted:<25} = {value}')
    logger.info('\n' + '\n'.join(arg_list_str))

    # --- 1. Data and Model Loading ---
    logger.info(f"Loading SV coordinates from: {args.sv_file}")
    sv_bychrom = parse_SVs(args.sv_file)
    logger.info(f"Loading pre-trained models from: {args.model_path}")
    models = load_models(args.model_path)
    resolutions = [int(r) for r in args.resolutions.split(',')]
    if args.balance_type == 'ICE':
        balance = 'weight'
    elif args.balance_type == 'CNV':
        balance = 'sweight'
    elif args.balance_type == 'Raw':
        balance = False
    
    chroms = ['chr'+str(i) for i in range(1, 23)] + ['chrX']
    SV_labels = ['++', '+-', '-+', '--', '++/--', '+-/-+']
    eval_res = {}
    for res in resolutions:
        logger.info(f'Resolution {res}')
        uri = f'{args.mcool}::resolutions/{res}'
        clr = cooler.Cooler(uri)
        logger.info(f"Loading .mcool file: {uri}")

        # --- 2. Core Computation ---
        logger.info("Calculating expected values for chromosomes...")
        max_bins = max(200, 2000000 // res)
        exp_dict = calculate_expected(clr, chroms, balance, max_bins, nproc=args.nproc)

        logger.info("Extracting images around SV coordinates...")
        buffer_size = res * args.search_buffer
        images, sv_coords, window_coords = collect_images_core(
            clr, sv_bychrom, balance, exp_dict, buff=buffer_size
        )

        if images.size == 0:
            logger.warning("Failed to extract any valid images from the input coordinates.")
            sys.exit(0)
        logger.info(f"Successfully extracted {len(images)} images for analysis.")
        logger.info("Evaluating the probability of each SV...")
        sv_probs = predict(models, images, sv_coords,
                           batch_size=args.batch_size, window_coords=window_coords)
        
        # --- 3. Result Processing and Saving ---
        logger.info("Filtering for the best prediction for each SV...")
        sv_evaluated = find_best_prediction(sv_probs)
        eval_res[res] = sv_evaluated

    # --- 4. Save Results ---
    with open(f'{args.output_file}.SV_evaluate.txt', 'w') as out:
        out.write('\t'.join(['chrom1', 'pos1', 'chrom2', 'pos2', 'strand', 'probability', 'hic_pos1', 'hic_pos2', 'resolution']) + '\n')
        for res, svs in eval_res.items():
            for c1, c2, p1, p2, strand, prob, hic_p1, hic_p2 in svs:
                out.write(f"{c1}\t{p1}\t{c2}\t{p2}\t{strand}\t{prob:.4g}\t{hic_p1}\t{hic_p2}\t{res}\n")
    logger.info("Processing complete!")


if __name__ == '__main__':
    main()