#!/usr/bin/env python

# Created on Tue Aug 21 19:53:19 2018
# Author: XiaoTao Wang

## Required modules

import argparse, sys, os, hicpeaks

currentVersion = hicpeaks.__version__

def getargs():
    ## Construct an ArgumentParser object for command-line arguments
    parser = argparse.ArgumentParser(description='''Visualize peak calls on heatmap.''',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    # Version
    parser.add_argument('-v', '--version', action='version',
                        version=' '.join(['%(prog)s',currentVersion]),
                        help='Print version number and exit.')
    
    # Output
    parser.add_argument('-O', '--output', help='Output png file name.')
    parser.add_argument('--dpi', default=300, type=int,
                        help='''The resolution in dots per inch of the output figure.''')

    # Input
    parser.add_argument('-p', '--path',
                        help = 'URI string pointing to a cooler under specific resolution.')
    parser.add_argument('-I', '--loop-file', help='Loop file outputed by pyHICCUPS or pyBHFDR.')
    parser.add_argument('-C', '--chrom', help='Chromosome label of your anticipated region.')
    parser.add_argument('-S', '--start', type=int, help='Start site (bp) of the region.')
    parser.add_argument('-E', '--end', type=int, help='End site (bp) of the region.')
    parser.add_argument('--skip-rows', default=0, type=int,
                        help='''Number of leading lines in the loop file to skip.''')
    parser.add_argument('--correct', action='store_true',
                        help='''Whether or not plot ICE-corrected heatmap.''')
    parser.add_argument('--vmin', type=float,
                        help='''The minimum value that the colorbar covers.''')
    parser.add_argument('--vmax', type=float,
                        help='''The maximum value that the colorbar covers.''')
    parser.add_argument('--nolabel', action='store_true',
                        help='''Whether or not add genomic coordinates.''')
    
    ## Parse the command-line arguments
    commands = sys.argv[1:]
    if not commands:
        commands.append('-h')
    args = parser.parse_args(commands)
    
    return args, commands

def properU(pos):
    
    i_part = int(pos) // 1000000 # Integer Part
    d_part = (int(pos) % 1000000) // 1000 # Decimal Part
    
    if (i_part > 0) and (d_part > 0):
        return ''.join([str(i_part), 'M', str(d_part), 'K'])
    elif (i_part == 0):
        return ''.join([str(d_part), 'K'])
    else:
        return ''.join([str(i_part), 'M'])

def caxis_H(ax):
    """
    Axis Control for HeatMaps.
    """
    ax.yaxis.set_ticks_position('left')
    ax.xaxis.set_ticks_position('bottom')
    ax.tick_params(axis = 'both', labelsize = 12, length = 5, pad = 7)

def run():

    # Parse Arguments
    args, commands = getargs()
    # Improve the performance if you don't want to run it
    if commands[0] not in ['-h', '-v', '--help', '--version']:
        
        import numpy as np
        import matplotlib
        matplotlib.use('Agg')
        import cooler
        import matplotlib.pyplot as plt
        from matplotlib.colors import LinearSegmentedColormap
        
        cmap = LinearSegmentedColormap.from_list('interaction',
                                                 ['#FFFFFF','#FFDFDF','#FF7575','#FF2626','#F70000'])
        cmap.set_bad('#2672a1')

        # Load Cooler
        Lib = cooler.Cooler(args.path)
        chrom, start, end = args.chrom, args.start, args.end

        # Extract matrix
        res = Lib.binsize
        start = start//res * res
        end = end//res * res
        M = Lib.matrix(balance=args.correct, sparse=False).fetch((chrom,start,end))
        M[np.isnan(M)] = 0
    
        nonzero = M[np.nonzero(M)]
        if args.vmin is None:
            vmin = 0
        else:
            vmin = args.vmin
        if args.vmax is None:
            vmax = np.percentile(nonzero, 95)
        else:
            vmax = args.vmax

        chrom = chrom.lstrip('chr')
        Bool = np.zeros(M.shape, dtype=bool)
        if not args.loop_file is None:
            loop_file = args.loop_file
            # Read loop data
            loopType = np.dtype({'names':['chr','loc1','loc2','res'],
                                 'formats':['U5', np.int, np.int, np.int]})
            check = open(loop_file, 'rb').readline().rstrip().split()
            loops = np.loadtxt(loop_file, dtype=loopType, skiprows=args.skip_rows, usecols=[0,1,2,3])
            loops = loops[(loops['chr']==chrom)]
    
            # Mask original matrix using loop data
            test_x = loops['loc1']
            if len(check)==4:
                resarr = loops['res']
            else:
                resarr = np.ones(loops['res'].size, dtype=int) * res
            test_y = loops['loc2'] + resarr
            mask = (test_x >= start) & (test_y < end)
            loops = loops[mask]
        
            for x, y, orires in zip(loops['loc1'], loops['loc2'], resarr):
                # Lodate the peak pixel at given resolution
                s_l = range(x//res, int(np.ceil((x+orires)/float(res))))
                e_l = range(y//res, int(np.ceil((y+orires)/float(res))))
                si, ei = None, None
                for i in s_l:
                    for j in e_l:
                        st = i - start//res
                        et = j - start//res
                        if (st < M.shape[0]) and (et < M.shape[0]):
                            if si is None:
                                si, ei = st, et
                            else:
                                if M[st,et] > M[si,ei]:
                                    si, ei = st, et
                if not si is None:
                    if si < ei:
                        Bool[si, ei] = 1
                    else:
                        Bool[ei, si] = 1
    
        M = np.ma.array(M, mask = Bool)

        # Plot
        size = (8, 7.3)
        width = 0.7; Left = 0.1
        HB = 0.1; HH = width * size[0] / size[1]

        fig = plt.figure(figsize=size)
        ax = fig.add_axes([Left, HB, width, HH])
        sc = ax.imshow(M, cmap = cmap, aspect = 'auto', interpolation = 'none',
                       vmax = vmax, vmin = vmin)
        if args.nolabel:
            ax.tick_params(axis='both', bottom=False, top=False, left=False, right=False,
                           labelbottom=False, labeltop=False, labelleft=False, labelright=False)
        else:
            interval = (end - start) // res
            ticks = list(np.linspace(0, interval, 6).astype(int))
            pos = list(np.linspace(start, end, 6).astype(int))
            labels = [properU(p) for p in pos]
            ax.set_xticks(ticks)
            ax.set_xticklabels(labels)
            ax.set_yticks(ticks)
            ax.set_yticklabels(labels)
            caxis_H(ax)

        ## Colorbar
        ax = fig.add_axes([Left+width+0.03, HB, 0.03, HH])
        fig.colorbar(sc, cax=ax)

        plt.savefig(args.output, bbox_inches='tight', dpi=args.dpi)
        plt.close()


if __name__ == '__main__':
    run()