#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import sys
import argparse
import logging
import bean as be
from bean.plotting.allele_stats import plot_n_alleles_per_guide, plot_n_guides_per_edit
import matplotlib.pyplot as plt

plt.style.use("default")
logging.basicConfig(
    level=logging.INFO,
    format="%(levelname)-5s @ %(asctime)s:\n\t %(message)s \n",
    datefmt="%a, %d %b %Y %H:%M:%S",
    stream=sys.stderr,
    filemode="w",
)
error = logging.critical
warn = logging.warning
debug = logging.debug
info = logging.info


def parse_args():
    """Get the input arguments"""
    print(
        r"""
    _ _         
  /  \ '\       __ _ _ _           
  |   \  \     / _(_) | |_ ___ _ _ 
   \   \  |   |  _| | |  _/ -_) '_|
    `.__|/    |_| |_|_|\__\___|_|  
    """
    )
    print("bean-filter: filter alleles")
    parser = argparse.ArgumentParser(
        prog="allele_filter",
        description="Filter alleles based on edit position in spacer and frequency across samples.",
    )
    parser.add_argument(
        "bdata_path",
        type=str,
        help="Input ReporterScreen file of which allele will be filtered out.",
    )
    parser.add_argument(
        "--output-prefix",
        "-o",
        type=str,
        default=None,
        help="Output prefix for log and ReporterScreen file with allele assignment",
    )
    parser.add_argument(
        "--plasmid-path",
        "-p",
        type=str,
        help="Plasmid ReporterScreen object path. If provided, alleles are filtered based on if a nucleotide edit is more significantly enriched in sample compared to the plasmid data. Negative control data where no edit is expected can be fed in instead of plasmid library.",
    )
    parser.add_argument(
        "--edit-start-pos",
        "-s",
        type=int,
        help="0-based start posiiton (inclusive) of edit relative to the start of guide spacer.",
    )
    parser.add_argument(
        "--edit-end-pos",
        "-e",
        type=int,
        help="0-based end position (exclusive) of edit relative to the start of guide spacer.",
    )
    parser.add_argument(
        "--jaccard-threshold",
        "-j",
        type=float,
        help="Jaccard Index threshold when the alleles are mapped to the most similar alleles. In each filtering step, allele counts of filtered out alleles will be mapped to the most similar allele only if they have Jaccard Index of shared edit higher than this threshold.",
        default=0.3,
    )
    parser.add_argument(
        "--filter-window",
        "-w",
        help="Only consider edit within window provided by (edit-start-pos, edit-end-pos). If this flag is not provided, `--edit-start-pos` and `--edit-end-pos` flags are ignored.",
        action="store_true",
    )
    parser.add_argument(
        "--filter-target-basechange",
        "-b",
        help="Only consider target edit (stored in bdata.uns['target_base_change'])",
        action="store_true",
    )
    parser.add_argument(
        "--translate", "-t", help="Translate alleles", action="store_true"
    )
    parser.add_argument(
        "--translate-fasta",
        "-f",
        type=str,
        help="fasta file with exon positions. If not provided, LDLR hg19 coordinates will be used.",
        default=None,
    )
    parser.add_argument(
        "--filter-allele-proportion",
        "-ap",
        type=float,
        default=None,
        help="If provided, alleles that exceed `filter_allele_proportion` in `filter-sample-proportion` will be retained.",
    )
    parser.add_argument(
        "--filter-allele-count",
        "-ac",
        type=int,
        default=5,
        help="If provided, alleles that exceed `filter_allele_proportion` AND `filter_allele_count` in `filter-sample-proportion` will be retained.",
    )
    parser.add_argument(
        "--filter-sample-proportion",
        "-sp",
        type=float,
        default=0.2,
        help="If `filter_allele_proportion` is provided, alleles that exceed `filter_allele_proportion` in `filter-sample-proportion` will be retained.",
    )
    return parser.parse_args()


def check_args(args):
    if args.output_prefix is None:
        args.output_prefix = args.bdata_path.rsplit(".h5ad", 1)[0] + "_alleleFiltered"
    info(f"Saving results to {args.output_prefix}")
    if args.filter_window:
        if args.edit_start_pos is None and args.edit_end_pos is None:
            raise ValueError(
                "Invalid arguments: --filter-window option set but none of --edit-start-pos and --edit-end-pos specified."
            )
        if args.edit_start_pos is None:
            warn(
                "--filter-window option set but none of --edit-start-pos not provided. Using 0 as its value."
            )
            args.edit_start_pos = 0
        if args.edit_end_pos is None:
            warn(
                "--filter-window option set but none of --edit-end-pos not provided. Using 20 as its value."
            )
            args.edit_end_pos = 20
    if args.filter_allele_proportion < 0 or args.filter_allele_proportion > 1:
        raise ValueError(
            "Invalid arguments: filter-allele-proportion should be in range [0, 1]."
        )
    if args.filter_sample_proportion < 0 or args.filter_sample_proportion > 1:
        raise ValueError(
            "Invalid arguments: filter-sample-proportion should be in range [0, 1]."
        )


if __name__ == "__main__":
    args = parse_args()
    check_args(args)
    bdata = be.read_h5ad(args.bdata_path)
    allele_df_keys = ["allele_counts"]
    info(
        f"Starting from .uns['allele_counts'] with {len(bdata.uns['allele_counts'])} alleles."
    )
    if os.path.exists(f"{args.output_prefix}.h5ad"):
        bdata = be.read_h5ad(f"{args.output_prefix}.h5ad")
    else:
        if args.plasmid_path is not None:
            info(
                "Filtering significantly more edited nucleotide per guide compared to plasmid library..."
            )
            plasmid_adata = be.read_h5ad(args.plasmid_path)
            plasmid_adata.uns[allele_df_keys[-1]] = plasmid_adata.uns[
                allele_df_keys[-1]
            ].loc[plasmid_adata.uns[allele_df_keys[-1]].allele.map(str) != "", :]

            (
                q_val_each,
                sig_allele_df,
            ) = be.an.filter_alleles.filter_alleles(
                bdata, plasmid_adata, filter_each_sample=True, run_parallel=True
            )
            bdata.uns["sig_allele_counts"] = sig_allele_df.reset_index(drop=True)
            allele_df_keys.append("sig_allele_counts")
            info(f"Filtered down to {len(bdata.uns['sig_allele_counts'])} alleles.")

        info("Filtering out edits outside spacer position...")
        bdata.uns[f"{allele_df_keys[-1]}_spacer"] = bdata.filter_allele_counts_by_pos(
            rel_pos_start=0,
            rel_pos_end=20,
            rel_pos_is_reporter=False,
            map_to_filtered=True,
            allele_uns_key=allele_df_keys[-1],
            jaccard_threshold=0.2,
        ).reset_index(drop=True)
        info(
            f"Filtered down to {len(bdata.uns[f'{allele_df_keys[-1]}_spacer'])} alleles."
        )
        allele_df_keys.append(f"{allele_df_keys[-1]}_spacer")

        if args.filter_window:
            info(
                f"Filtering out edits based on relatvie position in spacer: 0-based [{args.edit_start_pos},{args.edit_end_pos})..."
            )
            filtered_key = (
                f"{allele_df_keys[-1]}_{args.edit_start_pos}_{args.edit_end_pos}"
            )
            bdata.uns[filtered_key] = bdata.filter_allele_counts_by_pos(
                rel_pos_start=args.edit_start_pos,
                rel_pos_end=args.edit_end_pos,
                rel_pos_is_reporter=False,
                map_to_filtered=True,
                allele_uns_key=allele_df_keys[-1],
                jaccard_threshold=args.jaccard_threshold,
            ).reset_index(drop=True)
            allele_df_keys.append(filtered_key)
            info(f"Filtered down to {len(bdata.uns[filtered_key])} alleles.")

        if args.filter_target_basechange:
            filtered_key = (
                f"{allele_df_keys[-1]}_{bdata.base_edited_from}.{bdata.base_edited_to}"
            )
            info(f"Filtering out non-{bdata.uns['target_base_change']} edits...")
            bdata.uns[filtered_key] = bdata.filter_allele_counts_by_base(
                bdata.base_edited_from,
                bdata.base_edited_to,
                map_to_filtered=False,
                allele_uns_key=allele_df_keys[-1],
            ).reset_index(drop=True)
            info(f"Filtered down to {len(bdata.uns[filtered_key])} alleles.")
            allele_df_keys.append(filtered_key)

        if args.translate:
            info(
                "Translating alleles..."
            )  # TODO: Check & document custom fasta file for translation
            filtered_key = f"{allele_df_keys[-1]}_translated"
            bdata.uns[filtered_key] = be.translate_allele_df(
                bdata.uns[allele_df_keys[-1]], fasta_file=args.translate_fasta
            ).rename(columns={"allele": "aa_allele"})
            allele_df_keys.append(filtered_key)
            info(f"Filtered down to {len(bdata.uns[filtered_key])} alleles.")

        if args.filter_allele_proportion is not None:
            info(
                f"Filtering alleles for those have allele fraction {args.filter_allele_proportion} in at least {args.filter_sample_proportion*100}% of samples..."
            )
            filtered_key = f"{allele_df_keys[-1]}_prop{args.filter_allele_proportion}_{args.filter_sample_proportion}"
            bdata.uns[filtered_key] = be.an.filter_alleles.filter_allele_prop(
                bdata,
                allele_df_keys[-1],
                allele_prop_thres=args.filter_allele_proportion,
                allele_count_thres=args.filter_allele_count,
                sample_prop_thres=args.filter_sample_proportion,
                map_to_filtered=True,
                retain_max=True,
                allele_col=bdata.uns[allele_df_keys[-1]].columns[1],
                distribute=True,
                jaccard_threshold=args.jaccard_threshold,
            )
            allele_df_keys.append(filtered_key)
            info(f"Filtered down to {len(bdata.uns[filtered_key])} alleles.")
            info("Done filtering!")
        info(
            f"Saving ReporterScreen with filtered alleles at {args.output_prefix}.h5ad..."
        )
        bdata.write(f"{args.output_prefix}.h5ad")

    info("Plotting allele stats for each filtering step...")
    fig, ax = plt.subplots(len(allele_df_keys), 2, figsize=(6, 3 * len(allele_df_keys)))
    for i, key in enumerate(allele_df_keys):
        plot_n_alleles_per_guide(bdata, key, bdata.uns[key].columns[1], ax[i, 0])
        plot_n_guides_per_edit(bdata, key, bdata.uns[key].columns[1], ax[i, 1])
    plt.tight_layout()
    plt.savefig(f"{args.output_prefix}.filtered_allele_stats.pdf", bbox_inches="tight")
    info(
        f"Saving plotting result and log at {args.output_prefix}.[filtered_allele_stats.pdf, filter_log.txt]."
    )
    with open(f"{args.output_prefix}.filter_log.txt", "w") as out_log:
        for key in allele_df_keys:
            out_log.write(f"{key}\t{len(bdata.uns[key])}\n")
