#!python

import argparse, sys, eaglec, os

currentVersion = eaglec.__version__

def getargs():
    ## Construct an ArgumentParser object for command-line arguments
    parser = argparse.ArgumentParser(description='''Plot a local contact map centered on the
                                     provided SV breakpoint coordinates. For intra-chromosomal
                                     SVs, contact counts will be distance-normalized. All contact
                                     matrices will be min-max scaled to the range [0, 1].''',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    
    # Version
    parser.add_argument('-v', '--version', action='version',
                        version=' '.join(['%(prog)s',currentVersion]),
                        help='Print version number and exit.')

    parser.add_argument('--cool-uri', help='''Cool URI.''')
    parser.add_argument('--breakpoint-coords', help='''Breakpoint coordinates in the
                        format: chrom1,pos1,chrom2,pos2.''')
    parser.add_argument('--window-width', type=int, default=15, help='''Width of the contact
                        map window, specified in number of bins. For example, setting this
                        to 15 results in a 31x31 window.''')
    parser.add_argument('-O', '--output-figure-name', help='''Output figure name.''')
    parser.add_argument('--balance-type', default='ICE', choices=['ICE', 'CNV','Raw'])
    parser.add_argument('--max-value', type=float, default=0.99,
                        help='''Maximum value for the heatmap color bar. Must be within
                        the range [0, 1].''')
    parser.add_argument('--dpi', default=800, type=int)

    ## Parse the command-line arguments
    commands = sys.argv[1:]
    if not commands:
        commands.append('-h')
    args = parser.parse_args(commands)
    
    return args, commands

def run():

    # Parse Arguments
    args, commands = getargs()
    # Improve the performance if you don't want to run it
    if commands[0] not in ['-h', '-v', '--help', '--version']:

        import joblib, cooler, matplotlib
        import numpy as np
        import matplotlib.pyplot as plt
        from eaglec.utilities import find_matched_resolution, distance_normaize_core, image_normalize
        from matplotlib.colors import LinearSegmentedColormap

        new_rc_params = {'text.usetex': False,
        "svg.fonttype": 'none'
        }

        matplotlib.rcParams.update(new_rc_params)

        cmap = LinearSegmentedColormap.from_list('interaction',
                ['#FFFFFF','#FFDFDF','#FF7575','#FF2626','#F70000'])

        if args.balance_type == 'Raw':
            correct = False
        elif args.balance_type == 'CNV':
            correct = 'sweight'
        else:
            correct = 'weight'

        c1, p1, c2, p2 = args.breakpoint_coords.split(',')
        p1, p2 = int(p1), int(p2)
        w = args.window_width
        clr = cooler.Cooler(args.cool_uri)
        res = clr.binsize
        chromsize = clr.chromsizes

        # expected values
        folder = os.path.join(os.path.split(eaglec.__file__)[0], 'data')
        expected_values = joblib.load(os.path.join(folder, 'expected-values.by-res.pkl'))
        matched_res = find_matched_resolution(expected_values, res)
        exp = expected_values[matched_res]

        # extract and normalize the submatrix
        interval1 = (c1, max(p1-res*w, 0), min(p1+res*w+res, chromsize[c1]))
        interval2 = (c2, max(p2-res*w, 0), min(p2+res*w+res, chromsize[c2]))
        M = clr.matrix(balance=correct, sparse=False).fetch(interval1, interval2)
        M[np.isnan(M)] = 0
        if not M.shape==(2*w+1, 2*w+1):
            print('At least one breakpoint is too close to the chromosome boundary. Consider lowering the window size.')
        else:
            if M.sum() == 0:
                print('No contact counts found in the extracted matrix. Try increasing the window size or switching to a lower-resolution matrix.')
            else:
                if c1 == c2:
                    M = distance_normaize_core(M, exp, p1//res, p2//res, w)
                M = image_normalize(M)
                # plot the heatmap
                fig = plt.figure(figsize = (1.2, 1.2))
                ax = fig.add_subplot(111)
                sc = ax.imshow(M, interpolation='none', cmap=cmap, aspect = 'auto', vmax=args.max_value)
                xmin, xmax = ax.get_xlim()
                ymin, ymax = ax.get_ylim()
                fontsize=5
                offset = 0.02 * (xmax - xmin)
                ax.text(xmin, ymin+offset*2, print_coordinate(p2-res*w), va='top', ha='left', fontsize=fontsize)
                ax.text(xmax, ymin+offset*2, print_coordinate(p2+res*w+res), va='top', ha='right', fontsize=fontsize)
                ax.text(-offset*2, ymax, print_coordinate(p1-res*w), rotation=90, va='top', ha='right', fontsize=fontsize)
                ax.text(-offset*2, ymin, print_coordinate(p1+res*w+res), rotation=90, va='bottom', ha='right', fontsize=fontsize)
                ax.text((xmin+xmax)/2, ymin+7*offset, 'chr'+c2.lstrip('chr'), va='top', ha='center', fontsize=fontsize)
                ax.text(-8*offset, (ymin+ymax)/2, 'chr'+c1.lstrip('chr'), rotation=90, va='center', ha='right', fontsize=fontsize)

                # mark SVs positions
                ax.scatter(w, w, c='none', ec='k', fc='none', marker='o', s=40, alpha=1)
                ax.tick_params(axis='both', bottom=False, top=False, left=False, right=False,
                            labelbottom=False, labeltop=False, labelleft=False, labelright=False)
                for spine in ['right', 'top', 'bottom', 'left']:
                    ax.spines[spine].set_linewidth(1)
                ax.set_xlim(xmin, xmax)
                ax.set_ylim(ymin, ymax)
                plt.savefig(args.output_figure_name, dpi=args.dpi, bbox_inches='tight')
                plt.close()

def print_coordinate(pos):

    if pos % 1000000 == 0:
        return '{0}M'.format(pos//1000000)
    else:
        return '{0:.2f}M'.format(pos/1000000)

if __name__ == '__main__':
    run()