#!python

## Required modules
import argparse, sys, logging, eaglec, os, time, tempfile

currentVersion = eaglec.__version__

def getargs():
    """Parse command-line arguments using argparse."""
    parser = argparse.ArgumentParser(
        description="Evaluates a predefined list of SVs using EagleC2 models.",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    
    # --- Input/Output Arguments ---
    parser.add_argument('-v', '--version', action='version',
                        version=' '.join(['%(prog)s',currentVersion]),
                        help='Print version number and exit.')
    
    parser.add_argument('-i', '--sv-file',
                        help="Path to a .txt file containing SV breakpoint coordinates for evaluation. ")
    parser.add_argument('-m', '--mcool',
                        help="Path to a mcool file containing contact matrices at multiple resolutions.")
    parser.add_argument('-O', '--output-prefix', help='''Prefix of the output file names''')
    parser.add_argument('--model-path', default='EagleC2-models',
                        help="Path to the folder containing pretrained models.")

    # --- Analysis Arguments ---
    parser.add_argument('--resolutions', default='5000,10000,25000',
                        help="Resolutions at which the contact matrices will be used for SV evaluation.")
    parser.add_argument('--balance-type', default='ICE', choices=['ICE', 'CNV', 'Raw'],
                        help = "Normalization method.")
    parser.add_argument('--search-buffer', type=int, default=3,
                        help="Buffer size (in bins) around each SV breakpoint pair used during the search.")

    # --- Performance Arguments ---
    parser.add_argument('-p', '--nproc', type=int, default=8,
                        help="Maximum number of processes for parallel computation.")
    parser.add_argument('--batch-size', type=int, default=256,
                        help="Batch size for model prediction.")
    
    commands = sys.argv[1:]
    if not commands:
        commands.append('-h')
    args = parser.parse_args(commands)
    
    return args, commands



def main():

    args, commands = getargs()
    if commands[0] not in ['-h', '-v', '--help', '--version']:
        
        import cooler
        import joblib
        import numpy as np
        from collections import defaultdict
        from eaglec.utilities import calculate_expected
        from eaglec.extractMatrix import collect_images
        from eaglec.predictCore import load_models, get_queue, convert2TF

        def evaluate(cache_folder, models, batch_size=256):

            queue = get_queue(cache_folder, maxn=100000, pattern='collect*.pkl')
            original_predictions = {}
            SV_labels = ['++', '+-', '-+', '--', '++/--', '+-/-+']
            for data in queue:
                images = np.r_[[d[0] for d in data]]
                images = convert2TF(images, batch_size)
                coords = [d[1] for d in data]
                prob_pool = np.stack([model.predict(images) for model in models])
                prob_mean = prob_pool.mean(axis=0)[:,:6]
                for i in range(prob_mean.shape[0]):
                    c1, p1, c2, p2, res = coords[i]
                    prob = prob_mean[i]
                    maxi = prob.argmax()
                    sv = SV_labels[maxi]
                    if not res in original_predictions:
                        original_predictions[res] = {}
                    
                    original_predictions[res][(c1, c2, p1, p2)] = [sv, prob[maxi]]

            return original_predictions

        def parse_SVs(fil):

            coords = {}
            with open(fil, 'r') as source:
                for line in source:
                    c1, p1, c2, p2 = line.rstrip().split()[:4]
                    p1, p2 = int(p1), int(p2)
                    if c1 > c2:
                        c1, c2 = c2, c1
                        p1, p2 = p2, p1
                    key = (c1, c2)
                    if not key in coords:
                        coords[key] = []
                    coords[key].append((p1, p2))
            
            return coords


        """Main execution function."""
        # --- 0. Setup Logging (directly in main) ---
        log_file_path = os.path.splitext(args.output_prefix)[0] + '.log'
        
        # Get root logger
        logger = logging.getLogger()
        logger.setLevel(logging.DEBUG) # Set lowest level to DEBUG (10)

        # Create and configure File Handler
        file_handler = logging.FileHandler(log_file_path, mode='w')
        file_handler.setLevel(logging.INFO) # File records INFO and above

        # Create and configure Console Handler
        console_handler = logging.StreamHandler(sys.stdout)
        console_handler.setLevel(logging.INFO) # Console also shows INFO and above

        # Create Formatter to match the first script's style
        formatter = logging.Formatter(
            fmt='%(name)-25s %(levelname)-7s @ %(asctime)s: %(message)s',
            datefmt='%m/%d/%y %H:%M:%S'
        )
        
        # Set formatter for handlers
        file_handler.setFormatter(formatter)
        console_handler.setFormatter(formatter)
        
        # Add handlers to the root logger
        logger.addHandler(file_handler)
        logger.addHandler(console_handler)

        # --- Start Logging ---
        # Log all arguments in the style of the first script
        arg_list_str = ['# ARGUMENT LIST:']
        for key, value in sorted(vars(args).items()):
            arg_key_formatted = key.replace('_', ' ').capitalize()
            arg_list_str.append(f'# {arg_key_formatted:<25} = {value}')
        logger.info('\n' + '\n'.join(arg_list_str))

        resolutions = [int(r) for r in args.resolutions.split(',')]
        clr = cooler.Cooler('{0}::resolutions/{1}'.format(args.mcool, resolutions[0]))
        chroms = []
        for c in clr.chromnames:
            chroms.append(c)
        buf = args.search_buffer
        if args.balance_type == 'ICE':
            balance = 'weight'
        elif args.balance_type == 'CNV':
            balance = 'sweight'
        elif args.balance_type == 'Raw':
            balance = False
        
        # cache folder
        cache_root = os.path.abspath(os.path.expanduser('.eaglec2'))
        if not os.path.exists(cache_root):
            os.mkdir(cache_root)

        tl = time.strftime('%Y%m%d%H%M%S', time.localtime(time.time()))
        cache_folder = tempfile.mkdtemp(**{'suffix':tl, 'dir':cache_root})
        
        # expected values
        expected_values = {}
        logger.info('Calculating the expected values ...')
        for res in resolutions:
            logger.info('Resolution {0}'.format(res))
            max_bins = max(200, 2000000//res)
            clr = cooler.Cooler('{0}::resolutions/{1}'.format(args.mcool, res))
            expected_values[res] = calculate_expected(clr, chroms, balance, max_bins, nproc=args.nproc)
        joblib.dump(expected_values, os.path.join(cache_folder, 'expected_values.pkl'))

        # read candidates at each resolution
        sv_bychrom = parse_SVs(args.sv_file)
        logging.info(f"Found {sum(len(sv_bychrom[c]) for c in sv_bychrom)} SVs to evaluate")
        by_res = defaultdict(dict)
        for res in resolutions:
            for c1, c2 in sv_bychrom:
                if not c1 in chroms or not c2 in chroms:
                    continue

                by_res[res][(c1, c2)] = set()
                for p1, p2 in sv_bychrom[(c1, c2)]:
                    b1 = p1 // res
                    b2 = p2 // res
                    for i in range(-buf, buf+1):
                        for j in range(-buf, buf+1):
                            xi, yi = b1 + i, b2 + j
                            if c1 == c2:
                                if yi - xi < 5:
                                    continue
                            
                            by_res[res][(c1, c2)].add((xi, yi))
        
        # collect images
        total_n = collect_images(args.mcool, by_res, expected_values, balance, cache_folder,
                                 w=15, entropy_cutoff=1, nproc=args.nproc)
        logging.info(f"Totally collected {total_n} images for evaluation.")
        # load models
        logger.info(f"Loading pre-trained models from: {args.model_path}")
        models = load_models(args.model_path)

        # model evaluation
        original_predictions = evaluate(cache_folder, models, batch_size=args.batch_size)
        
        # match original SVs
        collect = []
        for c1, c2 in sv_bychrom:
            for p1, p2 in sv_bychrom[(c1, c2)]:
                for res in resolutions:
                    if not res in original_predictions:
                        continue
                    b1 = p1 // res
                    b2 = p2 // res
                    sort_table = []
                    for i in range(-buf, buf+1):
                        for j in range(-buf, buf+1):
                            xi, yi = b1 + i, b2 + j
                            k = (c1, c2, xi, yi)
                            if k in original_predictions[res]:
                                sv_label, prob = original_predictions[res][k]
                                sort_table.append((prob, (xi*res, yi*res, sv_label, prob)))
                    
                    sort_table.sort(reverse=True)
                    if len(sort_table) > 0:
                        _, _, sv_label, prob = sort_table[0][1]
                        collect.append((c1, p1, c2, p2, sv_label, prob, res))
                                    
        logger.info(f"Found {len(collect)} matching SVs across all resolutions.")                             
        
        # --- 4. Save Results ---
        with open('{0}.txt'.format(args.output_prefix), 'w') as out:
            out.write('\t'.join(['chrom1', 'pos1', 'chrom2', 'pos2', 'strand', 'probability', 'resolution']) + '\n')
            for row in collect:
                c1, p1, c2, p2, strand, prob, res = row
                out.write(f"{c1}\t{p1}\t{c2}\t{p2}\t{strand}\t{prob:.4g}\t{res}\n")
        logger.info("Processing complete!")


if __name__ == '__main__':
    main()