#!python

'''
Usage:
    quasr [-p] --config <configfile> --job <jobname> [--params=<name:value>...]
    quasr [-p] --config <configfile> --jobs [--file <filename>]
    quasr --config <configfile> --list [-v]

Options:
    -p --preview      preview mode
    -v --verbose      verbose job listing
'''

# QUASR: Quality Assurance SQL Runner

import os, sys
import json
import re
import docopt
from collections import namedtuple
from snap import snap, common
from mercury import journaling as jrnl
from uuid import UUID


PARAM_DELIMITER = ':'
PARAM_TOKEN_DELIMITER = ','

integer_rx = re.compile(r'^[0-9]+$')
float_rx = re.compile(r'^[0-9]*.[0-9]+$')
string_rx = re.compile(r'^[a-zA-Z0-9_]+$')
uuid_rx = re.compile(r'^[a-f0-9]{8}-?[a-f0-9]{4}-?4[a-f0-9]{3}-?[89ab][a-f0-9]{3}-?[a-f0-9]{12}\Z', re.I)


TypeFormat = namedtuple('TypeFormat', 'class_object regex')

type_formats = {
    'str': TypeFormat(class_object=str, regex=string_rx),
    'float': TypeFormat(class_object=float, regex=float_rx),
    'int': TypeFormat(class_object=int, regex=integer_rx),
    'uuid': TypeFormat(class_object=UUID, regex=uuid_rx)
}

Slot = namedtuple('Slot', 'name datatype')
 # We are standardizing on three categories of DataCondition: "note", "flag", and "error".
 # A "note" indicates something which the reader of the report may want to know about,
 # but is not consequential in terms of data quality, and therefore is safe to ignore.
 #
 # A "flag" is a condition which requires human attention, but does not necessarily mean
 # that the dataset under analysis contains an error.
 #
 # An "error" means that some property of the analyzed dataset violates expected standards. 

DataCondition = namedtuple('DataCondition', 'category message')

class QaSqlNode(object):
    def __init__(self, query_template):
        self.sql_template = query_template
        self.input_slots = {}
        self.output_slots = {}


    def add_input_slot(self, name, classname):
        self.input_slots[name] = Slot(name, classname)


    def add_output_slot(self, name, classname):
        self.output_slots[name] = Slot(name, classname)


    def validate_input_types(self, **kwargs):
        for name, _ in self.input_slots.items():
            if kwargs.get(name) is None:
                raise Exception('No data supplied for input slot "%s".' % name)

        for name, value in kwargs.items():
            input_slot = self.input_slots.get(name)
            if not input_slot:
                raise Exception('No input slot "%s" registered with this QA node.')

            if value.__class__.__name__ != input_slot.datatype:
                can_convert = True
                type_format = type_formats.get(input_slot.datatype)
                print(f'attempting to convert input value {value} to type {type_format.class_object}...')
                try:
                    type_format.class_object(str(value))
                except:
                    can_convert = False

                if not can_convert:
                    raise Exception('The input parameter "%s" is not compatible with the registered slot type "%s".' 
                                    % (name, input_slot.datatype))

    def validate_outputs(self, **kwargs):
        for slot_name, output_slot in self.output_slots.items():
            if not kwargs.get(slot_name):
                raise Exception('This QA job node contains an output slot "%s", but no such output field was generated.' % slot_name)
            

    def generate_query(self, **kwargs):
        kwreader = common.KeywordArgReader(*self.input_slots.keys())
        kwreader.read(**kwargs)
        self.validate_input_types(**kwargs)
        return self.sql_template.format(**kwargs)


class QaJobRunner(object):
    def __init__(self, yaml_config):
        self.qa_nodes = {}
        self.config = yaml_config
        print('### initializing services...', file=sys.stderr)
        self.service_registry = common.ServiceObjectRegistry(snap.initialize_services(yaml_config))
        print('### services spun up.', file=sys.stderr)
        self.analyzers = {}

        job_config_group = yaml_config['jobs']
        for nodename in job_config_group:
            template = self._load_template(job_config_group[nodename]['sql_template'])
            inputs =  self._load_input_values(job_config_group[nodename]['inputs'] or [])
            outputs = self._load_output_values(job_config_group[nodename]['outputs'] or [])

            if job_config_group[nodename].get('analyzer'):
                logic_module_name = self.config['globals']['qa_logic_module']
                analyzer_func = common.load_class(job_config_group[nodename].get('analyzer'),
                                                  logic_module_name)
                self.analyzers[nodename] = analyzer_func                                                  
            node = QaSqlNode(template)

            for input in inputs:                
                node.add_input_slot(input['name'], input['type'])

            for output in outputs:                
                node.add_output_slot(output['name'], output['type'])

            self.qa_nodes[nodename] = node

    def _lookup_analyzer(self, jobname):
        return self.analyzers.get(jobname)


    def _load_input_values(self, dict_array):
        input_values = []
        for input in dict_array:
            input_values.append({'name': input['name'], 'type': input['type']})
        return input_values


    def _load_output_values(self, dict_array):
        output_values = []
        for output in dict_array:
             output_values.append({'name': output['name'], 'type': output['type']})
        return output_values


    def _load_template(self, template_name):        
        if template_name.startswith('(') and template_name.endswith(')'):
            # a template name in parentheses means that the template itself 
            # is in the config file
            template_name = template_name.lstrip('(').rstrip(')')
            if not self.config.get('templates'):
                raise Exception('bad internal template ref: the configfile has no "templates" section.')
            if not self.config['templates'].get(template_name):
                raise Exception('no template registered under the alias "%s" in the "templates" config section.'
                                % template_name)
            return self.config['templates'][template_name]
        else:
            # otherwise the template is in the module referenced in the 
            # globals section as template_module
            
            # this is not really loading a class -- the underlying logic just pulls the named
            # attribute of the loaded Python module using getattr(). 
            # TODO: create a load_module_object() method, so as not to mislead
            return common.load_class(template_name, self.config['globals']['template_module'])
        

    def _lookup_node(self, job_name):
        node = self.qa_nodes.get(job_name)
        if not node:
            raise Exception('no job registered under alias "%s".' % job_name)
        return node


    def get_sql_template(self, job_name):
        node = self._lookup_node(job_name)        
        return node.sql_template


    def inputs(self, job_name):
        node = self._lookup_node(job_name)
        result = {}
        for name, slot in node.input_slots.items():
            result[name] = slot.datatype
        return result


    def convert_input_params(self, job_name, **kwargs):
        node = self._lookup_node(job_name)
        job_params = {}
        slots = node.input_slots
        for param_name, param in kwargs.items():
            input_slot = slots.get(param_name)
            if not input_slot:
                raise Exception('no input slot registered with name "%s".' % param_name)
            
            type_format = type_formats.get(input_slot.datatype)
            if not type_format:
                raise Exception('input datatype "%s" is not supported.' % input_slot.datatype)
            
            if not type_format.regex.match(param):
                raise Exception('the format of input param "%s" (value: "%s") does not match type %s.'
                                % (param_name, param, input_slot.datatype))

            job_params[param_name] = type_format.class_object(param)

        return job_params


    def validate_outputs(self, raw_outputs, job_name):
        missing_fields = []
        node = self._lookup_node(job_name)
        for key, _ in node.output_slots.items():
            if raw_outputs.get(key) is None:
                missing_fields.append(key)
        
        if len(missing_fields):
            raise Exception('the generated outputs for job node "%s" are missing the fields: %s' 
                            % (job_name, missing_fields))


    def analyze(self, raw_outputs, job_name):
        result = []
        analyzer_func = self._lookup_analyzer(job_name)
        if analyzer_func:
            for condition in analyzer_func(raw_outputs, self.service_registry):
                result.append({'type': condition.category, 'message': condition.message})
        return result


    def run(self, job_name, preview_mode=False, **kwargs):
        # execute query -- must return record generator
        # call the job node's generate_outputs() method and pass in generator
        # if there is an analyzer attached to this job, pass the outputs to it
        # return the raw and analyzed outputs
        node = self._lookup_node(job_name)
        query = node.generate_query(**kwargs)
        result = {}
        if preview_mode:
            print(query)
        else:
            print('### running QA job "%s".' % job_name, file=sys.stderr)
            job_config_group = self.config['jobs']
            exec_func_name = job_config_group[job_name]['executor_function']
            output_builder_func_name = job_config_group[job_name]['builder_function']
            # this is not a class being loaded, but a pythonic generator function
            query_executor = common.load_class(exec_func_name,  self.config['globals']['qa_logic_module'])

            # load the function which builds outputs from the query results
            output_builder_func = common.load_class(output_builder_func_name, self.config['globals']['qa_logic_module'])
            raw_outputs =  output_builder_func(query_executor(query, self.service_registry, **kwargs))
            self.validate_outputs(raw_outputs, job_name)
            result['job_output'] = raw_outputs
            result['analysis_ouptut'] = self.analyze(raw_outputs, job_name)

        return result


def read_params_from_args(arg_dict):
    params = {}
    if not arg_dict.get('--params'):
        return params

    param_string = args['--params'][0]
    param_tokens = [pstring.lstrip().rstrip() for pstring in param_string.split(PARAM_TOKEN_DELIMITER)]
    for token in param_tokens:
        if PARAM_DELIMITER not in token:
            raise Exception('the optional --params string must be in the form of name1:value1,...nameN:valueN')
        name = token.split(PARAM_DELIMITER)[0].lstrip().rstrip()
        value = token.split(PARAM_DELIMITER)[1].lstrip().rstrip()
        params[name] = value
    return params


def parse_param_string(input_param_string):
    '''parse an input string of the format 

    <name1>:<value1>,<name2>:<value2>,...

    and return a dictionary of name-value pairs
    '''

    params = {}
    tokens = input_param_string.split(',')
    for tok in tokens:
        param_tokens = tok.split(':')
        if len(param_tokens) != 2:
            raise Exception('Each parameter string in a job file must be of the format <name>:<value>')

        name = param_tokens[0]
        value = param_tokens[1]
        params[name] = value

    return params


def main(args):    
    configfile = args['<configfile>']
    yaml_config = common.read_config_file(configfile)
    preview_mode = False
    timelog = jrnl.TimeLog()

    if args['--preview'] == True:
        preview_mode = True
        print('### Running in preview mode.', file=sys.stderr)

    if args['--list']:
        jobs = []
        if args['--verbose']:
            for job in yaml_config['jobs']:
                input_params = yaml_config['jobs'][job]['inputs'] or []
                output_params = yaml_config['jobs'][job]['outputs'] or []
                jobs.append({
                    'name': job,
                    'inputs': input_params,
                    'outputs': output_params
                })
        else:
            for job in yaml_config['jobs']:
                jobs.append(job)

        print(json.dumps(jobs))
        return

    project_dir = common.load_config_var(yaml_config['globals']['project_home'])
    sys.path.append(project_dir)
    job_runner = QaJobRunner(yaml_config)

    if args['--job']:
        jobname = args['<jobname>']       
        input_params = read_params_from_args(args)

        print('### running in single-job mode.', file=sys.stderr)
        print('### target job: %s' % jobname, file=sys.stderr)
        
        job_params = job_runner.convert_input_params(jobname, **input_params)
        print('### job parameters:',  file=sys.stderr)
        print(job_params, file=sys.stderr)

        if preview_mode:
            job_runner._lookup_node(jobname).validate_input_types(**input_params)
            print(job_runner.get_sql_template(jobname).format(**input_params))
        else:
            print(json.dumps(job_runner.run(jobname, **job_params)))
        return

    elif args['--jobs']:
        job_strings = []
        if args['--file']:
            jobfile_name = args['<filename>']
            with open(jobfile_name) as f:
                for line in f:
                    job_strings.append(line.lstrip().rstrip())
        else:
            for line in sys.stdin:
                if sys.hexversion < 0x03000000:
                    line = line.decode('utf-8')
                job_strings.append(line.lstrip().rstrip())

            for job_string in job_strings:
                # Each job string is in the format <job_name> <param1>:<value1>,...
                job_tokens = job_string.split(' ')
                if len(job_tokens) != 2:
                    print('Each job string must be in the format <job_name> <param1>:<value1>,...')
                    return

                jobname = job_tokens[0]
                param_string = job_tokens[1]
                raw_params = parse_param_string(param_string)
                input_params = job_runner.convert_input_params(jobname, **raw_params)

                job_runner._lookup_node(jobname).validate_input_types(**input_params)
                if preview_mode:
                    print(job_runner.get_sql_template(jobname).format(**input_params))
                else:
                    with jrnl.stopwatch('qa_job', timelog):
                        print(json.dumps(job_runner.run(jobname, **input_params)))
                    print(common.jsonpretty(timelog.readout), file=sys.stderr)
        return


if __name__ == '__main__':
    args = docopt.docopt(__doc__)
    main(args)

