#!python

#################
# paired_end_read_merger
#
# Description:
# Merges paired-end reads to one fused reads based on alignment.
# Can also sort unpaired and unaligned reads to seperate output files.
#
# Note: Unpaired reads can be written to a separate file or by default the same output file.
#       Unaligned reads are skipped unless specified with --unaligned.
#       
# dependencies:
#  - python==3.10.12
#  - pysam==0.23.0
#  - biopython==1.83
#  - numpy==1.26.3
#  - progress==1.6
#
# Author:
# Michal Okoniewski
# Ivan Topolsky
#
#################

import argparse
import sys
import logging
from pathlib import Path
from progress.bar import Bar
from progress.spinner import Spinner

from itertools import groupby
import more_itertools
import pysam

from smallgenomeutilities._version import __version__


## SAM preparation, perhaps for a separate rule in Snakemake
## SAM file needs a header with @SQ, thus:
# samtools view -h -T reference.fasta -t reference.fasta.fai  L001.bam > L001.sam
# sort the SAM according to the QNAME
# samtools sort -O sam -n L001.sam > L001.sorted.sam

####################################################################################
## micro-API for getting back to cigar, sequence and quelity from the list
####################################################################################


### reverse direction: 4-element tuples list back into aligned read
def get_aligned_read_from_tuples(llo):
    seq = ""
    qual = ()
    cigtup = ()
    # group in chunks with the same cig op
    for cig, titr in groupby(llo, key=lambda it: it[4]):
        if pysam.CDEL == cig:
            # deletion do not extend the string
            cigtup += (
                (
                    cig,
                    sum(1 for _ in titr),
                ),
            )
            continue
        # TODO handle PAD

        # transpose two columns into two lists
        addseq, addqual = list(zip(*[(it[2], it[3]) for it in titr]))

        # tuple columns 2 and 3 back into strings and qualities list
        seq += "".join(addseq)
        qual += addqual
        cigtup += (
            (
                cig,
                len(addseq),
            ),
        )
    return seq, qual, cigtup


####################################################################################
## "microAPI" to operate on the lists of tuples from get_aligned_pairs()
####################################################################################


# get first valid position
def get_first_valid_pos(ll):
    return next((i[1] for i in ll if i[1] is not None), None)  # or 'null'


# get last valid position
def get_last_valid_pos(ll):
    return get_first_valid_pos(reversed(ll))  # iterator power!


# split into two lists at reference position
def split_before_refpos(ll, pos):
    return more_itertools.split_before(ll, lambda it: it[1] == pos, maxsplit=1)


def split_after_refpos(ll, pos):
    return more_itertools.split_after(ll, lambda it: it[1] == pos, maxsplit=1)


# advance over insertion
def advance_insert(t, liter):
    ins = []
    while (t is not None) and (t[1] is None):
        ins += [t]
        t = next(liter, None)
    return t, ins


####################################################################################
# utlity: query length from CIGAR
####################################################################################


def query_len(cigar_string):
    """
    Given a CIGAR string, return the number of bases consumed from the
    query sequence.
    """
    read_consuming_ops = ("M", "I", "S", "=", "X")
    result = 0
    cig_iter = groupby(cigar_string, lambda chr: chr.isdigit())
    for _, length_digits in cig_iter:
        length = int("".join(length_digits))
        op = next(next(cig_iter)[1])
        if op in read_consuming_ops:
            result += length
    return result


####################################################################################
# main algorithm
####################################################################################


def reconcile_Overlap_tupples(ll1, ll2, name=""):
    out = []
    lit1 = iter(ll1)
    lit2 = iter(ll2)
    t1 = next(lit1, None)
    t2 = next(lit2, None)
    # NOTE no indel should happen at the edge of the overlap thus 'and' instead of 'or'
    while (t1 is not None) and (t2 is not None):
        # matches, dels, etc. (on the ref)
        # advance both cursors
        if (t1[1] is not None) and (t2[1] is not None):
            # positions on ref should agree
            # assert t1[1] == t2[1], f"Got out of sync at {t1} and {t2}"

            # check if both on read
            if (t1[0] is not None) and (t2[0] is not None):
                # no deletion: highest quality
                # simply comparing PHRED scores (higher is better)
                out += [t1] if t1[3] > t2[3] else [t2]
            else:
                # delition always wins
                out += [t1] if t1[0] is None else [t2]

            t1 = next(lit1, None)
            t2 = next(lit2, None)

        # not on the ref: insertions.
        if (t1 is not None and t1[1] is None) or (t2 is not None and t2[1] is None):
            # scan each independently if/until we're back on ref
            t1, ins1 = advance_insert(t1, lit1)
            t2, ins2 = advance_insert(t2, lit2)
            # Longer insert wins
            out += ins1 if len(ins1) > len(ins2) else ins2
    # assert t1 is None, f"ll1 not consumed {name}"
    # assert t2 is None, f"ll2 not consumed {name}"
    return out


### extending the get_aligned_pairs tuples into 4-element ones, with nucleotide and quality, while skipping soft-cliping
def get_aligned_pairs_extended(r1):
    llo = []
    seq = r1.query_sequence
    qual = r1.query_qualities
    for ic in r1.get_aligned_pairs(with_cigar=True):
        i = (ic[0], ic[1])  # offset in read sequence, reference position (both 0-based)
        cig = ic[2]  # the cig op which lead to this aligned pair
        if pysam.CSOFT_CLIP == cig:
            # throw out soft clip
            continue
        if i[0] is not None and i[1] is not None:
            llo.append(i + (seq[i[0]], qual[i[0]], cig))
        if i[0] is None:  # DEL
            assert pysam.CDEL == cig, f"R1: {r1.query_name} -- Is not a deletion: {ic}"
            llo.append(i + (None, None, cig))
        if i[1] is None:  # INS
            if pysam.CPAD == cig:
                # HACK we remove PADding
                # TODO support PAD for Viloca
                continue
            else:
                assert (
                    cig == pysam.CINS
                ), f"R1: {r1.query_name} -- Is not a insertion: {ic}"
            llo.append(i + (seq[i[0]], qual[i[0]], cig))
    return llo


####################################################################################
# fusion function on the level of a single alignment pair
####################################################################################


def read_fusion(r1, r2, header, qfiller=0):
    outr = pysam.AlignedSegment(header=header)
    outr.query_name = r1.query_name
    outr.flag = (
        0  # not paired, super nice alignment ;) we ignore for now existing flags
    )
    outr.reference_name = r1.reference_name
    outr.template_length = max(
        r1.template_length, r2.template_length
    )  # the reverse one is negative, (with most sequencers/aligners)
    # geometry
    if r1.pos > r2.pos:  # machniom! ie. swap
        (r1, r2) = (r2, r1)

    outr.pos = r1.pos
    # gap start and end
    # gs - the base after the R1 alignment
    ll1 = get_aligned_pairs_extended(r1)
    rst1 = get_first_valid_pos(ll1)
    gs = get_last_valid_pos(ll1) + 1
    # ge - the base before the R2
    ll2 = get_aligned_pairs_extended(r2)
    rst2 = get_first_valid_pos(ll2)
    ge = rst2 - 1

    # print(f"R1: {r1.query_name}\tgap start and end: {gs} {ge}", file=sys.stderr)
    if rst1 < rst2:  # checking the beginnings of lists
        llo = []
        if (ge - gs) == -1:  # no gap in fact, ends meet
            llo = ll1 + ll2
        elif (ge - gs) < -1:  # when no gap, but an overlap
            # NOTE no indel should happen at the edge of the overlap
            llbef, llo1 = split_before_refpos(ll1, rst2)
            llo2, llaft = split_after_refpos(ll2, gs - 1)

            llo = llbef + reconcile_Overlap_tupples(llo1, llo2, r1.query_name) + llaft
        else:  # proper gap
            llo = (
                ll1
                + [
                    # offset in read sequence (both 0-based), [dummy] reference position
                    (p, -1, "N", qfiller, pysam.CMATCH)
                    for p in range(gs, rst2)
                ]
                + ll2
            )
        outr.query_sequence, outr.query_qualities, outr.cigartuples = (
            get_aligned_read_from_tuples(llo)
        )
    else:  # p2.pos comes for strange reasons before p1.pos
        if r1.query_length > r2.query_length:
            r1.flag = 0
            outr = (
                r1  # brutally returning just R1, likely soft clipping seriously trimmed
            )
        else:
            r2.flag = 0
            outr = (
                r2  # brutally returning just R2, likely soft clipping seriously trimmed
            )

    outr.mapping_quality = min(r1.mapping_quality, r2.mapping_quality)
    outr.next_reference_name = "*"
    outr.next_reference_start = 0
    outr.pnext = 0
    # print(r1.qname + "  " + r2.qname, file=sys.stderr)
    # print(outr.seq, file=sys.stderr)
    return outr


def fuse_reads(fname, fname_sam_fused_output, fname_ref, qfiller=0, fname_unpaired=None, fname_unaligned=None):
    logging.info(f"Starting processing {fname}")
    ifp = open(fname, "r") if fname != "-" else sys.stdin
    samfile = pysam.AlignmentFile(ifp, "r")
    sam_out = pysam.Samfile(
        fname_sam_fused_output if fname_sam_fused_output != "-" else sys.stdout,
        "w",
        header=samfile.header,
    )
    
    # Setup unpaired output - default to same as main output if not specified
    if fname_unpaired is None:
        sam_unpaired = sam_out
    else:
        sam_unpaired = pysam.Samfile(
            fname_unpaired,
            "w",
            header=samfile.header,
        )
    
    # Setup unaligned output if specified
    sam_unaligned = None
    if fname_unaligned:
        sam_unaligned = pysam.Samfile(
            fname_unaligned,
            "w",
            header=samfile.header,
        )

    c = 0
    ooo = 0
    unal = 0
    unpaired = 0
    prev = None

    # Try to estimate number of reads for progress bar
    total_reads = None
    progress_bar = None
    isize = 0
    if logging.getLogger().level <= logging.INFO:
        if fname != "-":
            isize = Path(fname).stat().st_size
            if isize > 0:
                progress_bar = Bar(
                    "Processing read pairs",
                    max=isize,
                    suffix="%(percent)u%% (%(index)u bytes) [%(elapsed_td)s/%(eta_td)s]",
                )
        if progress_bar is None:
            progress_bar = Spinner("Processing read pairs ")

    for read in samfile.fetch():
        # not currently holding a previous pair member
        if prev is None:
            prev = read
            continue

        # out-of-order or unpaired read (different name)
        if prev.qname != read.qname:
            logging.debug(f"Read {prev.qname} has no pair, writing as unpaired")
            # Write previous read as unpaired
            if prev.reference_end is not None:
                sam_unpaired.write(prev)
                unpaired += 1
            elif sam_unaligned:
                sam_unaligned.write(prev)
                unal += 1
            
            prev = read
            ooo += 1
            continue

        # for properly aligned, skipping unaligned
        if read.reference_end is None or prev.reference_end is None:
            logging.debug(f"Read pair {prev.qname} has at least one unaligned mate")
            
            # Handle unaligned reads if output file is specified
            if sam_unaligned:
                if prev.reference_end is None:
                    sam_unaligned.write(prev)
                if read.reference_end is None:
                    sam_unaligned.write(read)
            
            unal += 1
            prev = None
            continue

        fused = read_fusion(r1=prev, r2=read, header=samfile.header, qfiller=qfiller)
        c += 1
        cl = query_len(fused.cigarstring)
        ql = fused.query_length
        assert (
            cl == ql
        ), f"CIGAR length and query length not matching {fused.cigarstring} : {fused.seq}"
        sam_out.write(fused)

        if progress_bar and c & 0x3FF == 0:
            if isize > 0:
                progress_bar.goto(ifp.tell())
            else:
                progress_bar.next()

        prev = None
    
    # Handle the last read if it exists
    if prev is not None:
        logging.debug(f"Final read {prev.qname} has no pair, writing as unpaired")
        if prev.reference_end is not None:
            sam_unpaired.write(prev)
            unpaired += 1
        elif sam_unaligned:
            sam_unaligned.write(prev)
            unal += 1
    
    if progress_bar:
        progress_bar.finish()
    
    # Close file handles if they're different
    if fname_unpaired and sam_unpaired != sam_out:
        sam_unpaired.close()
    if sam_unaligned:
        sam_unaligned.close()
        
    logging.info(f"Finished fusion: {c} pairs fused")
    logging.info(f"{unpaired} unpaired reads written")
    if ooo:
        logging.warning(f"{ooo} reads not grouped in pairs")
        logging.warning(
            "Either no pair exist or reads were not sorted alphabetically by name"
        )
    if unal:
        logging.warning(f"{unal} unaligned reads")
        if fname_unaligned:
            logging.info(f"Unaligned reads written to {fname_unaligned}")
        else:
            logging.info("Unaligned reads were skipped (use --unaligned to write them)")


def parse_args():
    """Set up the parsing of command-line arguments"""
    parser = argparse.ArgumentParser(
        description="# Merges paired-end reads to one fused reads based on alignment.",
        epilog="SAM file need to be sorted by QNAME (not POS) and need @SQ in header:\n\n"
        "samtools view -h -T reference.fasta -t reference.fasta.fai  L001.bam > L001.sam &&\n\n"
        "samtools sort -O sam -n L001.sam > L001.sorted.sam",
    )
    parser.add_argument(
        "-v",
        "--version",
        action="version",
        version="%(prog)s {version}".format(version=__version__),
    )
    parser.add_argument(
        "--verbose",
        action="store_true",
        help="Enable verbose debugging output",
    )
    parser.add_argument(
        "-q",
        "--quiet",
        action="store_true",
        help="Suppress all output except errors",
    )
    parser.add_argument(
        "-f",
        "--ref",
        metavar="FASTA",
        required=False,
        default=None,
        type=str,
        dest="reference",
        help="reference file used during alignment",
    )
    parser.add_argument(
        "-qn",
        "--quality-n",
        metavar="PHRED",
        required=False,
        default=0,
        type=int,
        dest="qfiller",
        help="PHRED quality to use when filling gap with n. (e.g. 0 or 2)",
    )
    parser.add_argument(
        "-o",
        "--output",
        metavar="SAM",
        required=False,
        default="-",
        type=str,
        dest="output",
        help="file to write merged read-pairs to",
    )
    parser.add_argument(
        "--unpaired",
        metavar="SAM",
        required=False,
        default=None,
        type=str,
        dest="unpaired",
        help="file to write unpaired reads to (defaults to same as output)",
    )
    parser.add_argument(
        "--unaligned",
        metavar="SAM",
        required=False,
        default=None,
        type=str,
        dest="unaligned",
        help="file to write unaligned reads to (if not specified, unaligned reads are skipped)",
    )
    parser.add_argument(
        "FILE", nargs=1, metavar="SAM", help="input SAM file (sorted by QNAME)"
    )

    return parser.parse_args()


def main():
    args = parse_args()

    # Configure logging
    log_level = logging.INFO  # Default level
    if args.quiet:
        log_level = logging.ERROR
    elif args.verbose:
        log_level = logging.DEBUG

    logging.basicConfig(format="%(levelname)s: %(message)s", level=log_level)

    fuse_reads(
        fname=args.FILE[0],
        fname_sam_fused_output=args.output,
        fname_ref=args.reference,
        qfiller=args.qfiller,
        fname_unpaired=args.unpaired,
        fname_unaligned=args.unaligned,
    )


if __name__ == "__main__":
    main()
