#!python
# created on Wed Jul 16 2025

## Required modules
import argparse, sys, eaglec

currentVersion = eaglec.__version__

def getargs():
    """
    Constructs an ArgumentParser object for command-line arguments
    for filtering EagleC2 SV calls.
    """
    parser = argparse.ArgumentParser(
        description='''Filters the predicted SVs based on probability values.''',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )

    parser.add_argument('-v', '--version', action='version',
                        version=' '.join(['%(prog)s',currentVersion]),
                        help='Print version number and exit.')

    # Input/Output
    parser.add_argument('-i', '--input-file',
                        help='Path to a .txt file generated by the predictSV command.')
    parser.add_argument('-o', '--output',
                        help='Output file name.')

    # Group 1: Resolution-specific cutoffs
    parser.add_argument('--res-cutoffs', nargs='+', type=str,
                        help='''A comma-separated list of probability cutoffs, each
                        corresponding to a resolution specified by --res-list.''')
    parser.add_argument('--res-list', nargs='+', type=str,
                        help='''Comma-separated list of resolutions.''')

    ## Parse the command-line arguments
    commands = sys.argv[1:]
    if not commands:
        commands.append('-h')
    args = parser.parse_args(commands)

    return args, commands


def parse_eaglec2_file(filepath):
    pool = {}
    header = ""
    try:
        with open(filepath, 'r') as source:
            header = source.readline().rstrip()  # Read header

            for line_num, line in enumerate(source, 2):  # Start line_num from 2 for data rows
                parts = line.rstrip().split('\t')
                if len(parts) < 13:
                    continue

                c1, p1_str, c2, p2_str = parts[:4]
                p1, p2 = int(p1_str), int(p2_str)

                # Ensure c1 < c2 or if c1 == c2, then p1 < p2 for consistent keying
                if c1 == c2:
                    if p1 > p2:
                        # Swap positions if c1==c2 and p1>p2 to ensure consistent ordering for key
                        p1, p2 = p2, p1
                elif c1 > c2:
                    # Swap chromosomes and positions for consistent ordering
                    c1, c2 = c2, c1
                    p1, p2 = p2, p1

                probs = [float(x) for x in parts[4:10]]
                res_origin = int(parts[10])
                res_finemapped = int(parts[11])

                row_key = (c1, c2, p1, p2, res_origin, res_finemapped)
                pool[row_key] = {'line': line, 'probs': probs,
                                'res_origin': res_origin,
                                'res_finemapped': res_finemapped}
    except FileNotFoundError:
        print(f"Error: Input file '{filepath}' not found.")
        sys.exit(1)

    return pool, header


def run():
    """
    Main function to parse arguments, filter SV calls, and write output.
    """
    args, commands = getargs()
    if commands[0] not in ['-h', '-v', '--help', '--version']:
        # Parse the input file (gap label filtering happens here, it's always applied)
        sv_pool, header = parse_eaglec2_file(args.input_file)
        print(f"Loaded {len(sv_pool)} SVs")

        svs_filtered = {}

        # Determine which filtering strategy to apply based on mutually exclusive group
        if args.res_cutoffs and args.res_list:
            # Apply resolution-specific cutoffs
            # args.res_cutoffs: 0.5,0.5,0.99
            # args.res_list: 5000,10000,50000
            res_list =  args.res_list[0].split(',')
            res_list = [int(res) for res in res_list]
            res_cutoffs = args.res_cutoffs[0].split(',')
            res_cutoffs = [float(cutoff) for cutoff in res_cutoffs]
            if len(res_list) != len(res_cutoffs):
                print("Error: The number of resolutions and cutoffs must match.")
                sys.exit(1)
            res_detail_cutoffs = dict(zip(res_list, res_cutoffs))
            print(f"Applying resolution-specific cutoffs: {res_detail_cutoffs}")
            for row_key, data in sv_pool.items():
                prob_max = max(data['probs'])
                res_finemapped = data['res_finemapped']

                if res_finemapped in res_detail_cutoffs:
                    cutoff = res_detail_cutoffs[res_finemapped]
                    if prob_max >= cutoff:
                        svs_filtered[row_key] = data['line']
        else:
            # This case should ideally not be reached due to required=True in mutually_exclusive_group
            print("No filtering parameters was applied. Please check your arguments.")

        print(f"Found {len(svs_filtered)} SVs after applying filters.")

        # Write filtered SVs to output file
        output_filepath = f"{args.output}"
        with open(output_filepath, 'w') as out:
            # Ensure the header is correctly written (from the parsed file)
            out.write(header + '\n')
            for row_key in sorted(svs_filtered.keys()): # Sort for consistent output
                out.write(svs_filtered[row_key])

if __name__ == '__main__':
    run()