#!python

'''
Usage:
    segmentr --config <configfile> --job-name <jobname> --jobfile <filename> --segments=<segname>,... [--segment-params=<n,v>...] [--limit=<max_records>]
    segmentr --config <configfile> --job-name <jobname> --jobfile <filename> --segment-count=<number> [--segment-params=<n,v>...] [--limit=<max_records>]

'''

''' +mdoc+

segmentr performs a user-designated function on a (potentially very large) collection of records
by splitting those records into segments and processing those segments in parallel.

'''


import os, sys
import uuid
import time
from collections import namedtuple
import threading
import queue
from multiprocessing import Process, Pipe, JoinableQueue
import json
import docopt
import yaml
from snap import snap, common
from uhashring import HashRing
from mercury import utils


def segment_datafile(filename, segment_size, limit = -1):

    print(f'segmenting datafile with limit {limit}')

    if segment_size < 1:
        raise Exception('segment size must be greater than 0.')

    with open(filename) as f:
        segment_count = 0
        while True:
            if segment_count == limit:
                break
            segment = []
            for i in range(segment_size):
                line = f.readline()
                if line:
                    segment.append(json.loads(line))
        
            print(f'segment size is {len(segment)}')

            if not len(segment):
                break
            segment_count += 1

            yield segment


JobSpec = namedtuple('JobSpec', 'processor_func params')


def load_jobs(proc_module_name, yaml_config: dict)-> dict: 
    job_data = {}
    for job_name, job_config in yaml_config['jobs'].items():
        
        processor_func_name = job_config['processor']
        params = {}

        for nvpair in job_config.get('job_params', []):
            name = nvpair['name']
            value = nvpair['value']
            params[name] = value

        processor_function = common.load_class(processor_func_name,
                                               proc_module_name)

        job_data[job_name] = JobSpec(processor_func=processor_function,
                                     params=params)

    return job_data
 

def main(args):

    configfile = args['<configfile>']
    yaml_config = common.read_config_file(configfile)
    service_tbl = snap.initialize_services(yaml_config)
    services = common.ServiceObjectRegistry(service_tbl)
   
    limit = -1
    if args.get('--limit'):
        limit = int(args['--limit'])

    job_name = args['<jobname>']

    processor_module_name = yaml_config['globals']['processor_module']

    job_tbl = load_jobs(processor_module_name, yaml_config)
    job_spec = job_tbl.get(job_name)
    if not job_spec:
        raise Exception(f'No job registered under the alias {job_name}. Please check your configfile.')

    segment_names = []
    if args.get('--segments'):
        segment_names = args['--segments'][0].split(',')
    else:
        num_segments = int(args['--segment-count'])
        for i in range(num_segments):
            segment_names.append(f'segment_{i}')

    segment_size = len(segment_names)
    if args.get('--segment-params'): 
        segment_params = utils.parse_cli_params(args['--segment-params'])
    else:
        segment_params = dict()

    segment_params.update(job_spec.params)

    segment_table = dict()
    segment_workers = []

    for name in segment_names:
        segment_config = dict()
        segment_config.update(segment_params)
        segment_config['segment_id'] = name        
        send_channel = JoinableQueue()
        
        segment_config['send_channel'] = send_channel        
        segment_config['worker'] = Process(target=job_spec.processor_func, 
                                           args=(send_channel,
                                                 services,
                                                 segment_params))
        segment_config['worker'].start()
        segment_table[name] = segment_config
        
    segment_worker_pool = HashRing(segment_table)

    try:
        json_filename = args['<filename>']
        
        # Read the input records from the datafile, N records at a time (where N is our segment size)
        #
        segment_count = 1

        open_queues = dict()

        for segment in segment_datafile(json_filename, segment_size, limit):
            print(f'+++reading a segment of size {segment_size}')
            
            completion_records = []

            # Launch up to <segment_size> workers in parallel
            #
            print('+++ stepping through segment...')
            for record in segment:
                segment_worker_node_id = segment_worker_pool.get_node(hash(json.dumps(record)))
                segment_config = segment_table[segment_worker_node_id]           
                segment_config['send_channel'].put(record)
            segment_count += 1
        
        for name, segment_config in segment_table.items():
            segment_config['send_channel'].close()
            
        print('done.')

    finally:
        pass


if __name__ == '__main__':
    args = docopt.docopt(__doc__)
    main(args)