#!/usr/bin/env python

# Created on Thu Apr 30 15:15:30 2015

# Author: XiaoTao Wang

from __future__ import division
import argparse, sys, logging, logging.handlers, hicpeaks

currentVersion = hicpeaks.__version__

def getargs():
    ## Construct an ArgumentParser object for command-line arguments
    parser = argparse.ArgumentParser(usage = '%(prog)s <-O output> [options]',
                                     description = 'Local Peak Calling for Hi-C Data',
                                     formatter_class = argparse.ArgumentDefaultsHelpFormatter)
    
    # Version
    parser.add_argument('-v', '--version', action='version',
                        version=' '.join(['%(prog)s',currentVersion]),
                        help='Print version number and exit.')

    # Output
    parser.add_argument('-O', '--output', help = 'Output peak file path.')
    parser.add_argument('--logFile', default = 'HICCUPS.log', help = 'Logging file name.')
    
    group_1 = parser.add_argument_group(title = 'Relate to Hi-C data:')
    group_1.add_argument('-p', '--path',
                         help = 'Cooler URI.')
    group_1.add_argument('-C', '--chroms', nargs = '*', default = ['#', 'X'],
                        help = 'List of chromosome labels. Only Hi-C data within the specified '
                        'chromosomes will be included. Specially, "#" stands for chromosomes '
                        'with numerical labels. "--chroms" with zero argument will include '
                        'all chromosome data.')
    
    ## About the algorithm
    group_2 = parser.add_argument_group(title = 'Algorithm Parameters:')
    group_2.add_argument('--pw', type=int, nargs='+',
                         help='''List of the peak widths.''')
    group_2.add_argument('--ww', type=int, nargs='+',
                         help='''List of the donut widths.''')
    group_2.add_argument('--maxww', type = int, default = 20, help = 'Maximum donut width.')
    group_2.add_argument('--siglevel', type = float, default = 0.1, help = 'Significant Level.')
    group_2.add_argument('--sumq', type = float, default = 0.01,
                         help = '''During the additional filtering, original peak pixels would be
                         filtered out if there are no other peak pixels located in its neighborhood
                         and the sum of its 2 q-values is greater than this threshold.''')
    group_2.add_argument('--double-fold', type = float, default = 1.75,
                         help = '''Besides FDR control, pyHICCUPS removes all peak pixels that
                         doesn't show this minimum fold enrichment with respect to the expected
                         values for both backgrounds.''')
    group_2.add_argument('--single-fold', type = float, default = 2,
                         help = '''pyHICCUPS only remains peak pixels which have at least this fold
                         enrichment over either the donut expected value or the lower-left expected
                         value.''')
    group_2.add_argument('--use-raw', action = 'store_true',
                        help = '''When specified, peak pixels will be sorted by raw observed interaction
                        frequencies during local clustering. By default, ICE-corrected values will be used.''')
    group_2.add_argument('--min-marginal-peaks', type = int, default = 3,
                        help = '''Minimum marginal number of peaks when detecting peak anchors.''')
    group_2.add_argument('--only-anchors', action = 'store_true',
                         help = '''When specified, either of the peak loci must be an anchor.''')
    group_2.add_argument('--maxapart', type = int, default = 10000000, help = 'Maximum genomic'
                         ' distance between two involved loci.')
    group_2.add_argument('--nproc', type = int, default = 1, help = 'Number of worker processes.')
    
    ## Parse the command-line arguments
    commands = sys.argv[1:]
    if not commands:
        commands.append('-h')
    args = parser.parse_args(commands)
    
    return args, commands

def run():
     # Parse Arguments
    args, commands = getargs()
    # Improve the performance if you don't want to run it
    if commands[0] not in ['-h', '--help']:
        ## Root Logger Configuration
        logger = logging.getLogger()
        # Logger Level
        logger.setLevel(10)
        console = logging.StreamHandler()
        filehandler = logging.handlers.RotatingFileHandler(args.logFile,
                                                           maxBytes = 200000,
                                                           backupCount = 5)
        # Set level for Handlers
        console.setLevel('INFO')
        filehandler.setLevel('INFO')
        # Customizing Formatter
        formatter = logging.Formatter(fmt = '%(name)-21s %(levelname)-7s @ %(asctime)s: %(message)s',
                                      datefmt = '%m/%d/%y %H:%M:%S')
        
        console.setFormatter(formatter)
        filehandler.setFormatter(formatter)
        # Add Handlers
        logger.addHandler(console)
        logger.addHandler(filehandler)

        logger.info('Python Version: {}'.format(sys.version.split()[0]))
        
        ## Logging for argument setting
        arglist = ['# ARGUMENT LIST:',
                   '# Output file = {}'.format(args.output),
                   '# Cooler URI = {}'.format(args.path),
                   '# Chromosomes = {}'.format(args.chroms),
                   '# Peak window width = {}'.format(args.pw),
                   '# Donut width = {}'.format(args.ww),
                   '# Maximum donut width = {}'.format(args.maxww),
                   '# Significant Level = {}'.format(args.siglevel),
                   '# Sum of 2 q-values = {}'.format(args.sumq),
                   '# Double fold threshold = {}'.format(args.double_fold),
                   '# Single fold threshold = {}'.format(args.single_fold),
                   '# Use Raw IF in clustering = {0}'.format(args.use_raw),
                   '# Minimum marginal peaks = {0}'.format(args.min_marginal_peaks),
                   '# Only remain anchors = {0}'.format(args.only_anchors),
                   '# Maximum Genomic distance = {}'.format(args.maxapart),
                   '# Number of Processes = {}'.format(args.nproc)
                   ]
        
        argtxt = '\n'.join(arglist)
        logger.info('\n'+argtxt)
        
        # Package Dependencies
        import cooler
        from multiprocess import Pool
        import numpy as np
        from scipy import sparse
        from hicpeaks.callers import hiccups

        def worker(tuple_arg):
            Lib, key, pw, ww, maxww, siglevel, sumq, dfold, sfold, maxapart, resolution, use_raw, mmp, only_anchor = tuple_arg
            #logger.info('Chromosome %s ...', key.lstrip('chr'))
            H = Lib.matrix(balance=False, sparse=True).fetch(key)
            cHeatMap = Lib.matrix(balance=True, sparse=True).fetch(key)
            # Customize Sparse Matrix ...
            chromLen = H.shape[0]
            num = args.maxapart // resolution + args.maxww + 1
            Diags = [H.diagonal(i) for i in np.arange(num)]
            M = sparse.diags(Diags, np.arange(num), format='csr')
            x = np.arange(min(ww), num)
            IR = {}
            cDiags = []
            for i in x:
                diag = cHeatMap.diagonal(i)
                mask = np.isnan(diag)
                notnan = diag[np.logical_not(mask)]
                IR[i] = notnan.mean()
                diag[mask] = 0
                cDiags.append(diag)
            cM = sparse.diags(cDiags, x, format='csr')

            del H, cHeatMap

            tmp = Lib.bins().fetch(key)['weight'].values
            mask = np.logical_not((tmp==0) | np.isnan(tmp))
            biases = np.zeros_like(tmp)
            biases[mask] = 1/tmp[mask]

            key = key.lstrip('chr')

            pixel_table = hiccups(M, cM, biases, biases, IR, chromLen, Diags, cDiags, num, key,
                                  pw=pw, ww=ww, sig=siglevel, sumq=sumq, maxww=maxww, maxapart=maxapart,
                                  double_fold=dfold, single_fold=sfold, res=resolution, use_raw=use_raw,
                                  min_marginal_peaks=mmp, onlyanchor=only_anchor)
            
            return key, pixel_table
        
        logger.info('Loading Hi-C data ...')
        Lib = cooler.Cooler(args.path)
        resolution = Lib.binsize

        logger.info('Calling Peaks ...')
        OF = open(args.output, 'w')
        head = '\t'.join(['chromLabel', 'loc_1', 'loc_2', 'centroid_x', 'centroid_y','radius','IF', 'D-Enrichment',
                          'D-pvalue', 'D-qvalue', 'LL-Enrichment', 'LL-pvalue', 'LL-qvalue']) + '\n'
        OF.write(head)

        Params = []
        for key in Lib.chromnames:
            chromlabel = key.lstrip('chr')
            if ((not args.chroms) or (chromlabel.isdigit() and '#' in args.chroms) or (chromlabel in args.chroms)):
                Params.append((Lib, key, args.pw, args.ww, args.maxww, args.siglevel, args.sumq, args.double_fold,
                               args.single_fold, args.maxapart, resolution, args.use_raw, args.min_marginal_peaks,
                               args.only_anchors))

        if args.nproc == 1:
            results = map(worker, Params)
        else:
            pool = Pool(args.nproc)
            results = pool.map(worker, Params)
            pool.close()
            pool.join()

        for key, pixel_table in results:
            for pixel in pixel_table:
                lineFormat = '{0}\t{1}\t{2}\t{3}\t{4}\t{5}\t{6:.3g}\t{7:.3g}\t{8:.3g}\t{9:.3g}\t{10:.3g}\t{11:.3g}\t{12:.3g}\n'
                contents = (key,) + pixel + pixel_table[pixel]
                line = lineFormat.format(*contents)
                OF.write(line)        
                    
        OF.flush()
        OF.close()
                            
        logger.info('Done!')
    

if __name__ == '__main__':
    run()
