import re
import sys
import os
from os.path import join
import shutil


def message(mes):
    sys.stderr.write("|---- " + mes + "\n")

def errormes(mes):
    sys.stderr.write("| ERROR ----" + mes + "\n")

configfile : "config.yaml"
cconfigfile : "config.yaml"
datadir = config["fastq"]
output_directory = config["output"]
scriptdir = config["Scripts"]
host_db = config["host_db"]
rRNA_bact=config["rRNA_bact"]
rRNA_host=config["rRNA_host"]
base_nr = config["base_nr"]
base_taxo = config["base_taxo"]
base_nt = config["base_nt"]
base_taxo_nt=config["base_taxo_nt"]
ext=config["ext"]
ext_R1=config["ext_R1"]
ext_R2=config["ext_R2"]
threads_default= config["threads_default"]
threads_Map_On_host= config["threads_Map_On_host"]
threads_Map_On_bacteria= config["threads_Map_On_bacteria"]
threads_Megahit_Assembly= config["threads_Megahit_Assembly"]
threads_Map_On_Assembly= config["threads_Map_On_Assembly"]
threads_Blast_contigs_on_nr= config["threads_Blast_contigs_on_nr"]
threads_Blast_contigs_on_nt= config["threads_Blast_contigs_on_nt"]

logclust  = os.getcwd()+"/logsclust/"

message(str(logclust))

READS, =  glob_wildcards(datadir+"{readfile}"+ext)
SAMPLES, = glob_wildcards(datadir+"{sample}"+ ext_R1+ext)


SNAKEMAKE_DIR = os.path.dirname(workflow.snakefile)
snakemake.utils.makedirs("cluster_log/")


NBSAMPLES = len(SAMPLES)
NBREADS = len(READS)
RUN = config["run"]
#message(str(READS))
message(str(NBSAMPLES)+" samples  will be analysed")
message(str(len(READS))+" fastq files  will be processed")
message("Run name: "+RUN)
if NBREADS != 2*NBSAMPLES:
    errormes("Please provide two reads file per sample")
    sys.exit()

rule final:
    input:
        expand( f"{output_directory}/logs/01_read_processing/Stats_contaminent_{{smp}}.txt", smp=SAMPLES),
        f"{output_directory}/logs/02_assembly_results/logsAssembly/{RUN}_assembly_stats.txt",
        expand(f"{output_directory}/logs/03_mapping_on_assembly/{{smp}}_insert_size_metrics_{RUN}.txt", smp=SAMPLES),
        expand(f"{output_directory}/logs/03_mapping_on_assembly/03_CountsMapping/{{smp}}_coverage_{RUN}.txt", smp=SAMPLES),
        f"{output_directory}/04_blast_output/03_coverage/count_contigs_raw_{RUN}.csv",
        expand(f"{output_directory}/05_depth/00_Coverage_raw/{{smp}}_coverage_raw_{RUN}.txt", smp=SAMPLES),
        f"{output_directory}/04_blast_output/05_taxonomy_nt/intergreted_vir_check_{RUN}.csv",
        f"{output_directory}/06_final_data/hosts_lineage_{RUN}.csv",
        f"{output_directory}/06_final_data/stats_run_{RUN}.csv",
        f"{output_directory}/06_final_data/{RUN}_range_10x_coverage.cov",
        f"{output_directory}/06_final_data/clustering_Report_{RUN}.html"

############################### Read processing #################################################
include: f"snakefiles/read_processing.snake"

############################### Read processing #################################################
include: f"snakefiles/assembly.snake"

############################ blast information & extract viral contig ###########################
include: f"snakefiles/map_on_assembly.snake"

############################ blast information & extract viral contig ###########################
include: f"snakefiles/blast_taxid.snake"

############################ Calculate depth for count table ####################################
include: f"snakefiles/depth.snake"


rule Create_logs_report:
    input:
        lin = rules.get_nt_lineage_from_taxids.output.lin,
        lineage = rules.complete_taxo.output.lineage,
        by_seq = rules.Join_seq_acc_taxo_nr.output.taxo,
        count_t = rules.Build_array_coverage_nr.output.by_seq
    output:
        f"{output_directory}/06_final_data/stats_run_{{RUN}}.csv"
    params:
        files=datadir,
        output_directory = output_directory,
        ext=ext_R1+ext,
        script=scriptdir+"create_results_doc_new.py"
    envmodules: config['module_file']
    shell:
        """
        python {params.script} {params.files} {ext_R1} {ext_R2} {ext} {RUN} {input.by_seq}   {input.lineage}    {input.count_t}  {output} {params.output_directory}
        """

rule Add_vir_to_blastn:
    input:
        virald_ids = rules.extract_viral_ids.output.taxo,
        blast_5_hit = rules.get_nr_lineage_from_taxids.output.blst,
        interg_virus = rules.intergreted_vir_check.output.tmp,
    output:
        blast_add_vir = f"{output_directory}/06_final_data/Blast_5_hit_{{RUN}}_add_vir.csv"
    envmodules: config['module_file']
    shell:
        """
        grep -Fw -f {input.virald_ids} {input.blast_5_hit}  | awk -F"[\t, ]" 'FNR==NR&&NR!=1{{a[$1]="YES"}} FNR!=NR{{OFS=",";if(FNR!=1){{$(NF+1)= (NR==1 ? "is_vir" : a[$1])}}1;print}}' {input.interg_virus} -  |awk -F',' '!$NF {{OFS=","; $NF="NO" }}1' > {output.blast_add_vir}
        sed -i "s/tax_id/tax_id,not_vir/" {output.blast_add_vir}
        """

rule Clustering_OTU:
    input:
        blast_file = rules.Add_vir_to_blastn.output.blast_add_vir,
        taxo_file = rules.get_nr_lineage_from_taxids.output.lin5c,
        host_file = rules.join_host_tax.output.host_lineage,
        count_file = rules.Build_array_coverage_nr.output.by_seq,
        cov10_file = rules.results_depth.output.range_10x,
        coverage_stat_file = rules.results_depth.output.stats_deph
    output:
        directory = directory(f"{output_directory}/06_final_data/{{RUN}}_OTU/"),
        report = f"{output_directory}/06_final_data/clustering_Report_{{RUN}}.html"
    params:
        script= f'{scriptdir}/clustering_Report.Rmd',
        min_read_by_sample = 1,
        min_read_by_contig = 1,
        run_name = RUN
    envmodules: config['module_file']
    shell:
        """
        Rscript -e 'rmarkdown::render("{params.script}", output_file="{output.report}", quiet=TRUE, params = list(output_directory = "{output.directory}", blast_file = "{input.blast_file}", taxo_file = "{input.taxo_file}", host_file = "{input.host_file}", count_file = "{input.count_file}", cov10_file = "{input.cov10_file}", coverage_stat_file = "{input.coverage_stat_file}", min_read_by_sample = {params.min_read_by_sample}, min_read_by_contig = {params.min_read_by_contig} ,run_name = "{params.run_name}"))'
        """
