#!/usr/bin/env python3

import argparse
import pysam
import os
import numpy as np


def parse_args():
    """ Set up the parsing of command-line arguments """

    parser = argparse.ArgumentParser(description="Script to extract coverage windows for ShoRAH",
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    requiredNamed = parser.add_argument_group('required named arguments')
    requiredNamed.add_argument("-i", nargs='+', metavar='INPUT', dest='input_files',
                               help="File containing coverage per locus per sample. Samples are expected as columns and loci as rows. Alternatively, BAM file(s) to extract coverage")
    requiredNamed.add_argument("-r", required=True, metavar='REF', dest='reference',
                               help="Either a fasta file containing a reference sequence or the reference name of the region/chromosome of interest. The latter is required if BAM files are provided as input. If a fasta file is given and has more than one sequence, only the first one is considered.")

    parser.add_argument("-c", required=False, default=100, metavar='INT', dest='min_coverage',
                        type=int, help="Minimum read depth for reporting variants per locus and sample")
    parser.add_argument("-w", required=False, default='201', metavar='len1,len2,...', dest='offset',
                        help="Offset used by ShoRAH to construct overlapping windows")
    parser.add_argument("-N", required=False, default=None, metavar='name1,name2,...', dest="patientIDs",
                        help="Patient/sample identifiers as comma separated strings")
    parser.add_argument("-t", required=False, default=1, metavar='INT', dest='thrds', type=int,
                        help="Number of threads")
    parser.add_argument("-e", required=False, action='store_true', default=False, dest='right_offset',
                        help="Indicate whether a more liberal offset is to be applied on the right of each interval")
    parser.add_argument("-o", required=False, default="coverage_intervals.tsv", metavar='OUTPUT', dest='outfile',
                        help="Output file name")

    return parser.parse_args()


def get_coverage(args):

    bamfile, region_len = args

    coverage = np.zeros(region_len, dtype=int)

    with pysam.AlignmentFile(bamfile, 'rb') as alnfile:

        for pileupcolumn in alnfile.pileup():
            coverage[pileupcolumn.pos] = pileupcolumn.n

    return coverage


def nonzero_intervals(x, offset, start=None, right_offset=False):

    if len(x) == 0:
        return 0
    elif not isinstance(x, np.ndarray):
        x = np.array(x)

    edges, = np.nonzero(np.diff(x == 0))
    intervals = [edges + 1]

    if x[0] != 0:
        intervals.insert(0, [0])
    if x[-1] != 0:
        intervals.append([x.size])
    intervals = np.concatenate(intervals)

    if start is not None:
        intervals += start

    intervals_start = intervals[::2]
    intervals_end = intervals[1::2]

    # NOTE: ShoRAH adds or subtracts 3 * (window-length / window-shift), in
    #       order to ensure that every locus in the target region is covered by
    #       3 window. By default, window-shift is set to 3, meaning that
    #       overlapping windows are shifted by window-length / 3. In this
    #       settings the ofsset is equivalent to the window length

    # Check original intervals are at least as long as the window length
    length = intervals_end - intervals_start

    mask = length < offset
    intervals_start = intervals_start[~mask]
    intervals_end = intervals_end[~mask]

    # Add offset to starting position.
    intervals_start += offset

    # Subtract offset to ending position.
    if right_offset:
        # Often the last position with high coverage is not covered by any
        # of the windows
        intervals_end -= int(offset * 2 / 3)
    else:
        intervals_end -= offset

    intervals = np.vstack((intervals_start, intervals_end)).T

    return intervals


def main():
    args = parse_args()

    # Get name of the reference
    reference_extension = args.reference.split(".")[-1].lower()
    extensionsToCheck = ['fasta', 'fas', 'fa']

    if any(ext in reference_extension for ext in extensionsToCheck):
        reference = pysam.FastaFile(args.reference)
        reference_name = reference.references[0]
        region_len = reference.lengths[0]
    else:
        reference_name = args.reference
        region_len = None

    # Load input file
    input_extension = args.input_files[0].split(".")[-1].lower()

    if input_extension == "tsv":
        coverage_file = args.input_files[0]
        coverage = np.loadtxt(coverage_file, dtype=int,
                              delimiter='\t', skiprows=1)

        loci = coverage[:, 0]
        coverage = coverage[:, 1:]
        num_samples = coverage.shape[1]

        if args.patientIDs is None:

            # Read patient identifiers from the input file
            with open(coverage_file, 'r') as fin:
                first_line = fin.readline()

            patientIDs = first_line.rstrip().split()
            # Comments are output with a hashtag. Remove hashtag and identifier of first
            # column, which corresponds to positions
            patientIDs = patientIDs[2:]
            if len(patientIDs) != num_samples:
                patientIDs = np.arange(num_samples)
        else:
            patientIDs = args.patientIDs.split(",")

    else:
        assert region_len is not None, 'A reference file should be provided when the coverage is to be determined for the reads'

        from multiprocessing import Pool

        alphabet = np.array(['A', 'C', 'G', 'T', '-'])
        alphabet_len = alphabet.size

        num_samples = len(args.input_files)
        # Extract number of reads per locus
        args_list = [(bamfile, region_len) for bamfile in args.input_files]
        pool = Pool(processes=args.thrds)
        res = pool.map(get_coverage, args_list)
        pool.close()
        pool.join()

        coverage = np.vstack(res).T
        loci = np.arange(region_len)

        if args.patientIDs is None:
            patientIDs = np.arange(num_samples)
        else:
            patientIDs = args.patientIDs.split(",")

    assert len(patientIDs) == num_samples, 'Number of patient/sample identifiers do not match number of columns in input file.'

    offset = args.offset.split(",")

    assert len(
        offset) == num_samples, 'Number of input values do not match number of columns in input file.'

    with open(args.outfile, "wt") as outfile:

        for idx in range(num_samples):
            # Identify samples with coverage below threshold and discard those read
            # counts
            mask = coverage[:, idx] < args.min_coverage
            coverage[mask, idx] = 0
            intervals = nonzero_intervals(
                coverage[:, idx], int(offset[0]), loci[0], right_offset=args.right_offset)

            outfile.write("{}\t{}\n".format(patientIDs[idx],
                                            ','.join("{}:{}-{}".format(reference_name, x[0], x[1]) for x in intervals)))


if __name__ == '__main__':
    main()
