#!python
# -*- coding: utf-8 -*-

#  PINTS: Peak Identifier for Nascent Transcripts Starts
#  Copyright (c) 2019-2021 Li Yao at the Yu Lab.
#
#  This program is free software: you can redistribute it and/or modify
#  it under the terms of the GNU General Public License as published by
#  the Free Software Foundation, either version 3 of the License, or
#  (at your option) any later version.
#
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU General Public License for more details.

# @Author: Li Yao
# @Date: 5/7/21
import os
import sys
import argparse
import logging
import warnings

try:
    import numpy as np
    import pandas as pd
    from pybedtools import BedTool
    from pints.io_engine import get_read_signal, get_coverage_bw, log_assert
    from collections import namedtuple
except ImportError as e:
    missing_package = str(e).replace("No module named '", "").replace("'", "")
    sys.exit("Please install %s first!" % missing_package)
warnings.filterwarnings("error")

DEFAULT_PREFIX = str(os.getpid())
logging.basicConfig(format="%(name)s - %(asctime)s - %(levelname)s: %(message)s",
                    datefmt="%d-%b-%y %H:%M:%S",
                    level=logging.INFO,
                    handlers=[
                        logging.FileHandler(os.path.join(os.getcwd(), '%s.log' % DEFAULT_PREFIX)),
                        logging.StreamHandler()
                    ])
logger = logging.getLogger("PINTS - BoundaryExtender")


def main(input_file, layout, div_file, bid_file, single_file, divergent_extension=(60, 60),
         unidirectional_extension=(60, 60), promoter_bed=None, **kwargs):
    """
    Extend boundaries

    Parameters
    ----------
    input_file : str
        Path to a bam file
    layout : str
        Layout out the bam file
    div_file : str
        Path to divergent peaks PINTS called
    bid_file : str
        Path to bidirectional peaks PINTS called
    single_file : str
        Path to unidirectional peaks PINTS called
    divergent_extension : tuple
        BPs to be extended for both divergent and bidirectional peaks
    unidirectional_extension : tuple
        BPs to be extended for unidirectional peaks
    promoter_bed : str or None
        Path to a bed file which defines promoter regions

    Returns
    -------
    None
    """

    de_tag = "_".join(set(map(str, divergent_extension)))
    parent_path = os.path.dirname(div_file)
    if isinstance(input_file, str):
        log_assert(layout is not None,
                   "Please specify which type of experiment this data was generated from with --exp-type", logger)
        pl, mn, _ = get_read_signal(input_file, loc_prime=layout, chromosome_startswith="chr",
                                    output_dir=parent_path, output_prefix=str(os.getpid()))
    else:
        log_assert(len(input_file[0]) == len(input_file[1]),
                   "Must provide the same amount of bigwig files for both strands", logger)

        pl, mn, _ = get_coverage_bw(bw_pl=input_file[0], bw_mn=input_file[1],
                                    chromosome_startswith="chr",
                                    output_dir=parent_path,
                                    output_prefix=str(os.getpid()))

    div = pd.read_csv(div_file, sep="\t", header=None)
    div = div.loc[np.logical_or(div[0].isin(pl), div[0].isin(mn)), :]
    bid = pd.read_csv(bid_file, sep="\t", header=None)
    bid = bid.loc[np.logical_or(bid[0].isin(pl), bid[0].isin(mn)), :]
    single = pd.read_csv(single_file, sep="\t", header=None)
    single = single.loc[np.logical_or(single[0].isin(pl), single[0].isin(mn)), :]
    div["pl_summit"] = 0
    div["mn_summit"] = 0
    div["element_start"] = 0
    div["element_end"] = 0
    bid["pl_summit"] = 0
    bid["mn_summit"] = 0
    single["summit"] = 0
    for chromosome in pl:
        pl_cov = np.load(pl[chromosome], allow_pickle=True)
        mn_cov = np.load(mn[chromosome], allow_pickle=True)
        div_sub = div.loc[div[0] == chromosome, :]
        for nr, row in div_sub.iterrows():
            pcov = pl_cov[row[1]:row[2]]
            mcov = mn_cov[row[1]:row[2]]
            # cpl = np.argmax(pcov) + row[1]
            cpls = np.where(pcov == pcov.max())[0] + row[1]
            # cmn = np.argmax(mcov) + row[1]
            cmns = np.where(mcov == mcov.max())[0] + row[1]
            div.loc[nr, "pl_summit"] = ",".join([str(x) for x in cpls])
            div.loc[nr, "mn_summit"] = ",".join([str(x) for x in cmns])

            # extend boundaries with the following conditions:
            # Find the prominent peaks at basepair resolution (any peaks with ⅓ of the highest peak and >5 reads)
            # and extend x (60, 200, or others) bps beyond the furthest prominent peak
            plb = np.nanmax(pcov)
            mlb = np.nanmax(mcov)
            pl_threshold = plb * 0.3
            mn_threshold = mlb * 0.3
            if pl_threshold > 5:
                cpl = max(np.where(pcov > 5)[0][-1] + row[1], row[2])
            else:
                cpl = cpls[-1]
            if mn_threshold > 5:
                cmn = min(np.where(mcov > 5)[0][0] + row[1], row[1])
            else:
                cmn = cmns[0]

            f = min(cpl, cmn) - divergent_extension[0]
            r = max(cpl, cmn) + divergent_extension[1]
            div.loc[nr, "element_start"] = f
            div.loc[nr, "element_end"] = r

        bid_sub = bid.loc[bid[0] == chromosome, :]
        for nr, row in bid_sub.iterrows():
            pcov = pl_cov[row[1]:row[2]]
            mcov = mn_cov[row[1]:row[2]]
            # cpl = np.argmax(pcov) + row[1]
            cpls = np.where(pcov == pcov.max())[0] + row[1]
            # cmn = np.argmax(mcov) + row[1]
            cmns = np.where(mcov == mcov.max())[0] + row[1]
            bid.loc[nr, "pl_summit"] = ",".join([str(x) for x in cpls])
            bid.loc[nr, "mn_summit"] = ",".join([str(x) for x in cmns])

            plb = np.nanmax(pcov)
            mlb = np.nanmax(mcov)
            pl_threshold = plb * 0.3
            mn_threshold = mlb * 0.3
            if pl_threshold > 5:
                cpl = max(np.where(pcov > 5)[0][-1] + row[1], row[2])
            else:
                cpl = cpls[-1]
            if mn_threshold > 5:
                cmn = min(np.where(mcov > 5)[0][0] + row[1], row[1])
            else:
                cmn = cmns[0]

            f = min(cpl, cmn) - divergent_extension[0]
            r = max(cpl, cmn) + divergent_extension[1]
            bid.loc[nr, "element_start"] = f
            bid.loc[nr, "element_end"] = r
        # unidirectional elements are defined as:
        # peak region boundaries defined by PINTS
        # go upstream 300bp (we assume the opposite peak should be within 300 bp),
        # then further +60 or +200bp to define the whole element
        single_sub = single.loc[single[0] == chromosome, :]
        for nr, row in single_sub.iterrows():
            if row[5] == "+":
                f = row[1] - unidirectional_extension[0] - 300
                r = row[2] + unidirectional_extension[1]
            else:
                f = row[1] - unidirectional_extension[0]
                r = row[2] + unidirectional_extension[1] + 300

            single.loc[nr, "element_start"] = f
            single.loc[nr, "element_end"] = r
    div = div.loc[:, (0, "element_start", "element_end", 3, 4, 5)]
    div.element_start = div.element_start.astype(int)
    div.element_end = div.element_end.astype(int)
    div.loc[div.element_start < 0, "element_start"] = 0
    div["ID"] = ["Divergent" + str(i) for i in list(div.index)]

    bid = bid.loc[:, (0, "element_start", "element_end", 3, 4, 5)]
    bid.element_start = bid.element_start.astype(int)
    bid.element_end = bid.element_end.astype(int)
    bid.loc[bid.element_start < 0, "element_start"] = 0
    bid["ID"] = ["Bidirectional" + str(i) for i in list(bid.index)]

    single = single.loc[:, (0, "element_start", "element_end", 3, 4, 5, 15)]
    single.element_start = single.element_start.astype(int)
    single.element_end = single.element_end.astype(int)
    single.loc[single.element_start < 0, "element_start"] = 0
    single["ID"] = ["Single" + str(i) for i in list(single.index)]
    div.to_csv(div_file.replace(".bed", "_element_{de_tag}bp.bed".format(de_tag=de_tag)), sep="\t", index=False,
               header=False)
    bid.to_csv(bid_file.replace(".bed", "_element_{de_tag}bp.bed".format(de_tag=de_tag)), sep="\t", index=False,
               header=False)
    single_obj = BedTool(single.to_csv(sep="\t", index=False, header=False), from_string=True)
    div_obj = BedTool(div.to_csv(sep="\t", index=False, header=False), from_string=True)
    bid_obj = BedTool(bid.to_csv(sep="\t", index=False, header=False), from_string=True)
    single_obj = single_obj.intersect(div_obj, v=True)
    single_obj = single_obj.intersect(bid_obj, v=True)

    if promoter_bed is not None:
        promoter_bed_obj = BedTool(promoter_bed)
        BedTool.from_dataframe(div).intersect(promoter_bed_obj, v=True).saveas(
            div_file.replace(".bed", "_element_{de_tag}bp_e.bed".format(de_tag=de_tag)))
        BedTool.from_dataframe(bid).intersect(promoter_bed_obj, v=True).saveas(
            bid_file.replace(".bed", "_element_{de_tag}bp_e.bed".format(de_tag=de_tag)))
        single_obj.intersect(promoter_bed_obj, v=True).saveas(
            single_file.replace(".bed", "_element_{de_tag}bp_e.bed".format(de_tag=de_tag)))

    single_obj.moveto(single_file.replace(".bed", "_element_{de_tag}bp.bed".format(de_tag=de_tag)))
    housekeeping_files = []
    housekeeping_files.extend(pl.values())
    housekeeping_files.extend(mn.values())
    for hf in housekeeping_files:
        if os.path.exists(hf):
            try:
                os.remove(hf)
            except:
                pass


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Element boundary refinement")
    group = parser.add_mutually_exclusive_group(required=True)
    group.add_argument("--bam-files", action="store", dest="bam_files", nargs="*",
                       type=str, required=False,
                       help="input bam file, if you want to use bigwig files, please use --bw-pl and --bw-mn")
    group.add_argument("--bw-pl", action="store", dest="bw_pl", nargs="*",
                       type=str, required=False,
                       help="Bigwig for plus strand. If you want to use bigwig instead of BAM, "
                            "please set bam_file to bigwig")
    parser.add_argument("--bw-mn", action="store", dest="bw_mn", nargs="*",
                        type=str, required=False,
                        help="Bigwig for minus strand. If you want to use bigwig instead of BAM, "
                             "please set bam_file to bigwig")
    parser.add_argument("--exp-type", action="store", default=("CoPRO", ), dest="bam_parser",
                        choices=("CoPRO", "GROcap", "PROcap", "CAGE", "NETCAGE", "csRNAseq", "PROseq", "GROseq",
                                 "R_5", "R_3", "R1_5", "R1_3", "R2_5", "R2_3"),
                        help="Type of experiment, acceptable values are: CoPRO/GROcap/GROseq/PROcap/PROseq, or if you "
                             "know the position of RNA ends which you're interested on the reads, you can specify "
                             "R_5, R_3, R1_5, R1_3, R2_5 or R2_3", nargs="+", required=False)
    parser.add_argument("--divergent-files", help="Divergent peak call(s)", nargs="+")
    parser.add_argument("--bidirectional-files", help="Divergent peak call(s)", nargs="+")
    parser.add_argument("--unidirectional-files", help="Unidirectional peak call(s)", nargs="+")
    parser.add_argument("--promoter-bed", default=None,
                        help="Promoter bed, if specified, will create a separate bed file for distal elements")
    parser.add_argument("--save-to", help="save elements to")
    parser.add_argument("--div-ext-left", action="store", dest="div_ext_left", nargs="+",
                        type=int, required=False, default=(60,), help="divergent extension left")
    parser.add_argument("--div-ext-right", action="store", dest="div_ext_right", nargs="+",
                        type=int, required=False, default=(60,), help="divergent extension right")
    parser.add_argument("--single-ext-left", action="store", dest="single_ext_left", nargs="+",
                        type=int, required=False, default=(60, ), help="single extension left")
    parser.add_argument("--single-ext-right", action="store", dest="single_ext_right", nargs="+",
                        type=int, required=False, default=(60, ), help="single extension right")

    args = parser.parse_args()
    if sum((args.bw_pl is None, args.bw_mn is None)) == 1:
        parser.error("both of the two arguments --bw-pl --bw-mn are required")

    if args.bam_files is not None:
        if not len(args.bam_files) == len(args.divergent_files) == len(args.bidirectional_files) == len(
                args.unidirectional_files):
            parser.error("Number of peak calls from different categories should match")
        # if not (len(args.bam_parser) == 1 or len(args.bam_parser) == len(args.bam_files)):
        #     raise argparse.ArgumentError("You need to provide one or paired exp-type for bam-files")

    if args.bw_pl is not None:
        if not len(args.bw_pl) == len(args.bw_mn) == len(args.divergent_files) == len(args.bidirectional_files) == len(
                args.unidirectional_files):
            parser.error("Number of peak calls from different categories should match")

    assert len(args.div_ext_left) == len(args.div_ext_right)
    assert len(args.single_ext_left) == len(args.single_ext_right)

    groups = {}
    Group = namedtuple("Group", ("bam_file", "divergent_calls", "bidirectional_calls", "single_calls"))
    for i in range(1, (len(args.bam_files) if args.bam_files is not None else len(args.bw_pl)) + 1):
        groups = {
            "divergent_calls": None,
            "bidirectional_calls": None,
            "unidirectional_calls": None
        }

        element_types = ("divergent_peaks", "bidirectional_peaks", "unidirectional_peaks")

        for et in element_types:
            k = "_{index}_{et}.bed".format(index=i, et=et)
            for df in args.divergent_files:
                if df.find(k) != -1:
                    groups["divergent_calls"] = df

            for bf in args.bidirectional_files:
                if bf.find(k) != -1:
                    groups["bidirectional_calls"] = bf

            for sf in args.unidirectional_files:
                if sf.find(k) != -1:
                    groups["unidirectional_calls"] = sf
        for j in range(len(args.div_ext_left)):
            if args.bam_files is not None:
                main(args.bam_files[i - 1], args.bam_parser[i - 1] if len(args.bam_parser) > 1 else args.bam_parser[0],
                     groups["divergent_calls"], groups["bidirectional_calls"], groups["unidirectional_calls"],
                     divergent_extension=(args.div_ext_left[j], args.div_ext_right[j]),
                     unidirectional_extension=(args.single_ext_left[j], args.single_ext_right[j]),
                     promoter_bed=args.promoter_bed)
            else:
                main((args.bw_pl[i - 1], args.bw_mn[i - 1]), None,
                     groups["divergent_calls"], groups["bidirectional_calls"], groups["unidirectional_calls"],
                     divergent_extension=(args.div_ext_left[j], args.div_ext_right[j]),
                     unidirectional_extension=(args.single_ext_left[j], args.single_ext_right[j]),
                     promoter_bed=args.promoter_bed)
