#!python

from Bio import motifs
import math
import os
import sys
import argparse
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import pandas
import glob
import argparse
from pathlib import Path
import pybio
import gzip
import scanRBP

scanRBP_path = Path(__file__).parent.absolute()

CLIP_data = {}
def read_CLIP():
    f = open("eCLIP_data/metadata.tsv", "rt")
    header = f.readline().replace("\r", "").replace("\n", "").split("\t")
    r = f.readline()
    while r:
        r = r.replace("\r", "").replace("\n", "").split("\t")
        data = dict(zip(header, r))
        if data["File format"]=="bed narrowPeak":
            fbase = data["File accession"]
            fbed = f"eCLIP_data/{fbase}.bed"
            protein = data["Experiment target"].replace("-human", "")
            cell_type = data["Biosample term name"]
            CLIP_data[(cell_type, protein)] = pybio.data.Bed(f"CLIP_data/{fbed}")
        r = f.readline()
    f.close()

help_0 = """Usage: scanRBP sequence output [options]
    * one sequence provided on the command line, generates output.png/pdf + output.tab

Usage: scanRBP filename.fasta [options]
    * one heatmap/matrix will be generated per sequence from the FASTA file
    * output name of the files will be sequence ids provided in the fasta file

Usage: scanRBP search search_term
    * returns list of proteins available matching search_term

Options:
     -annotate                Annotate each heatmap cell with the number
     -xlabels                 Display sequence (x-labels), default False
     -protein TARDBP.K562.00  Only analyze binding for the provided protein (default: analyze binding for all proteins in scanRBP database)
     -cumulative              Take only one protein (-protein) and plot binding across all sequences provided in the FASTA file
     -figsize "(10,20)"       Change matplotlib/seaborn figure size for the heatmap, example width=10, height=20
     -fontscale               Change font scale, default 0.2
     -heatmap title           Make heatmap (png+pdf) with title
     -output_folder folder    Store all results to the output folder (default: current folder)
     -nonzero                 All negative vector values are set to 0, not enabled by default
     -CLIP                    Use actual CLIP data, do not use PWM
     -force                   If output files already exist, overwrite them, otherwise do not process
"""

parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter, add_help=False)
parser.add_argument('commands', help="command(s) to run", nargs='*')
parser.add_argument("-help", "-h", "--help", action="store_true")
parser.add_argument("-annotate", "--annotate", action="store_true", default=False)
parser.add_argument("-xlabels", "--xlabels", action="store_true", default=False)
parser.add_argument("-protein", "--protein", default=False)
parser.add_argument("-figsize", "--figsize", default=False)
parser.add_argument("-fontscale", "--fontscale", default=0.2, type=float)
parser.add_argument("-title", "--title", default=False)
parser.add_argument("-matrix", "--matrix", action="store_true", default=True)
parser.add_argument("-CLIP", "--CLIP", action="store_true", default=False)
parser.add_argument("-nonzero", "--nonzero", action="store_true", default=False)
parser.add_argument("-heatmap", "--heatmap", default=False)
parser.add_argument("-force", "--force", default=False)
parser.add_argument("-output_folder", "--output_folder", default=".")
parser.add_argument("-cumulative", "--cumulative", action="store_true", default=False)
parser.add_argument("-version", "--version", help="Print version", action="store_true")
parser.add_argument("-vlines", "--vlines", default="[]", type=str)

args = parser.parse_args()

print(f"scanRBP | v{scanRBP.version}, https://github.com/grexor/scanRBP")
print(f"scanRBP | config file: {scanRBP.config.config_fname()}")
print(f"scanRBP | data folder: {scanRBP.config.data_folder}")
print()

# find all proteins with search
def find_proteins(search):
    results, results_data = [], []
    for scan_id, data in scanRBP.database.proteins.items():
        protein_lower = data["protein"].lower()
        protein_desc_lower = data["description"].lower()
        scan_id_lower = scan_id.lower()
        search_lower = search.lower()
        if scan_id_lower.find(search_lower)!=-1 or protein_desc_lower.find(search_lower)!=-1 or protein_lower.find(search_lower)!=-1:
            results.append(scan_id)
            results_data.append(data)
    return results, results_data

if args.version:
    sys.exit(0)

if len(args.commands) not in [1,2] or args.help:
    print(help_0)
    sys.exit(0)

# converts BED files to bedGraph files (eCLIP datasets)
def bed_bedgraph(fname_in, fname_out):
    fin = gzip.open(fname_in)
    fout = gzip.open(fname_out, "wb")
    return

def make_bedgraph():
    files = glob.glob("../data/CLIP_bed/*.bed.gz")
    print(files)

def process_CLIP(chr, strand, start, stop, output_fname):
    heatmap_data = []
    heatmap_columns = []
    scores = CLIP_data["tardbp"].get_vector(chr, strand, start, stop)
    heatmap_data.append(scores)
    heatmap_columns.append("K562.TARDBP.0")
    heatmap_rows = list(range(start, stop+1))
    if strand=="-":
        heatmap_rows.reverse()
    basename = output_fname.replace(".png", "").replace(".pdf", "")
    data = pandas.DataFrame(heatmap_data, index=heatmap_columns, columns=heatmap_rows)
    if not os.path.exists(args.output_folder):
        os.makedirs(args.output_folder)
    data.to_csv(args.output_folder+"/"+basename+"_CLIP.tab", sep="\t")

def scan(seq):
    heatmap_data = []
    heatmap_columns = []
    sum_all = 0
    for scan_id, pssm in scanRBP.pwm.pssm.items():
        scores = pssm.calculate(seq)
        scores = [x if x!=-math.inf else 0 for x in scores]
        scores = [x if x>=0 else 0 for x in scores]
        if len(scores)<len(seq):
            scores = scores + [0] * (len(seq)-len(scores))
        sum_all += sum(scores)
        heatmap_data.append(scores)
        heatmap_columns.append(scan_id)
    heatmap_rows = list(seq)
    data = pandas.DataFrame(heatmap_data, index=heatmap_columns, columns=heatmap_rows)
    return data

def process(seq, output_fname, args, title=None):
    output_fname = output_fname.replace("+", "plus").replace("-", "minus")
    basename = os.path.basename(output_fname.replace(".png", "").replace(".pdf", ""))
    output_final = f"{args.output_folder}/{basename}.tab.gz"
    if os.path.exists(output_final) and not args.force:
        return True

    heatmap_data = []
    heatmap_columns = []
    sum_all = 0
    for scan_id, pssm in scanRBP.pwm.pssm.items():
        if args.protein: # only process a specific protein?
            if scan_id not in args.protein:
                continue
        scores = pssm.calculate(seq)
        scores = [x if x!=-math.inf else 0 for x in scores]
        if args.nonzero:
            scores = [x if x>=0 else 0 for x in scores]
        if len(scores)<len(seq):
            scores = scores + [0] * (len(seq)-len(scores))
        sum_all += sum(scores)
        heatmap_data.append(scores)
        heatmap_columns.append(scan_id)

    for index, data in enumerate(heatmap_data):
        assert(len(seq)==len(data))

    heatmap_rows = list(seq)
    data = pandas.DataFrame(heatmap_data, index=heatmap_columns, columns=heatmap_rows)
    if not os.path.exists(args.output_folder):
        os.makedirs(args.output_folder)
    data.to_csv(output_final, sep="\t")

    if args.heatmap!=False:
        plt.figure()
        sns.set(font="Arial")
        sns.set(font_scale=args.fontscale)
        sns.set_style("dark")
        sns.set_style("ticks")

        rdgn = sns.diverging_palette(h_neg=130, h_pos=10, s=99, l=55, sep=3, as_cmap=True)
        optional_params = {"linewidths":0.0}
        optional_params["col_cluster"] = False
        optional_params["annot"] = args.annotate
        optional_params["center"] = 0
        optional_params["fmt"] = ".0f"
        optional_params["yticklabels"] = True
        optional_params["xticklabels"] = args.xlabels
        optional_params["cmap"] = rdgn
        optional_params["cbar_pos"] = None

        if args.figsize!=False:
            optional_params["figsize"] = eval(args.figsize)
        fig = sns.clustermap(data, **optional_params)
        fig.ax_col_dendrogram.set_visible(False)
        ax = fig.ax_heatmap
        #ax.vlines([int(len(seq)/2)], *ax.get_ylim(), linestyle='dashed', label='center', colors=["#bbbbbb"], linewidth=1)

        vlines = args.vlines
        if not vlines.startswith("["):
            vlines = "[" + vlines
        if not vlines.endswith("]"):
            vlines = vlines + "]"
        vlines = eval(vlines)
        for vline_x in vlines:
            if vline_x<0:
                vline_x = len(seq)+vline_x
            ax.vlines([vline_x], *ax.get_ylim(), linestyle='dashed', label='center', colors=["#bbbbbb"], linewidth=1)

        if title==None:
            plt.title(output_fname + ": " + args.heatmap)
        else:
            plt.title(title)
        plt.tight_layout() 
        if not os.path.exists(args.output_folder):
            os.makedirs(args.output_folder)
        plt.savefig(f"{args.output_folder}/{basename}.png", dpi=300)
        plt.savefig(f"{args.output_folder}/{basename}.pdf")
        plt.close()

def process_cumulative(seqs, protein, output_fname, args, title=None):
    heatmap_data = []
    heatmap_columns = []

    sum_all = 0
    pssm = scanRBP.pwm.pssm[protein]
    seq_len = None
    for (seq_id, seq) in seqs:
        if seq_len==None:
            seq_len = len(seq)
        else:
            if len(seq)!=seq_len:
                print("scanRBP | please provide a FASTA file with sequences of equal length")
                sys.exit(1)
        scores = pssm.calculate(seq)

        scores = [x if x!=-math.inf else 0 for x in scores]
        if args.nonzero:
            scores = [x if x>=0 else 0 for x in scores]
        if len(scores)<len(seq):
            scores = scores + [0] * (len(seq)-len(scores))
        sum_all += sum(scores)
        heatmap_data.append(scores)
        heatmap_columns.append(seq_id)

    for index, data in enumerate(heatmap_data):
        assert(len(seq)==len(data))

    heatmap_rows = list(seq)
    output_fname = output_fname.replace("+", "plus").replace("-", "minus")
    basename = os.path.basename(output_fname.replace(".png", "").replace(".pdf", ""))
    data = pandas.DataFrame(heatmap_data, index=heatmap_columns, columns=heatmap_rows)
    if not os.path.exists(args.output_folder):
        os.makedirs(args.output_folder)
    data.to_csv(f"{args.output_folder}/pwm_{basename}.tab.gz", sep="\t")

    if args.heatmap!=False:
        plt.figure()
        sns.set(font="Arial")
        sns.set(font_scale=args.fontscale)
        sns.set_style("dark")
        sns.set_style("ticks")

        rdgn = sns.diverging_palette(h_neg=130, h_pos=10, s=99, l=55, sep=3, as_cmap=True)
        optional_params = {"linewidths": 0.0}
        optional_params["col_cluster"] = False
        optional_params["annot"] = args.annotate
        optional_params["center"] = 0
        optional_params["fmt"] = ".0f"
        optional_params["yticklabels"] = True
        optional_params["xticklabels"] = args.xlabels
        optional_params["cmap"] = rdgn
        optional_params["cbar_pos"] = None

        if args.figsize!=False:
            optional_params["figsize"] = eval(args.figsize)
        fig = sns.clustermap(data, **optional_params)
        fig.ax_col_dendrogram.set_visible(False)
        ax = fig.ax_heatmap
        #ax.vlines([int(seq_len/2)], *ax.get_ylim(), linestyle='dashed', label='center', colors=["#bbbbbb"], linewidth=1)

        vlines = args.vlines
        if not vlines.startswith("["):
            vlines = "[" + vlines
        if not vlines.endswith("]"):
            vlines = vlines + "]"
        vlines = eval(vlines)
        for vline_x in vlines:
            if vline_x<0:
                vline_x = len(seq)+vline_x
            ax.vlines([vline_x], *ax.get_ylim(), linestyle='dashed', label='center', colors=["#bbbbbb"], linewidth=1)

        if title==None:
            plt.title(output_fname + ": " + args.heatmap)
        else:
            plt.title(title)
        plt.tight_layout() 
        if not os.path.exists(args.output_folder):
            os.makedirs(args.output_folder)
        plt.savefig(f"{args.output_folder}/{basename}.png", dpi=300)
        #plt.savefig(f"{args.output_folder}/{basename}.pdf")
        plt.close()

if len(args.commands)>0:

    if args.protein:
        args.protein, _ = find_proteins(args.protein)

    if args.commands[0]=="config":
        if len(args.commands)>1:
            scanRBP.config.init(args.commands[1])
        sys.exit(0)

    if args.commands[0]=="search":
        if len(args.commands)>1:
            proteins, proteins_data = find_proteins(args.commands[1])
            if len(proteins)>0:
                print("[scanRBP] Found proteins in the scanRBP database:")
                table = []
                for rec in proteins_data:
                    table.append([rec['scan_id'], rec['protein'], rec['tissue'], rec['description'], rec['source']])
                df = pandas.DataFrame(table, columns = ['scan_id', 'protein', 'tissue', 'description', 'source'])
                print(df.to_string(index=False))
            else:
                print(f"[scanRBP] Found no proteins in the scanRBP database with provided seach '{args.commands[1]}'")
        else:
            print("[scanRBP] please provide search term (protein name)")
        sys.exit(0)

    if args.commands[0].lower().endswith(".fasta") or args.commands[0].lower().endswith(".fa"):
        import pybio
        basename = ".".join(args.commands[0].split(".")[:-1])
        data = pybio.data.Fasta(args.commands[0])
        seqs = []
        while data.read():
            id = data.id.lstrip(" ").split(" ")[0]
            #id = basename + "_" + id
            id = f"pwm_{id}"
            seq = data.sequence
            if not args.cumulative:
                process(seq, id, args, title=id)
            seqs.append((data.id.lstrip(" "), data.sequence))
        if args.cumulative:
            if args.protein==False:
                print("scanRBP | please provide a specific protein (or list of proteins) with '-protein name1,name2...' to compute cumulative binding across all sequences in the FASTA file for each of the provided proteins")
                sys.exit(1)
            for protein in args.protein:
                process_cumulative(seqs, protein, f"{basename}_{protein}", args, title=f"{basename}_{protein}")
    elif args.commands[0]=="bed":
        make_bedgraph()
    else:
        if len(args.commands)>=2:
            process(args.commands[0], args.commands[1], args)
        else:
            print(help_0)
