#!/usr/bin/env python
__license__ = "MIT"
__version__ = "0.9.3"
__authors__ = ["Marvin Jens","Alex Robertson"]
__email__ = "mjens@mit.edu"

import sys
import itertools
import numpy as np
import copy
import time
import os
import logging
import ska_kmers
 
def read_freqs(src):
    """
    Used to read kmer background frequencies from a previous run on input
    """
    
    f = []
    for line in file(src):
        if line.startswith('#'): 
            continue
        
        parts = line.rstrip().split('\t')

        f.append(float(parts[1]))

    return np.array(f, dtype=np.float32)



def stream_counts(k, seqm, num_iterations=10, convergence=0.5, background = []):
    logger = logging.getLogger("stream_counts()")
    kmers = list(yield_kmers(k))
    t0 = time.time()
    kmer_counts = ska_kmers.seq_set_kmer_count(seqm,k)
    t1 = time.time()
    logger.debug("counted kmer occurrences in {0:.1f} seconds".format( (t1-t0) ) )

    current_weights = np.ones(4 ** k, dtype=np.float32)
    N = kmer_counts.sum()
    kmer_freqs = np.array(kmer_counts, dtype=np.float32) * (4**k)/N

    if not len(background):
        background = np.ones(4 ** k, dtype=np.float32)
    else:
        background = np.array(background, dtype=np.float32)

    weight_history = []
    for iteration_i in range(num_iterations):
        new_weights = ska_kmers.seq_set_SKA(seqm, current_weights, background, k)
            
        weight_history.append(new_weights)
        current_weights = copy.copy(new_weights)
        
        if len(weight_history) > 1:
            delta = weight_history[-1] - weight_history[-2]
            
            change = np.fabs(delta)
            mean_change = change.mean()
            max_i = change.argmax()
            max_change = change[max_i]
            
            logger.info("#iteration {0}: mean_change={1} max_change={2} for '{3}' ({4} -> {5})".format(iteration_i, mean_change, max_change, kmers[max_i], weight_history[-2][max_i], weight_history[-1][max_i]))
            if max_change < convergence:
                logger.info("reached convergence after {0} iterations".format(iteration_i))
                break
            
    return current_weights, kmer_freqs

        

def main():
    from optparse import OptionParser
    usage = "usage: %prog [options] <reads_file> OR cat <reads_file> | %prog [options] /dev/stdin"

    parser = OptionParser(usage=usage)
    parser.add_option("-k","--min-k",dest="min_k",default=3,type=int,help="min kmer size (default=3)")
    parser.add_option("-K","--max-k",dest="max_k",default=8,type=int,help="max kmer size (default=8)")
    
    parser.add_option("-n","--n-passes",dest="n_passes",default=10,type=int,help="max number of passes (default=10)")
    parser.add_option("-c","--convergence",dest="convergence",default=0.5,type=float,help="convergence is reached when max. change in absolute weight is below this value (default=0.5)")
    parser.add_option("-B","--background", dest="background", default="", help="path to file with background (input) kmer abundances in the library")
    parser.add_option("-o","--output",dest="output",default=".",help="path where results are to be stored")
    parser.add_option("","--debug",dest="debug",default=False, action="store_true",help="SWITCH: activate debug output")
    parser.add_option("","--version",dest="version",default=False, action="store_true",help="SWITCH: show version information and quit")
    options,args = parser.parse_args()

    if options.version:
        print __version__
        print __license__
        print "by", ", ".join(__authors__)
        sys.exit(0)

    # prepare outout path
    out_path = os.path.abspath(options.output)
    if not os.path.exists(out_path):
        os.makedirs(out_path)

    # set up logging
    log_path = os.path.join(out_path,"run.log")
    if options.debug:
        lvl = logging.DEBUG
    else:
        lvl = logging.INFO

    FORMAT = '%(asctime)-20s\t%(levelname)s\t%(name)s\t%(message)s'
    formatter = logging.Formatter(FORMAT)
    logging.basicConfig(level=lvl, format=FORMAT)    
    root = logging.getLogger('')
    fh = logging.FileHandler(filename=log_path, mode='w')
    fh.setFormatter(logging.Formatter(FORMAT))
    root.addHandler(fh)
    
    logger = logging.getLogger("SKA")
    logger.info("called as '{0}'".format(" ".join(sys.argv)) )

    if not args:
        parser.error("missing argument: need <reads_file> (or use /dev/stdin)")
        sys.exit(1)

    t0 = time.time()
    # load and keep all sequences in memory (numerically A=0,...T=3 )
    #seqm = kmer.read_raw_seqs_C(args[0])
    seqm = ska_kmers.read_raw_seqs_chunked(file(args[0]), chunklines=2000000)
    t1 = time.time()
    logger.info("read {0:.3f}M sequences of length {1} in {2:.1f} seconds".format(len(seqm)/1E6, len(seqm[0]), (t1-t0) ) )

    for k in range(options.min_k, options.max_k+1):
        logger.info("streaming {0}-mers".format(k))
        
        if options.background:
            bg_path = os.path.join(options.background,"rel_freqs.{0}mer.txt".format(k))
            background = read_freqs(bg_path)
            logger.info("loaded background kmer frequencies from {0}".format(bg_path))
        else:
            background = []

        SKA_weights, kmer_freqs = stream_counts(
            k,
            seqm,
            num_iterations=options.n_passes,
            convergence=options.convergence,
            background=background
        )
        
        of = open(os.path.join(out_path, 'SKA_weights.{0}mer.txt'.format(k)), 'w')
        of.write('# kmer\tweight\n')
        for kmer, weight in zip(yield_kmers(k), SKA_weights):
            of.write('%s\t%g\n' % (kmer, weight))
        of.close()

        of = open(os.path.join(out_path, 'rel_freqs.{0}mer.txt'.format(k)), 'w')
        of.write('# kmer\trel_freq\n')
        for kmer, freq in zip(yield_kmers(k), kmer_freqs):
            of.write('%s\t%g\n' % (kmer, freq))
        of.close()

 
def yield_kmers(k):
    """
    An iterater to all kmers of length k in alphabetical order
    """
    bases = 'ACGT'
    for kmer in itertools.product(bases, repeat=k):
        yield ''.join(kmer)

if __name__ == '__main__':
    main()
