#!/usr/bin/env python3

import pysam
import argparse
import os
import numpy as np
from multiprocessing import Pool

'''
input:  alignment file(s) as Bam file(s) and, either reference sequence in
        FASTA format or, if only one genetic region is of interest, the name of
        the reference. An index for each BAM file should exists in the same
        directory as the corresponding BAM file.
output: tab-separated file. Minority alleles on the rows and samples as columns.
        Loci are reported using 0-based indexing.
        Majority base/aminoacid per loci is identify by summing relative
        abundances.
'''


class AlignedRead():

    def __init__(self, read):
        self.read = read

    def get_alignment_sequence(self):
        """
        Read sequenced bases, excluding soft-clipped bases and insertions
        but including deletions w.r.t reference sequence
        """

        read = self.read
        alignment = ''
        idx = 0

        # cigartuples: cigar string encoded as a list of tuples (operation,
        # length)
        for op, length in read.cigartuples:
            # Return read bases if they correspond to an alignment match (0),
            # sequence match (7) or sequence mismatch (8)
            if op in [0, 7, 8]:
                alignment += read.query_sequence[idx:idx + length]
                idx += length
            # Add gap symbol '-' for deletions (2). Do not increment the index
            # because deletions are not reported in read.query_sequence
            elif op == 2:
                alignment += ''.join(np.repeat('-', length))
            # skip read bases if they correspond to an insertion (1) or soft-
            # clipped bases (4)
            elif op in [1, 4]:
                idx += length

        return alignment

    def get_alignment_positions(self):

        return np.arange(self.read.reference_start, self.read.reference_end)


class CheckPath(argparse.Action):

    def __call__(self, parser, namespace, values, option_string=None):

        path = values

        if not os.path.exists(path):
            os.makedirs(path)

        if os.access(path, os.W_OK):
            setattr(namespace, self.dest, path)
        else:
            raise argparse.ArgumentTypeError(
                "CheckPath:{0} is not a writable directory".format(path))


def parse_args():
    """ Set up the parsing of command-line arguments """
    parser = argparse.ArgumentParser(description="Script to extract minority alleles per samples",
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    requiredNamed = parser.add_argument_group('required named arguments')
    requiredNamed.add_argument("-r", required=True, metavar='FASTA', dest='reference',
                               help="Either a fasta file containing a reference sequence or the reference name of the region/chromosome of interest. The latter is expected if a region is specified")
    parser.add_argument("-s", "--start", required=False, default=None, metavar='INT', dest='start',
                        type=int, help="Starting position of the region of interest, 0-based indexing")
    parser.add_argument("-e", "--end", required=False, default=None, metavar='INT', dest='end',
                        type=int, help="Ending position of the region of interest, 0-based indexing. Note a half-open interval is used, i.e, [start:end)")
    parser.add_argument("-p", "--config", required=False, default=None, metavar='file.config', dest='frames',
                        help="Report minority aminoacids - a .config file specifying reading frames expected")
    parser.add_argument("-c", required=False, default=100, metavar='INT', dest='min_coverage',
                        type=int, help="Minimum read depth for reporting variants per locus and sample")
    parser.add_argument("-N", required=False, default=None, metavar='name1,name2,...', dest="patientIDs",
                        help="Patient/sample identifiers as comma separated strings")
    parser.add_argument("-t", required=False, default=1, metavar='INT', dest='thrds', type=int,
                        help="Number of threads")
    parser.add_argument("-f", "--freqs", required=False, action='store_true', dest='freqs',
                        help="Indicates whether or not all frequencies should be stored")
    parser.add_argument("-d", required=False, action='store_true', dest='coverage',
                        help="Indicates wheter coverage per locus should be output")
    parser.add_argument("-o", required=False, default=os.getcwd(), action=CheckPath, metavar='PATH', dest='outdir',
                        help="Output directory")
    parser.add_argument("FILES", nargs='+', metavar='BAM', help="BAM file(s)")

    return parser.parse_args()


def ascii2idx(sequence):
    # character ascii index
    #    'A'      65    0
    #    'C'      67    1
    #    'G'      71    2
    #    'T'      84    3
    #    '-'      45    4
    charA = ord('A')
    charC = ord('C')
    charG = ord('G')
    charT = ord('T')
    charDel = ord('-')

    sequence[sequence == charA] = 0
    sequence[sequence == charC] = 1
    sequence[sequence == charG] = 2
    sequence[sequence == charT] = 3
    sequence[sequence == charDel] = 4
    return sequence


def get_counts(args):

    bamfile, reference_name, start, end, region_len, alphabet_len = args

    nt_counts = np.zeros(shape=(region_len * alphabet_len))

    with pysam.AlignmentFile(bamfile, 'rb') as alnfile:

        for read in alnfile.fetch(reference=reference_name, start=start, end=end):
            # Fetch returns all reads which cover an specific region. However,
            # all positions - including positions outside the region of interest -
            # are returned
            aligned_read = AlignedRead(read)
            alignment_positions = aligned_read.get_alignment_positions()
            alignment_sequence = np.array(
                aligned_read.get_alignment_sequence(), dtype='c').view(np.uint8)

            if start is not None and end is not None:
                # Extract region of interest
                idxs = (alignment_positions >= start) & (
                    alignment_positions <= end)
                alignment_positions = alignment_positions[idxs]
                alignment_sequence = alignment_sequence[idxs]

            alignment_sequence = ascii2idx(alignment_sequence)

            # Filter bases that are not in the alphabet
            if (alignment_sequence >= alphabet_len).any():
                idxs = np.where(alignment_sequence >= alphabet_len)
                alignment_positions = np.delete(alignment_positions, idxs)
                alignment_sequence = np.delete(alignment_sequence, idxs)

            if start is None and end is None:
                idxs_array = alignment_positions * alphabet_len + alignment_sequence
            else:
                # Shift the indexing
                idxs_array = (alignment_positions * alphabet_len) + \
                    alignment_sequence - (start * alphabet_len)

            nt_counts[idxs_array] += 1

    return nt_counts


def main():
    args = parse_args()

    if args.start is not None:
        assert args.end is not None, 'Minority variants are extracted from a region of interest. An ending position was expected'

    if args.end is not None:
        if args.start is None:
            print('Starting position was expected. Setting it to 0')
            args.start = 0

    if args.frames is None:
        # Nucleotides - Alphabet = {A, C, G, T, -}, including deletions
        alphabet = np.array(['A', 'C', 'G', 'T', '-'])
        alphabet_len = alphabet.size
    else:
        alphabet = np.array(['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L',
                             'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y'])
        alphabet_len = alphabet.size
        print("Option not implemented yet")
        raise SystemExit(0)

    num_samples = len(args.FILES)

    if args.patientIDs is not None:
        patientIDs = args.patientIDs.split(",")
        assert len(
            patientIDs) == num_samples, 'Number of patient/sample identifiers do not match number of BAM files'

    if args.start is None and args.end is None:
        # Length of the reference sequence
        # Only one reference sequence expected
        start = args.start
        end = args.end
        reference = pysam.FastaFile(args.reference)
        assert reference.nreferences == 1, 'Only one reference sequence expected'
        region_len = reference.lengths[0]
        reference_name = reference.references[0]
        cohort_consensus = reference.fetch(reference=reference_name).lower()

    else:
        start = args.start
        end = args.end - 1
        region_len = end - start + 1
        reference_name = args.reference
        cohort_consensus = 'n' * region_len

    args_list = [(bamfile, reference_name, start, end, region_len,
                  alphabet_len) for bamfile in args.FILES]
    pool = Pool(processes=args.thrds)
    res = pool.map(get_counts, args_list)
    pool.close()
    pool.join()

    nt_counts = np.vstack(res).T

    coverage = np.zeros(shape=(region_len, num_samples), dtype=int)
    coverage_raw = np.zeros(shape=(region_len, num_samples), dtype=int)
    nt_freqs = np.zeros(shape=(region_len * alphabet_len, num_samples))

    for i in range(region_len):
        idx_locus = i * alphabet_len
        coverage_raw[i, ] = np.sum(
            nt_counts[idx_locus:(idx_locus + alphabet_len), ], axis=0)
        # Identify samples with coverage below threshold and discard those read
        # counts
        mask = coverage_raw[i, ] < args.min_coverage
        coverage[i, ] = coverage_raw[i, ]
        coverage[i, mask] = 0
        nt_freqs[idx_locus:(idx_locus + alphabet_len), ~mask] = nt_counts[
            idx_locus:(idx_locus + alphabet_len), ~mask] / coverage[i, ~mask]

    if start is None and end is None:
        pos = np.arange(region_len)
    else:
        pos = np.arange(start, end + 1)

    # Exclude loci for which all samples report zero counts
    loci = pos[np.sum(coverage, axis=1) > 0]
    variant_loci = np.zeros(shape=loci.size * alphabet_len).astype(bool)
    minor_variants_freqs = np.zeros(shape=(variant_loci.size, num_samples))

    # Write matrix of frequencies per locus and per sample
    if args.freqs:
        minor_variants = np.tile(alphabet, loci.size)
        loci_tile = np.repeat(loci, alphabet_len)
        if start is not None:
            loci_tile = loci_tile - start
        aux = nt_freqs[(loci_tile * alphabet_len) +
                       np.tile(np.arange(alphabet_len), loci.size)]

        type_string = 'i8,U8,' + ','.join('f8' for x in np.arange(num_samples))
        out = np.zeros(loci_tile.size, dtype=type_string)
        for i_idx in range(loci_tile.size):
            out[i_idx][0] = loci_tile[i_idx]
            out[i_idx][1] = minor_variants[i_idx]
            for j_idx in range(num_samples):
                out[i_idx][2 + j_idx] = aux[i_idx, j_idx]

        np.save(os.path.join(args.outdir, 'frequencies.npy'), out)

    # Extract minority variants
    for i, locus in enumerate(loci):

        if start is not None:
            locus -= start

        # Counts per bases for locus i and for all samples
        idx_locus = locus * alphabet_len
        nt_freqs_locus = nt_freqs[idx_locus:(idx_locus + alphabet_len), ]

        # Identify variants: bases reporting at least <min_coverage count> for one
        # of the samples. Store 'True' if the sum of nucleotide frequencies across
        # sample is larger than 0.
        idx_array = i * alphabet_len
        variant_loci[idx_array:(idx_array + alphabet_len)
                     ] = np.sum(nt_freqs_locus, axis=1) > 0
        minor_variants_freqs[idx_array:(
            idx_array + alphabet_len), ] = nt_freqs_locus

        # Identify samples for which current locus doesn't report
        # 'min_coverage'
        mask = coverage[locus, ] == 0
        minor_variants_freqs[idx_array:(
            idx_array + alphabet_len), mask] = np.nan

        # Identify the majority base per locus, and omit samples for which
        # locus is not covered
        nt_freqs_locus = nt_freqs_locus[:, ~mask]
        idx_major = np.sum(nt_freqs_locus, axis=1).argmax()

        # Store 'False' for the majority variant
        variant_loci[idx_array:(idx_array + alphabet_len)][idx_major] = False
        cohort_consensus = cohort_consensus[
            :locus] + alphabet[idx_major] + cohort_consensus[(locus + 1):]

    # Instantiate arrays
    minor_variants = np.tile(alphabet, loci.size)
    loci = np.repeat(loci, alphabet_len)

    # Exclude bases with zero counts for all samples, as well as majority bases
    minor_variants = minor_variants[variant_loci]
    minor_variants_freqs = minor_variants_freqs[variant_loci, ]
    loci = loci[variant_loci]

    # Write to output file
    if args.patientIDs is None:
        patientIDs = "\t".join(str(x) for x in np.arange(num_samples))
    else:
        patientIDs = "\t".join(patientIDs)

    with open(os.path.join(args.outdir, 'minority_variants.tsv'), 'w') as outfile:
        outfile.write("# pos\tvariant\t" + patientIDs + "\n")
        for i_idx in range(loci.size):
            outfile.write('{:d}\t{:s}'.format(
                loci[i_idx],  minor_variants[i_idx]))
            for j_idx in range(num_samples):
                outfile.write('\t{:.6e}'.format(
                    minor_variants_freqs[i_idx, j_idx]))
            outfile.write('\n')

    # Write to output file cohort-consensus. Consensus is with respect to HXB2
    with open(os.path.join(args.outdir, 'cohort_consensus.fasta'), 'w') as outfile:
        if start is None and end is None:
            outfile.write(">{}\n".format(reference_name))
        else:
            outfile.write(">{}:{}-{}\n".format(reference_name, start, end))
        outfile.write(cohort_consensus)

    # Write to output file coverage per locus
    if args.coverage:
        out = np.concatenate((pos[:, np.newaxis], coverage_raw), axis=1)
        header = "pos\t" + patientIDs

        np.savetxt(os.path.join(args.outdir, 'coverage.tsv'), out, fmt='%d',
                   delimiter='\t', header=header)


if __name__ == '__main__':
    main()
