#!/usr/bin/env python3
"""Merge a minos VCF with a GVCF at certain positions (driven by the catalogue).
Will only include null calls from the GVCF.
"""
import argparse
import grumpy

from vcf_subset import subset_vcf

from pathlib import Path

def fetch_minos_positions(minos_path: Path, min_dp: int) -> set[int]:
    """Given a minos VCF, return the positions to exclude from the gvcf.

    Args:
        minos_vcf (Path): Path to the minos VCF file.
        min_dp (int): Minimum DP to consider a call valid.
    Returns:
        set[int]: The positions to exclude.
    """
    vcf = grumpy.VCFFile(minos_path.as_posix(), False, min_dp)
    positions = set()
    for position in vcf.calls.keys():
        calls = vcf.calls[position]
        for call in calls:
            if call.call_type == grumpy.AltType.DEL:
                # Deletion - need to exclude all bases deleted
                positions.update(range(position, position + len(call.alt)))
            else:
                positions.add(position)
    return positions

def check_gvcf_row(row: str, min_dp: int) -> bool:
    """Check if a GVCF row is just a null call (and so should be included).

    Args:
        row (str): The VCF row.
        min_dp (int): Minimum DP to consider a call valid.
    Returns:
        bool: True if the row is _just_ null calls, False otherwise.
    """
    with open(".temp_gvcf_row.vcf", "w") as f:
        f.write("##fileformat=VCFv4.2\n")
        f.write("#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\tsample\n")
        f.write(row + "\n")
    vcf = grumpy.VCFFile(".temp_gvcf_row.vcf", False, min_dp)
    valid = True
    for position in vcf.calls:
        for call in vcf.calls[position]:
            if call.call_type != grumpy.AltType.NULL:
                valid = False
    Path(".temp_gvcf_row.vcf").unlink()
    return valid


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--minos_vcf", help="The minos VCF filepath", required=True)
    parser.add_argument("--gvcf", help="The GVCF filepath", required=True)
    parser.add_argument("--resistant-positions", help="Path to list of resistant sites", required=True)
    parser.add_argument("--min_dp", help="Minimum DP to consider a call valid", type=int, default=3)
    parser.add_argument("--output", help="The output VCF file path", required=True)
    args = parser.parse_args()

    # Sanity check that args are paths
    minos_path = Path(args.minos_vcf)
    gvcf_path = Path(args.gvcf)
    resistant_positions_path = Path(args.resistant_positions)

    output_path = Path(args.output)

    # Sanity checking arguments
    if not minos_path.exists() or not gvcf_path.exists() or not resistant_positions_path.exists():
        raise FileNotFoundError("One or more of the input files does not exist!")
    if output_path.exists():
        raise FileExistsError("Output file already exists!")
    
    # Read in the resistant positions
    with open(resistant_positions_path) as f:
        resistant_positions = set([int(line.strip()) for line in f])
    
    minos_positions = fetch_minos_positions(minos_path, args.min_dp)
    to_fetch = sorted(list(resistant_positions - minos_positions))
    print(f"Fetching {len(to_fetch)} positions from the GVCF")
    # fetch_strs = set([str(pos) for pos in to_fetch])

    # Pull out these positions from the GVCF
    gvcf_headers, subset = subset_vcf(gvcf_path.as_posix(), to_fetch)

    minos_headers, minos_values = subset_vcf(minos_path.as_posix(), [])
    
    # Pull out header parts to catch parts which need adding
    minos_format = [header for header in minos_headers if "##FORMAT" in header]
    minos_info = [header for header in minos_headers if "##INFO" in header]
    minos_filter = [header for header in minos_headers if "##FILTER" in header]

    gvcf_format = [header for header in gvcf_headers if "##FORMAT" in header]
    gvcf_info = [header for header in gvcf_headers if "##INFO" in header]
    gvcf_filter = [header for header in gvcf_headers if "##FILTER" in header]

    missing_format = [header for header in gvcf_format if header not in minos_format]
    missing_info = [header for header in gvcf_info if header not in minos_info]
    missing_filter = [header for header in gvcf_filter if header not in minos_filter]

    
    minos_misc_headers = [header for header in minos_headers if "##FORMAT" not in header and "##INFO" not in header and "##FILTER" not in header and "#CHROM" not in header]

    chrom_line = [header for header in minos_headers if "#CHROM" in header][0]

    with open(output_path, "w") as f:
        for misc in minos_misc_headers:
            f.write(misc + "\n")
        
        # Merged format
        for header in minos_format:
            f.write(header + "\n")
        for header in missing_format:
            f.write(header + "\n")
        
        # Merged info
        for header in minos_info:
            f.write(header + "\n")
        for header in missing_info:
            f.write(header + "\n")
        
        # Merged filter
        for header in minos_filter:
            f.write(header + "\n")
        for header in missing_filter:
            f.write(header + "\n")
        
        f.write(chrom_line + "\n")

        # Minos rows
        minos_positions = []
        for row in minos_values:
            r = row.split("\t")
            pos = int(r[1])
            minos_positions.append(pos)
            f.write(row + "\n")
            
        # GVCF rows
        for row in subset:
            # Replace DP with COV as GVCF doesn't have a COV-like field
            if ":COV" not in row and "\tCOV:" not in row:
                if ":DP" in row:
                    row = row.replace(":DP", ":COV")
                elif "\tDP:" in row:
                    row = row.replace("\tDP:", "\tCOV:")

            row = row.split("\t")
            pos = int(row[1])
            if pos in minos_positions:
                # We should already be filtering out positions, but in cases of dels, the start can slip through
                # Catch duplicates in these cases
                continue
            
            # GVCF doesn't explicitly call filter passes, so ensure the calls are picked up
            row[6] = "PASS" if row[6] == "." else row[6]
            row = "\t".join(row) + "\n"
            if check_gvcf_row(row, args.min_dp):
                # Only include null calls
                f.write(row)