#!/usr/bin/env python
import os
import sys
import logging
from copy import deepcopy
from functools import partial
import numpy as np
import pandas as pd

from tqdm.auto import tqdm
import torch
import pyro
import pyro.infer
import pyro.optim
import pickle as pkl
import argparse
import bean.model.model as m
from bean.model.readwrite import write_result_table
from bean.preprocessing.data_class import (
    VariantSortingScreenData,
    VariantSortingReporterScreenData,
    TilingSortingReporterScreenData,
)
from bean.preprocessing.utils import (
    _obtain_effective_edit_rate,
    _obtain_n_guides_alleles_per_variant,
)

import bean as be

logging.basicConfig(
    level=logging.INFO,
    format="%(levelname)-5s @ %(asctime)s:\n\t %(message)s \n",
    datefmt="%a, %d %b %Y %H:%M:%S",
    stream=sys.stderr,
    filemode="w",
)
error = logging.critical
warn = logging.warning
debug = logging.debug
info = logging.info
pyro.set_rng_seed(101)

DATACLASS_DICT = {
    "Normal": VariantSortingReporterScreenData,
    "MixtureNormal": VariantSortingReporterScreenData,
    "_MixtureNormal+Acc": VariantSortingReporterScreenData,  # TODO: old
    "MixtureNormal+Acc": VariantSortingReporterScreenData,
    "MixtureNormalConstPi": VariantSortingScreenData,
    "MultiMixtureNormal": TilingSortingReporterScreenData,
    "MultiMixtureNormal+Acc": TilingSortingReporterScreenData,
}


def run_inference(
    model, guide, data, initial_lr=0.01, gamma=0.1, num_steps=2000, autoguide=False
):
    pyro.clear_param_store()
    lrd = gamma ** (1 / num_steps)
    svi = pyro.infer.SVI(
        model=model,
        guide=guide,
        optim=pyro.optim.ClippedAdam({"lr": initial_lr, "lrd": lrd}),
        loss=pyro.infer.Trace_ELBO(),
    )

    losses = []

    try:
        for t in tqdm(range(num_steps)):
            loss = svi.step(data)
            if t % 100 == 0:
                print(f"loss {loss} @ iter {t}")
            losses.append(loss)
    except ValueError as exc:
        error(
            "Error occurred during fitting. Saving temporary output at tmp_result.pkl."
        )
        with open("tmp_result.pkl", "wb") as handle:
            pkl.dump({"param": pyro.get_param_store()}, handle)

        raise ValueError(
            f"Fitting halted for command: {' '.join(sys.argv)} with following error: \n {exc}"
        )
    return {
        "loss": losses,
        "params": pyro.get_param_store(),
    }


def _get_guide_target_info(bdata):
    guide_info = bdata.guides.copy()
    edit_rate_info = (
        guide_info[[args.target_col, "edit_rate"]]
        .groupby(args.target_col, sort=False)
        .agg({"edit_rate": ["mean", "std"]})
    )
    edit_rate_info.columns = edit_rate_info.columns.get_level_values(1)
    edit_rate_info = edit_rate_info.rename(
        columns={"mean": "edit_rate_mean", "std": "edit_rate_std"}
    )
    target_info = (
        guide_info[
            [args.target_col]
            + [
                col
                for col in guide_info.columns
                if col.startswith("target_")
                and len(guide_info[[args.target_col, col]].drop_duplicates())
                == len(guide_info[args.target_col].unique())
                and col != args.target_col
            ]
        ]
        .drop_duplicates()
        .set_index(args.target_col, drop=True)
    )
    target_info = target_info.join(edit_rate_info)
    return target_info


def parse_args():
    print(
        r"""
    _ _       
  /  \ '\                 
  |   \  \      _ _ _  _ _ _  
   \   \  |    | '_| || | ' \ 
    `.__|/     |_|  \_,_|_||_|
    """
    )
    print("bean-run: Run model to identify targeted variants and their impact.")
    parser = argparse.ArgumentParser(description="Run model on data.")
    parser.add_argument(
        "mode",
        type=str,
        help="[variant, tiling]- Screen type whether to run variant or tiling screen model.",
    )
    parser.add_argument("bdata_path", type=str, help="Path of an ReporterScreen object")
    parser.add_argument(
        "--rep-pi",
        "-r",
        action="store_true",
        default=False,
        help="Fit replicate specific scaling factor. Recommended to set as True if you expect variable editing activity across biological replicates.",
    )
    parser.add_argument(
        "--uniform-edit",
        "-p",
        action="store_true",
        default=False,
        help="Assume uniform editing rate for all guides.",
    )
    parser.add_argument(
        "--scale-by-acc",
        action="store_true",
        default=False,
        help="Scale guide editing efficiency by the target loci accessibility",
    )
    parser.add_argument(
        "--acc-bw-path",
        type=str,
        default=None,
        help="Accessibility .bigWig file to be used to assign accessibility of guides.",
    )
    parser.add_argument(
        "--acc-col",
        type=str,
        default=None,
        help="Column name in bdata.guides that specify raw ATAC-seq signal.",
    )
    parser.add_argument(
        "--const-pi",
        default=False,
        action="store_true",
        help="Use constant pi provided in --guide-activity-col (instead of fitting from reporter data)",
    )
    parser.add_argument(
        "--condition-col",
        default="bin",
        type=str,
        help="Column key in `bdata.samples` that describes experimental condition.",
    )
    parser.add_argument(
        "--target-col",
        default="target",
        type=str,
        help="Column key in `bdata.guides` that describes the target element of each guide.",
    )
    parser.add_argument(
        "--guide-activity-col",
        "-a",
        type=str,
        default=None,
        help="Column in ReporterScreen.guides DataFrame showing the editing rate estimated via external tools",
    )
    parser.add_argument(
        "--outdir",
        "-o",
        default=None,
        type=str,
        help="Directory to save the run result.",
    )
    parser.add_argument(
        "--result-suffix",
        default="",
        type=str,
        help="Suffix of the output files",
    )
    parser.add_argument(
        "--sorting-bin-upper-quantile-col",
        "-uq",
        help="Column name with upper quantile values of each sorting bin in [Reporter]Screen.samples (or AnnData.var)",
        default="upper_quantile",
    )
    parser.add_argument(
        "--sorting-bin-lower-quantile-col",
        "-lq",
        help="Column name with lower quantile values of each sorting bin in [Reporter]Screen.samples (or AnnData var)",
        default="lower_quantile",
    )
    parser.add_argument("--cuda", action="store_true", default=False, help="run on GPU")
    parser.add_argument(
        "--sample-mask-col",
        type=str,
        default=None,
        help="Name of the column indicating the sample mask in [Reporter]Screen.samples (or AnnData.var). Sample is ignored if the value in this column is 0. This can be used to mask out low-quality samples.",
    )
    parser.add_argument(
        "--fit-negctrl",
        action="store_true",
        default=False,
        help="Fit the shared negative control distribution to normalize the fitted parameters",
    )
    parser.add_argument(
        "--negctrl-col",
        type=str,
        default="target_group",
        help="Column in bdata.obs specifying if a guide is negative control. If the `bdata.guides[negctrl_col].tolower() == negctrl_col_value`, it is treated as negative control guide.",
    )
    parser.add_argument(
        "--negctrl-col-value",
        type=str,
        default="negctrl",
        help="Column value in bdata.guides specifying if a guide is negative control. If the `bdata.guides[negctrl_col].tolower() == negctrl_col_value`, it is treated as negative control guide.",
    )
    parser.add_argument(
        "--repguide-mask",
        type=str,
        default="repguide_mask",
        help="n_replicate x n_guide mask to mask the outlier guides. screen.uns[repguide_mask] will be used.",
    )
    parser.add_argument(
        "--device",
        type=str,
        default=None,
        help="Optionally use GPU if provided valid GPU device name (ex. cuda:0)",
    )
    parser.add_argument(
        "--ignore-bcmatch",
        action="store_true",
        default=False,
        help="If provided, even if the screen object has .X_bcmatch, ignore the count when fitting.",
    )
    parser.add_argument(
        "--allele-df-key",
        type=str,
        default=None,
        help="screen.uns[allele_df_key] will be used as the allele count.",
    )
    parser.add_argument(
        "--splice-site-path",
        type=str,
        default=None,
        help="Path to splicing site",
    )
    parser.add_argument(
        "--control-guide-tag",
        type=str,
        default="CONTROL",
        help="If this string is in guide name, treat each guide separately not to mix the position. Used for negative controls.",
    )
    parser.add_argument(
        "--dont-fit-noise",  # TODO: add check args
        action="store_true",
    )
    parser.add_argument(
        "--load-existing",  # TODO: add check args
        action="store_true",
        help="Load existing .pkl file if present.",
    )

    return parser.parse_args()


def check_args(args):
    if args.scale_by_acc:
        if args.acc_col is None and args.acc_bw_path is None:
            raise ValueError(
                "--scale-by-acc not accompanied by --acc-col nor --acc-bw-path to use. Pass either one."
            )
        elif args.acc_col is not None and args.acc_bw_path is not None:
            warn(
                "Both --acc-col and --acc-bw-path is specified. --acc-bw-path is ignored."
            )
            args.acc_bw_path = None
    if args.outdir is None:
        args.outdir = os.path.dirname(args.bdata_path)
    if args.mode == "variant":
        pass
    elif args.mode == "tiling":
        if args.allele_df_key is None:
            raise ValueError(
                "--allele-df-key not provided for tiling screen. Feed in the key then allele counts in screen.uns[allele_df_key] will be used."
            )
    else:
        raise ValueError(
            "Invalid mode provided. Select either 'variant' or 'tiling'."
        )  # TODO: change this into discrete modes via argparse
    if args.fit_negctrl:
        bdata = be.read_h5ad(args.bdata_path)
        n_negctrl = (
            bdata.guides[args.negctrl_col].map(lambda s: s.lower())
            == args.negctrl_col_value.lower()
        ).sum()
        if not n_negctrl >= 20:
            raise ValueError(
                f"Not enough negative control guide in the input data: {n_negctrl}. Check your input arguments."
            )
    return args


def identify_model_guide(args):
    if args.mode == "tiling":
        info("Using Mixture Normal model...")
        return (
            f"MultiMixtureNormal{'+Acc' if args.scale_by_acc else ''}",
            partial(
                m.MultiMixtureNormalModel,
                scale_by_accessibility=args.scale_by_acc,
                use_bcmatch=~args.ignore_bcmatch,
            ),
            partial(
                m.MultiMixtureNormalGuide,
                scale_by_accessibility=args.scale_by_acc,
                fit_noise=~args.dont_fit_noise,
            ),
        )
    if args.uniform_edit:
        if args.guide_activity_col is not None:
            raise ValueError(
                "Can't use the guide activity column while constraining uniform edit."
            )
        info("Using Normal model...")
        return (
            "Normal",
            partial(m.NormalModel, use_bcmatch=~args.ignore_bcmatch),
            m.NormalGuide,
        )
    elif args.const_pi:
        if args.guide_activity_col is not None:
            raise ValueError(
                "--guide-activity-col to be used as constant pi is not provided."
            )
        info("Using Mixture Normal model with constant weight ...")
        return (
            "MixtureNormalConstPi",
            partial(m.MixtureNormalConstPiModel, use_bcmatch=~args.ignore_bcmatch),
            m.MixtureNormalGuide,
        )
    else:
        info(
            f"Using Mixture Normal model {'with accessibility normalization' if args.scale_by_acc else ''}..."
        )
        return (
            f"{'_' if args.dont_fit_noise else ''}MixtureNormal{'+Acc' if args.scale_by_acc else ''}",
            partial(
                m.MixtureNormalModel,
                scale_by_accessibility=args.scale_by_acc,
                use_bcmatch=~args.ignore_bcmatch,
            ),
            partial(
                m.MixtureNormalGuide,
                scale_by_accessibility=args.scale_by_acc,
                fit_noise=~args.dont_fit_noise,
            ),
        )


def main(args):
    if args.cuda:
        os.environ["CUDA_VISIBLE_DEVICES"] = "1"
        torch.set_default_tensor_type(torch.cuda.FloatTensor)
    else:
        torch.set_default_tensor_type(torch.FloatTensor)
    prefix = (
        args.outdir
        + "/bean_run_result."
        + os.path.basename(args.bdata_path).rsplit(".", 1)[0]
    )
    os.makedirs(prefix, exist_ok=True)
    model_label, model, guide = identify_model_guide(args)
    bdata = be.read_h5ad(args.bdata_path)
    guide_index = bdata.guides.index
    info("Done loading data. Preprocessing...")
    bdata.samples["rep"] = bdata.samples["rep"].astype("category")
    bdata.guides = bdata.guides.loc[:, ~bdata.guides.columns.duplicated()].copy()
    if args.mode == "variant":
        if bdata.guides[args.target_col].isnull().any():
            raise ValueError(
                f"Some target column (bdata.guides[{args.target_col}]) value is null. Check your input file."
            )
        bdata = bdata[bdata.guides[args.target_col].argsort(), :]
    ndata = DATACLASS_DICT[model_label](
        screen=bdata,
        device=args.device,
        repguide_mask=args.repguide_mask,
        sample_mask_column=args.sample_mask_col,
        accessibility_col=args.acc_col,
        accessibility_bw_path=args.acc_bw_path,
        use_const_pi=args.const_pi,
        condition_column=args.condition_col,
        allele_df_key=args.allele_df_key,
        control_guide_tag=args.control_guide_tag,
        target_col=args.target_col,
    )
    if args.mode == "variant":
        target_info_df = _get_guide_target_info(ndata.screen)
    else:
        if args.splice_site_path is not None:
            splice_site = pd.read_csv(args.splice_site_path).pos
        target_info_df = be.an.translate_allele.annotate_edit(
            pd.DataFrame(pd.Series(ndata.edit_index))
            .reset_index()
            .rename(columns={"index": "edit"}),
            splice_sites=None if args.splice_site_path is None else splice_site,
        )
        target_info_df["effective_edit_rate"] = _obtain_effective_edit_rate(ndata).cpu()
        target_info_df["n_guides"] = _obtain_n_guides_alleles_per_variant(ndata).cpu()

    guide_info_df = ndata.screen.guides
    del bdata

    info(f"Running inference for {model_label}...")

    if args.load_existing:
        with open(f"{prefix}/{model_label}.result.pkl", "rb") as handle:
            param_history_dict = pkl.load(handle)
    else:
        param_history_dict = deepcopy(run_inference(model, guide, ndata))
        if args.fit_negctrl:
            negctrl_model = m.ControlNormalModel
            negctrl_guide = m.ControlNormalGuide
            negctrl_idx = np.where(
                guide_info_df[args.negctrl_col].map(lambda s: s.lower())
                == args.negctrl_col_value.lower()
            )[0]
            print(len(negctrl_idx))
            print(negctrl_idx.shape)
            ndata_negctrl = ndata[negctrl_idx]
            param_history_dict["negctrl"] = run_inference(
                negctrl_model, negctrl_guide, ndata_negctrl
            )

    outfile_path = (
        f"{prefix}/bean_element[sgRNA]_result.{model_label}{args.result_suffix}.csv"
    )
    info(f"Done running inference. Writing result at {outfile_path}...")
    with open(f"{prefix}/{model_label}.result{args.result_suffix}.pkl", "wb") as handle:
        try:
            pkl.dump(param_history_dict, handle)
        except TypeError as exc:
            print(exc.message)
            # print(param_history_dict)
    write_result_table(
        target_info_df,
        param_history_dict,
        model_label=model_label,
        prefix=f"{prefix}/",
        suffix=args.result_suffix,
        guide_index=guide_index,
        guide_acc=ndata.guide_accessibility.cpu().numpy()
        if hasattr(ndata, "guide_accessibility")
        and ndata.guide_accessibility is not None
        else None,
    )
    info("Done!")


if __name__ == "__main__":
    args = parse_args()
    args = check_args(args)
    main(args)
