#!/usr/bin/env python

# Created on Fri Jun 29 19:26:12 2018
# Author: XiaoTao Wang

## Required modules

import argparse, sys, os, hicpeaks

currentVersion = hicpeaks.__version__

def getargs():
    ## Construct an ArgumentParser object for command-line arguments
    parser = argparse.ArgumentParser(description='''Perform Aggregate Peak Analysis (APA).''',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    
    # Version
    parser.add_argument('-v', '--version', action='version',
                        version=' '.join(['%(prog)s',currentVersion]),
                        help='Print version number and exit.')
    
    # Output
    parser.add_argument('-O', '--output', help='Output file name.')
    parser.add_argument('--dpi', default=200, type=int,
                        help='''The resolution in dots per inch of the output figure.''')
    
    # Input
    parser.add_argument('-p', '--path',
                        help = 'Cooler URI.')
    parser.add_argument('-I', '--loop-file', help='Loop file in bedpe format.')
    parser.add_argument('-S', '--skip-rows', default=0, type=int,
                        help='''Number of leading lines in the loop file to skip.''')
    parser.add_argument('-U', '--useICE', action='store_true',
                        help='''Whether or not use ICE-corrected matrix.''')
    parser.add_argument('-M', '--min-dis', default=10, type=int,
                        help='''We only examine peak calls where the peak loci are separated by at
                        least this number of bins.''')
    parser.add_argument('-W', '--window', default=5, type=int,
                        help='''Width of the window in APA analysis.''')
    parser.add_argument('-C', '--corner-size', default=3, type=int,
                        help='''Lower-/upper-corner size of the resulted APA matrix.''')
    parser.add_argument('--colormap-name', default='traditional',
                        help='Name of the colormap in matplotlib.')
    parser.add_argument('--vmax', type=float,
                        help='''The maximum value that the colorbar covers.''')
        
    ## Parse the command-line arguments
    commands = sys.argv[1:]
    if not commands:
        commands.append('-h')
    args = parser.parse_args(commands)
    
    return args, commands


def run():

    # Parse Arguments
    args, commands = getargs()
    # Improve the performance if you don't want to run it
    if commands[0] not in ['-h', '-v', '--help', '--version']:

        import numpy as np
        import matplotlib
        matplotlib.use('Agg')
        import matplotlib.pyplot as plt
        import cooler
        from hicpeaks.apa import apa_analysis, apa_submatrix
        from hicpeaks.utilities import _parse_peakfile, find_chrom_pre
        from matplotlib.colors import LinearSegmentedColormap

        cmap = LinearSegmentedColormap.from_list('interaction',
                ['#FFFFFF','#ff9292','#ff6767','#F70000'])

        ## extract Hi-C matrix
        hic_pool = cooler.Cooler(args.path)
        res = hic_pool.binsize
        correct = args.useICE

        # consistent chromosome label
        pre = find_chrom_pre(hic_pool.chromnames)

        peaks = _parse_peakfile(args.loop_file, args.skip_rows)
        apa = []
        for c in peaks:
            chrom = pre + c
            if not chrom in hic_pool.chromsizes:
                continue
            M = hic_pool.matrix(balance=correct, sparse=True).fetch(chrom)
            M = M.tocsr()

            # locate exact pos at given resolution
            pos = []
            for p in peaks[c]:
                x, y = p[0], p[2]
                if abs(y-x) < args.min_dis*res:
                    continue
                s_l = range(p[0]//res, int(np.ceil(p[1]/float(res))))
                e_l = range(p[2]//res, int(np.ceil(p[3]/float(res))))
                si, ei = None, None
                for st in s_l:
                    for et in e_l:
                        if (st < M.shape[0]) and (et < M.shape[0]):
                            if si is None:
                                si, ei = st, et
                            else:
                                if M[st,et] > M[si,ei]:
                                    si, ei = st, et
                
                if not si is None:
                    if si < ei:
                        pos.append((si, ei))
                    else:
                        pos.append((ei, si))
                        
            tmp = apa_submatrix(M, pos, w=args.window)
            apa.extend(tmp)
        
        apa = np.r_[apa]
        print(len(apa))
        avg, score, z, p, maxi = apa_analysis(apa, w=args.window, cw=args.corner_size)
        if args.vmax is None:
            vmax = maxi
        else:
            vmax = args.vmax
        if args.colormap_name=='traditional':
            plt.imshow(avg, cmap=cmap, vmax=vmax, interpolation='none')
        else:
            plt.imshow(avg, cmap=args.colormap_name, vmax=vmax, interpolation='none')
        plt.tick_params(axis='both', bottom=False, top=False, left=False, right=False,
                        labelbottom=False, labeltop=False, labelleft=False, labelright=False)
        #plt.title('APA score = {0:.3g}, p-value = {1:.3g}'.format(score, p))
        plt.colorbar()
        plt.savefig(args.output, dpi=args.dpi, bbox_inches='tight')
        plt.close()

if __name__ == '__main__':
    run()



