#!/usr/bin/env python3

import argparse
import pysam
import os
import numpy as np


def parse_args():
    """ Set up the parsing of command-line arguments """

    parser = argparse.ArgumentParser(description="Script to extract coverage windows for ShoRAH",
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    requiredNamed = parser.add_argument_group('required named arguments')
    requiredNamed.add_argument("-i", required=True, metavar='INPUT', dest='coverage',
                               help="File containing samples as columns and loci as rows")
    requiredNamed.add_argument("-r", required=True, metavar='REF', dest='reference',
                               help="Either a fasta file containing a reference sequence or the reference name of the region/chromosome of interest. If a fasta file is given and has more than one sequence, only the first one is considered.")

    parser.add_argument("-c", required=False, default=100, metavar='INT', dest='min_coverage',
                        type=int, help="Minimum read depth for reporting variants per locus and sample")
    parser.add_argument("-w", required=False, default='201', metavar='len1,len2,...', dest='offset',
                        help="Offset used by ShoRAH to construct overlapping windows")
    parser.add_argument("-N", required=False, default=None, metavar='name1,name2,...', dest="patientIDs",
                        help="Patient/sample identifiers as comma separated strings")
    parser.add_argument("-o", required=False, default="coverage_intervals.tsv", metavar='OUTPUT', dest='outfile',
                        help="Output file name")

    return parser.parse_args()


def nonzero_intervals(x, offset, start=None):

    if len(x) == 0:
        return 0
    elif not isinstance(x, np.ndarray):
        x = np.array(x)

    edges, = np.nonzero(np.diff(x == 0))
    intervals = [edges + 1]

    if x[0] != 0:
        intervals.insert(0, [0])
    if x[-1] != 0:
        intervals.append([x.size])
    intervals = np.concatenate(intervals)

    if start is not None:
        intervals += start

    intervals_start = intervals[::2]
    intervals_end = intervals[1::2]

    # NOTE: ShoRAH adds or subtracts 3 * (window-length / window-shift), in
    #       order to ensure that every locus in the target region is covered by
    #       3 window. By default, window-shift is set to 3, meaning that
    #       overlapping windows are shifted by window-length / 3. In this
    #       settings the ofsset is equivalent to the window length

    # Check original intervals are at least as long as the window length
    length = intervals_end - intervals_start

    mask = length < offset
    intervals_start = intervals_start[~mask]
    intervals_end = intervals_end[~mask]

    # Add offset to starting position.
    intervals_start += offset

    # Subtract offset to ending position.
    intervals_end -= offset

    intervals = np.vstack((intervals_start, intervals_end)).T

    return intervals


# MAIN PROGRAM
args = parse_args()

# Get name of the reference
reference_extension = args.reference.split(".")[-1].lower()

if reference_extension == "fasta":
    reference = pysam.FastaFile(args.reference)
    reference_name = reference.references[0]
else:
    reference_name = args.reference

# Load input file
coverage = np.loadtxt(args.coverage, dtype=int, delimiter='\t', skiprows=1)

loci = coverage[:, 0]
coverage = coverage[:, 1:]
num_samples = coverage.shape[1]

if args.patientIDs is None:

    # Read patient identifiers from the input file
    with open(args.coverage, 'r') as fin:
        first_line = fin.readline()

    patientIDs = first_line.rstrip().split()
    # Comments are output with a hashtag. Remove hashtag and identifier of first
    # column, which corresponds to positions
    patientIDs = patientIDs[2:]
    if len(patientIDs) != num_samples:
        patientIDs = np.arange(num_samples)
else:
    patientIDs = args.patientIDs.split(",")

assert len(patientIDs) == num_samples, 'Number of patient/sample identifiers do not match number of columns in input file.'

offset = args.offset.split(",")

assert len(
    offset) == num_samples, 'Number of input values do not match number of columns in input file.'


with open(args.outfile, "wt") as outfile:

    for idx in range(num_samples):
        # Identify samples with coverage below threshold and discard those read
        # counts
        mask = coverage[:, idx] < args.min_coverage
        coverage[mask, idx] = 0
        intervals = nonzero_intervals(
            coverage[:, idx], int(offset[0]), loci[0])

        outfile.write("{}\t{}\n".format(patientIDs[idx],
                                        ','.join("{}:{}-{}".format(reference_name, x[0], x[1]) for x in intervals)))
