from snakemake.utils import min_version
from pathlib import Path

import subprocess

from rich.console import Console
console = Console()

min_version("6.0")

SNAKE_DIR = Path(workflow.basedir)
SCRIPT_DIR = SNAKE_DIR/ "scripts"
ENV_DIR = SNAKE_DIR / "envs"
CONFIG_DIR = SNAKE_DIR.parent / "config"
LOG_DIR = Path("logs")
WARNINGS_DIR = LOG_DIR/"warnings"
WARNINGS_DIR.mkdir(parents=True, exist_ok=True)


# Defaults for config
ORTHOLOG_MIN_SEQS_DEFAULT = 5
ORTHOLOG_MIN_TAXA_DEFAULT = 5
MINIMUM_TRIMMED_ALIGNMENT_LENGTH_CDS_DEFAULT = 501
MINIMUM_TRIMMED_ALIGNMENT_LENGTH_PROTEINS_DEFAULT = 167
MAX_TRIMMED_PROPORTION_DEFAULT = 0.5
INFER_TREE_WITH_PROTEIN_SEQS_DEFAULT = True
INFER_TREE_WITH_CDS_SEQS_DEFAULT = False
USE_ORTHOFISHER_DEFAULT = False
BOOTSTRAP_STRING_DEFAULT = "-bb 1000"
MODEL_STRING_DEFAULT = "-m TEST"
SUPERMATRIX_OUTGROUP_DEFAULT = ""
TRANSLATION_TABLE_DEFAULT = "1"
IGNORE_NON_VALID_FILES_DEFAULT = False
PROTEIN_INPUT_DEFAULT = False

# Get the path to the default config file.
# This can be overridden on the command line
configfile: CONFIG_DIR / "config.yml"

report: "report/workflow.rst"


##### load rules #####
include: "rules/intake.smk"
include: "rules/orthofisher.smk"
include: "rules/orthofinder.smk"
include: "rules/alignment.smk"
include: "rules/supermatrix.smk"
include: "rules/gene_tree.smk"
include: "rules/supertree.smk"
include: "rules/summary.smk"
include: "rules/report.smk"


onstart:
    # Print some environment info
    print("Workflow directories:")
    path_vars = (
        'SNAKE_DIR',
        'SCRIPT_DIR',
        'CONFIG_DIR',
        'LOG_DIR',
    )
    for name in path_vars:
        print(f"\t{name:20s} ➡  {str(globals()[name])}")

    print("Environment:")
    shell = lambda cmd: subprocess.run(cmd, shell=True, stdout=subprocess.PIPE).stdout.decode().rstrip()
    print(f"\t{shell('python --version'):20s} ➡  {shell('which python')}")
    print(f"\t{shell('conda --version'):20s} ➡  {shell('which conda')}")
    print(f"\t{' '.join(('snakemake', shell('snakemake --version'))):20s} ➡  {shell('which snakemake')}")

    # Ensure all required directories exist
    for name in path_vars:
        globals()[name].mkdir(exist_ok=True)
        assert globals()[name].exists()

def get_final_reports(wildcards):
    reports = []
    if config.get("infer_tree_with_protein_seqs", INFER_TREE_WITH_PROTEIN_SEQS_DEFAULT):
        reports.append(rules.report.output[0].format(alignment_type="protein"))
    
    if config.get("infer_tree_with_cds_seqs", INFER_TREE_WITH_CDS_SEQS_DEFAULT):
        reports.append(rules.report.output[0].format(alignment_type="cds"))

    if not config.get("infer_tree_with_protein_seqs", INFER_TREE_WITH_PROTEIN_SEQS_DEFAULT) and not config.get("infer_tree_with_cds_seqs", INFER_TREE_WITH_CDS_SEQS_DEFAULT):
        reports.append(rules.report.output[0].format(alignment_type="protein"))
    
    return reports


rule all:
    input: get_final_reports
    shell:
        """
        rm -rf results/orthofinder/tmp/
        rm -rf results/orthofinder/output/WorkingDirectory/
        rm -rf results/orthofisher/output/hmmsearch_output/
        rm -rf results/orthofisher/output/scog/
        """

def print_warning(text, warning_style=None):
    if not warning_style:
        warning_style = "bold white on red"  
    
    longest = max(text.split("\n"), key=len)

    console.print("-"*len(longest), style=warning_style)
    console.print(text, style=warning_style)
    console.print("-"*len(longest), style=warning_style)

onsuccess:
    warning_message = ""
    count = 0
    for warnings_file in WARNINGS_DIR.glob("*.txt"):
        temp = warnings_file.read_text()
        warning = temp.split("\n")[0]
        if warning:
            count += 1
            if count > 1:
                warning_message += "\n"
            warning_message += "WARNING " + str(count) + " --> " + str(warning)
    if len(warning_message) > 0:
        print_warning("There are warnings. Please check the warnings tab in the report or the warnings folder in the log.")
        print_warning(warning_message, warning_style="bold white")
    else:
        print_warning("Program ended without warnings", warning_style="bold white")

