#!/usr/bin/env python

# Created on Thu Apr 30 15:15:30 2015

# Author: XiaoTao Wang

from __future__ import division
import argparse, sys, logging, logging.handlers, hicpeaks

import numpy as np
from scipy import sparse

currentVersion = hicpeaks.__version__

def getargs():
    ## Construct an ArgumentParser object for command-line arguments
    parser = argparse.ArgumentParser(usage = '%(prog)s <-O output> [options]',
                                     description = 'Local Peak Calling for Hi-C Data',
                                     formatter_class = argparse.ArgumentDefaultsHelpFormatter)
    
    # Version
    parser.add_argument('-v', '--version', action='version',
                        version=' '.join(['%(prog)s',currentVersion]),
                        help='Print version number and exit.')

    # Output
    parser.add_argument('-O', '--output', help = 'Output peak file path.')
    parser.add_argument('--logFile', default = 'HICCUPS.log', help = 'Logging file name.')
    
    group_1 = parser.add_argument_group(title = 'Relate to Hi-C data:')
    group_1.add_argument('-p', '--path',
                         help = 'URI string pointing to a cooler under specific resolution.')
    group_1.add_argument('-C', '--chroms', nargs = '*', default = ['#', 'X'],
                        help = 'List of chromosome labels. Only Hi-C data within the specified '
                        'chromosomes will be included. Specially, "#" stands for chromosomes '
                        'with numerical labels. "--chroms" with zero argument will include '
                        'all chromosome data.')
    
    ## About the algorithm
    group_2 = parser.add_argument_group(title = 'Algorithm Parameters:')
    group_2.add_argument('--pw', type = int, default = 2, help = 'Width of the interaction '
                         'region surrounding the peak. According to experience, we set it'
                         ' to 1 at 20 kb, 2 at 10 kb, and 4 at 5 kb.')
    group_2.add_argument('--ww', type = int, default = 5, help = 'Width of the donut '
                         'sampled. Set it to 3 at 20 kb, 5 at 10 kb, and 7 at 5 kb.')
    group_2.add_argument('--maxww', type = int, default = 20, help = 'Maximum donut width.')
    group_2.add_argument('--siglevel', type = float, default = 0.1, help = 'Significant Level.')
    group_2.add_argument('--sumq', type = float, default = 0.08,
                         help = '''During the additional filtering, original peak pixels would be
                         filtered out if there are no other peak pixels located in its neighborhood
                         and the sum of its 4 q-values is greater than this threshold.''')
    group_2.add_argument('--maxapart', type = int, default = 5000000, help = 'Maximum genomic'
                         ' distance between two involved loci.')
    group_2.add_argument('--nproc', type = int, default = 1, help = 'Number of worker processes.')
    
    ## Parse the command-line arguments
    commands = sys.argv[1:]
    if not commands:
        commands.append('-h')
    args = parser.parse_args(commands)
    
    return args, commands

def run():
     # Parse Arguments
    args, commands = getargs()
    # Improve the performance if you don't want to run it
    if commands[0] not in ['-h', '--help']:
        ## Root Logger Configuration
        logger = logging.getLogger()
        # Logger Level
        logger.setLevel(10)
        console = logging.StreamHandler()
        filehandler = logging.handlers.RotatingFileHandler(args.logFile,
                                                           maxBytes = 200000,
                                                           backupCount = 5)
        # Set level for Handlers
        console.setLevel('INFO')
        filehandler.setLevel('INFO')
        # Customizing Formatter
        formatter = logging.Formatter(fmt = '%(name)-14s %(levelname)-7s @ %(asctime)s: %(message)s',
                                      datefmt = '%m/%d/%y %H:%M:%S')
        
        console.setFormatter(formatter)
        filehandler.setFormatter(formatter)
        # Add Handlers
        logger.addHandler(console)
        logger.addHandler(filehandler)
        
        ## Logging for argument setting
        arglist = ['# ARGUMENT LIST:',
                   '# Output file = {}'.format(args.output),
                   '# Cooler URI = {}'.format(args.path),
                   '# Chromosomes = {}'.format(args.chroms),
                   '# Peak window width = {}'.format(args.pw),
                   '# Donut width = {}'.format(args.ww),
                   '# Maximum donut width = {}'.format(args.maxww),
                   '# Significant Level = {}'.format(args.siglevel),
                   '# Sum of 4 q-values = {}'.format(args.sumq),
                   '# Maximum Genomic distance = {}'.format(args.maxapart),
                   '# Number of Processes = {}'.format(args.nproc)
                   ]
        
        argtxt = '\n'.join(arglist)
        logger.info('\n'+argtxt)
        
        # Package Dependencies
        import cooler
        from multiprocess import Pool
        from sklearn import isotonic

        def worker(tuple_arg):
            Lib, key, pw, ww, siglevel, sumq, maxww, maxapart, resolution = tuple_arg
            logger.info('Chromosome %s ...', key)
            H = Lib.matrix(balance=False, sparse=True).fetch(key)
            cHeatMap = Lib.matrix(balance=True, sparse=True).fetch(key)
            # Customize Sparse Matrix ...
            chromLen = H.shape[0]
            num = args.maxapart // resolution + args.maxww + 1
            Diags = [H.diagonal(i) for i in np.arange(num)]
            M = sparse.diags(Diags, np.arange(num), format='csr')
            x = np.arange(args.ww, num)
            y = []
            cDiags = []
            for i in x:
                diag = cHeatMap.diagonal(i)
                mask = np.isnan(diag)
                notnan = diag[np.logical_not(mask)]
                y.append(notnan.mean())
                diag[mask] = 0
                cDiags.append(diag)
            cM = sparse.diags(cDiags, x, format='csr')
            IR = isotonic.IsotonicRegression(increasing='auto')
            IR.fit(x, y)

            del H, cHeatMap

            tmp = Lib.bins().fetch(key)['weight'].values
            mask = np.logical_not(tmp==0)
            biases = np.zeros_like(tmp)
            biases[mask] = 1/tmp[mask]

            Donuts, LL, VB, HB = pcaller(M, cM, biases, IR, chromLen, Diags, cDiags, num, key,
                                        pw=pw, ww=ww, sig=siglevel, sumq=sumq, maxww=maxww, maxapart=maxapart,
                                        res=resolution)
            
            return key, Donuts, LL, VB, HB
        
        logger.info('Loading Hi-C data ...')
        Lib = cooler.Cooler(args.path)
        resolution = Lib.binsize

        if args.nproc == 1:
            map_ = map
        else:
            pool = Pool(args.nproc)
            map_ = pool.map
        
        logger.info('Calling Peaks ...')
        OF = open(args.output, 'wb')
        head = '\t'.join(['chromLabel', 'loc_1', 'loc_2', 'IF', 'D-Enrichment', 'D-pvalue', 'D-qvalue',
                          'LL-Enrichment', 'LL-pvalue', 'LL-qvalue', 'V-Enrichment', 'V-pvalue', 'V-qvalue',
                          'H-Enrichment', 'H-pvalue', 'H-qvalue']) + '\n'
        OF.write(head)
        
        Params = []
        for key in Lib.chromnames:
            if ((not args.chroms) or (key.isdigit() and '#' in args.chroms) or (key in args.chroms)):
                Params.append((Lib, key, args.pw, args.ww, args.siglevel, args.sumq, args.maxww, args.maxapart, resolution))
        
        results = map_(worker, Params)

        for key, Donuts, LL, VB, HB in results:
            for i in Donuts:
                lineFormat = '{0}\t{1}\t{2}\t{3:.4g}\t{4:.4g}\t{5:.4g}\t{6:.4g}\t{7:.4g}\t{8:.4g}\t{9:.4g}\t{10:.4g}\t{11:.4g}\t{12:.4g}\t{13:.4g}\t{14:.4g}\t{15:.4g}\n'
                contents = (key,) + i + Donuts[i] + LL[i][1:] + VB[i][1:] + HB[i][1:]
                line = lineFormat.format(*contents)
                OF.write(line)        
                    
        OF.flush()
        OF.close()
                            
        logger.info('Done!')

def pcaller(M, cM, biases, IR, chromLen, Diags, cDiags, num, chrom, pw=2, ww=5, sig=0.1, sumq=0.08, maxww=10,
            maxapart=2000000, res=10000):
    
    # Necessary Modules
    from scipy.stats import poisson
    from statsmodels.sandbox.stats.multicomp import multipletests
    from sklearn.cluster import dbscan
    
    logger = logging.getLogger()

    # more codes for lower memory
    # use reference instead of creating new arrays
    extDiags_ref = []
    for i in range(num):
        OneDArray = Diags[i]
        extODA = np.zeros(chromLen - i + maxww*2)
        extODA[maxww:-maxww] = OneDArray
        extDiags_ref.append(extODA)
    
    extDiags = {maxww: extDiags_ref}
    for w in range(ww, maxww):
        temp = []
        for i in range(num):
            delta = maxww-w
            extODA = extDiags_ref[i][delta:-delta]
            temp.append(extODA)
        extDiags[w] = temp
    
    x = np.arange(ww, num)
    predictE = IR.predict(x)
    predictE[predictE < 0] = 0
    EDiags = []
    for i in range(x.size):
        OneDArray = np.ones(chromLen - x[i]) * predictE[i]
        EDiags.append(OneDArray)
    
    EM = sparse.diags(EDiags, x, format = 'csr')

    extCDiags_ref = []
    extEDiags_ref = []
    for i in range(x.size):
        extODA_C = np.zeros(chromLen - x[i] + maxww*2)
        extODA_C[maxww:-maxww] = cDiags[i]
        extCDiags_ref.append(extODA_C)
        extODA_E = np.zeros(chromLen - x[i] + maxww*2)
        extODA_E[maxww:-maxww] = EDiags[i]
        extEDiags_ref.append(extODA_E)
    
    extCDiags = {maxww: extCDiags_ref}
    extEDiags = {maxww: extEDiags_ref}
    for w in range(ww, maxww):
        tempC = []
        tempE = []
        for i in range(x.size):
            delta = maxww - w
            extODA_C = extCDiags_ref[i][delta:-delta]
            tempC.append(extODA_C)
            extODA_E = extEDiags_ref[i][delta:-delta]
            tempE.append(extODA_E)
        extCDiags[w] = tempC
        extEDiags[w] = tempE
    
    ps = 2 * pw + 1 # Peak Size
                
    ## Peak Calling ...    
    xi, yi = M.nonzero()
    Mask = ((yi - xi) >= ww) & ((yi - xi) <= (maxapart // res))
    xi = xi[Mask]
    yi = yi[Mask]
    flocals = ['K', 'Y', 'V', 'H'] # V <-> H
    bSV = {}; bEV = {}
    for fl in flocals:
        bSV[fl] = np.zeros(xi.size)
        bEV[fl] = np.zeros(xi.size)
    
    logger.info('Chrom:{0}, Observed Contact Number: {1}'.format(chrom, xi.size))
    
    RefIdx = np.arange(xi.size)
    RefMask = np.ones(xi.size, dtype = bool)
    
    iniNum = xi.size
    
    logger.info('Chrom:{0}, Four local neighborhoods, four expected matrices ...'.format(chrom))
    for w in range(ww, maxww + 1):
        ws = 2 * w + 1
        bS = {}; bE = {}
        for fl in flocals:
            bS[fl] = sparse.csr_matrix((chromLen, chromLen))
            bE[fl] = sparse.csr_matrix((chromLen, chromLen))
        Reads = sparse.csr_matrix((chromLen, chromLen))
        logger.info('Chrom:{0},    Current window width: {1}'.format(chrom, w))
        P1 = set([(i,j) for i in range(w-pw, ps+w-pw) for j in range(w-pw, ps+w-pw)]) # Center Peak Region
        P_1 = set([(i,j) for i in range(w+1, ws) for j in range(w)])
        P_2 = set([(i,j) for i in range(w+1, ps+w-pw) for j in range(w-pw, w)])
        P2 = P_1 - P_2 # Lower-left Region
        P_v_1 = set([(i,j) for i in range(w-1, w+2) for j in range(w-pw)]) # higher interaction
        P_v_2 = set([(i,j) for i in range(w-1, w+2) for j in range(ps+w-pw, ws)])
        P_h_1 = set([(i,j) for i in range(ps+w-pw, ws) for j in range(w-1, w+2)]) # higher interaction
        P_h_2 = set([(i,j) for i in range(w-pw) for j in range(w-1, w+2)])

        ss = range(ws)
        Pool_Diags = {}
        Pool_EDiags = {}
        Pool_cDiags = {}
        for i in ss:
            for j in ss:
                Pool_Diags[(i,j)] = []
                Pool_EDiags[(i,j)] = []
                Pool_cDiags[(i,j)] = []
                for oi in range(num):
                    if oi + i - j >= 0:
                        starti = i
                        endi = i + chromLen - (oi + i - j)
                    else:
                        starti = i - (oi + i - j)
                        endi = starti + chromLen + (oi + i - j)
                    Pool_Diags[(i,j)].append(extDiags[w][oi][starti:endi])
                for oi in range(x.size):
                    if x[oi] + i - j >= 0:
                        starti = i
                        endi = i + chromLen - (x[oi] + i - j)
                    else:
                        starti = i - (x[oi] + i - j)
                        endi = starti + chromLen + (x[oi] + i - j)
                    Pool_EDiags[(i,j)].append(extEDiags[w][oi][starti:endi])
                    Pool_cDiags[(i,j)].append(extCDiags[w][oi][starti:endi])
        
        for key in Pool_Diags:
            cDiags_matrix = sparse.diags(Pool_cDiags[key], x + (key[0] - key[1]), format = 'csr')
            EDiags_matrix = sparse.diags(Pool_EDiags[key], x + (key[0] - key[1]), format = 'csr')
            if (key[0] != w) and (key[1] != w) and (key not in P1) and (key not in P2):
                bS['K'] = bS['K'] + cDiags_matrix
                bE['K'] = bE['K'] + EDiags_matrix
            if key in P2:
                bS['K'] = bS['K'] + cDiags_matrix
                bE['K'] = bE['K'] + EDiags_matrix
                bS['Y'] = bS['Y'] + cDiags_matrix
                bE['Y'] = bE['Y'] + EDiags_matrix
                Reads = Reads + sparse.diags(Pool_Diags[key], np.arange(num) + (key[0] - key[1]), format = 'csr')
            if key in P_v_1:
                bS['V'] = bS['V'] + cDiags_matrix
                bE['V'] = bE['V'] + EDiags_matrix
            if key in P_v_2:
                bS['V'] = bS['V'] - cDiags_matrix
                bE['V'] = bE['V'] - EDiags_matrix
            if key in P_h_1:
                bS['H'] = bS['H'] + cDiags_matrix
                bE['H'] = bE['H'] + EDiags_matrix
            if key in P_h_2:
                bS['H'] = bS['H'] - cDiags_matrix
                bE['H'] = bE['H'] - EDiags_matrix
                
        Txi = xi[RefIdx]
        Tyi = yi[RefIdx]
        RNums = np.array(Reads[Txi, Tyi]).ravel()
        EIdx = RefIdx[RNums >= 16]
        logger.info('Chrom:{0},    Valid Contact Number: {1}'.format(chrom, EIdx.size))
        Valid_Ratio = EIdx.size/float(iniNum)
        logger.info('Chrom:{0},    Valid Contact Ratio: {1:.3f}'.format(chrom, Valid_Ratio))
        Exi = xi[EIdx]
        Eyi = yi[EIdx]
        for fl in flocals:
            bSV[fl][EIdx] = np.array(bS[fl][Exi, Eyi]).ravel()
            bEV[fl][EIdx] = np.array(bE[fl][Exi, Eyi]).ravel()
                
        RefIdx = RefIdx[RNums < 16]
            
        iniNum = RefIdx.size
        
        if Valid_Ratio < 0.3:
            logger.info('Chrom:{0},    Ratio of valid contact is too small, break the loop ...'.format(chrom))
            break
        
        logger.info('Chrom:{0},    Continue ...'.format(chrom))
        logger.info('Chrom:{0},    {1} Contacts will get into next loop ...'.format(chrom, RefIdx.size))
    
    RefMask[RefIdx] = False
    
    Mask = (bEV['K'] != 0) & (bEV['Y'] != 0) & (bEV['V'] != 0) & (bEV['H'] != 0) & RefMask
    xi = xi[Mask]
    yi = yi[Mask]
    bRV = {}
    for fl in flocals:
        bRV[fl] = bSV[fl][Mask] / bEV[fl][Mask]
    
    bR = {}
    for fl in flocals:
        bR[fl] = sparse.lil_matrix((chromLen, chromLen))
        bR[fl][xi, yi] = bRV[fl]
    
    ## Corrected Expected Matrix
    cEM = {}
    for fl in flocals:
        cEM[fl] = EM.multiply(bR[fl].tocsr())
    
    logger.info('Chrom:{0}, Poisson Models and Benjamini-Hochberg Correcting for lambda chunks ...'.format(chrom))
    Description = {'K': 'Donut backgrounds', 'Y': 'Lower-left backgrounds',
                   'V': 'Vertical backgrounds', 'H': 'Horizontal backgrounds'}
    xpos = {}; ypos = {}; Ovalues = {}; Evalues = {}
    Fold = {}; pvalues = {}; qvalues = {}
    gaps = set(np.where(np.array(M.sum(axis=1)).ravel() == 0)[0])
    for fl in flocals:
        logger.info('Chrom:{0},    {1} ...'.format(chrom, Description[fl]))
        xi, yi = cEM[fl].nonzero()
        Evalues[fl] = np.array(cEM[fl][xi, yi]).ravel() * biases[xi] * biases[yi]
        Mask = (Evalues[fl] > 0)
        Evalues[fl] = Evalues[fl][Mask]
        xi = xi[Mask]
        yi = yi[Mask]
        Ovalues[fl] = np.array(M[xi, yi]).ravel()
        Fold[fl] =  Ovalues[fl] / Evalues[fl]
        logger.info('Chrom:{0},    Valid contact number: {1}'.format(chrom, xi.size))
        
        pvalue = np.ones(xi.size)
        qvalue = np.ones(xi.size)
        
        logger.info('Chrom:{0},    Lambda chunking ...'.format(chrom))
        chunks = lambdachunk(Evalues[fl])
        logger.info('Chrom:{0},    Number of chunks: {1}'.format(chrom, len(chunks)))
        for chunk in chunks:
            logger.debug('Chrom:{0},        lv: {1:.4g}, rv: {2:.4g}, Num: {3}'.format(chrom, chunk[0], chunk[1], chunk[2].size))
            if chunk[2].size > 0:
                Poiss = poisson(chunk[1])
                logger.debug('Chrom:{0},        Assign P values ...'.format(chrom))
                chunkP = 1 - Poiss.cdf(Ovalues[fl][chunk[2]])
                pvalue[chunk[2]] = chunkP
                logger.debug('Chrom:{0},        Multiple testing ...'.format(chrom))
                cResults = multipletests(chunkP, alpha = sig, method = 'fdr_bh')
                cP = cResults[1] # Corrected Pvalue
                qvalue[chunk[2]] = cP
            else:
                logger.debug('Chrom:{0},        Skipping ...'.format(chrom))
        
        reject = qvalue <= sig
        qvalue = qvalue[reject]
        pvalue = pvalue[reject]
        Ovalues[fl] = Ovalues[fl][reject]
        Evalues[fl] = Evalues[fl][reject]
        Fold[fl] = Fold[fl][reject]
        xi = xi[reject]
        yi = yi[reject]
        
        logger.info('Chrom:{0},    Remove Gap Effects ...'.format(chrom))
        
        if len(gaps) > 0:
            fIdx = []
            for i in xrange(xi.size):
                lower = (xi[i] - 5) if (xi[i] > 5) else 0
                upper = (xi[i] + 5) if ((xi[i] + 5) < chromLen) else (chromLen - 1)
                cregion_1 = range(lower, upper)
                lower = (yi[i] - 5) if (yi[i] > 5) else 0
                upper = (yi[i] + 5) if ((yi[i] + 5) < chromLen) else (chromLen - 1)
                cregion_2 = range(lower, upper)
                cregion = set(cregion_1) | set(cregion_2)
                intersect = cregion & gaps
                if len(intersect) == 0:
                    fIdx.append(i)
        
            xi = xi[fIdx]
            yi = yi[fIdx]
            Ovalues[fl] = Ovalues[fl][fIdx]
            pvalue = pvalue[fIdx]
            qvalue = qvalue[fIdx]
            Fold[fl] = Fold[fl][fIdx]
            Evalues[fl] = Evalues[fl][fIdx]
        
        xpos[fl] = xi
        ypos[fl] = yi
        pvalues[fl] = pvalue
        qvalues[fl] = qvalue
    
    logger.info('Chrom:{0}, Combine four local filters ...'.format(chrom))
    
    preDonuts = dict(zip(zip(xpos['K'], ypos['K']), zip(Ovalues['K'], Fold['K'], pvalues['K'], qvalues['K'])))
    preLL = dict(zip(zip(xpos['Y'], ypos['Y']), zip(Ovalues['Y'], Fold['Y'], pvalues['Y'], qvalues['Y'])))
    preVB = dict(zip(zip(xpos['V'], ypos['V']), zip(Ovalues['V'], Fold['V'], pvalues['V'], qvalues['V'])))
    preHB = dict(zip(zip(xpos['H'], ypos['H']), zip(Ovalues['H'], Fold['H'], pvalues['H'], qvalues['H'])))
    
    commonPos = set(preDonuts.keys()) & set(preLL.keys()) & set(preVB.keys()) & set(preHB.keys())
    Donuts = {}; LL = {}; VB = {}; HB = {}
    candi = sorted(commonPos)
    SUMQ = np.array([preDonuts[candi[i]][3]+preLL[candi[i]][3]+preVB[candi[i]][3]+preHB[candi[i]][3] for i in xrange(len(candi))])
    candi = np.array(candi)
    mask = SUMQ <= sumq
    if len(candi) > 2:
        _, labels = dbscan(candi, eps=1.1, min_samples=2)
        incluster = labels >= 0
    mask = mask | incluster
    candi = candi[mask]
    for pos in candi:
        key = (pos[0]*res, pos[1]*res)
        Donuts[key] = preDonuts[tuple(pos)]
        LL[key] = preLL[tuple(pos)]
        VB[key] = preVB[tuple(pos)]
        HB[key] = preHB[tuple(pos)]
    
    return Donuts, LL, VB, HB

def lambdachunk(E):
    
    numbin = np.int(np.ceil(np.log(E.max()) / np.log(2) * 3 + 1))
    Pool = []
    for i in range(1, numbin + 1):
        if i == 1:
            lv = 0; rv = 1
        else:
            lv = np.power(2, ((i - 2)/3.))
            rv = np.power(2, ((i - 1)/3.))
        idx = np.where((E > lv) & (E < rv))[0]
        Pool.append((lv, rv, idx))
    
    return Pool
    

if __name__ == '__main__':
    run()
