#!/Users/bsummers/.pyenv/versions/3.6.2/bin/python3.6
# -*- coding: utf-8 -*-
""" TcEx Data Module """
import base64
import collections
import json
import operator
import os
import re
import sys
import traceback
import urllib3

from past.builtins import basestring
from six import string_types

from tcex import TcEx

# disable ssl warning message
urllib3.disable_warnings()

# Python 2 unicode
if sys.version_info[0] == 2:
    reload(sys)
    sys.setdefaultencoding('utf-8')

tcex = TcEx()

# not including these libraries as a dependencies since they are only required for local testing.
try:
    from deepdiff import DeepDiff
except ImportError:
    tcex.exit(1, 'Could not import DeepDiff module (try "pip install deepdiff").')
try:
    import jmespath
except ImportError:
    tcex.exit(1, 'Could not import jmespath module (try "pip install jmespath").')

# supported action: clear, stage, validate
tcex.parser.add_argument('--action', help='The action to perform.', required=True)
tcex.parser.add_argument(
    '--data_file', help='The file containing the data to stage.', required=True)
tcex.parser.add_argument(
    '--data_owner', help='The owner for staging ThreatConnect data.', required=False)
args = tcex.args


class TcData(object):
    """Manage testing Data"""

    def __init__(self, arg_data):
        """ """
        self._args = arg_data
        self.data = []
        # match variable name with modifiers
        self._vars_match = re.compile(
            r'#([A-Za-z]+)'  # match literal (#App) at beginning of String
            r':([\d]+)'  # app id (:7979)
            r':([A-Za-z0-9_\.\-\[\]]+)'  # variable name (:variable_name)
            r'!(StringArray|BinaryArray|KeyValueArray'  # variable type (array)
            r'|TCEntityArray|TCEnhancedEntityArray'  # variable type (array)
            r'|String|Binary|KeyValue|TCEntity|TCEnhancedEntity'  # variable type
            r'|(?:(?!String)(?!Binary)(?!KeyValue)'  # non matching for custom
            r'(?!TCEntity)(?!TCEnhancedEntity)'  # non matching for custom
            r'[A-Za-z0-9_-]+))'  # variable type (custom)
        )
        # Define common comparison operators
        self._operators = {
            'dd': self.deep_diff,
            'deep-diff': self.deep_diff,
            'eq': operator.eq,
            'ew': self.data_endswith,
            'ends-with': self.data_endswith,
            'ge': operator.ge,
            'gt': operator.gt,
            'jc': self.json_compare,
            'json-compare': self.json_compare,
            'kva': self.data_kva_compare,
            'key-value-compare': self.data_kva_compare,
            'in': self.data_in_db,
            'in-db-data': self.data_in_db,
            'in-user-data': self.data_in_user,
            'ni': self.data_not_in,
            'not-in': self.data_not_in,
            'it': self.data_it,  # is type
            'is-type': self.data_it,
            'lt': operator.lt,
            'le': operator.le,
            'ne': operator.ne,
            'se': self.data_string_compare,
            'string-compare': self.data_string_compare,
            'sw': self.data_startswith,
            'starts-with': self.data_startswith
        }
        self.load_data(arg_data.data_file)

    def clear(self):
        """Clear Data"""
        for data in self.data:
            data_type = data.get('data_type', 'redis')  # default to redis for older data files
            if data_type == 'redis':
                tcex.log.debug('Clear Redis Data')
                self.redis_clear(data.get('variable'))
            elif data_type == 'threatconnect':
                tcex.log.debug('Clear ThreatConnect Data')
                self.tc_clear(data)
            elif data_type == 'redis_array':
                tcex.log.debug('Clear Redis Array')
                self.redis_clear(data.get('variable'))

    @staticmethod
    def data_endswith(db_data, user_data):
        """Validate data ends with user data"""
        if db_data.endswith(user_data):
            return True
        return False

    @staticmethod
    def data_in_db(db_data, user_data):
        """Validate db data in user data"""
        if db_data in user_data:
            return True
        return False

    @staticmethod
    def data_in_user(db_data, user_data):
        """Validate user data in db data"""
        if user_data in db_data:
            return True
        return False

    @staticmethod
    def data_it(db_data, user_type):
        """Validate data is Type"""
        data_type = {
            'array': (list),
            'binary': (bytes),
            'bytes': (bytes),
            'dict': (dict),
            'entity': (dict),
            'list': (list),
            'str': (string_types),
            'string': (string_types)
        }
        # user_type_tuple = tuple([data_type[t] for t in user_types])
        # if isinstance(db_data, user_type_tuple):
        if user_type is None:
            if db_data is None:
                return True
        elif user_type.lower() in ['null', 'none']:
            if db_data is None:
                return True
        elif data_type.get(user_type.lower()) is not None:
            if isinstance(db_data, data_type.get(user_type.lower())):
                return True
        return False

    @staticmethod
    def data_kva_compare(db_data, user_data):
        """Validate key/value data in KeyValueArray"""
        for kv_data in db_data:
            if kv_data.get('key') == user_data.get('key'):
                if kv_data.get('value') == user_data.get('value'):
                    return True
        return False

    @staticmethod
    def data_not_in(db_data, user_data):
        """Validate data not in user data"""
        if db_data not in user_data:
            return True
        return False

    @staticmethod
    def data_startswith(db_data, user_data):
        """Validate data starts with user data"""
        if db_data.startswith(user_data):
            return True
        return False

    @staticmethod
    def data_string_compare(db_data, user_data):
        """Validate string removing all white space before comparison"""
        db_data = ''.join(db_data.split())
        user_data = ''.join(user_data.split())
        if operator.eq(db_data, user_data):
            return True
        return False

    @staticmethod
    def deep_diff(db_data, user_data):
        """Validate data in user data"""
        try:
            ddiff = DeepDiff(db_data, user_data, ignore_order=True)
        except NameError:
            tcex.log.warning(u'Could not find DeepDiff module.')
            return False
        if ddiff:
            tcex.log.info(u'Diff: {}'.format(ddiff))
            return False
        return True

    def json_compare(self, db_data, user_data):
        """Validate data in user data"""
        if isinstance(db_data, (basestring)):
            db_data = json.loads(db_data)
        if isinstance(user_data, (basestring)):
            user_data = json.loads(user_data)
        return self.deep_diff(db_data, user_data)

    def load_data(self, data_file):
        """Load the data file"""
        data_file = os.path.abspath(data_file)
        if os.path.isfile(data_file):  # double check file exist
            f = open(data_file, 'r')
            data_array = json.load(f)
            f.close()
        else:
            tcex.log.error(u'Could not open data file ({}).'.format(data_file))
            tcex.exit(1)
        self.data = data_array

    @staticmethod
    def path_data(data, path):
        """Get data from a dictionary/list using jmesPath."""
        if isinstance(data, string_types):
            # try to convert string into list/dict before using expression
            try:
                json.loads(data)
            except Exception as e:
                tcex.exit(1, 'String value ({}) could not JSON serialized ({}).'.format(data, e))
        # get updated user data using jmespath expression
        expression = jmespath.compile(path)
        return expression.search(data, jmespath.Options(dict_cls=collections.OrderedDict))

    @staticmethod
    def redis_clear(variable):
        """Clear Data in Redis"""
        if variable is not None:
            tcex.log.debug(u'Clearing redis variable {}.'.format(variable))
            tcex.playbook.delete(variable)

    @staticmethod
    def redis_stage(variable, data):
        """Stage Redis Data"""
        if isinstance(data, int):
            data = str(data)
        if variable.endswith('Binary'):
            data = base64.b64decode(data)
        elif variable.endswith('BinaryArray'):
            # loop through each entry
            decoded_data = []
            for one_doc in data:
                one_doc_decoded = base64.b64decode(one_doc)
                decoded_data.append(one_doc_decoded)
            data = decoded_data
        tcex.log.info(u'Creating variable {}'.format(variable))
        tcex.playbook.create(variable, data)

    def redis_validate(self, db_data, user_data, oper):
        """Validate Data"""
        # convert any int to string since playbooks don't support int values
        if isinstance(db_data, int):
            db_data = str(db_data)
        if isinstance(user_data, int):
            user_data = str(user_data)

        tcex.log.debug(u'-> DB Data: ({}), Type: [{}]'.format(db_data, type(db_data)))
        tcex.log.debug(u'- Operator: ({}) -'.format(oper))
        tcex.log.debug(u'<- Validation Data: ({}), Type: [{}]'.format(user_data, type(user_data)))

        # sort list for simple comparisons
        if isinstance(db_data, (list)):
            try:
                db_data = sorted(db_data)
            except TypeError:
                tcex.log.warning('Could not sort list')
        if isinstance(user_data, (list)):
            try:
                user_data = sorted(user_data)
            except TypeError:
                tcex.log.warning('Could not sort list')

        # compare the data
        if oper in self._operators and self._operators.get(oper)(db_data, user_data):
            tcex.log.info(u'Validation was successful')
        else:
            tcex.exit(1, u'Validation failed')

    def stage(self):
        """Stage Data"""
        for data in self.data:
            data_type = data.get('data_type', 'redis')  # default to redis for older data files
            if data_type == 'redis':
                tcex.log.debug('Stage Redis Data')
                self.redis_stage(data.get('variable'), data.get('data'))
            elif data_type == 'threatconnect':
                tcex.log.debug('Stage ThreatConnect Data')
                self.tc_stage(data)
            elif data_type == 'redis_array':
                tcex.log.debug('Stage Redis Array')
                redis_array = []
                out_variable = data.get('variable')
                out_variable_type = tcex.playbook.variable_type(out_variable)
                for var in data.get('data', {}).get('variables', []):
                    variable = var.get('value')
                    path = var.get('path')
                    data = tcex.playbook.read(variable)
                    if path is not None:
                        data = self.path_data(data, path)
                    redis_array.append(data)
                # create merged variable
                tcex.playbook.create_data_types[out_variable_type](out_variable, redis_array)

    def tc_clear(self, data):
        """Clear Data in ThreatConnect"""
        owner = data.get('data_owner', args.data_owner)
        resource_data = data.get('data')
        variable = data.get('variable')
        if variable is not None:
            self.redis_clear(variable)

        # parse resource_data
        resource_type = resource_data.get('type')

        tcex.log.debug('resource data {}'.format(resource_data))
        resource_ids = []
        if resource_type in tcex.indicator_types:
            # temp resource to used to get resource_id from json
            resource = tcex.resource(resource_type)
            for r_data in resource.indicators(resource_data):
                if r_data.get('value') is not None:
                    resource_ids = [r_data.get('value')]
                    break
        elif resource_type in tcex.group_types:
            resource_ids = self.tc_group_name_to_id(
                resource_data.get('name'), resource_type, owner)
        else:
            tcex.log.warning('No Resource ID found to delete.')
            return

        for resource_id in resource_ids:
            resource = tcex.resource(resource_type)
            resource.resource_id(resource_id)
            resource.http_method = 'DELETE'
            resource.owner = owner
            response = resource.request()

            if response.get('status') != 'Success':
                warn = u'Failed deleting resource type "{}" with value "{}" ({}).'
                tcex.log.warning(warn.format(
                    resource_type, resource_id, response.get('response').text))

    @staticmethod
    def tc_create_associations(resource_data):
        """Add Attribute to a Resource"""

        # resource 1
        entity1 = tcex.playbook.read(resource_data.get('entity1'))
        entity1_id = entity1.get('id')
        entity1_owner = entity1.get('ownerName')
        entity1_type = entity1.get('type')
        if entity1.get('type') in tcex.indicator_types:
            entity1_id = entity1.get('value')

        # resource 2
        entity2 = tcex.playbook.read(resource_data.get('entity2'))
        entity2_id = entity2.get('id')
        entity2_owner = entity1.get('ownerName')
        entity2_type = entity2.get('type')
        if entity2.get('type') in tcex.indicator_types:
            entity2_id = entity2.get('value')

        if entity1_owner != entity2_owner:
            tcex.log.error('Can not associate resource across owners.')
            return

        resource1 = tcex.resource(entity1_type)
        resource1.http_method = 'POST'
        resource1.owner = entity1_owner
        resource1.resource_id(entity1_id)

        resource2 = tcex.resource(entity2_type)
        resource2.resource_id(entity2_id)

        a_resource = resource1.associations(resource2)
        response = a_resource.request()
        if response.get('status') != 'Success':
            tcex.log.warning(u'Failed associating "{}:{}" with "{}:{}" ({}).'.format(
                entity1_type, entity1_id, entity2_type, entity2_id, response.get('response').text))

    @staticmethod
    def tc_create_attribute(attribute_type, attribute_value, resource):
        """Add Attribute to a Resource"""
        attribute_data = {
            'type': str(attribute_type),
            'value': str(attribute_value)
        }
        # handle default description and source
        if attribute_type in ['Description', 'Source']:
            attribute_data['displayed'] = True

        attrib_resource = resource.attributes()
        attrib_resource.body = json.dumps(attribute_data)
        attrib_resource.http_method = 'POST'

        # add the attribute
        a_response = attrib_resource.request()
        if a_response.get('status') != 'Success':
            tcex.log.warning(u'Failed adding attribute type "{}" with value "{}" ({}).'.format(
                attribute_type, attribute_value, a_response.get('response').text))

    @staticmethod
    def tc_create_security_label(label, resource):
        """Add a Tag to a Resource"""
        sl_resource = resource.security_labels(label)
        sl_resource.http_method = 'POST'
        sl_response = sl_resource.request()
        if sl_response.get('status') != 'Success':
            tcex.log.warning(u'Failed adding security label "{}" ({}).'.format(
                label, sl_response.get('response').text))

    @staticmethod
    def tc_create_tag(tag, resource):
        """Add a Tag to a Resource"""
        tag_resource = resource.tags(tcex.safetag(tag))
        tag_resource.http_method = 'POST'
        t_response = tag_resource.request()
        if t_response.get('status') != 'Success':
            tcex.log.warning(u'Failed adding tag "{}" ({}).'.format(
                tag, t_response.get('response').text))

    @staticmethod
    def tc_group_name_to_id(name, resource_type, owner):
        """Covert Group Name to Group Ids"""
        tcex.log.debug(u'coverting {} to ids'.format(name))
        resource = tcex.resource(resource_type)
        resource.add_filter('name', '=', name)
        if owner is not None:
            resource.owner = owner

        results = resource.request()
        data = results.get('data')
        if not isinstance(data, list):
            data = [data]

        resource_ids = []
        for d in data:
            resource_ids.append(d.get('id'))
        tcex.log.debug(u'name_to_ids count: {}'.format(len(resource_ids)))

        return resource_ids

    def tc_stage(self, data):
        """Stage ThreatConnect Data"""
        owner = data.get('data_owner', args.data_owner)
        resource_data = data.get('data')
        variable = data.get('variable')

        # parse resource_data
        resource_type = resource_data.pop('type')

        if resource_type == 'Association':
            self.tc_create_associations(resource_data)
        elif resource_type in tcex.indicator_types or resource_type in tcex.group_types:
            try:
                attributes = resource_data.pop('attribute')
            except KeyError:
                attributes = []
            try:
                security_labels = resource_data.pop('security_label')
            except KeyError:
                security_labels = []
            try:
                tags = resource_data.pop('tag')
            except KeyError:
                tags = []

            resource = tcex.resource(resource_type)
            resource.http_method = 'POST'
            resource.owner = owner
            resource.url = args.tc_api_path

            # special case for Email Group Type
            if resource_type == 'Email':
                resource.add_payload('option', 'createVictims')

            tcex.log.debug(u'body: {}'.format(resource_data))
            resource.body = json.dumps(resource_data)

            response = resource.request()
            if response.get('status') == 'Success':
                # add resource id
                if resource_type in tcex.indicator_types:
                    for r_data in resource.indicators(response.get('data')):
                        if r_data.get('value') is not None:
                            resource_id = r_data.get('value')
                            break
                elif resource_type in tcex.group_types:
                    resource_id = response.get('data', {}).get('id')
                tcex.log.debug('resource_id: {}'.format(resource_id))
                resource.resource_id(resource_id)

                entity = tcex.playbook.json_to_entity(
                    response.get('data'), resource.value_fields, resource.name, resource.parent)
                tcex.log.debug(u'Creating Entity: {} ({})'.format(variable, entity[0]))
                tcex.playbook.create_tc_entity(variable, entity[0])

                # update metadata
                for attribute_data in attributes:
                    self.tc_create_attribute(
                        attribute_data.get('type'), attribute_data.get('value'), resource)
                for label_data in security_labels:
                    self.tc_create_security_label(label_data.get('name'), resource)
                for tag_data in tags:
                    self.tc_create_tag(tag_data.get('name'), resource)
        else:
            tcex.log.error('Unsupported resource type {}.'.format(resource_type))

    def validate(self):
        """Validate Data"""
        for data in self.data:
            data_type = data.get('data_type', 'redis')  # default to redis for older data files
            if data_type == 'redis':
                # get user variable or data
                user_data = data.get('data')
                user_data_path = data.get('data_path')  # jmespath expression
                if isinstance(user_data, string_types) and re.match(self._vars_match, user_data):
                    # if user_data reference a redis variable retrieve the data
                    user_data = tcex.playbook.read(user_data)
                if user_data_path is not None:
                    user_data = self.path_data(user_data, user_data_path)

                # get db variable/data
                variable = data.get('variable')
                db_data = tcex.playbook.read(variable)
                db_data_path = data.get('variable_path')
                if db_data_path is not None:
                    db_data = self.path_data(db_data, db_data_path)

                # validate if possible
                if variable.endswith('Binary') or variable.endswith('BinaryArray'):
                    tcex.log.debug(u'-> DB Data: ({}), Type: [{}]'.format(
                        'Excluding Binary Data Output', type(db_data)))
                else:
                    self.redis_validate(
                        db_data, user_data, data.get('operator', 'eq'))


if __name__ == '__main__':
    try:
        tcd = TcData(args)
        # tcd.run()
        if args.action == 'clear':
            tcex.log.info('TcData Clear')
            tcd.clear()
        elif args.action == 'stage':
            tcex.log.info('TcData Stage')
            tcd.stage()
        elif args.action == 'validate':
            tcex.log.info('TcData Validate')
            tcd.validate()
        else:
            tcex.log.info(u'Invalid action provided: {}'.format(args.action))
            tcex.exit(1)
        tcex.exit()
    except Exception as e:
        # TODO: Update this, possibly raise
        tcex.exit(1, u'Generic Failure ({}).'.format(traceback.format_exc()))
