#!/usr/bin/env python
from pita.model import get_chrom_models, load_chrom_data
from pita.util import model_to_bed, read_statistics, exons_to_seq, longest_orf
from pita.log import setup_logging
from pita.config import config
from pita.annotationdb import AnnotationDb
from pita.utr import *
import os
import sys
import argparse
import multiprocessing as mp
from functools import partial
import signal
from pkg_resources import require

DEFAULT_CONFIG = "pita.yaml"
DEFAULT_THREADS = 4
VERSION = require("pita")[0].version

p = argparse.ArgumentParser()
group = p.add_mutually_exclusive_group(required=True)
group.add_argument("-c",
               dest= "configfile",
               default = DEFAULT_CONFIG,
               help="configuration file (default: {0})".format(DEFAULT_CONFIG)
              )
p.add_argument("-t",
               dest= "threads",
               default = DEFAULT_THREADS,
               type = int,
               help="number of threads (default: {0})".format(DEFAULT_THREADS)
              )
p.add_argument("-i",
               dest= "index_dir",
               default = None,
               help="genome index dir"
              )
p.add_argument("-y",
               dest= "yaml_file",
               default = None,
               help="dump database to yaml file"
              )
p.add_argument("-d",
               dest= "debug_level",
               default = "INFO",
               help="Debug level"
              )
p.add_argument("-r",
                dest ="reannotate",
                default = False,
                action = "store_true",
                help="reannotate using existing database"
              )
group.add_argument("-v",
                dest="version",
                action="store_true",
                help="Print the program version and exit"
              )


args = p.parse_args()
if args.version:
    print("Version: "+VERSION)
    sys.exit()

configfile = args.configfile

if not os.path.exists(configfile):
    print "Missing config file {}".format(configfile)
    print
    p.print_help()
    sys.exit()

threads = args.threads
index = args.index_dir
debug_level = args.debug_level

basename = os.path.splitext(configfile)[0]

# Setup logging
logger = setup_logging(basename, debug_level)

# Load config file
config.load(configfile, args.reannotate)

# FASTA output
protein_fh = open("{}.protein.fa".format(basename), "w")
cdna_fh = open("{}.cdna.fa".format(basename), "w")

#BED output for non utr extended version
bedOutputName = "{}.notExtended.bed".format(basename)
bedOutput = open(bedOutputName,"w")

line_format = "{0}\t{1}\t{2}\t{3}\t{4}\t{5}\t{6}\t{7}\t{8}\t{9}\t{10}\t{11}"
print 'track name="{0}"'.format(basename)
chroms = config.chroms

def init_worker():
    signal.signal(signal.SIGINT, signal.SIG_IGN)

"""
Going to attempt to write the bed output to a tempfile, then run pita_utr on the bed file, and print the results to the consile after that
"""
def print_output(genename, exons, lock=None):
    #alog.log_to_file(genename, exons) 

    if lock:
        lock.acquire()
    bedOutput.write("{}\n".format(model_to_bed(exons, genename)))    
    bedOutput.flush()
    #sys.stdout.write("{}\n".format(model_to_bed(exons, genename)))
    #sys.stdout.flush()



    # Print sequences
    cdna_fh.write(">{0}\n{1}\n".format(genename, exons_to_seq(exons)))
    
    pep = longest_orf(exons, do_prot=True)
    if len(pep) >= config.min_protein_size:
        protein_fh.write(">{0}\n{1}\n".format(genename, pep))
    cdna_fh.flush()
    protein_fh.flush()
    if lock:
        lock.release()

def listener(q, lock=None):
    '''listens for messages on the q, writes to file. '''
    
    while 1:
        m = q.get()
        if m == 'kill':
            break
         
        genename, exons = m
        logger.debug("calling print_output for {0}".format(genename))
        print_output(genename, exons, lock)

def annotate_chrom(chrom, conn, q, anno_files, data, repeats, weight, prune, keep, filter_ev, experimental, index, reannotate):
    new = False
    if conn.startswith("sqlite"):
        conn += ".{}".format(chrom)
        if not reannotate:
            new = True
    logger.info("Chromosome {0} started".format(chrom))
    if not reannotate:
        load_chrom_data(conn, new, chrom, anno_files, data, index)
    for genename, best_exons in get_chrom_models(conn, chrom, weight, repeats, prune, keep, filter_ev, experimental):
        #results.append([genename, best_exons])
        logger.debug("Putting {0} in print queue".format(genename))
        q.put([genename, best_exons])

    logger.info("Chromosome {0} finished".format(chrom))

# Initialize database
if not args.reannotate:
    db = AnnotationDb(new=True, conn=config.db_conn)

if threads > 1:
    logger.info("Starting threaded work")
    manager = mp.Manager()
    lock = manager.Lock()
    q = manager.Queue()
    pool = mp.Pool(threads, init_worker, maxtasksperchild=1) 
    
    try:  
        
        #put listener to work first
        watcher = pool.apply_async(listener, args=(q, lock) )
        
        # do the main work 
        partialAnnotate = partial(annotate_chrom, conn=config.db_conn, q=q, anno_files=config.anno_files, data=config.data, repeats=config.repeats, weight=config.weight, prune=config.prune, keep=config.keep, filter_ev=config.filter, experimental=config.experimental, index=index, reannotate=args.reannotate)
        pool.map(partialAnnotate, chroms) 
        
        # kill the queue!
        q.put('kill')
        pool.close()
        pool.join()

    except KeyboardInterrupt:
        logger.exception("Caught KeyboardInterrupt, terminating workers")
        pool.terminate()
        pool.join()
else:
    for chrom in chroms:
        if not args.reannotate:
            load_chrom_data(config.db_conn, True, chrom, config.anno_files, config.data,index)
        for genename, best_exons in get_chrom_models(config.db_conn, chrom, config.weight, repeats=config.repeats, prune=config.prune, keep=config.keep, filter_ev=config.filter, experimental=config.experimental):
            print_output(genename, best_exons)

cdna_fh.close()
protein_fh.close()
bedOutput.close()

if config.pitaUTR:
<<<<<<< HEAD
    sys.stdout.write(config.data)
    sys.stdout.flush()
    print(config.data) 
    
=======
	for name, fname, span, extend in config.data:
		if span == "all":
			bamFiles = fname
			print_updated_bed(bedOutputName, bamFiles)
>>>>>>> fd198800e6fa0e61c935f8da1927076d24d7c5ec

# dump database
if args.yaml_file:
    with AnnotationDb(new=False) as db:
        with open(args.yaml_file, "w") as f:
            f.write(db.dump_yaml())
