#!/usr/bin/env python
import sys
import optparse
import logging
import time
import multiprocessing as mp
from fbilr import utils
   

def worker(edge_sequences_list, barcodes_list, mode, confident_ed):

    hits_list = []

    if mode == "SE":
        for length, seq_head, seq_tail in edge_sequences_list:
            
            offset = length - len(seq_tail)

            hits = []

            for barcodes in barcodes_list:
                hit_bc = None
                hit_direction = None
                hit_location = None
                hit_start = None
                hit_end = None
                hit_ed = None
                for bc_name, bc_seq_f, bc_seq_r in barcodes:
                    
                    if len(bc_seq_f) > len(bc_seq_f):
                        continue

                    # head forward

                    x, y, ed = utils.align(bc_seq_f, seq_head)
                    if hit_ed is None or ed < hit_ed:
                        hit_bc = bc_name
                        hit_direction = "F"
                        hit_start = x
                        hit_end = y
                        hit_ed = ed
                        if hit_ed <= confident_ed:
                            break
                    
                    # head reverse
                    x, y, ed = utils.align(bc_seq_r, seq_head)
                    if ed < hit_ed:
                        hit_bc = bc_name
                        hit_direction = "R"
                        hit_start = x
                        hit_end = y
                        hit_ed = ed
                        if hit_ed <= confident_ed:
                            break
                    
                    # tail forward

                    x, y, ed = utils.align(bc_seq_f, seq_tail)
                    if ed < hit_ed:
                        hit_bc = bc_name
                        hit_direction = "F"
                        hit_start = x + offset
                        hit_end = y + offset
                        hit_ed = ed
                        if hit_ed <= confident_ed:
                            break
                    
                    # tail reverse

                    x, y, ed = utils.align(bc_seq_r, seq_tail)
                    if ed < hit_ed:
                        hit_bc = bc_name
                        hit_direction = "R"
                        hit_start = x + offset
                        hit_end = y + offset
                        hit_ed = ed
                        if hit_ed <= confident_ed:
                            break
                
                if hit_bc is not None:
                    mid = length / 2
                    if hit_end < mid:
                        hit_location = "H"
                    elif hit_start > mid:
                        hit_location = "T"
                    else:
                        hit_location = "M"

                hits.append([hit_bc, hit_direction, hit_location, hit_start, hit_end, hit_ed])

            hits_list.append(hits)
            
    elif mode == "PE":
        
        for length, seq_head, seq_tail in edge_sequences_list:
            
            offset = length - len(seq_tail)

            hits = []

            for barcodes in barcodes_list:
                
                # head
                
                hit_bc = None
                hit_direction = None
                hit_location = None
                hit_start = None
                hit_end = None
                hit_ed = None
                
                for bc_name, bc_seq_f, bc_seq_r in barcodes:
                    
                    if len(bc_seq_f) > len(bc_seq_f):
                        continue

                    # head forward

                    x, y, ed = utils.align(bc_seq_f, seq_head)
                    if hit_ed is None or ed < hit_ed:
                        hit_bc = bc_name
                        hit_direction = "F"
                        hit_start = x
                        hit_end = y
                        hit_ed = ed
                        if hit_ed <= confident_ed:
                            break
                    
                    # head reverse
                    
                    x, y, ed = utils.align(bc_seq_r, seq_head)
                    if ed < hit_ed:
                        hit_bc = bc_name
                        hit_direction = "R"
                        hit_start = x
                        hit_end = y
                        hit_ed = ed
                        if hit_ed <= confident_ed:
                            break
                    
                if hit_bc is not None:
                    mid = length / 2
                    if hit_end < mid:
                        hit_location = "H"
                    elif hit_start > mid:
                        hit_location = "T"
                    else:
                        hit_location = "M"
                        
                        hit_bc = None
                        hit_direction = None
                        hit_location = None
                        hit_start = None
                        hit_end = None
                        hit_ed = None

                hits.append([hit_bc, hit_direction, hit_location, hit_start, hit_end, hit_ed])
                
                # tail
                
                hit_bc = None
                hit_direction = None
                hit_location = None
                hit_start = None
                hit_end = None
                hit_ed = None
                
                for bc_name, bc_seq_f, bc_seq_r in barcodes:
                    
                    if len(bc_seq_f) > len(bc_seq_f):
                        continue
                    
                    # tail forward

                    x, y, ed = utils.align(bc_seq_f, seq_tail)
                    if hit_bc is None or ed < hit_ed:
                        hit_bc = bc_name
                        hit_direction = "F"
                        hit_start = x + offset
                        hit_end = y + offset
                        hit_ed = ed
                        if hit_ed <= confident_ed:
                            break
                    
                    # tail reverse

                    x, y, ed = utils.align(bc_seq_r, seq_tail)
                    if ed < hit_ed:
                        hit_bc = bc_name
                        hit_direction = "R"
                        hit_start = x + offset
                        hit_end = y + offset
                        hit_ed = ed
                        if hit_ed <= confident_ed:
                            break
                
                if hit_bc is not None:
                    mid = length / 2
                    if hit_end < mid:
                        hit_location = "H"
                    elif hit_start > mid:
                        hit_location = "T"
                    else:
                        hit_location = "M"
                        
                        hit_bc = None
                        hit_direction = None
                        hit_location = None
                        hit_start = None
                        hit_end = None
                        hit_ed = None

                hits.append([hit_bc, hit_direction, hit_location, hit_start, hit_end, hit_ed])

            hits_list.append(hits)

    return hits_list


def process_results(fw, reads, hits_list, ignore_name=False, include_read=False):
    for read, hits in zip(reads, hits_list):
        name, seq, qua = read
        length = len(read[1])
        row = None
        if ignore_name:
            row = [".", length]
        else:
            row = [name, length]
        for hit in hits:
            bc, direction, location, start, end, ed = hit
            if bc is None:
                bc = "."
                direction = "."
                location = "."
                start = -1
                end = -1
                ed = -1
            row.extend([bc, direction, location, start, end, ed])
        if include_read:
            row.extend([name, seq, "+", qua])
        line = "\t".join(map(str, row))
        fw.write(line)
        fw.write("\n")
    fw.flush()


def run_pipeline(f_read, f_barcode_list, f_outfile, 
                 width=200, mode="SE", threads=1, 
                 confident_edit_distance=0,
                 ignore_name=False, include_read=False):
    
    barcodes_list = [utils.load_barcodes(p) for p in f_barcode_list]  
    
    fw = None
    if f_outfile is None or f_outfile == "-":
        fw = sys.stdout
    else:
        fw = open(f_outfile, "w+")

    submitted_task_list = []
    reads_per_batch = 1000
    max_submitted_batch = threads * 2
    pool = mp.Pool(threads, maxtasksperchild=100)

    for i, reads in enumerate(utils.load_batch(f_read, reads_per_batch)):    
        
        # Testing
        # if i >= 240:
        #     break

        while True:
            
            # Priority process finished batch
            if len(submitted_task_list) > 0:
                task = submitted_task_list[0]
                r = task[2]
                if r.ready():
                    assert task[2].successful()
                    process_results(fw, task[1], task[2].get(), ignore_name, include_read)
                    submitted_task_list.pop(0)
                    logging.info("Processed batch %d" % task[0])
                    
            # Do not submit too many tasks
            if len(submitted_task_list) < max_submitted_batch:
                edge_seqs_list = [utils.cut_edge_sequence(read[1], width) for read in reads]
                args = (edge_seqs_list, barcodes_list, mode, confident_edit_distance)
                task = [i, reads, pool.apply_async(worker, args)]
                submitted_task_list.append(task)
                logging.info("Submitted batch %d" % i)
                break
            else:
                time.sleep(0.01)
                
    while len(submitted_task_list) > 0:
        task = submitted_task_list[0]
        r = task[2]
        if r.ready():
            assert r.successful()
            process_results(fw, task[1], task[2].get(), ignore_name, include_read)
            submitted_task_list.pop(0)
            logging.info("Processed batch %d" % task[0])
        else:
            time.sleep(0.01)
        
    pool.close()
    pool.terminate()
    pool.join()
    
    fw.close()
    

def main():
    
    logging.basicConfig(level=logging.INFO, format = '%(asctime)s - %(levelname)s: %(message)s')

    # Parameters
    
    parser = optparse.OptionParser(usage="%prog [options] reads.fq.gz")
    
    parser.add_option("-b", "--barcodes", dest="barcodes", metavar="PATH",
                      help="Path to barcode list in FASTA format. Multiple files separated by commas. [%default]")
    
    group = optparse.OptionGroup(parser, "Search")
    
    group.add_option("-t", "--threads", dest="threads", type="int", default=1, metavar="INT",
                      help="Number of threads used. [%default]")

    group.add_option("-c", "--confident-edit-distance", dest="confident_edit_distance", default=0, metavar="INT",
                      help="Edit distance threshold for confident barcode. When edit distance <= threshold, stop to find other barcodes. [%default]")
    
    group.add_option("-w", "--width", dest="width", type="int", default=200, metavar="INT",
                      help="Find barcode sequences within WIDTH of the boundary. [%default]")
        
    group.add_option("-m", "--mode", dest="mode", default="SE", metavar="STR",
                      help="Running mode: SE or PE. [%default]")
    
    parser.add_option_group(group)
    
    group = optparse.OptionGroup(parser, "Outputs")
    
    group.add_option("-o", "--outfile", dest="outfile", default=None, metavar="PATH",
                      help="Path of output matrix file. [%default]")
    
    group.add_option("-i", "--ignore-name", dest="ignore_name", action="store_true", default=False,
                     help="Ignore read name to save storage. [%default]")
    
    group.add_option("-q", "--include-read", dest="include_read", action="store_true", default=False,
                     help="Include fastq record in the output. [%default]")
    
    parser.add_option_group(group)
            
    options, args = parser.parse_args()
    
    f_read = None
    if len(args) != 1:
        logging.critical("Please provided input fastq file.")
        exit(1)
    else:
        f_read = args[0]
    
    f_barcode_list = None
    if options.barcodes is None:
        logging.critical("Please provide input barcode file.")
        exit(1)
    else:
        f_barcode_list = options.barcodes.split(",")
        
    logging.info("The input read file is %s" % f_read)
    logging.info("The input barcode files is %s" % ", ".join(f_barcode_list))
    logging.info("The width is %d" % options.width)
    logging.info("The confident edit distance is %d" % options.confident_edit_distance)
    logging.info("The search mode is %s" % options.mode)
    logging.info("Ignore name is %s" % options.ignore_name)
    logging.info("Include read is %s" % options.include_read)
    
    t1 = time.time()        
    run_pipeline(f_read=f_read, 
                 f_barcode_list=f_barcode_list, 
                 f_outfile=options.outfile,
                 width=options.width,
                 mode=options.mode,
                 threads=options.threads,
                 confident_edit_distance=options.confident_edit_distance,
                 ignore_name=options.ignore_name, 
                 include_read=options.include_read)
    t2 = time.time()
    t = t2 - t1
    h = int(t / 3600)
    t = t % 3600
    m = t / 60
    s = t % 60
    logging.info("Spent %dh%dm%.2fs" % (h, m, s))
    

if __name__ == "__main__":
    main()

