#!/usr/bin/env python3
import os
import sys
import argparse
import subprocess
from multiprocessing import Pool


def dir_path(string):
    if os.path.isdir(string):
        return string
    else:
        raise NotADirectoryError(string)


def execute_py(args):
    filename, data_file = args
    process = subprocess.Popen(['python3.9', filename], stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=open(data_file, 'r'))
    out, err = process.communicate()
    if process.returncode != 0:
        raise RuntimeError(f'Solution ran into an error\n\nFile: {data_file}')
    return data_file, out

def execute_cpp(args):
    filename, data_file = args
    process = subprocess.Popen([filename], stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=open(data_file, 'r'))
    out, err = process.communicate()
    if process.returncode != 0:
        raise RuntimeError(f'Solution ran into an error\n\nFile: {data_file}')
    return data_file, out

def compile_cpp(filename):
    process = subprocess.Popen(['g++', filename, '-o', filename.removesuffix('.cpp')], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    out, err = process.communicate()
    if process.returncode != 0:
        raise RuntimeError(f'Compilation ran into an error\n\n{err.decode()}')

def main(arguments):

    parser = argparse.ArgumentParser(description='Generate test cases')
    parser.add_argument('folder', help="Data folder", type=dir_path)
    parser.add_argument('-s', '--sol', help='Solution source', required=True, type=argparse.FileType('r'))
    parser.add_argument('-w', '--worker', help='Number of worker processes', type=int, default=2)

    args = parser.parse_args(arguments)

    pathname = os.path.abspath(args.sol.name)
    filename = os.path.basename(pathname)

    execute = None
    if '.cpp' in filename:
        compile_cpp(pathname)
        execute = execute_cpp
        exec_name = pathname.removesuffix('.cpp')
    elif '.py' in filename:
        execute = execute_py
        exec_name = pathname.removesuffix('.py')
    else:
        raise Exception("Unknown file format")

    # print(args.folder)
    datas_in = []
    for root, dir, files in os.walk(args.folder):
        # print('AAA', root, dir, files)
        for name in files:
            if '.in' in name:
                datas_in.append((exec_name, os.path.join(root, name)))
            
    cnt = 0
    with Pool(processes=args.worker) as p:
        for data_file, data_out in p.imap_unordered(execute, datas_in):
            out_file = data_file.removesuffix('.in') + '.out'
            with open(out_file, 'w') as f:
                f.write(data_out.decode())
            cnt += 1
            print('Generated %d / %d' % (cnt, len(datas_in)))


if __name__ == '__main__':
    sys.exit(main(sys.argv[1:]))
