#!/usr/bin/env python

from thrift.Thrift import TType
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
from concrete.uuid.ttypes import UUID
from concrete.metadata.ttypes import AnnotationMetadata
from inspect import isroutine
from concrete.util.mem_io import communication_deep_copy
from concrete.util.file_io import CommunicationReader, CommunicationWriterTGZ
from concrete.util.concrete_uuid import AnalyticUUIDGeneratorFactory
from concrete.util.unnone import lun

import logging

def filtered_getmembers(obj):
    '''
    Generate key-value pairs of object members that may contain UUIDs.
    Over-generate, but filter the output enough that concrete objects
    can be traversed recursively using this function without leading to
    stack overflows or infinite loops.
    '''

    for k in dir(obj):
        if not (k[0] == '_' or k == 'thrift_spec' or k == 'read' or k == 'write'
                or k == 'validate'):
            v = getattr(obj, k)
            if not (isroutine(v)
                    or isinstance(v, int) or isinstance(v, float)
                    or isinstance(v, str) or isinstance(v, unicode)):
                yield (k, v)

FILTERED_TTYPES = set((TType.STRUCT, TType.LIST, TType.MAP, TType.SET))

def fast_filtered_getmembers(obj):
    'Fast thrift-specific implementation of filtered_getmembers.'
    if hasattr(obj, 'thrift_spec'):
        for s in obj.thrift_spec:
            if s is not None:
                t = s[1]
                if t in FILTERED_TTYPES:
                    k = s[2]
                    yield (k, getattr(obj, k))


class UUIDClustering(object):
    '''
    Representation of the UUID instance clusters in a concrete
    communication (each cluster represents the set of nested members of
    the communication that reference or are identified by a given UUID).
    '''

    def __init__(self, comm):
        self._clusters = dict() # map: UUID -> set of nested members
        self._search(comm)

    def hashable_clusters(self):
        '''
        Return the set of unlabeled UUID clusters in a unique and
        hashable format.  Two UUIDClusterings c1 and c2 are equivalent
        (the two underlying communications' UUID structures are
        equivalent) if and only if:

            c1.hashable_clusters() == c2.hashable_clusters()
        '''
        return set(tuple(sorted(c)) for c in self._clusters.values())

    def _search(self, obj, prefix=()):
        '''
        Search obj for UUIDs, calling _add_uuid_field when UUIDs are
        found and calling _search on other object members.
        When _search calls itself, prefix is appended with the object
        member name, forming a uniquely identifiable tuple
        representation of the path from the root object to a nested
        object member.
        '''

        if isinstance(obj, UUID):
            self._add_uuid_field(obj.uuidString, prefix)
        elif isinstance(obj, list):
            for (i, v) in enumerate(obj):
                self._search(v, prefix + (('list', i),) )
        elif isinstance(obj, set):
            raise ValueError('UUIDClustering does not support sets')
        elif isinstance(obj, dict):
            for (k, v) in obj.items():
                self._search(v, prefix + (('dict', k),) )
        else:
            for (k, v) in filtered_getmembers(obj):
                self._search(v, prefix + (k,) )

    def _add_uuid_field(self, u, f):
        '''
        Add UUID field f (a unique, hashable representation of the path
        from the root communication to a nested UUID object) to the UUID
        cluster indexed by UUID string u.
        '''
        if u in self._clusters:
            self._clusters[u].add(f)
        else:
            self._clusters[u] = set([f])


class UUIDCompressor(object):
    def __init__(self, single_analytic=False):
        self.single_analytic = single_analytic

    def compress(self, comm):
        'Return a deep copy of comm with compressed UUIDs.'

        cc = communication_deep_copy(comm)
        self.augf = AnalyticUUIDGeneratorFactory(cc)
        self.augs = dict()
        self.uuid_map = dict()

        self._compress_uuids(cc)
        self._compress_uuid_refs(cc)

        return cc

    def _compress_uuids(self, obj, name_is_uuid=False, tool=None):
        'Generate new UUIDs in "uuid" fields and save mapping'

        tool = self._get_tool(obj, tool)

        if name_is_uuid:
            if isinstance(obj, UUID):
                obj.uuidString = self._gen_uuid(obj, tool)
            else:
                logging.warning('uuid not instance of UUID')

        if not isinstance(obj, UUID): # we already took care of "uuid"
            self._apply(
                lambda elt, elt_name_is_uuid: self._compress_uuids(elt,
                    name_is_uuid=elt_name_is_uuid, tool=tool), obj)

    def _compress_uuid_refs(self, obj, name_is_uuid=False, tool=None):
        'Update UUID references (not in "uuid" fields) using saved mapping'

        tool = self._get_tool(obj, tool)

        if isinstance(obj, UUID):
            if not name_is_uuid:
                obj.uuidString = self.uuid_map[obj.uuidString]
        else:
            self._apply(
                lambda elt, elt_name_is_uuid: self._compress_uuid_refs(elt,
                    name_is_uuid=elt_name_is_uuid, tool=tool), obj)

    def _get_tool(self, obj, tool=None):
        '''
        Return tool for this object, given the parent tool;
        update self.augs
        '''

        if hasattr(obj, 'metadata'):
            if isinstance(obj.metadata, AnnotationMetadata):
                tool = obj.metadata.tool
            else:
                logging.warning('metadata not instance of AnnotationMetadata')
        if self.single_analytic:
            tool = None
        if tool not in self.augs:
            self.augs[tool] = self.augf.create()
        return tool

    def _gen_uuid(self, old_uuid, tool):
        '''
        Return a new UUID for the provided tool, using self.augs;
        update self.uuid_map
        '''

        aug = self.augs[tool]
        new_uuid = aug.next()
        if old_uuid.uuidString in self.uuid_map:
            raise ValueError('encountered UUID %s twice, aborting' %
                    old_uuid.uuidString)
        self.uuid_map[old_uuid.uuidString] = new_uuid.uuidString
        return new_uuid.uuidString

    @classmethod
    def _apply(cls, f, x):
        '''
        Apply f to the members of x if it is a basic container type,
        otherwise apply f to x directly.
        '''

        if isinstance(x, list):
            for elt in x:
                f(elt, False)
        elif isinstance(x, set):
            for elt in x:
                f(elt, False)
        elif isinstance(x, dict):
            for elt in x.values():
                f(elt, False)
        else:
            for (k, v) in fast_filtered_getmembers(x):
                f(v, k == 'uuid')


def compress_uuids(input_path, output_path, verify=False, uuid_map_path=None,
                   single_analytic=False):
    reader = CommunicationReader(input_path, add_references=False)
    writer = CommunicationWriterTGZ(output_path)

    if uuid_map_path is None:
        uuid_map_file = None
    else:
        uuid_map_file = open(uuid_map_path, 'w')

    uc = UUIDCompressor(single_analytic=single_analytic)

    for (i, (comm, _)) in enumerate(reader):
        new_comm = uc.compress(comm)

        logging.info('compressed %s (%d analytics, %d uuids) (%d/?)'
                     % (comm.id, len(uc.augs), len(uc.uuid_map), i+1))

        if uuid_map_file is not None:
            for (old_uuid, new_uuid) in sorted(uc.uuid_map.items(),
                                               key=lambda p: str(p[1])):
                uuid_map_file.write('%s %s\n' % (old_uuid, new_uuid))

        num_old_uuids = len(set(uc.uuid_map.keys()))
        num_new_uuids = len(set(uc.uuid_map.values()))

        if verify:
            c1 = UUIDClustering(comm).hashable_clusters()
            c2 = UUIDClustering(new_comm).hashable_clusters()

            # Verification is c1 == c2;
            # also check UUID map lengths are the same as a sanity-check
            if num_old_uuids == num_new_uuids and c1 == c2:
                logging.info('verified %s (%d uuid instances, %d uuids)'
                             % (comm.id, sum(len(c) for c in c1), len(c1)))
            else:
                logging.error('%s failed verification' % comm.id)
                logging.error('uuid counts: %d -> %d'
                        % (num_old_uuids, num_new_uuids))
                logging.error('verified number of uuids: %d -> %d'
                    % (len(c1), len(c2)))
                logging.error('verified number of uuid instances: %d -> %d'
                    % (sum(map(len, c1)), sum(map(len, c2))))
                raise Exception('%s failed verification' % comm.id)

        else:
            if num_old_uuids != num_new_uuids:
                logging.warning('uuid counts are not the same (%d -> %d)'
                        % (num_old_uuids, num_new_uuids))

        writer.write(new_comm)


def main():
    parser = ArgumentParser(
        formatter_class=ArgumentDefaultsHelpFormatter,
        description='Read a concrete tarball and write it back out, rewriting'
                    ' UUIDs with compressible UUID scheme',
    )
    parser.set_defaults(log_level='INFO')
    parser.add_argument('input_path', type=str,
                        help='Input tarball path (- for stdin)')
    parser.add_argument('output_path', type=str,
                        help='Output tarball path (- for stdout)')
    parser.add_argument('--log-level', type=str,
                        choices=('DEBUG', 'INFO', 'WARNING', 'ERROR'),
                        help='Logging verbosity level (to stderr)')
    parser.add_argument('--verify', action='store_true',
                        help='Verify within-communication links are satisfied'
                             ' after conversion')
    parser.add_argument('--single-analytic', action='store_true',
                        help='Verify within-communication links are satisfied'
                             ' after conversion')
    parser.add_argument('--uuid-map-path', type=str,
                        help='Output path of UUID map')
    ns = parser.parse_args()

    # Won't work on Windows... but that use case is very unlikely
    input_path = '/dev/fd/0' if ns.input_path == '-' else ns.input_path
    output_path = '/dev/fd/1' if ns.output_path == '-' else ns.output_path

    logging.basicConfig(
        level=ns.log_level,
        format='%(asctime)-15s %(levelname)s: %(message)s'
    )

    compress_uuids(input_path, output_path, verify=ns.verify,
                   single_analytic=ns.single_analytic,
                   uuid_map_path=ns.uuid_map_path)


if __name__ == "__main__":
    main()
