#!/usr/bin/env python

import warnings
warnings.filterwarnings('ignore')

import datetime

import argparse
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  # '2'

import pandas as pd

from learna_tools.learna.design_rna_inference_only import design_rna
from learna_tools.learna.data.parse_dot_brackets import parse_dot_brackets

from learna_tools.learna.agent import NetworkConfig, get_network, AgentConfig, get_agent_fn
from learna_tools.learna.environment import RnaDesignEnvironment, RnaDesignEnvironmentConfig

from learna_tools.data import read_task_description

from learna_tools.visualization import plot_df_with_varna, display_varna_plot, plot_sequence_logo



if __name__ == "__main__":
    import argparse
    from pathlib import Path
    from learna_tools.learna.data.parse_dot_brackets import parse_dot_brackets

    parser = argparse.ArgumentParser()

    # Data
    parser.add_argument(
        "--target_structure", type=str, help="Structure in dot-bracket notation"
    )
    parser.add_argument(
        "--target_id", type=str, help="Id of the input target structure", required=False
    )
    parser.add_argument(
        "--input_file", type=Path, help="Path to sequence to run on"
    )

    parser.add_argument("--data_dir", default="data", help="Data directory")
    parser.add_argument("--dataset", type=Path, help="Available: eterna, rfam_taneda")
    parser.add_argument(
        "--target_structure_ids",
        default=None,
        required=False,
        type=int,
        nargs="+",
        help="List of target structure ids to run on",
    )

    parser.add_argument("--restore_path", type=Path, default='models/224_0_1', help="From where to load model")
    # parser.add_argument("--stop_learning", action="store_true", help="Stop learning")
    # parser.add_argument("--random_agent", action="store_true", help="Use random agent")


    # Timeout behaviour
    parser.add_argument("--timeout", default=600, type=int, help="Maximum time to run")

    # Hyperparameters
    parser.add_argument("--learning_rate", type=float, default=6.442010833400271e-05, help="Learning rate to use")
    parser.add_argument(
        "--mutation_threshold", type=int, default=5, help="Enable MUTATION with set threshold"
    )
    parser.add_argument(
        "--reward_exponent", default=8.932893783628236, type=float, help="Exponent for reward shaping"
    )
    parser.add_argument(
        "--state_radius", default=29, type=int, help="Radius around current site"
    )
    parser.add_argument(
        "--conv_sizes", type=int, default=[11, 3], nargs="+", help="Size of conv kernels"
    )
    parser.add_argument(
        "--conv_channels", type=int, default=[10, 3], nargs="+", help="Channel size of conv"
    )
    parser.add_argument(
        "--num_fc_layers", type=int, default=1, help="Number of FC layers to use"
    )
    parser.add_argument(
        "--fc_units", type=int, default=52, help="Number of units to use per FC layer"
    )
    parser.add_argument(
        "--batch_size", type=int, default=123, help="Batch size for ppo agent"
    )
    parser.add_argument(
        "--entropy_regularization", type=float, default=0.00015087352506343337, help="The output entropy"
    )
    parser.add_argument(
        "--restart_timeout", type=int, default=1800, help="Time after which to restart the agent"
    )
    parser.add_argument("--lstm_units", type=int, default=3, help="The number of lstm units")
    parser.add_argument("--num_lstm_layers", type=int, default=0, help="The number of lstm layers")
    parser.add_argument("--embedding_size", type=int, default=2, help="The size of the embedding")

    parser.add_argument("--hamming_tolerance", type=int, default=0, help="Allowed tolerance of Hamming Distance for structure")

    parser.add_argument("--num_solutions", type=int, default=1, help="Number of solutions")
    # parser.add_argument("--diversity_loss", action="store_true", help="Use additional loss for diversity")
    parser.add_argument("--plot_structure", action="store_true", help="Plot each structure with varna")
    parser.add_argument("--show_plots", action="store_true", help="Show each plot generated")
    parser.add_argument("--resolution", type=str, default='8.0', help="Resolution for structure plots")
    parser.add_argument("--plot_logo", action="store_true", help="Plot sequence information as logo")
    parser.add_argument("--show_logo", action="store_true", help="Plot sequence information as logo")
    parser.add_argument("--no_shared_agent", action="store_true", help="If Active, agent is not shared across targets")
    parser.add_argument("--plotting_dir", type=str, default=None, help="select target directory for saving plots.")
    parser.add_argument("--results_dir", type=str, default=None, help="select target directory for saving results. Defaults to ./results")
    parser.add_argument("--output_format", type=str, default='pickle', help="select an output format for the results. Options: pickle, csv, fasta")



    args = parser.parse_args()

    network_config = NetworkConfig(
        conv_sizes=args.conv_sizes,  # radius * 2 + 1
        conv_channels=args.conv_channels,
        num_fc_layers=args.num_fc_layers,
        fc_units=args.fc_units,
        lstm_units=args.lstm_units,
        num_lstm_layers=args.num_lstm_layers,
        embedding_size=args.embedding_size,
    )
    agent_config = AgentConfig(
        learning_rate=args.learning_rate,
        batch_size=args.batch_size,
        entropy_regularization=args.entropy_regularization,
        random_agent=False,
    )
    env_config = RnaDesignEnvironmentConfig(
        mutation_threshold=args.mutation_threshold,
        reward_exponent=args.reward_exponent,
        state_radius=args.state_radius,
        diversity_loss=False,  # args.diversity_loss,
    )

    if args.target_structure:
        if args.target_id:
            dot_brackets = [(args.target_id, args.target_structure)]
        else:
            dot_brackets = [(1, args.target_structure)]
    elif args.input_file:
        data = read_task_description(args.input_file)
        dot_brackets = list((i, db) for i, db in zip(data['Id'], data['str']))
    else:
        raise UserWarning('Please either enter a target structure or an input file')
    ids = [id for id, _ in dot_brackets]

    processed_ids = []

    assert len(ids) == len(set(ids)), 'Ids must be unique'

    def return_intermediate_solutions():
        for i, solutions in design_rna(
            dot_brackets,
            timeout=args.timeout,
            restore_path=None,
            stop_learning=False,
            restart_timeout=args.restart_timeout,
            network_config=network_config,
            agent_config=agent_config,
            env_config=env_config,
            num_solutions=args.num_solutions,
            hamming_tolerance=args.hamming_tolerance,
            share_agent=not args.no_shared_agent,
        ):
            if i not in processed_ids:
                processed_ids.append(i)
                yield i, solutions



    # sols = design_rna(
    #         dot_brackets,
    #         timeout=args.timeout,
    #         restore_path=None,
    #         stop_learning=False,
    #         restart_timeout=args.restart_timeout,
    #         network_config=network_config,
    #         agent_config=agent_config,
    #         env_config=env_config,
    #         num_solutions=args.num_solutions,
    #         hamming_tolerance=args.hamming_tolerance,
    #     )

    # predictions = pd.DataFrame(preds)

    # print(sols)
                
    current_time = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

    for i, sols in return_intermediate_solutions():
        solutions = pd.DataFrame(sols)
        if solutions.empty:
            print("\033[91m" + f'WARNING: No solutions found for target {i}' + "\033[0m")
            continue
        if args.hamming_tolerance > 0:
            solutions.sort_values(by='hamming_distance', inplace=True)
        else:
            solutions.sort_values(by='time', inplace=True)
    
        solutions = solutions.reset_index(drop=True)
    
        # if args.hamming_tolerance > 0:
        #     subopt = predictions[predictions['hamming_distance'] <= args.hamming_tolerance]
        #     solutions = pd.concat([solutions, subopt])
        #     solutions = solutions.drop_duplicates('sequence')
        
        # print(predictions)
        print()
        print('Solutions for target structure', i)
        print()
        print(solutions.to_markdown())

        if args.plotting_dir is not None:
            plotting_dir = Path(args.plotting_dir)
            plotting_dir.mkdir(exist_ok=True, parents=True)

        name = f'{i}_{current_time}'

        if args.plot_structure:
            if args.plotting_dir is not None:
                plot_df_with_varna(solutions, show=args.show_plots, name=name, resolution=args.resolution, plot_dir=plotting_dir)
            else:
                plot_df_with_varna(solutions, show=args.show_plots, resolution=args.resolution, name=name)
    
        if args.plot_logo:
            if args.plotting_dir is not None:
                plot_sequence_logo(solutions, show=args.show_plots, plotting_dir=plotting_dir, name=name)
            else:
                plot_sequence_logo(solutions, show=args.show_plots, name=name)
        if args.results_dir is not None:
            results_dir = Path(args.results_dir)
            results_dir.mkdir(exist_ok=True, parents=True)
            if args.output_format == 'pickle':
                solutions.to_pickle(results_dir / f'{name}.pkl')
            elif args.output_format == 'csv':
                solutions.to_csv(results_dir / f'{name}.csv')
            elif args.output_format == 'fasta':
                with open(results_dir / f'{name}.fasta', 'w') as f:
                    for j, row in solutions.iterrows():
                        f.write(f'>{j}\n{row["sequence"]}\n{row["structure"]}\n')
            else:
                raise ValueError(f'Unknown output format {args.output_format}')


        if any(i not in processed_ids for i in ids):
            print()
            print('Continue with predictions for ids:', ', '.join([str(i) for i in ids if i not in processed_ids]))
            print()


