#!/usr/bin/env python

# Created on Tue Aug 21 19:53:19 2018
# Author: XiaoTao Wang

## Required modules

import argparse, sys, os, hicpeaks

currentVersion = hicpeaks.__version__

def getargs():
    ## Construct an ArgumentParser object for command-line arguments
    parser = argparse.ArgumentParser(description='''Visualize peak calls on heatmap.''',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    # Version
    parser.add_argument('-v', '--version', action='version',
                        version=' '.join(['%(prog)s',currentVersion]),
                        help='Print version number and exit.')
    
    # Output
    parser.add_argument('-O', '--output', help='Output png file name.')
    parser.add_argument('--dpi', default=500, type=int,
                        help='''The resolution in dots per inch of the output figure.''')

    # Input
    parser.add_argument('-p', '--path',
                        help = 'Cooler URI.')
    parser.add_argument('-I', '--loop-file', help='Loop file in bedpe format.')
    parser.add_argument('-C', '--chrom', help='Chromosome label of your anticipated region.')
    parser.add_argument('-S', '--start', type=int, help='Start site (bp) of the region.')
    parser.add_argument('-E', '--end', type=int, help='End site (bp) of the region.')
    parser.add_argument('--skip-rows', default=0, type=int,
                        help='''Number of leading lines in the loop file to skip.''')
    parser.add_argument('--clr-weight-name', default='weight',
                        help='''The name of the weight column in your Cooler URI for normalizing
                        the contact signals. Specify it to "raw" if you want to plot the raw signals.''')
    parser.add_argument('--vmin', type=float,
                        help='''The minimum value that the colorbar covers.''')
    parser.add_argument('--vmax', type=float,
                        help='''The maximum value that the colorbar covers.''')
    parser.add_argument('--colormap-name', default='traditional',
                        help='Name of the colormap in matplotlib.')
    parser.add_argument('--marker-size', default=10, type=int, help='''Marker sizes.''')
    parser.add_argument('--marker-color', default='#1F78B4', help='''Marker Colors''')
    parser.add_argument('--marker-alpha', default=1, type=float,
                       help='''The alpha blending value of loop markers, between 0
                       (transparent) and 1 (opaque)''')
    parser.add_argument('--marker-linewidth', default=0.5, type=float,
                        help='''Marker line widths.''')
    parser.add_argument('--nolabel', action='store_true',
                        help='''Whether or not add genomic coordinates.''')
    parser.add_argument('--log', action='store_true')
    
    ## Parse the command-line arguments
    commands = sys.argv[1:]
    if not commands:
        commands.append('-h')
    args = parser.parse_args(commands)
    
    return args, commands

def print_coordinate(pos):
    
    if pos % 1000000 == 0:
        return '{0}M'.format(pos//1000000)
    else:
        return '{0:.2f}M'.format(pos/1000000)

def run():

    # Parse Arguments
    args, commands = getargs()
    # Improve the performance if you don't want to run it
    if commands[0] not in ['-h', '-v', '--help', '--version']:
        
        import numpy as np
        import matplotlib
        matplotlib.use('Agg')
        import cooler
        import matplotlib.pyplot as plt
        from matplotlib.colors import LinearSegmentedColormap
        from hicpeaks.utilities import _parse_peakfile
        from matplotlib.colors import LogNorm
        
        cmap = LinearSegmentedColormap.from_list('interaction',
                                                 ['#FFFFFF','#FFDFDF','#FF7575','#FF2626','#F70000'])

        if args.clr_weight_name.lower() == 'raw':
            correct = False
        else:
            correct = args.clr_weight_name

        # Load Cooler
        Lib = cooler.Cooler(args.path)
        chrom, start, end = args.chrom, args.start, args.end

        # Extract matrix
        res = Lib.binsize
        start = start//res * res
        end = end//res * res
        M = Lib.matrix(balance=correct, sparse=False).fetch((chrom,start,end))
        M[np.isnan(M)] = 0

        # plot heatmap
        nonzero = M[np.nonzero(M)]
        if args.vmin is None:
            vmin = nonzero.min()
        else:
            vmin = args.vmin
        if args.vmax is None:
            vmax = np.percentile(nonzero, 93)
        else:
            vmax = args.vmax

        size = (2.2, 2)
        fig = plt.figure(figsize=size)
        width = 0.7; Left = 0.1
        HB = 0.1; HH = width * size[0] / size[1]
        ax = fig.add_axes([Left, HB, width, HH])
        if args.colormap_name=='traditional':
            if args.log:
                sc = ax.imshow(M, cmap = cmap, aspect = 'auto', interpolation = 'none',
                            norm=LogNorm(vmin=vmin, vmax=vmax))
            else:
                sc = ax.imshow(M, cmap = cmap, aspect = 'auto', interpolation = 'none',
                        vmax = vmax, vmin = vmin)
        else:
            if args.log:
                sc = ax.imshow(M, cmap = args.colormap_name, aspect = 'auto',
                            interpolation = 'none', norm=LogNorm(vmin=vmin, vmax=vmax))
            else:
                sc = ax.imshow(M, cmap = args.colormap_name, aspect = 'auto',
                        interpolation = 'none', vmax = vmax, vmin = vmin)

        xmin, xmax = ax.get_xlim()
        ymin, ymax = ax.get_ylim()

        # extract and plot chromatin loops
        chrom = chrom.lstrip('chr')
        if not args.loop_file is None:
            loop_file = args.loop_file
            # Read loop data
            loops = _parse_peakfile(loop_file, skip=args.skip_rows)
            loops = loops[chrom]
    
            # Mask original matrix using loop data
            for xs, xe, ys, ye in loops:
                # Lodate the peak pixel at given resolution
                s_l = range(xs//res, int(np.ceil(xe/float(res))))
                e_l = range(ys//res, int(np.ceil(ye/float(res))))
                si, ei = None, None
                for i in s_l:
                    for j in e_l:
                        st = i - start//res
                        et = j - start//res
                        if (0 <= st < M.shape[0]) and (0 <= et < M.shape[0]):
                            if si is None:
                                si, ei = st, et
                            else:
                                if M[st,et] > M[si,ei]:
                                    si, ei = st, et
                if not si is None:
                    ax.scatter(si, ei, s=args.marker_size, c='none', marker='o',
                               edgecolors=args.marker_color, alpha=args.marker_alpha,
                               linewidths=args.marker_linewidth)
                    ax.scatter(ei, si, s=args.marker_size, c='none', marker='o',
                               edgecolors=args.marker_color, alpha=args.marker_alpha,
                               linewidths=args.marker_linewidth)
        ax.set_xlim(xmin, xmax)
        ax.set_ylim(ymin, ymax)
        
        # add coordinates
        ax.tick_params(axis='both', bottom=False, top=False, left=False, right=False,
                       labelbottom=False, labeltop=False, labelleft=False, labelright=False)
        for spine in ['right', 'top', 'bottom', 'left']:
            ax.spines[spine].set_linewidth(0.9)

        if not args.nolabel:
            fontsize=6
            offset = 0.02 * (xmax - xmin)
            ax.text(xmin, ymin+offset, print_coordinate(start), va='top', ha='left', fontsize=fontsize)
            ax.text(xmax, ymin+offset, print_coordinate(end), va='top', ha='right', fontsize=fontsize)
            ax.text(-offset, ymax, print_coordinate(start), rotation=90, va='top', ha='right', fontsize=fontsize)
            ax.text(-offset, ymin, print_coordinate(end), rotation=90, va='bottom', ha='right', fontsize=fontsize)
            ax.text((xmin+xmax)/2, ymin+2*offset, 'chr'+chrom, va='top', ha='center', fontsize=fontsize)
            ax.text(-2*offset, (ymin+ymax)/2, 'chr'+chrom, rotation=90, va='center', ha='right', fontsize=fontsize)

        ## Colorbar
        ax = fig.add_axes([Left+width+0.04, 0.72, 0.03, 0.15])
        fig.colorbar(sc, cax=ax, ticks=[vmin, vmax], format='%.3g')
        ax.tick_params(labelsize=5)

        plt.savefig(args.output, bbox_inches='tight', dpi=args.dpi)
        plt.close()


if __name__ == '__main__':
    run()