#!/usr/bin/env python2.7

import logging
from skbio.tree import TreeNode, NoLengthError
import os
import sys
import argparse
import tempfile
import random
import subprocess

try:
    import tree2tax
except ImportError:
    sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)),'..'))
from tree2tax.tree2tax import Tree2Tax



parser = argparse.ArgumentParser(description='''--- reports randomly selected pairs of sequences, their percent identity and tree distance''')
parser.add_argument('-t', '--tree', help='newick format tree file to partition', required=True)
parser.add_argument('-f', '--fasta', help='fasta file of sequences with corresponding IDs', required=True)
parser.add_argument('-n', '--num_comparisons', type=int, help='fasta file of sequences with corresponding IDs', default=10)
parser.add_argument('--debug', help='output debug information', action="store_true")

args = parser.parse_args()

if args.debug:
    logging.basicConfig(level=logging.DEBUG)
else:
    logging.basicConfig(level=logging.INFO)
    
    
# Read in the tree
logging.info("Reading tree..")
tree = TreeNode.read(args.tree)
tips = list(tree.tips())
logging.info("Read in tree with %s tips" % len(tips))

# Create two tempfiles for blasting
temp1 = tempfile.NamedTemporaryFile()
temp2 = tempfile.NamedTemporaryFile()


# num_comparison times do
num_fail = 0
for i in range(args.num_comparisons):
    logging.debug("%s done" % i)
    tip1, tip2 = random.sample(tips, 2)
    if tip1 == tip2: continue #ignore when the same thing gets given twice
    
    # determine their tree distance
    try:
        tree_distance = tip1.distance(tip2)
    except:
        num_fail += 1
        continue
    
    # determine their percent identity with blast, after extracting their sequences with samtools faidx
    #IPython.embed()
    cmd = "samtools faidx %s %s > %s" % (args.fasta, tip1.name, temp1.name)
    subprocess.check_call(cmd, shell=True)
    cmd = "samtools faidx %s %s > %s" % (args.fasta, tip2.name, temp2.name)
    subprocess.check_call(cmd, shell=True)
    cmd = "blastn -query %s -subject %s -outfmt 6 -task blastn" % (temp1.name, temp2.name)
    blast_output = subprocess.check_output(cmd, shell=True)
    #print "%s/%s: %s" % (tip1.name, tip2.name, blast_output)
    perc_id = blast_output.split()[2]
    
    # print tree distance and percent identity
    print "\t".join([tip1.name,
                     tip2.name,
                     str(tree_distance),
                     perc_id
                     ])

temp1.close()
temp2.close()
logging.info("%s comparisons failed" % num_fail)
