Merge tag 'refs/tags/sync-piper' into sync-stage

# Conflicts:
#	src/google/protobuf/compiler/java/java_message.cc
#	src/google/protobuf/compiler/java/java_message_lite.cc
This commit is contained in:
David L. Jones 2022-02-17 09:54:12 -08:00
commit a6c9a5f69d
70 changed files with 2172 additions and 1799 deletions

View File

@ -1,38 +1,27 @@
package com.google.protobuf; package com.google.protobuf;
import com.google.caliper.BeforeExperiment; import com.google.caliper.BeforeExperiment;
import com.google.caliper.AfterExperiment;
import com.google.caliper.Benchmark; import com.google.caliper.Benchmark;
import com.google.caliper.Param; import com.google.caliper.Param;
import com.google.caliper.api.VmOptions;
import com.google.protobuf.ByteString;
import com.google.protobuf.CodedOutputStream;
import com.google.protobuf.ExtensionRegistry;
import com.google.protobuf.Message;
import com.google.protobuf.benchmarks.Benchmarks.BenchmarkDataset; import com.google.protobuf.benchmarks.Benchmarks.BenchmarkDataset;
import java.io.ByteArrayInputStream; import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream; import java.io.ByteArrayOutputStream;
import java.io.BufferedWriter;
import java.io.File; import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.FileWriter;
import java.io.IOException; import java.io.IOException;
import java.io.RandomAccessFile; import java.io.RandomAccessFile;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
// Caliper set CICompilerCount to 1 for making sure compilation doesn't run in parallel with itself, /**
// This makes TieredCompilation not working. We just disable TieredCompilation by default. In master * Basic benchmarks for Java protobuf parsing.
// branch this has been disabled by default in caliper: */
// https://github.com/google/caliper/blob/master/caliper-runner/src/main/java/com/google/caliper/runner/target/Jvm.java#L38:14
// But this haven't been added into most recent release.
@VmOptions("-XX:-TieredCompilation")
public class ProtoCaliperBenchmark { public class ProtoCaliperBenchmark {
public enum BenchmarkMessageType { public enum BenchmarkMessageType {
GOOGLE_MESSAGE1_PROTO3 { GOOGLE_MESSAGE1_PROTO3 {
@Override ExtensionRegistry getExtensionRegistry() { return ExtensionRegistry.newInstance(); } @Override
ExtensionRegistry getExtensionRegistry() {
return ExtensionRegistry.newInstance();
}
@Override @Override
Message getDefaultInstance() { Message getDefaultInstance() {
return com.google.protobuf.benchmarks.BenchmarkMessage1Proto3.GoogleMessage1 return com.google.protobuf.benchmarks.BenchmarkMessage1Proto3.GoogleMessage1
@ -40,7 +29,9 @@ public class ProtoCaliperBenchmark {
} }
}, },
GOOGLE_MESSAGE1_PROTO2 { GOOGLE_MESSAGE1_PROTO2 {
@Override ExtensionRegistry getExtensionRegistry() { return ExtensionRegistry.newInstance(); } @Override ExtensionRegistry getExtensionRegistry() {
return ExtensionRegistry.newInstance();
}
@Override @Override
Message getDefaultInstance() { Message getDefaultInstance() {
return com.google.protobuf.benchmarks.BenchmarkMessage1Proto2.GoogleMessage1 return com.google.protobuf.benchmarks.BenchmarkMessage1Proto2.GoogleMessage1
@ -48,7 +39,10 @@ public class ProtoCaliperBenchmark {
} }
}, },
GOOGLE_MESSAGE2 { GOOGLE_MESSAGE2 {
@Override ExtensionRegistry getExtensionRegistry() { return ExtensionRegistry.newInstance(); } @Override
ExtensionRegistry getExtensionRegistry() {
return ExtensionRegistry.newInstance();
}
@Override @Override
Message getDefaultInstance() { Message getDefaultInstance() {
return com.google.protobuf.benchmarks.BenchmarkMessage2.GoogleMessage2.getDefaultInstance(); return com.google.protobuf.benchmarks.BenchmarkMessage2.GoogleMessage2.getDefaultInstance();

View File

@ -18,8 +18,7 @@ static struct PyModuleDef _module = {PyModuleDef_HEAD_INIT,
NULL}; NULL};
extern "C" { extern "C" {
PyMODINIT_FUNC PyMODINIT_FUNC PyInit_libbenchmark_messages() {
PyInit_libbenchmark_messages() {
benchmarks::BenchmarkDataset().descriptor(); benchmarks::BenchmarkDataset().descriptor();
benchmarks::proto3::GoogleMessage1().descriptor(); benchmarks::proto3::GoogleMessage1().descriptor();
benchmarks::proto2::GoogleMessage1().descriptor(); benchmarks::proto2::GoogleMessage1().descriptor();

View File

@ -59,11 +59,8 @@ copy "${PROTOBUF_SOURCE_WIN32_PATH}\..\src\google\protobuf\generated_enum_reflec
copy "${PROTOBUF_SOURCE_WIN32_PATH}\..\src\google\protobuf\generated_enum_util.h" include\google\protobuf\generated_enum_util.h copy "${PROTOBUF_SOURCE_WIN32_PATH}\..\src\google\protobuf\generated_enum_util.h" include\google\protobuf\generated_enum_util.h
copy "${PROTOBUF_SOURCE_WIN32_PATH}\..\src\google\protobuf\generated_message_bases.h" include\google\protobuf\generated_message_bases.h copy "${PROTOBUF_SOURCE_WIN32_PATH}\..\src\google\protobuf\generated_message_bases.h" include\google\protobuf\generated_message_bases.h
copy "${PROTOBUF_SOURCE_WIN32_PATH}\..\src\google\protobuf\generated_message_reflection.h" include\google\protobuf\generated_message_reflection.h copy "${PROTOBUF_SOURCE_WIN32_PATH}\..\src\google\protobuf\generated_message_reflection.h" include\google\protobuf\generated_message_reflection.h
copy "${PROTOBUF_SOURCE_WIN32_PATH}\..\src\google\protobuf\generated_message_table_driven.h" include\google\protobuf\generated_message_table_driven.h
copy "${PROTOBUF_SOURCE_WIN32_PATH}\..\src\google\protobuf\generated_message_table_driven_lite.h" include\google\protobuf\generated_message_table_driven_lite.h
copy "${PROTOBUF_SOURCE_WIN32_PATH}\..\src\google\protobuf\generated_message_tctable_decl.h" include\google\protobuf\generated_message_tctable_decl.h copy "${PROTOBUF_SOURCE_WIN32_PATH}\..\src\google\protobuf\generated_message_tctable_decl.h" include\google\protobuf\generated_message_tctable_decl.h
copy "${PROTOBUF_SOURCE_WIN32_PATH}\..\src\google\protobuf\generated_message_tctable_impl.h" include\google\protobuf\generated_message_tctable_impl.h copy "${PROTOBUF_SOURCE_WIN32_PATH}\..\src\google\protobuf\generated_message_tctable_impl.h" include\google\protobuf\generated_message_tctable_impl.h
copy "${PROTOBUF_SOURCE_WIN32_PATH}\..\src\google\protobuf\generated_message_tctable_impl.inc" include\google\protobuf\generated_message_tctable_impl.inc
copy "${PROTOBUF_SOURCE_WIN32_PATH}\..\src\google\protobuf\generated_message_util.h" include\google\protobuf\generated_message_util.h copy "${PROTOBUF_SOURCE_WIN32_PATH}\..\src\google\protobuf\generated_message_util.h" include\google\protobuf\generated_message_util.h
copy "${PROTOBUF_SOURCE_WIN32_PATH}\..\src\google\protobuf\has_bits.h" include\google\protobuf\has_bits.h copy "${PROTOBUF_SOURCE_WIN32_PATH}\..\src\google\protobuf\has_bits.h" include\google\protobuf\has_bits.h
copy "${PROTOBUF_SOURCE_WIN32_PATH}\..\src\google\protobuf\implicit_weak_message.h" include\google\protobuf\implicit_weak_message.h copy "${PROTOBUF_SOURCE_WIN32_PATH}\..\src\google\protobuf\implicit_weak_message.h" include\google\protobuf\implicit_weak_message.h

View File

@ -37,7 +37,6 @@ else()
add_library(GTest::gmock_main ALIAS gmock_main) add_library(GTest::gmock_main ALIAS gmock_main)
endif() endif()
set(lite_test_protos set(lite_test_protos
google/protobuf/map_lite_unittest.proto google/protobuf/map_lite_unittest.proto
google/protobuf/unittest_import_lite.proto google/protobuf/unittest_import_lite.proto

View File

@ -46,7 +46,8 @@ function doTest(request) {
return response; return response;
} }
if (request.getRequestedOutputFormat() == conformance.WireFormat.TEXT_FORMAT) { if (request.getRequestedOutputFormat() ==
conformance.WireFormat.TEXT_FORMAT) {
response.setSkipped('Text format is not supported as output format.'); response.setSkipped('Text format is not supported as output format.');
return response; return response;
} }

View File

@ -10164,8 +10164,8 @@ namespace Google.Protobuf.Reflection {
/// location. /// location.
/// ///
/// Each element is a field number or an index. They form a path from /// Each element is a field number or an index. They form a path from
/// the root FileDescriptorProto to the place where the definition occurs. For /// the root FileDescriptorProto to the place where the definition occurs.
/// example, this path: /// For example, this path:
/// [ 4, 3, 2, 7, 1 ] /// [ 4, 3, 2, 7, 1 ]
/// refers to: /// refers to:
/// file.message_type(3) // 4, 3 /// file.message_type(3) // 4, 3

View File

@ -305,7 +305,8 @@ public abstract class ByteString implements Iterable<Byte>, Serializable {
ByteIterator latterBytes = latter.iterator(); ByteIterator latterBytes = latter.iterator();
while (formerBytes.hasNext() && latterBytes.hasNext()) { while (formerBytes.hasNext() && latterBytes.hasNext()) {
int result = Integer.valueOf(toInt(formerBytes.nextByte())) int result =
Integer.valueOf(toInt(formerBytes.nextByte()))
.compareTo(toInt(latterBytes.nextByte())); .compareTo(toInt(latterBytes.nextByte()));
if (result != 0) { if (result != 0) {
return result; return result;

View File

@ -63,15 +63,15 @@ final class DescriptorMessageInfoFactory implements MessageInfoFactory {
private static final DescriptorMessageInfoFactory instance = new DescriptorMessageInfoFactory(); private static final DescriptorMessageInfoFactory instance = new DescriptorMessageInfoFactory();
/** /**
* Names that should be avoided (in UpperCamelCase format). * Names that should be avoided (in UpperCamelCase format). Using them causes the compiler to
* Using them causes the compiler to generate accessors whose names * generate accessors whose names collide with methods defined in base classes.
* collide with methods defined in base classes.
* *
* Keep this list in sync with kForbiddenWordList in * <p>Keep this list in sync with kForbiddenWordList in
* src/google/protobuf/compiler/java/java_helpers.cc * src/google/protobuf/compiler/java/java_helpers.cc
*/ */
private static final Set<String> specialFieldNames = private static final Set<String> specialFieldNames =
new HashSet<>(Arrays.asList( new HashSet<>(
Arrays.asList(
// java.lang.Object: // java.lang.Object:
"Class", "Class",
// com.google.protobuf.MessageLiteOrBuilder: // com.google.protobuf.MessageLiteOrBuilder:
@ -83,7 +83,8 @@ final class DescriptorMessageInfoFactory implements MessageInfoFactory {
"AllFields", "AllFields",
"DescriptorForType", "DescriptorForType",
"InitializationErrorString", "InitializationErrorString",
"UnknownFields", // TODO(b/219045204): re-enable
// "UnknownFields",
// obsolete. kept for backwards compatibility of generated code // obsolete. kept for backwards compatibility of generated code
"CachedSize")); "CachedSize"));
@ -148,6 +149,8 @@ final class DescriptorMessageInfoFactory implements MessageInfoFactory {
* *
* <p>This class is thread-safe. * <p>This class is thread-safe.
*/ */
// <p>The code is adapted from the C++ implementation:
// https://github.com/protocolbuffers/protobuf/blob/master/src/google/protobuf/compiler/java/java_helpers.h
static class IsInitializedCheckAnalyzer { static class IsInitializedCheckAnalyzer {
private final Map<Descriptor, Boolean> resultCache = private final Map<Descriptor, Boolean> resultCache =
@ -630,7 +633,8 @@ final class DescriptorMessageInfoFactory implements MessageInfoFactory {
// For example: // For example:
// proto field name = "class" // proto field name = "class"
// java field name = "class__" // java field name = "class__"
// accessor method name = "getClass_()" (so that it does not clash with Object.getClass()) // accessor method name = "getClass_()" (so that it does not clash with
// Object.getClass())
suffix = "__"; suffix = "__";
} else { } else {
// For other proto field names, // For other proto field names,
@ -652,7 +656,8 @@ final class DescriptorMessageInfoFactory implements MessageInfoFactory {
/** /**
* Converts a snake case string into lower camel case. * Converts a snake case string into lower camel case.
* *
* <p>Some examples:</p> * <p>Some examples:
*
* <pre> * <pre>
* snakeCaseToLowerCamelCase("foo_bar") => "fooBar" * snakeCaseToLowerCamelCase("foo_bar") => "fooBar"
* snakeCaseToLowerCamelCase("foo") => "foo" * snakeCaseToLowerCamelCase("foo") => "foo"
@ -668,7 +673,8 @@ final class DescriptorMessageInfoFactory implements MessageInfoFactory {
/** /**
* Converts a snake case string into upper camel case. * Converts a snake case string into upper camel case.
* *
* <p>Some examples:</p> * <p>Some examples:
*
* <pre> * <pre>
* snakeCaseToUpperCamelCase("foo_bar") => "FooBar" * snakeCaseToUpperCamelCase("foo_bar") => "FooBar"
* snakeCaseToUpperCamelCase("foo") => "Foo" * snakeCaseToUpperCamelCase("foo") => "Foo"
@ -684,10 +690,11 @@ final class DescriptorMessageInfoFactory implements MessageInfoFactory {
/** /**
* Converts a snake case string into camel case. * Converts a snake case string into camel case.
* *
* <p>For better readability, prefer calling either * <p>For better readability, prefer calling either {@link #snakeCaseToLowerCamelCase(String)} or
* {@link #snakeCaseToLowerCamelCase(String)} or {@link #snakeCaseToUpperCamelCase(String)}.</p> * {@link #snakeCaseToUpperCamelCase(String)}.
*
* <p>Some examples:
* *
* <p>Some examples:</p>
* <pre> * <pre>
* snakeCaseToCamelCase("foo_bar", false) => "fooBar" * snakeCaseToCamelCase("foo_bar", false) => "fooBar"
* snakeCaseToCamelCase("foo_bar", true) => "FooBar" * snakeCaseToCamelCase("foo_bar", true) => "FooBar"
@ -697,16 +704,15 @@ final class DescriptorMessageInfoFactory implements MessageInfoFactory {
* snakeCaseToCamelCase("fooBar", false) => "fooBar" * snakeCaseToCamelCase("fooBar", false) => "fooBar"
* </pre> * </pre>
* *
* <p>This implementation of this method must exactly match the corresponding * <p>This implementation of this method must exactly match the corresponding function in the
* function in the protocol compiler. Specifically, the * protocol compiler. Specifically, the {@code UnderscoresToCamelCase} function in {@code
* {@code UnderscoresToCamelCase} function in * src/google/protobuf/compiler/java/java_helpers.cc}.
* {@code src/google/protobuf/compiler/java/java_helpers.cc}.</p>
* *
* @param snakeCase the string in snake case to convert * @param snakeCase the string in snake case to convert
* @param capFirst true if the first letter of the returned string should be uppercase. * @param capFirst true if the first letter of the returned string should be uppercase. false if
* false if the first letter of the returned string should be lowercase. * the first letter of the returned string should be lowercase.
* @return the string converted to camel case, with an uppercase or lowercase first * @return the string converted to camel case, with an uppercase or lowercase first character
* character depending on if {@code capFirst} is true or false, respectively * depending on if {@code capFirst} is true or false, respectively
*/ */
private static String snakeCaseToCamelCase(String snakeCase, boolean capFirst) { private static String snakeCaseToCamelCase(String snakeCase, boolean capFirst) {
StringBuilder sb = new StringBuilder(snakeCase.length() + 1); StringBuilder sb = new StringBuilder(snakeCase.length() + 1);

View File

@ -32,6 +32,7 @@
// This file tests that various identifiers work as field and type names even // This file tests that various identifiers work as field and type names even
// though the same identifiers are used internally by the java code generator. // though the same identifiers are used internally by the java code generator.
// LINT: LEGACY_NAMES
syntax = "proto2"; syntax = "proto2";
@ -41,8 +42,9 @@ option java_generic_services = true; // auto-added
option java_package = "com.google.protobuf"; option java_package = "com.google.protobuf";
option java_outer_classname = "TestBadIdentifiersProto"; option java_outer_classname = "TestBadIdentifiersProto";
// Message with field names using underscores that conflict with accessors in the base message class in java. // Message with field names using underscores that conflict with accessors in
// See kForbiddenWordList in src/google/protobuf/compiler/java/java_helpers.cc // the base message class in java. See kForbiddenWordList in
// src/google/protobuf/compiler/java/java_helpers.cc
message ForbiddenWordsUnderscoreMessage { message ForbiddenWordsUnderscoreMessage {
// java.lang.Object // java.lang.Object
optional bool class = 1; optional bool class = 1;
@ -55,13 +57,15 @@ message ForbiddenWordsUnderscoreMessage {
optional bool all_fields = 5; optional bool all_fields = 5;
optional bool descriptor_for_type = 6; optional bool descriptor_for_type = 6;
optional bool initialization_error_string = 7; optional bool initialization_error_string = 7;
optional bool unknown_fields = 8; // TODO(b/219045204): re-enable
// optional bool unknown_fields = 8;
// obsolete. kept for backwards compatibility of generated code // obsolete. kept for backwards compatibility of generated code
optional bool cached_size = 9; optional bool cached_size = 9;
} }
// Message with field names using leading underscores that conflict with accessors in the base message class in java. // Message with field names using leading underscores that conflict with
// See kForbiddenWordList in src/google/protobuf/compiler/java/java_helpers.cc // accessors in the base message class in java. See kForbiddenWordList in
// src/google/protobuf/compiler/java/java_helpers.cc
message ForbiddenWordsLeadingUnderscoreMessage { message ForbiddenWordsLeadingUnderscoreMessage {
// java.lang.Object // java.lang.Object
optional bool _class = 1; optional bool _class = 1;
@ -74,13 +78,15 @@ message ForbiddenWordsLeadingUnderscoreMessage {
optional bool _all_fields = 5; optional bool _all_fields = 5;
optional bool _descriptor_for_type = 6; optional bool _descriptor_for_type = 6;
optional bool _initialization_error_string = 7; optional bool _initialization_error_string = 7;
optional bool _unknown_fields = 8; // TODO(b/219045204): re-enable
// optional bool _unknown_fields = 8;
// obsolete. kept for backwards compatibility of generated code // obsolete. kept for backwards compatibility of generated code
optional bool _cached_size = 9; optional bool _cached_size = 9;
} }
// Message with field names in camel case that conflict with accessors in the base message class in java. // Message with field names in camel case that conflict with accessors in the
// See kForbiddenWordList in src/google/protobuf/compiler/java/java_helpers.cc // base message class in java. See kForbiddenWordList in
// src/google/protobuf/compiler/java/java_helpers.cc
message ForbiddenWordsCamelMessage { message ForbiddenWordsCamelMessage {
// java.lang.Object // java.lang.Object
optional bool class = 1; optional bool class = 1;
@ -93,7 +99,8 @@ message ForbiddenWordsCamelMessage {
optional bool initializationErrorString = 5; optional bool initializationErrorString = 5;
optional bool descriptorForType = 6; optional bool descriptorForType = 6;
optional bool allFields = 7; optional bool allFields = 7;
optional bool unknownFields = 8; // TODO(b/219045204): re-enable
// optional bool unknownFields = 8;
// obsolete. kept for backwards compatibility of generated code // obsolete. kept for backwards compatibility of generated code
optional bool cachedSize = 9; optional bool cachedSize = 9;
} }
@ -166,7 +173,8 @@ message Double {
} }
service TestConflictingMethodNames { service TestConflictingMethodNames {
rpc Override(ForbiddenWordsUnderscoreMessage) returns (ForbiddenWordsUnderscoreMessage); rpc Override(ForbiddenWordsUnderscoreMessage)
returns (ForbiddenWordsUnderscoreMessage);
} }
message TestConflictingFieldNames { message TestConflictingFieldNames {

View File

@ -87,15 +87,11 @@ class Proto2LiteTest {
optionalBool = true optionalBool = true
optionalString = "115" optionalString = "115"
optionalBytes = toBytes("116") optionalBytes = toBytes("116")
optionalGroup = optionalGroup = TestAllTypesLiteKt.optionalGroup { a = 117 }
TestAllTypesLiteKt.optionalGroup { a = 117 }
optionalNestedMessage = nestedMessage { bb = 118 } optionalNestedMessage = nestedMessage { bb = 118 }
optionalForeignMessage = optionalForeignMessage = foreignMessageLite { c = 119 }
foreignMessageLite { c = 119 } optionalImportMessage = ImportMessageLite.newBuilder().setD(120).build()
optionalImportMessage = optionalPublicImportMessage = PublicImportMessageLite.newBuilder().setE(126).build()
ImportMessageLite.newBuilder().setD(120).build()
optionalPublicImportMessage =
PublicImportMessageLite.newBuilder().setE(126).build()
optionalLazyMessage = nestedMessage { bb = 127 } optionalLazyMessage = nestedMessage { bb = 127 }
optionalNestedEnum = NestedEnum.BAZ optionalNestedEnum = NestedEnum.BAZ
optionalForeignEnum = ForeignEnumLite.FOREIGN_LITE_BAZ optionalForeignEnum = ForeignEnumLite.FOREIGN_LITE_BAZ
@ -119,12 +115,8 @@ class Proto2LiteTest {
repeatedBytes.add(toBytes("216")) repeatedBytes.add(toBytes("216"))
repeatedGroup.add(TestAllTypesLiteKt.repeatedGroup { a = 217 }) repeatedGroup.add(TestAllTypesLiteKt.repeatedGroup { a = 217 })
repeatedNestedMessage.add(nestedMessage { bb = 218 }) repeatedNestedMessage.add(nestedMessage { bb = 218 })
repeatedForeignMessage.add( repeatedForeignMessage.add(foreignMessageLite { c = 219 })
foreignMessageLite { c = 219 } repeatedImportMessage.add(ImportMessageLite.newBuilder().setD(220).build())
)
repeatedImportMessage.add(
ImportMessageLite.newBuilder().setD(220).build()
)
repeatedLazyMessage.add(nestedMessage { bb = 227 }) repeatedLazyMessage.add(nestedMessage { bb = 227 })
repeatedNestedEnum.add(NestedEnum.BAR) repeatedNestedEnum.add(NestedEnum.BAR)
repeatedForeignEnum.add(ForeignEnumLite.FOREIGN_LITE_BAR) repeatedForeignEnum.add(ForeignEnumLite.FOREIGN_LITE_BAR)
@ -148,12 +140,9 @@ class Proto2LiteTest {
repeatedBytes += toBytes("316") repeatedBytes += toBytes("316")
repeatedGroup += TestAllTypesLiteKt.repeatedGroup { a = 317 } repeatedGroup += TestAllTypesLiteKt.repeatedGroup { a = 317 }
repeatedNestedMessage += nestedMessage { bb = 318 } repeatedNestedMessage += nestedMessage { bb = 318 }
repeatedForeignMessage += repeatedForeignMessage += foreignMessageLite { c = 319 }
foreignMessageLite { c = 319 } repeatedImportMessage += ImportMessageLite.newBuilder().setD(320).build()
repeatedImportMessage += repeatedLazyMessage += TestAllTypesLiteKt.nestedMessage { bb = 327 }
ImportMessageLite.newBuilder().setD(320).build()
repeatedLazyMessage +=
TestAllTypesLiteKt.nestedMessage { bb = 327 }
repeatedNestedEnum += NestedEnum.BAZ repeatedNestedEnum += NestedEnum.BAZ
repeatedForeignEnum += ForeignEnumLite.FOREIGN_LITE_BAZ repeatedForeignEnum += ForeignEnumLite.FOREIGN_LITE_BAZ
repeatedImportEnum += ImportEnumLite.IMPORT_LITE_BAZ repeatedImportEnum += ImportEnumLite.IMPORT_LITE_BAZ
@ -180,14 +169,12 @@ class Proto2LiteTest {
defaultStringPiece = "424" defaultStringPiece = "424"
defaultCord = "425" defaultCord = "425"
oneofUint32 = 601 oneofUint32 = 601
oneofNestedMessage = oneofNestedMessage = TestAllTypesLiteKt.nestedMessage { bb = 602 }
TestAllTypesLiteKt.nestedMessage { bb = 602 }
oneofString = "603" oneofString = "603"
oneofBytes = toBytes("604") oneofBytes = toBytes("604")
} }
).isEqualTo(
TestUtilLite.getAllLiteSetBuilder().build()
) )
.isEqualTo(TestUtilLite.getAllLiteSetBuilder().build())
} }
@Test @Test
@ -243,7 +230,8 @@ class Proto2LiteTest {
TestAllTypesLiteKt.repeatedGroup { a = 2 } TestAllTypesLiteKt.repeatedGroup { a = 2 }
) )
) )
assertThat(repeatedGroup).isEqualTo( assertThat(repeatedGroup)
.isEqualTo(
listOf( listOf(
TestAllTypesLiteKt.repeatedGroup { a = 1 }, TestAllTypesLiteKt.repeatedGroup { a = 1 },
TestAllTypesLiteKt.repeatedGroup { a = 2 } TestAllTypesLiteKt.repeatedGroup { a = 2 }
@ -254,7 +242,8 @@ class Proto2LiteTest {
TestAllTypesLiteKt.repeatedGroup { a = 3 }, TestAllTypesLiteKt.repeatedGroup { a = 3 },
TestAllTypesLiteKt.repeatedGroup { a = 4 } TestAllTypesLiteKt.repeatedGroup { a = 4 }
) )
assertThat(repeatedGroup).isEqualTo( assertThat(repeatedGroup)
.isEqualTo(
listOf( listOf(
TestAllTypesLiteKt.repeatedGroup { a = 1 }, TestAllTypesLiteKt.repeatedGroup { a = 1 },
TestAllTypesLiteKt.repeatedGroup { a = 2 }, TestAllTypesLiteKt.repeatedGroup { a = 2 },
@ -263,7 +252,8 @@ class Proto2LiteTest {
) )
) )
repeatedGroup[0] = TestAllTypesLiteKt.repeatedGroup { a = 5 } repeatedGroup[0] = TestAllTypesLiteKt.repeatedGroup { a = 5 }
assertThat(repeatedGroup).isEqualTo( assertThat(repeatedGroup)
.isEqualTo(
listOf( listOf(
TestAllTypesLiteKt.repeatedGroup { a = 5 }, TestAllTypesLiteKt.repeatedGroup { a = 5 },
TestAllTypesLiteKt.repeatedGroup { a = 2 }, TestAllTypesLiteKt.repeatedGroup { a = 2 },
@ -273,14 +263,11 @@ class Proto2LiteTest {
) )
repeatedNestedMessage.addAll(listOf(nestedMessage { bb = 1 }, nestedMessage { bb = 2 })) repeatedNestedMessage.addAll(listOf(nestedMessage { bb = 1 }, nestedMessage { bb = 2 }))
assertThat(repeatedNestedMessage).isEqualTo( assertThat(repeatedNestedMessage)
listOf( .isEqualTo(listOf(nestedMessage { bb = 1 }, nestedMessage { bb = 2 }))
nestedMessage { bb = 1 },
nestedMessage { bb = 2 }
)
)
repeatedNestedMessage += listOf(nestedMessage { bb = 3 }, nestedMessage { bb = 4 }) repeatedNestedMessage += listOf(nestedMessage { bb = 3 }, nestedMessage { bb = 4 })
assertThat(repeatedNestedMessage).isEqualTo( assertThat(repeatedNestedMessage)
.isEqualTo(
listOf( listOf(
nestedMessage { bb = 1 }, nestedMessage { bb = 1 },
nestedMessage { bb = 2 }, nestedMessage { bb = 2 },
@ -289,7 +276,8 @@ class Proto2LiteTest {
) )
) )
repeatedNestedMessage[0] = nestedMessage { bb = 5 } repeatedNestedMessage[0] = nestedMessage { bb = 5 }
assertThat(repeatedNestedMessage).isEqualTo( assertThat(repeatedNestedMessage)
.isEqualTo(
listOf( listOf(
nestedMessage { bb = 5 }, nestedMessage { bb = 5 },
nestedMessage { bb = 2 }, nestedMessage { bb = 2 },
@ -301,13 +289,11 @@ class Proto2LiteTest {
repeatedNestedEnum.addAll(listOf(NestedEnum.FOO, NestedEnum.BAR)) repeatedNestedEnum.addAll(listOf(NestedEnum.FOO, NestedEnum.BAR))
assertThat(repeatedNestedEnum).isEqualTo(listOf(NestedEnum.FOO, NestedEnum.BAR)) assertThat(repeatedNestedEnum).isEqualTo(listOf(NestedEnum.FOO, NestedEnum.BAR))
repeatedNestedEnum += listOf(NestedEnum.BAZ, NestedEnum.FOO) repeatedNestedEnum += listOf(NestedEnum.BAZ, NestedEnum.FOO)
assertThat(repeatedNestedEnum).isEqualTo( assertThat(repeatedNestedEnum)
listOf(NestedEnum.FOO, NestedEnum.BAR, NestedEnum.BAZ, NestedEnum.FOO) .isEqualTo(listOf(NestedEnum.FOO, NestedEnum.BAR, NestedEnum.BAZ, NestedEnum.FOO))
)
repeatedNestedEnum[0] = NestedEnum.BAR repeatedNestedEnum[0] = NestedEnum.BAR
assertThat(repeatedNestedEnum).isEqualTo( assertThat(repeatedNestedEnum)
listOf(NestedEnum.BAR, NestedEnum.BAR, NestedEnum.BAZ, NestedEnum.FOO) .isEqualTo(listOf(NestedEnum.BAR, NestedEnum.BAR, NestedEnum.BAZ, NestedEnum.FOO))
)
} }
} }
@ -380,21 +366,15 @@ class Proto2LiteTest {
optionalInt32 = 101 optionalInt32 = 101
optionalString = "115" optionalString = "115"
} }
val modifiedMessage = message.copy { val modifiedMessage = message.copy { optionalInt32 = 201 }
optionalInt32 = 201
}
assertThat(message).isEqualTo( assertThat(message)
TestAllTypesLite.newBuilder() .isEqualTo(
.setOptionalInt32(101) TestAllTypesLite.newBuilder().setOptionalInt32(101).setOptionalString("115").build()
.setOptionalString("115")
.build()
) )
assertThat(modifiedMessage).isEqualTo( assertThat(modifiedMessage)
TestAllTypesLite.newBuilder() .isEqualTo(
.setOptionalInt32(201) TestAllTypesLite.newBuilder().setOptionalInt32(201).setOptionalString("115").build()
.setOptionalString("115")
.build()
) )
} }
@ -402,18 +382,15 @@ class Proto2LiteTest {
fun testOneof() { fun testOneof() {
val message = testAllTypesLite { val message = testAllTypesLite {
oneofString = "foo" oneofString = "foo"
assertThat(oneofFieldCase) assertThat(oneofFieldCase).isEqualTo(TestAllTypesLite.OneofFieldCase.ONEOF_STRING)
.isEqualTo(TestAllTypesLite.OneofFieldCase.ONEOF_STRING)
assertThat(oneofString).isEqualTo("foo") assertThat(oneofString).isEqualTo("foo")
clearOneofField() clearOneofField()
assertThat(hasOneofUint32()).isFalse() assertThat(hasOneofUint32()).isFalse()
assertThat(oneofFieldCase) assertThat(oneofFieldCase).isEqualTo(TestAllTypesLite.OneofFieldCase.ONEOFFIELD_NOT_SET)
.isEqualTo(TestAllTypesLite.OneofFieldCase.ONEOFFIELD_NOT_SET)
oneofUint32 = 5 oneofUint32 = 5
} }
assertThat(message.getOneofFieldCase()) assertThat(message.getOneofFieldCase()).isEqualTo(TestAllTypesLite.OneofFieldCase.ONEOF_UINT32)
.isEqualTo(TestAllTypesLite.OneofFieldCase.ONEOF_UINT32)
assertThat(message.getOneofUint32()).isEqualTo(5) assertThat(message.getOneofUint32()).isEqualTo(5)
} }
@ -536,9 +513,8 @@ class Proto2LiteTest {
this[UnittestLite.oneofStringExtensionLite] = "603" this[UnittestLite.oneofStringExtensionLite] = "603"
this[UnittestLite.oneofBytesExtensionLite] = toBytes("604") this[UnittestLite.oneofBytesExtensionLite] = toBytes("604")
} }
).isEqualTo(
TestUtilLite.getAllLiteExtensionsSet()
) )
.isEqualTo(TestUtilLite.getAllLiteExtensionsSet())
} }
@Test @Test
@ -584,23 +560,16 @@ class Proto2LiteTest {
.isEqualTo(listOf("5", "2", "3", "4")) .isEqualTo(listOf("5", "2", "3", "4"))
this[UnittestLite.repeatedGroupExtensionLite].addAll( this[UnittestLite.repeatedGroupExtensionLite].addAll(
listOf( listOf(repeatedGroupExtensionLite { a = 1 }, repeatedGroupExtensionLite { a = 2 })
repeatedGroupExtensionLite { a = 1 },
repeatedGroupExtensionLite { a = 2 }
)
)
assertThat(this[UnittestLite.repeatedGroupExtensionLite]).isEqualTo(
listOf(
repeatedGroupExtensionLite { a = 1 },
repeatedGroupExtensionLite { a = 2 }
) )
assertThat(this[UnittestLite.repeatedGroupExtensionLite])
.isEqualTo(
listOf(repeatedGroupExtensionLite { a = 1 }, repeatedGroupExtensionLite { a = 2 })
) )
this[UnittestLite.repeatedGroupExtensionLite] += this[UnittestLite.repeatedGroupExtensionLite] +=
listOf( listOf(repeatedGroupExtensionLite { a = 3 }, repeatedGroupExtensionLite { a = 4 })
repeatedGroupExtensionLite { a = 3 }, assertThat(this[UnittestLite.repeatedGroupExtensionLite])
repeatedGroupExtensionLite { a = 4 } .isEqualTo(
)
assertThat(this[UnittestLite.repeatedGroupExtensionLite]).isEqualTo(
listOf( listOf(
repeatedGroupExtensionLite { a = 1 }, repeatedGroupExtensionLite { a = 1 },
repeatedGroupExtensionLite { a = 2 }, repeatedGroupExtensionLite { a = 2 },
@ -609,7 +578,8 @@ class Proto2LiteTest {
) )
) )
this[UnittestLite.repeatedGroupExtensionLite][0] = repeatedGroupExtensionLite { a = 5 } this[UnittestLite.repeatedGroupExtensionLite][0] = repeatedGroupExtensionLite { a = 5 }
assertThat(this[UnittestLite.repeatedGroupExtensionLite]).isEqualTo( assertThat(this[UnittestLite.repeatedGroupExtensionLite])
.isEqualTo(
listOf( listOf(
repeatedGroupExtensionLite { a = 5 }, repeatedGroupExtensionLite { a = 5 },
repeatedGroupExtensionLite { a = 2 }, repeatedGroupExtensionLite { a = 2 },
@ -621,12 +591,12 @@ class Proto2LiteTest {
this[UnittestLite.repeatedNestedMessageExtensionLite].addAll( this[UnittestLite.repeatedNestedMessageExtensionLite].addAll(
listOf(nestedMessage { bb = 1 }, nestedMessage { bb = 2 }) listOf(nestedMessage { bb = 1 }, nestedMessage { bb = 2 })
) )
assertThat(this[UnittestLite.repeatedNestedMessageExtensionLite]).isEqualTo( assertThat(this[UnittestLite.repeatedNestedMessageExtensionLite])
listOf(nestedMessage { bb = 1 }, nestedMessage { bb = 2 }) .isEqualTo(listOf(nestedMessage { bb = 1 }, nestedMessage { bb = 2 }))
)
this[UnittestLite.repeatedNestedMessageExtensionLite] += this[UnittestLite.repeatedNestedMessageExtensionLite] +=
listOf(nestedMessage { bb = 3 }, nestedMessage { bb = 4 }) listOf(nestedMessage { bb = 3 }, nestedMessage { bb = 4 })
assertThat(this[UnittestLite.repeatedNestedMessageExtensionLite]).isEqualTo( assertThat(this[UnittestLite.repeatedNestedMessageExtensionLite])
.isEqualTo(
listOf( listOf(
nestedMessage { bb = 1 }, nestedMessage { bb = 1 },
nestedMessage { bb = 2 }, nestedMessage { bb = 2 },
@ -635,7 +605,8 @@ class Proto2LiteTest {
) )
) )
this[UnittestLite.repeatedNestedMessageExtensionLite][0] = nestedMessage { bb = 5 } this[UnittestLite.repeatedNestedMessageExtensionLite][0] = nestedMessage { bb = 5 }
assertThat(this[UnittestLite.repeatedNestedMessageExtensionLite]).isEqualTo( assertThat(this[UnittestLite.repeatedNestedMessageExtensionLite])
.isEqualTo(
listOf( listOf(
nestedMessage { bb = 5 }, nestedMessage { bb = 5 },
nestedMessage { bb = 2 }, nestedMessage { bb = 2 },
@ -644,18 +615,17 @@ class Proto2LiteTest {
) )
) )
this[UnittestLite.repeatedNestedEnumExtensionLite] this[UnittestLite.repeatedNestedEnumExtensionLite].addAll(
.addAll(listOf(NestedEnum.FOO, NestedEnum.BAR)) listOf(NestedEnum.FOO, NestedEnum.BAR)
)
assertThat(this[UnittestLite.repeatedNestedEnumExtensionLite]) assertThat(this[UnittestLite.repeatedNestedEnumExtensionLite])
.isEqualTo(listOf(NestedEnum.FOO, NestedEnum.BAR)) .isEqualTo(listOf(NestedEnum.FOO, NestedEnum.BAR))
this[UnittestLite.repeatedNestedEnumExtensionLite] += listOf(NestedEnum.BAZ, NestedEnum.FOO) this[UnittestLite.repeatedNestedEnumExtensionLite] += listOf(NestedEnum.BAZ, NestedEnum.FOO)
assertThat(this[UnittestLite.repeatedNestedEnumExtensionLite]).isEqualTo( assertThat(this[UnittestLite.repeatedNestedEnumExtensionLite])
listOf(NestedEnum.FOO, NestedEnum.BAR, NestedEnum.BAZ, NestedEnum.FOO) .isEqualTo(listOf(NestedEnum.FOO, NestedEnum.BAR, NestedEnum.BAZ, NestedEnum.FOO))
)
this[UnittestLite.repeatedNestedEnumExtensionLite][0] = NestedEnum.BAR this[UnittestLite.repeatedNestedEnumExtensionLite][0] = NestedEnum.BAR
assertThat(this[UnittestLite.repeatedNestedEnumExtensionLite]).isEqualTo( assertThat(this[UnittestLite.repeatedNestedEnumExtensionLite])
listOf(NestedEnum.BAR, NestedEnum.BAR, NestedEnum.BAZ, NestedEnum.FOO) .isEqualTo(listOf(NestedEnum.BAR, NestedEnum.BAR, NestedEnum.BAZ, NestedEnum.FOO))
)
} }
} }
@ -726,17 +696,10 @@ class Proto2LiteTest {
@Test @Test
fun testEmptyMessages() { fun testEmptyMessages() {
assertThat( assertThat(testEmptyMessageLite {}).isEqualTo(TestEmptyMessageLite.newBuilder().build())
testEmptyMessageLite {}
).isEqualTo(
TestEmptyMessageLite.newBuilder().build()
)
assertThat( assertThat(testEmptyMessageWithExtensionsLite {})
testEmptyMessageWithExtensionsLite {} .isEqualTo(TestEmptyMessageWithExtensionsLite.newBuilder().build())
).isEqualTo(
TestEmptyMessageWithExtensionsLite.newBuilder().build()
)
} }
@Test @Test
@ -761,7 +724,8 @@ class Proto2LiteTest {
mapInt32Enum[1] = MapEnumLite.MAP_ENUM_FOO_LITE mapInt32Enum[1] = MapEnumLite.MAP_ENUM_FOO_LITE
mapInt32ForeignMessage[1] = foreignMessageLite { c = 1 } mapInt32ForeignMessage[1] = foreignMessageLite { c = 1 }
} }
).isEqualTo( )
.isEqualTo(
TestMapLite.newBuilder() TestMapLite.newBuilder()
.putMapInt32Int32(1, 2) .putMapInt32Int32(1, 2)
.putMapInt64Int64(1L, 2L) .putMapInt64Int64(1L, 2L)
@ -804,13 +768,13 @@ class Proto2LiteTest {
mapInt32Enum.put(1, MapEnumLite.MAP_ENUM_FOO_LITE) mapInt32Enum.put(1, MapEnumLite.MAP_ENUM_FOO_LITE)
assertThat(mapInt32Enum).isEqualTo(mapOf(1 to MapEnumLite.MAP_ENUM_FOO_LITE)) assertThat(mapInt32Enum).isEqualTo(mapOf(1 to MapEnumLite.MAP_ENUM_FOO_LITE))
mapInt32Enum[2] = MapEnumLite.MAP_ENUM_BAR_LITE mapInt32Enum[2] = MapEnumLite.MAP_ENUM_BAR_LITE
assertThat(mapInt32Enum).isEqualTo( assertThat(mapInt32Enum)
mapOf(1 to MapEnumLite.MAP_ENUM_FOO_LITE, 2 to MapEnumLite.MAP_ENUM_BAR_LITE) .isEqualTo(mapOf(1 to MapEnumLite.MAP_ENUM_FOO_LITE, 2 to MapEnumLite.MAP_ENUM_BAR_LITE))
)
mapInt32Enum.putAll( mapInt32Enum.putAll(
mapOf(3 to MapEnumLite.MAP_ENUM_BAZ_LITE, 4 to MapEnumLite.MAP_ENUM_FOO_LITE) mapOf(3 to MapEnumLite.MAP_ENUM_BAZ_LITE, 4 to MapEnumLite.MAP_ENUM_FOO_LITE)
) )
assertThat(mapInt32Enum).isEqualTo( assertThat(mapInt32Enum)
.isEqualTo(
mapOf( mapOf(
1 to MapEnumLite.MAP_ENUM_FOO_LITE, 1 to MapEnumLite.MAP_ENUM_FOO_LITE,
2 to MapEnumLite.MAP_ENUM_BAR_LITE, 2 to MapEnumLite.MAP_ENUM_BAR_LITE,
@ -822,13 +786,13 @@ class Proto2LiteTest {
mapInt32ForeignMessage.put(1, foreignMessageLite { c = 1 }) mapInt32ForeignMessage.put(1, foreignMessageLite { c = 1 })
assertThat(mapInt32ForeignMessage).isEqualTo(mapOf(1 to foreignMessageLite { c = 1 })) assertThat(mapInt32ForeignMessage).isEqualTo(mapOf(1 to foreignMessageLite { c = 1 }))
mapInt32ForeignMessage[2] = foreignMessageLite { c = 2 } mapInt32ForeignMessage[2] = foreignMessageLite { c = 2 }
assertThat(mapInt32ForeignMessage).isEqualTo( assertThat(mapInt32ForeignMessage)
mapOf(1 to foreignMessageLite { c = 1 }, 2 to foreignMessageLite { c = 2 }) .isEqualTo(mapOf(1 to foreignMessageLite { c = 1 }, 2 to foreignMessageLite { c = 2 }))
)
mapInt32ForeignMessage.putAll( mapInt32ForeignMessage.putAll(
mapOf(3 to foreignMessageLite { c = 3 }, 4 to foreignMessageLite { c = 4 }) mapOf(3 to foreignMessageLite { c = 3 }, 4 to foreignMessageLite { c = 4 })
) )
assertThat(mapInt32ForeignMessage).isEqualTo( assertThat(mapInt32ForeignMessage)
.isEqualTo(
mapOf( mapOf(
1 to foreignMessageLite { c = 1 }, 1 to foreignMessageLite { c = 1 },
2 to foreignMessageLite { c = 2 }, 2 to foreignMessageLite { c = 2 },
@ -916,7 +880,8 @@ class Proto2LiteTest {
serializedSize_ = true serializedSize_ = true
by = "foo" by = "foo"
} }
).isEqualTo( )
.isEqualTo(
EvilNamesProto2.newBuilder() EvilNamesProto2.newBuilder()
.setInitialized(true) .setInitialized(true)
.setHasFoo(true) .setHasFoo(true)

View File

@ -58,6 +58,7 @@ import protobuf_unittest.UnittestProto.TestEmptyMessageWithExtensions
import protobuf_unittest.copy import protobuf_unittest.copy
import protobuf_unittest.foreignMessage import protobuf_unittest.foreignMessage
import protobuf_unittest.optionalGroupExtension import protobuf_unittest.optionalGroupExtension
import protobuf_unittest.optionalNestedMessageOrNull
import protobuf_unittest.repeatedGroupExtension import protobuf_unittest.repeatedGroupExtension
import protobuf_unittest.testAllExtensions import protobuf_unittest.testAllExtensions
import protobuf_unittest.testAllTypes import protobuf_unittest.testAllTypes
@ -953,4 +954,16 @@ class Proto2Test {
assertThat(hasDo_()).isFalse() assertThat(hasDo_()).isFalse()
} }
} }
@Test
fun testGetOrNull() {
val noNestedMessage = testAllTypes {}
assertThat(noNestedMessage.optionalNestedMessageOrNull).isEqualTo(null)
val someNestedMessage = testAllTypes {
optionalNestedMessage = TestAllTypesKt.nestedMessage { bb = 118 }
}
assertThat(someNestedMessage.optionalNestedMessageOrNull)
.isEqualTo(TestAllTypesKt.nestedMessage { bb = 118 })
}
} }

View File

@ -44,6 +44,7 @@ import proto3_unittest.UnittestProto3.TestAllTypes
import proto3_unittest.UnittestProto3.TestAllTypes.NestedEnum import proto3_unittest.UnittestProto3.TestAllTypes.NestedEnum
import proto3_unittest.UnittestProto3.TestEmptyMessage import proto3_unittest.UnittestProto3.TestEmptyMessage
import proto3_unittest.copy import proto3_unittest.copy
import proto3_unittest.optionalNestedMessageOrNull
import proto3_unittest.testAllTypes import proto3_unittest.testAllTypes
import proto3_unittest.testEmptyMessage import proto3_unittest.testEmptyMessage
import org.junit.Test import org.junit.Test
@ -86,14 +87,11 @@ class Proto3Test {
assertThat(repeatedString).isEqualTo(listOf("5", "2", "3", "4")) assertThat(repeatedString).isEqualTo(listOf("5", "2", "3", "4"))
repeatedNestedMessage.addAll(listOf(nestedMessage { bb = 1 }, nestedMessage { bb = 2 })) repeatedNestedMessage.addAll(listOf(nestedMessage { bb = 1 }, nestedMessage { bb = 2 }))
assertThat(repeatedNestedMessage).isEqualTo( assertThat(repeatedNestedMessage)
listOf( .isEqualTo(listOf(nestedMessage { bb = 1 }, nestedMessage { bb = 2 }))
nestedMessage { bb = 1 },
nestedMessage { bb = 2 }
)
)
repeatedNestedMessage += listOf(nestedMessage { bb = 3 }, nestedMessage { bb = 4 }) repeatedNestedMessage += listOf(nestedMessage { bb = 3 }, nestedMessage { bb = 4 })
assertThat(repeatedNestedMessage).isEqualTo( assertThat(repeatedNestedMessage)
.isEqualTo(
listOf( listOf(
nestedMessage { bb = 1 }, nestedMessage { bb = 1 },
nestedMessage { bb = 2 }, nestedMessage { bb = 2 },
@ -102,7 +100,8 @@ class Proto3Test {
) )
) )
repeatedNestedMessage[0] = nestedMessage { bb = 5 } repeatedNestedMessage[0] = nestedMessage { bb = 5 }
assertThat(repeatedNestedMessage).isEqualTo( assertThat(repeatedNestedMessage)
.isEqualTo(
listOf( listOf(
nestedMessage { bb = 5 }, nestedMessage { bb = 5 },
nestedMessage { bb = 2 }, nestedMessage { bb = 2 },
@ -114,13 +113,11 @@ class Proto3Test {
repeatedNestedEnum.addAll(listOf(NestedEnum.FOO, NestedEnum.BAR)) repeatedNestedEnum.addAll(listOf(NestedEnum.FOO, NestedEnum.BAR))
assertThat(repeatedNestedEnum).isEqualTo(listOf(NestedEnum.FOO, NestedEnum.BAR)) assertThat(repeatedNestedEnum).isEqualTo(listOf(NestedEnum.FOO, NestedEnum.BAR))
repeatedNestedEnum += listOf(NestedEnum.BAZ, NestedEnum.FOO) repeatedNestedEnum += listOf(NestedEnum.BAZ, NestedEnum.FOO)
assertThat(repeatedNestedEnum).isEqualTo( assertThat(repeatedNestedEnum)
listOf(NestedEnum.FOO, NestedEnum.BAR, NestedEnum.BAZ, NestedEnum.FOO) .isEqualTo(listOf(NestedEnum.FOO, NestedEnum.BAR, NestedEnum.BAZ, NestedEnum.FOO))
)
repeatedNestedEnum[0] = NestedEnum.BAR repeatedNestedEnum[0] = NestedEnum.BAR
assertThat(repeatedNestedEnum).isEqualTo( assertThat(repeatedNestedEnum)
listOf(NestedEnum.BAR, NestedEnum.BAR, NestedEnum.BAZ, NestedEnum.FOO) .isEqualTo(listOf(NestedEnum.BAR, NestedEnum.BAR, NestedEnum.BAZ, NestedEnum.FOO))
)
} }
} }
@ -143,9 +140,8 @@ class Proto3Test {
oneofUint32 = 601 oneofUint32 = 601
clearOneofUint32() clearOneofUint32()
} }
).isEqualTo(
TestAllTypes.newBuilder().build()
) )
.isEqualTo(TestAllTypes.newBuilder().build())
} }
@Test @Test
@ -154,49 +150,32 @@ class Proto3Test {
optionalInt32 = 101 optionalInt32 = 101
optionalString = "115" optionalString = "115"
} }
val modifiedMessage = message.copy { val modifiedMessage = message.copy { optionalInt32 = 201 }
optionalInt32 = 201
}
assertThat(message).isEqualTo( assertThat(message)
TestAllTypes.newBuilder() .isEqualTo(TestAllTypes.newBuilder().setOptionalInt32(101).setOptionalString("115").build())
.setOptionalInt32(101) assertThat(modifiedMessage)
.setOptionalString("115") .isEqualTo(TestAllTypes.newBuilder().setOptionalInt32(201).setOptionalString("115").build())
.build()
)
assertThat(modifiedMessage).isEqualTo(
TestAllTypes.newBuilder()
.setOptionalInt32(201)
.setOptionalString("115")
.build()
)
} }
@Test @Test
fun testOneof() { fun testOneof() {
val message = testAllTypes { val message = testAllTypes {
oneofString = "foo" oneofString = "foo"
assertThat(oneofFieldCase) assertThat(oneofFieldCase).isEqualTo(TestAllTypes.OneofFieldCase.ONEOF_STRING)
.isEqualTo(TestAllTypes.OneofFieldCase.ONEOF_STRING)
assertThat(oneofString).isEqualTo("foo") assertThat(oneofString).isEqualTo("foo")
clearOneofField() clearOneofField()
assertThat(oneofFieldCase) assertThat(oneofFieldCase).isEqualTo(TestAllTypes.OneofFieldCase.ONEOFFIELD_NOT_SET)
.isEqualTo(TestAllTypes.OneofFieldCase.ONEOFFIELD_NOT_SET)
oneofUint32 = 5 oneofUint32 = 5
} }
assertThat(message.getOneofFieldCase()) assertThat(message.getOneofFieldCase()).isEqualTo(TestAllTypes.OneofFieldCase.ONEOF_UINT32)
.isEqualTo(TestAllTypes.OneofFieldCase.ONEOF_UINT32)
assertThat(message.getOneofUint32()).isEqualTo(5) assertThat(message.getOneofUint32()).isEqualTo(5)
} }
@Test @Test
fun testEmptyMessages() { fun testEmptyMessages() {
assertThat( assertThat(testEmptyMessage {}).isEqualTo(TestEmptyMessage.newBuilder().build())
testEmptyMessage {}
).isEqualTo(
TestEmptyMessage.newBuilder().build()
)
} }
@Test @Test
@ -237,7 +216,8 @@ class Proto3Test {
LeadingUnderscore = "foo" LeadingUnderscore = "foo"
option = 1 option = 1
} }
).isEqualTo( )
.isEqualTo(
EvilNamesProto3.newBuilder() EvilNamesProto3.newBuilder()
.setInitialized(true) .setInitialized(true)
.setHasFoo(true) .setHasFoo(true)
@ -350,16 +330,22 @@ class Proto3Test {
@Test @Test
fun testMultipleFiles() { fun testMultipleFiles() {
assertThat( assertThat(com.google.protobuf.kotlin.generator.multipleFilesMessageA {})
com.google.protobuf.kotlin.generator.multipleFilesMessageA {} .isEqualTo(com.google.protobuf.kotlin.generator.MultipleFilesMessageA.newBuilder().build())
).isEqualTo(
com.google.protobuf.kotlin.generator.MultipleFilesMessageA.newBuilder().build()
)
assertThat( assertThat(com.google.protobuf.kotlin.generator.multipleFilesMessageB {})
com.google.protobuf.kotlin.generator.multipleFilesMessageB {} .isEqualTo(com.google.protobuf.kotlin.generator.MultipleFilesMessageB.newBuilder().build())
).isEqualTo( }
com.google.protobuf.kotlin.generator.MultipleFilesMessageB.newBuilder().build()
) @Test
fun testGetOrNull() {
val noNestedMessage = testAllTypes {}
assertThat(noNestedMessage.optionalNestedMessageOrNull).isEqualTo(null)
val someNestedMessage = testAllTypes {
optionalNestedMessage = TestAllTypesKt.nestedMessage { bb = 118 }
}
assertThat(someNestedMessage.optionalNestedMessageOrNull)
.isEqualTo(TestAllTypesKt.nestedMessage { bb = 118 })
} }
} }

View File

@ -1901,7 +1901,7 @@ public class JsonFormat {
return json.getAsString(); return json.getAsString();
} }
private ByteString parseBytes(JsonElement json) throws InvalidProtocolBufferException { private ByteString parseBytes(JsonElement json) {
try { try {
return ByteString.copyFrom(BaseEncoding.base64().decode(json.getAsString())); return ByteString.copyFrom(BaseEncoding.base64().decode(json.getAsString()));
} catch (IllegalArgumentException e) { } catch (IllegalArgumentException e) {

View File

@ -1117,7 +1117,7 @@ public class JsonFormatTest {
+ " \"value\": \"12345\"\n" + " \"value\": \"12345\"\n"
+ "}"); + "}");
assertRoundTripEquals(anyMessage, registry); assertRoundTripEquals(anyMessage, registry);
anyMessage = Any.pack(UInt64Value.newBuilder().setValue(12345).build()); anyMessage = Any.pack(UInt64Value.of(12345));
assertThat(printer.print(anyMessage)) assertThat(printer.print(anyMessage))
.isEqualTo( .isEqualTo(
"{\n" "{\n"
@ -1125,7 +1125,7 @@ public class JsonFormatTest {
+ " \"value\": \"12345\"\n" + " \"value\": \"12345\"\n"
+ "}"); + "}");
assertRoundTripEquals(anyMessage, registry); assertRoundTripEquals(anyMessage, registry);
anyMessage = Any.pack(FloatValue.newBuilder().setValue(12345).build()); anyMessage = Any.pack(FloatValue.of(12345));
assertThat(printer.print(anyMessage)) assertThat(printer.print(anyMessage))
.isEqualTo( .isEqualTo(
"{\n" "{\n"
@ -1133,7 +1133,7 @@ public class JsonFormatTest {
+ " \"value\": 12345.0\n" + " \"value\": 12345.0\n"
+ "}"); + "}");
assertRoundTripEquals(anyMessage, registry); assertRoundTripEquals(anyMessage, registry);
anyMessage = Any.pack(DoubleValue.newBuilder().setValue(12345).build()); anyMessage = Any.pack(DoubleValue.of(12345));
assertThat(printer.print(anyMessage)) assertThat(printer.print(anyMessage))
.isEqualTo( .isEqualTo(
"{\n" "{\n"
@ -1350,20 +1350,16 @@ public class JsonFormatTest {
// Expected. // Expected.
} }
// TODO(xiaofeng): GSON allows trailing comma in arrays even after I set try {
// the JsonReader to non-lenient mode. If we want to enforce strict JSON TestAllTypes.Builder builder = TestAllTypes.newBuilder();
// compliance, we might want to switch to a different JSON parser or mergeFromJson(
// implement one by ourselves. "{\n"
// try { + " \"repeatedInt32\": [12345,]\n"
// TestAllTypes.Builder builder = TestAllTypes.newBuilder(); + "}", builder);
// JsonFormat.merge( assertWithMessage("IOException expected.").fail();
// "{\n" } catch (IOException e) {
// + " \"repeatedInt32\": [12345,]\n" // Expected.
// + "}", builder); }
// fail("Exception is expected.");
// } catch (IOException e) {
// // Expected.
// }
} }
@Test @Test

View File

@ -19,8 +19,8 @@ class Location extends \Google\Protobuf\Internal\Message
* Identifies which part of the FileDescriptorProto was defined at this * Identifies which part of the FileDescriptorProto was defined at this
* location. * location.
* Each element is a field number or an index. They form a path from * Each element is a field number or an index. They form a path from
* the root FileDescriptorProto to the place where the definition. For * the root FileDescriptorProto to the place where the definition occurs.
* example, this path: * For example, this path:
* [ 4, 3, 2, 7, 1 ] * [ 4, 3, 2, 7, 1 ]
* refers to: * refers to:
* file.message_type(3) // 4, 3 * file.message_type(3) // 4, 3
@ -111,8 +111,8 @@ class Location extends \Google\Protobuf\Internal\Message
* Identifies which part of the FileDescriptorProto was defined at this * Identifies which part of the FileDescriptorProto was defined at this
* location. * location.
* Each element is a field number or an index. They form a path from * Each element is a field number or an index. They form a path from
* the root FileDescriptorProto to the place where the definition. For * the root FileDescriptorProto to the place where the definition occurs.
* example, this path: * For example, this path:
* [ 4, 3, 2, 7, 1 ] * [ 4, 3, 2, 7, 1 ]
* refers to: * refers to:
* file.message_type(3) // 4, 3 * file.message_type(3) // 4, 3
@ -185,8 +185,8 @@ class Location extends \Google\Protobuf\Internal\Message
* Identifies which part of the FileDescriptorProto was defined at this * Identifies which part of the FileDescriptorProto was defined at this
* location. * location.
* Each element is a field number or an index. They form a path from * Each element is a field number or an index. They form a path from
* the root FileDescriptorProto to the place where the definition. For * the root FileDescriptorProto to the place where the definition occurs.
* example, this path: * For example, this path:
* [ 4, 3, 2, 7, 1 ] * [ 4, 3, 2, 7, 1 ]
* refers to: * refers to:
* file.message_type(3) // 4, 3 * file.message_type(3) // 4, 3
@ -216,8 +216,8 @@ class Location extends \Google\Protobuf\Internal\Message
* Identifies which part of the FileDescriptorProto was defined at this * Identifies which part of the FileDescriptorProto was defined at this
* location. * location.
* Each element is a field number or an index. They form a path from * Each element is a field number or an index. They form a path from
* the root FileDescriptorProto to the place where the definition. For * the root FileDescriptorProto to the place where the definition occurs.
* example, this path: * For example, this path:
* [ 4, 3, 2, 7, 1 ] * [ 4, 3, 2, 7, 1 ]
* refers to: * refers to:
* file.message_type(3) // 4, 3 * file.message_type(3) // 4, 3

View File

@ -657,8 +657,8 @@ class SecondaryDescriptorFromDescriptorDB(DescriptorPoolTestBase,
self.assertRaisesRegex(KeyError, 'SubMessage', self.assertRaisesRegex(KeyError, 'SubMessage',
self.pool.FindMessageTypeByName, self.pool.FindMessageTypeByName,
'collector.ErrorMessage') 'collector.ErrorMessage')
self.assertRaisesRegex(KeyError, 'SubMessage', self.assertRaisesRegex(KeyError, 'SubMessage', self.pool.FindFileByName,
self.pool.FindFileByName, 'error_file') 'error_file')
with self.assertRaises(KeyError) as exc: with self.assertRaises(KeyError) as exc:
self.pool.FindFileByName('none_file') self.pool.FindFileByName('none_file')
self.assertIn(str(exc.exception), ('\'none_file\'', self.assertIn(str(exc.exception), ('\'none_file\'',

View File

@ -100,9 +100,7 @@ class JsonFormatBase(unittest.TestCase):
def CheckError(self, text, error_message): def CheckError(self, text, error_message):
message = json_format_proto3_pb2.TestMessage() message = json_format_proto3_pb2.TestMessage()
self.assertRaisesRegex( self.assertRaisesRegex(json_format.ParseError, error_message,
json_format.ParseError,
error_message,
json_format.Parse, text, message) json_format.Parse, text, message)
@ -813,8 +811,7 @@ class JsonFormatTest(JsonFormatBase):
self.assertTrue(parsed_message.HasField('message_value')) self.assertTrue(parsed_message.HasField('message_value'))
# Null is not allowed to be used as an element in repeated field. # Null is not allowed to be used as an element in repeated field.
self.assertRaisesRegex( self.assertRaisesRegex(
json_format.ParseError, json_format.ParseError, r'Failed to parse repeatedInt32Value field: '
r'Failed to parse repeatedInt32Value field: '
r'null is not allowed to be used as an element in a repeated field ' r'null is not allowed to be used as an element in a repeated field '
r'at TestMessage.repeatedInt32Value\[1\].', json_format.Parse, r'at TestMessage.repeatedInt32Value\[1\].', json_format.Parse,
'{"repeatedInt32Value":[1, null]}', parsed_message) '{"repeatedInt32Value":[1, null]}', parsed_message)
@ -1019,13 +1016,11 @@ class JsonFormatTest(JsonFormatBase):
def testInvalidMap(self): def testInvalidMap(self):
message = json_format_proto3_pb2.TestMap() message = json_format_proto3_pb2.TestMap()
text = '{"int32Map": {"null": 2, "2": 3}}' text = '{"int32Map": {"null": 2, "2": 3}}'
self.assertRaisesRegex( self.assertRaisesRegex(json_format.ParseError,
json_format.ParseError,
'Failed to parse int32Map field: invalid literal', 'Failed to parse int32Map field: invalid literal',
json_format.Parse, text, message) json_format.Parse, text, message)
text = '{"int32Map": {1: 2, "2": 3}}' text = '{"int32Map": {1: 2, "2": 3}}'
self.assertRaisesRegex( self.assertRaisesRegex(json_format.ParseError,
json_format.ParseError,
(r'Failed to load JSON: Expecting property name' (r'Failed to load JSON: Expecting property name'
r'( enclosed in double quotes)?: line 1'), r'( enclosed in double quotes)?: line 1'),
json_format.Parse, text, message) json_format.Parse, text, message)
@ -1035,8 +1030,7 @@ class JsonFormatTest(JsonFormatBase):
'Failed to parse boolMap field: Expected "true" or "false", not null at ' 'Failed to parse boolMap field: Expected "true" or "false", not null at '
'TestMap.boolMap.key', json_format.Parse, text, message) 'TestMap.boolMap.key', json_format.Parse, text, message)
text = r'{"stringMap": {"a": 3, "\u0061": 2}}' text = r'{"stringMap": {"a": 3, "\u0061": 2}}'
self.assertRaisesRegex( self.assertRaisesRegex(json_format.ParseError,
json_format.ParseError,
'Failed to load JSON: duplicate key a', 'Failed to load JSON: duplicate key a',
json_format.Parse, text, message) json_format.Parse, text, message)
text = r'{"stringMap": 0}' text = r'{"stringMap": 0}'
@ -1057,11 +1051,10 @@ class JsonFormatTest(JsonFormatBase):
text = '{"value": "1970-01-01T00:00:00.0123456789012Z"}' text = '{"value": "1970-01-01T00:00:00.0123456789012Z"}'
self.assertRaisesRegex( self.assertRaisesRegex(
json_format.ParseError, json_format.ParseError,
'nanos 0123456789012 more than 9 fractional digits.', 'nanos 0123456789012 more than 9 fractional digits.', json_format.Parse,
json_format.Parse, text, message) text, message)
text = '{"value": "1972-01-01T01:00:00.01+08"}' text = '{"value": "1972-01-01T01:00:00.01+08"}'
self.assertRaisesRegex( self.assertRaisesRegex(json_format.ParseError,
json_format.ParseError,
(r'Invalid timezone offset value: \+08.'), (r'Invalid timezone offset value: \+08.'),
json_format.Parse, text, message) json_format.Parse, text, message)
# Time smaller than minimum time. # Time smaller than minimum time.
@ -1072,9 +1065,7 @@ class JsonFormatTest(JsonFormatBase):
json_format.Parse, text, message) json_format.Parse, text, message)
# Time bigger than maximum time. # Time bigger than maximum time.
message.value.seconds = 253402300800 message.value.seconds = 253402300800
self.assertRaisesRegex( self.assertRaisesRegex(OverflowError, 'date value out of range',
OverflowError,
'date value out of range',
json_format.MessageToJson, message) json_format.MessageToJson, message)
# Lower case t does not accept. # Lower case t does not accept.
text = '{"value": "0001-01-01t00:00:00Z"}' text = '{"value": "0001-01-01t00:00:00Z"}'
@ -1100,8 +1091,7 @@ class JsonFormatTest(JsonFormatBase):
self.assertRaisesRegex( self.assertRaisesRegex(
json_format.ParseError, json_format.ParseError,
r'Failed to parse value field: ListValue must be in \[\] which is ' r'Failed to parse value field: ListValue must be in \[\] which is '
'1234 at TestListValue.value.', '1234 at TestListValue.value.', json_format.Parse, text, message)
json_format.Parse, text, message)
class UnknownClass(object): class UnknownClass(object):
@ -1119,8 +1109,7 @@ class JsonFormatTest(JsonFormatBase):
self.assertRaisesRegex( self.assertRaisesRegex(
json_format.ParseError, json_format.ParseError,
'Failed to parse value field: Struct must be in a dict which is ' 'Failed to parse value field: Struct must be in a dict which is '
'1234 at TestStruct.value', '1234 at TestStruct.value', json_format.Parse, text, message)
json_format.Parse, text, message)
def testTimestampInvalidStringValue(self): def testTimestampInvalidStringValue(self):
message = json_format_proto3_pb2.TestTimestamp() message = json_format_proto3_pb2.TestTimestamp()
@ -1133,10 +1122,9 @@ class JsonFormatTest(JsonFormatBase):
def testDurationInvalidStringValue(self): def testDurationInvalidStringValue(self):
message = json_format_proto3_pb2.TestDuration() message = json_format_proto3_pb2.TestDuration()
text = '{"value": {"foo": 123}}' text = '{"value": {"foo": 123}}'
self.assertRaisesRegex( self.assertRaisesRegex(json_format.ParseError,
json_format.ParseError, r"Duration JSON value not a string: {u?'foo': 123}",
r"Duration JSON value not a string: {u?'foo': 123}", json_format.Parse, json_format.Parse, text, message)
text, message)
def testFieldMaskInvalidStringValue(self): def testFieldMaskInvalidStringValue(self):
message = json_format_proto3_pb2.TestFieldMask() message = json_format_proto3_pb2.TestFieldMask()
@ -1149,10 +1137,7 @@ class JsonFormatTest(JsonFormatBase):
def testInvalidAny(self): def testInvalidAny(self):
message = any_pb2.Any() message = any_pb2.Any()
text = '{"@type": "type.googleapis.com/google.protobuf.Int32Value"}' text = '{"@type": "type.googleapis.com/google.protobuf.Int32Value"}'
self.assertRaisesRegex( self.assertRaisesRegex(KeyError, 'value', json_format.Parse, text, message)
KeyError,
'value',
json_format.Parse, text, message)
text = '{"value": 1234}' text = '{"value": 1234}'
self.assertRaisesRegex(json_format.ParseError, self.assertRaisesRegex(json_format.ParseError,
'@type is missing when parsing any message at Any', '@type is missing when parsing any message at Any',
@ -1250,9 +1235,7 @@ class JsonFormatTest(JsonFormatBase):
self.assertRaisesRegex( self.assertRaisesRegex(
json_format.ParseError, json_format.ParseError,
r"Value v has unexpected type <class '.*\.UnknownClass'>.", r"Value v has unexpected type <class '.*\.UnknownClass'>.",
json_format.ParseDict, json_format.ParseDict, {'value': UnknownClass()}, message)
{'value': UnknownClass()},
message)
def testMessageToDict(self): def testMessageToDict(self):
message = json_format_proto3_pb2.TestMessage() message = json_format_proto3_pb2.TestMessage()

View File

@ -75,21 +75,19 @@ from google.protobuf.internal import _parameterized
UCS2_MAXUNICODE = 65535 UCS2_MAXUNICODE = 65535
warnings.simplefilter('error', DeprecationWarning) warnings.simplefilter('error', DeprecationWarning)
@_parameterized.named_parameters( @_parameterized.named_parameters(('_proto2', unittest_pb2),
('_proto2', unittest_pb2),
('_proto3', unittest_proto3_arena_pb2)) ('_proto3', unittest_proto3_arena_pb2))
@testing_refleaks.TestCase @testing_refleaks.TestCase
class MessageTest(unittest.TestCase): class MessageTest(unittest.TestCase):
def testBadUtf8String(self, message_module): def testBadUtf8String(self, message_module):
if api_implementation.Type() != 'python': if api_implementation.Type() != 'python':
self.skipTest("Skipping testBadUtf8String, currently only the python " self.skipTest('Skipping testBadUtf8String, currently only the python '
"api implementation raises UnicodeDecodeError when a " 'api implementation raises UnicodeDecodeError when a '
"string field contains bad utf-8.") 'string field contains bad utf-8.')
bad_utf8_data = test_util.GoldenFileData('bad_utf8_string') bad_utf8_data = test_util.GoldenFileData('bad_utf8_string')
with self.assertRaises(UnicodeDecodeError) as context: with self.assertRaises(UnicodeDecodeError) as context:
message_module.TestAllTypes.FromString(bad_utf8_data) message_module.TestAllTypes.FromString(bad_utf8_data)
@ -100,8 +98,7 @@ class MessageTest(unittest.TestCase):
# and doesn't preserve unknown fields, so for proto3 we use a golden # and doesn't preserve unknown fields, so for proto3 we use a golden
# message that doesn't have these fields set. # message that doesn't have these fields set.
if message_module is unittest_pb2: if message_module is unittest_pb2:
golden_data = test_util.GoldenFileData( golden_data = test_util.GoldenFileData('golden_message_oneof_implemented')
'golden_message_oneof_implemented')
else: else:
golden_data = test_util.GoldenFileData('golden_message_proto3') golden_data = test_util.GoldenFileData('golden_message_proto3')
@ -440,8 +437,7 @@ class MessageTest(unittest.TestCase):
except TypeError: except TypeError:
pass pass
self.assertEqual(2, len(msg.repeated_nested_message)) self.assertEqual(2, len(msg.repeated_nested_message))
self.assertEqual([1, 2], self.assertEqual([1, 2], [m.bb for m in msg.repeated_nested_message])
[m.bb for m in msg.repeated_nested_message])
def testInsertRepeatedCompositeField(self, message_module): def testInsertRepeatedCompositeField(self, message_module):
msg = message_module.TestAllTypes() msg = message_module.TestAllTypes()
@ -463,8 +459,8 @@ class MessageTest(unittest.TestCase):
self.assertEqual(5, len(msg.repeated_nested_message)) self.assertEqual(5, len(msg.repeated_nested_message))
self.assertEqual([-1000, 2, -1, 1, 3], self.assertEqual([-1000, 2, -1, 1, 3],
[m.bb for m in msg.repeated_nested_message]) [m.bb for m in msg.repeated_nested_message])
self.assertEqual(str(msg), self.assertEqual(
'repeated_nested_message {\n' str(msg), 'repeated_nested_message {\n'
' bb: -1000\n' ' bb: -1000\n'
'}\n' '}\n'
'repeated_nested_message {\n' 'repeated_nested_message {\n'
@ -497,8 +493,7 @@ class MessageTest(unittest.TestCase):
self.assertEqual(4, len(msg.repeated_int32)) self.assertEqual(4, len(msg.repeated_int32))
msg.repeated_nested_message.MergeFrom(other_msg.repeated_nested_message) msg.repeated_nested_message.MergeFrom(other_msg.repeated_nested_message)
self.assertEqual([1, 2, 3, 4], self.assertEqual([1, 2, 3, 4], [m.bb for m in msg.repeated_nested_message])
[m.bb for m in msg.repeated_nested_message])
def testAddWrongRepeatedNestedField(self, message_module): def testAddWrongRepeatedNestedField(self, message_module):
msg = message_module.TestAllTypes() msg = message_module.TestAllTypes()
@ -543,8 +538,7 @@ class MessageTest(unittest.TestCase):
msg.repeated_nested_message.add(bb=3) msg.repeated_nested_message.add(bb=3)
msg.repeated_nested_message.add(bb=4) msg.repeated_nested_message.add(bb=4)
self.assertEqual([1, 2, 3, 4], self.assertEqual([1, 2, 3, 4], [m.bb for m in msg.repeated_nested_message])
[m.bb for m in msg.repeated_nested_message])
self.assertEqual([4, 3, 2, 1], self.assertEqual([4, 3, 2, 1],
[m.bb for m in reversed(msg.repeated_nested_message)]) [m.bb for m in reversed(msg.repeated_nested_message)])
self.assertEqual([4, 3, 2, 1], self.assertEqual([4, 3, 2, 1],
@ -627,7 +621,8 @@ class MessageTest(unittest.TestCase):
self.assertEqual(message.repeated_nested_message[3].bb, 4) self.assertEqual(message.repeated_nested_message[3].bb, 4)
self.assertEqual(message.repeated_nested_message[4].bb, 5) self.assertEqual(message.repeated_nested_message[4].bb, 5)
self.assertEqual(message.repeated_nested_message[5].bb, 6) self.assertEqual(message.repeated_nested_message[5].bb, 6)
self.assertEqual(str(message.repeated_nested_message), self.assertEqual(
str(message.repeated_nested_message),
'[bb: 1\n, bb: 2\n, bb: 3\n, bb: 4\n, bb: 5\n, bb: 6\n]') '[bb: 1\n, bb: 2\n, bb: 3\n, bb: 4\n, bb: 5\n, bb: 6\n]')
def testSortingRepeatedCompositeFieldsStable(self, message_module): def testSortingRepeatedCompositeFieldsStable(self, message_module):
@ -642,8 +637,7 @@ class MessageTest(unittest.TestCase):
message.repeated_nested_message.add().bb = 24 message.repeated_nested_message.add().bb = 24
message.repeated_nested_message.add().bb = 10 message.repeated_nested_message.add().bb = 10
message.repeated_nested_message.sort(key=lambda z: z.bb // 10) message.repeated_nested_message.sort(key=lambda z: z.bb // 10)
self.assertEqual( self.assertEqual([13, 11, 10, 21, 20, 24, 33],
[13, 11, 10, 21, 20, 24, 33],
[n.bb for n in message.repeated_nested_message]) [n.bb for n in message.repeated_nested_message])
# Make sure that for the C++ implementation, the underlying fields # Make sure that for the C++ implementation, the underlying fields
@ -651,8 +645,7 @@ class MessageTest(unittest.TestCase):
pb = message.SerializeToString() pb = message.SerializeToString()
message.Clear() message.Clear()
message.MergeFromString(pb) message.MergeFromString(pb)
self.assertEqual( self.assertEqual([13, 11, 10, 21, 20, 24, 33],
[13, 11, 10, 21, 20, 24, 33],
[n.bb for n in message.repeated_nested_message]) [n.bb for n in message.repeated_nested_message])
def testRepeatedCompositeFieldSortArguments(self, message_module): def testRepeatedCompositeFieldSortArguments(self, message_module):
@ -826,7 +819,7 @@ class MessageTest(unittest.TestCase):
self.assertTrue(m.HasField('oneof_uint32')) self.assertTrue(m.HasField('oneof_uint32'))
self.assertFalse(m.HasField('oneof_string')) self.assertFalse(m.HasField('oneof_string'))
m.oneof_string = "" m.oneof_string = ''
self.assertEqual('oneof_string', m.WhichOneof('oneof_field')) self.assertEqual('oneof_string', m.WhichOneof('oneof_field'))
self.assertTrue(m.HasField('oneof_string')) self.assertTrue(m.HasField('oneof_string'))
self.assertFalse(m.HasField('oneof_uint32')) self.assertFalse(m.HasField('oneof_uint32'))
@ -973,7 +966,9 @@ class MessageTest(unittest.TestCase):
def testAssignByteStringToUnicodeField(self, message_module): def testAssignByteStringToUnicodeField(self, message_module):
"""Assigning a byte string to a string field should result """Assigning a byte string to a string field should result
in the value being converted to a Unicode string."""
in the value being converted to a Unicode string.
"""
m = message_module.TestAllTypes() m = message_module.TestAllTypes()
m.optional_string = str('') m.optional_string = str('')
self.assertIsInstance(m.optional_string, str) self.assertIsInstance(m.optional_string, str)
@ -1001,8 +996,7 @@ class MessageTest(unittest.TestCase):
with self.assertRaises(NameError) as _: with self.assertRaises(NameError) as _:
m.repeated_int32.extend(a for i in range(10)) # pylint: disable=undefined-variable m.repeated_int32.extend(a for i in range(10)) # pylint: disable=undefined-variable
with self.assertRaises(NameError) as _: with self.assertRaises(NameError) as _:
m.repeated_nested_enum.extend( m.repeated_nested_enum.extend(a for i in range(10)) # pylint: disable=undefined-variable
a for i in range(10)) # pylint: disable=undefined-variable
FALSY_VALUES = [None, False, 0, 0.0, b'', u'', bytearray(), [], {}, set()] FALSY_VALUES = [None, False, 0, 0.0, b'', u'', bytearray(), [], {}, set()]
@ -1179,14 +1173,12 @@ class MessageTest(unittest.TestCase):
pickle.dumps(m.repeated_int32, pickle.HIGHEST_PROTOCOL) pickle.dumps(m.repeated_int32, pickle.HIGHEST_PROTOCOL)
def testSortEmptyRepeatedCompositeContainer(self, message_module): def testSortEmptyRepeatedCompositeContainer(self, message_module):
"""Exercise a scenario that has led to segfaults in the past. """Exercise a scenario that has led to segfaults in the past."""
"""
m = message_module.TestAllTypes() m = message_module.TestAllTypes()
m.repeated_nested_message.sort() m.repeated_nested_message.sort()
def testHasFieldOnRepeatedField(self, message_module): def testHasFieldOnRepeatedField(self, message_module):
"""Using HasField on a repeated field should raise an exception. """Using HasField on a repeated field should raise an exception."""
"""
m = message_module.TestAllTypes() m = message_module.TestAllTypes()
with self.assertRaises(ValueError) as _: with self.assertRaises(ValueError) as _:
m.HasField('repeated_int32') m.HasField('repeated_int32')
@ -1226,6 +1218,7 @@ class MessageTest(unittest.TestCase):
def testReleasedNestedMessages(self, message_module): def testReleasedNestedMessages(self, message_module):
"""A case that lead to a segfault when a message detached from its parent """A case that lead to a segfault when a message detached from its parent
container has itself a child container. container has itself a child container.
""" """
m = message_module.NestedTestAllTypes() m = message_module.NestedTestAllTypes()
@ -1271,17 +1264,17 @@ class Proto2Test(unittest.TestCase):
def testFieldPresence(self): def testFieldPresence(self):
message = unittest_pb2.TestAllTypes() message = unittest_pb2.TestAllTypes()
self.assertFalse(message.HasField("optional_int32")) self.assertFalse(message.HasField('optional_int32'))
self.assertFalse(message.HasField("optional_bool")) self.assertFalse(message.HasField('optional_bool'))
self.assertFalse(message.HasField("optional_nested_message")) self.assertFalse(message.HasField('optional_nested_message'))
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
message.HasField("field_doesnt_exist") message.HasField('field_doesnt_exist')
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
message.HasField("repeated_int32") message.HasField('repeated_int32')
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
message.HasField("repeated_nested_message") message.HasField('repeated_nested_message')
self.assertEqual(0, message.optional_int32) self.assertEqual(0, message.optional_int32)
self.assertEqual(False, message.optional_bool) self.assertEqual(False, message.optional_bool)
@ -1291,27 +1284,27 @@ class Proto2Test(unittest.TestCase):
message.optional_int32 = 0 message.optional_int32 = 0
message.optional_bool = False message.optional_bool = False
message.optional_nested_message.bb = 0 message.optional_nested_message.bb = 0
self.assertTrue(message.HasField("optional_int32")) self.assertTrue(message.HasField('optional_int32'))
self.assertTrue(message.HasField("optional_bool")) self.assertTrue(message.HasField('optional_bool'))
self.assertTrue(message.HasField("optional_nested_message")) self.assertTrue(message.HasField('optional_nested_message'))
# Set the fields to non-default values. # Set the fields to non-default values.
message.optional_int32 = 5 message.optional_int32 = 5
message.optional_bool = True message.optional_bool = True
message.optional_nested_message.bb = 15 message.optional_nested_message.bb = 15
self.assertTrue(message.HasField(u"optional_int32")) self.assertTrue(message.HasField(u'optional_int32'))
self.assertTrue(message.HasField("optional_bool")) self.assertTrue(message.HasField('optional_bool'))
self.assertTrue(message.HasField("optional_nested_message")) self.assertTrue(message.HasField('optional_nested_message'))
# Clearing the fields unsets them and resets their value to default. # Clearing the fields unsets them and resets their value to default.
message.ClearField("optional_int32") message.ClearField('optional_int32')
message.ClearField(u"optional_bool") message.ClearField(u'optional_bool')
message.ClearField("optional_nested_message") message.ClearField('optional_nested_message')
self.assertFalse(message.HasField("optional_int32")) self.assertFalse(message.HasField('optional_int32'))
self.assertFalse(message.HasField("optional_bool")) self.assertFalse(message.HasField('optional_bool'))
self.assertFalse(message.HasField("optional_nested_message")) self.assertFalse(message.HasField('optional_nested_message'))
self.assertEqual(0, message.optional_int32) self.assertEqual(0, message.optional_int32)
self.assertEqual(False, message.optional_bool) self.assertEqual(False, message.optional_bool)
self.assertEqual(0, message.optional_nested_message.bb) self.assertEqual(0, message.optional_nested_message.bb)
@ -1361,16 +1354,17 @@ class Proto2Test(unittest.TestCase):
msg1 = more_extensions_pb2.TopLevelMessage() msg1 = more_extensions_pb2.TopLevelMessage()
msg2 = more_extensions_pb2.TopLevelMessage() msg2 = more_extensions_pb2.TopLevelMessage()
# Cpp extension will lazily create a sub message which is immutable. # Cpp extension will lazily create a sub message which is immutable.
self.assertEqual(0, msg1.submessage.Extensions[ self.assertEqual(
more_extensions_pb2.optional_int_extension]) 0,
msg1.submessage.Extensions[more_extensions_pb2.optional_int_extension])
self.assertFalse(msg1.HasField('submessage')) self.assertFalse(msg1.HasField('submessage'))
msg2.submessage.Extensions[ msg2.submessage.Extensions[more_extensions_pb2.optional_int_extension] = 123
more_extensions_pb2.optional_int_extension] = 123
# Make sure cmessage and extensions pointing to a mutable message # Make sure cmessage and extensions pointing to a mutable message
# after merge instead of the lazily created message. # after merge instead of the lazily created message.
msg1.MergeFrom(msg2) msg1.MergeFrom(msg2)
self.assertEqual(123, msg1.submessage.Extensions[ self.assertEqual(
more_extensions_pb2.optional_int_extension]) 123,
msg1.submessage.Extensions[more_extensions_pb2.optional_int_extension])
def testGoldenExtensions(self): def testGoldenExtensions(self):
golden_data = test_util.GoldenFileData('golden_message') golden_data = test_util.GoldenFileData('golden_message')
@ -1404,17 +1398,19 @@ class Proto2Test(unittest.TestCase):
# This is still an incomplete proto - so serializing should fail # This is still an incomplete proto - so serializing should fail
self.assertRaises(message.EncodeError, unpickled_message.SerializeToString) self.assertRaises(message.EncodeError, unpickled_message.SerializeToString)
# TODO(haberman): this isn't really a proto2-specific test except that this # TODO(haberman): this isn't really a proto2-specific test except that this
# message has a required field in it. Should probably be factored out so # message has a required field in it. Should probably be factored out so
# that we can test the other parts with proto3. # that we can test the other parts with proto3.
def testParsingMerge(self): def testParsingMerge(self):
"""Check the merge behavior when a required or optional field appears """Check the merge behavior when a required or optional field appears
multiple times in the input."""
multiple times in the input.
"""
messages = [ messages = [
unittest_pb2.TestAllTypes(), unittest_pb2.TestAllTypes(),
unittest_pb2.TestAllTypes(), unittest_pb2.TestAllTypes(),
unittest_pb2.TestAllTypes() ] unittest_pb2.TestAllTypes()
]
messages[0].optional_int32 = 1 messages[0].optional_int32 = 1
messages[1].optional_int64 = 2 messages[1].optional_int64 = 2
messages[2].optional_int32 = 3 messages[2].optional_int32 = 3
@ -1447,14 +1443,15 @@ class Proto2Test(unittest.TestCase):
self.assertEqual(parsing_merge.optional_all_types, merged_message) self.assertEqual(parsing_merge.optional_all_types, merged_message)
self.assertEqual(parsing_merge.optionalgroup.optional_group_all_types, self.assertEqual(parsing_merge.optionalgroup.optional_group_all_types,
merged_message) merged_message)
self.assertEqual(parsing_merge.Extensions[ self.assertEqual(
unittest_pb2.TestParsingMerge.optional_ext], parsing_merge.Extensions[unittest_pb2.TestParsingMerge.optional_ext],
merged_message) merged_message)
# Repeated fields should not be merged. # Repeated fields should not be merged.
self.assertEqual(len(parsing_merge.repeated_all_types), 3) self.assertEqual(len(parsing_merge.repeated_all_types), 3)
self.assertEqual(len(parsing_merge.repeatedgroup), 3) self.assertEqual(len(parsing_merge.repeatedgroup), 3)
self.assertEqual(len(parsing_merge.Extensions[ self.assertEqual(
len(parsing_merge.Extensions[
unittest_pb2.TestParsingMerge.repeated_ext]), 3) unittest_pb2.TestParsingMerge.repeated_ext]), 3)
def testPythonicInit(self): def testPythonicInit(self):
@ -1467,8 +1464,11 @@ class Proto2Test(unittest.TestCase):
optional_nested_message={'bb': 500}, optional_nested_message={'bb': 500},
optional_foreign_message={}, optional_foreign_message={},
optional_nested_enum='BAZ', optional_nested_enum='BAZ',
repeatedgroup=[{'a': 600}, repeatedgroup=[{
{'a': 700}], 'a': 600
}, {
'a': 700
}],
repeated_nested_enum=['FOO', unittest_pb2.TestAllTypes.BAR], repeated_nested_enum=['FOO', unittest_pb2.TestAllTypes.BAR],
default_int32=800, default_int32=800,
oneof_string='y') oneof_string='y')
@ -1848,8 +1848,7 @@ class Proto3Test(unittest.TestCase):
self.assertEqual(True, msg2.map_bool_bool[True]) self.assertEqual(True, msg2.map_bool_bool[True])
self.assertEqual(2, msg2.map_int32_enum[888]) self.assertEqual(2, msg2.map_int32_enum[888])
self.assertEqual(456, msg2.map_int32_enum[123]) self.assertEqual(456, msg2.map_int32_enum[123])
self.assertEqual('{-123: -456}', self.assertEqual('{-123: -456}', str(msg2.map_int32_int32))
str(msg2.map_int32_int32))
def testMapEntryAlwaysSerialized(self): def testMapEntryAlwaysSerialized(self):
msg = map_unittest_pb2.TestMap() msg = map_unittest_pb2.TestMap()
@ -1912,7 +1911,8 @@ class Proto3Test(unittest.TestCase):
self.assertEqual(2, len(msg2.map_int32_foreign_message)) self.assertEqual(2, len(msg2.map_int32_foreign_message))
msg2.map_int32_foreign_message[123].c = 1 msg2.map_int32_foreign_message[123].c = 1
# TODO(jieluo): Fix text format for message map. # TODO(jieluo): Fix text format for message map.
self.assertIn(str(msg2.map_int32_foreign_message), self.assertIn(
str(msg2.map_int32_foreign_message),
('{-456: , 123: c: 1\n}', '{123: c: 1\n, -456: }')) ('{-456: , 123: c: 1\n}', '{123: c: 1\n, -456: }'))
def testNestedMessageMapItemDelete(self): def testNestedMessageMapItemDelete(self):
@ -2041,8 +2041,7 @@ class Proto3Test(unittest.TestCase):
# Test when cpp extension cache a map. # Test when cpp extension cache a map.
m1 = map_unittest_pb2.TestMap() m1 = map_unittest_pb2.TestMap()
m2 = map_unittest_pb2.TestMap() m2 = map_unittest_pb2.TestMap()
self.assertEqual(m1.map_int32_foreign_message, self.assertEqual(m1.map_int32_foreign_message, m1.map_int32_foreign_message)
m1.map_int32_foreign_message)
m2.map_int32_foreign_message[123].c = 10 m2.map_int32_foreign_message[123].c = 10
m1.MergeFrom(m2) m1.MergeFrom(m2)
self.assertEqual(10, m2.map_int32_foreign_message[123].c) self.assertEqual(10, m2.map_int32_foreign_message[123].c)
@ -2167,6 +2166,34 @@ class Proto3Test(unittest.TestCase):
for key in int32_foreign_iter: for key in int32_foreign_iter:
pass pass
def testModifyMapEntryWhileIterating(self):
msg = map_unittest_pb2.TestMap()
msg.map_string_string['abc'] = '123'
msg.map_string_string['def'] = '456'
msg.map_string_string['ghi'] = '789'
msg.map_int32_foreign_message[5].c = 5
msg.map_int32_foreign_message[6].c = 6
msg.map_int32_foreign_message[7].c = 7
string_string_keys = list(msg.map_string_string.keys())
int32_foreign_keys = list(msg.map_int32_foreign_message.keys())
keys = []
for key in msg.map_string_string:
keys.append(key)
msg.map_string_string[key] = '000'
self.assertEqual(keys, string_string_keys)
self.assertEqual(keys, list(msg.map_string_string.keys()))
keys = []
for key in msg.map_int32_foreign_message:
keys.append(key)
msg.map_int32_foreign_message[key].c = 0
self.assertEqual(keys, int32_foreign_keys)
self.assertEqual(keys, list(msg.map_int32_foreign_message.keys()))
def testSubmessageMap(self): def testSubmessageMap(self):
msg = map_unittest_pb2.TestMap() msg = map_unittest_pb2.TestMap()
@ -2413,24 +2440,21 @@ class Proto3Test(unittest.TestCase):
msg2.MergeFromString(serialized) msg2.MergeFromString(serialized)
self.assertEqual(msg2.optional_string, u'😍') self.assertEqual(msg2.optional_string, u'😍')
msg = unittest_proto3_arena_pb2.TestAllTypes( msg = unittest_proto3_arena_pb2.TestAllTypes(optional_string=u'\ud001')
optional_string=u'\ud001')
self.assertEqual(msg.optional_string, u'\ud001') self.assertEqual(msg.optional_string, u'\ud001')
def testSurrogatesInPython3(self): def testSurrogatesInPython3(self):
# Surrogates are rejected at setters in Python3. # Surrogates are rejected at setters in Python3.
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
unittest_proto3_arena_pb2.TestAllTypes( unittest_proto3_arena_pb2.TestAllTypes(optional_string=u'\ud801\udc01')
optional_string=u'\ud801\udc01')
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
unittest_proto3_arena_pb2.TestAllTypes( unittest_proto3_arena_pb2.TestAllTypes(optional_string=b'\xed\xa0\x81')
optional_string=b'\xed\xa0\x81')
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
unittest_proto3_arena_pb2.TestAllTypes( unittest_proto3_arena_pb2.TestAllTypes(optional_string=u'\ud801')
optional_string=u'\ud801')
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
unittest_proto3_arena_pb2.TestAllTypes( unittest_proto3_arena_pb2.TestAllTypes(optional_string=u'\ud801\ud801')
optional_string=u'\ud801\ud801')
@testing_refleaks.TestCase @testing_refleaks.TestCase
@ -2441,7 +2465,8 @@ class ValidTypeNamesTest(unittest.TestCase):
tp_name = str(type(msg)).split("'")[1] tp_name = str(type(msg)).split("'")[1]
valid_names = ('Repeated%sContainer' % base_name, valid_names = ('Repeated%sContainer' % base_name,
'Repeated%sFieldContainer' % base_name) 'Repeated%sFieldContainer' % base_name)
self.assertTrue(any(tp_name.endswith(v) for v in valid_names), self.assertTrue(
any(tp_name.endswith(v) for v in valid_names),
'%r does end with any of %r' % (tp_name, valid_names)) '%r does end with any of %r' % (tp_name, valid_names))
parts = tp_name.split('.') parts = tp_name.split('.')
@ -2455,6 +2480,7 @@ class ValidTypeNamesTest(unittest.TestCase):
self.assertImportFromName(pb.repeated_int32, 'Scalar') self.assertImportFromName(pb.repeated_int32, 'Scalar')
self.assertImportFromName(pb.repeated_nested_message, 'Composite') self.assertImportFromName(pb.repeated_nested_message, 'Composite')
@testing_refleaks.TestCase @testing_refleaks.TestCase
class PackedFieldTest(unittest.TestCase): class PackedFieldTest(unittest.TestCase):
@ -2574,5 +2600,6 @@ class OversizeProtosTest(unittest.TestCase):
q.ParseFromString(self.p_serialized) q.ParseFromString(self.p_serialized)
self.assertEqual(self.p.field.payload, q.field.payload) self.assertEqual(self.p.field.payload, q.field.payload)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -50,6 +50,15 @@ from google.protobuf import descriptor
from google.protobuf import text_format from google.protobuf import text_format
from google.protobuf.internal import _parameterized from google.protobuf.internal import _parameterized
try:
# New module in Python 3.9:
import zoneinfo # pylint:disable=g-import-not-at-top
_TZ_JAPAN = zoneinfo.ZoneInfo('Japan')
_TZ_PACIFIC = zoneinfo.ZoneInfo('US/Pacific')
except ImportError:
_TZ_JAPAN = datetime.timezone(datetime.timedelta(hours=9), 'Japan')
_TZ_PACIFIC = datetime.timezone(datetime.timedelta(hours=-8), 'US/Pacific')
class TimeUtilTestBase(_parameterized.TestCase): class TimeUtilTestBase(_parameterized.TestCase):
@ -265,25 +274,37 @@ class TimeUtilTest(TimeUtilTestBase):
message.FromDatetime(naive_end_of_time) message.FromDatetime(naive_end_of_time)
self.assertEqual(naive_end_of_time, message.ToDatetime()) self.assertEqual(naive_end_of_time, message.ToDatetime())
def testDatetimeConversionWithTimezone(self): # Two hours after the Unix Epoch, around the world.
class TZ(datetime.tzinfo): @_parameterized.named_parameters(
('London', [1970, 1, 1, 2], datetime.timezone.utc),
('Tokyo', [1970, 1, 1, 11], _TZ_JAPAN),
('LA', [1969, 12, 31, 18], _TZ_PACIFIC),
)
def testTimezoneAwareDatetimeConversion(self, date_parts, tzinfo):
original_datetime = datetime.datetime(*date_parts, tzinfo=tzinfo) # pylint:disable=g-tzinfo-datetime
def utcoffset(self, _): message = timestamp_pb2.Timestamp()
return datetime.timedelta(hours=1) message.FromDatetime(original_datetime)
self.assertEqual(7200, message.seconds)
self.assertEqual(0, message.nanos)
def dst(self, _): # ToDatetime() with no parameters produces a naive UTC datetime, i.e. it not
return datetime.timedelta(0) # only loses the original timezone information (e.g. US/Pacific) as it's
# "normalised" to UTC, but also drops the information that the datetime
# represents a UTC one.
naive_datetime = message.ToDatetime()
self.assertEqual(datetime.datetime(1970, 1, 1, 2), naive_datetime)
self.assertIsNone(naive_datetime.tzinfo)
self.assertNotEqual(original_datetime, naive_datetime) # not even for UTC!
def tzname(self, _): # In contrast, ToDatetime(tzinfo=) produces an aware datetime in the given
return 'UTC+1' # timezone.
aware_datetime = message.ToDatetime(tzinfo=tzinfo)
message1 = timestamp_pb2.Timestamp() self.assertEqual(original_datetime, aware_datetime)
dt = datetime.datetime(1970, 1, 1, 1, tzinfo=TZ()) self.assertEqual(
message1.FromDatetime(dt) datetime.datetime(1970, 1, 1, 2, tzinfo=datetime.timezone.utc),
message2 = timestamp_pb2.Timestamp() aware_datetime)
dt = datetime.datetime(1970, 1, 1, 0) self.assertEqual(tzinfo, aware_datetime.tzinfo)
message2.FromDatetime(dt)
self.assertEqual(message1, message2)
def testTimedeltaConversion(self): def testTimedeltaConversion(self):
message = duration_pb2.Duration() message = duration_pb2.Duration()
@ -310,82 +331,61 @@ class TimeUtilTest(TimeUtilTestBase):
def testInvalidTimestamp(self): def testInvalidTimestamp(self):
message = timestamp_pb2.Timestamp() message = timestamp_pb2.Timestamp()
self.assertRaisesRegex( self.assertRaisesRegex(
ValueError, ValueError, 'Failed to parse timestamp: missing valid timezone offset.',
'Failed to parse timestamp: missing valid timezone offset.', message.FromJsonString, '')
message.FromJsonString,
'')
self.assertRaisesRegex( self.assertRaisesRegex(
ValueError, ValueError, 'Failed to parse timestamp: invalid trailing data '
'Failed to parse timestamp: invalid trailing data ' '1970-01-01T00:00:01Ztrail.', message.FromJsonString,
'1970-01-01T00:00:01Ztrail.',
message.FromJsonString,
'1970-01-01T00:00:01Ztrail') '1970-01-01T00:00:01Ztrail')
self.assertRaisesRegex( self.assertRaisesRegex(
ValueError, ValueError, 'time data \'10000-01-01T00:00:00\' does not match'
'time data \'10000-01-01T00:00:00\' does not match' ' format \'%Y-%m-%dT%H:%M:%S\'', message.FromJsonString,
' format \'%Y-%m-%dT%H:%M:%S\'', '10000-01-01T00:00:00.00Z')
message.FromJsonString, '10000-01-01T00:00:00.00Z')
self.assertRaisesRegex( self.assertRaisesRegex(
ValueError, ValueError, 'nanos 0123456789012 more than 9 fractional digits.',
'nanos 0123456789012 more than 9 fractional digits.', message.FromJsonString, '1970-01-01T00:00:00.0123456789012Z')
message.FromJsonString,
'1970-01-01T00:00:00.0123456789012Z')
self.assertRaisesRegex( self.assertRaisesRegex(
ValueError, ValueError,
(r'Invalid timezone offset value: \+08.'), (r'Invalid timezone offset value: \+08.'),
message.FromJsonString, message.FromJsonString,
'1972-01-01T01:00:00.01+08',) '1972-01-01T01:00:00.01+08',
self.assertRaisesRegex( )
ValueError, self.assertRaisesRegex(ValueError, 'year (0 )?is out of range',
'year (0 )?is out of range', message.FromJsonString, '0000-01-01T00:00:00Z')
message.FromJsonString,
'0000-01-01T00:00:00Z')
message.seconds = 253402300800 message.seconds = 253402300800
self.assertRaisesRegex( self.assertRaisesRegex(OverflowError, 'date value out of range',
OverflowError,
'date value out of range',
message.ToJsonString) message.ToJsonString)
def testInvalidDuration(self): def testInvalidDuration(self):
message = duration_pb2.Duration() message = duration_pb2.Duration()
self.assertRaisesRegex( self.assertRaisesRegex(ValueError, 'Duration must end with letter "s": 1.',
ValueError,
'Duration must end with letter "s": 1.',
message.FromJsonString, '1') message.FromJsonString, '1')
self.assertRaisesRegex( self.assertRaisesRegex(ValueError, 'Couldn\'t parse duration: 1...2s.',
ValueError,
'Couldn\'t parse duration: 1...2s.',
message.FromJsonString, '1...2s') message.FromJsonString, '1...2s')
text = '-315576000001.000000000s' text = '-315576000001.000000000s'
self.assertRaisesRegex( self.assertRaisesRegex(
ValueError, ValueError,
r'Duration is not valid\: Seconds -315576000001 must be in range' r'Duration is not valid\: Seconds -315576000001 must be in range'
r' \[-315576000000\, 315576000000\].', r' \[-315576000000\, 315576000000\].', message.FromJsonString, text)
message.FromJsonString, text)
text = '315576000001.000000000s' text = '315576000001.000000000s'
self.assertRaisesRegex( self.assertRaisesRegex(
ValueError, ValueError,
r'Duration is not valid\: Seconds 315576000001 must be in range' r'Duration is not valid\: Seconds 315576000001 must be in range'
r' \[-315576000000\, 315576000000\].', r' \[-315576000000\, 315576000000\].', message.FromJsonString, text)
message.FromJsonString, text)
message.seconds = -315576000001 message.seconds = -315576000001
message.nanos = 0 message.nanos = 0
self.assertRaisesRegex( self.assertRaisesRegex(
ValueError, ValueError,
r'Duration is not valid\: Seconds -315576000001 must be in range' r'Duration is not valid\: Seconds -315576000001 must be in range'
r' \[-315576000000\, 315576000000\].', r' \[-315576000000\, 315576000000\].', message.ToJsonString)
message.ToJsonString)
message.seconds = 0 message.seconds = 0
message.nanos = 999999999 + 1 message.nanos = 999999999 + 1
self.assertRaisesRegex( self.assertRaisesRegex(
ValueError, ValueError, r'Duration is not valid\: Nanos 1000000000 must be in range'
r'Duration is not valid\: Nanos 1000000000 must be in range' r' \[-999999999\, 999999999\].', message.ToJsonString)
r' \[-999999999\, 999999999\].',
message.ToJsonString)
message.seconds = -1 message.seconds = -1
message.nanos = 1 message.nanos = 1
self.assertRaisesRegex( self.assertRaisesRegex(ValueError,
ValueError,
r'Duration is not valid\: Sign mismatch.', r'Duration is not valid\: Sign mismatch.',
message.ToJsonString) message.ToJsonString)
@ -713,8 +713,7 @@ class FieldMaskTest(unittest.TestCase):
ValueError, ValueError,
'Fail to print FieldMask to Json string: Path name Foo must ' 'Fail to print FieldMask to Json string: Path name Foo must '
'not contain uppercase letters.', 'not contain uppercase letters.',
well_known_types._SnakeCaseToCamelCase, well_known_types._SnakeCaseToCamelCase, 'Foo')
'Foo')
# Any character after a "_" must be a lowercase letter. # Any character after a "_" must be a lowercase letter.
# 1. "_" cannot be followed by another "_". # 1. "_" cannot be followed by another "_".
# 2. "_" cannot be followed by a digit. # 2. "_" cannot be followed by a digit.
@ -723,20 +722,16 @@ class FieldMaskTest(unittest.TestCase):
ValueError, ValueError,
'Fail to print FieldMask to Json string: The character after a ' 'Fail to print FieldMask to Json string: The character after a '
'"_" must be a lowercase letter in path name foo__bar.', '"_" must be a lowercase letter in path name foo__bar.',
well_known_types._SnakeCaseToCamelCase, well_known_types._SnakeCaseToCamelCase, 'foo__bar')
'foo__bar')
self.assertRaisesRegex( self.assertRaisesRegex(
ValueError, ValueError,
'Fail to print FieldMask to Json string: The character after a ' 'Fail to print FieldMask to Json string: The character after a '
'"_" must be a lowercase letter in path name foo_3bar.', '"_" must be a lowercase letter in path name foo_3bar.',
well_known_types._SnakeCaseToCamelCase, well_known_types._SnakeCaseToCamelCase, 'foo_3bar')
'foo_3bar')
self.assertRaisesRegex( self.assertRaisesRegex(
ValueError, ValueError,
'Fail to print FieldMask to Json string: Trailing "_" in path ' 'Fail to print FieldMask to Json string: Trailing "_" in path '
'name foo_bar_.', 'name foo_bar_.', well_known_types._SnakeCaseToCamelCase, 'foo_bar_')
well_known_types._SnakeCaseToCamelCase,
'foo_bar_')
def testCamelCaseToSnakeCase(self): def testCamelCaseToSnakeCase(self):
self.assertEqual('foo_bar', self.assertEqual('foo_bar',
@ -748,8 +743,7 @@ class FieldMaskTest(unittest.TestCase):
self.assertRaisesRegex( self.assertRaisesRegex(
ValueError, ValueError,
'Fail to parse FieldMask: Path name foo_bar must not contain "_"s.', 'Fail to parse FieldMask: Path name foo_bar must not contain "_"s.',
well_known_types._CamelCaseToSnakeCase, well_known_types._CamelCaseToSnakeCase, 'foo_bar')
'foo_bar')
class StructTest(unittest.TestCase): class StructTest(unittest.TestCase):

View File

@ -194,6 +194,9 @@ class Message(object):
"""Parse serialized protocol buffer data into this message. """Parse serialized protocol buffer data into this message.
Like :func:`MergeFromString()`, except we clear the object first. Like :func:`MergeFromString()`, except we clear the object first.
Raises:
message.DecodeError if the input cannot be parsed.
""" """
self.Clear() self.Clear()
return self.MergeFromString(serialized) return self.MergeFromString(serialized)

View File

@ -420,7 +420,11 @@ PyTypeObject PyBaseDescriptor_Type = {
sizeof(PyBaseDescriptor), // tp_basicsize sizeof(PyBaseDescriptor), // tp_basicsize
0, // tp_itemsize 0, // tp_itemsize
(destructor)Dealloc, // tp_dealloc (destructor)Dealloc, // tp_dealloc
0, // tp_print #if PY_VERSION_HEX < 0x03080000
nullptr, // tp_print
#else
0, // tp_vectorcall_offset
#endif
nullptr, // tp_getattr nullptr, // tp_getattr
nullptr, // tp_setattr nullptr, // tp_setattr
nullptr, // tp_compare nullptr, // tp_compare
@ -690,7 +694,11 @@ PyTypeObject PyMessageDescriptor_Type = {
sizeof(PyBaseDescriptor), // tp_basicsize sizeof(PyBaseDescriptor), // tp_basicsize
0, // tp_itemsize 0, // tp_itemsize
nullptr, // tp_dealloc nullptr, // tp_dealloc
0, // tp_print #if PY_VERSION_HEX < 0x03080000
nullptr, // tp_print
#else
0, // tp_vectorcall_offset
#endif
nullptr, // tp_getattr nullptr, // tp_getattr
nullptr, // tp_setattr nullptr, // tp_setattr
nullptr, // tp_compare nullptr, // tp_compare
@ -1016,7 +1024,11 @@ PyTypeObject PyFieldDescriptor_Type = {
sizeof(PyBaseDescriptor), // tp_basicsize sizeof(PyBaseDescriptor), // tp_basicsize
0, // tp_itemsize 0, // tp_itemsize
nullptr, // tp_dealloc nullptr, // tp_dealloc
0, // tp_print #if PY_VERSION_HEX < 0x03080000
nullptr, // tp_print
#else
0, // tp_vectorcall_offset
#endif
nullptr, // tp_getattr nullptr, // tp_getattr
nullptr, // tp_setattr nullptr, // tp_setattr
nullptr, // tp_compare nullptr, // tp_compare
@ -1172,7 +1184,11 @@ PyTypeObject PyEnumDescriptor_Type = {
sizeof(PyBaseDescriptor), // tp_basicsize sizeof(PyBaseDescriptor), // tp_basicsize
0, // tp_itemsize 0, // tp_itemsize
nullptr, // tp_dealloc nullptr, // tp_dealloc
0, // tp_print #if PY_VERSION_HEX < 0x03080000
nullptr, // tp_print
#else
0, // tp_vectorcall_offset
#endif
nullptr, // tp_getattr nullptr, // tp_getattr
nullptr, // tp_setattr nullptr, // tp_setattr
nullptr, // tp_compare nullptr, // tp_compare
@ -1293,7 +1309,11 @@ PyTypeObject PyEnumValueDescriptor_Type = {
sizeof(PyBaseDescriptor), // tp_basicsize sizeof(PyBaseDescriptor), // tp_basicsize
0, // tp_itemsize 0, // tp_itemsize
nullptr, // tp_dealloc nullptr, // tp_dealloc
0, // tp_print #if PY_VERSION_HEX < 0x03080000
nullptr, // tp_print
#else
0, // tp_vectorcall_offset
#endif
nullptr, // tp_getattr nullptr, // tp_getattr
nullptr, // tp_setattr nullptr, // tp_setattr
nullptr, // tp_compare nullptr, // tp_compare
@ -1478,7 +1498,11 @@ PyTypeObject PyFileDescriptor_Type = {
sizeof(PyFileDescriptor), // tp_basicsize sizeof(PyFileDescriptor), // tp_basicsize
0, // tp_itemsize 0, // tp_itemsize
(destructor)file_descriptor::Dealloc, // tp_dealloc (destructor)file_descriptor::Dealloc, // tp_dealloc
0, // tp_print #if PY_VERSION_HEX < 0x03080000
nullptr, // tp_print
#else
0, // tp_vectorcall_offset
#endif
nullptr, // tp_getattr nullptr, // tp_getattr
nullptr, // tp_setattr nullptr, // tp_setattr
nullptr, // tp_compare nullptr, // tp_compare
@ -1639,7 +1663,11 @@ PyTypeObject PyOneofDescriptor_Type = {
sizeof(PyBaseDescriptor), // tp_basicsize sizeof(PyBaseDescriptor), // tp_basicsize
0, // tp_itemsize 0, // tp_itemsize
nullptr, // tp_dealloc nullptr, // tp_dealloc
0, // tp_print #if PY_VERSION_HEX < 0x03080000
nullptr, // tp_print
#else
0, // tp_vectorcall_offset
#endif
nullptr, // tp_getattr nullptr, // tp_getattr
nullptr, // tp_setattr nullptr, // tp_setattr
nullptr, // tp_compare nullptr, // tp_compare
@ -1757,7 +1785,11 @@ PyTypeObject PyServiceDescriptor_Type = {
sizeof(PyBaseDescriptor), // tp_basicsize sizeof(PyBaseDescriptor), // tp_basicsize
0, // tp_itemsize 0, // tp_itemsize
nullptr, // tp_dealloc nullptr, // tp_dealloc
0, // tp_print #if PY_VERSION_HEX < 0x03080000
nullptr, // tp_print
#else
0, // tp_vectorcall_offset
#endif
nullptr, // tp_getattr nullptr, // tp_getattr
nullptr, // tp_setattr nullptr, // tp_setattr
nullptr, // tp_compare nullptr, // tp_compare
@ -1881,7 +1913,11 @@ PyTypeObject PyMethodDescriptor_Type = {
sizeof(PyBaseDescriptor), // tp_basicsize sizeof(PyBaseDescriptor), // tp_basicsize
0, // tp_itemsize 0, // tp_itemsize
nullptr, // tp_dealloc nullptr, // tp_dealloc
0, // tp_print #if PY_VERSION_HEX < 0x03080000
nullptr, // tp_print
#else
0, // tp_vectorcall_offset
#endif
nullptr, // tp_getattr nullptr, // tp_getattr
nullptr, // tp_setattr nullptr, // tp_setattr
nullptr, // tp_compare nullptr, // tp_compare

View File

@ -550,7 +550,11 @@ PyTypeObject DescriptorMapping_Type = {
sizeof(PyContainer), // tp_basicsize sizeof(PyContainer), // tp_basicsize
0, // tp_itemsize 0, // tp_itemsize
nullptr, // tp_dealloc nullptr, // tp_dealloc
0, // tp_pkrint #if PY_VERSION_HEX < 0x03080000
nullptr, // tp_print
#else
0, // tp_vectorcall_offset
#endif
nullptr, // tp_getattr nullptr, // tp_getattr
nullptr, // tp_setattr nullptr, // tp_setattr
nullptr, // tp_compare nullptr, // tp_compare
@ -733,7 +737,11 @@ PyTypeObject DescriptorSequence_Type = {
sizeof(PyContainer), // tp_basicsize sizeof(PyContainer), // tp_basicsize
0, // tp_itemsize 0, // tp_itemsize
nullptr, // tp_dealloc nullptr, // tp_dealloc
0, // tp_print #if PY_VERSION_HEX < 0x03080000
nullptr, // tp_print
#else
0, // tp_vectorcall_offset
#endif
nullptr, // tp_getattr nullptr, // tp_getattr
nullptr, // tp_setattr nullptr, // tp_setattr
nullptr, // tp_compare nullptr, // tp_compare
@ -875,7 +883,11 @@ static PyTypeObject ContainerIterator_Type = {
sizeof(PyContainerIterator), // tp_basicsize sizeof(PyContainerIterator), // tp_basicsize
0, // tp_itemsize 0, // tp_itemsize
(destructor)Iterator_Dealloc, // tp_dealloc (destructor)Iterator_Dealloc, // tp_dealloc
0, // tp_print #if PY_VERSION_HEX < 0x03080000
nullptr, // tp_print
#else
0, // tp_vectorcall_offset
#endif
nullptr, // tp_getattr nullptr, // tp_getattr
nullptr, // tp_setattr nullptr, // tp_setattr
nullptr, // tp_compare nullptr, // tp_compare

View File

@ -694,7 +694,11 @@ PyTypeObject PyDescriptorPool_Type = {
sizeof(PyDescriptorPool), // tp_basicsize sizeof(PyDescriptorPool), // tp_basicsize
0, // tp_itemsize 0, // tp_itemsize
cdescriptor_pool::Dealloc, // tp_dealloc cdescriptor_pool::Dealloc, // tp_dealloc
0, // tp_print #if PY_VERSION_HEX < 0x03080000
nullptr, // tp_print
#else
0, // tp_vectorcall_offset
#endif
nullptr, // tp_getattr nullptr, // tp_getattr
nullptr, // tp_setattr nullptr, // tp_setattr
nullptr, // tp_compare nullptr, // tp_compare

View File

@ -372,7 +372,11 @@ PyTypeObject ExtensionDict_Type = {
sizeof(ExtensionDict), // tp_basicsize sizeof(ExtensionDict), // tp_basicsize
0, // tp_itemsize 0, // tp_itemsize
(destructor)extension_dict::dealloc, // tp_dealloc (destructor)extension_dict::dealloc, // tp_dealloc
0, // tp_print #if PY_VERSION_HEX < 0x03080000
nullptr, // tp_print
#else
0, // tp_vectorcall_offset
#endif
nullptr, // tp_getattr nullptr, // tp_getattr
nullptr, // tp_setattr nullptr, // tp_setattr
nullptr, // tp_compare nullptr, // tp_compare
@ -439,7 +443,11 @@ PyTypeObject ExtensionIterator_Type = {
sizeof(extension_dict::ExtensionIterator), // tp_basicsize sizeof(extension_dict::ExtensionIterator), // tp_basicsize
0, // tp_itemsize 0, // tp_itemsize
extension_dict::DeallocExtensionIterator, // tp_dealloc extension_dict::DeallocExtensionIterator, // tp_dealloc
0, // tp_print #if PY_VERSION_HEX < 0x03080000
nullptr, // tp_print
#else
0, // tp_vectorcall_offset
#endif
nullptr, // tp_getattr nullptr, // tp_getattr
nullptr, // tp_setattr nullptr, // tp_setattr
nullptr, // tp_compare nullptr, // tp_compare

View File

@ -87,7 +87,11 @@ static PyTypeObject _CFieldProperty_Type = {
sizeof(PyMessageFieldProperty), // tp_basicsize sizeof(PyMessageFieldProperty), // tp_basicsize
0, // tp_itemsize 0, // tp_itemsize
nullptr, // tp_dealloc nullptr, // tp_dealloc
0, // tp_print #if PY_VERSION_HEX < 0x03080000
nullptr, // tp_print
#else
0, // tp_vectorcall_offset
#endif
nullptr, // tp_getattr nullptr, // tp_getattr
nullptr, // tp_setattr nullptr, // tp_setattr
nullptr, // tp_compare nullptr, // tp_compare

View File

@ -432,12 +432,12 @@ int MapReflectionFriend::ScalarMapSetItem(PyObject* _self, PyObject* key,
return -1; return -1;
} }
self->version++;
if (v) { if (v) {
// Set item to v. // Set item to v.
reflection->InsertOrLookupMapValue(message, self->parent_field_descriptor, if (reflection->InsertOrLookupMapValue(
map_key, &value); message, self->parent_field_descriptor, map_key, &value)) {
self->version++;
}
if (!PythonToMapValueRef(self, v, reflection->SupportsUnknownEnumValues(), if (!PythonToMapValueRef(self, v, reflection->SupportsUnknownEnumValues(),
&value)) { &value)) {
@ -448,6 +448,7 @@ int MapReflectionFriend::ScalarMapSetItem(PyObject* _self, PyObject* key,
// Delete key from map. // Delete key from map.
if (reflection->DeleteMapValue(message, self->parent_field_descriptor, if (reflection->DeleteMapValue(message, self->parent_field_descriptor,
map_key)) { map_key)) {
self->version++;
return 0; return 0;
} else { } else {
PyErr_Format(PyExc_KeyError, "Key not present in map"); PyErr_Format(PyExc_KeyError, "Key not present in map");
@ -857,7 +858,11 @@ PyTypeObject MapIterator_Type = {
sizeof(MapIterator), // tp_basicsize sizeof(MapIterator), // tp_basicsize
0, // tp_itemsize 0, // tp_itemsize
DeallocMapIterator, // tp_dealloc DeallocMapIterator, // tp_dealloc
0, // tp_print #if PY_VERSION_HEX < 0x03080000
nullptr, // tp_print
#else
0, // tp_vectorcall_offset
#endif
nullptr, // tp_getattr nullptr, // tp_getattr
nullptr, // tp_setattr nullptr, // tp_setattr
nullptr, // tp_compare nullptr, // tp_compare

View File

@ -458,7 +458,11 @@ static PyTypeObject _CMessageClass_Type = {
sizeof(CMessageClass), // tp_basicsize sizeof(CMessageClass), // tp_basicsize
0, // tp_itemsize 0, // tp_itemsize
message_meta::Dealloc, // tp_dealloc message_meta::Dealloc, // tp_dealloc
0, // tp_print #if PY_VERSION_HEX < 0x03080000
nullptr, /* tp_print */
#else
0, /* tp_vectorcall_offset */
#endif
nullptr, // tp_getattr nullptr, // tp_getattr
nullptr, // tp_setattr nullptr, // tp_setattr
nullptr, // tp_compare nullptr, // tp_compare
@ -1930,7 +1934,7 @@ static PyObject* MergeFromString(CMessage* self, PyObject* arg) {
"Error parsing message as the message exceeded the protobuf limit " "Error parsing message as the message exceeded the protobuf limit "
"with type '%s'", "with type '%s'",
self->GetMessageClass()->message_descriptor->full_name().c_str()); self->GetMessageClass()->message_descriptor->full_name().c_str());
return NULL; return nullptr;
} }
// ctx has an explicit limit set (length of string_view), so we have to // ctx has an explicit limit set (length of string_view), so we have to
@ -2735,7 +2739,11 @@ static CMessageClass _CMessage_Type = {{{
sizeof(CMessage), // tp_basicsize sizeof(CMessage), // tp_basicsize
0, // tp_itemsize 0, // tp_itemsize
(destructor)cmessage::Dealloc, // tp_dealloc (destructor)cmessage::Dealloc, // tp_dealloc
0, // tp_print #if PY_VERSION_HEX < 0x03080000
nullptr, // tp_print
#else
0, // tp_vectorcall_offset
#endif
nullptr, // tp_getattr nullptr, // tp_getattr
nullptr, // tp_setattr nullptr, // tp_setattr
nullptr, // tp_compare nullptr, // tp_compare

View File

@ -253,7 +253,11 @@ PyTypeObject PyMessageFactory_Type = {
sizeof(PyMessageFactory), // tp_basicsize sizeof(PyMessageFactory), // tp_basicsize
0, // tp_itemsize 0, // tp_itemsize
message_factory::Dealloc, // tp_dealloc message_factory::Dealloc, // tp_dealloc
0, // tp_print #if PY_VERSION_HEX < 0x03080000
nullptr, // tp_print
#else
0, // tp_vectorcall_offset
#endif
nullptr, // tp_getattr nullptr, // tp_getattr
nullptr, // tp_setattr nullptr, // tp_setattr
nullptr, // tp_compare nullptr, // tp_compare

View File

@ -160,7 +160,11 @@ PyTypeObject PyUnknownFields_Type = {
sizeof(PyUnknownFields), // tp_basicsize sizeof(PyUnknownFields), // tp_basicsize
0, // tp_itemsize 0, // tp_itemsize
unknown_fields::Dealloc, // tp_dealloc unknown_fields::Dealloc, // tp_dealloc
0, // tp_print #if PY_VERSION_HEX < 0x03080000
nullptr, // tp_print
#else
0, // tp_vectorcall_offset
#endif
nullptr, // tp_getattr nullptr, // tp_getattr
nullptr, // tp_setattr nullptr, // tp_setattr
nullptr, // tp_compare nullptr, // tp_compare
@ -317,7 +321,11 @@ PyTypeObject PyUnknownFieldRef_Type = {
sizeof(PyUnknownFieldRef), // tp_basicsize sizeof(PyUnknownFieldRef), // tp_basicsize
0, // tp_itemsize 0, // tp_itemsize
unknown_field::Dealloc, // tp_dealloc unknown_field::Dealloc, // tp_dealloc
0, // tp_print #if PY_VERSION_HEX < 0x03080000
nullptr, // tp_print
#else
0, // tp_vectorcall_offset
#endif
nullptr, // tp_getattr nullptr, // tp_getattr
nullptr, // tp_setattr nullptr, // tp_setattr
nullptr, // tp_compare nullptr, // tp_compare

View File

@ -145,19 +145,20 @@ class _ServiceBuilder(object):
# instance to the method that does the real CallMethod work. # instance to the method that does the real CallMethod work.
# Making sure to use exact argument names from the abstract interface in # Making sure to use exact argument names from the abstract interface in
# service.py to match the type signature # service.py to match the type signature
def _WrapCallMethod(self, method_descriptor, def _WrapCallMethod(self, method_descriptor, rpc_controller, request, done):
rpc_controller, request, done): return builder._CallMethod(self, method_descriptor, rpc_controller,
return builder._CallMethod(self, method_descriptor, request, done)
rpc_controller, request, done)
def _WrapGetRequestClass(self, method_descriptor): def _WrapGetRequestClass(self, method_descriptor):
return builder._GetRequestClass(method_descriptor) return builder._GetRequestClass(method_descriptor)
def _WrapGetResponseClass(self, method_descriptor): def _WrapGetResponseClass(self, method_descriptor):
return builder._GetResponseClass(method_descriptor) return builder._GetResponseClass(method_descriptor)
builder.cls = cls builder.cls = cls
cls.CallMethod = _WrapCallMethod cls.CallMethod = _WrapCallMethod
cls.GetDescriptor = staticmethod(lambda: builder.descriptor) cls.GetDescriptor = staticmethod(lambda: builder.descriptor)
cls.GetDescriptor.__doc__ = "Returns the service descriptor." cls.GetDescriptor.__doc__ = 'Returns the service descriptor.'
cls.GetRequestClass = _WrapGetRequestClass cls.GetRequestClass = _WrapGetRequestClass
cls.GetResponseClass = _WrapGetResponseClass cls.GetResponseClass = _WrapGetResponseClass
for method in builder.descriptor.methods: for method in builder.descriptor.methods:

View File

@ -1,6 +1,9 @@
#! /usr/bin/env python #! /usr/bin/env python
# #
# See README for usage instructions. # See README for usage instructions.
# pylint:disable=missing-module-docstring
# pylint:disable=g-bad-import-order
from distutils import util from distutils import util
import fnmatch import fnmatch
import glob import glob
@ -10,7 +13,9 @@ import re
import subprocess import subprocess
import sys import sys
import sysconfig import sysconfig
import platform
# pylint:disable=g-importing-member
# pylint:disable=g-multiple-import
# We must use setuptools, not distutils, because we need to use the # We must use setuptools, not distutils, because we need to use the
# namespace_packages option for the "google" package. # namespace_packages option for the "google" package.
@ -24,44 +29,52 @@ from distutils.spawn import find_executable
# Find the Protocol Compiler. # Find the Protocol Compiler.
if 'PROTOC' in os.environ and os.path.exists(os.environ['PROTOC']): if 'PROTOC' in os.environ and os.path.exists(os.environ['PROTOC']):
protoc = os.environ['PROTOC'] protoc = os.environ['PROTOC']
elif os.path.exists("../src/protoc"): elif os.path.exists('../src/protoc'):
protoc = "../src/protoc" protoc = '../src/protoc'
elif os.path.exists("../src/protoc.exe"): elif os.path.exists('../src/protoc.exe'):
protoc = "../src/protoc.exe" protoc = '../src/protoc.exe'
elif os.path.exists("../vsprojects/Debug/protoc.exe"): elif os.path.exists('../vsprojects/Debug/protoc.exe'):
protoc = "../vsprojects/Debug/protoc.exe" protoc = '../vsprojects/Debug/protoc.exe'
elif os.path.exists("../vsprojects/Release/protoc.exe"): elif os.path.exists('../vsprojects/Release/protoc.exe'):
protoc = "../vsprojects/Release/protoc.exe" protoc = '../vsprojects/Release/protoc.exe'
else: else:
protoc = find_executable("protoc") protoc = find_executable('protoc')
def GetVersion(): def GetVersion():
"""Gets the version from google/protobuf/__init__.py """Reads and returns the version from google/protobuf/__init__.py.
Do not import google.protobuf.__init__ directly, because an installed Do not import google.protobuf.__init__ directly, because an installed
protobuf library may be loaded instead.""" protobuf library may be loaded instead.
Returns:
The version.
"""
with open(os.path.join('google', 'protobuf', '__init__.py')) as version_file: with open(os.path.join('google', 'protobuf', '__init__.py')) as version_file:
exec(version_file.read(), globals()) exec(version_file.read(), globals()) # pylint:disable=exec-used
global __version__ return __version__ # pylint:disable=undefined-variable
return __version__
def generate_proto(source, require = True): def GenProto(source, require=True):
"""Invokes the Protocol Compiler to generate a _pb2.py from the given """Generates a _pb2.py from the given .proto file.
.proto file. Does nothing if the output already exists and is newer than
the input.""" Does nothing if the output already exists and is newer than the input.
Args:
source: the .proto file path.
require: if True, exit immediately when a path is not found.
"""
if not require and not os.path.exists(source): if not require and not os.path.exists(source):
return return
output = source.replace(".proto", "_pb2.py").replace("../src/", "") output = source.replace('.proto', '_pb2.py').replace('../src/', '')
if (not os.path.exists(output) or if (not os.path.exists(output) or
(os.path.exists(source) and (os.path.exists(source) and
os.path.getmtime(source) > os.path.getmtime(output))): os.path.getmtime(source) > os.path.getmtime(output))):
print("Generating %s..." % output) print('Generating %s...' % output)
if not os.path.exists(source): if not os.path.exists(source):
sys.stderr.write("Can't find required file: %s\n" % source) sys.stderr.write("Can't find required file: %s\n" % source)
@ -69,78 +82,85 @@ def generate_proto(source, require = True):
if protoc is None: if protoc is None:
sys.stderr.write( sys.stderr.write(
"protoc is not installed nor found in ../src. Please compile it " 'protoc is not installed nor found in ../src. Please compile it '
"or install the binary package.\n") 'or install the binary package.\n')
sys.exit(-1) sys.exit(-1)
protoc_command = [ protoc, "-I../src", "-I.", "--python_out=.", source ] protoc_command = [protoc, '-I../src', '-I.', '--python_out=.', source]
if subprocess.call(protoc_command) != 0: if subprocess.call(protoc_command) != 0:
sys.exit(-1) sys.exit(-1)
def GenerateUnittestProtos(): def GenerateUnittestProtos():
generate_proto("../src/google/protobuf/any_test.proto", False) """Generates protobuf code for unittests."""
generate_proto("../src/google/protobuf/map_proto2_unittest.proto", False) GenProto('../src/google/protobuf/any_test.proto', False)
generate_proto("../src/google/protobuf/map_unittest.proto", False) GenProto('../src/google/protobuf/map_proto2_unittest.proto', False)
generate_proto("../src/google/protobuf/test_messages_proto3.proto", False) GenProto('../src/google/protobuf/map_unittest.proto', False)
generate_proto("../src/google/protobuf/test_messages_proto2.proto", False) GenProto('../src/google/protobuf/test_messages_proto3.proto', False)
generate_proto("../src/google/protobuf/unittest_arena.proto", False) GenProto('../src/google/protobuf/test_messages_proto2.proto', False)
generate_proto("../src/google/protobuf/unittest.proto", False) GenProto('../src/google/protobuf/unittest_arena.proto', False)
generate_proto("../src/google/protobuf/unittest_custom_options.proto", False) GenProto('../src/google/protobuf/unittest.proto', False)
generate_proto("../src/google/protobuf/unittest_import.proto", False) GenProto('../src/google/protobuf/unittest_custom_options.proto', False)
generate_proto("../src/google/protobuf/unittest_import_public.proto", False) GenProto('../src/google/protobuf/unittest_import.proto', False)
generate_proto("../src/google/protobuf/unittest_mset.proto", False) GenProto('../src/google/protobuf/unittest_import_public.proto', False)
generate_proto("../src/google/protobuf/unittest_mset_wire_format.proto", False) GenProto('../src/google/protobuf/unittest_mset.proto', False)
generate_proto("../src/google/protobuf/unittest_no_generic_services.proto", False) GenProto('../src/google/protobuf/unittest_mset_wire_format.proto', False)
generate_proto("../src/google/protobuf/unittest_proto3_arena.proto", False) GenProto('../src/google/protobuf/unittest_no_generic_services.proto', False)
generate_proto("../src/google/protobuf/util/json_format.proto", False) GenProto('../src/google/protobuf/unittest_proto3_arena.proto', False)
generate_proto("../src/google/protobuf/util/json_format_proto3.proto", False) GenProto('../src/google/protobuf/util/json_format.proto', False)
generate_proto("google/protobuf/internal/any_test.proto", False) GenProto('../src/google/protobuf/util/json_format_proto3.proto', False)
generate_proto("google/protobuf/internal/descriptor_pool_test1.proto", False) GenProto('google/protobuf/internal/any_test.proto', False)
generate_proto("google/protobuf/internal/descriptor_pool_test2.proto", False) GenProto('google/protobuf/internal/descriptor_pool_test1.proto', False)
generate_proto("google/protobuf/internal/factory_test1.proto", False) GenProto('google/protobuf/internal/descriptor_pool_test2.proto', False)
generate_proto("google/protobuf/internal/factory_test2.proto", False) GenProto('google/protobuf/internal/factory_test1.proto', False)
generate_proto("google/protobuf/internal/file_options_test.proto", False) GenProto('google/protobuf/internal/factory_test2.proto', False)
generate_proto("google/protobuf/internal/import_test_package/inner.proto", False) GenProto('google/protobuf/internal/file_options_test.proto', False)
generate_proto("google/protobuf/internal/import_test_package/outer.proto", False) GenProto('google/protobuf/internal/import_test_package/inner.proto', False)
generate_proto("google/protobuf/internal/missing_enum_values.proto", False) GenProto('google/protobuf/internal/import_test_package/outer.proto', False)
generate_proto("google/protobuf/internal/message_set_extensions.proto", False) GenProto('google/protobuf/internal/missing_enum_values.proto', False)
generate_proto("google/protobuf/internal/more_extensions.proto", False) GenProto('google/protobuf/internal/message_set_extensions.proto', False)
generate_proto("google/protobuf/internal/more_extensions_dynamic.proto", False) GenProto('google/protobuf/internal/more_extensions.proto', False)
generate_proto("google/protobuf/internal/more_messages.proto", False) GenProto('google/protobuf/internal/more_extensions_dynamic.proto', False)
generate_proto("google/protobuf/internal/no_package.proto", False) GenProto('google/protobuf/internal/more_messages.proto', False)
generate_proto("google/protobuf/internal/packed_field_test.proto", False) GenProto('google/protobuf/internal/no_package.proto', False)
generate_proto("google/protobuf/internal/test_bad_identifiers.proto", False) GenProto('google/protobuf/internal/packed_field_test.proto', False)
generate_proto("google/protobuf/internal/test_proto3_optional.proto", False) GenProto('google/protobuf/internal/test_bad_identifiers.proto', False)
generate_proto("google/protobuf/pyext/python.proto", False) GenProto('google/protobuf/internal/test_proto3_optional.proto', False)
GenProto('google/protobuf/pyext/python.proto', False)
class clean(_clean): class CleanCmd(_clean):
"""Custom clean command for building the protobuf extension."""
def run(self): def run(self):
# Delete generated files in the code tree. # Delete generated files in the code tree.
for (dirpath, dirnames, filenames) in os.walk("."): for (dirpath, unused_dirnames, filenames) in os.walk('.'):
for filename in filenames: for filename in filenames:
filepath = os.path.join(dirpath, filename) filepath = os.path.join(dirpath, filename)
if filepath.endswith("_pb2.py") or filepath.endswith(".pyc") or \ if (filepath.endswith('_pb2.py') or filepath.endswith('.pyc') or
filepath.endswith(".so") or filepath.endswith(".o"): filepath.endswith('.so') or filepath.endswith('.o')):
os.remove(filepath) os.remove(filepath)
# _clean is an old-style class, so super() doesn't work. # _clean is an old-style class, so super() doesn't work.
_clean.run(self) _clean.run(self)
class build_py(_build_py):
class BuildPyCmd(_build_py):
"""Custom build_py command for building the protobuf runtime."""
def run(self): def run(self):
# Generate necessary .proto file if it doesn't exist. # Generate necessary .proto file if it doesn't exist.
generate_proto("../src/google/protobuf/descriptor.proto") GenProto('../src/google/protobuf/descriptor.proto')
generate_proto("../src/google/protobuf/compiler/plugin.proto") GenProto('../src/google/protobuf/compiler/plugin.proto')
generate_proto("../src/google/protobuf/any.proto") GenProto('../src/google/protobuf/any.proto')
generate_proto("../src/google/protobuf/api.proto") GenProto('../src/google/protobuf/api.proto')
generate_proto("../src/google/protobuf/duration.proto") GenProto('../src/google/protobuf/duration.proto')
generate_proto("../src/google/protobuf/empty.proto") GenProto('../src/google/protobuf/empty.proto')
generate_proto("../src/google/protobuf/field_mask.proto") GenProto('../src/google/protobuf/field_mask.proto')
generate_proto("../src/google/protobuf/source_context.proto") GenProto('../src/google/protobuf/source_context.proto')
generate_proto("../src/google/protobuf/struct.proto") GenProto('../src/google/protobuf/struct.proto')
generate_proto("../src/google/protobuf/timestamp.proto") GenProto('../src/google/protobuf/timestamp.proto')
generate_proto("../src/google/protobuf/type.proto") GenProto('../src/google/protobuf/type.proto')
generate_proto("../src/google/protobuf/wrappers.proto") GenProto('../src/google/protobuf/wrappers.proto')
GenerateUnittestProtos() GenerateUnittestProtos()
# _build_py is an old-style class, so super() doesn't work. # _build_py is an old-style class, so super() doesn't work.
@ -148,17 +168,18 @@ class build_py(_build_py):
def find_package_modules(self, package, package_dir): def find_package_modules(self, package, package_dir):
exclude = ( exclude = (
"*test*", '*test*',
"google/protobuf/internal/*_pb2.py", 'google/protobuf/internal/*_pb2.py',
"google/protobuf/internal/_parameterized.py", 'google/protobuf/internal/_parameterized.py',
"google/protobuf/pyext/python_pb2.py", 'google/protobuf/pyext/python_pb2.py',
) )
modules = _build_py.find_package_modules(self, package, package_dir) modules = _build_py.find_package_modules(self, package, package_dir)
return [(pkg, mod, fil) for (pkg, mod, fil) in modules return [(pkg, mod, fil) for (pkg, mod, fil) in modules
if not any(fnmatch.fnmatchcase(fil, pat=pat) for pat in exclude)] if not any(fnmatch.fnmatchcase(fil, pat=pat) for pat in exclude)]
class build_ext(_build_ext): class BuildExtCmd(_build_ext):
"""Command class for building the protobuf Python extension."""
def get_ext_filename(self, ext_name): def get_ext_filename(self, ext_name):
# since python3.5, python extensions' shared libraries use a suffix that # since python3.5, python extensions' shared libraries use a suffix that
@ -169,23 +190,25 @@ class build_ext(_build_ext):
# suffix so that the resulting file name matches the target architecture # suffix so that the resulting file name matches the target architecture
# and we end up with a well-formed wheel. # and we end up with a well-formed wheel.
filename = _build_ext.get_ext_filename(self, ext_name) filename = _build_ext.get_ext_filename(self, ext_name)
orig_ext_suffix = sysconfig.get_config_var("EXT_SUFFIX") orig_ext_suffix = sysconfig.get_config_var('EXT_SUFFIX')
new_ext_suffix = os.getenv("PROTOCOL_BUFFERS_OVERRIDE_EXT_SUFFIX") new_ext_suffix = os.getenv('PROTOCOL_BUFFERS_OVERRIDE_EXT_SUFFIX')
if new_ext_suffix and filename.endswith(orig_ext_suffix): if new_ext_suffix and filename.endswith(orig_ext_suffix):
filename = filename[:-len(orig_ext_suffix)] + new_ext_suffix filename = filename[:-len(orig_ext_suffix)] + new_ext_suffix
return filename return filename
class test_conformance(_build_py):
class TestConformanceCmd(_build_py):
target = 'test_python' target = 'test_python'
def run(self): def run(self):
# Python 2.6 dodges these extra failures. # Python 2.6 dodges these extra failures.
os.environ["CONFORMANCE_PYTHON_EXTRA_FAILURES"] = ( os.environ['CONFORMANCE_PYTHON_EXTRA_FAILURES'] = (
"--failure_list failure_list_python-post26.txt") '--failure_list failure_list_python-post26.txt')
cmd = 'cd ../conformance && make %s' % (test_conformance.target) cmd = 'cd ../conformance && make %s' % (TestConformanceCmd.target)
status = subprocess.check_call(cmd, shell=True) subprocess.check_call(cmd, shell=True)
def get_option_from_sys_argv(option_str): def GetOptionFromArgv(option_str):
if option_str in sys.argv: if option_str in sys.argv:
sys.argv.remove(option_str) sys.argv.remove(option_str)
return True return True
@ -195,32 +218,36 @@ def get_option_from_sys_argv(option_str):
if __name__ == '__main__': if __name__ == '__main__':
ext_module_list = [] ext_module_list = []
warnings_as_errors = '--warnings_as_errors' warnings_as_errors = '--warnings_as_errors'
if get_option_from_sys_argv('--cpp_implementation'): if GetOptionFromArgv('--cpp_implementation'):
# Link libprotobuf.a and libprotobuf-lite.a statically with the # Link libprotobuf.a and libprotobuf-lite.a statically with the
# extension. Note that those libraries have to be compiled with # extension. Note that those libraries have to be compiled with
# -fPIC for this to work. # -fPIC for this to work.
compile_static_ext = get_option_from_sys_argv('--compile_static_extension') compile_static_ext = GetOptionFromArgv('--compile_static_extension')
libraries = ['protobuf'] libraries = ['protobuf']
extra_objects = None extra_objects = None
if compile_static_ext: if compile_static_ext:
libraries = None libraries = None
extra_objects = ['../src/.libs/libprotobuf.a', extra_objects = ['../src/.libs/libprotobuf.a',
'../src/.libs/libprotobuf-lite.a'] '../src/.libs/libprotobuf-lite.a']
test_conformance.target = 'test_python_cpp' TestConformanceCmd.target = 'test_python_cpp'
extra_compile_args = [] extra_compile_args = []
message_extra_link_args = None message_extra_link_args = None
api_implementation_link_args = None api_implementation_link_args = None
if "darwin" in sys.platform: if 'darwin' in sys.platform:
if sys.version_info[0] == 2: if sys.version_info[0] == 2:
message_init_symbol = 'init_message' message_init_symbol = 'init_message'
api_implementation_init_symbol = 'init_api_implementation' api_implementation_init_symbol = 'init_api_implementation'
else: else:
message_init_symbol = 'PyInit__message' message_init_symbol = 'PyInit__message'
api_implementation_init_symbol = 'PyInit__api_implementation' api_implementation_init_symbol = 'PyInit__api_implementation'
message_extra_link_args = ['-Wl,-exported_symbol,_%s' % message_init_symbol] message_extra_link_args = [
api_implementation_link_args = ['-Wl,-exported_symbol,_%s' % api_implementation_init_symbol] '-Wl,-exported_symbol,_%s' % message_init_symbol
]
api_implementation_link_args = [
'-Wl,-exported_symbol,_%s' % api_implementation_init_symbol
]
if sys.platform != 'win32': if sys.platform != 'win32':
extra_compile_args.append('-Wno-write-strings') extra_compile_args.append('-Wno-write-strings')
@ -230,8 +257,8 @@ if __name__ == '__main__':
extra_compile_args.append('-std=c++11') extra_compile_args.append('-std=c++11')
if sys.platform == 'darwin': if sys.platform == 'darwin':
extra_compile_args.append("-Wno-shorten-64-to-32"); extra_compile_args.append('-Wno-shorten-64-to-32')
extra_compile_args.append("-Wno-deprecated-register"); extra_compile_args.append('-Wno-deprecated-register')
# https://developer.apple.com/documentation/xcode_release_notes/xcode_10_release_notes # https://developer.apple.com/documentation/xcode_release_notes/xcode_10_release_notes
# C++ projects must now migrate to libc++ and are recommended to set a # C++ projects must now migrate to libc++ and are recommended to set a
@ -254,10 +281,10 @@ if __name__ == '__main__':
extra_compile_args.append('-DMS_WIN64') extra_compile_args.append('-DMS_WIN64')
# MSVS default is dymanic # MSVS default is dymanic
if (sys.platform == 'win32'): if sys.platform == 'win32':
extra_compile_args.append('/MT') extra_compile_args.append('/MT')
if "clang" in os.popen('$CC --version 2> /dev/null').read(): if 'clang' in os.popen('$CC --version 2> /dev/null').read():
extra_compile_args.append('-Wno-shorten-64-to-32') extra_compile_args.append('-Wno-shorten-64-to-32')
if warnings_as_errors in sys.argv: if warnings_as_errors in sys.argv:
@ -267,9 +294,9 @@ if __name__ == '__main__':
# C++ implementation extension # C++ implementation extension
ext_module_list.extend([ ext_module_list.extend([
Extension( Extension(
"google.protobuf.pyext._message", 'google.protobuf.pyext._message',
glob.glob('google/protobuf/pyext/*.cc'), glob.glob('google/protobuf/pyext/*.cc'),
include_dirs=[".", "../src"], include_dirs=['.', '../src'],
libraries=libraries, libraries=libraries,
extra_objects=extra_objects, extra_objects=extra_objects,
extra_link_args=message_extra_link_args, extra_link_args=message_extra_link_args,
@ -277,9 +304,10 @@ if __name__ == '__main__':
extra_compile_args=extra_compile_args, extra_compile_args=extra_compile_args,
), ),
Extension( Extension(
"google.protobuf.internal._api_implementation", 'google.protobuf.internal._api_implementation',
glob.glob('google/protobuf/internal/api_implementation.cc'), glob.glob('google/protobuf/internal/api_implementation.cc'),
extra_compile_args=extra_compile_args + ['-DPYTHON_PROTO2_CPP_IMPL_V2'], extra_compile_args=(extra_compile_args +
['-DPYTHON_PROTO2_CPP_IMPL_V2']),
extra_link_args=api_implementation_link_args, extra_link_args=api_implementation_link_args,
), ),
]) ])
@ -299,12 +327,12 @@ if __name__ == '__main__':
maintainer_email='protobuf@googlegroups.com', maintainer_email='protobuf@googlegroups.com',
license='BSD-3-Clause', license='BSD-3-Clause',
classifiers=[ classifiers=[
"Programming Language :: Python", 'Programming Language :: Python',
"Programming Language :: Python :: 3", 'Programming Language :: Python :: 3',
"Programming Language :: Python :: 3.7", 'Programming Language :: Python :: 3.7',
"Programming Language :: Python :: 3.8", 'Programming Language :: Python :: 3.8',
"Programming Language :: Python :: 3.9", 'Programming Language :: Python :: 3.9',
"Programming Language :: Python :: 3.10", 'Programming Language :: Python :: 3.10',
], ],
namespace_packages=['google'], namespace_packages=['google'],
packages=find_packages( packages=find_packages(
@ -314,12 +342,12 @@ if __name__ == '__main__':
],), ],),
test_suite='google.protobuf.internal', test_suite='google.protobuf.internal',
cmdclass={ cmdclass={
'clean': clean, 'clean': CleanCmd,
'build_py': build_py, 'build_py': BuildPyCmd,
'build_ext': build_ext, 'build_ext': BuildExtCmd,
'test_conformance': test_conformance, 'test_conformance': TestConformanceCmd,
}, },
install_requires=install_requires, install_requires=install_requires,
ext_modules=ext_module_list, ext_modules=ext_module_list,
python_requires=">=3.7", python_requires='>=3.7',
) )

View File

@ -86,9 +86,18 @@ const std::string& LazyString::Init() const {
namespace { namespace {
#if defined(NDEBUG) || !GOOGLE_PROTOBUF_INTERNAL_DONATE_STEAL
class ScopedCheckPtrInvariants {
public:
explicit ScopedCheckPtrInvariants(const TaggedStringPtr*) {}
};
#endif // NDEBUG || !GOOGLE_PROTOBUF_INTERNAL_DONATE_STEAL
// Creates a heap allocated std::string value. // Creates a heap allocated std::string value.
inline TaggedPtr<std::string> CreateString(ConstStringParam value) { inline TaggedStringPtr CreateString(ConstStringParam value) {
TaggedPtr<std::string> res; TaggedStringPtr res;
res.SetAllocated(new std::string(value.data(), value.length())); res.SetAllocated(new std::string(value.data(), value.length()));
return res; return res;
} }
@ -96,8 +105,8 @@ inline TaggedPtr<std::string> CreateString(ConstStringParam value) {
#if !GOOGLE_PROTOBUF_INTERNAL_DONATE_STEAL #if !GOOGLE_PROTOBUF_INTERNAL_DONATE_STEAL
// Creates an arena allocated std::string value. // Creates an arena allocated std::string value.
TaggedPtr<std::string> CreateArenaString(Arena& arena, ConstStringParam s) { TaggedStringPtr CreateArenaString(Arena& arena, ConstStringParam s) {
TaggedPtr<std::string> res; TaggedStringPtr res;
res.SetMutableArena(Arena::Create<std::string>(&arena, s.data(), s.length())); res.SetMutableArena(Arena::Create<std::string>(&arena, s.data(), s.length()));
return res; return res;
} }
@ -106,13 +115,8 @@ TaggedPtr<std::string> CreateArenaString(Arena& arena, ConstStringParam s) {
} // namespace } // namespace
std::string* ArenaStringPtr::SetAndReturnNewString() {
std::string* new_string = new std::string();
tagged_ptr_.SetAllocated(new_string);
return new_string;
}
void ArenaStringPtr::Set(ConstStringParam value, Arena* arena) { void ArenaStringPtr::Set(ConstStringParam value, Arena* arena) {
ScopedCheckPtrInvariants check(&tagged_ptr_);
if (IsDefault()) { if (IsDefault()) {
// If we're not on an arena, skip straight to a true string to avoid // If we're not on an arena, skip straight to a true string to avoid
// possible copy cost later. // possible copy cost later.
@ -124,6 +128,7 @@ void ArenaStringPtr::Set(ConstStringParam value, Arena* arena) {
} }
void ArenaStringPtr::Set(std::string&& value, Arena* arena) { void ArenaStringPtr::Set(std::string&& value, Arena* arena) {
ScopedCheckPtrInvariants check(&tagged_ptr_);
if (IsDefault()) { if (IsDefault()) {
NewString(arena, std::move(value)); NewString(arena, std::move(value));
} else if (IsFixedSizeArena()) { } else if (IsFixedSizeArena()) {
@ -137,6 +142,7 @@ void ArenaStringPtr::Set(std::string&& value, Arena* arena) {
} }
std::string* ArenaStringPtr::Mutable(Arena* arena) { std::string* ArenaStringPtr::Mutable(Arena* arena) {
ScopedCheckPtrInvariants check(&tagged_ptr_);
if (tagged_ptr_.IsMutable()) { if (tagged_ptr_.IsMutable()) {
return tagged_ptr_.Get(); return tagged_ptr_.Get();
} else { } else {
@ -146,6 +152,7 @@ std::string* ArenaStringPtr::Mutable(Arena* arena) {
std::string* ArenaStringPtr::Mutable(const LazyString& default_value, std::string* ArenaStringPtr::Mutable(const LazyString& default_value,
Arena* arena) { Arena* arena) {
ScopedCheckPtrInvariants check(&tagged_ptr_);
if (tagged_ptr_.IsMutable()) { if (tagged_ptr_.IsMutable()) {
return tagged_ptr_.Get(); return tagged_ptr_.Get();
} else { } else {
@ -154,6 +161,7 @@ std::string* ArenaStringPtr::Mutable(const LazyString& default_value,
} }
std::string* ArenaStringPtr::MutableNoCopy(Arena* arena) { std::string* ArenaStringPtr::MutableNoCopy(Arena* arena) {
ScopedCheckPtrInvariants check(&tagged_ptr_);
if (tagged_ptr_.IsMutable()) { if (tagged_ptr_.IsMutable()) {
return tagged_ptr_.Get(); return tagged_ptr_.Get();
} else { } else {
@ -171,6 +179,7 @@ std::string* ArenaStringPtr::MutableSlow(::google::protobuf::Arena* arena,
} }
std::string* ArenaStringPtr::Release() { std::string* ArenaStringPtr::Release() {
ScopedCheckPtrInvariants check(&tagged_ptr_);
if (IsDefault()) return nullptr; if (IsDefault()) return nullptr;
std::string* released = tagged_ptr_.Get(); std::string* released = tagged_ptr_.Get();
@ -183,6 +192,7 @@ std::string* ArenaStringPtr::Release() {
} }
void ArenaStringPtr::SetAllocated(std::string* value, Arena* arena) { void ArenaStringPtr::SetAllocated(std::string* value, Arena* arena) {
ScopedCheckPtrInvariants check(&tagged_ptr_);
// Release what we have first. // Release what we have first.
Destroy(); Destroy();
@ -208,6 +218,7 @@ void ArenaStringPtr::Destroy() {
} }
void ArenaStringPtr::ClearToEmpty() { void ArenaStringPtr::ClearToEmpty() {
ScopedCheckPtrInvariants check(&tagged_ptr_);
if (IsDefault()) { if (IsDefault()) {
// Already set to default -- do nothing. // Already set to default -- do nothing.
} else { } else {
@ -222,6 +233,7 @@ void ArenaStringPtr::ClearToEmpty() {
void ArenaStringPtr::ClearToDefault(const LazyString& default_value, void ArenaStringPtr::ClearToDefault(const LazyString& default_value,
::google::protobuf::Arena* arena) { ::google::protobuf::Arena* arena) {
ScopedCheckPtrInvariants check(&tagged_ptr_);
(void)arena; (void)arena;
if (IsDefault()) { if (IsDefault()) {
// Already set to default -- do nothing. // Already set to default -- do nothing.
@ -233,6 +245,7 @@ void ArenaStringPtr::ClearToDefault(const LazyString& default_value,
const char* EpsCopyInputStream::ReadArenaString(const char* ptr, const char* EpsCopyInputStream::ReadArenaString(const char* ptr,
ArenaStringPtr* s, ArenaStringPtr* s,
Arena* arena) { Arena* arena) {
ScopedCheckPtrInvariants check(&s->tagged_ptr_);
GOOGLE_DCHECK(arena != nullptr); GOOGLE_DCHECK(arena != nullptr);
int size = ReadSize(&ptr); int size = ReadSize(&ptr);

View File

@ -94,8 +94,7 @@ class PROTOBUF_EXPORT LazyString {
const std::string& Init() const; const std::string& Init() const;
}; };
template <typename T> class TaggedStringPtr {
class TaggedPtr {
public: public:
// Bit flags qualifying string properties. We can use up to 3 bits as // Bit flags qualifying string properties. We can use up to 3 bits as
// ptr_ is guaranteed and enforced to be aligned on 8 byte boundaries. // ptr_ is guaranteed and enforced to be aligned on 8 byte boundaries.
@ -130,30 +129,36 @@ class TaggedPtr {
kFixedSizeArena = kArenaBit, kFixedSizeArena = kArenaBit,
}; };
TaggedPtr() = default; TaggedStringPtr() = default;
explicit constexpr TaggedPtr(ExplicitlyConstructedArenaString* ptr) explicit constexpr TaggedStringPtr(ExplicitlyConstructedArenaString* ptr)
: ptr_(ptr) {} : ptr_(ptr) {}
// Sets the value to `p`, tagging the value as being a 'default' value. // Sets the value to `p`, tagging the value as being a 'default' value.
// See documentation for kDefault for more info. // See documentation for kDefault for more info.
inline const T* SetDefault(const T* p) { inline const std::string* SetDefault(const std::string* p) {
return TagAs(kDefault, const_cast<T*>(p)); return TagAs(kDefault, const_cast<std::string*>(p));
} }
// Sets the value to `p`, tagging the value as a heap allocated value. // Sets the value to `p`, tagging the value as a heap allocated value.
// Allocated strings are mutable and (as the name implies) owned. // Allocated strings are mutable and (as the name implies) owned.
// `p` must not be null // `p` must not be null
inline T* SetAllocated(T* p) { return TagAs(kAllocated, p); } inline std::string* SetAllocated(std::string* p) {
return TagAs(kAllocated, p);
}
// Sets the value to `p`, tagging the value as a fixed size arena string. // Sets the value to `p`, tagging the value as a fixed size arena string.
// See documentation for kFixedSizeArena for more info. // See documentation for kFixedSizeArena for more info.
// `p` must not be null // `p` must not be null
inline T* SetFixedSizeArena(T* p) { return TagAs(kFixedSizeArena, p); } inline std::string* SetFixedSizeArena(std::string* p) {
return TagAs(kFixedSizeArena, p);
}
// Sets the value to `p`, tagging the value as a mutable arena string. // Sets the value to `p`, tagging the value as a mutable arena string.
// See documentation for kMutableArena for more info. // See documentation for kMutableArena for more info.
// `p` must not be null // `p` must not be null
inline T* SetMutableArena(T* p) { return TagAs(kMutableArena, p); } inline std::string* SetMutableArena(std::string* p) {
return TagAs(kMutableArena, p);
}
// Returns true if the contents of the current string are fully mutable. // Returns true if the contents of the current string are fully mutable.
inline bool IsMutable() const { return as_int() & kMutableBit; } inline bool IsMutable() const { return as_int() & kMutableBit; }
@ -174,7 +179,9 @@ class TaggedPtr {
} }
// Returns the contained string pointer. // Returns the contained string pointer.
inline T* Get() const { return reinterpret_cast<T*>(as_int() & ~kMask); } inline std::string* Get() const {
return reinterpret_cast<std::string*>(as_int() & ~kMask);
}
// Returns true if the contained pointer is null, indicating some error. // Returns true if the contained pointer is null, indicating some error.
// The Null value is only used during parsing for temporary values. // The Null value is only used during parsing for temporary values.
@ -186,7 +193,7 @@ class TaggedPtr {
GOOGLE_DCHECK_EQ(reinterpret_cast<uintptr_t>(p) & kMask, 0UL); GOOGLE_DCHECK_EQ(reinterpret_cast<uintptr_t>(p) & kMask, 0UL);
} }
inline T* TagAs(Type type, T* p) { inline std::string* TagAs(Type type, std::string* p) {
GOOGLE_DCHECK(p != nullptr); GOOGLE_DCHECK(p != nullptr);
assert_aligned(p); assert_aligned(p);
ptr_ = reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(p) | type); ptr_ = reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(p) | type);
@ -197,8 +204,8 @@ class TaggedPtr {
void* ptr_; void* ptr_;
}; };
static_assert(std::is_trivial<TaggedPtr<std::string>>::value, static_assert(std::is_trivial<TaggedStringPtr>::value,
"TaggedPtr must be trivial"); "TaggedStringPtr must be trivial");
// This class encapsulates a pointer to a std::string with or without arena // This class encapsulates a pointer to a std::string with or without arena
// owned contents, tagged by the bottom bits of the string pointer. It is a // owned contents, tagged by the bottom bits of the string pointer. It is a
@ -225,13 +232,6 @@ struct PROTOBUF_EXPORT ArenaStringPtr {
ExplicitlyConstructedArenaString* default_value) ExplicitlyConstructedArenaString* default_value)
: tagged_ptr_(default_value) {} : tagged_ptr_(default_value) {}
// Some methods below are overloaded on a `default_value` and on tags.
// The tagged overloads help reduce code size in the callers in generated
// code, while the `default_value` overloads are useful from reflection.
// By-value empty struct arguments are elided in the ABI.
struct EmptyDefault {};
struct NonEmptyDefault {};
// Called from generated code / reflection runtime only. Resets value to point // Called from generated code / reflection runtime only. Resets value to point
// to a default string pointer, with the semantics that this ArenaStringPtr // to a default string pointer, with the semantics that this ArenaStringPtr
// does not own the pointed-to memory. Disregards initial value of ptr_ (so // does not own the pointed-to memory. Disregards initial value of ptr_ (so
@ -334,107 +334,9 @@ struct PROTOBUF_EXPORT ArenaStringPtr {
ArenaStringPtr* lhs, ArenaStringPtr* lhs,
Arena* lhs_arena); Arena* lhs_arena);
// --------------------------------------------------------
// Below functions will be removed in subsequent code change
// --------------------------------------------------------
#ifdef DEPRECATED_METHODS_TO_BE_DELETED
PROTOBUF_NDEBUG_INLINE const std::string* GetPointer() const {
return UnsafeGetPointer();
}
template <typename DefaultArg>
void Set(DefaultArg, ConstStringParam value, Arena* arena) {
return Set(value, arena);
}
template <typename DefaultArg>
void Set(DefaultArg, std::string&& value, Arena* arena) {
return Set(std::move(value), arena);
}
template <typename DefaultArg>
void Set(DefaultArg, const char* s, Arena* arena) {
return Set(ConstStringParam{s}, arena);
}
template <typename DefaultArg>
void Set(DefaultArg, const char* s, size_t n, Arena* arena) {
return Set(ConstStringParam{s, n}, arena);
}
void SetBytes(EmptyDefault, ConstStringParam value, Arena* arena) {
return Set(value, arena);
}
void SetBytes(NonEmptyDefault, ConstStringParam value, Arena* arena) {
return Set(value, arena);
}
void SetBytes(const std::string*, ConstStringParam value, Arena* arena) {
return Set(value, arena);
}
void SetBytes(EmptyDefault, std::string&& value, Arena* arena) {
return Set(std::move(value), arena);
}
void SetBytes(NonEmptyDefault, std::string&& value, Arena* arena) {
return Set(std::move(value), arena);
}
void SetBytes(const std::string*, std::string&& value, Arena* arena) {
return Set(std::move(value), arena);
}
void SetBytes(EmptyDefault, const char* s, Arena* arena) {
return Set(s, arena);
}
void SetBytes(NonEmptyDefault, const char* s, Arena* arena) {
return Set(s, arena);
}
void SetBytes(const std::string*, const char* s, Arena* arena) {
return Set(s, arena);
}
void SetBytes(EmptyDefault, const void* p, size_t n, Arena* arena) {
return SetBytes(p, n, arena);
}
void SetBytes(NonEmptyDefault, const void* p, size_t n, Arena* arena) {
return SetBytes(p, n, arena);
}
void SetBytes(const std::string*, const void* p, size_t n, Arena* arena) {
return SetBytes(p, n, arena);
}
std::string* Mutable(EmptyDefault, Arena* arena) { return Mutable(arena); }
std::string* MutableNoArenaNoDefault(const std::string*) {
return Mutable(nullptr);
}
std::string* MutableNoCopy(const std::string*, ::google::protobuf::Arena* arena) {
return MutableNoCopy(arena);
}
PROTOBUF_NODISCARD std::string* Release(const std::string*, Arena* arena) {
return Release();
}
PROTOBUF_NODISCARD std::string* ReleaseNonDefault(const std::string*,
Arena* arena) {
return Release();
}
void SetAllocated(const std::string*, std::string* value, Arena* arena) {
SetAllocated(value, arena);
}
void Destroy(const std::string*, ::google::protobuf::Arena* arena) { Destroy(); }
void Destroy(EmptyDefault, ::google::protobuf::Arena* arena) { Destroy(); }
void Destroy(NonEmptyDefault, ::google::protobuf::Arena* arena) { Destroy(); }
void DestroyNoArena(const std::string*) { Destroy(); }
inline PROTOBUF_NDEBUG_INLINE static void InternalSwap(const std::string*,
ArenaStringPtr* rhs,
Arena* rhs_arena,
ArenaStringPtr* lhs,
Arena* lhs_arena) {
InternalSwap(rhs, rhs_arena, lhs, lhs_arena);
}
#endif // DEPRECATED_METHODS_TO_BE_DELETED
// Internal setter used only at parse time to directly set a donated string // Internal setter used only at parse time to directly set a donated string
// value. // value.
void UnsafeSetTaggedPointer(TaggedPtr<std::string> value) { void UnsafeSetTaggedPointer(TaggedStringPtr value) { tagged_ptr_ = value; }
tagged_ptr_ = value;
}
// Generated code only! An optimization, in certain cases the generated // Generated code only! An optimization, in certain cases the generated
// code is certain we can obtain a std::string with no default checks and // code is certain we can obtain a std::string with no default checks and
// tag tests. // tag tests.
@ -455,7 +357,7 @@ struct PROTOBUF_EXPORT ArenaStringPtr {
} }
} }
TaggedPtr<std::string> tagged_ptr_; TaggedStringPtr tagged_ptr_;
bool IsFixedSizeArena() const { return false; } bool IsFixedSizeArena() const { return false; }
@ -472,18 +374,15 @@ struct PROTOBUF_EXPORT ArenaStringPtr {
// Slow paths. // Slow paths.
// MutableSlow requires that !IsString() || IsDefault // MutableSlow requires that !IsString() || IsDefault
// Variadic to support 0 args for EmptyDefault and 1 arg for LazyString. // Variadic to support 0 args for empty default and 1 arg for LazyString.
template <typename... Lazy> template <typename... Lazy>
std::string* MutableSlow(::google::protobuf::Arena* arena, const Lazy&... lazy_default); std::string* MutableSlow(::google::protobuf::Arena* arena, const Lazy&... lazy_default);
// Sets value to a newly allocated string and returns it
std::string* SetAndReturnNewString();
friend class EpsCopyInputStream; friend class EpsCopyInputStream;
}; };
inline void ArenaStringPtr::InitDefault() { inline void ArenaStringPtr::InitDefault() {
tagged_ptr_ = TaggedPtr<std::string>(&fixed_address_empty_string); tagged_ptr_ = TaggedStringPtr(&fixed_address_empty_string);
} }
inline void ArenaStringPtr::InitExternal(const std::string* str) { inline void ArenaStringPtr::InitExternal(const std::string* str) {

View File

@ -55,8 +55,6 @@ namespace protobuf {
using internal::ArenaStringPtr; using internal::ArenaStringPtr;
using EmptyDefault = ArenaStringPtr::EmptyDefault;
const internal::LazyString nonempty_default{{{"default", 7}}, {nullptr}}; const internal::LazyString nonempty_default{{{"default", 7}}, {nullptr}};
const std::string* empty_default = &internal::GetEmptyString(); const std::string* empty_default = &internal::GetEmptyString();

View File

@ -57,7 +57,7 @@
#include <fstream> #include <fstream>
#include <iostream> #include <iostream>
#include <limits.h> //For PATH_MAX #include <limits.h> // For PATH_MAX
#include <memory> #include <memory>
@ -69,7 +69,6 @@
#include <google/protobuf/stubs/common.h> #include <google/protobuf/stubs/common.h>
#include <google/protobuf/stubs/logging.h> #include <google/protobuf/stubs/logging.h>
#include <google/protobuf/stubs/stringprintf.h>
#include <google/protobuf/compiler/subprocess.h> #include <google/protobuf/compiler/subprocess.h>
#include <google/protobuf/compiler/zip_writer.h> #include <google/protobuf/compiler/zip_writer.h>
#include <google/protobuf/compiler/plugin.pb.h> #include <google/protobuf/compiler/plugin.pb.h>
@ -82,6 +81,7 @@
#include <google/protobuf/dynamic_message.h> #include <google/protobuf/dynamic_message.h>
#include <google/protobuf/text_format.h> #include <google/protobuf/text_format.h>
#include <google/protobuf/stubs/strutil.h> #include <google/protobuf/stubs/strutil.h>
#include <google/protobuf/stubs/stringprintf.h>
#include <google/protobuf/stubs/substitute.h> #include <google/protobuf/stubs/substitute.h>
#include <google/protobuf/io/io_win32.h> #include <google/protobuf/io/io_win32.h>
#include <google/protobuf/stubs/map_util.h> #include <google/protobuf/stubs/map_util.h>
@ -1676,8 +1676,7 @@ CommandLineInterface::InterpretArgument(const std::string& name,
// On Windows, the shell (typically cmd.exe) does not expand wildcards in // On Windows, the shell (typically cmd.exe) does not expand wildcards in
// file names (e.g. foo\*.proto), so we do it ourselves. // file names (e.g. foo\*.proto), so we do it ourselves.
switch (google::protobuf::io::win32::ExpandWildcards( switch (google::protobuf::io::win32::ExpandWildcards(
value, value, [this](const std::string& path) {
[this](const std::string& path) {
this->input_files_.push_back(path); this->input_files_.push_back(path);
})) { })) {
case google::protobuf::io::win32::ExpandWildcardsResult::kSuccess: case google::protobuf::io::win32::ExpandWildcardsResult::kSuccess:
@ -2590,7 +2589,8 @@ void FormatFreeFieldNumbers(const std::string& name,
StringAppendF(&output, " %d", next_free_number); StringAppendF(&output, " %d", next_free_number);
} else { } else {
// Range // Range
StringAppendF(&output, " %d-%d", next_free_number, i->first - 1); StringAppendF(&output, " %d-%d", next_free_number,
i->first - 1);
} }
} }
next_free_number = i->second; next_free_number = i->second;

View File

@ -130,13 +130,13 @@ void EnumFieldGenerator::GenerateSwappingCode(io::Printer* printer) const {
void EnumFieldGenerator::GenerateConstructorCode(io::Printer* printer) const { void EnumFieldGenerator::GenerateConstructorCode(io::Printer* printer) const {
Formatter format(printer, variables_); Formatter format(printer, variables_);
format("$name$_ = $default$;\n"); format("$field$ = $default$;\n");
} }
void EnumFieldGenerator::GenerateCopyConstructorCode( void EnumFieldGenerator::GenerateCopyConstructorCode(
io::Printer* printer) const { io::Printer* printer) const {
Formatter format(printer, variables_); Formatter format(printer, variables_);
format("$name$_ = from.$name$_;\n"); format("$field$ = from.$field$;\n");
} }
void EnumFieldGenerator::GenerateSerializeWithCachedSizesToArray( void EnumFieldGenerator::GenerateSerializeWithCachedSizesToArray(
@ -216,7 +216,7 @@ void EnumOneofFieldGenerator::GenerateSwappingCode(io::Printer* printer) const {
void EnumOneofFieldGenerator::GenerateConstructorCode( void EnumOneofFieldGenerator::GenerateConstructorCode(
io::Printer* printer) const { io::Printer* printer) const {
Formatter format(printer, variables_); Formatter format(printer, variables_);
format("$ns$::_$classname$_default_instance_.$name$_ = $default$;\n"); format("$ns$::_$classname$_default_instance_.$field$ = $default$;\n");
} }
// =================================================================== // ===================================================================

View File

@ -103,7 +103,7 @@ std::string GenerateTemplateForOneofString(const FieldDescriptor* descriptor,
return strings::Substitute( return strings::Substitute(
StrCat("_internal_has_", field_name, "() ? ", field_pointer, " : ", StrCat("_internal_has_", field_name, "() ? ", field_pointer, " : ",
default_value_pointer), default_value_pointer),
field_member, MakeDefaultName(descriptor)); field_member, MakeDefaultFieldName(descriptor));
} }
std::string GenerateTemplateForSingleString(const FieldDescriptor* descriptor, std::string GenerateTemplateForSingleString(const FieldDescriptor* descriptor,
@ -115,7 +115,7 @@ std::string GenerateTemplateForSingleString(const FieldDescriptor* descriptor,
if (descriptor->options().ctype() == google::protobuf::FieldOptions::STRING) { if (descriptor->options().ctype() == google::protobuf::FieldOptions::STRING) {
return strings::Substitute( return strings::Substitute(
"$0.IsDefault() ? &$1.get() : $0.UnsafeGetPointer()", field_member, "$0.IsDefault() ? &$1.get() : $0.UnsafeGetPointer()", field_member,
MakeDefaultName(descriptor)); MakeDefaultFieldName(descriptor));
} }
return StrCat("&", field_member); return StrCat("&", field_member);
@ -150,7 +150,7 @@ void AddAccessorAnnotations(const FieldDescriptor* descriptor,
google::protobuf::FileOptions::LITE_RUNTIME) { google::protobuf::FileOptions::LITE_RUNTIME) {
return; return;
} }
std::string field_member = (*variables)["field_member"]; std::string field_member = (*variables)["field"];
const google::protobuf::OneofDescriptor* oneof_member = const google::protobuf::OneofDescriptor* oneof_member =
descriptor->real_containing_oneof(); descriptor->real_containing_oneof();
const std::string proto_ns = (*variables)["proto_ns"]; const std::string proto_ns = (*variables)["proto_ns"];
@ -241,11 +241,6 @@ void SetCommonFieldVariables(const FieldDescriptor* descriptor,
(*variables)["number"] = StrCat(descriptor->number()); (*variables)["number"] = StrCat(descriptor->number());
(*variables)["classname"] = ClassName(FieldScope(descriptor), false); (*variables)["classname"] = ClassName(FieldScope(descriptor), false);
(*variables)["declared_type"] = DeclaredTypeMethodName(descriptor->type()); (*variables)["declared_type"] = DeclaredTypeMethodName(descriptor->type());
// TODO(b/218325252): convert all usages of "field_member" to "field" and
// remove this. The former may unnecessarily cause line breaks in protoc code.
// Note that the length of variables has no effect on the generated code. It
// only affects the readability of code template in protoc.
(*variables)["field_member"] = FieldMemberName(descriptor);
(*variables)["field"] = FieldMemberName(descriptor); (*variables)["field"] = FieldMemberName(descriptor);
(*variables)["tag_size"] = StrCat( (*variables)["tag_size"] = StrCat(

View File

@ -511,9 +511,10 @@ void FileGenerator::GenerateSourceDefaultInstance(int idx,
format( format(
"PROTOBUF_ATTRIBUTE_INIT_PRIORITY2 std::true_type " "PROTOBUF_ATTRIBUTE_INIT_PRIORITY2 std::true_type "
"$1$::_init_inline_$2$_ = " "$1$::_init_inline_$2$_ = "
"($3$._instance.$2$_.Init(), std::true_type{});\n", "($3$._instance.$4$.Init(), std::true_type{});\n",
ClassName(generator->descriptor_), FieldName(field), ClassName(generator->descriptor_), FieldName(field),
DefaultInstanceName(generator->descriptor_, options_)); DefaultInstanceName(generator->descriptor_, options_),
FieldMemberName(field));
} }
} }

View File

@ -486,6 +486,20 @@ inline std::string MakeDefaultName(const FieldDescriptor* field) {
"_"; "_";
} }
// Semantically distinct from MakeDefaultName in that it gives the C++ code
// referencing a default field from the message scope, rather than just the
// variable name.
// For example, declarations of default variables should always use just
// MakeDefaultName to produce code like:
// Type _i_give_permission_to_break_this_code_default_field_;
//
// Code that references these should use MakeDefaultFieldName, in case the field
// exists at some nested level like:
// internal_container_._i_give_permission_to_break_this_code_default_field_;
inline std::string MakeDefaultFieldName(const FieldDescriptor* field) {
return MakeDefaultName(field);
}
bool IsAnyMessage(const FileDescriptor* descriptor, const Options& options); bool IsAnyMessage(const FileDescriptor* descriptor, const Options& options);
bool IsAnyMessage(const Descriptor* descriptor, const Options& options); bool IsAnyMessage(const Descriptor* descriptor, const Options& options);

View File

@ -126,7 +126,7 @@ void MapFieldGenerator::GenerateInlineAccessorDefinitions(
format( format(
"inline const ::$proto_ns$::Map< $key_cpp$, $val_cpp$ >&\n" "inline const ::$proto_ns$::Map< $key_cpp$, $val_cpp$ >&\n"
"$classname$::_internal_$name$() const {\n" "$classname$::_internal_$name$() const {\n"
" return $name$_.GetMap();\n" " return $field$.GetMap();\n"
"}\n" "}\n"
"inline const ::$proto_ns$::Map< $key_cpp$, $val_cpp$ >&\n" "inline const ::$proto_ns$::Map< $key_cpp$, $val_cpp$ >&\n"
"$classname$::$name$() const {\n" "$classname$::$name$() const {\n"
@ -136,7 +136,7 @@ void MapFieldGenerator::GenerateInlineAccessorDefinitions(
"}\n" "}\n"
"inline ::$proto_ns$::Map< $key_cpp$, $val_cpp$ >*\n" "inline ::$proto_ns$::Map< $key_cpp$, $val_cpp$ >*\n"
"$classname$::_internal_mutable_$name$() {\n" "$classname$::_internal_mutable_$name$() {\n"
" return $name$_.MutableMap();\n" " return $field$.MutableMap();\n"
"}\n" "}\n"
"inline ::$proto_ns$::Map< $key_cpp$, $val_cpp$ >*\n" "inline ::$proto_ns$::Map< $key_cpp$, $val_cpp$ >*\n"
"$classname$::mutable_$name$() {\n" "$classname$::mutable_$name$() {\n"
@ -148,17 +148,17 @@ void MapFieldGenerator::GenerateInlineAccessorDefinitions(
void MapFieldGenerator::GenerateClearingCode(io::Printer* printer) const { void MapFieldGenerator::GenerateClearingCode(io::Printer* printer) const {
Formatter format(printer, variables_); Formatter format(printer, variables_);
format("$name$_.Clear();\n"); format("$field$.Clear();\n");
} }
void MapFieldGenerator::GenerateMergingCode(io::Printer* printer) const { void MapFieldGenerator::GenerateMergingCode(io::Printer* printer) const {
Formatter format(printer, variables_); Formatter format(printer, variables_);
format("$name$_.MergeFrom(from.$name$_);\n"); format("$field$.MergeFrom(from.$field$);\n");
} }
void MapFieldGenerator::GenerateSwappingCode(io::Printer* printer) const { void MapFieldGenerator::GenerateSwappingCode(io::Printer* printer) const {
Formatter format(printer, variables_); Formatter format(printer, variables_);
format("$name$_.InternalSwap(&other->$name$_);\n"); format("$field$.InternalSwap(&other->$field$);\n");
} }
void MapFieldGenerator::GenerateCopyConstructorCode( void MapFieldGenerator::GenerateCopyConstructorCode(
@ -267,7 +267,7 @@ void MapFieldGenerator::GenerateIsInitialized(io::Printer* printer) const {
Formatter format(printer, variables_); Formatter format(printer, variables_);
format( format(
"if (!::$proto_ns$::internal::AllAreInitialized($name$_)) return " "if (!::$proto_ns$::internal::AllAreInitialized($field$)) return "
"false;\n"); "false;\n");
} }
@ -285,7 +285,7 @@ void MapFieldGenerator::GenerateDestructorCode(io::Printer* printer) const {
GOOGLE_CHECK(!IsFieldStripped(descriptor_, options_)); GOOGLE_CHECK(!IsFieldStripped(descriptor_, options_));
Formatter format(printer, variables_); Formatter format(printer, variables_);
format("$name$_.Destruct();\n"); format("$field$.Destruct();\n");
} }
void MapFieldGenerator::GenerateArenaDestructorCode( void MapFieldGenerator::GenerateArenaDestructorCode(
@ -296,7 +296,7 @@ void MapFieldGenerator::GenerateArenaDestructorCode(
Formatter format(printer, variables_); Formatter format(printer, variables_);
// _this is the object being destructed (we are inside a static method here). // _this is the object being destructed (we are inside a static method here).
format("_this->$name$_.Destruct();\n"); format("_this->$field$.Destruct();\n");
} }
ArenaDtorNeeds MapFieldGenerator::NeedsArenaDestructor() const { ArenaDtorNeeds MapFieldGenerator::NeedsArenaDestructor() const {

View File

@ -1082,7 +1082,7 @@ void MessageGenerator::GenerateSingularFieldHasBits(
// We maintain the invariant that for a submessage x, has_x() returning // We maintain the invariant that for a submessage x, has_x() returning
// true implies that x_ is not null. By giving this information to the // true implies that x_ is not null. By giving this information to the
// compiler, we allow it to eliminate unnecessary null checks later on. // compiler, we allow it to eliminate unnecessary null checks later on.
format(" PROTOBUF_ASSUME(!value || $name$_ != nullptr);\n"); format(" PROTOBUF_ASSUME(!value || $field$ != nullptr);\n");
} }
format( format(
@ -1097,13 +1097,13 @@ void MessageGenerator::GenerateSingularFieldHasBits(
if (IsLazy(field, options_, scc_analyzer_)) { if (IsLazy(field, options_, scc_analyzer_)) {
format( format(
"inline bool $classname$::_internal_has_$name$() const {\n" "inline bool $classname$::_internal_has_$name$() const {\n"
" return !$name$_.IsCleared();\n" " return !$field$.IsCleared();\n"
"}\n"); "}\n");
} else { } else {
format( format(
"inline bool $classname$::_internal_has_$name$() const {\n" "inline bool $classname$::_internal_has_$name$() const {\n"
" return this != internal_default_instance() " " return this != internal_default_instance() "
"&& $name$_ != nullptr;\n" "&& $field$ != nullptr;\n"
"}\n"); "}\n");
} }
format( format(
@ -2219,7 +2219,7 @@ std::pair<size_t, size_t> MessageGenerator::GenerateOffsets(
// Don't use the top bit because that is for unused fields. // Don't use the top bit because that is for unused fields.
format("::_pbi::kInvalidFieldOffsetTag"); format("::_pbi::kInvalidFieldOffsetTag");
} else { } else {
format("PROTOBUF_FIELD_OFFSET($classtype$, $1$_)", FieldName(field)); format("PROTOBUF_FIELD_OFFSET($classtype$, $1$)", FieldMemberName(field));
} }
// Some information about a field is in the pdproto profile. The profile is // Some information about a field is in the pdproto profile. The profile is
@ -2227,11 +2227,6 @@ std::pair<size_t, size_t> MessageGenerator::GenerateOffsets(
// offset of the field, so that the information is available when // offset of the field, so that the information is available when
// reflectively accessing the field at run time. // reflectively accessing the field at run time.
// //
// Embed whether the field is used to the MSB of the offset.
if (!IsFieldUsed(field, options_)) {
format(" | 0x80000000u // unused\n");
}
// Embed whether the field is eagerly verified lazy or inlined string to the // Embed whether the field is eagerly verified lazy or inlined string to the
// LSB of the offset. // LSB of the offset.
if (IsEagerlyVerifiedLazy(field, options_, scc_analyzer_)) { if (IsEagerlyVerifiedLazy(field, options_, scc_analyzer_)) {
@ -2420,16 +2415,16 @@ void MessageGenerator::GenerateConstructorBody(io::Printer* printer,
std::string pod_template; std::string pod_template;
if (copy_constructor) { if (copy_constructor) {
pod_template = pod_template =
"::memcpy(&$first$_, &from.$first$_,\n" "::memcpy(&$first$, &from.$first$,\n"
" static_cast<size_t>(reinterpret_cast<char*>(&$last$_) -\n" " static_cast<size_t>(reinterpret_cast<char*>(&$last$) -\n"
" reinterpret_cast<char*>(&$first$_)) + sizeof($last$_));\n"; " reinterpret_cast<char*>(&$first$)) + sizeof($last$));\n";
} else { } else {
pod_template = pod_template =
"::memset(reinterpret_cast<char*>(this) + static_cast<size_t>(\n" "::memset(reinterpret_cast<char*>(this) + static_cast<size_t>(\n"
" reinterpret_cast<char*>(&$first$_) - " " reinterpret_cast<char*>(&$first$) - "
"reinterpret_cast<char*>(this)),\n" "reinterpret_cast<char*>(this)),\n"
" 0, static_cast<size_t>(reinterpret_cast<char*>(&$last$_) -\n" " 0, static_cast<size_t>(reinterpret_cast<char*>(&$last$) -\n"
" reinterpret_cast<char*>(&$first$_)) + sizeof($last$_));\n"; " reinterpret_cast<char*>(&$first$)) + sizeof($last$));\n";
} }
for (int i = 0; i < optimized_order_.size(); ++i) { for (int i = 0; i < optimized_order_.size(); ++i) {
@ -2445,9 +2440,9 @@ void MessageGenerator::GenerateConstructorBody(io::Printer* printer,
if (it != runs.end() && it->second > 1) { if (it != runs.end() && it->second > 1) {
// Use a memset, then skip run_length fields. // Use a memset, then skip run_length fields.
const size_t run_length = it->second; const size_t run_length = it->second;
const std::string first_field_name = FieldName(field); const std::string first_field_name = FieldMemberName(field);
const std::string last_field_name = const std::string last_field_name =
FieldName(optimized_order_[i + run_length - 1]); FieldMemberName(optimized_order_[i + run_length - 1]);
format.Set("first", first_field_name); format.Set("first", first_field_name);
format.Set("last", last_field_name); format.Set("last", last_field_name);
@ -2814,10 +2809,10 @@ void MessageGenerator::GenerateClear(io::Printer* printer) {
.GenerateMessageClearingCode(printer); .GenerateMessageClearingCode(printer);
} else { } else {
format( format(
"::memset(&$1$_, 0, static_cast<size_t>(\n" "::memset(&$1$, 0, static_cast<size_t>(\n"
" reinterpret_cast<char*>(&$2$_) -\n" " reinterpret_cast<char*>(&$2$) -\n"
" reinterpret_cast<char*>(&$1$_)) + sizeof($2$_));\n", " reinterpret_cast<char*>(&$1$)) + sizeof($2$));\n",
FieldName(memset_start), FieldName(memset_end)); FieldMemberName(memset_start), FieldMemberName(memset_end));
} }
} }
@ -2971,20 +2966,20 @@ void MessageGenerator::GenerateSwap(io::Printer* printer) {
if (it != runs.end() && it->second > 1) { if (it != runs.end() && it->second > 1) {
// Use a memswap, then skip run_length fields. // Use a memswap, then skip run_length fields.
const size_t run_length = it->second; const size_t run_length = it->second;
const std::string first_field_name = FieldName(field); const std::string first_field_name = FieldMemberName(field);
const std::string last_field_name = const std::string last_field_name =
FieldName(optimized_order_[i + run_length - 1]); FieldMemberName(optimized_order_[i + run_length - 1]);
format.Set("first", first_field_name); format.Set("first", first_field_name);
format.Set("last", last_field_name); format.Set("last", last_field_name);
format( format(
"::PROTOBUF_NAMESPACE_ID::internal::memswap<\n" "::PROTOBUF_NAMESPACE_ID::internal::memswap<\n"
" PROTOBUF_FIELD_OFFSET($classname$, $last$_)\n" " PROTOBUF_FIELD_OFFSET($classname$, $last$)\n"
" + sizeof($classname$::$last$_)\n" " + sizeof($classname$::$last$)\n"
" - PROTOBUF_FIELD_OFFSET($classname$, $first$_)>(\n" " - PROTOBUF_FIELD_OFFSET($classname$, $first$)>(\n"
" reinterpret_cast<char*>(&$first$_),\n" " reinterpret_cast<char*>(&$first$),\n"
" reinterpret_cast<char*>(&other->$first$_));\n"); " reinterpret_cast<char*>(&other->$first$));\n");
i += run_length - 1; i += run_length - 1;
// ++i at the top of the loop. // ++i at the top of the loop.

View File

@ -61,10 +61,10 @@ void SetMessageVariables(const FieldDescriptor* descriptor,
SetCommonFieldVariables(descriptor, variables, options); SetCommonFieldVariables(descriptor, variables, options);
(*variables)["type"] = FieldMessageTypeName(descriptor, options); (*variables)["type"] = FieldMessageTypeName(descriptor, options);
(*variables)["casted_member"] = ReinterpretCast( (*variables)["casted_member"] = ReinterpretCast(
(*variables)["type"] + "*", (*variables)["field_member"], implicit_weak); (*variables)["type"] + "*", (*variables)["field"], implicit_weak);
(*variables)["casted_member_const"] = (*variables)["casted_member_const"] =
ReinterpretCast("const " + (*variables)["type"] + "&", ReinterpretCast("const " + (*variables)["type"] + "&",
"*" + (*variables)["field_member"], implicit_weak); "*" + (*variables)["field"], implicit_weak);
(*variables)["type_default_instance"] = (*variables)["type_default_instance"] =
QualifiedDefaultInstanceName(descriptor->message_type(), options); QualifiedDefaultInstanceName(descriptor->message_type(), options);
(*variables)["type_default_instance_ptr"] = ReinterpretCast( (*variables)["type_default_instance_ptr"] = ReinterpretCast(
@ -435,7 +435,7 @@ void MessageFieldGenerator::GenerateDestructorCode(io::Printer* printer) const {
// care when handling them. // care when handling them.
format("if (this != internal_default_instance()) "); format("if (this != internal_default_instance()) ");
} }
format("delete $name$_;\n"); format("delete $field$;\n");
} }
void MessageFieldGenerator::GenerateConstructorCode( void MessageFieldGenerator::GenerateConstructorCode(
@ -443,7 +443,7 @@ void MessageFieldGenerator::GenerateConstructorCode(
GOOGLE_CHECK(!IsFieldStripped(descriptor_, options_)); GOOGLE_CHECK(!IsFieldStripped(descriptor_, options_));
Formatter format(printer, variables_); Formatter format(printer, variables_);
format("$name$_ = nullptr;\n"); format("$field$ = nullptr;\n");
} }
void MessageFieldGenerator::GenerateCopyConstructorCode( void MessageFieldGenerator::GenerateCopyConstructorCode(
@ -453,9 +453,9 @@ void MessageFieldGenerator::GenerateCopyConstructorCode(
Formatter format(printer, variables_); Formatter format(printer, variables_);
format( format(
"if (from._internal_has_$name$()) {\n" "if (from._internal_has_$name$()) {\n"
" $name$_ = new $type$(*from.$name$_);\n" " $field$ = new $type$(*from.$field$);\n"
"} else {\n" "} else {\n"
" $name$_ = nullptr;\n" " $field$ = nullptr;\n"
"}\n"); "}\n");
} }

View File

@ -87,11 +87,12 @@ bool IsFieldEligibleForFastParsing(
IsLazy(field, options, scc_analyzer)) { IsLazy(field, options, scc_analyzer)) {
return false; return false;
} }
switch (field->type()) {
// Groups are not handled on the fast path.
case FieldDescriptor::TYPE_GROUP:
return false;
// We will check for a valid auxiliary index range later. However, we might
// want to change the value we check for inlined string fields.
int aux_idx = entry.aux_idx;
switch (field->type()) {
case FieldDescriptor::TYPE_ENUM: case FieldDescriptor::TYPE_ENUM:
// If enum values are not validated at parse time, then this field can be // If enum values are not validated at parse time, then this field can be
// handled on the fast path like an int32. // handled on the fast path like an int32.
@ -106,11 +107,16 @@ bool IsFieldEligibleForFastParsing(
// Some bytes fields can be handled on fast path. // Some bytes fields can be handled on fast path.
case FieldDescriptor::TYPE_STRING: case FieldDescriptor::TYPE_STRING:
case FieldDescriptor::TYPE_BYTES: case FieldDescriptor::TYPE_BYTES:
if (field->options().ctype() != FieldOptions::STRING || if (field->options().ctype() != FieldOptions::STRING) {
!field->default_value_string().empty() ||
IsStringInlined(field, options)) {
return false; return false;
} }
if (IsStringInlined(field, options)) {
GOOGLE_CHECK(!field->is_repeated());
// For inlined strings, the donation state index is stored in the
// `aux_idx` field of the fast parsing info. We need to check the range
// of that value instead of the auxiliary index.
aux_idx = entry.inlined_string_idx;
}
break; break;
default: default:
@ -126,7 +132,7 @@ bool IsFieldEligibleForFastParsing(
// If the field needs auxiliary data, then the aux index is needed. This // If the field needs auxiliary data, then the aux index is needed. This
// must fit in a uint8_t. // must fit in a uint8_t.
if (entry.aux_idx > std::numeric_limits<uint8_t>::max()) { if (aux_idx > std::numeric_limits<uint8_t>::max()) {
return false; return false;
} }
@ -187,8 +193,13 @@ std::vector<TailCallTableInfo::FastFieldInfo> SplitFastFieldsForSize(
// If this field does not have presence, then it can set an out-of-bounds // If this field does not have presence, then it can set an out-of-bounds
// bit (tailcall parsing uses a uint64_t for hasbits, but only stores 32). // bit (tailcall parsing uses a uint64_t for hasbits, but only stores 32).
info.hasbit_idx = HasHasbit(field) ? entry.hasbit_idx : 63; info.hasbit_idx = HasHasbit(field) ? entry.hasbit_idx : 63;
if (IsStringInlined(field, options)) {
GOOGLE_CHECK(!field->is_repeated());
info.aux_idx = static_cast<uint8_t>(entry.inlined_string_idx);
} else {
info.aux_idx = static_cast<uint8_t>(entry.aux_idx); info.aux_idx = static_cast<uint8_t>(entry.aux_idx);
} }
}
return result; return result;
} }
@ -244,6 +255,7 @@ std::vector<const FieldDescriptor*> FilterMiniParsedFields(
break; break;
case FieldDescriptor::TYPE_MESSAGE: case FieldDescriptor::TYPE_MESSAGE:
case FieldDescriptor::TYPE_GROUP:
// TODO(b/210762816): support remaining field types. // TODO(b/210762816): support remaining field types.
if (field->is_map() || IsWeak(field, options) || if (field->is_map() || IsWeak(field, options) ||
IsImplicitWeakField(field, options, scc_analyzer) || IsImplicitWeakField(field, options, scc_analyzer) ||
@ -269,7 +281,9 @@ std::vector<const FieldDescriptor*> FilterMiniParsedFields(
TailCallTableInfo::TailCallTableInfo( TailCallTableInfo::TailCallTableInfo(
const Descriptor* descriptor, const Options& options, const Descriptor* descriptor, const Options& options,
const std::vector<const FieldDescriptor*>& ordered_fields, const std::vector<const FieldDescriptor*>& ordered_fields,
const std::vector<int>& has_bit_indices, MessageSCCAnalyzer* scc_analyzer) { const std::vector<int>& has_bit_indices,
const std::vector<int>& inlined_string_indices,
MessageSCCAnalyzer* scc_analyzer) {
int oneof_count = descriptor->real_oneof_decl_count(); int oneof_count = descriptor->real_oneof_decl_count();
// If this message has any oneof fields, store the case offset in the first // If this message has any oneof fields, store the case offset in the first
// auxiliary entry. // auxiliary entry.
@ -280,6 +294,22 @@ TailCallTableInfo::TailCallTableInfo(
aux_entries.push_back(StrCat( aux_entries.push_back(StrCat(
"_fl::Offset{offsetof(", ClassName(descriptor), ", _oneof_case_)}")); "_fl::Offset{offsetof(", ClassName(descriptor), ", _oneof_case_)}"));
} }
int inlined_string_count = 0;
for (const FieldDescriptor* field : ordered_fields) {
if (IsString(field, options) && IsStringInlined(field, options)) {
++inlined_string_count;
}
}
// If this message has any inlined string fields, store the donation state
// offset in the second auxiliary entry.
if (inlined_string_count > 0) {
aux_entries.resize(2); // pad if necessary
aux_entries[1] =
StrCat("_fl::Offset{offsetof(", ClassName(descriptor),
", _inlined_string_donated_)}");
}
// Fill in mini table entries. // Fill in mini table entries.
for (const FieldDescriptor* field : ordered_fields) { for (const FieldDescriptor* field : ordered_fields) {
field_entries.push_back( field_entries.push_back(
@ -344,6 +374,19 @@ TailCallTableInfo::TailCallTableInfo(
aux_entries.push_back( aux_entries.push_back(
StrCat(QualifiedClassName(enum_type, options), "_IsValid")); StrCat(QualifiedClassName(enum_type, options), "_IsValid"));
} }
} else if ((field->type() == FieldDescriptor::TYPE_STRING ||
field->type() == FieldDescriptor::TYPE_BYTES) &&
IsStringInlined(field, options)) {
GOOGLE_CHECK(!field->is_repeated());
// Inlined strings have an extra marker to represent their donation state.
int idx = inlined_string_indices[field->index()];
// For mini parsing, the donation state index is stored as an `offset`
// auxiliary entry.
entry.aux_idx = aux_entries.size();
aux_entries.push_back(StrCat("_fl::Offset{", idx, "}"));
// For fast table parsing, the donation state index is stored instead of
// the aux_idx (this will limit the range to 8 bits).
entry.inlined_string_idx = idx;
} }
} }
@ -403,7 +446,8 @@ ParseFunctionGenerator::ParseFunctionGenerator(
num_hasbits_(max_has_bit_index) { num_hasbits_(max_has_bit_index) {
if (should_generate_tctable()) { if (should_generate_tctable()) {
tc_table_info_.reset(new TailCallTableInfo( tc_table_info_.reset(new TailCallTableInfo(
descriptor_, options_, ordered_fields_, has_bit_indices, scc_analyzer)); descriptor_, options_, ordered_fields_, has_bit_indices,
inlined_string_indices, scc_analyzer));
} }
SetCommonVars(options_, &variables_); SetCommonVars(options_, &variables_);
SetUnknownFieldsVariable(descriptor_, options_, &variables_); SetUnknownFieldsVariable(descriptor_, options_, &variables_);
@ -528,6 +572,31 @@ void ParseFunctionGenerator::GenerateTailcallFallbackFunction(
"}\n"); "}\n");
} }
struct SkipEntry16 {
uint16_t skipmap;
uint16_t field_entry_offset;
};
struct SkipEntryBlock {
uint32_t first_fnum;
std::vector<SkipEntry16> entries;
};
struct NumToEntryTable {
uint32_t skipmap32; // for fields #1 - #32
std::vector<SkipEntryBlock> blocks;
// Compute the number of uint16_t required to represent this table.
int size16() const {
int size = 2; // for the termination field#
for (const auto& block : blocks) {
// 2 for the field#, 1 for a count of skip entries, 2 for each entry.
size += 3 + block.entries.size() * 2;
}
return size;
}
};
static NumToEntryTable MakeNumToEntryTable(
const std::vector<const FieldDescriptor*>& field_descriptors);
void ParseFunctionGenerator::GenerateDataDecls(io::Printer* printer) { void ParseFunctionGenerator::GenerateDataDecls(io::Printer* printer) {
if (!should_generate_tctable()) { if (!should_generate_tctable()) {
return; return;
@ -538,11 +607,13 @@ void ParseFunctionGenerator::GenerateDataDecls(io::Printer* printer) {
format("#ifdef PROTOBUF_TAIL_CALL_TABLE_PARSER_ENABLED\n"); format("#ifdef PROTOBUF_TAIL_CALL_TABLE_PARSER_ENABLED\n");
format.Indent(); format.Indent();
} }
auto field_num_to_entry_table = MakeNumToEntryTable(ordered_fields_);
format( format(
"static const ::$proto_ns$::internal::TcParseTable<$1$, $2$, $3$, $4$> " "static const ::$proto_ns$::internal::"
"_table_;\n", "TcParseTable<$1$, $2$, $3$, $4$, $5$> _table_;\n",
tc_table_info_->table_size_log2, ordered_fields_.size(), tc_table_info_->table_size_log2, ordered_fields_.size(),
tc_table_info_->aux_entries.size(), CalculateFieldNamesSize()); tc_table_info_->aux_entries.size(), CalculateFieldNamesSize(),
field_num_to_entry_table.size16());
if (should_generate_guarded_tctable()) { if (should_generate_guarded_tctable()) {
format.Outdent(); format.Outdent();
format("#endif // PROTOBUF_TAIL_CALL_TABLE_PARSER_ENABLED\n"); format("#endif // PROTOBUF_TAIL_CALL_TABLE_PARSER_ENABLED\n");
@ -610,6 +681,68 @@ void ParseFunctionGenerator::GenerateLoopingParseFunction(Formatter& format) {
"}\n"); "}\n");
} }
static NumToEntryTable MakeNumToEntryTable(
const std::vector<const FieldDescriptor*>& field_descriptors) {
NumToEntryTable num_to_entry_table;
num_to_entry_table.skipmap32 = static_cast<uint32_t>(-1);
// skip_entry_block is the current block of SkipEntries that we're
// appending to. cur_block_first_fnum is the number of the first
// field represented by the block.
uint16_t field_entry_index = 0;
uint16_t N = field_descriptors.size();
// First, handle field numbers 1-32, which affect only the initial
// skipmap32 and don't generate additional skip-entry blocks.
for (; field_entry_index != N; ++field_entry_index) {
auto* field_descriptor = field_descriptors[field_entry_index];
if (field_descriptor->number() > 32) break;
auto skipmap32_index = field_descriptor->number() - 1;
num_to_entry_table.skipmap32 -= 1 << skipmap32_index;
}
// If all the field numbers were less than or equal to 32, we will have
// no further entries to process, and we are already done.
if (field_entry_index == N) return num_to_entry_table;
SkipEntryBlock* block = nullptr;
bool start_new_block = true;
// To determine sparseness, track the field number corresponding to
// the start of the most recent skip entry.
uint32_t last_skip_entry_start = 0;
for (; field_entry_index != N; ++field_entry_index) {
auto* field_descriptor = field_descriptors[field_entry_index];
uint32_t fnum = field_descriptor->number();
GOOGLE_CHECK_GT(fnum, last_skip_entry_start);
if (start_new_block == false) {
// If the next field number is within 15 of the last_skip_entry_start, we
// continue writing just to that entry. If it's between 16 and 31 more,
// then we just extend the current block by one. If it's more than 31
// more, we have to add empty skip entries in order to continue using the
// existing block. Obviously it's just 32 more, it doesn't make sense to
// start a whole new block, since new blocks mean having to write out
// their starting field number, which is 32 bits, as well as the size of
// the additional block, which is 16... while an empty SkipEntry16 only
// costs 32 bits. So if it was 48 more, it's a slight space win; we save
// 16 bits, but probably at the cost of slower run time. We're choosing
// 96 for now.
if (fnum - last_skip_entry_start > 96) start_new_block = true;
}
if (start_new_block) {
num_to_entry_table.blocks.push_back(SkipEntryBlock{fnum});
block = &num_to_entry_table.blocks.back();
start_new_block = false;
}
auto skip_entry_num = (fnum - block->first_fnum) / 16;
auto skip_entry_index = (fnum - block->first_fnum) % 16;
while (skip_entry_num >= block->entries.size())
block->entries.push_back({0xFFFF, field_entry_index});
block->entries[skip_entry_num].skipmap -= 1 << (skip_entry_index);
last_skip_entry_start = fnum - skip_entry_index;
}
return num_to_entry_table;
}
void ParseFunctionGenerator::GenerateTailCallTable(Formatter& format) { void ParseFunctionGenerator::GenerateTailCallTable(Formatter& format) {
GOOGLE_CHECK(should_generate_tctable()); GOOGLE_CHECK(should_generate_tctable());
// All entries without a fast-path parsing function need a fallback. // All entries without a fast-path parsing function need a fallback.
@ -631,12 +764,15 @@ void ParseFunctionGenerator::GenerateTailCallTable(Formatter& format) {
// maps, weak fields, lazy, more than 1 extension range. In the cases // maps, weak fields, lazy, more than 1 extension range. In the cases
// the table is sufficient we can use a generic routine, that just handles // the table is sufficient we can use a generic routine, that just handles
// unknown fields and potentially an extension range. // unknown fields and potentially an extension range.
auto field_num_to_entry_table = MakeNumToEntryTable(ordered_fields_);
format( format(
"PROTOBUF_ATTRIBUTE_INIT_PRIORITY1\n" "PROTOBUF_ATTRIBUTE_INIT_PRIORITY1\n"
"const ::_pbi::TcParseTable<$1$, $2$, $3$, $4$> $classname$::_table_ = " "const ::_pbi::TcParseTable<$1$, $2$, $3$, $4$, $5$> "
"$classname$::_table_ = "
"{\n", "{\n",
tc_table_info_->table_size_log2, ordered_fields_.size(), tc_table_info_->table_size_log2, ordered_fields_.size(),
tc_table_info_->aux_entries.size(), CalculateFieldNamesSize()); tc_table_info_->aux_entries.size(), CalculateFieldNamesSize(),
field_num_to_entry_table.size16());
{ {
auto table_scope = format.ScopedIndent(); auto table_scope = format.ScopedIndent();
format("{\n"); format("{\n");
@ -659,28 +795,16 @@ void ParseFunctionGenerator::GenerateTailCallTable(Formatter& format) {
format("$1$, $2$, // max_field_number, fast_idx_mask\n", format("$1$, $2$, // max_field_number, fast_idx_mask\n",
(ordered_fields_.empty() ? 0 : ordered_fields_.back()->number()), (ordered_fields_.empty() ? 0 : ordered_fields_.back()->number()),
(((1 << tc_table_info_->table_size_log2) - 1) << 3)); (((1 << tc_table_info_->table_size_log2) - 1) << 3));
format(
// Determine the sequential fields that can be looked up by index: "offsetof(decltype(_table_), field_lookup_table),\n"
uint16_t num_sequential_fields = 0; "$1$, // skipmap\n",
uint16_t sequential_fields_start = 0; field_num_to_entry_table.skipmap32);
if (!ordered_fields_.empty() && if (ordered_fields_.empty()) {
ordered_fields_.front()->number() <= format(
std::numeric_limits<uint16_t>::max()) { "offsetof(decltype(_table_), field_names), // no field_entries\n");
sequential_fields_start = ordered_fields_[0]->number(); } else {
const FieldDescriptor* previous_field = ordered_fields_[0]; format("offsetof(decltype(_table_), field_entries),\n");
const int N = std::min(ordered_fields_.size(),
size_t{std::numeric_limits<uint8_t>::max()} + 1);
for (int i = 1; i < N; ++i) {
const FieldDescriptor* current_field = ordered_fields_[i];
if (current_field->number() > previous_field->number() + 1) {
break;
} }
++num_sequential_fields;
previous_field = current_field;
}
}
format("$1$, $2$, // num_sequential_fields, sequential_fields_start\n",
num_sequential_fields, sequential_fields_start);
format( format(
"$1$, // num_field_entries\n" "$1$, // num_field_entries\n"
@ -704,32 +828,41 @@ void ParseFunctionGenerator::GenerateTailCallTable(Formatter& format) {
auto fast_scope = format.ScopedIndent(); auto fast_scope = format.ScopedIndent();
GenerateFastFieldEntries(format); GenerateFastFieldEntries(format);
} }
format("}}, {{\n");
{
// field_lookup_table[]
auto field_lookup_scope = format.ScopedIndent();
int line_entries = 0;
for (int i = 0, N = field_num_to_entry_table.blocks.size(); i < N; ++i) {
SkipEntryBlock& entry_block = field_num_to_entry_table.blocks[i];
format("$1$, $2$, $3$,\n", entry_block.first_fnum & 65535,
entry_block.first_fnum / 65536, entry_block.entries.size());
for (auto se16 : entry_block.entries) {
if (line_entries == 0) {
format("$1$, $2$,", se16.skipmap, se16.field_entry_offset);
++line_entries;
} else if (line_entries < 5) {
format(" $1$, $2$,", se16.skipmap, se16.field_entry_offset);
++line_entries;
} else {
format(" $1$, $2$,\n", se16.skipmap, se16.field_entry_offset);
line_entries = 0;
}
}
}
if (line_entries) format("\n");
format("65535, 65535\n");
}
if (ordered_fields_.empty()) { if (ordered_fields_.empty()) {
GOOGLE_LOG_IF(DFATAL, !tc_table_info_->aux_entries.empty()) GOOGLE_LOG_IF(DFATAL, !tc_table_info_->aux_entries.empty())
<< "Invalid message: " << descriptor_->full_name() << " has " << "Invalid message: " << descriptor_->full_name() << " has "
<< tc_table_info_->aux_entries.size() << tc_table_info_->aux_entries.size()
<< " auxiliary field entries, but no fields"; << " auxiliary field entries, but no fields";
format("}},\n" format(
"// no field_numbers, field_entries, or aux_entries\n" "}},\n"
"// no field_entries, or aux_entries\n"
"{{\n"); "{{\n");
} else { } else {
format("}}, {{\n");
{
// field_numbers[]
auto field_number_scope = format.ScopedIndent();
for (int i = 0, N = ordered_fields_.size(); i < N; ++i) {
const FieldDescriptor* field = ordered_fields_[i];
if (i > 0) {
if (i % 10 == 0) {
format(",\n");
} else {
format(", ");
}
}
format("$1$", field->number());
}
format("\n");
}
format("}}, {{\n"); format("}}, {{\n");
{ {
// field_entries[] // field_entries[]
@ -755,7 +888,7 @@ void ParseFunctionGenerator::GenerateTailCallTable(Formatter& format) {
} // ordered_fields_.empty() } // ordered_fields_.empty()
{ {
// field_names[] // field_names[]
auto field_scope = format.ScopedIndent(); auto field_name_scope = format.ScopedIndent();
GenerateFieldNames(format); GenerateFieldNames(format);
} }
format("}},\n"); format("}},\n");
@ -991,12 +1124,6 @@ void ParseFunctionGenerator::GenerateArenaString(Formatter& format,
if (HasHasbit(field)) { if (HasHasbit(field)) {
format("_Internal::set_has_$1$(&$has_bits$);\n", FieldName(field)); format("_Internal::set_has_$1$(&$has_bits$);\n", FieldName(field));
} }
std::string default_string =
field->default_value_string().empty()
? "::" + ProtobufNamespace(options_) +
"::internal::GetEmptyStringAlreadyInited()"
: QualifiedClassName(field->containing_type(), options_) +
"::" + MakeDefaultName(field) + ".get()";
format( format(
"if (arena != nullptr) {\n" "if (arena != nullptr) {\n"
" ptr = ctx->ReadArenaString(ptr, &$msg$$field$, arena"); " ptr = ctx->ReadArenaString(ptr, &$msg$$field$, arena");
@ -1004,13 +1131,8 @@ void ParseFunctionGenerator::GenerateArenaString(Formatter& format,
GOOGLE_DCHECK(!inlined_string_indices_.empty()); GOOGLE_DCHECK(!inlined_string_indices_.empty());
int inlined_string_index = inlined_string_indices_[field->index()]; int inlined_string_index = inlined_string_indices_[field->index()];
GOOGLE_DCHECK_GT(inlined_string_index, 0); GOOGLE_DCHECK_GT(inlined_string_index, 0);
format( format(", &$msg$_inlined_string_donated_[0], $1$, $this$",
", $msg$_internal_$name$_donated()" inlined_string_index);
", &$msg$_inlined_string_donated_[$1$]"
", ~0x$2$u"
", $this$",
inlined_string_index / 32,
strings::Hex(1u << (inlined_string_index % 32), strings::ZERO_PAD_8));
} else { } else {
GOOGLE_DCHECK(field->default_value_string().empty()); GOOGLE_DCHECK(field->default_value_string().empty());
} }
@ -1305,7 +1427,7 @@ void ParseFunctionGenerator::GenerateFieldBody(
format("_Internal::set_has_$name$(&$has_bits$);\n"); format("_Internal::set_has_$name$(&$has_bits$);\n");
} }
format( format(
"$msg$$name$_ = " "$msg$$field$ = "
"::$proto_ns$::internal::UnalignedLoad<$primitive_type$>(ptr);\n" "::$proto_ns$::internal::UnalignedLoad<$primitive_type$>(ptr);\n"
"ptr += sizeof($primitive_type$);\n"); "ptr += sizeof($primitive_type$);\n");
} }
@ -1433,7 +1555,6 @@ void ParseFunctionGenerator::GenerateFieldSwitch(
format.Indent(); format.Indent();
for (const auto* field : fields) { for (const auto* field : fields) {
// Set abbreviated form instead of field_member.
format.Set("field", FieldMemberName(field)); format.Set("field", FieldMemberName(field));
PrintFieldComment(format, field); PrintFieldComment(format, field);
format("case $1$:\n", field->number()); format("case $1$:\n", field->number());
@ -1539,6 +1660,9 @@ std::string FieldParseFunctionName(
case FieldDescriptor::TYPE_BYTES: case FieldDescriptor::TYPE_BYTES:
name.append("B"); name.append("B");
if (IsStringInlined(field, options)) {
name.append("i");
}
break; break;
case FieldDescriptor::TYPE_STRING: case FieldDescriptor::TYPE_STRING:
switch (GetUtf8CheckMode(field, options)) { switch (GetUtf8CheckMode(field, options)) {
@ -1556,11 +1680,17 @@ std::string FieldParseFunctionName(
<< static_cast<int>(GetUtf8CheckMode(field, options)); << static_cast<int>(GetUtf8CheckMode(field, options));
return ""; return "";
} }
if (IsStringInlined(field, options)) {
name.append("i");
}
break; break;
case FieldDescriptor::TYPE_MESSAGE: case FieldDescriptor::TYPE_MESSAGE:
name.append("M"); name.append("M");
break; break;
case FieldDescriptor::TYPE_GROUP:
name.append("G");
break;
default: default:
GOOGLE_LOG(DFATAL) << "Type not handled: " << field->DebugString(); GOOGLE_LOG(DFATAL) << "Type not handled: " << field->DebugString();

View File

@ -51,6 +51,7 @@ struct TailCallTableInfo {
TailCallTableInfo(const Descriptor* descriptor, const Options& options, TailCallTableInfo(const Descriptor* descriptor, const Options& options,
const std::vector<const FieldDescriptor*>& ordered_fields, const std::vector<const FieldDescriptor*>& ordered_fields,
const std::vector<int>& has_bit_indices, const std::vector<int>& has_bit_indices,
const std::vector<int>& inlined_string_indices,
MessageSCCAnalyzer* scc_analyzer); MessageSCCAnalyzer* scc_analyzer);
// Fields parsed by the table fast-path. // Fields parsed by the table fast-path.
@ -67,6 +68,7 @@ struct TailCallTableInfo {
struct FieldEntryInfo { struct FieldEntryInfo {
const FieldDescriptor* field; const FieldDescriptor* field;
int hasbit_idx; int hasbit_idx;
int inlined_string_idx;
uint16_t aux_idx; uint16_t aux_idx;
// True for enums entirely covered by the start/length fields of FieldAux: // True for enums entirely covered by the start/length fields of FieldAux:
bool is_enum_range; bool is_enum_range;

View File

@ -186,13 +186,13 @@ void PrimitiveFieldGenerator::GenerateSwappingCode(io::Printer* printer) const {
void PrimitiveFieldGenerator::GenerateConstructorCode( void PrimitiveFieldGenerator::GenerateConstructorCode(
io::Printer* printer) const { io::Printer* printer) const {
Formatter format(printer, variables_); Formatter format(printer, variables_);
format("$name$_ = $default$;\n"); format("$field$ = $default$;\n");
} }
void PrimitiveFieldGenerator::GenerateCopyConstructorCode( void PrimitiveFieldGenerator::GenerateCopyConstructorCode(
io::Printer* printer) const { io::Printer* printer) const {
Formatter format(printer, variables_); Formatter format(printer, variables_);
format("$name$_ = from.$name$_;\n"); format("$field$ = from.$field$;\n");
} }
void PrimitiveFieldGenerator::GenerateSerializeWithCachedSizesToArray( void PrimitiveFieldGenerator::GenerateSerializeWithCachedSizesToArray(
@ -286,7 +286,7 @@ void PrimitiveOneofFieldGenerator::GenerateSwappingCode(
void PrimitiveOneofFieldGenerator::GenerateConstructorCode( void PrimitiveOneofFieldGenerator::GenerateConstructorCode(
io::Printer* printer) const { io::Printer* printer) const {
Formatter format(printer, variables_); Formatter format(printer, variables_);
format("$ns$::_$classname$_default_instance_.$name$_ = $default$;\n"); format("$ns$::_$classname$_default_instance_.$field$ = $default$;\n");
} }
// =================================================================== // ===================================================================

View File

@ -58,8 +58,8 @@ void SetStringVariables(const FieldDescriptor* descriptor,
(*variables)["default"] = DefaultValue(options, descriptor); (*variables)["default"] = DefaultValue(options, descriptor);
(*variables)["default_length"] = (*variables)["default_length"] =
StrCat(descriptor->default_value_string().length()); StrCat(descriptor->default_value_string().length());
std::string default_variable_string = MakeDefaultName(descriptor); (*variables)["default_variable_name"] = MakeDefaultName(descriptor);
(*variables)["default_variable_name"] = default_variable_string; (*variables)["default_variable_field"] = MakeDefaultFieldName(descriptor);
if (descriptor->default_value_string().empty()) { if (descriptor->default_value_string().empty()) {
(*variables)["default_string"] = kNS + "GetEmptyStringAlreadyInited()"; (*variables)["default_string"] = kNS + "GetEmptyStringAlreadyInited()";
@ -67,8 +67,8 @@ void SetStringVariables(const FieldDescriptor* descriptor,
(*variables)["lazy_variable_args"] = ""; (*variables)["lazy_variable_args"] = "";
} else { } else {
(*variables)["lazy_variable"] = (*variables)["lazy_variable"] =
QualifiedClassName(descriptor->containing_type(), options) + StrCat(QualifiedClassName(descriptor->containing_type(), options),
"::" + default_variable_string; "::", MakeDefaultFieldName(descriptor));
(*variables)["default_string"] = (*variables)["lazy_variable"] + ".get()"; (*variables)["default_string"] = (*variables)["lazy_variable"] + ".get()";
(*variables)["default_value"] = "nullptr"; (*variables)["default_value"] = "nullptr";
@ -205,7 +205,7 @@ void StringFieldGenerator::GenerateInlineAccessorDefinitions(
if (!descriptor_->default_value_string().empty()) { if (!descriptor_->default_value_string().empty()) {
format( format(
" if ($field$.IsDefault()) return " " if ($field$.IsDefault()) return "
"$default_variable_name$.get();\n"); "$default_variable_field$.get();\n");
} }
format( format(
" return _internal_$name$();\n" " return _internal_$name$();\n"
@ -345,7 +345,7 @@ void StringFieldGenerator::GenerateNonInlineAccessorDefinitions(
if (!descriptor_->default_value_string().empty()) { if (!descriptor_->default_value_string().empty()) {
format( format(
"const ::$proto_ns$::internal::LazyString " "const ::$proto_ns$::internal::LazyString "
"$classname$::$default_variable_name$" "$classname$::$default_variable_field$"
"{{{$default$, $default_length$}}, {nullptr}};\n"); "{{{$default$, $default_length$}}, {nullptr}};\n");
} }
} }
@ -430,12 +430,12 @@ void StringFieldGenerator::GenerateConstructorCode(io::Printer* printer) const {
return; return;
} }
GOOGLE_DCHECK(!inlined_); GOOGLE_DCHECK(!inlined_);
format("$name$_.InitDefault();\n"); format("$field$.InitDefault();\n");
if (IsString(descriptor_, options_) && if (IsString(descriptor_, options_) &&
descriptor_->default_value_string().empty()) { descriptor_->default_value_string().empty()) {
format( format(
"#ifdef PROTOBUF_FORCE_COPY_DEFAULT_STRING\n" "#ifdef PROTOBUF_FORCE_COPY_DEFAULT_STRING\n"
" $name$_.Set(\"\", GetArenaForAllocation());\n" " $field$.Set(\"\", GetArenaForAllocation());\n"
"#endif // PROTOBUF_FORCE_COPY_DEFAULT_STRING\n"); "#endif // PROTOBUF_FORCE_COPY_DEFAULT_STRING\n");
} }
} }
@ -445,7 +445,7 @@ void StringFieldGenerator::GenerateCopyConstructorCode(
Formatter format(printer, variables_); Formatter format(printer, variables_);
GenerateConstructorCode(printer); GenerateConstructorCode(printer);
if (inlined_) { if (inlined_) {
format("new (&$name$_) ::$proto_ns$::internal::InlinedStringField();\n"); format("new (&$field$) ::_pbi::InlinedStringField();\n");
} }
if (HasHasbit(descriptor_)) { if (HasHasbit(descriptor_)) {

View File

@ -156,7 +156,7 @@ template <>
FieldGeneratorMap<ImmutableFieldLiteGenerator>::~FieldGeneratorMap(); FieldGeneratorMap<ImmutableFieldLiteGenerator>::~FieldGeneratorMap();
// Field information used in FieldGeneartors. // Field information used in FieldGenerators.
struct FieldGeneratorInfo { struct FieldGeneratorInfo {
std::string name; std::string name;
std::string capitalized_name; std::string capitalized_name;

View File

@ -39,6 +39,7 @@
#include <google/protobuf/io/printer.h> #include <google/protobuf/io/printer.h>
#include <google/protobuf/io/zero_copy_stream.h> #include <google/protobuf/io/zero_copy_stream.h>
#include <google/protobuf/stubs/stringprintf.h>
#include <google/protobuf/compiler/java/java_file.h> #include <google/protobuf/compiler/java/java_file.h>
#include <google/protobuf/compiler/java/java_generator_factory.h> #include <google/protobuf/compiler/java/java_generator_factory.h>
#include <google/protobuf/compiler/java/java_helpers.h> #include <google/protobuf/compiler/java/java_helpers.h>

View File

@ -40,9 +40,9 @@
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include <google/protobuf/stubs/stringprintf.h>
#include <google/protobuf/wire_format.h> #include <google/protobuf/wire_format.h>
#include <google/protobuf/stubs/strutil.h> #include <google/protobuf/stubs/strutil.h>
#include <google/protobuf/stubs/stringprintf.h>
#include <google/protobuf/stubs/substitute.h> #include <google/protobuf/stubs/substitute.h>
#include <google/protobuf/compiler/java/java_name_resolver.h> #include <google/protobuf/compiler/java/java_name_resolver.h>
#include <google/protobuf/compiler/java/java_names.h> #include <google/protobuf/compiler/java/java_names.h>
@ -83,7 +83,8 @@ const char* kForbiddenWordList[] = {
"AllFields", "AllFields",
"DescriptorForType", "DescriptorForType",
"InitializationErrorString", "InitializationErrorString",
"UnknownFields", // TODO(b/219045204): re-enable
// "UnknownFields",
// obsolete. kept for backwards compatibility of generated code // obsolete. kept for backwards compatibility of generated code
"CachedSize", "CachedSize",
}; };

View File

@ -54,13 +54,13 @@
#include <google/protobuf/stubs/logging.h> #include <google/protobuf/stubs/logging.h>
#include <google/protobuf/stubs/common.h> #include <google/protobuf/stubs/common.h>
#include <google/protobuf/stubs/stringprintf.h>
#include <google/protobuf/compiler/python/python_helpers.h> #include <google/protobuf/compiler/python/python_helpers.h>
#include <google/protobuf/compiler/python/python_pyi_generator.h> #include <google/protobuf/compiler/python/python_pyi_generator.h>
#include <google/protobuf/io/printer.h> #include <google/protobuf/io/printer.h>
#include <google/protobuf/io/zero_copy_stream.h> #include <google/protobuf/io/zero_copy_stream.h>
#include <google/protobuf/descriptor.h> #include <google/protobuf/descriptor.h>
#include <google/protobuf/stubs/strutil.h> #include <google/protobuf/stubs/strutil.h>
#include <google/protobuf/stubs/stringprintf.h>
#include <google/protobuf/stubs/substitute.h> #include <google/protobuf/stubs/substitute.h>
#include <google/protobuf/descriptor.pb.h> #include <google/protobuf/descriptor.pb.h>

View File

@ -412,6 +412,10 @@ class FlatAllocatorImpl {
const auto push_name = [&](std::string new_name) { const auto push_name = [&](std::string new_name) {
for (size_t i = 0; i < names.size(); ++i) { for (size_t i = 0; i < names.size(); ++i) {
// Do not compare the full_name. It is unlikely to match, except in
// custom json_name. We are not taking this into account in
// PlanFieldNames so better to not try it.
if (i == 1) continue;
if (names[i] == new_name) return i; if (names[i] == new_name) return i;
} }
names.push_back(std::move(new_name)); names.push_back(std::move(new_name));

View File

@ -813,8 +813,8 @@ message SourceCodeInfo {
// location. // location.
// //
// Each element is a field number or an index. They form a path from // Each element is a field number or an index. They form a path from
// the root FileDescriptorProto to the place where the definition occurs. For // the root FileDescriptorProto to the place where the definition occurs.
// example, this path: // For example, this path:
// [ 4, 3, 2, 7, 1 ] // [ 4, 3, 2, 7, 1 ]
// refers to: // refers to:
// file.message_type(3) // 4, 3 // file.message_type(3) // 4, 3

View File

@ -865,6 +865,22 @@ TEST_F(DescriptorTest, FieldNamesDedup) {
ElementsAre("fieldname7")); ElementsAre("fieldname7"));
} }
TEST_F(DescriptorTest, FieldNameDedupJsonEqFull) {
// Test a regression where json_name == full_name
FileDescriptorProto proto;
proto.set_name("file");
auto* message = AddMessage(&proto, "Name1");
auto* field =
AddField(message, "Name2", 1, FieldDescriptorProto::LABEL_OPTIONAL,
FieldDescriptorProto::TYPE_INT32);
field->set_json_name("Name1.Name2");
auto* file = pool_.BuildFile(proto);
EXPECT_EQ(file->message_type(0)->name(), "Name1");
EXPECT_EQ(file->message_type(0)->field(0)->name(), "Name2");
EXPECT_EQ(file->message_type(0)->field(0)->full_name(), "Name1.Name2");
EXPECT_EQ(file->message_type(0)->field(0)->json_name(), "Name1.Name2");
}
TEST_F(DescriptorTest, FieldsByIndex) { TEST_F(DescriptorTest, FieldsByIndex) {
ASSERT_EQ(4, message_->field_count()); ASSERT_EQ(4, message_->field_count());
EXPECT_EQ(foo_, message_->field(0)); EXPECT_EQ(foo_, message_->field(0));

View File

@ -75,7 +75,6 @@ using google::protobuf::internal::RepeatedPtrFieldBase;
using google::protobuf::internal::StringSpaceUsedExcludingSelfLong; using google::protobuf::internal::StringSpaceUsedExcludingSelfLong;
using google::protobuf::internal::WrappedMutex; using google::protobuf::internal::WrappedMutex;
namespace google { namespace google {
namespace protobuf { namespace protobuf {

View File

@ -227,13 +227,6 @@ struct ReflectionSchema {
return false; return false;
} }
// Returns true if the field's accessor is called by any external code (aka,
// non proto library code).
bool IsFieldUsed(const FieldDescriptor* field) const {
(void)field;
return true;
}
bool IsFieldStripped(const FieldDescriptor* field) const { bool IsFieldStripped(const FieldDescriptor* field) const {
(void)field; (void)field;
return false; return false;
@ -271,9 +264,9 @@ struct ReflectionSchema {
if (type == FieldDescriptor::TYPE_MESSAGE || if (type == FieldDescriptor::TYPE_MESSAGE ||
type == FieldDescriptor::TYPE_STRING || type == FieldDescriptor::TYPE_STRING ||
type == FieldDescriptor::TYPE_BYTES) { type == FieldDescriptor::TYPE_BYTES) {
return v & 0x7FFFFFFEu; return v & 0xFFFFFFFEu;
} }
return v & 0x7FFFFFFFu; return v;
} }
static bool Inlined(uint32_t v, FieldDescriptor::Type type) { static bool Inlined(uint32_t v, FieldDescriptor::Type type) {

View File

@ -129,8 +129,9 @@ struct alignas(uint64_t) TcParseTableBase {
uint32_t extension_range_high; uint32_t extension_range_high;
uint32_t max_field_number; uint32_t max_field_number;
uint8_t fast_idx_mask; uint8_t fast_idx_mask;
uint8_t num_sequential_fields; uint16_t lookup_table_offset;
uint16_t sequential_fields_start; uint32_t skipmap32;
uint32_t field_entries_offset;
uint16_t num_field_entries; uint16_t num_field_entries;
uint16_t num_aux_entries; uint16_t num_aux_entries;
@ -150,8 +151,9 @@ struct alignas(uint64_t) TcParseTableBase {
uint16_t has_bits_offset, uint16_t extension_offset, uint16_t has_bits_offset, uint16_t extension_offset,
uint32_t extension_range_low, uint32_t extension_range_high, uint32_t extension_range_low, uint32_t extension_range_high,
uint32_t max_field_number, uint8_t fast_idx_mask, uint32_t max_field_number, uint8_t fast_idx_mask,
uint8_t num_sequential_fields, uint16_t sequential_fields_start, uint16_t lookup_table_offset, uint32_t skipmap32,
uint16_t num_field_entries, uint16_t num_aux_entries, uint32_t aux_offset, uint32_t field_entries_offset, uint16_t num_field_entries,
uint16_t num_aux_entries, uint32_t aux_offset,
const MessageLite* default_instance, TailCallParseFunc fallback) const MessageLite* default_instance, TailCallParseFunc fallback)
: has_bits_offset(has_bits_offset), : has_bits_offset(has_bits_offset),
extension_offset(extension_offset), extension_offset(extension_offset),
@ -159,8 +161,9 @@ struct alignas(uint64_t) TcParseTableBase {
extension_range_high(extension_range_high), extension_range_high(extension_range_high),
max_field_number(max_field_number), max_field_number(max_field_number),
fast_idx_mask(fast_idx_mask), fast_idx_mask(fast_idx_mask),
num_sequential_fields(num_sequential_fields), lookup_table_offset(lookup_table_offset),
sequential_fields_start(sequential_fields_start), skipmap32(skipmap32),
field_entries_offset(field_entries_offset),
num_field_entries(num_field_entries), num_field_entries(num_field_entries),
num_aux_entries(num_aux_entries), num_aux_entries(num_aux_entries),
aux_offset(aux_offset), aux_offset(aux_offset),
@ -179,16 +182,10 @@ struct alignas(uint64_t) TcParseTableBase {
return reinterpret_cast<const FastFieldEntry*>(this + 1) + idx; return reinterpret_cast<const FastFieldEntry*>(this + 1) + idx;
} }
// Returns a begin/end iterator (pointer) for the field numbers array. // Returns a begin iterator (pointer) to the start of the field lookup table.
// The field numbers are a parallel array to the `FieldEntry` array. Note that const uint16_t* field_lookup_begin() const {
// not all numbers may be valid fields; in these cases, the corresponding return reinterpret_cast<const uint16_t*>(reinterpret_cast<uintptr_t>(this) +
// field entry will have a field kind of `field_layout::kFkNone`. lookup_table_offset);
const uint32_t* field_numbers_begin() const {
return reinterpret_cast<const uint32_t*>(
fast_entry((fast_idx_mask >> 3) + 1));
}
const uint32_t* field_numbers_end() const {
return field_numbers_begin() + num_field_entries;
} }
// Field entry for all fields. // Field entry for all fields.
@ -201,7 +198,8 @@ struct alignas(uint64_t) TcParseTableBase {
// Returns a begin iterator (pointer) to the start of the field entries array. // Returns a begin iterator (pointer) to the start of the field entries array.
const FieldEntry* field_entries_begin() const { const FieldEntry* field_entries_begin() const {
return reinterpret_cast<const FieldEntry*>(field_numbers_end()); return reinterpret_cast<const FieldEntry*>(
reinterpret_cast<uintptr_t>(this) + field_entries_offset);
} }
// Auxiliary entries for field types that need extra information. // Auxiliary entries for field types that need extra information.
@ -248,7 +246,8 @@ static_assert(sizeof(TcParseTableBase::FieldEntry) <= 16,
"Field entry is too big."); "Field entry is too big.");
template <size_t kFastTableSizeLog2, size_t kNumFieldEntries = 0, template <size_t kFastTableSizeLog2, size_t kNumFieldEntries = 0,
size_t kNumFieldAux = 0, size_t kNameTableSize = 0> size_t kNumFieldAux = 0, size_t kNameTableSize = 0,
size_t kFieldLookupSize = 2>
struct TcParseTable { struct TcParseTable {
TcParseTableBase header; TcParseTableBase header;
@ -261,8 +260,9 @@ struct TcParseTable {
std::array<TcParseTableBase::FastFieldEntry, (1 << kFastTableSizeLog2)> std::array<TcParseTableBase::FastFieldEntry, (1 << kFastTableSizeLog2)>
fast_entries; fast_entries;
// Just big enough to find all the field entries.
std::array<uint16_t, kFieldLookupSize> field_lookup_table;
// Entries for all fields: // Entries for all fields:
std::array<uint32_t, kNumFieldEntries> field_numbers;
std::array<TcParseTableBase::FieldEntry, kNumFieldEntries> field_entries; std::array<TcParseTableBase::FieldEntry, kNumFieldEntries> field_entries;
std::array<TcParseTableBase::FieldAux, kNumFieldAux> aux_entries; std::array<TcParseTableBase::FieldAux, kNumFieldAux> aux_entries;
std::array<char, kNameTableSize> field_names; std::array<char, kNameTableSize> field_names;
@ -273,24 +273,26 @@ struct TcParseTable {
// However, different implementations have different sizeof(std::array<T, 0>). // However, different implementations have different sizeof(std::array<T, 0>).
// Skipping the member makes offset computations portable. // Skipping the member makes offset computations portable.
template <size_t kFastTableSizeLog2, size_t kNumFieldEntries, template <size_t kFastTableSizeLog2, size_t kNumFieldEntries,
size_t kNameTableSize> size_t kNameTableSize, size_t kFieldLookupSize>
struct TcParseTable<kFastTableSizeLog2, kNumFieldEntries, 0, kNameTableSize> { struct TcParseTable<kFastTableSizeLog2, kNumFieldEntries, 0, kNameTableSize,
kFieldLookupSize> {
TcParseTableBase header; TcParseTableBase header;
std::array<TcParseTableBase::FastFieldEntry, (1 << kFastTableSizeLog2)> std::array<TcParseTableBase::FastFieldEntry, (1 << kFastTableSizeLog2)>
fast_entries; fast_entries;
std::array<uint32_t, kNumFieldEntries> field_numbers; std::array<uint16_t, kFieldLookupSize> field_lookup_table;
std::array<TcParseTableBase::FieldEntry, kNumFieldEntries> field_entries; std::array<TcParseTableBase::FieldEntry, kNumFieldEntries> field_entries;
std::array<char, kNameTableSize> field_names; std::array<char, kNameTableSize> field_names;
}; };
// Partial specialization: if there are no fields at all, then we can save space // Partial specialization: if there are no fields at all, then we can save space
// by skipping the field numbers and entries. // by skipping the field numbers and entries.
template <size_t kNameTableSize> template <size_t kNameTableSize, size_t kFieldLookupSize>
struct TcParseTable<0, 0, 0, kNameTableSize> { struct TcParseTable<0, 0, 0, kNameTableSize, kFieldLookupSize> {
TcParseTableBase header; TcParseTableBase header;
// N.B.: the fast entries are sized by log2, so 2**0 fields = 1 entry. // N.B.: the fast entries are sized by log2, so 2**0 fields = 1 entry.
// The fast parsing loop will always use this entry, so it must be present. // The fast parsing loop will always use this entry, so it must be present.
std::array<TcParseTableBase::FastFieldEntry, 1> fast_entries; std::array<TcParseTableBase::FastFieldEntry, 1> fast_entries;
std::array<uint16_t, kFieldLookupSize> field_lookup_table;
std::array<char, kNameTableSize> field_names; std::array<char, kNameTableSize> field_names;
}; };

View File

@ -339,6 +339,7 @@ class PROTOBUF_EXPORT TcParser final {
// Functions referenced by generated fast tables (string types): // Functions referenced by generated fast tables (string types):
// B: bytes S: string U: UTF-8 string // B: bytes S: string U: UTF-8 string
// (empty): ArenaStringPtr i: InlinedString
// S: singular R: repeated // S: singular R: repeated
// 1/2: tag length (bytes) // 1/2: tag length (bytes)
static const char* FastBS1(PROTOBUF_TC_PARAM_DECL); static const char* FastBS1(PROTOBUF_TC_PARAM_DECL);
@ -354,14 +355,25 @@ class PROTOBUF_EXPORT TcParser final {
static const char* FastUR1(PROTOBUF_TC_PARAM_DECL); static const char* FastUR1(PROTOBUF_TC_PARAM_DECL);
static const char* FastUR2(PROTOBUF_TC_PARAM_DECL); static const char* FastUR2(PROTOBUF_TC_PARAM_DECL);
static const char* FastBiS1(PROTOBUF_TC_PARAM_DECL);
static const char* FastBiS2(PROTOBUF_TC_PARAM_DECL);
static const char* FastSiS1(PROTOBUF_TC_PARAM_DECL);
static const char* FastSiS2(PROTOBUF_TC_PARAM_DECL);
static const char* FastUiS1(PROTOBUF_TC_PARAM_DECL);
static const char* FastUiS2(PROTOBUF_TC_PARAM_DECL);
// Functions referenced by generated fast tables (message types): // Functions referenced by generated fast tables (message types):
// M: message // M: message G: group
// S: singular R: repeated // S: singular R: repeated
// 1/2: tag length (bytes) // 1/2: tag length (bytes)
static const char* FastMS1(PROTOBUF_TC_PARAM_DECL); static const char* FastMS1(PROTOBUF_TC_PARAM_DECL);
static const char* FastMS2(PROTOBUF_TC_PARAM_DECL); static const char* FastMS2(PROTOBUF_TC_PARAM_DECL);
static const char* FastMR1(PROTOBUF_TC_PARAM_DECL); static const char* FastMR1(PROTOBUF_TC_PARAM_DECL);
static const char* FastMR2(PROTOBUF_TC_PARAM_DECL); static const char* FastMR2(PROTOBUF_TC_PARAM_DECL);
static const char* FastGS1(PROTOBUF_TC_PARAM_DECL);
static const char* FastGS2(PROTOBUF_TC_PARAM_DECL);
static const char* FastGR1(PROTOBUF_TC_PARAM_DECL);
static const char* FastGR2(PROTOBUF_TC_PARAM_DECL);
template <typename T> template <typename T>
static inline T& RefAt(void* x, size_t offset) { static inline T& RefAt(void* x, size_t offset) {
@ -402,9 +414,9 @@ class PROTOBUF_EXPORT TcParser final {
private: private:
friend class GeneratedTcTableLiteTest; friend class GeneratedTcTableLiteTest;
template <typename TagType> template <typename TagType, bool group_coding>
static inline const char* SingularParseMessageAuxImpl(PROTOBUF_TC_PARAM_DECL); static inline const char* SingularParseMessageAuxImpl(PROTOBUF_TC_PARAM_DECL);
template <typename TagType> template <typename TagType, bool group_coding>
static inline const char* RepeatedParseMessageAuxImpl(PROTOBUF_TC_PARAM_DECL); static inline const char* RepeatedParseMessageAuxImpl(PROTOBUF_TC_PARAM_DECL);
static inline PROTOBUF_ALWAYS_INLINE void SyncHasbits( static inline PROTOBUF_ALWAYS_INLINE void SyncHasbits(
@ -496,7 +508,7 @@ class PROTOBUF_EXPORT TcParser final {
MessageLite* msg); MessageLite* msg);
// UTF-8 validation: // UTF-8 validation:
static void ReportFastUtf8Error(uint16_t coded_tag, static void ReportFastUtf8Error(uint32_t decoded_tag,
const TcParseTableBase* table); const TcParseTableBase* table);
static bool MpVerifyUtf8(StringPiece wire_bytes, static bool MpVerifyUtf8(StringPiece wire_bytes,
const TcParseTableBase* table, const TcParseTableBase* table,

View File

@ -137,62 +137,121 @@ inline PROTOBUF_ALWAYS_INLINE const char* TcParser::Error(
return nullptr; return nullptr;
} }
// On the fast path, a (matching) 1-byte tag already has the decoded value.
static uint32_t FastDecodeTag(uint8_t coded_tag) {
return coded_tag;
}
// On the fast path, a (matching) 2-byte tag always needs to be decoded.
static uint32_t FastDecodeTag(uint16_t coded_tag) {
uint32_t result = coded_tag;
result += static_cast<int8_t>(coded_tag);
return result >> 1;
}
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
// Core mini parsing implementation: // Core mini parsing implementation:
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
// Field lookup table layout:
//
// Because it consists of a series of variable-length segments, the lookuup
// table is organized within an array of uint16_t, and each element is either
// a uint16_t or a uint32_t stored little-endian as a pair of uint16_t.
//
// Its fundamental building block maps 16 contiguously ascending field numbers
// to their locations within the field entry table:
struct SkipEntry16 {
uint16_t skipmap;
uint16_t field_entry_offset;
};
// The skipmap is a bitfield of which of those field numbers do NOT have a
// field entry. The lowest bit of the skipmap corresponds to the lowest of
// the 16 field numbers, so if a proto had only fields 1, 2, 3, and 7, the
// skipmap would contain 0b11111111'10111000.
//
// The field lookup table begins with a single 32-bit skipmap that maps the
// field numbers 1 through 32. This is because the majority of proto
// messages only contain fields numbered 1 to 32.
//
// The rest of the lookup table is a repeated series of
// { 32-bit field #, #SkipEntry16s, {SkipEntry16...} }
// That is, the next thing is a pair of uint16_t that form the next
// lowest field number that the lookup table handles. If this number is -1,
// that is the end of the table. Then there is a uint16_t that is
// the number of contiguous SkipEntry16 entries that follow, and then of
// course the SkipEntry16s themselves.
// Originally developed and tested at https://godbolt.org/z/vbc7enYcf
// Returns the address of the field for `tag` in the table's field entries. // Returns the address of the field for `tag` in the table's field entries.
// Returns nullptr if the field was not found. // Returns nullptr if the field was not found.
const TcParseTableBase::FieldEntry* TcParser::FindFieldEntry( const TcParseTableBase::FieldEntry* TcParser::FindFieldEntry(
const TcParseTableBase* table, uint32_t field_num) { const TcParseTableBase* table, uint32_t field_num) {
const FieldEntry* const field_entries = table->field_entries_begin(); const FieldEntry* const field_entries = table->field_entries_begin();
// Most messages have fields numbered sequentially. If the decoded tag is uint32_t fstart = 1;
// within that range, we can look up the field by index. uint32_t adj_fnum = field_num - fstart;
const uint32_t sequential_start = table->sequential_fields_start;
uint32_t adjusted_field_num = field_num - sequential_start;
const uint32_t num_sequential = table->num_sequential_fields;
if (PROTOBUF_PREDICT_TRUE(adjusted_field_num < num_sequential)) {
return field_entries + adjusted_field_num;
}
// Check if this field is larger than the max in the table. This is often an if (PROTOBUF_PREDICT_TRUE(adj_fnum < 32)) {
// extension. uint32_t skipmap = table->skipmap32;
if (field_num > table->max_field_number) { uint32_t skipbit = 1 << adj_fnum;
return nullptr; if (PROTOBUF_PREDICT_FALSE(skipmap & skipbit)) return nullptr;
skipmap &= skipbit - 1;
#if (__GNUC__ || __clang__) && __POPCNT__
// Note: here and below, skipmap typically has very few set bits
// (31 in the worst case, but usually zero) so a loop isn't that
// bad, and a compiler-generated popcount is typically only
// worthwhile if the processor itself has hardware popcount support.
adj_fnum -= __builtin_popcount(skipmap);
#else
while (skipmap) {
--adj_fnum;
skipmap &= skipmap - 1;
} }
#endif
// Otherwise, scan the next few field numbers, skipping the first return field_entries + adj_fnum;
// `num_sequential` entries.
const uint32_t* const field_num_begin = table->field_numbers_begin();
const uint32_t small_scan_limit =
std::min(num_sequential + kMtSmallScanSize,
static_cast<uint32_t>(table->num_field_entries));
for (uint32_t i = num_sequential; i < small_scan_limit; ++i) {
if (field_num <= field_num_begin[i]) {
if (PROTOBUF_PREDICT_FALSE(field_num != field_num_begin[i])) {
// Field number mismatch.
return nullptr;
} }
return field_entries + i; const uint16_t* lookup_table = table->field_lookup_begin();
for (;;) {
#ifdef PROTOBUF_LITTLE_ENDIAN
memcpy(&fstart, lookup_table, sizeof(fstart));
#else
fstart = lookup_table[0] | (lookup_table[1] << 16);
#endif
lookup_table += sizeof(fstart) / sizeof(*lookup_table);
uint32_t num_skip_entries = *lookup_table++;
if (field_num < fstart) return nullptr;
adj_fnum = field_num - fstart;
uint32_t skip_num = adj_fnum / 16;
if (PROTOBUF_PREDICT_TRUE(skip_num < num_skip_entries)) {
// for each group of 16 fields we have:
// a bitmap of 16 bits
// a 16-bit field-entry offset for the first of them.
auto* skip_data = lookup_table + (adj_fnum / 16) * (sizeof(SkipEntry16) /
sizeof(uint16_t));
SkipEntry16 se = {skip_data[0], skip_data[1]};
adj_fnum &= 15;
uint32_t skipmap = se.skipmap;
uint16_t skipbit = 1 << adj_fnum;
if (PROTOBUF_PREDICT_FALSE(skipmap & skipbit)) return nullptr;
skipmap &= skipbit - 1;
adj_fnum += se.field_entry_offset;
#if (__GNUC__ || __clang__) && __POPCNT__
adj_fnum -= __builtin_popcount(skipmap);
#else
while (skipmap) {
--adj_fnum;
skipmap &= skipmap - 1;
} }
#endif
return field_entries + adj_fnum;
} }
lookup_table +=
// Finally, look up with binary search. num_skip_entries * (sizeof(SkipEntry16) / sizeof(*lookup_table));
const uint32_t* const field_num_end = table->field_numbers_end();
auto it = std::lower_bound(field_num_begin + small_scan_limit, field_num_end,
field_num);
if (it == field_num_end) {
// The only reason for binary search failing is if there was nothing to
// search.
GOOGLE_DCHECK_EQ(field_num_begin + small_scan_limit, field_num_end) << field_num;
return nullptr;
} }
if (PROTOBUF_PREDICT_FALSE(*it != field_num)) {
// Field number mismatch.
return nullptr;
}
return field_entries + (it - field_num_begin);
} }
// Field names are stored in a format of: // Field names are stored in a format of:
@ -298,12 +357,13 @@ inline PROTOBUF_ALWAYS_INLINE void InvertPacked(TcFieldData& data) {
// Message fields // Message fields
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
template <typename TagType> template <typename TagType, bool group_coding>
inline PROTOBUF_ALWAYS_INLINE inline PROTOBUF_ALWAYS_INLINE
const char* TcParser::SingularParseMessageAuxImpl(PROTOBUF_TC_PARAM_DECL) { const char* TcParser::SingularParseMessageAuxImpl(PROTOBUF_TC_PARAM_DECL) {
if (PROTOBUF_PREDICT_FALSE(data.coded_tag<TagType>() != 0)) { if (PROTOBUF_PREDICT_FALSE(data.coded_tag<TagType>() != 0)) {
PROTOBUF_MUSTTAIL return MiniParse(PROTOBUF_TC_PARAM_PASS); PROTOBUF_MUSTTAIL return MiniParse(PROTOBUF_TC_PARAM_PASS);
} }
auto saved_tag = UnalignedLoad<TagType>(ptr);
ptr += sizeof(TagType); ptr += sizeof(TagType);
hasbits |= (uint64_t{1} << data.hasbit_idx()); hasbits |= (uint64_t{1} << data.hasbit_idx());
auto& field = RefAt<MessageLite*>(msg, data.offset()); auto& field = RefAt<MessageLite*>(msg, data.offset());
@ -313,25 +373,39 @@ const char* TcParser::SingularParseMessageAuxImpl(PROTOBUF_TC_PARAM_DECL) {
field = default_instance->New(ctx->data().arena); field = default_instance->New(ctx->data().arena);
} }
SyncHasbits(msg, hasbits, table); SyncHasbits(msg, hasbits, table);
if (group_coding) {
return ctx->ParseGroup(field, ptr, FastDecodeTag(saved_tag));
}
return ctx->ParseMessage(field, ptr); return ctx->ParseMessage(field, ptr);
} }
const char* TcParser::FastMS1(PROTOBUF_TC_PARAM_DECL) { const char* TcParser::FastMS1(PROTOBUF_TC_PARAM_DECL) {
PROTOBUF_MUSTTAIL return SingularParseMessageAuxImpl<uint8_t>( PROTOBUF_MUSTTAIL return SingularParseMessageAuxImpl<uint8_t, false>(
PROTOBUF_TC_PARAM_PASS); PROTOBUF_TC_PARAM_PASS);
} }
const char* TcParser::FastMS2(PROTOBUF_TC_PARAM_DECL) { const char* TcParser::FastMS2(PROTOBUF_TC_PARAM_DECL) {
PROTOBUF_MUSTTAIL return SingularParseMessageAuxImpl<uint16_t>( PROTOBUF_MUSTTAIL return SingularParseMessageAuxImpl<uint16_t, false>(
PROTOBUF_TC_PARAM_PASS); PROTOBUF_TC_PARAM_PASS);
} }
template <typename TagType> const char* TcParser::FastGS1(PROTOBUF_TC_PARAM_DECL) {
PROTOBUF_MUSTTAIL return SingularParseMessageAuxImpl<uint8_t, true>(
PROTOBUF_TC_PARAM_PASS);
}
const char* TcParser::FastGS2(PROTOBUF_TC_PARAM_DECL) {
PROTOBUF_MUSTTAIL return SingularParseMessageAuxImpl<uint16_t, true>(
PROTOBUF_TC_PARAM_PASS);
}
template <typename TagType, bool group_coding>
inline PROTOBUF_ALWAYS_INLINE inline PROTOBUF_ALWAYS_INLINE
const char* TcParser::RepeatedParseMessageAuxImpl(PROTOBUF_TC_PARAM_DECL) { const char* TcParser::RepeatedParseMessageAuxImpl(PROTOBUF_TC_PARAM_DECL) {
if (PROTOBUF_PREDICT_FALSE(data.coded_tag<TagType>() != 0)) { if (PROTOBUF_PREDICT_FALSE(data.coded_tag<TagType>() != 0)) {
PROTOBUF_MUSTTAIL return MiniParse(PROTOBUF_TC_PARAM_PASS); PROTOBUF_MUSTTAIL return MiniParse(PROTOBUF_TC_PARAM_PASS);
} }
auto saved_tag = UnalignedLoad<TagType>(ptr);
ptr += sizeof(TagType); ptr += sizeof(TagType);
SyncHasbits(msg, hasbits, table); SyncHasbits(msg, hasbits, table);
const MessageLite* default_instance = const MessageLite* default_instance =
@ -339,16 +413,29 @@ const char* TcParser::RepeatedParseMessageAuxImpl(PROTOBUF_TC_PARAM_DECL) {
auto& field = RefAt<RepeatedPtrFieldBase>(msg, data.offset()); auto& field = RefAt<RepeatedPtrFieldBase>(msg, data.offset());
MessageLite* submsg = MessageLite* submsg =
field.Add<GenericTypeHandler<MessageLite>>(default_instance); field.Add<GenericTypeHandler<MessageLite>>(default_instance);
if (group_coding) {
return ctx->ParseGroup(submsg, ptr, FastDecodeTag(saved_tag));
}
return ctx->ParseMessage(submsg, ptr); return ctx->ParseMessage(submsg, ptr);
} }
const char* TcParser::FastMR1(PROTOBUF_TC_PARAM_DECL) { const char* TcParser::FastMR1(PROTOBUF_TC_PARAM_DECL) {
PROTOBUF_MUSTTAIL return RepeatedParseMessageAuxImpl<uint8_t>( PROTOBUF_MUSTTAIL return RepeatedParseMessageAuxImpl<uint8_t, false>(
PROTOBUF_TC_PARAM_PASS); PROTOBUF_TC_PARAM_PASS);
} }
const char* TcParser::FastMR2(PROTOBUF_TC_PARAM_DECL) { const char* TcParser::FastMR2(PROTOBUF_TC_PARAM_DECL) {
PROTOBUF_MUSTTAIL return RepeatedParseMessageAuxImpl<uint16_t>( PROTOBUF_MUSTTAIL return RepeatedParseMessageAuxImpl<uint16_t, false>(
PROTOBUF_TC_PARAM_PASS);
}
const char* TcParser::FastGR1(PROTOBUF_TC_PARAM_DECL) {
PROTOBUF_MUSTTAIL return RepeatedParseMessageAuxImpl<uint8_t, true>(
PROTOBUF_TC_PARAM_PASS);
}
const char* TcParser::FastGR2(PROTOBUF_TC_PARAM_DECL) {
PROTOBUF_MUSTTAIL return RepeatedParseMessageAuxImpl<uint16_t, true>(
PROTOBUF_TC_PARAM_PASS); PROTOBUF_TC_PARAM_PASS);
} }
@ -970,12 +1057,9 @@ void PrintUTF8ErrorLog(StringPiece message_name,
StringPiece field_name, const char* operation_str, StringPiece field_name, const char* operation_str,
bool emit_stacktrace); bool emit_stacktrace);
void TcParser::ReportFastUtf8Error(uint16_t coded_tag, void TcParser::ReportFastUtf8Error(uint32_t decoded_tag,
const TcParseTableBase* table) { const TcParseTableBase* table) {
if (coded_tag > 127) { uint32_t field_num = decoded_tag >> 3;
coded_tag = (coded_tag & 0x7f) + ((coded_tag & 0xff00) >> 1);
}
uint32_t field_num = coded_tag >> 3;
const auto* entry = FindFieldEntry(table, field_num); const auto* entry = FindFieldEntry(table, field_num);
PrintUTF8ErrorLog(MessageName(table), FieldName(table, entry), "parsing", PrintUTF8ErrorLog(MessageName(table), FieldName(table, entry), "parsing",
false); false);
@ -1020,7 +1104,7 @@ PROTOBUF_ALWAYS_INLINE const char* TcParser::SingularString(
if (PROTOBUF_PREDICT_TRUE(IsStructurallyValidUTF8(field.Get()))) { if (PROTOBUF_PREDICT_TRUE(IsStructurallyValidUTF8(field.Get()))) {
return ToParseLoop(PROTOBUF_TC_PARAM_PASS); return ToParseLoop(PROTOBUF_TC_PARAM_PASS);
} }
ReportFastUtf8Error(saved_tag, table); ReportFastUtf8Error(FastDecodeTag(saved_tag), table);
return utf8 == kUtf8 ? Error(PROTOBUF_TC_PARAM_PASS) return utf8 == kUtf8 ? Error(PROTOBUF_TC_PARAM_PASS)
: ToParseLoop(PROTOBUF_TC_PARAM_PASS); : ToParseLoop(PROTOBUF_TC_PARAM_PASS);
} }
@ -1051,6 +1135,27 @@ const char* TcParser::FastUS2(PROTOBUF_TC_PARAM_DECL) {
PROTOBUF_TC_PARAM_PASS); PROTOBUF_TC_PARAM_PASS);
} }
// Inlined string variants:
const char* TcParser::FastBiS1(PROTOBUF_TC_PARAM_DECL) {
PROTOBUF_MUSTTAIL return MiniParse(PROTOBUF_TC_PARAM_PASS);
}
const char* TcParser::FastBiS2(PROTOBUF_TC_PARAM_DECL) {
PROTOBUF_MUSTTAIL return MiniParse(PROTOBUF_TC_PARAM_PASS);
}
const char* TcParser::FastSiS1(PROTOBUF_TC_PARAM_DECL) {
PROTOBUF_MUSTTAIL return MiniParse(PROTOBUF_TC_PARAM_PASS);
}
const char* TcParser::FastSiS2(PROTOBUF_TC_PARAM_DECL) {
PROTOBUF_MUSTTAIL return MiniParse(PROTOBUF_TC_PARAM_PASS);
}
const char* TcParser::FastUiS1(PROTOBUF_TC_PARAM_DECL) {
PROTOBUF_MUSTTAIL return MiniParse(PROTOBUF_TC_PARAM_PASS);
}
const char* TcParser::FastUiS2(PROTOBUF_TC_PARAM_DECL) {
PROTOBUF_MUSTTAIL return MiniParse(PROTOBUF_TC_PARAM_PASS);
}
template <typename TagType, TcParser::Utf8Type utf8> template <typename TagType, TcParser::Utf8Type utf8>
PROTOBUF_ALWAYS_INLINE const char* TcParser::RepeatedString( PROTOBUF_ALWAYS_INLINE const char* TcParser::RepeatedString(
PROTOBUF_TC_PARAM_DECL) { PROTOBUF_TC_PARAM_DECL) {
@ -1076,7 +1181,7 @@ PROTOBUF_ALWAYS_INLINE const char* TcParser::RepeatedString(
if (PROTOBUF_PREDICT_TRUE(IsStructurallyValidUTF8(*str))) { if (PROTOBUF_PREDICT_TRUE(IsStructurallyValidUTF8(*str))) {
break; break;
} }
ReportFastUtf8Error(expected_tag, table); ReportFastUtf8Error(FastDecodeTag(expected_tag), table);
if (utf8 == kUtf8) return Error(PROTOBUF_TC_PARAM_PASS); if (utf8 == kUtf8) return Error(PROTOBUF_TC_PARAM_PASS);
break; break;
} }
@ -1613,16 +1718,31 @@ const char* TcParser::MpMessage(PROTOBUF_TC_PARAM_DECL) {
if (card == field_layout::kFcRepeated) { if (card == field_layout::kFcRepeated) {
PROTOBUF_MUSTTAIL return MpRepeatedMessage(PROTOBUF_TC_PARAM_PASS); PROTOBUF_MUSTTAIL return MpRepeatedMessage(PROTOBUF_TC_PARAM_PASS);
} }
// Check for wire type mismatch:
// TODO(b/210762816): support groups. const uint32_t decoded_tag = data.tag();
if ((data.tag() & 7) != WireFormatLite::WIRETYPE_LENGTH_DELIMITED) { const uint32_t decoded_wiretype = decoded_tag & 7;
PROTOBUF_MUSTTAIL return table->fallback(PROTOBUF_TC_PARAM_PASS); const uint16_t rep = type_card & field_layout::kRepMask;
const bool is_group = rep == field_layout::kRepGroup;
// Validate wiretype:
switch (rep) {
case field_layout::kRepMessage:
if (decoded_wiretype != WireFormatLite::WIRETYPE_LENGTH_DELIMITED) {
goto fallback;
} }
break;
case field_layout::kRepGroup:
if (decoded_wiretype != WireFormatLite::WIRETYPE_START_GROUP) {
goto fallback;
}
break;
default: {
fallback:
// Lazy and implicit weak fields are handled by generated code: // Lazy and implicit weak fields are handled by generated code:
// TODO(b/210762816): support these. // TODO(b/210762816): support these.
if ((type_card & field_layout::kRepMask) != field_layout::kRepMessage) {
PROTOBUF_MUSTTAIL return table->fallback(PROTOBUF_TC_PARAM_PASS); PROTOBUF_MUSTTAIL return table->fallback(PROTOBUF_TC_PARAM_PASS);
} }
}
const bool is_oneof = card == field_layout::kFcOneof; const bool is_oneof = card == field_layout::kFcOneof;
bool need_init = false; bool need_init = false;
@ -1638,6 +1758,9 @@ const char* TcParser::MpMessage(PROTOBUF_TC_PARAM_DECL) {
field = default_instance->New(ctx->data().arena); field = default_instance->New(ctx->data().arena);
} }
SyncHasbits(msg, hasbits, table); SyncHasbits(msg, hasbits, table);
if (is_group) {
return ctx->ParseGroup(field, ptr, decoded_tag);
}
return ctx->ParseMessage(field, ptr); return ctx->ParseMessage(field, ptr);
} }
@ -1646,16 +1769,29 @@ const char* TcParser::MpRepeatedMessage(PROTOBUF_TC_PARAM_DECL) {
const uint16_t type_card = entry.type_card; const uint16_t type_card = entry.type_card;
GOOGLE_DCHECK_EQ(type_card & field_layout::kFcMask, GOOGLE_DCHECK_EQ(type_card & field_layout::kFcMask,
static_cast<uint16_t>(field_layout::kFcRepeated)); static_cast<uint16_t>(field_layout::kFcRepeated));
const uint32_t decoded_tag = data.tag();
const uint32_t decoded_wiretype = decoded_tag & 7;
const uint16_t rep = type_card & field_layout::kRepMask;
const bool is_group = rep == field_layout::kRepGroup;
// Check for wire type mismatch: // Validate wiretype:
// TODO(b/210762816): support groups. switch (rep) {
if ((data.tag() & 7) != WireFormatLite::WIRETYPE_LENGTH_DELIMITED) { case field_layout::kRepMessage:
if (decoded_wiretype != WireFormatLite::WIRETYPE_LENGTH_DELIMITED) {
goto fallback;
}
break;
case field_layout::kRepGroup:
if (decoded_wiretype != WireFormatLite::WIRETYPE_START_GROUP) {
goto fallback;
}
break;
default: {
fallback:
// Lazy and implicit weak fields are handled by generated code:
// TODO(b/210762816): support these.
PROTOBUF_MUSTTAIL return table->fallback(PROTOBUF_TC_PARAM_PASS); PROTOBUF_MUSTTAIL return table->fallback(PROTOBUF_TC_PARAM_PASS);
} }
// Implicit weak fields are handled by generated code:
// TODO(b/210762816): support these.
if ((type_card & field_layout::kRepMask) != field_layout::kRepMessage) {
PROTOBUF_MUSTTAIL return table->fallback(PROTOBUF_TC_PARAM_PASS);
} }
SyncHasbits(msg, hasbits, table); SyncHasbits(msg, hasbits, table);
@ -1664,6 +1800,9 @@ const char* TcParser::MpRepeatedMessage(PROTOBUF_TC_PARAM_DECL) {
auto& field = RefAt<RepeatedPtrFieldBase>(msg, entry.offset); auto& field = RefAt<RepeatedPtrFieldBase>(msg, entry.offset);
MessageLite* value = MessageLite* value =
field.Add<GenericTypeHandler<MessageLite>>(default_instance); field.Add<GenericTypeHandler<MessageLite>>(default_instance);
if (is_group) {
return ctx->ParseGroup(value, ptr, decoded_tag);
}
return ctx->ParseMessage(value, ptr); return ctx->ParseMessage(value, ptr);
} }

View File

@ -44,7 +44,7 @@ namespace {
using ::testing::Eq; using ::testing::Eq;
using ::testing::Not; using ::testing::Not;
MATCHER_P2(IsEntryForFieldNum, table, field_num, MATCHER_P3(IsEntryForFieldNum, table, field_num, field_numbers_table,
StrCat(negation ? "isn't " : "", StrCat(negation ? "isn't " : "",
"the field entry for field number ", field_num)) { "the field entry for field number ", field_num)) {
if (arg == nullptr) { if (arg == nullptr) {
@ -54,7 +54,7 @@ MATCHER_P2(IsEntryForFieldNum, table, field_num,
// Use the entry's index to compare field numbers. // Use the entry's index to compare field numbers.
size_t index = static_cast<const TcParseTableBase::FieldEntry*>(arg) - size_t index = static_cast<const TcParseTableBase::FieldEntry*>(arg) -
&table->field_entries[0]; &table->field_entries[0];
uint32_t actual_field_num = table->field_numbers[index]; uint32_t actual_field_num = field_numbers_table[index];
if (actual_field_num != field_num) { if (actual_field_num != field_num) {
*result_listener << "which is the entry for " << actual_field_num; *result_listener << "which is the entry for " << actual_field_num;
return false; return false;
@ -64,25 +64,31 @@ MATCHER_P2(IsEntryForFieldNum, table, field_num,
TEST(IsEntryForFieldNumTest, Matcher) { TEST(IsEntryForFieldNumTest, Matcher) {
// clang-format off // clang-format off
TcParseTable<0, 3, 0, 0> table = { TcParseTable<0, 3, 0, 0, 2> table = {
// header: // header:
{ {
0, 0, 0, 0, // has_bits_offset, extensions 0, 0, 0, 0, // has_bits_offset, extensions
0, // max_field_number 0, // max_field_number
0, 0, // fast_idx_mask, num_sequential_fields 0, // fast_idx_mask,
0, 0, // sequential_fields_start, num_field_entries offsetof(decltype(table), field_lookup_table),
0xFFFFFFFF - 7, // 7 = fields 1, 2, and 3.
offsetof(decltype(table), field_names),
0, // num_field_entries
0, 0, // num_aux_entries, aux_offset, 0, 0, // num_aux_entries, aux_offset,
nullptr, // default instance nullptr, // default instance
nullptr, // fallback function nullptr, // fallback function
}}; }};
// clang-format on // clang-format on
table.field_numbers = {1, 2, 3}; int table_field_numbers[] = {1, 2, 3};
table.field_lookup_table = {65535, 65535};
EXPECT_THAT(&table.field_entries[0], IsEntryForFieldNum(&table, 1)); auto& entries = table.field_entries;
EXPECT_THAT(&table.field_entries[2], IsEntryForFieldNum(&table, 3)); EXPECT_THAT(&entries[0], IsEntryForFieldNum(&table, 1, table_field_numbers));
EXPECT_THAT(&table.field_entries[1], Not(IsEntryForFieldNum(&table, 3))); EXPECT_THAT(&entries[2], IsEntryForFieldNum(&table, 3, table_field_numbers));
EXPECT_THAT(&entries[1],
Not(IsEntryForFieldNum(&table, 3, table_field_numbers)));
EXPECT_THAT(nullptr, Not(IsEntryForFieldNum(&table, 1))); EXPECT_THAT(nullptr, Not(IsEntryForFieldNum(&table, 1, table_field_numbers)));
} }
} // namespace } // namespace
@ -91,30 +97,30 @@ class FindFieldEntryTest : public ::testing::Test {
public: public:
// Calls the private `FindFieldEntry` function. // Calls the private `FindFieldEntry` function.
template <size_t kFastTableSizeLog2, size_t kNumEntries, size_t kNumFieldAux, template <size_t kFastTableSizeLog2, size_t kNumEntries, size_t kNumFieldAux,
size_t kNameTableSize> size_t kNameTableSize, size_t kFieldLookupTableSize>
static const TcParseTableBase::FieldEntry* FindFieldEntry( static const TcParseTableBase::FieldEntry* FindFieldEntry(
const TcParseTable<kFastTableSizeLog2, kNumEntries, kNumFieldAux, const TcParseTable<kFastTableSizeLog2, kNumEntries, kNumFieldAux,
kNameTableSize>& table, kNameTableSize, kFieldLookupTableSize>& table,
uint32_t tag) { uint32_t tag) {
return TcParser::FindFieldEntry(&table.header, tag); return TcParser::FindFieldEntry(&table.header, tag);
} }
// Calls the private `FieldName` function. // Calls the private `FieldName` function.
template <size_t kFastTableSizeLog2, size_t kNumEntries, size_t kNumFieldAux, template <size_t kFastTableSizeLog2, size_t kNumEntries, size_t kNumFieldAux,
size_t kNameTableSize> size_t kNameTableSize, size_t kFieldLookupTableSize>
static StringPiece FieldName( static StringPiece FieldName(
const TcParseTable<kFastTableSizeLog2, kNumEntries, kNumFieldAux, const TcParseTable<kFastTableSizeLog2, kNumEntries, kNumFieldAux,
kNameTableSize>& table, kNameTableSize, kFieldLookupTableSize>& table,
const TcParseTableBase::FieldEntry* entry) { const TcParseTableBase::FieldEntry* entry) {
return TcParser::FieldName(&table.header, entry); return TcParser::FieldName(&table.header, entry);
} }
// Calls the private `MessageName` function. // Calls the private `MessageName` function.
template <size_t kFastTableSizeLog2, size_t kNumEntries, size_t kNumFieldAux, template <size_t kFastTableSizeLog2, size_t kNumEntries, size_t kNumFieldAux,
size_t kNameTableSize> size_t kNameTableSize, size_t kFieldLookupTableSize>
static StringPiece MessageName( static StringPiece MessageName(
const TcParseTable<kFastTableSizeLog2, kNumEntries, kNumFieldAux, const TcParseTable<kFastTableSizeLog2, kNumEntries, kNumFieldAux,
kNameTableSize>& table) { kNameTableSize, kFieldLookupTableSize>& table) {
return TcParser::MessageName(&table.header); return TcParser::MessageName(&table.header);
} }
@ -123,27 +129,38 @@ class FindFieldEntryTest : public ::testing::Test {
}; };
TEST_F(FindFieldEntryTest, SequentialFieldRange) { TEST_F(FindFieldEntryTest, SequentialFieldRange) {
// Look up fields that are within the range of `num_sequential_fields`. // Look up fields that are within the range of `lookup_table_offset`.
// clang-format off // clang-format off
TcParseTable<0, 5, 0, 0> table = { TcParseTable<0, 5, 0, 0, 8> table = {
// header: // header:
{ {
0, 0, 0, 0, // has_bits_offset, extensions 0, 0, 0, 0, // has_bits_offset, extensions
111, // max_field_number 111, // max_field_number
0, 4, // fast_idx_mask, num_sequential_fields 0, // fast_idx_mask,
2, 5, // sequential_fields_start, num_field_entries offsetof(decltype(table), field_lookup_table),
0xFFFFFFFF - (1 << 1) - (1 << 2) // fields 2, 3
- (1 << 3) - (1 << 4), // fields 4, 5
offsetof(decltype(table), field_entries),
5, // num_field_entries
0, 0, // num_aux_entries, aux_offset, 0, 0, // num_aux_entries, aux_offset,
nullptr, // default instance nullptr, // default instance
{}, // fallback function {}, // fallback function
}, },
{}, // fast_entries {}, // fast_entries
// field_numbers: // field_lookup_table for 2, 3, 4, 5, 111:
{{2, 3, 4, 5, 111}}, {{
111, 0, // field 111
1, // 1 skip entry
0xFFFE, 4, // 1 field, entry 4.
65535, 65535, // end of table
}},
}; };
// clang-format on // clang-format on
int table_field_numbers[] = {2, 3, 4, 5, 111};
for (int i : table.field_numbers) { for (int i : table_field_numbers) {
EXPECT_THAT(FindFieldEntry(table, i), IsEntryForFieldNum(&table, i)); EXPECT_THAT(FindFieldEntry(table, i),
IsEntryForFieldNum(&table, i, table_field_numbers));
} }
for (int i : {0, 1, 6, 7, 110, 112, 500000000}) { for (int i : {0, 1, 6, 7, 110, 112, 500000000}) {
GOOGLE_LOG(WARNING) << "Field " << i; GOOGLE_LOG(WARNING) << "Field " << i;
@ -152,33 +169,43 @@ TEST_F(FindFieldEntryTest, SequentialFieldRange) {
} }
TEST_F(FindFieldEntryTest, SmallScanRange) { TEST_F(FindFieldEntryTest, SmallScanRange) {
// Look up fields past `num_sequential_fields`, but before binary search. // Look up fields past `lookup_table_offset`, but before binary search.
ASSERT_THAT(small_scan_size(), Eq(4)) << "test needs to be updated"; ASSERT_THAT(small_scan_size(), Eq(4)) << "test needs to be updated";
// clang-format off // clang-format off
TcParseTable<0, 6, 0, 0> table = { TcParseTable<0, 6, 0, 0, 8> table = {
// header: // header:
{ {
0, 0, 0, 0, // has_bits_offset, extensions 0, 0, 0, 0, // has_bits_offset, extensions
111, // max_field_number 111, // max_field_number
0, 1, // fast_idx_mask, num_sequential_fields 0, // fast_idx_mask,
1, 6, // sequential_fields_start, num_field_entries offsetof(decltype(table), field_lookup_table),
0xFFFFFFFF - (1<<0) - (1<<2) - (1<<3) - (1<<4) - (1<<6), // 1,3-5,7
offsetof(decltype(table), field_entries),
6, // num_field_entries
0, 0, // num_aux_entries, aux_offset, 0, 0, // num_aux_entries, aux_offset,
nullptr, // default instance nullptr, // default instance
{}, // fallback function {}, // fallback function
}, },
{}, // fast_entries {}, // fast_entries
// field_numbers: // field_lookup_table for 1, 3, 4, 5, 7, 111:
{{// Sequential entries: {{
111, 0, // field 111
1, // 1 skip entry
0xFFFE, 5, // 1 field, entry 5
65535, 65535 // end of table
}},
};
// clang-format on
int table_field_numbers[] = {// Sequential entries:
1, 1,
// Small scan range: // Small scan range:
3, 4, 5, 7, 3, 4, 5, 7,
// Binary search range: // Binary search range:
111}}, 111};
};
// clang-format on
for (int i : table.field_numbers) { for (int i : table_field_numbers) {
EXPECT_THAT(FindFieldEntry(table, i), IsEntryForFieldNum(&table, i)); EXPECT_THAT(FindFieldEntry(table, i),
IsEntryForFieldNum(&table, i, table_field_numbers));
} }
for (int i : {0, 2, 6, 8, 9, 110, 112, 500000000}) { for (int i : {0, 2, 6, 8, 9, 110, 112, 500000000}) {
EXPECT_THAT(FindFieldEntry(table, i), Eq(nullptr)); EXPECT_THAT(FindFieldEntry(table, i), Eq(nullptr));
@ -191,29 +218,43 @@ TEST_F(FindFieldEntryTest, BinarySearchRange) {
ASSERT_THAT(small_scan_size(), Eq(4)) << "test needs to be updated"; ASSERT_THAT(small_scan_size(), Eq(4)) << "test needs to be updated";
// clang-format off // clang-format off
TcParseTable<0, 10, 0, 0> table = { TcParseTable<0, 10, 0, 0, 8> table = {
// header: // header:
{ {
0, 0, 0, 0, // has_bits_offset, extensions 0, 0, 0, 0, // has_bits_offset, extensions
70, // max_field_number 70, // max_field_number
0, 1, // fast_idx_mask, num_sequential_fields 0, // fast_idx_mask,
1, 10, // sequential_fields_start, num_field_entries offsetof(decltype(table), field_lookup_table),
0xFFFFFFFF - (1<<0) - (1<<2) - (1<<3) - (1<<4) // 1, 3, 4, 5, 6
- (1<<5) - (1<<7) - (1<<8) - (1<<10) // 8, 9, 11, 12
- (1<<11),
offsetof(decltype(table), field_entries),
10, // num_field_entries
0, 0, // num_aux_entries, aux_offset, 0, 0, // num_aux_entries, aux_offset,
nullptr, // default instance nullptr, // default instance
{}, // fallback function {}, // fallback function
}, },
{}, // fast_entries {}, // fast_entries
// field_numbers: // field_lookup_table for 1, 3, 4, 5, 6, 8, 9, 11, 12, 70
{{// Sequential entries: {{
70, 0, // field 70
1, // 1 skip entry
0xFFFE, 9, // 1 field, entry 9
65535, 65535 // end of table
}},
};
int table_field_numbers[] = {
// Sequential entries:
1, 1,
// Small scan range: // Small scan range:
3, 4, 5, 6, 3, 4, 5, 6,
// Binary search range: // Binary search range:
8, 9, 11, 12, 70}}, 8, 9, 11, 12, 70
}; };
// clang-format on // clang-format on
for (int i : table.field_numbers) { for (int i : table_field_numbers) {
EXPECT_THAT(FindFieldEntry(table, i), IsEntryForFieldNum(&table, i)); EXPECT_THAT(FindFieldEntry(table, i),
IsEntryForFieldNum(&table, i, table_field_numbers));
} }
for (int i : {0, 2, 7, 10, 13, 69, 71, 112, 500000000}) { for (int i : {0, 2, 7, 10, 13, 69, 71, 112, 500000000}) {
EXPECT_THAT(FindFieldEntry(table, i), Eq(nullptr)); EXPECT_THAT(FindFieldEntry(table, i), Eq(nullptr));
@ -223,21 +264,25 @@ TEST_F(FindFieldEntryTest, BinarySearchRange) {
TEST_F(FindFieldEntryTest, OutOfRange) { TEST_F(FindFieldEntryTest, OutOfRange) {
// Look up tags that are larger than the maximum in the message. // Look up tags that are larger than the maximum in the message.
// clang-format off // clang-format off
TcParseTable<0, 3, 0, 15> table = { TcParseTable<0, 3, 0, 15, 2> table = {
// header: // header:
{ {
0, 0, 0, 0, // has_bits_offset, extensions 0, 0, 0, 0, // has_bits_offset, extensions
3, // max_field_number 3, // max_field_number
0, 3, // fast_idx_mask, num_sequential_fields 0, // fast_idx_mask,
1, 3, // sequential_fields_start, num_field_entries offsetof(decltype(table), field_lookup_table),
0xFFFFFFFF - (1<<0) - (1<<1) - (1<<2), // fields 1, 2, 3
offsetof(decltype(table), field_entries),
3, // num_field_entries
0, // num_aux_entries 0, // num_aux_entries
offsetof(decltype(table), field_names), // no aux_entries offsetof(decltype(table), field_names), // no aux_entries
nullptr, // default instance nullptr, // default instance
{}, // fallback function {}, // fallback function
}, },
{}, // fast_entries {}, // fast_entries
// field_numbers: {{// field lookup table
{{1, 2, 3}}, 65535, 65535 // end of table
}},
{}, // "mini" table {}, // "mini" table
// auxiliary entries (none in this test) // auxiliary entries (none in this test)
{{ // name lengths {{ // name lengths
@ -248,10 +293,12 @@ TEST_F(FindFieldEntryTest, OutOfRange) {
"003"}}, "003"}},
}; };
// clang-format on // clang-format on
int table_field_numbers[] = {1, 2, 3};
for (int field_num : table.field_numbers) { for (int field_num : table_field_numbers) {
auto* entry = FindFieldEntry(table, field_num); auto* entry = FindFieldEntry(table, field_num);
EXPECT_THAT(entry, IsEntryForFieldNum(&table, field_num)); EXPECT_THAT(entry,
IsEntryForFieldNum(&table, field_num, table_field_numbers));
StringPiece name = FieldName(table, entry); StringPiece name = FieldName(table, entry);
EXPECT_EQ(name.length(), field_num); EXPECT_EQ(name.length(), field_num);
@ -265,21 +312,27 @@ TEST_F(FindFieldEntryTest, OutOfRange) {
TEST_F(FindFieldEntryTest, EmptyMessage) { TEST_F(FindFieldEntryTest, EmptyMessage) {
// Ensure that tables with no fields are handled correctly. // Ensure that tables with no fields are handled correctly.
using TableType = TcParseTable<0, 0, 0, 20>; using TableType = TcParseTable<0, 0, 0, 20, 2>;
// clang-format off // clang-format off
TableType table = { TableType table = {
// header: // header:
{ {
0, 0, 0, 0, // has_bits_offset, extensions 0, 0, 0, 0, // has_bits_offset, extensions
0, // max_field_number 0, // max_field_number
0, 0, // fast_idx_mask, num_sequential_fields 0, // fast_idx_mask,
0, 0, // sequential_fields_start, num_field_entries offsetof(decltype(table), field_lookup_table),
0xFFFFFFFF, // no fields
offsetof(decltype(table), field_names), // no field_entries
0, // num_field_entries
0, // num_aux_entries 0, // num_aux_entries
offsetof(TableType, field_names), offsetof(TableType, field_names),
nullptr, // default instance nullptr, // default instance
nullptr, // fallback function nullptr, // fallback function
}, },
{}, // fast_entries {}, // fast_entries
{{// empty field lookup table
65535, 65535
}},
{{ {{
"\13\0\0\0\0\0\0\0" "\13\0\0\0\0\0\0\0"
"MessageName" "MessageName"
@ -294,13 +347,32 @@ TEST_F(FindFieldEntryTest, EmptyMessage) {
} }
// Make a monster with lots of field numbers // Make a monster with lots of field numbers
int32_t test_all_types_table_field_numbers[] = {
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, //
11, 12, 13, 14, 15, 18, 19, 21, 22, 24, //
25, 27, 31, 32, 33, 34, 35, 36, 37, 38, //
39, 40, 41, 42, 43, 44, 45, 48, 49, 51, //
52, 54, 55, 56, 57, 58, 59, 60, 61, 62, //
63, 64, 65, 66, 67, 68, 69, 70, 71, 72, //
73, 74, 75, 76, 77, 78, 79, 80, 81, 82, //
83, 84, 85, 86, 87, 88, 89, 90, 91, 92, //
93, 94, 95, 96, 97, 98, 99, 100, 101, 102, //
111, 112, 113, 114, 115, 116, 117, 118, 119, 201, //
241, 242, 243, 244, 245, 246, 247, 248, 249, 250, //
251, 252, 253, 254, 255, 321, 322, 401, 402, 403, //
404, 405, 406, 407, 408, 409, 410, 411, 412, 413, //
414, 415, 416, 417};
// clang-format off // clang-format off
const TcParseTable<5, 134, 5, 2176> test_all_types_table = { const TcParseTable<5, 134, 5, 2176, 55> test_all_types_table = {
// header: // header:
{ {
0, 0, 0, 0, // has_bits_offset, extensions 0, 0, 0, 0, // has_bits_offset, extensions
418, 248, // max_field_number, fast_idx_mask 418, 248, // max_field_number, fast_idx_mask
14, 1, // num_sequential_fields, sequential_fields_start offsetof(decltype(test_all_types_table), field_lookup_table),
977895424, // skipmap for fields 1-15,18-19,21-22,24-25,27,31-32
offsetof(decltype(test_all_types_table), field_entries),
135, // num_field_entries 135, // num_field_entries
5, // num_aux_entries 5, // num_aux_entries
offsetof(decltype(test_all_types_table), aux_entries), offsetof(decltype(test_all_types_table), aux_entries),
@ -310,21 +382,18 @@ const TcParseTable<5, 134, 5, 2176> test_all_types_table = {
{{ {{
// tail-call table // tail-call table
}}, }},
{{// field numbers {{ // field lookup table
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, //
11, 12, 13, 14, 15, 18, 19, 21, 22, 24, // fields 33-417, over 25 skipmap / offset pairs
25, 27, 31, 32, 33, 34, 35, 36, 37, 38, 33, 0, 25,
39, 40, 41, 42, 43, 44, 45, 48, 49, 51, 24576, 24, 18, 38, 0, 52, 0, 68, 16320, 84,
52, 54, 55, 56, 57, 58, 59, 60, 61, 62, 65408, 92, 65535, 99, 65535, 99, 65535, 99, 65535, 99,
63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 65279, 99, 65535, 100, 65535, 100, 32768, 100, 65535, 115,
73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 65535, 115, 65535, 115, 65535, 115, 65532, 115, 65535, 117,
83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 65535, 117, 65535, 117, 65535, 117, 0, 117, 65532, 133,
93, 94, 95, 96, 97, 98, 99, 100, 101, 102, // end of table
111, 112, 113, 114, 115, 116, 117, 118, 119, 201, 65535, 65535
241, 242, 243, 244, 245, 246, 247, 248, 249, 250, }},
251, 252, 253, 254, 255, 321, 322, 401, 402, 403,
404, 405, 406, 407, 408, 409, 410, 411, 412, 413,
414, 415, 416, 417}},
{{ {{
// "mini" table // "mini" table
}}, }},

View File

@ -92,8 +92,8 @@
#include <google/protobuf/stubs/common.h> #include <google/protobuf/stubs/common.h>
#include <google/protobuf/stubs/logging.h> #include <google/protobuf/stubs/logging.h>
#include <google/protobuf/stubs/stringprintf.h>
#include <google/protobuf/stubs/strutil.h> #include <google/protobuf/stubs/strutil.h>
#include <google/protobuf/stubs/stringprintf.h>
#include <google/protobuf/io/strtod.h> #include <google/protobuf/io/strtod.h>
#include <google/protobuf/io/zero_copy_stream.h> #include <google/protobuf/io/zero_copy_stream.h>
#include <google/protobuf/stubs/stl_util.h> #include <google/protobuf/stubs/stl_util.h>

View File

@ -233,13 +233,14 @@
#ifdef PROTOBUF_TAILCALL #ifdef PROTOBUF_TAILCALL
#error PROTOBUF_TAILCALL was previously defined #error PROTOBUF_TAILCALL was previously defined
#endif #endif
#if __has_cpp_attribute(clang::musttail) && \ #if __has_cpp_attribute(clang::musttail) && !defined(__arm__) && \
!defined(__arm__) && !defined(_ARCH_PPC) && !defined(__wasm__) && \ !defined(_ARCH_PPC) && !defined(__wasm__) && \
!(defined(_MSC_VER) && defined(_M_IX86)) !(defined(_MSC_VER) && defined(_M_IX86))
# ifndef PROTO2_OPENSOURCE # ifndef PROTO2_OPENSOURCE
// Compilation fails on ARM32: b/195943306 // Compilation fails on ARM32: b/195943306
// Compilation fails on powerpc64le: b/187985113 // Compilation fails on powerpc64le: b/187985113
// Compilation fails on X86 Windows: https://github.com/llvm/llvm-project/issues/53271 // Compilation fails on X86 Windows:
// https://github.com/llvm/llvm-project/issues/53271
# endif # endif
#define PROTOBUF_MUSTTAIL [[clang::musttail]] #define PROTOBUF_MUSTTAIL [[clang::musttail]]
#define PROTOBUF_TAILCALL true #define PROTOBUF_TAILCALL true

View File

@ -1515,8 +1515,6 @@ class RepeatedPtrIterator {
using iterator = RepeatedPtrIterator<Element>; using iterator = RepeatedPtrIterator<Element>;
using iterator_category = std::random_access_iterator_tag; using iterator_category = std::random_access_iterator_tag;
using value_type = typename std::remove_const<Element>::type; using value_type = typename std::remove_const<Element>::type;
using const_iterator = RepeatedPtrIterator<const value_type>;
using nonconst_iterator = RepeatedPtrIterator<value_type>;
using difference_type = std::ptrdiff_t; using difference_type = std::ptrdiff_t;
using pointer = Element*; using pointer = Element*;
using reference = Element&; using reference = Element&;

View File

@ -178,8 +178,8 @@ std::ostream& operator<<(std::ostream& o, const uint128& b) {
if ((flags & std::ios::adjustfield) == std::ios::left) { if ((flags & std::ios::adjustfield) == std::ios::left) {
rep.append(width - repSize, o.fill()); rep.append(width - repSize, o.fill());
} else { } else {
rep.insert(static_cast<std::string::size_type>(0), rep.insert(static_cast<std::string::size_type>(0), width - repSize,
width - repSize, o.fill()); o.fill());
} }
} }

View File

@ -37,7 +37,8 @@
#include <algorithm> #include <algorithm>
#include <google/protobuf/port_def.inc> // Must be last.
#include <google/protobuf/port_def.inc> // NOLINT
namespace google { namespace google {
namespace protobuf { namespace protobuf {
@ -84,6 +85,6 @@ inline char* string_as_array(std::string* str) {
} // namespace protobuf } // namespace protobuf
} // namespace google } // namespace google
#include <google/protobuf/port_undef.inc> #include <google/protobuf/port_undef.inc> // NOLINT
#endif // GOOGLE_PROTOBUF_STUBS_STL_UTIL_H__ #endif // GOOGLE_PROTOBUF_STUBS_STL_UTIL_H__

View File

@ -36,7 +36,6 @@
#include <google/protobuf/stubs/logging.h> #include <google/protobuf/stubs/logging.h>
#include <google/protobuf/stubs/common.h> #include <google/protobuf/stubs/common.h>
#include <google/protobuf/stubs/stringprintf.h>
#include <google/protobuf/io/coded_stream.h> #include <google/protobuf/io/coded_stream.h>
#include <google/protobuf/io/zero_copy_stream_impl.h> #include <google/protobuf/io/zero_copy_stream_impl.h>
#include <google/protobuf/descriptor.h> #include <google/protobuf/descriptor.h>
@ -47,6 +46,7 @@
#include <google/protobuf/stubs/strutil.h> #include <google/protobuf/stubs/strutil.h>
#include <google/protobuf/stubs/casts.h> #include <google/protobuf/stubs/casts.h>
#include <google/protobuf/stubs/status.h> #include <google/protobuf/stubs/status.h>
#include <google/protobuf/stubs/stringprintf.h>
#include <google/protobuf/stubs/time.h> #include <google/protobuf/stubs/time.h>
#include <google/protobuf/util/internal/constants.h> #include <google/protobuf/util/internal/constants.h>
#include <google/protobuf/util/internal/field_mask_utility.h> #include <google/protobuf/util/internal/field_mask_utility.h>
@ -1098,11 +1098,11 @@ const std::string FormatNanos(uint32_t nanos, bool with_trailing_zeros) {
return with_trailing_zeros ? ".000" : ""; return with_trailing_zeros ? ".000" : "";
} }
const char* format = (nanos % 1000 != 0) ? "%.9f" const int precision = (nanos % 1000 != 0) ? 9
: (nanos % 1000000 != 0) ? "%.6f" : (nanos % 1000000 != 0) ? 6
: "%.3f"; : 3;
std::string formatted = std::string formatted = StringPrintf(
StringPrintf(format, static_cast<double>(nanos) / kNanosPerSecond); "%.*f", precision, static_cast<double>(nanos) / kNanosPerSecond);
// remove the leading 0 before decimal. // remove the leading 0 before decimal.
return formatted.substr(1); return formatted.substr(1);
} }

View File

@ -44,7 +44,6 @@
#include <google/protobuf/stubs/logging.h> #include <google/protobuf/stubs/logging.h>
#include <google/protobuf/stubs/common.h> #include <google/protobuf/stubs/common.h>
#include <google/protobuf/stubs/stringprintf.h>
#include <google/protobuf/io/printer.h> #include <google/protobuf/io/printer.h>
#include <google/protobuf/io/zero_copy_stream.h> #include <google/protobuf/io/zero_copy_stream.h>
#include <google/protobuf/io/zero_copy_stream_impl.h> #include <google/protobuf/io/zero_copy_stream_impl.h>
@ -56,6 +55,7 @@
#include <google/protobuf/message.h> #include <google/protobuf/message.h>
#include <google/protobuf/text_format.h> #include <google/protobuf/text_format.h>
#include <google/protobuf/stubs/strutil.h> #include <google/protobuf/stubs/strutil.h>
#include <google/protobuf/stubs/stringprintf.h>
#include <google/protobuf/util/field_comparator.h> #include <google/protobuf/util/field_comparator.h>
// Always include as last one, otherwise it can break compilation // Always include as last one, otherwise it can break compilation

View File

@ -32,11 +32,11 @@
#include <cstdint> #include <cstdint>
#include <google/protobuf/stubs/stringprintf.h>
#include <google/protobuf/stubs/strutil.h> #include <google/protobuf/stubs/strutil.h>
#include <google/protobuf/duration.pb.h> #include <google/protobuf/duration.pb.h>
#include <google/protobuf/timestamp.pb.h> #include <google/protobuf/timestamp.pb.h>
#include <google/protobuf/stubs/int128.h> #include <google/protobuf/stubs/int128.h>
#include <google/protobuf/stubs/stringprintf.h>
#include <google/protobuf/stubs/time.h> #include <google/protobuf/stubs/time.h>
// Must go after other includes. // Must go after other includes.