#!python

'''
Usage:
    sqs-consume --config <configfile> --source <source_name>
    sqs-consume --version
'''

'''
+mdoc+

sqs-consume connects to an Amazon SQS queue, pulls messages down from that queue, and forwards
each message to a user-defined handler function.

+mdoc+
'''


import os, sys
import json
import time
import datetime
from multiprocessing import Process
import boto3
import docopt
from snap import snap, common
from mercury import journaling as jrnl
from sh import git

from mercury.mlog import mlog, mlog_err

VERSION_NUM = '0.5.2'

def show_version():
    git_hash = git.describe('--always').lstrip().rstrip()
    return '%s[%s]' % (VERSION_NUM, git_hash)


def main(args):
    if args['--version']:
        print(show_version())
        return

    configfile = args['<configfile>']
    yaml_config = common.read_config_file(configfile)

    source_name = args['<source_name>']
    if not yaml_config['sources'].get(source_name):
        raise Exception(f'No queue source "{source_name}" defined. Please check your config file.')

    service_tbl = snap.initialize_services(yaml_config)
    source_config = yaml_config['sources'][source_name]

    # Create SQS client
    region = source_config['region']
    polling_interval = int(source_config['polling_interval_seconds'])
    
    print('############## SPINNING UP SQS client', file=sys.stderr)

    sqs = None
    auth_method = source_config['auth_method']

    if auth_method == 'profile':
        profile_name = source_config['profile_name'] 
        session = boto3.Session(profile_name=profile_name)
        sqs = session.client('sqs', region_name=region)
    
    elif auth_method == 'iam':
        session = boto3.Session()
        sqs = session.client('sqs', region_name=region)

    else:
        raise Exception(f'Unsupported auth method: {auth_method}')
    
    queue_url = common.load_config_var(source_config['queue_url'])
    max_messages = int(source_config['max_msgs_per_cycle'])
    msg_handler_name = source_config['handler']
    event_type = source_config['event_type']

    project_dir = common.load_config_var(yaml_config['globals']['project_home'])
    sys.path.append(project_dir)
    
    msg_handler_module = yaml_config['globals']['consumer_module']
    msg_handler_func = common.load_class(msg_handler_name, msg_handler_module)

    child_procs = []
    # loop forever
    
    current_time = datetime.datetime.now().isoformat()
    print('### checking SQS queue for messages at %s...' % current_time, file=sys.stderr)
    while True:
        # Receive message from SQS queue
        response = sqs.receive_message(
            QueueUrl=queue_url,
            AttributeNames=[
                'SentTimestamp'
            ],
            MaxNumberOfMessages=max_messages,
            MessageAttributeNames=[
                'All'
            ],
            VisibilityTimeout=30,
            # VisibilityTimeout (integer) -- The duration (in seconds) that the received messages
            # are hidden from subsequent retrieve requests after being retrieved by a ReceiveMessage request.
            WaitTimeSeconds=1
            # WaitTimeSeconds (integer) -- The duration (in seconds) for which the call waits for a message 
            # to arrive in the queue before returning.
            # If a message is available, the call returns sooner than WaitTimeSeconds . If no messages are available
            # and the wait time expires, the call returns successfully with an empty list of messages.
        )
        
        inbound_msgs = response.get('Messages') or []
        if not len(inbound_msgs):
            
            time.sleep(polling_interval)
            continue

        '''
        
        
        @counter('call annotated function', global_count_log)
        def some_func():
        pass

        '''
        msg_count_log = jrnl.CountLog()

        for message in inbound_msgs:

            with jrnl.counter('handle inbound messages', msg_count_log):
                receipt_handle = message['ReceiptHandle']
                current_time = datetime.datetime.now().isoformat()
                print('### spawning message processor at %s...' % current_time, file=sys.stderr)

                try:              
                    message_body_raw = message['Body']
                    message_body = json.loads(message_body_raw)

                    for record in message_body['Records']:
                        event_source = record['eventSource']
                        service_name = event_source.split(':')[1]

                        event_payload = record[service_name]

                        p = Process(target=msg_handler_func, args=(event_type, event_payload, receipt_handle, service_tbl))
                        p.start()
                        child_procs.append(p)
                        print('### Queued message-handling subprocess.', file=sys.stderr)

                        # Delete received message from queue
                        sqs.delete_message(
                            QueueUrl=queue_url,
                            ReceiptHandle=receipt_handle
                        )
                        print('### Received and deleted message with receipt: %s' % receipt_handle, file=sys.stderr)
                except Exception as err:
                    print('!!! Error processing message with receipt: %s' % receipt_handle, file=sys.stderr)
                    print('')
                    print(err, file=sys.stderr)

        mlog(msg_count_log.readout, message_count=msg_count_log.op_data['handle inbound messages']) 


if __name__ == '__main__':
    args = docopt.docopt(__doc__)
    main(args)
