#!/usr/bin/env python3
#	kartfire - Test framework to consistently run submission files
#	Copyright (C) 2023-2024 Johannes Bauer
#
#	This file is part of kartfire.
#
#	kartfire is free software; you can redistribute it and/or modify
#	it under the terms of the GNU General Public License as published by
#	the Free Software Foundation; this program is ONLY licensed under
#	version 3 of the License, later versions are explicitly excluded.
#
#	kartfire is distributed in the hope that it will be useful,
#	but WITHOUT ANY WARRANTY; without even the implied warranty of
#	MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#	GNU General Public License for more details.
#
#	You should have received a copy of the GNU General Public License
#	along with kartfire; if not, write to the Free Software
#	Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
#
#	Johannes Bauer <JohannesBauer@gmx.de>

import os
import base64
import json
import sys
import time
import subprocess
import contextlib
import tempfile
import stat
import collections

if len(sys.argv) < 2:
	print(f"{sys.argv[0]}: No argument given.", file = sys.stderr)
	sys.exit(1)

class ContainerTestrunner():
	def __init__(self, config_filename: str):
		with open(config_filename) as f:
			self._config = json.load(f)
		with contextlib.suppress(FileExistsError):
			os.makedirs(self.meta["local_dut_dir"])
		self._tc_results = { }

	def logmsg(self, msg: str):
		if self.meta.get("debug"):
			print(msg, file = sys.stderr)

	@property
	def meta(self) -> dict:
		return self._config["meta"]

	@property
	def testbatches(self) -> list:
		return self._config["testbatches"]

	@property
	def testbatch_count(self):
		return len(self.testbatches)

	def _exec(self, cmd, timeout):
		t0 = time.time()
		had_timeout = False
		exception_msg = None
		stdout = bytes()
		stderr = bytes()
		returncode = None
		try:
			result = subprocess.run(cmd, stdout = subprocess.PIPE, stderr = subprocess.PIPE, timeout = timeout)
			stdout = result.stdout
			stderr = result.stderr
			returncode = result.returncode
			if returncode == -9:
				exception_msg = "Out of memory, process terminated."
		except subprocess.TimeoutExpired as e:
			had_timeout = True
			exception_msg = f"{e.__class__.__name__} when trying execution: {str(e)}"
		except PermissionError as e:
			perms = stat.S_IMODE(os.stat(cmd[0]).st_mode)
			missing_exec = (perms & 0o111) == 0
			exception_msg = f"{e.__class__.__name__} when trying execution of {cmd[0]} (permissions {perms:#o}{' - missing execution bit!' if missing_exec else ''}): {str(e)}"
		except OSError as e:
			exception_msg = f"{e.__class__.__name__} when trying execution: {str(e)}"
		runtime_secs = time.time() - t0

		return {
			"stdout": stdout[-self.meta["limit_stdout_bytes"] : ],
			"stderr": stderr[-self.meta["limit_stdout_bytes"] : ],
			"stdout_length": len(stdout),
			"stderr_length": len(stderr),
			"returncode": returncode,
			"runtime_secs": runtime_secs,
			"timeout": had_timeout,
			"exception_msg": exception_msg,
		}

	def _unpack_files(self):
		subprocess.check_call([ "tar", "-x", "-C", self.meta["local_dut_dir"], "-f", self.meta["local_testcase_tar_file"] ])

	def _run_setup(self):
		setup_exec = f"{self.meta['local_dut_dir']}/{self.meta['setup_name']}"
		if os.path.exists(setup_exec):
			return self._exec([ setup_exec ], timeout = self.meta["max_setup_time_secs"])

	def _run_testbatch(self, testbatch: list[dict]):
		with open(self.meta["local_testcase_filename"], "w") as f:
			local_json_content = {
				"testcases": { testcase["name"]: testcase["testcase_data"] for testcase in testbatch },
			}
			json.dump(local_json_content, f)
		solution_exec = f"{self.meta['local_dut_dir']}/{self.meta['solution_name']}"
		total_allowance = sum(testcase["runtime_allowance_secs"] for testcase in testbatch)
		testbatch_results = self._exec([ solution_exec, self.meta["local_testcase_filename"] ], timeout = total_allowance)
		return {
			"testcases": [ testcase["name"] for testcase in testbatch ],
			"results": testbatch_results,
		}

	def run(self):
		t0 = time.time()
		result = {
			"setup": None,
			"testcase_results": [ ],
		}
		self._unpack_files()
		result["setup"] = self._run_setup()
		for (testbatch_no, testbatch) in enumerate(self.testbatches, 1):
			self.logmsg(f"Running testbatch {testbatch_no} / {self.testbatch_count} with {len(testbatch)} testcases: {testbatch}")
			result_data = self._run_testbatch(testbatch)
			self.logmsg(f"Finished testbatch {testbatch_no} / {self.testbatch_count}: {result_data}")
			result["testcase_results"].append(result_data)
		result["runtime_secs"] = time.time() - t0
		return result

def serializer(obj):
	if isinstance(obj, bytes):
		return base64.b64encode(obj).decode("ascii")
	elif isinstance(obj, enum.Enum):
		return obj.value
	else:
		raise TypeError(obj)

config_filename = sys.argv[1]
ctr = ContainerTestrunner(config_filename)
result = ctr.run()
print(json.dumps(result, default = serializer))
