import sys
import numpy as np
import pandas as pd
from os import path

# ======================
# Configuration Handling
# ======================

# Define paths to key files and directories
simulation_data = path.join('config', 'simulations.csv')
configfile: path.join('config', 'params.yml')       # Configuration file for Snakemake

# Retrieve essential parameters from the config
number_frames = config.get("number_frames", 100)  # Default to 100 if not set
replicates = config.get("replicates", 1)  # Default to 1 if not set

# Extract the mode: 
# a) PPi
# b) protein
# c) small Molecule
try:
    mode = [arg for arg in sys.argv[1:] if not arg.startswith("-")][1]
except:
    mode = 'PPi'


# ==================================
# Random Seed Generation for Replicates
# ==================================
seed = config.get("seed", 23)
np.random.seed(seed)  # Ensure reproducibility
seeds = np.random.randint(100, 1000, size=replicates)
representative_seed = seeds[0] # First seeds is used as representation for surface interaction analysis

# ==================================
# Simulation Data Preprocessing
# ==================================
# Load simulation conditions and preprocess data

simulations_df = pd.read_csv(simulation_data)

# Handle small molecule input data
if mode == 'molecule':
    ligs = pd.read_csv(path.join('config', 'ligands.csv'))
    simulations_df = pd.merge(ligs, simulations_df,  how="cross")
    simulations_df['sdf'] = 'ligands/' + simulations_df.ligand + ".sdf"


simulations_df['complex'] = simulations_df['target'] + '_' + simulations_df['ligand']
simulations_df['name'] = simulations_df['complex'] + '_' + simulations_df['mutation_all']
simulations_df.set_index('name', inplace=True)

print(simulations_df)

# Define global variables for easy access
complexes = simulations_df.complex.unique()
mutations = simulations_df.mutation_all.unique()

rule molecule:
    input:
        'results/metaReport.html',
        #'results/fingerprints/interactions.parquet',
        #expand('results/interactionSurface/{complex}.{mutation}.interaction.pdb', complex=complexes, mutation=mutations),
        #expand('{complex}/{mutation}/{seed}/fingerprint/fingerprint.parquet',complex=complexes,mutation=mutations,seed=seeds),
        #expand('{complex}/{mutation}/{seed}/analysis/RMSF.html', complex=complexes, mutation=mutations, seed=seeds),
        expand('{complex}/{mutation}/{seed}/frames/lig_{i}.pdb', i=range(number_frames), complex=complexes, mutation=mutations, seed=seeds),
        expand('{complex}/{mutation}/{seed}/frames/rec_{i}.pdb', i=range(number_frames), complex=complexes, mutation=mutations, seed=seeds),
        expand('{complex}/{mutation}/{seed}/po-sco/{i}.txt', i=range(number_frames), complex=complexes, mutation=mutations, seed=seeds),
        'results/posco/posco_interactions.parquet',
        'results/posco/posco.html',
        expand('{complex}/{mutation}/{seed}/aquaduct/aquaduct.pse',complex=complexes, mutation=mutations, seed=seeds)


rule PPi:
    input:
        'results/fingerprints/interactions.parquet',
        'results/metaReport.html',
        expand('results/interactionSurface/{complex}.{mutation}.interaction.pdb', complex=complexes, mutation=mutations),
        expand('{complex}/{mutation}/{seed}/fingerprint/fingerprint.parquet',complex=complexes,mutation=mutations,seed=seeds),
        expand('{complex}/{mutation}/{seed}/analysis/RMSF.html', complex=complexes, mutation=mutations, seed=seeds),
        expand('{complex}/{mutation}/{seed}/frames/lig_{i}.pdb', i=range(number_frames), complex=complexes, mutation=mutations, seed=seeds),
        expand('{complex}/{mutation}/{seed}/frames/rec_{i}.pdb', i=range(number_frames), complex=complexes, mutation=mutations, seed=seeds),
        expand('{complex}/{mutation}/{seed}/po-sco/{i}.txt', i=range(number_frames), complex=complexes, mutation=mutations, seed=seeds),
        'results/posco/posco_interactions.parquet',
        'results/posco/lig_heatmap.svg',
        'results/posco/lig_barplot.svg',
        'results/rmsf/rmsf.svg'
        #expand('{complex}/{mutation}/{seed}/aquaduct/aquaduct.pse',complex=complexes, mutation=mutations, seed=seeds)

rule protein:
    input:
        expand('{complex}/{mutation}/{seed}/MD/traj_center.dcd', complex=complexes, mutation=mutations, seed=seeds),
        expand('{complex}/{mutation}/{seed}/analysis/RMSF.html', complex=complexes, mutation=mutations, seed=seeds),
        expand('{complex}/{mutation}/{seed}/aquaduct/aquaduct.pse',complex=complexes, mutation=mutations, seed=seeds)


rule metaReport:
    input:
        params='config/params.yml',
        sims='config/simulations.csv'
    output:
        report('results/metaReport.html',caption="RMSF.rst",category="Meta"),
    shell:
        """
        metaReport.py  --param {input.params}  \
                       --sims {input.sims}  \
                       --output {output}  \
        """

rule Mutagensis:
    input:
        'config/params.yml',
        'config/simulations.csv'
    output:
        '{complex}/{mutation}/mutation.pdb'
    params:
        out_dir=directory('{complex}/{mutation}'),
        pdb= lambda wildcards: simulations_df.loc[f'{wildcards.complex}_{wildcards.mutation}']['input'],
        foldX='foldx_20251231'              # Needs to be updated every year
    log:
        '{complex}/{mutation}/mutation.log'
    shell:
        """
        if [[ {wildcards.mutation} = WT ]]
        then    # No mutagenisis
            cp {params.pdb} {output}
        else    # Mutate
            1_mutation.py --mutation {wildcards.mutation} --ligand {params.pdb} --output {wildcards.complex}/{wildcards.mutation}/mutant_file.txt
            cp {params.pdb} {params.out_dir}/WT.pdb
            {params.foldX} --command=BuildModel --pdb-dir="{params.out_dir}" --pdb=WT.pdb --mutant-file={wildcards.complex}/{wildcards.mutation}/mutant_file.txt --output-dir="{params.out_dir}" --rotabaseLocation ~/tools/foldx/rotabase.txt > {log} || true
            mv {params.out_dir}/WT_1.pdb {output}
        fi
        """


def get_sdf_input(wildcards):
    if mode == "molecule":
        return simulations_df.loc[f'{wildcards.complex}_{wildcards.mutation}']['sdf']
    else:
        return None  # Don't require an sdf input


rule MD:
    input:
        md_settings=ancient('config/params.yml'),
        pdb='{complex}/{mutation}/mutation.pdb',
        sdf=simulations_df.loc[f'{wildcards.complex}_{wildcards.mutation}']['sdf'] if mode=="molecule" else [],
        
    output:
        topo = '{complex}/{mutation}/{seed}/MD/frame_end.cif',
        traj=temp('{complex}/{mutation}/{seed}/MD/trajectory.h5'),
        stats='{complex}/{mutation}/{seed}/MD/MDStats.csv',
        metadynamics='{complex}/{mutation}/{seed}/MD/metadynamics/metadynamics.txt',
        params='{complex}/{mutation}/{seed}/MD/params.yml'
    resources:
        gpu=1
    priority:
        3
    shell:
        """
        2_MD.py --pdb {input.pdb} \
                --topo {output.topo} \
                --traj {output.traj} \
                --md_settings {input.md_settings} \
                --params {output.params} \
                --seed {wildcards.seed} \
                --stats {output.stats} \
                --metadynamics {output.metadynamics}
        """

rule centerMDTraj:
    input:
        topo = '{complex}/{mutation}/{seed}/MD/frame_end.cif',
        traj='{complex}/{mutation}/{seed}/MD/trajectory.h5',
    output:
        topo_center = '{complex}/{mutation}/{seed}/MD/topo_center.pdb',
        traj_center='{complex}/{mutation}/{seed}/MD/traj_center.dcd',
    priority:
        3
    shell:
        """
        4_centerMDTraj.py   --topo {input.topo} \
                            --traj {input.traj} \
                            --traj_center {output.traj_center} \
                            --topo_center {output.topo_center}
        """


rule aquaduct:
    input:
        topo_center = '{complex}/{mutation}/{seed}/MD/topo_center.pdb',
        traj_center='{complex}/{mutation}/{seed}/MD/traj_center.dcd',
        aquaduct_template='config/TEMPLATE_aquaduct.txt',
    output:
        aquaduct='{complex}/{mutation}/{seed}/aquaduct/aquaduct.txt',
        aquaduct_folder=directory('{complex}/{mutation}/{seed}/aquaduct/'),
        pymol_cmd='{complex}/{mutation}/{seed}/aquaduct/6_visualize_results.py',
    priority:
        10
    params:
        directory('{complex}/{mutation}/{seed}')
    conda:
        "aquaduct"
    shell:
        """
        sed  's|PREFIXINPUT|{params}|g' {input.aquaduct_template} > {output.aquaduct} &&
        sed  -i 's|PREFIXOUT|{output.aquaduct_folder}|g' {output.aquaduct} &&
        valve_run -c {output.aquaduct}
        """

rule aquaductPymol:
    input:
        '{complex}/{mutation}/{seed}/aquaduct/6_visualize_results.py',
    output:
        '{complex}/{mutation}/{seed}/aquaduct/aquaduct.pse',
    priority:
        10
    params:
        directory('{complex}/{mutation}/{seed}')
    shell:
        """
        python3 {input} --save-session {output}
        """

rule DescriptiveTrajAnalysis:
    input:
        topo='{complex}/{mutation}/{seed}/MD/frame_end.cif',
        traj='{complex}/{mutation}/{seed}/MD/traj_center.dcd',
        stats='{complex}/{mutation}/{seed}/MD/MDStats.csv'
    output:
        rmsf= report('{complex}/{mutation}/{seed}/analysis/RMSF.html',caption="RMSF.rst",category="RMSF",labels={"Complex": "{complex}", "Mutation": "{mutation}"}),
        rmsd= report('{complex}/{mutation}/{seed}/analysis/RMSD.svg',caption="RMSF.rst",category="RMSD",labels={"Complex": "{complex}", "Mutation": "{mutation}"}),
        rmsd_raw= '{complex}/{mutation}/{seed}/analysis/RMSD.csv',
        bfactors='{complex}/{mutation}/{seed}/analysis/bfactors.pdb',
        stats='{complex}/{mutation}/{seed}/analysis/Stats.svg',
    params:
        number_frames = config.get('number_frames')
    shell:
        """
        3.1_ExplorativeTrajectoryAnalysis.py --topo {input.topo} \
                                                   --traj {input.traj} \
                                                   --stats {input.stats} \
                                                   --rmsf {output.rmsf} \
                                                   --bfactors {output.bfactors} \
                                                   --rmsd {output.rmsd} \
                                                   --rmsd_raw {output.rmsd_raw} \
                                                   --fig_stats {output.stats}
        """

rule calculateRMSF:
    input:
        topo = expand('{complex}/{mutation}/{seed}/MD/frame_end.cif',complex=complexes, mutation=mutations, seed=seeds),
        traj = expand('{complex}/{mutation}/{seed}/MD/traj_center.dcd',  complex=complexes, mutation=mutations, seed=seeds), 
    output:
        'results/rmsf/rmsf.parquet'
    shell:
        """
        3.2_Compute_RMSF.py --topo {input.topo} \
                            --traj {input.traj} \
                            --output {output}
        """

rule visualiseRMSF:
    input:
        'results/rmsf/rmsf.parquet'
    output:
        'results/rmsf/rmsf.svg'
    shell:
        """
        3.3_Visualize_RMSF.py --input {input} \
                              --output {output}
        """


rule ExtractithFrameFromTrajectory:
    input:
        topo='{complex}/{mutation}/{seed}/MD/frame_end.cif',
        traj='{complex}/{mutation}/{seed}/MD/traj_center.dcd'
    output:
        lig_frame='{complex}/{mutation}/{seed}/frames/lig_{i}.pdb',
        rec_frame='{complex}/{mutation}/{seed}/frames/rec_{i}.pdb'
    params:
        sequence='{complex}/{mutation}/{seed}/frames/sequence.parquet'
    threads: 1
    shell:
        """
        5.1_Posco_ExtractLastFrames.py  --topo {input.topo} \
                                        --traj {input.traj} \
                                        --frame {wildcards.i} \
                                        --lig_frame {output.lig_frame} \
                                        --rec_frame {output.rec_frame} \
                                        --sequence {params.sequence}
        """


# PosCo
rule posco:
    input:
        lig_frames='{complex}/{mutation}/{seed}/frames/lig_{i}.pdb',
        rec_frames='{complex}/{mutation}/{seed}/frames/rec_{i}.pdb',
    output:
        '{complex}/{mutation}/{seed}/po-sco/{i}.txt'
    threads: 1
    shell:
        """
        po-sco {input.rec_frames} {input.lig_frames} -b  > {output}
        """

# PosCo
rule concatPosco:
    input:
        posco = expand('{complex}/{mutation}/{seed}/po-sco/{i}.txt', i=range(number_frames), complex=complexes, mutation=mutations, seed=seeds),
    output:
        'results/posco/posco_interactions.parquet'
    shell:
        """5.2_Posco_TransformDF.py --input {input.posco} --output {output}"""


# PosCo
rule PoScoHeatMap:
    input:
        data='results/posco/posco_interactions.parquet'
    output:
        lig_heatmap='results/posco/lig_heatmap.svg',
        rec_heatmap='results/posco/rec_heatmap.svg',
    shell:
        """
        5.4_Posco_Heatmap.py --input {input.data} --ligand_interaction {output.lig_heatmap} --receptor_interaction {output.rec_heatmap}
        """

# PosCo
# TODO: does not work for mutations are differnt complexes yet!!
rule PoScoBarplot:
    input:
        'results/posco/posco_interactions.parquet',  
    output:
        lig_barplot='results/posco/lig_barplot.svg',
        rec_barplot='results/posco/rec_barplot.svg',
    shell:
        """
        5.5_Posco_Barplot.py --input {input} \
                             --ligand_interaction {output.lig_barplot} \
                             --receptor_interaction {output.rec_barplot}
        """

rule interactionFingerprint:
    input:
        topo='{complex}/{mutation}/{seed}/MD/frame_end.cif',
        traj='{complex}/{mutation}/{seed}/MD/traj_center.dcd',
    output:
        report('{complex}/{mutation}/{seed}/fingerprint/fingerprint.parquet', labels=({"Name": "Interaction Analysis Prolif", "Type": "List"}),caption="RMSF.rst",category="Interaction Fingerprint")
    params:
        number_frames = config.get('number_frames'),
    threads: 4
    shell:
        """
        7_interactionFingerprint.py --topo {input.topo} \
                                    --traj  {input.traj} \
                                    --output {output} \
                                    --n_frames {params.number_frames} \
                                    --threads {threads} \
                                    --complex {wildcards.complex} \
                                    --mutation {wildcards.mutation} \
                                    --seed {wildcards.seed}
        """

rule GlobalFingerprintAnalysis:
    input:
        fingerprints=expand('{complex}/{mutation}/{seed}/fingerprint/fingerprint.parquet', complex=complexes, mutation=mutations, seed=seeds),
    output:
        interactions = report('results/fingerprints/interactions.parquet',labels=({"Name": "Interaction Analysis Prolif", "Type": "List"}),caption="RMSF.rst",category="Interaction Fingerprint"),
        figure = report('results/fingerprints/fingerprints.html',labels=({"Name": "Interaction Analysis Prolif", "Type": "Plot"}),caption="RMSF.rst",category="Interaction Fingerprint"),
    params:
        n_frames = config.get('number_frames')
    shell:
        """
        8_GlobalFinterprintAnalysis.py  --fingerprints {input.fingerprints} \
                                        --interactions {output.interactions} \
                                        --n_frames {params.n_frames} \
                                        --figure {output.figure}
        """

rule InteractionSurface:
    input:
        final_frame = f'{{complex}}/{{mutation}}/{representative_seed}/MD/topo_center.pdb',
        interactions= 'results/posco/posco_interactions.parquet',
    output:
        bfactor_pdbs = 'results/interactionSurface/{complex}.{mutation}.interaction.pdb',
        pymol_cmd = 'results/interactionSurface/{complex}.{mutation}.pml',
        pymol = report('results/interactionSurface/{complex}.{mutation}.final.pse',caption="RMSF.rst",category="PyMol",labels=({"Complex": "{complex}", "Mutation": "{mutation}", "Type": "PyMol"})),
        surface= report('results/interactionSurface/{complex}.{mutation}.png',caption="RMSF.rst",category="PyMol", labels=({"Complex": "{complex}", "Mutation": "{mutation}", "Type": "Image"}))
    params:
        representative_seed = seeds[0],
    shell:
        """
        10_InteractionSurface.py --interactions {input.interactions} \
                                 --seed {params.representative_seed} \
                                 --mutation {wildcards.mutation} \
                                 --complex {wildcards.complex}
        pymol -cQ {output.pymol_cmd}
        """

onsuccess:
    # If only single protein has been analyzed
    if 'protein' in sys.argv:
        print("Workflow finished, no error. Report for single protein will be generated")
        shell("squeeze protein --report report.html")

    elif 'molecule' in sys.argv:
        print("Workflow finished, no error. Report for small molecule will be generated")
        shell("squeeze molecule --report report.html")
    else:
        print("Workflow finished, no error. Report for protein-protein Interaction will be generated")
        shell("squeeze PPi --report report.html")
