#!/usr/bin/env python
# coding: utf-8
# TST test
# (c) 2012-2016 Dalton Serey, UFCG

from __future__ import print_function
from __future__ import unicode_literals
from collections import Counter
import os
import re
import sys
import json
import glob
import time
import shlex
import signal
import string
import codecs
import unicodedata
import argparse
import difflib
import hashlib

from tst.jsonfile import JsonFile
import tst
from tst.utils import _assert
from tst.utils import cprint, to_unicode
from tst.colors import *

from subprocess import Popen, PIPE

TIMEOUT_DEFAULT = 2
PYTHON = 'python2.7'
if sys.version_info < (2,7):
    sys.stderr.write('tst.py: requires python 2.7 or later\n')
    sys.exit(1)

REDIRECTED = os.fstat(0) != os.fstat(1)

STATUS_CODE = {
    # TST test codes
    'DefaultError': 'e',
    'Timeout': 't',
    'Success': '.',
    'QuasiSuccess': '*',
    'AllTokensSequence': '@',
    'AllTokensMultiset': '&',
    'MissingTokens': '%',
    'Fail': 'F',
    'ScriptTestError': '!',
    'Inconclusive': '?',
    'UNKNOWN': '-',
    'NoInterpreterError': 'X',

    # Python ERROR codes
    'AttributeError': 'a',
    'SyntaxError': 's',
    'EOFError': 'o',
    'ZeroDivisionError': 'z',
    'IndentationError': 'i',
    'IndexError': 'x',
    'ValueError': 'v',
    'TypeError': 'y',
    'NameError': 'n',
}


TSTJSONFAILMSG = LRED + "Spec file is not a valid config object" + RESET

def abort(msg, code=1):
    cprint(LRED, msg.encode('utf-8'), file=sys.stderr)
    sys.exit(code)


def data2json(data):
    return json.dumps(
        data,
        indent=2,
        separators=(',', ': '),
        ensure_ascii=False
    )


def unpack_results(run_method):

    def wrapper(self, *args, **kwargs):
        run_method(self, *args, **kwargs)

        if self.testcase.type == 'io':
            self.result['summary'] = STATUS_CODE[self.result['status']]
          
    return wrapper


def parse_test_report(data):
    if data and data[0] == '{':
        # assume it's json
        try:
            report = json.loads(data)
            if 'summary' not in report:
                return None, None
            return report['summary'], report.get('feedback')
        except:
            pass

    parts = data.split('\n', 2)
    # TODO: Spaces should be accepted as summary lines.
    #       Improve the specification of a valid summary line.
    if ' ' in parts[0]:
        return None, None

    summary = parts[0]
    if len(parts) > 2:
        feedback = parts[2]
    else:
        feedback = None
        
    return summary, feedback


class TestRun:

    def __init__(self, subject, testcase):
        self.subject = subject
        self.testcase = testcase
        self.result = {}
        self.result['type'] = self.testcase.type
        self.result['name'] = self.testcase.name

    def run(self, timeout=TIMEOUT_DEFAULT):
        if self.testcase.type == 'io':
            self.run_iotest()

        elif self.testcase.type == 'script':
            self.run_script()

        else:
            assert False, 'unknown test type'


    @unpack_results
    def run_script(self, timeout=TIMEOUT_DEFAULT):
        if "{}" in self.testcase.script:
            cmd_str = self.testcase.script.format(self.subject.filename)
        else:
            cmd_str = self.testcase.script
        cmd_str = re.sub(r'\bpython\b', 'python3', cmd_str)
        command = shlex.split(cmd_str.encode('utf-8'))
        self.result['command'] = cmd_str

        process = Popen(command, stdout=PIPE, stderr=PIPE) 
        signal.alarm(timeout)
        try:
            stdout, stderr = map(to_unicode, process.communicate())
            signal.alarm(0) # reset the alarm
            process.wait()
        except CutTimeOut: # program didn't stop within expected time
            process.terminate()
            self.result['status'] = 'Timeout'
            self.result['summary'] = 't'
            return

        # collect test data
        self.result['exit_status'] = process.returncode
        self.result['stderr'] = stderr # comment out to remove from report
        self.result['stdout'] = stdout # comment out to remove from report

        # check for correct termination of the test script
        if self.result['exit_status'] != 0:
            self.result['status'] = 'ScriptTestError'
            self.result['summary'] = '!'
            return

        # collect report from either stderr or stdout
        summary, feedback = parse_test_report(stderr)
        if not summary:
            summary, feedback = parse_test_report(stdout)

        if not summary:
            self.result['status'] = 'Inconclusive'
            self.result['summary'] = '?'
            return

        self.result['summary'] = summary
        if summary == len(summary) * '.':
            self.result['status'] = 'Success'
        else:
            self.result['status'] = 'Fail'
            
        return


    @unpack_results
    def run_iotest(self, timeout=TIMEOUT_DEFAULT):

        # define command
        config = tst.get_config()
        if config.get('run'):
            # use run option
            run = config['run']
            extensions = run.keys()
            ext = self.subject.filename.split('.')[-1]
            if ext not in extensions:
                self.result['status'] = 'NoInterpreterError'
                return
            _assert(ext in extensions, "\nfatal: missing command for extension %s" % ext)
            command = config['run'][ext]
            cmd_str = "%s %s" % (command, self.subject.filename)
        else:
            # default is running through python
            cmd_str = '%s %s' % (PYTHON, self.subject.filename)

        command = shlex.split(cmd_str.encode('utf-8'))

        # encode test input data
        if self.testcase.input:
            input_data = self.testcase.input.encode('utf-8')
        else:
            input_data = ''
        self.result['input'] = self.testcase.input
        self.result['output'] = self.testcase.output
        self.result['cases'] = self.testcase.cases
        self.result['sha1'] = self.testcase.sha1

        # loop until running the test
        while True:
            process = Popen(command, stdin=PIPE, stdout=PIPE, stderr=PIPE) 
            signal.alarm(timeout) 
            try:

                # run subject as external process 
                process_data = process.communicate(input=input_data)
                stdout, stderr = map(to_unicode, process_data)

                # collect output data
                self.result['stdout'] = stdout # comment out to remove from report
                self.result['stderr'] = stderr # comment out to remove from report

                # reset alarm for future use
                signal.alarm(0)
                process.wait()

                # leave loop
                break

            except CutTimeOut:
                # timeout... give up
                process.terminate()
                self.result['status'] = 'Timeout'
                return

            except OSError:
                # external error... must run again!
                process.terminate()
      
        # 1. Did an ERROR ocurred during execution?
        if process.returncode != 0:

            # set default status
            self.result['status'] = 'DefaultError'

            # try use error code from stderr, if available
            for error_code in STATUS_CODE:
                if error_code in stderr:
                    self.result['status'] = error_code
                    break

            return
      
        # 2. Was it a PERFECTLY SUCCESSFUL run?
        preprocessed_stdout = preprocess(stdout, self.testcase.ignore)
        if self.testcase.preprocessed_output  == preprocessed_stdout:
            self.result['status'] = 'Success'
            return


        # 3. Was it a NORMALIZED SUCCESSFUL run?
        normalized_stdout = preprocess(stdout, DEFAULT_OPS)
        if self.testcase.normalized_output == normalized_stdout:
            self.result['status'] = 'QuasiSuccess'
            return

        # 4. Doesn't the testcase specify TOKENS?
        if not self.testcase.tokens:
            self.result['status'] = 'Fail'
            return

        # set flags to use with methods re.*
        flags = re.DOTALL|re.UNICODE
        if 'case' in self.testcase.ignore:
            flags = flags|re.IGNORECASE

        # 5. Does the output have ALL EXPECTED TOKENS IN SEQUENCE?
        regex = '(.*)%s(.*)' % ('(.*)'.join(self.testcase.tokens))
        if re.match(regex, preprocessed_stdout, flags=flags):
            self.result['status'] = 'AllTokensSequence'
            return

        # 6. Does the output have the TOKENS MULTISET?
        regex = '|'.join(self.testcase.tokens)
        found = re.findall(regex, preprocessed_stdout, flags=flags)
        found_ms = Counter(found)                        
        tokens_ms = Counter(self.testcase.tokens)
        if found_ms >= tokens_ms:
            self.result['status'] = 'AllTokensMultiset'
            return

        # 7. Does the output have a proper SUBSET OF THE TOKENS?
        if found_ms <= tokens_ms and len(found_ms) > 0:
            self.result['status'] = 'MissingTokens'
            return

        # 8. otherwise...
        self.result['status'] = 'Fail'
        return
          

class TestSubject:

    def __init__(self, filename):
        self.filename = filename
        self._io_results = ''
        self.analyzer_results = ''
        self.testruns = []
        self._results = None
        self._summaries = None


    def results(self):
        if not self._results:
            self._results = []
            for tr in self.testruns:
                self._results.append(tr.result)

        return self._results


    def add_testrun(self, testrun):
        self.testruns.append(testrun)
        self._summaries = None


    def feedbacks(self):
        return [tr['feedback'] for tr in self.results() if tr.get('feedback')]


    def summaries(self, join_io=False):
        if not self._summaries:
            iosummaries = []
            self._summaries = []
            for tr in self.results():
                if join_io and tr['type'] == 'io':
                    iosummaries.append(tr['summary'])
                else:
                    self._summaries.append(tr['summary'])

            if iosummaries:
                self._summaries.insert(0, ''.join(iosummaries))

        return self._summaries


    def verdict(self):
        return 'success' if all(c == '.' for c in ''.join(self.summaries())) else 'fail'


    def summary(self):
        status_codes = [tr['summary'] for tr in self.results()]
        return ''.join(status_codes)



class TestCase():

    def __init__(self, test):

        # get data from tst.json
        self.name = test.get('name')
        self.input = test.get('input', '')
        self.output = test.get('output', '')
        self.cases = test.get('cases', None)
        self.sha1 = test.get('sha1', None)
        self.tokens = test.get('tokens', [])
        self.ignore = test.get('ignore', [])
        self.script = test.get('script')
        self.command = test.get('command')
        self.files = test.get('files')
        self.type = test.get('type') or ('script' if self.script or self.command else 'io')
        if self.type == 'script':
            _assert(self.command or self.script, "script tests must have a command")
            _assert(not self.command or not self.script, "script tests cannot have both script and command")
            _assert(not self.input and not self.output, "script tests cannot have input/output")
            _assert(not self.tokens, "script tests cannot have tokens")
            self.script = self.script or self.command

        # convert tokens to a list of strings, if necessary
        if isinstance(self.tokens, basestring):
            self.tokens = self.tokens.split()

        # convert ignore to a list of strings, if necessary
        if isinstance(self.ignore, basestring):
            self.ignore = self.ignore.split()

        # identify tokens within the expected output
        if not self.tokens and '{{' in self.output:
            p = r'{{(.*?)}}' # r'{{.*?}}|\[\[.*?\]\]'
            self.tokens = re.findall(p, self.output)
            # remove tokens' markup from expected output
            self.output = re.sub(p, lambda m: m.group(0)[2:-2], self.output)

        # preprocess individual tokens
        for i in xrange(len(self.tokens)):
            self.tokens[i] = preprocess(self.tokens[i], self.ignore)

        # set up preprocessed output
        self.preprocessed_output = preprocess(self.output, self.ignore)

        # set up normalized output
        self.normalized_output = preprocess(self.output, DEFAULT_OPS)

        # setup expression to match tokens (so we do it once for all subjects)
        self.tokens_expression = '|'.join(self.tokens) if self.tokens else ''



def preprocess(text, operator_names):

    # add whites if punctuation is used
    if 'punctuation' in operator_names and 'whites' not in operator_names:
        operator_names = operator_names + ['whites']

    # expand if all is requested
    if operator_names == ['all']:
        operator_names = OPERATOR.keys()

    _assert(all(name in OPERATOR for name in operator_names), "unknown operator in ignore")

    # sort to assure 'whites' is last
    operators = [OPERATOR[name] for name in sorted(operator_names)]

    # apply operators to text
    for op in operators:
        text = op(text)

    return text


def squeeze_whites(text):
    data = [lin.strip().split() for lin in text.splitlines()]
    data = [' '.join(line) for line in data]
    return '\n'.join(data)


def remove_linebreaks(text):
    # TODO: use carefully! it substitutes linebreaks for ' '
    return ' '.join(text.splitlines())


def drop_whites(text):
    # TODO: Use carefully! deletes all whites
    #       string.whitespace = '\t\n\x0b\x0c\r '
    table = dict((ord(char), None) for char in string.whitespace)
    return text.translate(table)


def punctuation_to_white(text):
    # WARNING: preprocess() silently adds 'whites' if punctuation is used
    # TODO: Use carefully! it substitutes punctuation for ' '
    # TODO: The specification is wrong!
    #       Punctuation should be changed to spaces? This will
    #       duplicate some whites... Should punctuation be deleted?
    #       This would merge tokens into a single one... Should we
    #       have a mixed behavior? All punctuation surrounded by white
    #       would be deleted and punctuation not surrounded
    # For now, it should be used with whites to work properly.
    table = dict((ord(char), u' ') for char in string.punctuation)
    return text.translate(table)


def squeeze_all_whites(text):
    return ' '.join(text.strip().splitlines())


def strip_blanks(text):
    return ' '.join(text.split())


def strip_accents(text):
    # text must be a unicode object, not str
    try:
        nkfd_form = unicodedata.normalize('NFKD', unicode(text,'utf-8'))
        only_ascii = nkfd_form.encode('ASCII', 'ignore')
    except:
        nkfd_form = unicodedata.normalize('NFKD', text)
        only_ascii = nkfd_form.encode('ASCII', 'ignore')

    return unicode(only_ascii)


OPERATOR = {
    'case': string.lower,
    'accents': strip_accents,
    'extra_whites': squeeze_whites,
    'linebreaks': remove_linebreaks, # not default
    'punctuation': punctuation_to_white, # not default
    'whites': drop_whites, # not default
}
DEFAULT_OPS = ['case', 'accents', 'extra_whites']

class CutTimeOut(Exception): pass

def alarm_handler(signum, frame):
    raise CutTimeOut

signal.signal(signal.SIGALRM, alarm_handler)


class StatusLine:

    def __init__(self):
        if not sys.stdout.isatty():
            self.terminal = False
            return

        self.terminal = True
        self.lastline = ''
        self.stty_size = os.popen('stty size', 'r').read()
        self.columns = int(self.stty_size.split()[1]) - 10


    def set(self, line):
        if not self.terminal:
            return
        GREEN = '\033[92m'
        RESET = '\033[0m'
        line = GREEN + line[:self.columns]+ RESET
        sys.stderr.write('\r%s\r' % ((1+len(self.lastline)) * ' '))
        sys.stderr.write(line)
        sys.stderr.flush()
        self.lastline = line


    def clear(self):
        if not self.terminal:
            return
        sys.stderr.write('\r%s\r' % ((1+len(self.lastline)) * ' '))
        sys.stderr.write('')
        sys.stderr.flush()
        self.lastline = ''


def indent(text):
    lines = text.splitlines()
    text = "\n".join(["    %s" % l for l in lines])
    return text


def color(color, text):
    reset = RESET
    if REDIRECTED:
        color = ""
        reset = ""
    return color + text + reset


def report_results(subjects, output):

    if output == 'debug':
        report = {}
        for subject in subjects:
            if all([r['status'] == 'Success' for r in subject.results()]) == True:
                continue
            print(color(LBLUE, "# %s: `%s`" % (subject.filename, subject.summary())))
            fails = []
            for res in subject.results():
                if res['status'] == 'Success':
                    continue

                if res['type'] == 'io':
                    name = '`%s`: ' % res['name'] if res['name'] else ''
                    print(color(LCYAN, "\n## %s`%s` (`%s`)" % (name, res['status'], res['summary'])))
                    if res['status'] == 'Timeout':
                        print('- input')
                        print(color(LCYAN, indent(res['input'])))

                    else:
                        print('### input:')
                        print(color(LCYAN, indent(res['input'])))
                        print(indent('(%s)' % color(LCYAN, repr(res['input'])[1:])))

                        print('### output:')
                        print(color(LCYAN, indent(res['output'])))
                        print(indent('(%s)' % color(LCYAN,
                        repr(res['output'])[1:])))

                        print('### output diff')
                        differ = difflib.Differ()
                        diff = differ.compare(res['stdout'].splitlines(), res['output'].splitlines())
                        for e in diff:
                            if e[0] == '+':
                                line = indent(color(LGREEN, "+ ") + color('\033[1;32;100m', e[2:]))
                            elif e[0] == '-':
                                line = indent(color(LRED, "- ") + color('\033[1;31;100m', e[2:]))
                            else:
                                line = indent(color('\033[2m', e))
                            print(line.encode('utf-8'))

                elif res['type'] == 'script':
                    name = '`%s`: ' % res['name'] if res['name'] else ''
                    print(color(LCYAN, "\n## %s`%s` (`%s`)" % (name, res['status'], res['summary'])))
                    report = res['stderr'] or res['stdout']
                    if report:
                        print('### test script: `%s`' % res['command'])
                        print(indent(report))

            print()

    elif output == 'debug':
        report = {}
        for subject in subjects:
            fails = []
            for res in subject.results():
                if res['status'] != 'Success' and res['type'] == 'io':
                    fails.append({
                        'input': res['input'],
                        'expected_output': res['output'],
                        'observed_output': res['stdout'],
                        'status': res['status']
                    })
                elif res['status'] != 'Success' and res['type'] == 'script':
                    fails.append({
                        'status': res['status'],
                        'summary': res['summary'],
                    })

            if fails:
                report[subject.filename] = fails

        print(data2json(report).encode('utf-8'))

    elif output == 'worker':
        data = {
            subject.filename: {
                "summary": "".join(subject.summaries(join_io=True))
            }
            for subject in subjects
        }

        print(data2json(data).encode('utf-8'))

    elif output == 'json':
        data = [{
            'subject': subject.filename,
            'tests': subject.results(),
            'summary': "".join(subject.summaries(join_io=True)),
            'verdict': subject.verdict()
        } for subject in subjects]

        print(data2json(data).encode('utf-8'))

    elif output == 'tests':
        _assert(len(subjects) == 1, "-o tests requires one single subjec")
        data = {
            "tests": [{
                "input": test['input'],
                "output": test['stdout'],
                "cases": test['cases'],
                "sha1": test.get('sha1', None),
            } for test in subjects[0].results()]
        }

        print(data2json(data).encode('utf-8'))

    elif output == 'summary':
        cwd = os.path.basename(os.path.normpath(os.getcwd()))
        print("%s:" % cwd)
        for subject in subjects:
            print("    %s: %s" % (subject.filename, subject.verdict()))

    return


def parse_cli():
    parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument('-t', '--timeout', type=int, default=5, help='stop execution at TIMEOUT seconds')
    parser.add_argument('-o', '--output', choices=['raw', 'debug', 'json', 'tests', 'summary', 'worker'], help='set output type')
    parser.add_argument('-s', '--spec-file', type=str, default=None, help='read specification from file (default: tst.yaml or tst.json)')
    parser.add_argument('-x', '--extra-tests', type=str, default='*.yaml,*.json', help='specify files to search for additional tests')
    parser.add_argument('filename', nargs='*', default=[''])
    args = parser.parse_args()

    # determine filenames
    if len(args.filename) == 1 and os.path.exists(args.filename[0]):
        filenames = [args.filename[0]]

    elif len(args.filename) == 1:
        fn_pattern = '*%s*' % args.filename[0]
        filenames = glob.glob(fn_pattern)

    else:
        filenames = args.filename

    # identify extra tests patterns
    extra_tests = args.extra_tests.split(",") if args.extra_tests else []

    return filenames, args.timeout, args.output, args.spec_file, extra_tests


def debug(msg=''):
    cwd = os.path.basename(os.path.normpath(os.getcwd()))
    line = "tst: debug: dir='%s', msg='%s'" % (cwd, msg)
    print(line.encode('utf-8'), file=sys.stderr)


def tests_report(summary):
    num_tests = len(summary)
    successes = sum(1 for c in summary if c == '.')
    fails = sum(1 for c in summary if c in '*F')
    errors = num_tests - successes - fails

    return num_tests, successes, fails, errors


def main():

    # parse command line data
    filenames, timeout, output_fmt, specfile, patterns2scan = parse_cli()

    # read specification file
    tstjson = tst.read_specification(specfile, verbose=True)
    specfile = tstjson.filename

    # check for files required by specification
    for pattern in tstjson.get('require', []):
        _assert(glob.glob(pattern), "Missing required files for this test: %s" % pattern)

    # identify files to scan for tests
    files2scan = [specfile]
    for pattern in patterns2scan:
        for filename in glob.glob(pattern):
            if filename not in files2scan:
                files2scan.append(filename)

    # read tests
    tests = []
    for filename in files2scan:
        cprint(LCYAN, "Reading %s" % filename, end='')
        try:
            testsfile = JsonFile(filename, array2map="tests")
            cprint(LGREEN, " (%s tests)" % len(testsfile.get("tests", [])))
            tests.extend(testsfile["tests"])
        except KeyError as e:
            #_msg = " %s✗%s (no tests found)" % (LRED, YELLOW)
            #cprint(YELLOW, _msg.encode('utf-8'))
            pass
        except tst.jsonfile.CorruptedJsonFile as e:
            _msg = " %s(corrupted json)" % LRED
            cprint(YELLOW, _msg.encode('utf-8'))

    # make sure there are tests
    _assert(tests, '0 tests found')
    
    # filter filenames based on extensions and ignore_files
    config = tst.get_config()
    extensions = tstjson.get('extensions') or config['run'].keys() if 'run' in config else ['py']
    ignore_files = tstjson.get('ignore', [])
    filenames = [f for f in filenames if any(f.endswith(e) for e in extensions) and f not in ignore_files]
    filenames.sort()
    _assert(filenames, 'No files to test')

    # read subjects and testcases
    subjects = [TestSubject(fn) for fn in filenames]
    testcases = [TestCase(t) for t in tests]

    # define output_fmt (use user's preference if any)
    output_fmt = output_fmt or tst.get_config().get('default_output_format') or 'raw'

    # run every test case for every subject
    status = StatusLine()
    tests2go = len(subjects) * len(testcases)
    tests_performed = 0
    for subject in subjects:
        for testcase in testcases:
            status.set('%d of %d tests' % (tests_performed, tests2go))
            testrun = TestRun(subject, testcase)
            testrun.run(timeout=timeout)
            subject.add_testrun(testrun)
            status.clear()
            tests_performed += 1

        if output_fmt == 'raw':
            summary = ''.join(subject.summaries(join_io=True))
            line = '%s %s' % (summary, subject.filename)
            print(line.encode('utf-8'), file=sys.stderr)

    output_fmt and report_results(subjects, output_fmt)


if __name__ == '__main__':
    if len(sys.argv) > 1 and sys.argv[1] == '--one-line-help':
        print('run tests specified in tst.json')
        sys.exit(0)

    main()
