#!python

'''
Usage:
    sqs-consume --config <configfile> --source <source_name> [--msgtype (sns | sqs)]
    sqs-consume --version
'''

'''
+mdoc+

sqs-consume connects to an Amazon SQS queue, pulls messages down from that queue, and forwards
each message to a user-defined handler function.

+mdoc+
'''


from logging import Handler
import multiprocessing
import os, sys
import json
import time
import datetime

from multiprocessing import Process, Pool, Queue
import queue
from unittest import result
from collections import namedtuple
import boto3
import docopt
from snap import snap, common
from mercury import journaling as jrnl
from sh import git

from mercury.mlog import mlog, mlog_err

VERSION_NUM = '0.5.2'
DEFAULT_POOLSIZE = 5
DEFAULT_MESSAGE_TYPE = 'sqs'

HandlerStatus = namedtuple('HandlerStatus', 'queue_url rcpt_handle result')


completion_queue = Queue()


def show_version():
    git_hash = git.describe('--always').lstrip().rstrip()
    return '%s[%s]' % (VERSION_NUM, git_hash)


#
# Loading the target function here is suboptimal (and will definitely make us eat slightly more latency), but
# (a) Python refuses to pickle the passed function, making the wrapper-strategy impossible, and
# (b) in context it shouldn't matter too much, because the kinds of events we're processing are S3 file uploads.
# Note that any pipeline needing to process "wire" events directly should use a different strategy.
#
def msg_handler_wrapper_func(queue_url, msg_handler_func, event_type, event_payload, receipt_handle, yaml_config):

    service_tbl = snap.initialize_services(yaml_config)
    result = msg_handler_func(event_type, event_payload, receipt_handle, service_tbl)
    return HandlerStatus(queue_url=queue_url, rcpt_handle=receipt_handle, result=result)


def handle_ok(handler_status):
    global completion_queue
    completion_queue.put(handler_status)


def handle_error(exception):
    mlog_err(exception, detail='Error processing queued message.')


def main(args):
    if args['--version']:
        print(show_version())
        return

    configfile = args['<configfile>']
    yaml_config = common.read_config_file(configfile)

    source_msg_type = DEFAULT_MESSAGE_TYPE
    if args.get('sns'):
        source_msg_type = 'sns'

    source_name = args['<source_name>']
    if not yaml_config['sources'].get(source_name):
        raise Exception(f'No queue source {source_name} defined. Please check your config file.')

    source_config = yaml_config['sources'][source_name]

    # Create SQS client
    region = source_config['region']
    polling_interval = int(source_config['polling_interval_seconds'])

    sqs = None
    auth_method = source_config['auth_method']

    if auth_method == 'profile':
        profile_name = source_config['profile_name']
        session = boto3.Session(profile_name=profile_name)
        sqs = session.client('sqs', region_name=region)

    elif auth_method == 'iam':
        session = boto3.Session()
        sqs = session.client('sqs', region_name=region)

    else:
        raise Exception(f'Unsupported auth method: {auth_method}')

    queue_url = common.load_config_var(source_config['queue_url'])
    max_messages = int(source_config['max_msgs_per_cycle'])
    msg_handler_name = source_config['handler']
    event_type = source_config['event_type']

    project_dir = common.load_config_var(yaml_config['globals']['project_home'])
    sys.path.append(project_dir)

    msg_handler_module = yaml_config['globals']['consumer_module']
    msg_handler_func = common.load_class(msg_handler_name, msg_handler_module)

    # loop forever

    current_time = datetime.datetime.now().isoformat()
    print('### checking SQS queue for messages at %s...' % current_time, file=sys.stderr)

    global completion_queue

    with Pool(DEFAULT_POOLSIZE) as pool:

        while True:
            # Receive message from SQS queue
            response = sqs.receive_message(
                QueueUrl=queue_url,
                AttributeNames=[
                    'SentTimestamp'
                ],
                MaxNumberOfMessages=max_messages,
                MessageAttributeNames=[
                    'All'
                ],
                VisibilityTimeout=30,
                # VisibilityTimeout (integer) -- The duration (in seconds) that the received messages
                # are hidden from subsequent retrieve requests after being retrieved by a ReceiveMessage request.
                WaitTimeSeconds=1
                # WaitTimeSeconds (integer) -- The duration (in seconds) for which the call waits for a message 
                # to arrive in the queue before returning.
                # If a message is available, the call returns sooner than WaitTimeSeconds . If no messages are available
                # and the wait time expires, the call returns successfully with an empty list of messages.
            )

            inbound_msgs = response.get('Messages') or []
            if not len(inbound_msgs):
                time.sleep(polling_interval)
                continue

            msg_count_log = jrnl.CountLog()

            for message in inbound_msgs:

                with jrnl.counter('handle inbound messages', msg_count_log):
                    receipt_handle = message['ReceiptHandle']
                    current_time = datetime.datetime.now().isoformat()

                    message_body_raw = message['Body']
                    message_body = json.loads(message_body_raw)

                    if source_msg_type == 'sns':
                        message_body_message_raw = message_body['Message']
                        message_body_data = json.loads(message_body_message_raw)

                    elif source_msg_type == 'sqs':
                        message_body_data = message_body

                    for record in message_body_data['Records']:

                        event_source = record['eventSource']
                        service_name = event_source.split(':')[1]

                        event_payload = record[service_name]

                        pool.apply_async(msg_handler_wrapper_func,
                                        (queue_url, msg_handler_func, event_type, event_payload, receipt_handle, yaml_config),
                                        callback=handle_ok,
                                        error_callback=handle_error)

            mlog(msg_count_log.readout, message_count=msg_count_log.op_data['handle inbound messages'])

            while not completion_queue.empty():
                try:
                    handler_status = completion_queue.get(False) # don't block, we'll get it on the next iteration of the outer loop
                    sqs.delete_message(
                        QueueUrl=handler_status.queue_url,
                        ReceiptHandle=handler_status.rcpt_handle
                    )
                    mlog(f'@@@ Deleted message with handle: {handler_status.rcpt_handle} at {datetime.datetime.now().isoformat()}')
                except queue.Empty:
                    break


if __name__ == '__main__':
    args = docopt.docopt(__doc__)
    main(args)
