#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import time
import pathlib
import click
import random

from opengate.exception import fatal, colored, color_ok, color_error
from opengate_core.testsDataSetup import check_tests_data_folder

CONTEXT_SETTINGS = dict(help_option_names=["-h", "--help"])


@click.command(context_settings=CONTEXT_SETTINGS)
@click.option("--test_id", "-i", default="all", help="Start test from this number")
@click.option(
    "--random_tests",
    "-r",
    is_flag=True,
    default=False,
    help="Start the last 10 tests and 1/4 of the others randomly",
)
def go(test_id, random_tests):
    pathFile = pathlib.Path(__file__).parent.resolve()
    if "src" in os.listdir(pathFile):
        mypath = os.path.join(pathFile, "../tests/src")
    else:
        import opengate.tests

        mypath = os.path.join(
            pathlib.Path(opengate.tests.__file__).resolve().parent, "../tests/src"
        )

    print("Look for tests in: " + mypath)

    if not check_tests_data_folder():
        return False

    # Look if torch is installed
    torch = True
    torch_tests = [
        "test034_gan_phsp_linac.py",
        "test038_gan_phsp_spect_gan_my.py",
        "test038_gan_phsp_spect_gan_aa.py",
        "test038_gan_phsp_spect_gan_se.py",
        "test038_gan_phsp_spect_gan_ze.py",
        "test040_gan_phsp_pet_gan.py",
        "test043_garf.py",
        "test045_speedup_all_wip.py",
        "test047_gan_vox_source_cond.py",
    ]
    try:
        import torch
    except:
        torch = False

    windowsWrongTests = [
        # "test014_engine_2.py",
        # "test060_PhsSource_ParticleName_direct.py",
        # "test060_PhsSource_ParticleName_fromPHS_PDGCode.py",
        # "test060_PhsSource_ParticleName_fromPHS_ParticleName.py",
        # "test060_PhsSource_rotation.py",
        # "test060_PhsSource_translation.py",
        # "test043_garf5_region_MT_subproc.py",
        # "test043_garf2_region_subproc.py.py",
        # "test061_user_event_info.py",
    ]

    onlyfiles = [
        f for f in os.listdir(mypath) if os.path.isfile(os.path.join(mypath, f))
    ]

    files = []
    for f in onlyfiles:
        if "wip" in f:
            print(f"Ignoring: {f:<40} (Work In Progress) ")
            continue
        if "visu" in f:
            continue
        if "OLD" in f:
            continue
        if "old" in f:
            continue
        if "test" not in f:
            continue
        if ".py" not in f:
            continue
        if ".log" in f:
            continue
        if "all_tes" in f:
            continue
        if "_base" in f:
            continue
        if "test045_speedup" in f:
            continue
        if "_helpers" in f:
            continue
        if os.name == "nt" and "_mt" in f:
            continue
        if os.name == "nt" and f in windowsWrongTests:
            continue
        if not torch and f in torch_tests:
            print(f"Ignoring: {f:<40} (Torch is not available) ")
            continue
        files.append(f)

    files = sorted(files)
    if test_id != "all":
        test_id = int(test_id)
        files_new = []
        for f in files:
            id = int(f[4:7])
            if id >= test_id:
                files_new.append(f)
            else:
                print(f"Ignoring: {f:<40} (< {test_id}) ")
        files = files_new
    elif random_tests:
        files_new = files[-10:]
        prob = 0.25
        files = files_new + random.sample(files[:-10], int(prob * (len(files) - 10)))
        files = sorted(files)

    print(f"Running {len(files)} tests")
    print(f"-" * 70)

    failure = False

    for f in files:
        start = time.time()
        print(f"Running: {f:<46}  ", end="")
        cmd = "python " + os.path.join(mypath, f"{f}")
        log = os.path.join(os.path.dirname(mypath), f"log/{f}.log")
        r = os.system(f"{cmd} > {log}")
        # subprocess.run(cmd, stdout=f, shell=True, check=True)
        if r == 0:
            print(colored.stylize(" OK", color_ok), end="")
        else:
            if r == 2:
                # this is probably a Ctrl+C, so we stop
                fatal("Stopped by user")
            else:
                print(colored.stylize(" FAILED !", color_error), end="")
                failure = True
                os.system("cat " + log)
        end = time.time()
        print(f"   {end - start:5.1f} s     {log:<65}")

    print(not failure)


# --------------------------------------------------------------------------
if __name__ == "__main__":
    go()
