#!/usr/bin/env python

import pysam
import argparse, sys
import math, time, re
from collections import Counter
from argparse import RawTextHelpFormatter

__author__ = "Colby Chiang (cc2qe@virginia.edu)"
__version__ = "$Revision: 0.0.4 $"
__date__ = "$Date: 2016-04-05 09:07 $"

# --------------------------------------
# define functions

def get_args():
    parser = argparse.ArgumentParser(formatter_class=RawTextHelpFormatter, description="\
svtyper\n\
author: " + __author__ + "\n\
version: " + __version__ + "\n\
description: Compute genotype of structural variants based on breakpoint depth")
    parser.add_argument('-B', '--bam', type=str, required=True, help='BAM file(s), comma-separated if genotyping multiple BAMs')
    parser.add_argument('-S', '--split_bam', type=str, required=False, help='split-read bam file for sample, comma-separated if genotyping multiple BAMs')
    parser.add_argument('-i', '--input_vcf', type=argparse.FileType('r'), default=None, help='VCF input (default: stdin)')
    parser.add_argument('-o', '--output_vcf', type=argparse.FileType('w'), default=sys.stdout, help='output VCF to write (default: stdout)')
    parser.add_argument('-f', '--splflank', type=int, required=False, default=20, help='min number of split read query bases flanking breakpoint on either side [20]')
    parser.add_argument('-F', '--discflank', type=int, required=False, default=20, help='min number of discordant read query bases flanking breakpoint on either side. (should not exceed read length) [20]')
    parser.add_argument('--split_weight', type=float, required=False, default=1, help='weight for split reads [1]')
    parser.add_argument('--disc_weight', type=float, required=False, default=1, help='weight for discordant paired-end reads [1]')
    parser.add_argument('-n', dest='num_samp', type=int, required=False, default=1000000, help='number of pairs to sample from BAM file for building insert size distribution [1000000]')
    parser.add_argument('-M', action='store_true', dest='legacy', required=False, help='split reads are flagged as secondary, not supplementary. For compatibility with legacy BWA-MEM "-M" flag')
    parser.add_argument('--debug', action='store_true', help='debugging verbosity')

    # parse the arguments
    args = parser.parse_args()

    # if no input, check if part of pipe and if so, read stdin.
    if args.input_vcf == None:
        if sys.stdin.isatty():
            parser.print_help()
            exit(1)
        else:
            args.input_vcf = sys.stdin
    # send back the user input
    return args

class Vcf(object):
    def __init__(self):
        self.file_format = 'VCFv4.2'
        # self.fasta = fasta
        self.reference = ''
        self.sample_list = []
        self.info_list = []
        self.format_list = []
        self.alt_list = []
        self.add_format('GT', 1, 'String', 'Genotype')

    def add_header(self, header):
        for line in header:
            if line.split('=')[0] == '##fileformat':
                self.file_format = line.rstrip().split('=')[1]
            elif line.split('=')[0] == '##reference':
                self.reference = line.rstrip().split('=')[1]
            elif line.split('=')[0] == '##INFO':
                a = line[line.find('<')+1:line.find('>')]
                r = re.compile(r'(?:[^,\"]|\"[^\"]*\")+')
                self.add_info(*[b.split('=')[1] for b in r.findall(a)])
            elif line.split('=')[0] == '##ALT':
                a = line[line.find('<')+1:line.find('>')]
                r = re.compile(r'(?:[^,\"]|\"[^\"]*\")+')
                self.add_alt(*[b.split('=')[1] for b in r.findall(a)])
            elif line.split('=')[0] == '##FORMAT':
                a = line[line.find('<')+1:line.find('>')]
                r = re.compile(r'(?:[^,\"]|\"[^\"]*\")+')
                self.add_format(*[b.split('=')[1] for b in r.findall(a)])
            elif line[0] == '#' and line[1] != '#':
                self.sample_list = line.rstrip().split('\t')[9:]

    # return the VCF header
    def get_header(self):
        header = '\n'.join(['##fileformat=' + self.file_format,
                            '##fileDate=' + time.strftime('%Y%m%d'),
                            '##reference=' + self.reference] + \
                           [i.hstring for i in self.info_list] + \
                           [a.hstring for a in self.alt_list] + \
                           [f.hstring for f in self.format_list] + \
                           ['\t'.join([
                               '#CHROM',
                               'POS',
                               'ID',
                               'REF',
                               'ALT',
                               'QUAL',
                               'FILTER',
                               'INFO',
                               'FORMAT'] + \
                                      self.sample_list
                                  )])
        return header

    def add_info(self, id, number, type, desc):
        if id not in [i.id for i in self.info_list]:
            inf = self.Info(id, number, type, desc)
            self.info_list.append(inf)

    def add_alt(self, id, desc):
        if id not in [a.id for a in self.alt_list]:
            alt = self.Alt(id, desc)
            self.alt_list.append(alt)

    def add_format(self, id, number, type, desc):
        if id not in [f.id for f in self.format_list]:
            fmt = self.Format(id, number, type, desc)
            self.format_list.append(fmt)

    def add_sample(self, name):
        self.sample_list.append(name)

    # get the VCF column index of a sample
    # NOTE: this is zero-based, like python arrays
    def sample_to_col(self, sample):
        return self.sample_list.index(sample) + 9

    class Info(object):
        def __init__(self, id, number, type, desc):
            self.id = str(id)
            self.number = str(number)
            self.type = str(type)
            self.desc = str(desc)
            # strip the double quotes around the string if present
            if self.desc.startswith('"') and self.desc.endswith('"'):
                self.desc = self.desc[1:-1]
            self.hstring = '##INFO=<ID=' + self.id + ',Number=' + self.number + ',Type=' + self.type + ',Description=\"' + self.desc + '\">'

    class Alt(object):
        def __init__(self, id, desc):
            self.id = str(id)
            self.desc = str(desc)
            # strip the double quotes around the string if present
            if self.desc.startswith('"') and self.desc.endswith('"'):
                self.desc = self.desc[1:-1]
            self.hstring = '##ALT=<ID=' + self.id + ',Description=\"' + self.desc + '\">'

    class Format(object):
        def __init__(self, id, number, type, desc):
            self.id = str(id)
            self.number = str(number)
            self.type = str(type)
            self.desc = str(desc)
            # strip the double quotes around the string if present
            if self.desc.startswith('"') and self.desc.endswith('"'):
                self.desc = self.desc[1:-1]
            self.hstring = '##FORMAT=<ID=' + self.id + ',Number=' + self.number + ',Type=' + self.type + ',Description=\"' + self.desc + '\">'

class Variant(object):
    def __init__(self, var_list, vcf):
        self.chrom = var_list[0]
        self.pos = int(var_list[1])
        self.var_id = var_list[2]
        self.ref = var_list[3]
        self.alt = var_list[4]
        if var_list[5] == '.':
            self.qual = 0
        else:
            self.qual = float(var_list[5])
        self.filter = var_list[6]
        self.sample_list = vcf.sample_list
        self.info_list = vcf.info_list
        self.info = dict()
        self.format_list = vcf.format_list
        self.active_formats = list()
        self.gts = dict()

        # fill in empty sample genotypes
        if len(var_list) < 8:
            sys.stderr.write('\nError: VCF file must have at least 8 columns\n')
            exit(1)
        if len(var_list) < 9:
            var_list.append("GT")

        # make a genotype for each sample at variant
        for s in self.sample_list:
            try:
                s_gt = var_list[vcf.sample_to_col(s)].split(':')[0]
                self.gts[s] = Genotype(self, s, s_gt)
                # import the existing fmt fields
                for j in zip(var_list[8].split(':'), var_list[vcf.sample_to_col(s)].split(':')):
                    self.gts[s].set_format(j[0], j[1])
            except IndexError:
                self.gts[s] = Genotype(self, s, './.')

        self.info = dict()
        i_split = [a.split('=') for a in var_list[7].split(';')] # temp list of split info column
        for i in i_split:
            if len(i) == 1:
                i.append(True)
            self.info[i[0]] = i[1]

    def set_info(self, field, value):
        if field in [i.id for i in self.info_list]:
            self.info[field] = value
        else:
            sys.stderr.write('\nError: invalid INFO field, \"' + field + '\"\n')
            exit(1)

    def get_info(self, field):
        return self.info[field]

    def get_info_string(self):
        i_list = list()
        for info_field in self.info_list:
            if info_field.id in self.info.keys():
                if info_field.type == 'Flag':
                    i_list.append(info_field.id)
                else:
                    i_list.append('%s=%s' % (info_field.id, self.info[info_field.id]))
        return ';'.join(i_list)

    def get_format_string(self):
        f_list = list()
        for f in self.format_list:
            if f.id in self.active_formats:
                f_list.append(f.id)
        return ':'.join(f_list)

    def genotype(self, sample_name):
        if sample_name in self.sample_list:
            return self.gts[sample_name]
        else:
            sys.stderr.write('\nError: invalid sample name, \"' + sample_name + '\"\n')

    def get_var_string(self):
        s = '\t'.join(map(str,[
            self.chrom,
            self.pos,
            self.var_id,
            self.ref,
            self.alt,
            '%0.2f' % self.qual,
            self.filter,
            self.get_info_string(),
            self.get_format_string(),
            '\t'.join(self.genotype(s).get_gt_string() for s in self.sample_list)
        ]))
        return s

class Genotype(object):
    def __init__(self, variant, sample_name, gt):
        self.format = dict()
        self.variant = variant
        self.set_format('GT', gt)

    def set_format(self, field, value):
        if field in [i.id for i in self.variant.format_list]:
            self.format[field] = value
            if field not in self.variant.active_formats:
                self.variant.active_formats.append(field)
                # sort it to be in the same order as the format_list in header
                self.variant.active_formats.sort(key=lambda x: [f.id for f in self.variant.format_list].index(x))
        else:
            sys.stderr.write('\nError: invalid FORMAT field, \"' + field + '\"\n')
            exit(1)

    def get_format(self, field):
        return self.format[field]

    def get_gt_string(self):
        g_list = list()
        for f in self.variant.active_formats:
            if f in self.format:
                if type(self.format[f]) == float:
                    g_list.append('%0.2f' % self.format[f])
                else:
                    g_list.append(self.format[f])
            else:
                g_list.append('.')
        return ':'.join(map(str,g_list))

# efficient combinatorial function to handle extremely large numbers
def log_choose(n, k):
    r = 0.0
    # swap for efficiency if k is more than half of n
    if k * 2 > n:
        k = n - k

    for  d in xrange(1,k+1):
        r += math.log(n, 10)
        r -= math.log(d, 10)
        n -= 1

    return r

# return the genotype and log10 p-value
def bayes_gt(ref, alt, is_dup):
    # probability of seeing an alt read with true genotype of of hom_ref, het, hom_alt respectively
    if is_dup: # specialized logic to handle non-destructive events such as duplications
        p_alt = [0.01, 0.3, 0.5]
    else:
        p_alt = [0.01, 0.5, 0.9]

    total = ref + alt

    lp_homref = log_choose(total, alt) + alt * math.log(p_alt[0], 10) + ref * math.log(1 - p_alt[0], 10)
    lp_het = log_choose(total, alt) + alt * math.log(p_alt[1], 10) + ref * math.log(1 - p_alt[1], 10)
    lp_homalt = log_choose(total, alt) + alt * math.log(p_alt[2], 10) + ref * math.log(1 - p_alt[2], 10)

    return (lp_homref, lp_het, lp_homalt)

# return the 5' alignment coordinate of the mate read by parsing the MC (mate cigar) SAM field
# biobambam puts the coordinate of the mate's pair in the MC flag, while samblaster provides a CIGAR string
def get_mate_5prime(bam, read):
    # if 'MC' in [t[0] for t in read.tags]:
    try:
        mc = read.opt('MC') # the mate CIGAR string
        if mc == '*':
            return
        # try to find from a coordinate (biobambam style)
        try:
            p = int(mc)
        # otherwise get it from a CIGAR string (samblaster style)
        except ValueError:
            keys = re.findall('[MIDNSHPX=]+', mc)
            nums = map(int, re.findall('[^MIDNSHPX=]+', mc))

            p = read.pnext
            for i in xrange(len(keys)):
                k = keys[i]
                n = nums[i]
                if k == 'M' or k == 'N' or k == 'D':
                    p += n
    except KeyError:
        p = bam.mate(read).aend
    return p

def get_mate_mapq(bam, read):
    # if 'MQ' in [t[0] for t in read.tags]:
    try:
        mq = read.opt('MQ') # the mate mapq score
        if mq == '*':
            return
    except KeyError:
        mq = bam.mate(read).mapq
    return mq

# calculate the probability that a read is concordant at a deletion breakpoint,
# given the putative deletion size and insert distribution of the library.
def p_concordant(read_ospan, var_length, ins_dens):
    conc_prior = 0.95 # a priori probability that a read-pair is concordant
    disc_prior = 1 - conc_prior
    try:
        p = float(ins_dens[read_ospan]) * conc_prior / (conc_prior * ins_dens[read_ospan] + disc_prior * (ins_dens[read_ospan - var_length]))
    except ZeroDivisionError:
        p = None
    return p

def count_pairedend(chrom,
                    pos,
                    ci,
                    mate_chrom,
                    mate_pos,
                    mate_ci,
                    o1,
                    o2,
                    svtype,
                    sample,
                    z,
                    discflank):
    conc_counter = 0
    disc_counter = 0
    conc_scaled_counter = 0
    disc_scaled_counter = 0

    fetch_flank = sample.get_fetch_flank(z)

    # don't count paired-end reads if deletion is smaller than 2 * standard dev of library
    if svtype == 'DEL':
        for lib in sample.lib_dict.values():
            if abs(mate_pos - pos) < 2 * lib.sd:
                return (conc_counter, disc_counter, conc_scaled_counter, disc_scaled_counter)

    if o1 == '+' or svtype == 'INV':
        # survey for concordant read pairs
        reads0 = []
        for read in sample.bam.fetch(chrom, max(pos - (fetch_flank), 0), pos):
            if read.is_unmapped: continue
            reads0.append(read)
            lib = sample.get_lib(read.opt('RG')) # get the read's library
            if (read.is_reverse
                or not read.mate_is_reverse
                or (legacy and read.is_secondary)
                or (not legacy and read.flag & 2048 == 2048)
                or read.is_unmapped
                or read.mate_is_unmapped
                or read.is_duplicate
                or read.pos + discflank > pos
                or read.pnext + lib.read_length - discflank < pos
                or read.tid != read.rnext
                or lib not in sample.active_libs
                ):
                continue
            else:
                mate_mapq = get_mate_mapq(sample.bam, read)
                ospan = get_mate_5prime(sample.bam, read) - read.pos
                ispan1 = read.pos + discflank
                ispan2 = get_mate_5prime(sample.bam, read) - discflank - 1 # speed up here
                ispan = ispan2 - ispan1
                prob_conc = p_concordant(ospan, abs(mate_pos - pos), lib.dens)

                if ispan2 > pos and prob_conc is not None: # bail if neither concordant or discordant
                    conc_counter += 1
                    conc_scaled_counter += prob_conc * (1-10**(-read.mapq/10.0)) * (1-10**(-mate_mapq/10.0))
                    # if a deletion, iterate the discordants for these too
                    if svtype == 'DEL' and ispan2 >= mate_ci[0]:
                        # disc_p = (1 - prob_conc) * (1-10**(-read.mapq/10.0)) * (1-10**(-mate_mapq/10.0))
                        # if disc_p > 0.5:
                        #     print ispan2, mate_ci[0]
                        #     print read
                        disc_counter += 1
                        disc_scaled_counter += (1 - prob_conc) * (1-10**(-read.mapq/10.0)) * (1-10**(-mate_mapq/10.0))

        # now look at discordants for corresponding SV types
        if svtype != 'DEL':
            readiter = reads0 or sample.bam.fetch(chrom, max(pos - (fetch_flank), 0), pos)
            for read in readiter:
                lib = sample.get_lib(read.opt('RG')) # get the read's library
                if (read.is_reverse
                    or (legacy and read.is_secondary)
                    or (not legacy and read.flag & 2048 == 2048)
                    or read.is_unmapped
                    or read.mate_is_unmapped
                    or read.is_duplicate
                    or read.pos + discflank > pos
                    or sample.bam.getrname(read.rnext) != mate_chrom
                    or lib not in sample.active_libs
                    ):
                    continue

                mate_mapq = get_mate_mapq(sample.bam, read) # move this for speed
                mate_5prime = get_mate_5prime(sample.bam, read)
                if svtype == 'DUP':
                    if not read.mate_is_reverse or read.pnext > read.pos or read.tid != read.rnext:
                        continue
                    ispan1 = read.pos + discflank
                    ispan2 = mate_5prime - discflank - 1
                    if ispan1 < pos and ispan2 < ispan1 and ispan2 >= mate_ci[0] and ispan2 <= (mate_ci[1] + (lib.mean + lib.sd * z)):
                        # print "+-", read
                        disc_counter += 1
                        disc_scaled_counter += (1-10**(-read.mapq/10.0)) * (1-10**(-mate_mapq/10.0))
                elif svtype == 'INV':
                    if read.mate_is_reverse:
                        continue
                    ispan1 = read.pos + discflank
                    ispan2 = read.pnext + discflank
                    if ispan1 < pos and ispan2 <= mate_ci[1] and ispan2 >= (mate_ci[0] - (lib.mean + lib.sd * z)):
                        # print "++", read
                        disc_counter += 1
                        disc_scaled_counter += (1-10**(-read.mapq/10.0)) * (1-10**(-mate_mapq/10.0))
                elif svtype == 'BND':
                    if o2 == '-':
                        if not read.mate_is_reverse or read.pnext + lib.read_length - discflank < mate_pos:
                            continue
                        ispan1 = read.pos + discflank
                        ispan2 = mate_5prime - discflank - 1
                        if ispan1 < pos and ispan2 >= mate_ci[0] and ispan2 <= (mate_ci[1] + (lib.mean + lib.sd * z)):
                            # print "+- BND", read
                            disc_counter += 1
                            disc_scaled_counter += (1-10**(-read.mapq/10.0)) * (1-10**(-mate_mapq/10.0))
                    if o2 == '+':
                        if read.mate_is_reverse or read.pnext + discflank > mate_pos:
                            continue
                        ispan1 = read.pos + discflank
                        ispan2 = read.pnext + discflank
                        if ispan1 < pos and ispan2 <= mate_ci[1] and ispan2 >= (mate_ci[0] - (lib.mean + lib.sd * z)):
                            disc_counter += 1
                            disc_scaled_counter += (1-10**(-read.mapq/10.0)) * (1-10**(-mate_mapq/10.0))

    reads1 = []
    if o1 == '-' or svtype == 'INV':
        # survey for concordant read pairs
        for read in sample.bam.fetch(chrom, max(pos, 0), pos + (fetch_flank)):
            if read.is_unmapped: continue
            reads1.append(read)
            lib = sample.get_lib(read.opt('RG')) # get the read's library
            if (not read.is_reverse
                or read.mate_is_reverse
                or (legacy and read.is_secondary)
                or (not legacy and read.flag & 2048 == 2048)
                or read.is_unmapped
                or read.mate_is_unmapped
                or read.is_duplicate
                or read.aend - discflank < pos
                or read.pnext + discflank > pos
                or read.tid != read.rnext
                or lib not in sample.active_libs
                ):
                continue
            else:
                mate_mapq = get_mate_mapq(sample.bam, read)
                ospan = read.aend - read.pnext
                ispan1 = read.aend - discflank - 1
                ispan2 = read.pnext + discflank
                ispan = ispan1 - ispan2
                prob_conc = p_concordant(ospan, abs(mate_pos - pos), lib.dens)

                if ispan2 < pos and prob_conc is not None: # bail if neither concordant or discordant
                    conc_counter += 1
                    conc_scaled_counter += prob_conc * (1-10**(-read.mapq/10.0)) * (1-10**(-mate_mapq/10.0))
                    if svtype == 'DEL' and ispan2 <= mate_ci[1]:
                        # disc_p = (1 - prob_conc) * (1-10**(-read.mapq/10.0)) * (1-10**(-mate_mapq/10.0))
                        # if disc_p > 0.5:
                        #     print ispan1, mate_ci[1]
                        #     print read
                        disc_counter += 1
                        disc_scaled_counter += (1 - prob_conc) * (1-10**(-read.mapq/10.0)) * (1-10**(-mate_mapq/10.0))

        # now look at discordants for corresponding SV types
        if svtype != 'DEL':
            readiter = reads1 or sample.bam.fetch(chrom, max(pos, 0), pos + (fetch_flank))
            for read in readiter:
                lib = sample.get_lib(read.opt('RG')) # get the read's library
                if (not read.is_reverse
                    or (legacy and read.is_secondary)
                    or (not legacy and read.flag & 2048 == 2048)
                    or read.is_unmapped
                    or read.mate_is_unmapped
                    or read.is_duplicate
                    or read.aend - discflank < pos
                    or sample.bam.getrname(read.rnext) != mate_chrom
                    or lib not in sample.active_libs
                    ):
                    continue

                mate_mapq = get_mate_mapq(sample.bam, read) # move this for speed
                mate_5prime = get_mate_5prime(sample.bam, read)

                if svtype == 'DUP':
                    if read.mate_is_reverse or read.pnext < read.pos or read.tid != read.rnext:
                        continue
                    ispan1 = read.aend - discflank - 1
                    ispan2 = read.pnext + discflank
                    if ispan1 > pos and ispan2 > ispan1 and ispan2 <= mate_ci[1] and ispan2 >= (mate_ci[0] - (lib.mean + lib.sd * z)):
                        # print "-+", read
                        disc_counter += 1
                        disc_scaled_counter += (1-10**(-read.mapq/10.0)) * (1-10**(-mate_mapq/10.0))
                elif svtype == 'INV':
                    if not read.mate_is_reverse:
                        continue
                    ispan1 = read.aend - discflank - 1
                    ispan2 = mate_5prime - discflank - 1
                    if ispan1 > pos and ispan2 >= mate_ci[0] and ispan2 <= (mate_ci[1] + (lib.mean + lib.sd * z)):
                        # print "--", read
                        disc_counter += 1
                        disc_scaled_counter += (1-10**(-read.mapq/10.0)) * (1-10**(-mate_mapq/10.0))
                elif svtype == 'BND':
                    if o2 == '-':
                        if not read.mate_is_reverse or read.pnext + lib.read_length - discflank < mate_pos:
                            continue
                        ispan1 = read.aend - discflank - 1
                        ispan2 = mate_5prime - discflank - 1
                        if ispan1 > pos and ispan2 >= mate_ci[0] and ispan2 <= (mate_ci[1] + (lib.mean + lib.sd * z)):
                            disc_counter += 1
                            disc_scaled_counter += (1-10**(-read.mapq/10.0)) * (1-10**(-mate_mapq/10.0))
                    if o2 == '+':
                        if read.mate_is_reverse or read.pnext + discflank > mate_pos:
                            continue
                        ispan1 = read.aend - discflank - 1
                        ispan2 = read.pnext + discflank
                        if ispan1 > pos and ispan2 <= mate_ci[1] and ispan2 >= (mate_ci[0] - (lib.mean + lib.sd * z)):
                            disc_counter += 1
                            disc_scaled_counter += (1-10**(-read.mapq/10.0)) * (1-10**(-mate_mapq/10.0))

    return (conc_counter, disc_counter, conc_scaled_counter, disc_scaled_counter)

# get the number of entries in the set
def countRecords(myCounter):
    numRecords = sum(myCounter.values())
    return numRecords

# median is approx 50th percentile, except when it is between
# two values in which case it's the mean of them.
def median(myCounter):
    #length is the number of bases we're looking at
    numEntries = countRecords(myCounter)

    # the ordinal value of the middle element
    # if 2 middle elements, then non-integer
    limit = 0.5 * numEntries

    # a list of the values, sorted smallest to largest
    # note that this list contains unique elements only
    valueList = list(myCounter)
    valueList.sort()

    # number of entries we've gone through
    runEntries = 0
    # index of the current value in valueList
    i = 0
    # initiate v, in case list only has one element
    v = valueList[i]

    # move through the value list, iterating by number of
    # entries for each value
    while runEntries < limit:
        v = valueList[i]
        runEntries += myCounter[v]
        i += 1
    if runEntries == limit:
        return (v + valueList[i]) / 2.0
    else:
        return v

# calculate upper median absolute deviation
def upper_mad(myCounter, myMedian):
    residCounter = Counter()
    for x in myCounter:
        if x > myMedian:
            residCounter[abs(x - myMedian)] += myCounter[x]
    return median(residCounter)

# sum of the entries
def sumRecords(myCounter):
    mySum = 0.0
    for c in myCounter:
        mySum += c * float(myCounter[c])
    return mySum

# calculate the arithmetic mean, given a counter and the
# length of the feature (chromosome or genome)
# for x percentile, x% of the elements in the set are
# <= the output value
def mean(myCounter):
    # the number of total entries in the set is the
    # sum of the occurrences for each value
    numRecords = countRecords(myCounter)

    # u holds the mean
    u = float()

    u = sumRecords(myCounter) / numRecords
    return u

def stdev(myCounter):
    # the number of total entries in the set is the
    # sum of the occurrences for each value
    numRecords = countRecords(myCounter)

    # u holds the mean
    u = mean(myCounter)
    sumVar = 0.0

    # stdev is sqrt(sum((x-u)^2)/#elements)
    for c in myCounter:
        sumVar += myCounter[c] * (c - u)**2
    myVariance = float(sumVar) / numRecords
    stdev = myVariance**(0.5)
    return stdev

# holds a library's insert size and read length information
class Library(object):
    def __init__(self, bam, name, num_samp):
        self.bam = bam
        self.name = name
        self.num_samp = num_samp
        self.readgroups = list()
        self.read_length = int()
        self.hist = dict()
        self.dens = dict()
        self.mean = float()
        self.sd = float()
        self.prevalence = float()

    def add_readgroup(self, rg):
        self.readgroups.append(rg)

    # calculate the library's prevalence in the BAM file
    def calc_lib_prevalence(self):
        max_count = 100000
        lib_counter = 0
        read_counter = 0

        for read in self.bam.fetch():
            if read_counter == max_count:
                break
            if read.opt('RG') in self.readgroups:
                lib_counter += 1
            read_counter += 1

        self.prevalence = float(lib_counter) / read_counter

    # get read length
    def calc_read_length(self):
        max_rl = 0
        counter = 0
        num_samp = 10000
        for read in self.bam.fetch():
            if read.opt('RG') not in self.readgroups:
                continue
            if read.qlen > max_rl:
                max_rl = read.qlen
            if counter == num_samp:
                break
            counter += 1
        self.read_length = max_rl

    # generate empirical histogram of the sample's insert size distribution
    # CC NOTE: REMOVE BEYOND Z stdev!!!!
    def calc_insert_hist(self):
        counter = 0
        skip = 0
        skip_counter = 0
        mads = 10
        ins_list = []

        # Each entry in valueCounts is a value, and its count is
        # the number of instances of that value observed in the dataset.
        # So valueCount[5] is the number of times 5 has been seen in the data.
        valueCounts = Counter()
        for read in self.bam:
            if skip_counter < skip:
                skip_counter += 1
                continue
            if (read.is_reverse
                or not read.mate_is_reverse
                or read.is_unmapped
                or read.mate_is_unmapped
                or (legacy and read.is_secondary)
                or (not legacy and read.flag & 2048 == 2048)
                or read.tlen <= 0
                or read.opt('RG') not in self.readgroups):
                continue
            else:
                valueCounts[read.tlen] += 1
                counter += 1
            if counter == self.num_samp:
                break

        # remove outliers
        med = median(valueCounts)
        u_mad = upper_mad(valueCounts, med)
        for x in [x for x in list(valueCounts) if x > med + mads * u_mad]:
            del valueCounts[x]

        self.hist = valueCounts
        self.mean = mean(self.hist)
        self.sd = stdev(self.hist)

    # calculate the density curve for and insert size histogram
    def calc_insert_density(self):
        dens = Counter()
        for i in list(self.hist):
            dens[i] = float(self.hist[i])/countRecords(self.hist)
        self.dens = dens


# holds each sample's BAM and library information
class Sample(object):
    def __init__(self, bam, spl_bam, num_samp, min_lib_prevalence):
        self.name = bam.header['RG'][0]['SM']
        self.bam = bam
        self.spl_bam = spl_bam

        self.lib_dict = dict()
        self.active_libs = []

        # parse library design
        self.rg_to_lib = dict()
        for r in self.bam.header['RG']:
            try:
                lib_name=r['LB']
            except KeyError, e:
                lib_name=''

            # add the new library
            if lib_name not in self.lib_dict:
                new_lib = Library(self.bam, lib_name, num_samp)
                self.lib_dict[lib_name] = new_lib
            self.rg_to_lib[r['ID']] = self.lib_dict[lib_name]
            self.lib_dict[lib_name].add_readgroup(r['ID'])

        # execute calculations
        for name in list(self.lib_dict):
            self.lib_dict[name].calc_lib_prevalence()
            # delete library if it constitutes less than min fraction of BAM file
            if self.lib_dict[name].prevalence < min_lib_prevalence:
                self.lib_dict.pop(name)
                continue
            self.lib_dict[name].calc_read_length()
            self.lib_dict[name].calc_insert_hist()
            self.lib_dict[name].calc_insert_density()
            self.active_libs += [self.lib_dict[name]]

    # get the maximum fetch flank for reading the BAM file
    def get_fetch_flank(self, z):
        return max([lib.mean + (lib.sd * z) for lib in self.lib_dict.values()])

    # return the library object for a specified read group
    def get_lib(self, readgroup):
        return self.rg_to_lib[readgroup]

# primary function
def sv_genotype(vcf_file,
                bam_string,
                spl_bam_string,
                vcf_out,
                splflank,
                discflank,
                split_weight,
                disc_weight,
                num_samp,
                debug):

    # parse the comma separated inputs
    bam_list = [pysam.Samfile(b, 'rb') for b in bam_string.split(',')]
    if spl_bam_string is not None:
        spl_bam_list = [pysam.Samfile(b, 'rb') for b in spl_bam_string.split(',')]
        if len(bam_list) != len(spl_bam_list):
            sys.stderr.write('\nError: Number of full BAMs and splitters BAMs differ.\n')
            exit(1)
    else:
        sys.stderr.write('Warning: Splitters BAM file (-S) not found. Performance will suffer.\n')

    min_lib_prevalence = 1e-3 # only consider libraries that constitute at least this fraction of the BAM
    sample_list = list()
    for i in xrange(len(bam_list)):
        if spl_bam_string is None:
            sample = Sample(bam_list[i], None, num_samp, min_lib_prevalence)
        else:
            sample = Sample(bam_list[i], spl_bam_list[i], num_samp, min_lib_prevalence)
        sample_list.append(sample)

    z = 3
    padding = 30 # this is the left/right distance for fetching splitters
    split_slop = 3 # amount of slop around breakpoint to count splitters
    in_header = True
    header = []
    breakend_dict = {} # cache to hold unmatched generic breakends for genotyping
    vcf = Vcf()

    # read input VCF
    for line in vcf_file:
        if in_header:
            if line[0] == '#':
                header.append(line)
                if line[1] != '#':
                    vcf_samples = line.rstrip().split('\t')[9:]
                continue
            else:
                in_header = False
                vcf.add_header(header)
                # if detailed:
                vcf.add_format('GQ', 1, 'Integer', 'Genotype quality')
                vcf.add_format('SQ', 1, 'Float', 'Phred-scaled probability that this site is variant (non-reference in this sample')
                vcf.add_format('GL', 'G', 'Float', 'Genotype Likelihood, log10-scaled likelihoods of the data given the called genotype for each possible genotype generated from the reference and alternate alleles given the sample ploidy')
                vcf.add_format('DP', 1, 'Integer', 'Read depth')
                vcf.add_format('RO', 1, 'Integer', 'Reference allele observation count, with partial observations recorded fractionally')
                vcf.add_format('AO', 'A', 'Integer', 'Alternate allele observations, with partial observations recorded fractionally')
                vcf.add_format('QR', 1, 'Integer', 'Sum of quality of reference observations')
                vcf.add_format('QA', 'A', 'Integer', 'Sum of quality of alternate observations')
                vcf.add_format('RS', 1, 'Integer', 'Reference allele split-read observation count, with partial observations recorded fractionally')
                vcf.add_format('AS', 'A', 'Integer', 'Alternate allele split-read observation count, with partial observations recorded fractionally')
                vcf.add_format('RP', 1, 'Integer', 'Reference allele paired-end observation count, with partial observations recorded fractionally')
                vcf.add_format('AP', 'A', 'Integer', 'Alternate allele paired-end observation count, with partial observations recorded fractionally')
                vcf.add_format('AB', 'A', 'Float', 'Allele balance, fraction of observations from alternate allele, QA/(QR+QA)')


                # add the samples in the BAM files to the VCF output
                for sample in sample_list:
                    if sample.name not in vcf.sample_list:
                        vcf.add_sample(sample.name)

                # write the output header
                vcf_out.write(vcf.get_header() + '\n')


        v = line.rstrip().split('\t')
        var = Variant(v, vcf)

        # genotype generic breakends
        if var.info['SVTYPE']=='BND':
            if var.info['MATEID'] in breakend_dict:
                var2 = var
                var = breakend_dict[var.info['MATEID']]
                chromA = var.chrom
                chromB = var2.chrom
                posA = var.pos
                posB = var2.pos
                # confidence intervals
                ciA = [posA + ci for ci in map(int, var.info['CIPOS'].split(','))]
                ciB = [posB + ci for ci in map(int, var2.info['CIPOS'].split(','))]

                # infer the strands from the alt allele
                if var.alt[-1] == '[' or var.alt[-1] == ']':
                    o1 = '+'
                else: o1 = '-'
                if var2.alt[-1] == '[' or var2.alt[-1] == ']':
                    o2 = '+'
                else: o2 = '-'
            else:
                breakend_dict[var.var_id] = var
                continue
        else:
            chromA = var.chrom
            chromB = var.chrom
            posA = var.pos
            posB = int(var.get_info('END'))
            # confidence intervals
            ciA = [posA + ci for ci in map(int, var.info['CIPOS'].split(','))]
            ciB = [posB + ci for ci in map(int, var.info['CIEND'].split(','))]
            if var.get_info('SVTYPE') == 'DEL':
                o1, o2 =  '+', '-'
            elif var.get_info('SVTYPE') == 'DUP':
                o1, o2 =  '-', '+'
            elif var.get_info('SVTYPE') == 'INV':
                o1, o2 =  '+', '+'

        # increment the negative strand values (note position in VCF should be the base immediately left of the breakpoint junction)
        if o1 == '-': posA += 1
        if o2 == '-': posB += 1
        # if debug: print posA, posB

        # for i in xrange(len(bam_list)):
        for sample in sample_list:
            '''
            Breakend A
            '''
            # Count splitters
            ref_counter_a = Counter()
            spl_counter_a = Counter()
            ref_scaled_counter_a = Counter()
            spl_scaled_counter_a = Counter()
            if spl_bam_string is not None:
                for ref_read in sample.bam.fetch(chromA, max(posA - padding, 0), posA + padding + 1):
                    if not ref_read.is_duplicate and not ref_read.is_unmapped:
                        for p in xrange(ref_read.pos + splflank, ref_read.aend + 1 - splflank):
                            if p - ref_read.pos >= splflank and ref_read.aend - p >= splflank:
                                ref_counter_a[p] += 1
                                ref_scaled_counter_a[p] += (1-10**(-ref_read.mapq/10.0))


                for spl_read in sample.spl_bam.fetch(chromA, max(posA - padding, 0), posA + padding + 1):
                    if not spl_read.is_duplicate and not spl_read.is_unmapped:
                        if o1 == '+' and spl_read.cigar[0][0] == 0:
                            # if debug: print 'o1+', spl_read.aend
                            spl_counter_a[spl_read.aend] += 1
                            spl_scaled_counter_a[spl_read.aend] += (1-10**(-spl_read.mapq/10.0))
                        elif o1 == '-' and spl_read.cigar[-1][0] == 0:
                            # if debug: print 'o1-', spl_read.pos + 1
                            spl_counter_a[spl_read.pos + 1] += 1
                            spl_scaled_counter_a[spl_read.pos + 1] += (1-10**(-spl_read.mapq/10.0))

            # Count paired-end discordant and concordants
            (conc_counter_a,
             disc_counter_a,
             conc_scaled_counter_a,
             disc_scaled_counter_a) = count_pairedend(chromA, posA, ciA,
                                                      chromB, posB, ciB,
                                                      o1, o2,
                                                      var.info['SVTYPE'],
                                                      sample,
                                                      z, discflank)
            '''
            Breakend B
            '''
            # Count splitters
            ref_counter_b = Counter()
            spl_counter_b = Counter()
            ref_scaled_counter_b = Counter()
            spl_scaled_counter_b = Counter()
            if spl_bam_string is not None:
                for ref_read in sample.bam.fetch(chromB, max(posB - padding, 0), posB + padding + 1):
                    if not ref_read.is_duplicate and not ref_read.is_unmapped:
                        for p in xrange(ref_read.pos + splflank, ref_read.aend + 1 - splflank):
                            if p - ref_read.pos >= splflank and ref_read.aend - p >= splflank:
                                ref_counter_b[p] += 1
                                ref_scaled_counter_b[p] += (1-10**(-ref_read.mapq/10.0))
                for spl_read in sample.spl_bam.fetch(chromB, max(posB - padding, 0), posB + padding + 1):
                    if not spl_read.is_duplicate and not spl_read.is_unmapped:
                        if o2 == '+' and spl_read.cigar[0][0] == 0:
                            spl_counter_b[spl_read.aend] += 1
                            # if debug: print 'o2+', spl_read.aend
                            spl_scaled_counter_b[spl_read.aend] += (1-10**(-spl_read.mapq/10.0))
                        elif o2 == '-' and spl_read.cigar[-1][0] == 0:
                            # if debug: print 'o2-', spl_read.pos + 1
                            spl_counter_b[spl_read.pos + 1] += 1
                            spl_scaled_counter_b[spl_read.pos + 1] += (1-10**(-spl_read.mapq/10.0))

            # tally up the splitters
            sr_ref_a = int(round(sum(ref_counter_a[p] for p in xrange(posA - split_slop, posA + split_slop + 1)) / float(2 * split_slop + 1)))
            sr_spl_a = sum(spl_counter_a[p] for p in xrange(posA-split_slop, posA+split_slop + 1))
            sr_ref_b = int(round(sum(ref_counter_b[p] for p in xrange(posB - split_slop, posB + split_slop + 1)) / float(2 * split_slop + 1)))
            sr_spl_b = sum(spl_counter_b[p] for p in xrange(posB - split_slop, posB + split_slop + 1))

            sr_ref_scaled_a = sum(ref_scaled_counter_a[p] for p in xrange(posA - split_slop, posA + split_slop + 1)) / float(2 * split_slop + 1)
            sr_spl_scaled_a = sum(spl_scaled_counter_a[p] for p in xrange(posA-split_slop, posA+split_slop + 1))
            sr_ref_scaled_b = sum(ref_scaled_counter_b[p] for p in xrange(posB - split_slop, posB + split_slop + 1)) / float(2 * split_slop + 1)
            sr_spl_scaled_b = sum(spl_scaled_counter_b[p] for p in xrange(posB - split_slop, posB + split_slop + 1))

            # Count paired-end discordants and concordants
            (conc_counter_b,
             disc_counter_b,
             conc_scaled_counter_b,
             disc_scaled_counter_b) = count_pairedend(chromB, posB, ciB,
                                                      chromA, posA, ciA,
                                                      o2, o1,
                                                      var.info['SVTYPE'],
                                                      sample,
                                                      z, discflank)
            if debug:
                print '--------------------'
                print sample.name
                print 'sr_a', '(ref, alt)', sr_ref_a, sr_spl_a
                print 'pe_a', '(ref, alt)', conc_counter_a, disc_counter_a
                print 'sr_b', '(ref, alt)', sr_ref_b, sr_spl_b
                print 'pe_b', '(ref, alt)', conc_counter_b, disc_counter_b
                print 'sr_a_scaled', '(ref, alt)', sr_ref_scaled_a, sr_spl_scaled_a
                print 'pe_a_scaled', '(ref, alt)', conc_scaled_counter_a, disc_scaled_counter_a
                print 'sr_b_scaled', '(ref, alt)', sr_ref_scaled_b, sr_spl_scaled_b
                print 'pe_b_scaled', '(ref, alt)', conc_scaled_counter_b, disc_scaled_counter_b

            # merge the breakend support
            split_ref = 0 # set these to zero unless there are informative alt bases for the ev type
            disc_ref = 0
            split_alt = sr_spl_a + sr_spl_b
            if split_alt > 0:
                split_ref = sr_ref_a + sr_ref_b
            disc_alt = disc_counter_a + disc_counter_b
            if disc_alt > 0:
                disc_ref = conc_counter_a + conc_counter_b
            if split_alt == 0 and disc_alt == 0:
                split_ref = sr_ref_a + sr_ref_b
                disc_ref = conc_counter_a + conc_counter_b

            split_scaled_ref = 0 # set these to zero unless there are informative alt bases for the ev type
            disc_scaled_ref = 0
            split_scaled_alt = sr_spl_scaled_a + sr_spl_scaled_b
            if int(split_scaled_alt) > 0:
                split_scaled_ref = sr_ref_scaled_a + sr_ref_scaled_b
            disc_scaled_alt = disc_scaled_counter_a + disc_scaled_counter_b
            if int(disc_scaled_alt) > 0:
                disc_scaled_ref = conc_scaled_counter_a + conc_scaled_counter_b
            if int(split_scaled_alt) == 0 and int(disc_scaled_alt) == 0: # if no alt alleles, set reference
                split_scaled_ref = sr_ref_scaled_a + sr_ref_scaled_b
                disc_scaled_ref = conc_scaled_counter_a + conc_scaled_counter_b

            if split_scaled_alt + split_scaled_ref + disc_scaled_alt + disc_scaled_ref > 0:
                # get bayesian classifier
                if var.info['SVTYPE'] == "DUP": is_dup = True
                else: is_dup = False
                QR = int(split_weight * split_scaled_ref) + int(disc_weight * disc_scaled_ref)
                QA = int(split_weight * split_scaled_alt) + int(disc_weight * disc_scaled_alt)
                gt_lplist = bayes_gt(QR, QA, is_dup)
                gt_idx = gt_lplist.index(max(gt_lplist))

                # print log probabilities of homref, het, homalt
                if debug:
                    print gt_lplist

                # set the overall variant QUAL score and sample specific fields
                var.genotype(sample.name).set_format('GL', ','.join(['%.0f' % x for x in gt_lplist]))
                var.genotype(sample.name).set_format('DP', int(split_scaled_ref + split_scaled_alt + disc_scaled_ref + disc_scaled_alt))
                var.genotype(sample.name).set_format('RO', int(split_scaled_ref + disc_scaled_ref))
                var.genotype(sample.name).set_format('AO', int(split_scaled_alt + disc_scaled_alt))
                var.genotype(sample.name).set_format('QR', QR)
                var.genotype(sample.name).set_format('QA', QA)
                # if detailed:
                var.genotype(sample.name).set_format('RS', int(split_scaled_ref))
                var.genotype(sample.name).set_format('AS', int(split_scaled_alt))
                var.genotype(sample.name).set_format('RP', int(disc_scaled_ref))
                var.genotype(sample.name).set_format('AP', int(disc_scaled_alt))
                try:
                    var.genotype(sample.name).set_format('AB', '%.2g' % (QA / float(QR + QA)))
                except ZeroDivisionError:
                    var.genotype(sample.name).set_format('AB', '.')


                # assign genotypes
                gt_sum = 0
                for gt in gt_lplist:
                    try:
                        gt_sum += 10**gt
                    except OverflowError:
                        gt_sum += 0
                if gt_sum > 0:
                    gt_sum_log = math.log(gt_sum, 10)
                    sample_qual = abs(-10 * (gt_lplist[0] - gt_sum_log)) # phred-scaled probability site is non-reference in this sample
                    if 1 - (10**gt_lplist[gt_idx] / 10**gt_sum_log) == 0:
                        phred_gq = 200
                    else:
                        phred_gq = abs(-10 * math.log(1 - (10**gt_lplist[gt_idx] / 10**gt_sum_log), 10))
                    var.genotype(sample.name).set_format('GQ', int(phred_gq))
                    var.genotype(sample.name).set_format('SQ', sample_qual)
                    var.qual += sample_qual
                    if gt_idx == 1:
                        var.genotype(sample.name).set_format('GT', '0/1')
                    elif gt_idx == 2:
                        var.genotype(sample.name).set_format('GT', '1/1')
                    elif gt_idx == 0:
                        var.genotype(sample.name).set_format('GT', '0/0')
                else:
                    var.genotype(sample.name).set_format('GQ', '.')
                    var.genotype(sample.name).set_format('SQ', '.')
                    var.genotype(sample.name).set_format('GT', './.')
            else:
                var.genotype(sample.name).set_format('GT', './.')
                var.qual = 0
                var.genotype(sample.name).set_format('GQ', '.')
                var.genotype(sample.name).set_format('SQ', '.')
                var.genotype(sample.name).set_format('GL', '.')
                var.genotype(sample.name).set_format('DP', 0)
                var.genotype(sample.name).set_format('AO', 0)
                var.genotype(sample.name).set_format('RO', 0)
                # if detailed:
                var.genotype(sample.name).set_format('AS', 0)
                var.genotype(sample.name).set_format('RS', 0)
                var.genotype(sample.name).set_format('AP', 0)
                var.genotype(sample.name).set_format('RP', 0)
                var.genotype(sample.name).set_format('QR', 0)
                var.genotype(sample.name).set_format('QA', 0)
                var.genotype(sample.name).set_format('AB', '.')

        # after all samples have been processed, write
        vcf_out.write(var.get_var_string() + '\n')
        if var.info['SVTYPE'] == 'BND':
            var2.qual = var.qual
            var2.active_formats = var.active_formats
            var2.genotype = var.genotype
            vcf_out.write(var2.get_var_string() + '\n')
    vcf_out.close()

    return

# --------------------------------------
# main function

def main():
    # parse the command line args
    args = get_args()

    global legacy
    legacy = args.legacy

    # call primary function
    sv_genotype(args.input_vcf,
                args.bam,
                args.split_bam,
                args.output_vcf,
                args.splflank,
                args.discflank,
                args.split_weight,
                args.disc_weight,
                args.num_samp,
                args.debug)

    # close the files
    args.input_vcf.close()

# initialize the script
if __name__ == '__main__':
    try:
        sys.exit(main())
    except IOError, e:
        if e.errno != 32:  # ignore SIGPIPE
            raise
