#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""%prog [options] <A.(bam|sam)> <operation> <B.(bam|sam)> <outfile.bam>

Performs set operation on two SAM/BAM files and outputs a BAM file.
Resulting file will have the same header as file A.

WARNING: Implementation is neither very fast nor memory efficient.

Possible operations:
  union:        Output union of A and B, abort with error if
                different lines for same read are encountered.
  intersection: Output intersection of A and B, abort with error if
                different lines for same read are encountered.
  setminus:     Outputs all read in A that are not in B.
  symdiff:      Output all reads in A or B but not in both."""
from __future__ import print_function, division
from sqt import HelpfulOptionParser
import sys
from pysam import Samfile

__author__ = "Tobias Marschall"


def SamOrBam(name, mode='r'):
	if name.endswith('.bam'):
		mode += 'b'
	return Samfile(name, mode)


def remove_suffix(s):
	i = s.rfind('/')
	if i == -1: return s
	else: return s[:i]


def nop(s):
	return s


def dict_of_reads(reads, exclude_unmapped, rename):
	d = dict()
	for read in reads:
		if exclude_unmapped and read.is_unmapped: continue
		name = rename(read.qname)
		if d.has_key(name):
			raise Exception("Duplicate read in input file (%s)"%name)
		d[name] = read
	return d


def union(A, B,outfile, exclude_unmapped_A, exclude_unmapped_B, rename):
	readsB = dict_of_reads(B, exclude_unmapped_B, rename)
	readnamesA = set()
	for read in A:
		if exclude_unmapped_A and read.is_unmapped: continue
		name = rename(read.qname)
		if name in readnamesA:
			raise Error("Duplicate read in input file (%s)"%name)
		readnamesA.add(name)
		if readsB.has_key(name):
			if read.compare(readsB[name]) != 0:
				print('Content mismatch for read %s:'%name, file=sys.stderr)
				print('File A:',read, file=sys.stderr)
				print('File B:',readsB[name], file=sys.stderr)
				sys.exit(1)
			readsB.pop(name)
		outfile.write(read)
	for read in readsB.itervalues():
		outfile.write(read)


def intersection(A,B,outfile,exclude_unmapped_A,exclude_unmapped_B,rename):
	readsB = dict_of_reads(B, exclude_unmapped_B, rename)
	readnamesA = set()
	for read in A:
		if exclude_unmapped_A and read.is_unmapped: continue
		name = rename(read.qname)
		if name in readnamesA:
			raise Error("Duplicate read in input file (%s)"%name)
		readnamesA.add(name)
		if readsB.has_key(name):
			if read.compare(readsB[name])!=0:
				print('Content mismatch for read %s:'%name, file=sys.stderr)
				print('File A:',read, file=sys.stderr)
				print('File B:',readsB[name], file=sys.stderr)
				sys.exit(1)
			outfile.write(read)


def setminus(A,B,outfile,exclude_unmapped_A,exclude_unmapped_B,rename):
	if exclude_unmapped_B: readnamesB = set((rename(read.qname) for read in B if not read.is_unmapped))
	else: readnamesB = set((rename(read.qname) for read in B))
	for read in A:
		if exclude_unmapped_A and read.is_unmapped: continue
		if not rename(read.qname) in readnamesB:
			outfile.write(read)


def symdiff(A,B,outfile,exclude_unmapped_A,exclude_unmapped_B,rename):
	if exclude_unmapped_B: readsB = dict(((rename(read.qname),read) for read in B if not read.is_unmapped))
	else: readsB = dict(((rename(read.qname),read) for read in B))
	for read in A:
		if exclude_unmapped_A and read.is_unmapped: continue
		name = rename(read.qname)
		if readsB.has_key(name):
			readsB.pop(name)
		else:
			outfile.write(read)
	for read in readsB.itervalues():
		outfile.write(read)


class Counter:
	count = 0
	def write(self, x):
		self.count += 1


def main():
	parser = HelpfulOptionParser(usage=__doc__)
	parser.add_option("-s", action="store_true", dest="sam_output", default=False,
		help="Output SAM file instead of BAM file")
	parser.add_option("-c", action="store_true", dest="just_count", default=False,
		help="Do not write output file but give the number of reads that would be written. Name of output filename need (and must) not be given in this case.")
	parser.add_option("-U", action="store_true", dest="exclude_unmapped_A", default=False,
		help="Exclude unmapped reads from file A")
	parser.add_option("-V", action="store_true", dest="exclude_unmapped_B", default=False,
		help="Exclude unmapped reads from file B")
	parser.add_option("-r", action="store_true", dest="remove_name_suffix", default=False,
		help="Remove trailing \"/*\" from read names. Useful if one mapper appends \"/1\" and another does not.")
	(options, args) = parser.parse_args()
	if len(args) != 3 + (0 if options.just_count else 1):
		parser.print_help()
		sys.exit(1)
	operation = args[1]
	if not operation in ['union','intersection','setminus','symdiff']:
		print('Error: Unrecognized operation "%s"'%operation, file=sys.stderr)
		sys.exit(1)
	A = SamOrBam(args[0])
	B = SamOrBam(args[2])
	if options.just_count:
		outfile = Counter()
	else:
		outfile = Samfile(args[3], 'wh' if options.sam_output else 'wb', template=A)
	globals()[operation](A,B,outfile,options.exclude_unmapped_A,options.exclude_unmapped_B,remove_suffix if options.remove_name_suffix else nop)
	if options.just_count:
		print(outfile.count)


if __name__ == '__main__':
	main())
