#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
%prog [options] [FASTQ-file]

Trim low-quality ends from sequences given in the FASTQ file.
If the name of the file is not given, standard input is read.

The resulting FASTQ file is written to standard output.

The algorithm is the same as the one used by BWA:
- Subtract the cutoff value from all qualities.
- Compute partial sums from all indices to the end of the sequence.
- Trim sequence at the index at which the sum is minimal.
"""
from __future__ import print_function, division

import sys
from optparse import OptionParser
from random import sample
from collections import defaultdict

from sqt.io.fasta import FastqReader, writefastq
from sqt.qualtrim import quality_trim_index as trim_index

__author__ = "Marcel Martin"


def main():
	parser = OptionParser(usage=__doc__)
	parser.add_option("-q", "--cutoff", type=int, default=10,
		help="Quality cutoff (default: %default)")
	parser.add_option("--histogram", action="store_true", default=False,
		help="Print a histogram of the length of removed ends")
	parser.add_option("-c", "--colorspace", action="store_true", default=False,
		help="Assume input files are in color space and that the sequences contain an initial primer base")

	(options, args) = parser.parse_args()
	if len(args) > 1:
		parser.error("Need the name of at most one FASTQ file.")
	if len(args) == 0:
		infile = sys.stdin
	else:
		infile = open(args[0])
	fastqfile = FastqReader(infile, options.colorspace)

	print("using a cutoff of", options.cutoff, file=sys.stderr)
	histogram = defaultdict(int)
	total_bases = 0
	total_trimmed = 0
	for sequence in fastqfile:
		desc = sequence.name
		seq = sequence.sequence
		qualities = sequence.qualities

		total_bases += len(qualities)
		index = trim_index(qualities, options.cutoff)
		total_trimmed += len(qualities) - index
		if options.histogram:
			histogram[len(qualities) - index] += 1
		qualities = qualities[:index]
		if options.colorspace:
			seq = seq[:index+1]
		else:
			seq = seq[:index]
		writefastq(sys.stdout, [(desc, seq, qualities)])
	if total_bases > 0:
		print("{} of {} bases trimmed ({:.2%})".format(total_trimmed, total_bases, total_trimmed/total_bases), file=sys.stderr)
	if options.histogram:
		for length, count in histogram.items():
			print(length, count, sep="\t", file=sys.stderr)


if __name__ == '__main__':
	main()
