#!/usr/bin/env python
# import datetime
import warnings
warnings.filterwarnings('ignore')

import datetime
import argparse
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  # '2'

import numpy as np
import pandas as pd

from pathlib import Path

from learna_tools.liblearna.design_rna_inference_only import design_rna, get_distance_metric
from learna_tools.liblearna.data.parse_dot_brackets import parse_dot_brackets, parse_local_design_data

from learna_tools.liblearna.agent import NetworkConfig, get_network, AgentConfig, get_agent_fn
from learna_tools.liblearna.environment import RnaDesignEnvironment, RnaDesignEnvironmentConfig

from learna_tools.data import read_task_description

from learna_tools.visualization import plot_df_with_varna, display_varna_plot, plot_sequence_logo

parser = argparse.ArgumentParser()

# Data
parser.add_argument(
    "--input_file", type=Path, help="Path to sequence to run on"
)
parser.add_argument("--data_dir", default="data", help="Data directory")
parser.add_argument("--dataset", type=Path, help="Available: eterna, rfam_taneda")
parser.add_argument(
    "--target_structure_ids",
    default=None,
    required=False,
    type=int,
    nargs="+",
    help="List of target structure ids to run on",
)

# Model
parser.add_argument("--restore_path", default='models/350_0_0', type=Path, help="From where to load model")
parser.add_argument("--stop_learning", action="store_true", help="Stop learning")
parser.add_argument("--agent", type=str, help="Select the agent, available choices:trpo, ppo, random")  # TRPO doesn't work, error in optimizers/solvers/conugate_gradient.py IndexedSlice does not have attribute get_shape

# Timeout behaviour
parser.add_argument("--timeout", type=int, help="Maximum time to run")

# Hyperparameters
parser.add_argument("--learning_rate", type=float, default=0.000589774576537135, help="Learning rate to use")
parser.add_argument(
    "--mutation_threshold", type=int, default=5, help="Enable MUTATION with set threshold"
)
parser.add_argument(
    "--reward_exponent", default=10.760044595015092, type=float, help="Exponent for reward shaping"
)
parser.add_argument(
    "--state_radius", default=10, type=int, help="Radius around current site"
)
parser.add_argument(
    "--conv_sizes", type=int, default=[0, 0], nargs="+", help="Size of conv kernels"
)
parser.add_argument(
    "--conv_channels", type=int, default=[17, 24], nargs="+", help="Channel size of conv"
)
parser.add_argument(
    "--num_fc_layers", type=int, default=2, help="Number of FC layers to use"
)
parser.add_argument(
    "--fc_units", type=int, default=12, help="Number of units to use per FC layer"
)
parser.add_argument(
    "--batch_size", type=int, default=247, help="Batch size for ppo agent"
)
parser.add_argument(
    "--entropy_regularization", type=float, default=4.464128875341124e-07, help="The output entropy"
)
parser.add_argument(
    "--restart_timeout", type=int, default=None, help="Time after which to restart the agent"
)
parser.add_argument("--lstm_units", default=20, type=int, help="The number of lstm units")
parser.add_argument("--num_lstm_layers", default=0, type=int, help="The number of lstm layers")
parser.add_argument("--embedding_size", default=17, type=int, help="The size of the embedding")
# parser.add_argument("--gc_tolerance", default=0.04, type=float, help="The tolerance of the gc-content")
# parser.add_argument("--desired_gc", default=0.5, type=float, help="The desired gc-content of the solution")
# parser.add_argument("--gc_improvement_step", action="store_true", help="Control the gc-content of the solution")
# parser.add_argument("--gc_reward", action="store_true", help="Include gc-content into reward function")
# parser.add_argument("--gc_weight", default=1.0, type=float, help="The weighting factor for the gc-content constraint")
# parser.add_argument("--num_actions", default=4, type=int, help="The number of actions that the agent chooses from")
# parser.add_argument("--keep_sequence", default="fully", type=str, help="How much of the sequence of targets for local design is kept: fully, partially, no")
parser.add_argument("--reward_function", type=str, default='structure_only', help="Decide if hamming distance is computed based on the folding only or also on the sequence parts")
# parser.add_argument("--training_data", default="random", type=str, help="Choose the training data for local design: random sequences, motif based sequences")
# parser.add_argument("--local_design", action="store_true", help="Choose if agent should do RNA local Design")
# parser.add_argument("--predict_pairs", action="store_true", help="Choose if Actions are used to directly predict watson-crick base pairs")
parser.add_argument("--state_representation", type=str, default='n-gram', help="Choose between n-gram and sequence_progress to show the nucleotides already placed in the state")
parser.add_argument("--data_type", type=str, default='random-sort', help="Choose type of training data, random motifs or motifs with balanced brackets")
# parser.add_argument("--sequence_constraints", type=str, default='-', help="Perform local design with knowledge about sequence")
parser.add_argument("--encoding", type=str, default='tuple', help="Choose encoding of the structure and seuqence constraints: 'tuple', 'sequence_bias', 'multihot'. Note: multihot only without embedding, requires conv (will be 2D)")
parser.add_argument(
    "--embedding_indices", type=int, default=21, help="Number of indices for the embedding"
)
# parser.add_argument("--variable_length", action="store_true", help="Output candidates of variable length")
parser.add_argument("--min_length", type=int, default=0, help="Minimum canidate length")
parser.add_argument("--max_length", type=int, default=0, help="Maximum candidate length")
parser.add_argument("--min_dot_extension", type=int, default=0, help="Minimum size of extension for dot representation O")
parser.add_argument(
    "--desired_gc", type=float, default=None, help="The desired GC content of the candidate solutions"
)
parser.add_argument(
    "--gc_tolerance", type=float, default=0.01, help="The tolerance for the GC content of the candidate solutions"
)
# parser.add_argument("--control_gc", action="store_true", help="Decide whether to control the GC content of the solution")
parser.add_argument("--algorithm", type=str, default='rnafold', help="Choose a folding algorithm for predictions. Available: rnafold, e2efold, linearfold")
# parser.add_argument("--rna_id", type=str, default='1', help="provide the rna id for writing files s.g. for contrafold")
parser.add_argument("--distance_metric", type=str, default='hamming', help="Choose a distance metric. Available: hamming, levenshtein, wl")
# parser.add_argument("--structure_only", action="store_true", help="Choose if state only considers structure parts of the target")
parser.add_argument("--num_solutions", type=int, default=1, help="Number of optimal solutions")
parser.add_argument("--plot_structure", action="store_true", help="Plot each structure with varna")
parser.add_argument("--show_plots", action="store_true", help="Show each plot generated")
parser.add_argument("--resolution", type=str, default='8.0', help="Resolution for structure plots")
parser.add_argument("--plot_logo", action="store_true", help="Plot sequence information as logo")
parser.add_argument("--show_logo", action="store_true", help="Plot sequence information as logo")
parser.add_argument("--working_dir", type=str, default='working_dir', help="the working_dir")
parser.add_argument("--cm_design", action="store_true", help="Activate CM design")
parser.add_argument("--cm_path", type=str, default='rfam_cms/Rfam.cm', help="path to the CM database")
parser.add_argument("--cm_name", type=str, default='RF00008', help="Specific cm to design for")
parser.add_argument("--seed", type=int, default=42, help="Seed for task sampling")
parser.add_argument("--rri_design", action="store_true", help="Activate CM design")
parser.add_argument("--rri_target", type=str, default='UUUAAAUUAAAAAAUCAUAGAAAAAGUAUCGUUUGAUACUUGUGAUUAUACUCAGUUAUACAGUAUCUUAAGGUGUUAUUAAUAGUGGUGAGGAGAAUUUAUGAAGCUUUUCAAAAGCUUGCUUGUGGCACCUGCAACUCUUGGUCUUUUAGCACCAAUGACCGCUACUGCUAAU', help="Target for RRI design")
parser.add_argument("--cm_threshold", type=float, default=10.0, help="Seed for task sampling")
parser.add_argument("--rri_threshold", type=float, default=10.0, help="Seed for task sampling")
parser.add_argument("--show_all_designs", action="store_true", help="Activate CM design")
# parser.add_argument("--hamming_tolerance", type=int, default=0, help="Seed for task sampling")
parser.add_argument("--no_shared_agent", action="store_true", help="Activate CM design")
parser.add_argument("--plotting_dir", type=str, default=None, help="select target directory for saving plots.")
parser.add_argument("--results_dir", type=str, default=None, help="select target directory for saving results. Defaults to ./results")
parser.add_argument("--output_format", type=str, default='pickle', help="select an output format for the results. Options: pickle, csv, fasta")


args = parser.parse_args()

args.hamming_tolerance = 0

network_config = NetworkConfig(
    conv_sizes=args.conv_sizes,  # radius * 2 + 1
    conv_channels=args.conv_channels,
    num_fc_layers=args.num_fc_layers,
    fc_units=args.fc_units,
    lstm_units=args.lstm_units,
    num_lstm_layers=args.num_lstm_layers,
    embedding_size=args.embedding_size,
    conv2d=args.encoding == "multihot",
    embedding_indices=args.embedding_indices,
)
agent_config = AgentConfig(
    learning_rate=args.learning_rate,
    batch_size=args.batch_size,
    entropy_regularization=args.entropy_regularization,
    agent=args.agent,
)
env_config = RnaDesignEnvironmentConfig(
    mutation_threshold=args.mutation_threshold,
    reward_exponent=args.reward_exponent,
    state_radius=args.state_radius,
    # distance_metric=get_distance_metric(args.distance_metric),
    # gc_tolerance=args.gc_tolerance,
    # desired_gc=args.desired_gc,
    # gc_improvement_step=args.gc_improvement_step,
    # gc_weight=args.gc_weight,
    # gc_reward=args.gc_reward,
    local_design=True,
    # num_actions=args.num_actions,
    # keep_sequence=args.keep_sequence,
    # sequence_reward=args.sequence_reward,
    reward_function=args.reward_function,
    predict_pairs=True,
    state_representation=args.state_representation,
    data_type=args.data_type,
    sequence_constraints='-',
    encoding=args.encoding,
    # variable_length=args.variable_length,
    min_length=args.min_length,
    max_length=args.max_length,
    min_dot_extension=args.min_dot_extension,
    control_gc=args.desired_gc is not None,
    desired_gc=args.desired_gc,
    algorithm=args.algorithm,
    gc_tolerance=args.gc_tolerance,
    cm_design=args.cm_design,
    cm_path=args.cm_path,
    cm_name=args.cm_name,
    working_dir=args.working_dir,
    seed=args.seed,
    rri_design=args.rri_design,
    rri_target=args.rri_target,
    # rna_id='1',
    # structure_only=args.structure_only,
    # training_data=args.training_data,
)

# print(env_config.mutation_threshold)

if not args.input_file:
    raise UserWarning('Error: Please provide an input file.')
data = read_task_description(args.input_file)

data = data.replace(np.nan, None)


if args.desired_gc:
    data.loc[:, 'globgc'] = args.desired_gc

for col in ['Id', 'str', 'seq', 'globgc', 'globen']:
    if col not in data.columns:
        data.loc[:, col] = None

dot_brackets = list((i, db, seq, gc, en) for i, db, seq, gc, en in zip(data['Id'],
                                                                       data['str'],
                                                                       data['seq'],
                                                                       data['globgc'],
                                                                       data['globen']))


ids = [id for id, _, _, _, _ in dot_brackets]

processed_ids = []

assert len(ids) == len(set(ids)), 'Ids must be unique'

def return_intermediate_solutions():
    for i, solutions in design_rna(
        dot_brackets,
        timeout=args.timeout,
        restore_path=None,
        stop_learning=False,
        restart_timeout=args.restart_timeout,
        network_config=network_config,
        agent_config=agent_config,
        env_config=env_config,
        num_solutions=args.num_solutions,
        hamming_tolerance=args.hamming_tolerance,
        share_agent=not args.no_shared_agent,
        num_cores = 1,  #  args.num_cores,
        show_all_designs=args.show_all_designs,
        report_cm_threshold=args.cm_threshold,
        report_rri_threshold=args.rri_threshold,
    ):
        if i not in processed_ids:
            processed_ids.append(i)
            yield i, solutions
# sols = design_rna(
#         dot_brackets,
#         timeout=args.timeout,
#         restore_path=None,
#         stop_learning=False,
#         restart_timeout=args.restart_timeout,
#         network_config=network_config,
#         agent_config=agent_config,
#         env_config=env_config,
#         num_solutions=args.num_solutions,
#         hamming_tolerance=args.hamming_tolerance,
#     )
# predictions = pd.DataFrame(preds)
# print(sols)
            
current_time = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

for i, sols in return_intermediate_solutions():
    solutions = pd.DataFrame(sols)
    if solutions.empty:
        print("\033[91m" + f'WARNING: No solutions found for target {i}' + "\033[0m")
        continue

    if args.hamming_tolerance > 0:
        solutions.sort_values(by='hamming_distance', inplace=True)
    else:
        solutions.sort_values(by='time', inplace=True)

    solutions = solutions.reset_index(drop=True)

    # if args.hamming_tolerance > 0:
    #     subopt = predictions[predictions['hamming_distance'] <= args.hamming_tolerance]
    #     solutions = pd.concat([solutions, subopt])
    #     solutions = solutions.drop_duplicates('sequence')

    # print(predictions)
    print()
    print('Solutions for target structure', i)
    print()
    print(solutions.to_markdown())
    if args.plotting_dir is not None:
        plotting_dir = Path(args.plotting_dir)
        plotting_dir.mkdir(exist_ok=True, parents=True)
    name = f'{i}_{current_time}'
    if args.plot_structure:
        if args.plotting_dir is not None:
            plot_df_with_varna(solutions, show=args.show_plots, name=name, resolution=args.resolution, plot_dir=plotting_dir)
        else:
            plot_df_with_varna(solutions, show=args.show_plots, resolution=args.resolution, name=name)

    if args.plot_logo:
        for n, group in solutions.groupby('length'):
            if args.plotting_dir is not None:
                name = name + f'_length_{n}'
                plot_sequence_logo(group, show=args.show_plots, plotting_dir=plotting_dir, name=name)
            else:
                plot_sequence_logo(group, show=args.show_plots, name=name)

    if args.results_dir is not None:
        results_dir = Path(args.results_dir)
        results_dir.mkdir(exist_ok=True, parents=True)
        if args.output_format == 'pickle':
            solutions.to_pickle(results_dir / f'{name}.pkl')
        elif args.output_format == 'csv':
            solutions.to_csv(results_dir / f'{name}.csv')
        elif args.output_format == 'fasta':
            with open(results_dir / f'{name}.fasta', 'w') as f:
                for j, row in solutions.iterrows():
                    f.write(f'>{j}\n{row["sequence"]}\n{row["structure"]}\n')
        else:
            raise ValueError(f'Unknown output format {args.output_format}')
    
    if any(i not in processed_ids for i in ids):
        print()
        print('Continue with predictions for ids:', ', '.join([str(i) for i in ids if i not in processed_ids]))
        print()

