#!python
# created on Wed Jul 16 2025

## Required modules
import argparse, sys, logging, eaglec, os

currentVersion = eaglec.__version__

def getargs():
    """
    Constructs an ArgumentParser object for command-line arguments
    for filtering EagleC2 SV calls.
    """
    parser = argparse.ArgumentParser(
        description='''Filter SV predictions from EagleC2 output based on various criteria.
                       You must choose either resolution-specific cutoffs OR high/low resolution cutoffs, not both.''',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )

    # Input/Output
    parser.add_argument('-i', '--input-file',
                        help='Path to the input EagleC2 SV calls file.')
    parser.add_argument('-O', '--output-prefix',
                        default='EagleC2.SV_calls.filtered',
                        help='Prefix for the output filtered SV calls file.')


    # Create a mutually exclusive group for the two filtering strategies
    filter_group = parser.add_mutually_exclusive_group(required=True) # Make one of these filtering options mandatory

    # Group 1: Resolution-specific cutoffs
    filter_group.add_argument('--res-cutoffs', nargs='+', type=str,
                              help='Comma-separated list of probability cutoffs corresponding '
                                   'to resolutions specified by --res-list. E.g., "0.5,0.65,0.99"')
    parser.add_argument('--res-list', nargs='+', type=str,
                              help='Comma-separated list of resolutions to keep --res-cutoffs. '
                                   'Must be used with --res-cutoffs and have the same number of values. '
                                   'E.g., "5000,10000,50000"')

    # Group 2: High/Low resolution cutoffs
    filter_group.add_argument('--high-low-cutoffs', action='store_true',
                              help='Use the default high/low resolution cutoffs (--high-res-cutoff, --low-res-cutoff). '
                                   'This flag must be set to enable these cutoffs.')
    parser.add_argument('--high-res-cutoff', type=float, default=0.5,
                        help='Probability cutoff for SVs refined at resolutions <= 10000. Only used with --high-low-cutoffs.')
    parser.add_argument('--low-res-cutoff', type=float, default=0.99,
                        help='Probability cutoff for SVs refined at resolutions > 10000 '
                             'and where original and refined resolutions are the same. Only used with --high-low-cutoffs.')


    ## Parse the command-line arguments
    args = parser.parse_args()

    # Validate resolution-specific cutoffs if they are chosen
    if args.res_cutoffs or args.res_list: # This checks if either is present, implying this group was chosen
        if not (args.res_cutoffs and args.res_list):
            parser.error("Both --res-cutoffs and --res-list must be provided together.")
        if len(args.res_cutoffs) != len(args.res_list):
            parser.error("Number of values in --res-cutoffs must match --res-list.")

    return args


def parse_eaglec2_file(filepath):
    pool = {}
    header = ""
    try:
        with open(filepath, 'r') as source:
            header = source.readline().rstrip()  # Read header

            for line_num, line in enumerate(source, 2):  # Start line_num from 2 for data rows
                parts = line.rstrip().split('\t')
                if len(parts) < 13:
                    continue

                c1, p1_str, c2, p2_str = parts[:4]
                p1, p2 = int(p1_str), int(p2_str)

                # Ensure c1 < c2 or if c1 == c2, then p1 < p2 for consistent keying
                if c1 == c2:
                    if p1 > p2:
                        # Swap positions if c1==c2 and p1>p2 to ensure consistent ordering for key
                        p1, p2 = p2, p1
                elif c1 > c2:
                    # Swap chromosomes and positions for consistent ordering
                    c1, c2 = c2, c1
                    p1, p2 = p2, p1

                probs = [float(x) for x in parts[4:10]]
                res_origin = int(parts[10])
                res_finemapped = int(parts[11])
                gap_label = parts[12]

                if not gap_label in ['0,0', '0,1', '1,0', '1,1', '0,2', '2,0']:
                    continue

                row_key = (c1, c2, p1, p2, res_origin, res_finemapped)
                pool[row_key] = {'line': line, 'probs': probs,
                                    'res_origin': res_origin,
                                    'res_finemapped': res_finemapped}
    except FileNotFoundError:
        print(f"Error: Input file '{filepath}' not found.")
        sys.exit(1)
    return pool, header


def run():
    """
    Main function to parse arguments, filter SV calls, and write output.
    """
    args = getargs()

    # Parse the input file (gap label filtering happens here, it's always applied)
    sv_pool, header = parse_eaglec2_file(args.input_file)
    print(f"Loaded {len(sv_pool)} SVs")

    svs_filtered = {}

    # Determine which filtering strategy to apply based on mutually exclusive group
    if args.res_cutoffs and args.res_list:
        # Apply resolution-specific cutoffs
        # args.res_cutoffs: 0.5,0.5,0.99
        # args.res_list: 5000,10000,50000
        res_list =  args.res_list[0].split(',')
        res_list = [int(res) for res in res_list]
        res_cutoffs = args.res_cutoffs[0].split(',')
        res_cutoffs = [float(cutoff) for cutoff in res_cutoffs]
        res_detail_cutoffs = dict(zip(res_list, res_cutoffs))
        print(f"Applying resolution-specific cutoffs: {res_detail_cutoffs}")
        for row_key, data in sv_pool.items():
            prob_max = max(data['probs'])
            res_finemapped = data['res_finemapped']

            if res_finemapped in res_detail_cutoffs:
                cutoff = res_detail_cutoffs[res_finemapped]
                if prob_max >= cutoff:
                    svs_filtered[row_key] = data['line']

    elif args.high_low_cutoffs:
        # Apply high/low resolution cutoffs
        print(f"Applying general high/low resolution cutoffs (High: {args.high_res_cutoff}, Low: {args.low_res_cutoff})")
        for row_key, data in sv_pool.items():
            prob_max = max(data['probs'])
            res_origin = data['res_origin']
            res_finemapped = data['res_finemapped']

            keep = True
            if res_finemapped <= 10000:
                if prob_max < args.high_res_cutoff:
                    keep = False
            elif res_finemapped > 10000 and res_finemapped == res_origin:
                if prob_max < args.low_res_cutoff:
                    keep = False

            if keep:
                svs_filtered[row_key] = data['line']
    else:
        # This case should ideally not be reached due to required=True in mutually_exclusive_group
        print("No filtering strategy selected. Exiting.")


    print(f"Found {len(svs_filtered)} SVs after applyingfilters.")

    # Write filtered SVs to output file
    output_filepath = f"{args.output_prefix}.txt"
    with open(output_filepath, 'w') as out:
        # Ensure the header is correctly written (from the parsed file)
        out.write(header + '\n')
        for row_key in sorted(svs_filtered.keys()): # Sort for consistent output
            out.write(svs_filtered[row_key])


if __name__ == '__main__':
    run()