cherrypick descriptor_pool.FindFileContainingSymbol by extensions (#2962)

* Use PyUnicode_AsEncodedString() instead of PyUnicode_AsEncodedObject()

* Cherrypick the fix descriptor_pool.FindFileContainingSymbol by extensions.
This commit is contained in:
Jie Luo 2017-04-10 16:37:57 -07:00 committed by Feng Xiao
parent e91caa1f19
commit 899460c9cb
4 changed files with 45 additions and 8 deletions

View File

@ -127,6 +127,9 @@ class DescriptorPool(object):
self._service_descriptors = {}
self._file_descriptors = {}
self._toplevel_extensions = {}
# TODO(jieluo): Remove _file_desc_by_toplevel_extension when
# FieldDescriptor.file is added in code gen.
self._file_desc_by_toplevel_extension = {}
# We store extensions in two two-level mappings: The first key is the
# descriptor of the message being extended, the second key is the extension
# full name or its tag number.
@ -170,7 +173,7 @@ class DescriptorPool(object):
raise TypeError('Expected instance of descriptor.Descriptor.')
self._descriptors[desc.full_name] = desc
self.AddFileDescriptor(desc.file)
self._AddFileDescriptor(desc.file)
def AddEnumDescriptor(self, enum_desc):
"""Adds an EnumDescriptor to the pool.
@ -185,7 +188,7 @@ class DescriptorPool(object):
raise TypeError('Expected instance of descriptor.EnumDescriptor.')
self._enum_descriptors[enum_desc.full_name] = enum_desc
self.AddFileDescriptor(enum_desc.file)
self._AddFileDescriptor(enum_desc.file)
def AddServiceDescriptor(self, service_desc):
"""Adds a ServiceDescriptor to the pool.
@ -251,6 +254,23 @@ class DescriptorPool(object):
file_desc: A FileDescriptor.
"""
self._AddFileDescriptor(file_desc)
# TODO(jieluo): This is a temporary solution for FieldDescriptor.file.
# Remove it when FieldDescriptor.file is added in code gen.
for extension in file_desc.extensions_by_name.itervalues():
self._file_desc_by_toplevel_extension[
extension.full_name] = file_desc
def _AddFileDescriptor(self, file_desc):
"""Adds a FileDescriptor to the pool, non-recursively.
If the FileDescriptor contains messages or enums, the caller must explicitly
register them.
Args:
file_desc: A FileDescriptor.
"""
if not isinstance(file_desc, descriptor.FileDescriptor):
raise TypeError('Expected instance of descriptor.FileDescriptor.')
self._file_descriptors[file_desc.name] = file_desc
@ -313,12 +333,18 @@ class DescriptorPool(object):
except KeyError:
pass
try:
return self._file_desc_by_toplevel_extension[symbol]
except KeyError:
pass
# Try nested extensions inside a message.
message_name, _, extension_name = symbol.rpartition('.')
try:
scope = self.FindMessageTypeByName(message_name)
assert scope.extensions_by_name[extension_name]
return scope.file
message = self.FindMessageTypeByName(message_name)
assert message.extensions_by_name[extension_name]
return message.file
except KeyError:
raise KeyError('Cannot find a file containing %s' % symbol)

View File

@ -63,6 +63,9 @@ from google.protobuf import symbol_database
class DescriptorPoolTest(unittest.TestCase):
def setUp(self):
# TODO(jieluo): Should make the pool which is created by
# serialized_pb same with generated pool.
# TODO(jieluo): More test coverage for the generated pool.
self.pool = descriptor_pool.DescriptorPool()
self.factory_test1_fd = descriptor_pb2.FileDescriptorProto.FromString(
factory_test1_pb2.DESCRIPTOR.serialized_pb)
@ -128,6 +131,12 @@ class DescriptorPoolTest(unittest.TestCase):
self.assertEqual('google/protobuf/internal/factory_test2.proto',
file_desc4.name)
# Tests the generated pool.
assert descriptor_pool.Default().FindFileContainingSymbol(
'google.protobuf.python.internal.Factory2Message.one_more_field')
assert descriptor_pool.Default().FindFileContainingSymbol(
'google.protobuf.python.internal.another_field')
def testFindFileContainingSymbolFailure(self):
with self.assertRaises(KeyError):
self.pool.FindFileContainingSymbol('Does not exist')

View File

@ -779,7 +779,7 @@ PyObject* CheckString(PyObject* arg, const FieldDescriptor* descriptor) {
encoded_string = arg; // Already encoded.
Py_INCREF(encoded_string);
} else {
encoded_string = PyUnicode_AsEncodedObject(arg, "utf-8", NULL);
encoded_string = PyUnicode_AsEncodedString(arg, "utf-8", NULL);
}
} else {
// In this case field type is "bytes".

View File

@ -445,8 +445,6 @@ void Generator::PrintFileDescriptor() const {
printer_->Outdent();
printer_->Print(")\n");
printer_->Print("_sym_db.RegisterFileDescriptor($name$)\n", "name",
kDescriptorKey);
printer_->Print("\n");
}
@ -999,6 +997,10 @@ void Generator::FixForeignFieldsInDescriptors() const {
for (int i = 0; i < file_->extension_count(); ++i) {
AddExtensionToFileDescriptor(*file_->extension(i));
}
// TODO(jieluo): Move this register to PrintFileDescriptor() when
// FieldDescriptor.file is added in generated file.
printer_->Print("_sym_db.RegisterFileDescriptor($name$)\n", "name",
kDescriptorKey);
printer_->Print("\n");
}