#!/usr/bin/python

import os
import sys
import optparse
from collections import defaultdict
from StringIO import StringIO
from cprotobuf.plugin_pb import CodeGeneratorRequest, CodeGeneratorResponse, FieldDescriptorProto, FileDescriptorSet
from sys import stdin, stdout

CPP_TYPES = {
    FieldDescriptorProto.TYPE_DOUBLE:    'double',
    FieldDescriptorProto.TYPE_FLOAT:     'float',
    FieldDescriptorProto.TYPE_INT32:     'int32',
    FieldDescriptorProto.TYPE_INT64:     'int64',
    FieldDescriptorProto.TYPE_UINT32:    'uint32',
    FieldDescriptorProto.TYPE_UINT64:    'uint64',
    FieldDescriptorProto.TYPE_FIXED32:   'fixed32',
    FieldDescriptorProto.TYPE_FIXED64:   'fixed64',
    FieldDescriptorProto.TYPE_SFIXED32:  'sfixed32',
    FieldDescriptorProto.TYPE_SFIXED64:  'sfixed64',
    FieldDescriptorProto.TYPE_SINT32:    'sint32',
    FieldDescriptorProto.TYPE_SINT64:    'sint64',
    FieldDescriptorProto.TYPE_BOOL:      'bool',
    FieldDescriptorProto.TYPE_STRING:    'string',
    FieldDescriptorProto.TYPE_BYTES:     'bytes',
    FieldDescriptorProto.TYPE_ENUM:      'enum',
}


keywords = ['static', 'class']


def fieldname(s):
    if s in keywords:
        s = s+'_'
    return s


def convert_default_value(v, t):
    if v == 'false':
        return 'False'
    elif v == 'true':
        return 'True'
    elif t == FieldDescriptorProto.TYPE_STRING:
        return repr(v)
    elif t == FieldDescriptorProto.TYPE_BYTES:
        return repr(bytes(v))
    else:
        return v


def real_message_name(type_name):
    parts = type_name.split('.', 2)
    if len(parts) > 2:
        return parts[2]
    else:
        return '.'.join(parts)


def typename(type_name, type, seen_messages):
    if type == FieldDescriptorProto.TYPE_MESSAGE:
        name = real_message_name(type_name)
        # for the case without package specified.
        if name[0] == '.':
            name = name[1:]
        if name in seen_messages:
            return name
        else:
            return '\'%s\'' % name
    else:
        return '\'%s\'' % CPP_TYPES[type]


def write_field(fp, desc, seen_messages):
    ftype = typename(desc.type_name, desc.type, seen_messages)
    fname = fieldname(desc.name)
    fnumber = desc.number
    txt = '    %(fname)-15s = Field(%(ftype)s,\t%(fnumber)d'
    args = []
    if desc.label == FieldDescriptorProto.LABEL_REQUIRED:
        pass
    elif desc.label == FieldDescriptorProto.LABEL_OPTIONAL:
        args.append('required=False')
    elif desc.label == FieldDescriptorProto.LABEL_REPEATED:
        args.append('repeated=True')
        if desc.options.packed:
            args.append('packed=True')
    if desc.default_value:
        args.append('default=%s' % convert_default_value(desc.default_value, desc.type))
    if args:
        txt += ', '
        txt += ', '.join(args)
    txt += ')'
    print >>fp, txt % locals()


def write_message(fp, message_descriptor, seen_messages):
    classname = message_descriptor.name
    print >>fp, 'class %(classname)s(ProtoEntity):' % locals()

    for descriptor in message_descriptor.enum_type:
        write_enum(fp, descriptor, 4)

    if message_descriptor.field:
        for field_descriptor in message_descriptor.field:
            write_field(fp, field_descriptor, seen_messages)
    else:
        print >>fp, '    pass'

    print >>fp, ''


def write_enum(fp, descriptor, indent=0):
    print >>fp, '%s# enum %s' % (' '*indent, descriptor.name)
    for value in descriptor.value:
        print >>fp, '%s%s=%d' % (' '*indent, value.name, value.number)


def sort_messages(desces):
    index_by_name = dict((desc.name, desc) for desc in desces)
    deps = defaultdict(set)

    for desc in desces:
        if desc.field:
            for field_desc in desc.field:
                if field_desc.type == FieldDescriptorProto.TYPE_MESSAGE:
                    type_name = real_message_name(field_desc.type_name)
                    if type_name in index_by_name and type_name != desc.name:
                        deps[desc.name].add(type_name)

    while True:
        if not index_by_name:
            break
        for name, attr in index_by_name.items():
            if name not in deps:
                index_by_name.pop(name)
                yield attr

                for k, v in deps.items():
                    try:
                        v.remove(name)
                    except KeyError:
                        pass
                    if not v:
                        del deps[k]

                break
        else:
            for name, attr in index_by_name.items():
                index_by_name.pop(name)
                yield attr


def main(proto_files):
    response = CodeGeneratorResponse()

    packages = defaultdict(list)

    for desc in proto_files:
        packages[desc.package].append(desc)

    seen_messages = set()
    for pkg, desces in packages.items():
        f = response.file.add()
        if pkg == '':
            pkg = 'protos'
        f.name = '%s_pb.py' % pkg.replace('.', '_')

        fp = StringIO()
        print >>fp, '# coding: utf-8'
        print >>fp, 'from cprotobuf import ProtoEntity, Field'

        for desc in desces:
            print >>fp, '# file: %s' % desc.name
            for descriptor in desc.enum_type:
                write_enum(fp, descriptor)
            for descriptor in sort_messages(desc.message_type):
                write_message(fp, descriptor, seen_messages)
                seen_messages.add(descriptor.name)

        f.content = fp.getvalue()

    return response

if __name__ == '__main__':
    if len(sys.argv) > 1:
        parser = optparse.OptionParser()
        parser.add_option("-d", "--output_directory", dest="output_directory",
                          help="directory to write output modules.", metavar="DIRECTORY", default='.')
        opt, args = parser.parse_args()

        for input_file in args:
            request = FileDescriptorSet()
            request.ParseFromString(open(input_file, 'rb').read())
            response = main(request.file)
            for file in response.file:
                open(os.path.join(opt.output_directory, file.name), 'w').write(file.content)
    else:
        request = CodeGeneratorRequest()
        request.ParseFromString(stdin.read())
        response = main(request.proto_file)
        stdout.write(response.SerializeToString())
