#!/usr/bin/env python

# Created on Fri Jun 29 19:26:12 2018
# Author: XiaoTao Wang

## Required modules

import argparse, sys, os, hicpeaks

currentVersion = hicpeaks.__version__

def getargs():
    ## Construct an ArgumentParser object for command-line arguments
    parser = argparse.ArgumentParser(description='''Perform Aggregate Peak Analysis (APA) on
                                     a peak list generated by our pyHICCUPS or pyBHFDR.''',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    
    # Version
    parser.add_argument('-v', '--version', action='version',
                        version=' '.join(['%(prog)s',currentVersion]),
                        help='Print version number and exit.')
    
    # Output
    parser.add_argument('-O', '--output', help='Output png file name.')
    parser.add_argument('--dpi', default=300, type=int,
                        help='''The resolution in dots per inch of the output figure.''')
    
    # Input
    parser.add_argument('-p', '--path',
                        help = 'URI string pointing to a cooler under specific resolution.')
    parser.add_argument('-I', '--loop-file', help='Loop file outputed by pyHICCUPS or pyBHFDR.')
    parser.add_argument('-S', '--skip-rows', default=0, type=int,
                        help='''Number of leading lines in the loop file to skip.''')
    parser.add_argument('-U', '--useICE', action='store_true',
                        help='''Whether or not use ICE-corrected matrix.''')
    parser.add_argument('-M', '--min-dis', default=10, type=int,
                        help='''We only examine peak calls where the peak loci are separated by at
                        least this number of bins.''')
    parser.add_argument('-W', '--window', default=5, type=int,
                        help='''Width of the window in APA analysis.''')
    parser.add_argument('-C', '--corner-size', default=3, type=int,
                        help='''Lower-/upper-corner size of the resulted APA matrix.''')
    
    ## Parse the command-line arguments
    commands = sys.argv[1:]
    if not commands:
        commands.append('-h')
    args = parser.parse_args(commands)
    
    return args, commands


def run():

    # Parse Arguments
    args, commands = getargs()
    # Improve the performance if you don't want to run it
    if commands[0] not in ['-h', '-v', '--help', '--version']:

        import numpy as np
        import matplotlib
        matplotlib.use('Agg')
        import matplotlib.pyplot as plt
        import cooler
        from hicpeaks.apa import apa_analysis, apa_submatrix
        from hicpeaks.utilities import _parse_peakfile, find_chrom_pre

        ## extract Hi-C matrix
        hic_pool = cooler.Cooler(args.path)
        res = hic_pool.binsize
        correct = args.useICE

        # consistent chromosome label
        pre = find_chrom_pre(hic_pool.chromnames)

        peaks = _parse_peakfile(args.loop_file, args.skip_rows)
        apa = []
        for c in peaks:
            chrom = pre + c
            M = hic_pool.matrix(balance=correct, sparse=True).fetch(chrom)
            M = M.tocsr()

            # locate exact pos at given resolution
            pos = []
            for p in peaks[c]:
                orires = int(p[-1])
                x, y = p[0], p[1]
                s_l = range(x//res, int(np.ceil((x+orires)/float(res))))
                e_l = range(y//res, int(np.ceil((y+orires)/float(res))))
                si, ei = None, None
                for st in s_l:
                    for et in e_l:
                        if (st < M.shape[0]) and (et < M.shape[0]):
                            if si is None:
                                si, ei = st, et
                            else:
                                if M[st,et] > M[si,ei]:
                                    si, ei = st, et
                
                if not si is None:
                    if si < ei:
                        pos.append((si, ei))
                    else:
                        pos.append((ei, si))
                        
            tmp = apa_submatrix(M, pos, w=args.window, t=args.min_dis)
            apa.extend(tmp)
        
        apa = np.r_[apa]
        avg, score, z, p, maxi = apa_analysis(apa, w=args.window, cw=args.corner_size)
        plt.imshow(avg, cmap=plt.cm.Reds, vmax=maxi, interpolation='none')
        plt.title('APA score = {0:.3g}, p-value = {1:.3g} ({2})'.format(score, p, apa.shape[0]))
        plt.colorbar()
        plt.savefig(args.output, dpi=args.dpi)
        plt.close()

if __name__ == '__main__':
    run()



