#!/usr/bin/env python

from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
from concrete.uuid.ttypes import UUID
from concrete.metadata.ttypes import AnnotationMetadata
from concrete.structure.ttypes import (
    Parse, DependencyParse, TokenTagging, TokenList, LatticePath
)
from inspect import isroutine, getmembers
from concrete.util.redis_io import communication_deep_copy
from concrete.util.file_io import CommunicationReader, CommunicationWriterTGZ
from concrete.util.concrete_uuid import AnalyticUUIDGeneratorFactory
from concrete.util.unnone import lun
import json

import logging


def filtered_getmembers(obj):
    for k in dir(obj):
        if not (k[0] == '_' or k == 'thrift_spec' or k == 'read' or k == 'write'
                or k == 'validate'):
            v = getattr(obj, k)
            if not (isroutine(v)
                    or isinstance(v, int) or isinstance(v, float)
                    or isinstance(v, str) or isinstance(v, unicode)):
                yield (k, v)


def has_minimal_uuids(obj):
    '''
    return True iff obj.uuid and obj.metadata are the only places UUIDs
    may appear
    '''
    return (isinstance(obj, Parse) or isinstance(obj, DependencyParse)
            or isinstance(obj, TokenTagging) or isinstance(obj, TokenList)
            or isinstance(obj, LatticePath))


class UUIDClustering(object):
    def __init__(self, comm):
        self.clusters = dict()
        self._search(comm)

    def hashable_clusters(self):
        return set(tuple(sorted(c)) for c in self.clusters.values())

    def _search(self, obj, prefix=None):
        if prefix is None:
            prefix = ()

        for (k, v) in filtered_getmembers(obj):
            if isinstance(v, UUID):
                self._add_uuid_field(v.uuidString, prefix + (k,))
            elif isinstance(v, list):
                for (i, elt) in enumerate(v):
                    self._search(elt, prefix + ((k, 'list', i),))
            elif isinstance(v, set):
                raise ValueError('UUIDClustering does not support sets')
            elif isinstance(v, dict):
                for (kk, vv) in v.items():
                    self._search(vv, prefix + ((k, 'dict', kk),))
            else:
                self._search(v, prefix + (k,))

    def _add_uuid_field(self, u, f):
        if u in self.clusters:
            self.clusters[u].add(f)
        else:
            self.clusters[u] = set([f])


class UUIDCompressor(object):
    def compress(self, comm):
        'Return a deep copy of comm with compressed UUIDs.'

        cc = communication_deep_copy(comm)
        self.augf = AnalyticUUIDGeneratorFactory(cc)
        self.augs = dict()
        self.uuid_map = dict()

        self._compress_uuids(cc)
        self._compress_uuid_refs(cc)

        return cc

    def _compress_uuids(self, obj, tool=None):
        'Generate new UUIDs in "uuid" fields and save mapping'

        tool = self._get_tool(obj, tool)

        if hasattr(obj, 'uuid'):
            if isinstance(obj.uuid, UUID):
                obj.uuid = self._gen_uuid(obj.uuid, tool)
            else:
                logging.warning('uuid not instance of UUID')

        if has_minimal_uuids(obj):
            return # no other field named "uuid" under obj

        for (k, v) in filtered_getmembers(obj):
            if not isinstance(v, UUID): # we already took care of "uuid"
                self._apply(lambda elt: self._compress_uuids(elt, tool), v)

    def _compress_uuid_refs(self, obj, tool=None):
        'Update UUID references (not in "uuid" fields) using saved mapping'

        tool = self._get_tool(obj, tool)

        if has_minimal_uuids(obj):
            # just need to check grandchildren of metadata
            if hasattr(obj, 'metadata'):
                if isinstance(obj.metadata, AnnotationMetadata):
                    self._compress_uuid_refs(obj.metadata, tool)
                else:
                    logging.warning('metadata not instance of'
                            ' AnnotationMetadata')
            return

        for (k, v) in filtered_getmembers(obj):
            if isinstance(v, UUID):
                if k != 'uuid':
                    setattr(obj, k, UUID(uuidString=self.uuid_map[v.uuidString]))
            else:
                self._apply(lambda elt: self._compress_uuid_refs(elt, tool), v)

    def _get_tool(self, obj, tool=None):
        '''
        Return tool for this object, given the parent tool;
        update self.augs
        '''

        if hasattr(obj, 'metadata'):
            if isinstance(obj.metadata, AnnotationMetadata):
                tool = obj.metadata.tool
            else:
                logging.warning('metadata not instance of AnnotationMetadata')
        if tool not in self.augs:
            self.augs[tool] = self.augf.create()
        return tool

    def _gen_uuid(self, old_uuid, tool):
        '''
        Return a new UUID for the provided tool, using self.augs;
        update self.uuid_map
        '''

        aug = self.augs[tool]
        new_uuid = aug.next()
        if old_uuid.uuidString in self.uuid_map:
            raise ValueError('encountered UUID %s twice, aborting' %
                    old_uuid.uuidString)
        self.uuid_map[old_uuid.uuidString] = new_uuid.uuidString
        return new_uuid

    @classmethod
    def _apply(cls, f, x):
        '''
        Apply f to the members of x if it is a basic container type,
        otherwise apply f to x directly.
        '''

        if isinstance(x, list):
            for elt in x:
                f(elt)
        elif isinstance(x, set):
            for elt in x:
                f(elt)
        elif isinstance(x, dict):
            for elt in x.values():
                f(elt)
        else:
            f(x)


def compress_uuids(input_path, output_path, verify=False, uuid_map_path=None):
    reader = CommunicationReader(input_path, add_references=False)
    writer = CommunicationWriterTGZ(output_path)

    if uuid_map_path is None:
        uuid_map_file = None
    else:
        uuid_map_file = open(uuid_map_path, 'w')

    uc = UUIDCompressor()

    for (i, (comm, _)) in enumerate(reader):
        new_comm = uc.compress(comm)

        logging.info('compressed %s (%d analytics, %d uuids) (%d/?)'
                     % (comm.id, len(uc.augs), len(uc.uuid_map), i+1))

        if uuid_map_file is not None:
            for (old_uuid, new_uuid) in sorted(uc.uuid_map.items(),
                                               key=lambda p: str(p[1])):
                uuid_map_file.write('%s %s\n' % (old_uuid, new_uuid))

        num_old_uuids = len(set(uc.uuid_map.keys()))
        num_new_uuids = len(set(uc.uuid_map.values()))

        if verify:
            c1 = UUIDClustering(comm).hashable_clusters()
            c2 = UUIDClustering(new_comm).hashable_clusters()

            if num_old_uuids == num_new_uuids and c1 == c2:
                logging.info('verified %s (%d uuid instances, %d uuids)'
                             % (comm.id, sum(len(c) for c in c1), len(c1)))
            else:
                logging.error('%s failed verification' % comm.id)
                logging.error('uuid counts: %d -> %d'
                        % (num_old_uuids, num_new_uuids))
                logging.error('verified number of uuids: %d -> %d'
                    % (len(c1), len(c2)))
                logging.error('verified number of uuid instances: %d -> %d'
                    % (sum(map(len, c1)), sum(map(len, c2))))
                raise Exception('%s failed verification' % comm.id)

        else:
            if num_old_uuids != num_new_uuids:
                logging.warning('uuid counts are not the same (%d -> %d)'
                        % (num_old_uuids, num_new_uuids))

        writer.write(new_comm)


def main():
    parser = ArgumentParser(
        formatter_class=ArgumentDefaultsHelpFormatter,
        description='Read a concrete tarball and write it back out, rewriting'
                    ' UUIDs with compressible UUID scheme',
    )
    parser.set_defaults(log_level='INFO')
    parser.add_argument('input_path', type=str,
                        help='Input tarball path (- for stdin)')
    parser.add_argument('output_path', type=str,
                        help='Output tarball path (- for stdout)')
    parser.add_argument('--log-level', type=str,
                        choices=('DEBUG', 'INFO', 'WARNING', 'ERROR'),
                        help='Logging verbosity level (to stderr)')
    parser.add_argument('--verify', action='store_true',
                        help='Verify within-communication links are satisfied'
                             ' after conversion')
    parser.add_argument('--uuid-map-path', type=str,
                        help='Output path of UUID map')
    ns = parser.parse_args()

    # Won't work on Windows...
    # but stdin/stdout pipelines using this script in Windows have probability 0
    input_path = '/dev/fd/0' if ns.input_path == '-' else ns.input_path
    output_path = '/dev/fd/1' if ns.output_path == '-' else ns.output_path

    logging.basicConfig(
        level=ns.log_level,
        format='%(asctime)-15s %(levelname)s: %(message)s'
    )

    compress_uuids(input_path, output_path, verify=ns.verify,
            uuid_map_path=ns.uuid_map_path)


if __name__ == "__main__":
    main()
