#!/usr/bin/env python2.7

import logging
from skbio.tree import TreeNode
import os
import sys
import argparse
from string import split as _

try:
    import tree2tax
except ImportError:
    sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)),'..'))
from tree2tax.tree2tax import Tree2Tax, TaxonomyFunctions



parser = argparse.ArgumentParser(description='''--- tree2tax %s --- partitions a tree into clades separated by a given distance threshold''' % tree2tax.__version__)
parser.add_argument('-t', '--tree', help='newick format tree file to partition', required=True)
parser.add_argument('-d', '--thresholds', nargs='+', help='thresholds at which to partition the tree, space separated', type=float, required=True)
parser.add_argument('-o', '--output_taxonomy', help='output the taxonomy to this file', required=True)
parser.add_argument('--debug', help='output debug information', action="store_true")

args = parser.parse_args()

if args.debug:
    logging.basicConfig(level=logging.DEBUG)
else:
    logging.basicConfig(level=logging.INFO)

logging.info("Reading tree file..")
tree = TreeNode.read(args.tree)
logging.info("Read in tree with %s tips" % tree.count(tips=True))

logging.info("Clustering..")
threshold_and_clusters = Tree2Tax().named_clusters_for_several_thresholds(tree, args.thresholds)

threshold_and_clusters.reverse() #display higher taxonomic levels first, then lower ones

threshold_names = _('K P C O F G S')
    
for threshold_clusters in threshold_and_clusters:
    clusters = threshold_clusters.clusters
    logging.info("Found %i clusters for threshold %s" % (len(clusters), threshold_clusters.threshold))
    
    num_singleton_clades = 0
    for named_clade in clusters:
        if len(named_clade.tips) == 1: num_singleton_clades += 1
        
    logging.info("Of these clusters, %s contained only a single sequence" % num_singleton_clades)
    
output_file_name = args.output_taxonomy
with open(output_file_name,'w') as f:
    for tip in threshold_and_clusters[0].each_tip():
        f.write(tip.name)
        f.write("\t");
        
        last_round = None
        for i, tc in enumerate(threshold_and_clusters):
            if i != 0: f.write("; ")
            
            f.write(threshold_names[i])
            f.write('__')
            
            if i != 0:
                # Work out if there is any missing taxonomic info between this node and the last one 
                missings = TaxonomyFunctions().missing_taxonomy(tree, 
                                                               tc.tip_to_cluster(tip).lca_node, 
                                                               last_round.tip_to_cluster(tip).lca_node)
                # don't count the node that is already recorded in the other
                # part of the taxonomy, if that is recorded
                if TaxonomyFunctions.taxonomy_from_node_name(tc.tip_to_cluster(tip).lca_node.name) and len(missings) > 0:
                    f.write('%s..' % '.'.join(missings))
                elif len(missings) > 1:
                    f.write('%s..' % '.'.join(missings[:-1]))
                
            f.write(tc.tip_to_cluster(tip).condensed_name())
            last_round = tc
        f.write("\n")
        
logging.info("Finished writing new taxonomy to %s" % output_file_name)
        
