#!python

'''
Usage:
    sqs-consume --config <configfile> --source <source_name> [--msgtype (sns | sqs)]
    sqs-consume --version
'''

'''
+mdoc+

sqs-consume connects to an Amazon SQS queue, pulls messages down from that queue, and forwards
each message to a user-defined handler function.

+mdoc+
'''


import os, sys
import json
import time
import datetime
from multiprocessing import Process, Pool
from unittest import result
from collections import namedtuple
import boto3
import docopt
from snap import snap, common
from mercury import journaling as jrnl
from sh import git

from mercury.mlog import mlog, mlog_err

VERSION_NUM = '0.5.2'
DEFAULT_POOLSIZE = 5
DEFAULT_MESSAGE_TYPE = 'sqs'

HandlerStatus = namedtuple('HandlerStatus', 'sqs_client queue_url rcpt_handle result')


def show_version():
    git_hash = git.describe('--always').lstrip().rstrip()
    return '%s[%s]' % (VERSION_NUM, git_hash)


def msg_handler_wrapper_func(sqs_client, queue_url, msg_handler_func, event_type, event_payload, receipt_handle, service_tbl):
    result = msg_handler_func(event_type, event_payload, receipt_handle, service_tbl)
    return HandlerStatus(sqs_client=sqs_client, queue_url=queue_url, rcpt_handle=receipt_handle, result=result)


def handle_ok(handler_status):
    # Delete received message from queue
    handler_status.sqs_client.delete_message(
        QueueUrl=handler_status.queue_url,
        ReceiptHandle=handler_status.rcpt_handle
    )


def handle_error(exception):
    mlog_err(exception, 'Error processing queued message.')


def main(args):
    if args['--version']:
        print(show_version())
        return

    configfile = args['<configfile>']
    yaml_config = common.read_config_file(configfile)

    source_msg_type = DEFAULT_MESSAGE_TYPE
    if args.get('sns'):
        source_msg_type = 'sns'

    source_name = args['<source_name>']
    if not yaml_config['sources'].get(source_name):
        raise Exception(f'No queue source "{source_name}" defined. Please check your config file.')

    service_tbl = snap.initialize_services(yaml_config)
    source_config = yaml_config['sources'][source_name]

    # Create SQS client
    region = source_config['region']
    polling_interval = int(source_config['polling_interval_seconds'])
    
    print('############## SPINNING UP SQS client', file=sys.stderr)

    sqs = None
    auth_method = source_config['auth_method']

    if auth_method == 'profile':
        profile_name = source_config['profile_name'] 
        session = boto3.Session(profile_name=profile_name)
        sqs = session.client('sqs', region_name=region)
    
    elif auth_method == 'iam':
        session = boto3.Session()
        sqs = session.client('sqs', region_name=region)

    else:
        raise Exception(f'Unsupported auth method: {auth_method}')
    
    queue_url = common.load_config_var(source_config['queue_url'])
    max_messages = int(source_config['max_msgs_per_cycle'])
    msg_handler_name = source_config['handler']
    event_type = source_config['event_type']

    project_dir = common.load_config_var(yaml_config['globals']['project_home'])
    sys.path.append(project_dir)
    
    msg_handler_module = yaml_config['globals']['consumer_module']
    msg_handler_func = common.load_class(msg_handler_name, msg_handler_module)

    # loop forever
    
    current_time = datetime.datetime.now().isoformat()
    print('### checking SQS queue for messages at %s...' % current_time, file=sys.stderr)

    with Pool(DEFAULT_POOLSIZE) as pool:

        while True:
            # Receive message from SQS queue
            response = sqs.receive_message(
                QueueUrl=queue_url,
                AttributeNames=[
                    'SentTimestamp'
                ],
                MaxNumberOfMessages=max_messages,
                MessageAttributeNames=[
                    'All'
                ],
                VisibilityTimeout=30,
                # VisibilityTimeout (integer) -- The duration (in seconds) that the received messages
                # are hidden from subsequent retrieve requests after being retrieved by a ReceiveMessage request.
                WaitTimeSeconds=1
                # WaitTimeSeconds (integer) -- The duration (in seconds) for which the call waits for a message 
                # to arrive in the queue before returning.
                # If a message is available, the call returns sooner than WaitTimeSeconds . If no messages are available
                # and the wait time expires, the call returns successfully with an empty list of messages.
            )
            
            inbound_msgs = response.get('Messages') or []
            if not len(inbound_msgs):
                time.sleep(polling_interval)
                continue

            msg_count_log = jrnl.CountLog()

            for message in inbound_msgs:

                with jrnl.counter('handle inbound messages', msg_count_log):
                    receipt_handle = message['ReceiptHandle']
                    current_time = datetime.datetime.now().isoformat()
                    print('### spawning message processor at %s...' % current_time, file=sys.stderr)

                    message_body_raw = message['Body']
                    message_body = json.loads(message_body_raw)

                    if source_msg_type == 'sns':
                        message_body_message_raw = message_body['Message']
                        message_body_data = json.loads(message_body_message_raw)

                    elif source_msg_type == 'sqs':
                        message_body_data = message_body
                    

                    for record in message_body_data['Records']:

                        event_source = record['eventSource']
                        service_name = event_source.split(':')[1]

                        event_payload = record[service_name]

                        result = pool.apply_async(msg_handler_wrapper_func, 
                                                  (sqs, queue_url, msg_handler_func, event_type, event_payload, receipt_handle, service_tbl,),
                                                  callback=handle_ok,
                                                  error_callback=handle_error)
                        
                        mlog('### Queued message-handling subprocess.')

            mlog(msg_count_log.readout, message_count=msg_count_log.op_data['handle inbound messages']) 


if __name__ == '__main__':
    args = docopt.docopt(__doc__)
    main(args)
