cherrypick descriptor_pool.FindFileContainingSymbol by extensions (#2962)
* Use PyUnicode_AsEncodedString() instead of PyUnicode_AsEncodedObject() * Cherrypick the fix descriptor_pool.FindFileContainingSymbol by extensions.
This commit is contained in:
parent
e91caa1f19
commit
899460c9cb
@ -127,6 +127,9 @@ class DescriptorPool(object):
|
||||
self._service_descriptors = {}
|
||||
self._file_descriptors = {}
|
||||
self._toplevel_extensions = {}
|
||||
# TODO(jieluo): Remove _file_desc_by_toplevel_extension when
|
||||
# FieldDescriptor.file is added in code gen.
|
||||
self._file_desc_by_toplevel_extension = {}
|
||||
# We store extensions in two two-level mappings: The first key is the
|
||||
# descriptor of the message being extended, the second key is the extension
|
||||
# full name or its tag number.
|
||||
@ -170,7 +173,7 @@ class DescriptorPool(object):
|
||||
raise TypeError('Expected instance of descriptor.Descriptor.')
|
||||
|
||||
self._descriptors[desc.full_name] = desc
|
||||
self.AddFileDescriptor(desc.file)
|
||||
self._AddFileDescriptor(desc.file)
|
||||
|
||||
def AddEnumDescriptor(self, enum_desc):
|
||||
"""Adds an EnumDescriptor to the pool.
|
||||
@ -185,7 +188,7 @@ class DescriptorPool(object):
|
||||
raise TypeError('Expected instance of descriptor.EnumDescriptor.')
|
||||
|
||||
self._enum_descriptors[enum_desc.full_name] = enum_desc
|
||||
self.AddFileDescriptor(enum_desc.file)
|
||||
self._AddFileDescriptor(enum_desc.file)
|
||||
|
||||
def AddServiceDescriptor(self, service_desc):
|
||||
"""Adds a ServiceDescriptor to the pool.
|
||||
@ -251,6 +254,23 @@ class DescriptorPool(object):
|
||||
file_desc: A FileDescriptor.
|
||||
"""
|
||||
|
||||
self._AddFileDescriptor(file_desc)
|
||||
# TODO(jieluo): This is a temporary solution for FieldDescriptor.file.
|
||||
# Remove it when FieldDescriptor.file is added in code gen.
|
||||
for extension in file_desc.extensions_by_name.itervalues():
|
||||
self._file_desc_by_toplevel_extension[
|
||||
extension.full_name] = file_desc
|
||||
|
||||
def _AddFileDescriptor(self, file_desc):
|
||||
"""Adds a FileDescriptor to the pool, non-recursively.
|
||||
|
||||
If the FileDescriptor contains messages or enums, the caller must explicitly
|
||||
register them.
|
||||
|
||||
Args:
|
||||
file_desc: A FileDescriptor.
|
||||
"""
|
||||
|
||||
if not isinstance(file_desc, descriptor.FileDescriptor):
|
||||
raise TypeError('Expected instance of descriptor.FileDescriptor.')
|
||||
self._file_descriptors[file_desc.name] = file_desc
|
||||
@ -313,12 +333,18 @@ class DescriptorPool(object):
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
try:
|
||||
return self._file_desc_by_toplevel_extension[symbol]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
# Try nested extensions inside a message.
|
||||
message_name, _, extension_name = symbol.rpartition('.')
|
||||
try:
|
||||
scope = self.FindMessageTypeByName(message_name)
|
||||
assert scope.extensions_by_name[extension_name]
|
||||
return scope.file
|
||||
message = self.FindMessageTypeByName(message_name)
|
||||
assert message.extensions_by_name[extension_name]
|
||||
return message.file
|
||||
|
||||
except KeyError:
|
||||
raise KeyError('Cannot find a file containing %s' % symbol)
|
||||
|
||||
|
@ -63,6 +63,9 @@ from google.protobuf import symbol_database
|
||||
class DescriptorPoolTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# TODO(jieluo): Should make the pool which is created by
|
||||
# serialized_pb same with generated pool.
|
||||
# TODO(jieluo): More test coverage for the generated pool.
|
||||
self.pool = descriptor_pool.DescriptorPool()
|
||||
self.factory_test1_fd = descriptor_pb2.FileDescriptorProto.FromString(
|
||||
factory_test1_pb2.DESCRIPTOR.serialized_pb)
|
||||
@ -128,6 +131,12 @@ class DescriptorPoolTest(unittest.TestCase):
|
||||
self.assertEqual('google/protobuf/internal/factory_test2.proto',
|
||||
file_desc4.name)
|
||||
|
||||
# Tests the generated pool.
|
||||
assert descriptor_pool.Default().FindFileContainingSymbol(
|
||||
'google.protobuf.python.internal.Factory2Message.one_more_field')
|
||||
assert descriptor_pool.Default().FindFileContainingSymbol(
|
||||
'google.protobuf.python.internal.another_field')
|
||||
|
||||
def testFindFileContainingSymbolFailure(self):
|
||||
with self.assertRaises(KeyError):
|
||||
self.pool.FindFileContainingSymbol('Does not exist')
|
||||
|
@ -779,7 +779,7 @@ PyObject* CheckString(PyObject* arg, const FieldDescriptor* descriptor) {
|
||||
encoded_string = arg; // Already encoded.
|
||||
Py_INCREF(encoded_string);
|
||||
} else {
|
||||
encoded_string = PyUnicode_AsEncodedObject(arg, "utf-8", NULL);
|
||||
encoded_string = PyUnicode_AsEncodedString(arg, "utf-8", NULL);
|
||||
}
|
||||
} else {
|
||||
// In this case field type is "bytes".
|
||||
|
@ -445,8 +445,6 @@ void Generator::PrintFileDescriptor() const {
|
||||
|
||||
printer_->Outdent();
|
||||
printer_->Print(")\n");
|
||||
printer_->Print("_sym_db.RegisterFileDescriptor($name$)\n", "name",
|
||||
kDescriptorKey);
|
||||
printer_->Print("\n");
|
||||
}
|
||||
|
||||
@ -999,6 +997,10 @@ void Generator::FixForeignFieldsInDescriptors() const {
|
||||
for (int i = 0; i < file_->extension_count(); ++i) {
|
||||
AddExtensionToFileDescriptor(*file_->extension(i));
|
||||
}
|
||||
// TODO(jieluo): Move this register to PrintFileDescriptor() when
|
||||
// FieldDescriptor.file is added in generated file.
|
||||
printer_->Print("_sym_db.RegisterFileDescriptor($name$)\n", "name",
|
||||
kDescriptorKey);
|
||||
printer_->Print("\n");
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user