#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
%prog [options] <FASTA>

Returns statistics on a given FASTA file.
"""
from __future__ import print_function, division

import sys
from collections import defaultdict

from sqt import HelpfulOptionParser
from sqt.io.fasta import FastaReader

__author__ = "Tobias Marschall"

def main():
	parser = HelpfulOptionParser(usage=__doc__)
	parser.add_option("-l", action="store_true", dest="length", default=False,
		help="Output only lengths of all individual sequences.")
	parser.add_option("-L", action="store_true", dest="total_length", default=False,
		help="Output only total sequence length.")
	parser.add_option("-c", action="store_true", dest="char_dist", default=False,
		help="Output only character composition.")
	parser.add_option("-u", action="store_true", dest="upcase", default=False,
		help="Convert all characters to upper-case when printing character distribution.")
	(options, args) = parser.parse_args()
	if len(args) < 1:
		parser.error("Sorry, need a FASTA file as parameter!")

	output_all = not (options.char_dist or options.total_length or options.length)
	fastaname = args[0]

	char_dist = defaultdict(int)
	length_stat = []
	total_length = 0

	# Read FASTA file and gather statistics
	for s in FastaReader(fastaname):
		shortname = s.name.split()[0]
		length_stat.append((shortname,len(s.sequence)))
		total_length += len(s.sequence)
		if output_all or options.char_dist:
			for c in s.sequence:
				if options.upcase:
					char_dist[c.upper()] += 1
				else:
					char_dist[c] += 1

	# Output statistics
	if output_all:
		print('Total length:')
	if output_all or options.total_length:
		print(total_length)
	if output_all:
		print('\nCharacter distribution: <char> <count> <percentage>')
	if output_all or options.char_dist:
		all_chars = char_dist.keys()
		all_chars.sort()
		for c in all_chars:
			n = char_dist[c]
			print(c, n, n*100.0/total_length)
	if output_all:
		print('\nLengths of individual sequences: <id> <length>')
	if output_all or options.length:
		for name, length in length_stat:
			print(name,length)


if __name__ == '__main__':
	main()
