#!python
from myqueue.workflow import run
from myqueue.task import Task
from typing import Optional, Dict, Tuple, Final, Union
from omegaconf import DictConfig, OmegaConf
from curator.utils import read_user_config, register_resolvers
from pathlib import Path

register_resolvers()

# TODO: manage restart all job

# paths for jobs
train_path: Final = 'train'
simulate_path: Final = 'simulate'
select_path: Final = 'select'
label_path: Final = 'label'

# keys inferred in workflow: model_path, init_traj, pool_set, al_info, start_index
# user specified arguments that can be different from worflow defaults: read_traj (init_traj), image_index (start_index)

path_set = set([
    'run_path', 
    'datapath',
    'train_path',
    'val_path',
    'test_path',
    'model_path',
    'load_model',       # used to load specific model but not model trained in workflow
    'init_traj',
    'read_traj',        # used to load specific initial trajectory but not trajectory generated in workflow
])

def resolve_paths(config: Union[DictConfig, dict], base_dir='.', path_set=path_set):
    # resolve absolute paths inside path_set
    for key, value in config.items():
        if key in path_set:
            if isinstance(value, str):
                abs_path = Path(base_dir) / value
                config[key] = str(abs_path.resolve())
            elif isinstance(value, list):
                abs_paths = [str((Path(base_dir) / l).resolve()) for l in list]
                config[key] = abs_paths
        elif isinstance(value, (dict, DictConfig)):
            resolve_paths(value, base_dir, path_set)

def train(
    deps: list[Task],
    config: DictConfig,
    iteration: Optional[int] = 0,
) -> Tuple[list[str], list[Task]]:
    ''' 
    Runs a train task for each model in the ensemble.
    This is done through the three steps for each model:
        - Create a new directory 
        - Save user_cfg file
        - Run task
    '''
    tasks = []
    model_paths = []              # collect models for simulation and active learning
    arguments = ['cfg=train.yaml']
    config = config.copy()

    # get general keys
    general = config.pop('general')

    # load multiple models
    for i, (name, job_config) in enumerate(config.items()):
        # define start_iteration:
        start_iteration = job_config.pop('start_iteration', 0)
        if iteration >= start_iteration:
            # load parameters, create run directory, and save user_cfg file
            job_config = OmegaConf.merge(general, job_config)
            cfg = read_user_config(job_config, config_name='train.yaml')
            run_path = Path(train_path) / f'iter_{iteration}' / name
            run_path.mkdir(parents=True, exist_ok=True)
            cfg.run_path = str(run_path.resolve())
            model_paths.append(cfg.run_path + '/model_path')
            # parse node resources
            job_resources = cfg.pop('resources')

            # TODO: load old model
            if iteration > 0:
                model_path = cfg.pop('load_model', None)           # load_specific model
                old_model_path = Path(train_path) / f'iter_{iteration-1}' / name / 'model_path'
                cfg.model_path = model_path or str(old_model_path.resolve())

            # save config file
            OmegaConf.save(cfg, run_path / 'train.yaml', resolve=True)

            tasks.append(run(
                shell='curator-train',
                deps=deps,
                args=arguments,
                folder=run_path,
                name='train',
                **job_resources,
            ))
    return tasks, model_paths

def simulate(
    deps: list[Task],
    model_path: list[str],
    config: DictConfig,
    iteration: Optional[int] = 0,
) -> Dict[str, Task]:
    ''' 
    Runs a simulate task for each model in the ensemble.
    This is done through the three steps for each model:
        - Create a new directory 
        - Save user_cfg file
        - Run task
    '''
    tasks = {}
    pool_path = {}                       # collect pool data set for active learning selection
    arguments = ['cfg=simulate.yaml']
    config = config.copy()

    # get general keys
    general = config.pop('general')

    # run multiple simulations
    for name, job_config in config.items():
        start_iteration = job_config.pop('start_iteration', 0)
        if iteration >= start_iteration:
            # load parameters, create run directory, and save user_cfg file
            job_config = OmegaConf.merge(general, job_config)
            cfg = read_user_config(job_config, config_name='simulate.yaml')
            run_path = Path(simulate_path) / f'iter_{iteration}' / name
            run_path.mkdir(parents=True, exist_ok=True)
            cfg.run_path = str(run_path.resolve())

            # parse node resources
            job_resources = cfg.pop('resources')

            # TODO: load old model, init_traj, load compiled model
            cfg.model_path = model_path
            # load user specified arguments: read_traj, image_index
            if iteration > 0:
                init_traj = cfg.simulator.pop('read_traj', cfg.simulator.out_traj.replace(f'iter_{iteration}', f'iter_{iteration-1}'))   #use traj from last iteration if no new traj is specified
                start_index = cfg.simulator.pop('image_index', -1)  # use last image if not specified
                cfg.simulator.init_traj = init_traj
                cfg.simulator.start_index = start_index

            pool_path[name] = [cfg.simulator.out_traj]
            try:
                pool_path[name].append(cfg.simulator.uncertainty.save_uncertain_atoms)
            except:
                pass

            OmegaConf.save(cfg, run_path / 'simulate.yaml', resolve=True)

            tasks[name] = run(
                shell='curator-simulate',
                deps=deps,
                args=arguments,
                folder=run_path,
                name=name,
                **job_resources,
            )
    return tasks, pool_path

def select(
    deps: Dict[str, Task],
    model_path: list[str],
    pool_path: Dict[str, str],
    config: DictConfig,
    iteration: Optional[int] = 0,
) -> Tuple[Dict[str, str], Dict[str, Task]]:
    ''' 
    Runs a select task for each model in the ensemble.
    This is done through the three steps for each model:
        - Create a new directory 
        - Save user_cfg file
        - Run task
    '''
    tasks = {}
    al_info = {}
    arguments = ['cfg=select.yaml']
    config = config.copy()

    # get general keys
    general = config.pop('general')

    # selection for multiple systems
    for name, job_config in config.items():
        start_iteration = job_config.pop('start_iteration', 0)
        if iteration >= start_iteration:
            # load parameters, create run directory, and save user_cfg file
            job_config = OmegaConf.merge(general, job_config)
            cfg = read_user_config(job_config, config_name='select.yaml')
            run_path = Path(select_path) / f'iter_{iteration}' / name
            run_path.mkdir(parents=True, exist_ok=True)
            cfg.run_path = str(run_path.resolve())

            # parse node resources
            job_resources = cfg.pop('resources')

            # TODO: load old model and get pool_set and al_info
            cfg.model_path = model_path
            cfg.pool_set = pool_path[name]
            al_info[name] = cfg.run_path + '/selected.json'

            OmegaConf.save(cfg, run_path / 'select.yaml', resolve=True)

            tasks[name] = run(
                shell='curator-select',
                deps=[deps[name]],
                args=arguments,
                folder=run_path,
                name=name,
                **job_resources,
            )
    return tasks, al_info

def label(
    deps: Dict[str, Task],
    pool_path: Dict[str, list],
    al_info: Dict[str, str],
    config: DictConfig,
    iteration: Optional[int] = 0,
) -> list[Task]:
    ''' 
    Runs a label task for each model in the ensemble.
    This is done through the three steps for each model:
        - Create a new directory 
        - Save user_cfg file
        - Run task
    '''
    tasks = []
    arguments = ['cfg=label.yaml']
    config = config.copy()

    # get general keys
    general = config.pop('general')

    # selection for multiple systems
    for name, job_config in config.items():
        start_iteration = job_config.pop('start_iteration', 0)
        if iteration >= start_iteration:
            # load parameters, create run directory, and save user_cfg file
            job_config = OmegaConf.merge(general, job_config)
            cfg = read_user_config(job_config, config_name='label.yaml')

            # parse node resources
            job_resources = cfg.pop('resources')

            # TODO: get atoms that need to be labelled, possibly overall datapath in training
            cfg.pool_set = pool_path[name]
            cfg.al_info = al_info[name]

            # split jobs if needed
            if cfg.split_jobs is not None:
                for i in range(cfg.split_jobs):
                    run_path = Path(label_path) / f'iter_{iteration}' / name / f'{i}'
                    run_path.mkdir(parents=True, exist_ok=True)
                    cfg.job_order = i
                    cfg.run_path = str(run_path.resolve())
                    OmegaConf.save(cfg, run_path / 'label.yaml', resolve=True)
                    tasks.append(run(
                        shell='curator-label',
                        deps=[deps[name]],
                        args=arguments,
                        folder=run_path,
                        name=name,
                        **job_resources,
                    ))
            else:
                run_path = Path(label_path) / f'iter_{iteration}' / name
                run_path.mkdir(parents=True, exist_ok=True)
                cfg.run_path = str(run_path.resolve())
                OmegaConf.save(cfg, run_path / 'label.yaml', resolve=True)
                tasks.append(run(
                    shell='curator-label',
                    deps=[deps[name]],
                    args=arguments,
                    folder=run_path,
                    name=name,
                    **job_resources,
                ))
    return tasks


def workflow(cfg='user_cfg.yaml'):
    cfg = OmegaConf.load(cfg)
    resolve_paths(cfg, base_dir=cfg.get('run_path', '.'))
    label_tasks = []

    for iteration in range(10):
        train_tasks, model_path = train(deps=label_tasks, config=cfg.train, iteration=iteration)

        simulate_tasks, pool_path = simulate(deps=train_tasks, model_path=model_path, config=cfg.simulate, iteration=iteration)

        select_tasks, al_info = select(deps=simulate_tasks, model_path=model_path, pool_path=pool_path, config=cfg.select, iteration=iteration)

        label_tasks = label(deps=select_tasks, pool_path=pool_path, al_info=al_info, config=cfg.label, iteration=iteration)