From cad0258d1758a0ca9f1b9725103bcdbae26697e1 Mon Sep 17 00:00:00 2001 From: Feng Xiao Date: Tue, 11 Apr 2017 16:08:16 -0700 Subject: [PATCH 1/2] Cherry-pick cl/151775298 --- src/google/protobuf/map_test.cc | 27 ++++++++++++++++++++++++++ src/google/protobuf/wire_format.cc | 31 ++++++++++++++++++++++++++++-- 2 files changed, 56 insertions(+), 2 deletions(-) diff --git a/src/google/protobuf/map_test.cc b/src/google/protobuf/map_test.cc index 3fe445121..e10a42783 100644 --- a/src/google/protobuf/map_test.cc +++ b/src/google/protobuf/map_test.cc @@ -2811,6 +2811,33 @@ TEST(WireFormatForMapFieldTest, SerializeMap) { EXPECT_TRUE(dynamic_data == generated_data); } +TEST(WireFormatForMapFieldTest, SerializeMapDynamicMessage) { + DynamicMessageFactory factory; + google::protobuf::scoped_ptr dynamic_message; + dynamic_message.reset( + factory.GetPrototype(unittest::TestMap::descriptor())->New()); + MapReflectionTester reflection_tester( + unittest::TestMap::descriptor()); + reflection_tester.SetMapFieldsViaReflection(dynamic_message.get()); + reflection_tester.ExpectMapFieldsSetViaReflection(*dynamic_message); + + unittest::TestMap generated_message; + MapTestUtil::SetMapFields(&generated_message); + MapTestUtil::ExpectMapFieldsSet(generated_message); + + string generated_data; + string dynamic_data; + + // Serialize. + generated_message.SerializeToString(&generated_data); + dynamic_message->SerializeToString(&dynamic_data); + + // Because map serialization doesn't guarantee order, we just compare + // serialized size here. This is enough to tell dynamic message doesn't miss + // anything in serialization. + EXPECT_TRUE(dynamic_data.size() == generated_data.size()); +} + TEST(WireFormatForMapFieldTest, MapParseHelpers) { string data; diff --git a/src/google/protobuf/wire_format.cc b/src/google/protobuf/wire_format.cc index d3d21c093..7778ecd19 100644 --- a/src/google/protobuf/wire_format.cc +++ b/src/google/protobuf/wire_format.cc @@ -797,7 +797,16 @@ void WireFormat::SerializeWithCachedSizes( int expected_endpoint = output->ByteCount() + size; std::vector fields; - message_reflection->ListFields(message, &fields); + + // Fields of map entry should always be serialized. + if (descriptor->options().map_entry()) { + for (int i = 0; i < descriptor->field_count(); i++) { + fields.push_back(descriptor->field(i)); + } + } else { + message_reflection->ListFields(message, &fields); + } + for (int i = 0; i < fields.size(); i++) { SerializeFieldWithCachedSizes(fields[i], message, output); } @@ -834,6 +843,9 @@ void WireFormat::SerializeFieldWithCachedSizes( if (field->is_repeated()) { count = message_reflection->FieldSize(message, field); + } else if (field->containing_type()->options().map_entry()) { + // Map entry fields always need to be serialized. + count = 1; } else if (message_reflection->HasField(message, field)) { count = 1; } @@ -984,7 +996,16 @@ size_t WireFormat::ByteSize(const Message& message) { size_t our_size = 0; std::vector fields; - message_reflection->ListFields(message, &fields); + + // Fields of map entry should always be serialized. + if (descriptor->options().map_entry()) { + for (int i = 0; i < descriptor->field_count(); i++) { + fields.push_back(descriptor->field(i)); + } + } else { + message_reflection->ListFields(message, &fields); + } + for (int i = 0; i < fields.size(); i++) { our_size += FieldByteSize(fields[i], message); } @@ -1015,6 +1036,9 @@ size_t WireFormat::FieldByteSize( size_t count = 0; if (field->is_repeated()) { count = FromIntSize(message_reflection->FieldSize(message, field)); + } else if (field->containing_type()->options().map_entry()) { + // Map entry fields always need to be serialized. + count = 1; } else if (message_reflection->HasField(message, field)) { count = 1; } @@ -1044,6 +1068,9 @@ size_t WireFormat::FieldDataOnlyByteSize( if (field->is_repeated()) { count = internal::FromIntSize(message_reflection->FieldSize(message, field)); + } else if (field->containing_type()->options().map_entry()) { + // Map entry fields always need to be serialized. + count = 1; } else if (message_reflection->HasField(message, field)) { count = 1; } From 57772592738cf9ec1464fe2232f4db195875e95d Mon Sep 17 00:00:00 2001 From: Feng Xiao Date: Tue, 11 Apr 2017 16:08:48 -0700 Subject: [PATCH 2/2] Cherry-pick cl/152450543 --- src/google/protobuf/map_field.cc | 14 ++ src/google/protobuf/map_field.h | 2 + src/google/protobuf/map_test.cc | 128 +++++++++++- src/google/protobuf/map_test_util.cc | 16 ++ src/google/protobuf/map_test_util.h | 5 + src/google/protobuf/message.h | 11 + src/google/protobuf/reflection_ops.cc | 22 ++ src/google/protobuf/wire_format.cc | 276 +++++++++++++++++++++++++- 8 files changed, 472 insertions(+), 2 deletions(-) diff --git a/src/google/protobuf/map_field.cc b/src/google/protobuf/map_field.cc index 4cde0aaa0..64dcc990f 100644 --- a/src/google/protobuf/map_field.cc +++ b/src/google/protobuf/map_field.cc @@ -67,6 +67,13 @@ size_t MapFieldBase::SpaceUsedExcludingSelfNoLock() const { } } +bool MapFieldBase::IsMapValid() const { + // "Acquire" insures the operation after SyncRepeatedFieldWithMap won't get + // executed before state_ is checked. + Atomic32 state = google::protobuf::internal::Acquire_Load(&state_); + return state != STATE_MODIFIED_REPEATED; +} + void MapFieldBase::SetMapDirty() { state_ = STATE_MODIFIED_MAP; } void MapFieldBase::SetRepeatedDirty() { state_ = STATE_MODIFIED_REPEATED; } @@ -359,6 +366,13 @@ void DynamicMapField::SyncMapWithRepeatedFieldNoLock() const { GOOGLE_LOG(FATAL) << "Can't get here."; break; } + + // Remove existing map value with same key. + Map::iterator iter = map->find(map_key); + if (iter != map->end()) { + iter->second.DeleteData(); + } + MapValueRef& map_val = (*map)[map_key]; map_val.SetType(val_des->cpp_type()); switch (val_des->cpp_type()) { diff --git a/src/google/protobuf/map_field.h b/src/google/protobuf/map_field.h index 6d9040763..9d5a328eb 100644 --- a/src/google/protobuf/map_field.h +++ b/src/google/protobuf/map_field.h @@ -86,6 +86,8 @@ class LIBPROTOBUF_EXPORT MapFieldBase { virtual bool ContainsMapKey(const MapKey& map_key) const = 0; virtual bool InsertOrLookupMapValue( const MapKey& map_key, MapValueRef* val) = 0; + // Insures operations after won't get executed before calling this. + bool IsMapValid() const; virtual bool DeleteMapValue(const MapKey& map_key) = 0; virtual bool EqualIterator(const MapIterator& a, const MapIterator& b) const = 0; diff --git a/src/google/protobuf/map_test.cc b/src/google/protobuf/map_test.cc index e10a42783..a06b432aa 100644 --- a/src/google/protobuf/map_test.cc +++ b/src/google/protobuf/map_test.cc @@ -975,6 +975,11 @@ static int Int(const string& value) { class MapFieldReflectionTest : public testing::Test { protected: typedef FieldDescriptor FD; + + int MapSize(const Reflection* reflection, const FieldDescriptor* field, + const Message& message) { + return reflection->MapSize(message, field); + } }; TEST_F(MapFieldReflectionTest, RegularFields) { @@ -1782,6 +1787,50 @@ TEST_F(MapFieldReflectionTest, RepeatedFieldRefMergeFromAndSwap) { // TODO(teboring): add test for duplicated key } +TEST_F(MapFieldReflectionTest, MapSizeWithDuplicatedKey) { + // Dynamic Message + { + DynamicMessageFactory factory; + google::protobuf::scoped_ptr message( + factory.GetPrototype(unittest::TestMap::descriptor())->New()); + const Reflection* reflection = message->GetReflection(); + const FieldDescriptor* field = + unittest::TestMap::descriptor()->FindFieldByName("map_int32_int32"); + + Message* entry1 = reflection->AddMessage(message.get(), field); + Message* entry2 = reflection->AddMessage(message.get(), field); + + const Reflection* entry_reflection = entry1->GetReflection(); + const FieldDescriptor* key_field = + entry1->GetDescriptor()->FindFieldByName("key"); + entry_reflection->SetInt32(entry1, key_field, 1); + entry_reflection->SetInt32(entry2, key_field, 1); + + EXPECT_EQ(2, reflection->FieldSize(*message, field)); + EXPECT_EQ(1, MapSize(reflection, field, *message)); + } + + // Generated Message + { + unittest::TestMap message; + const Reflection* reflection = message.GetReflection(); + const FieldDescriptor* field = + message.GetDescriptor()->FindFieldByName("map_int32_int32"); + + Message* entry1 = reflection->AddMessage(&message, field); + Message* entry2 = reflection->AddMessage(&message, field); + + const Reflection* entry_reflection = entry1->GetReflection(); + const FieldDescriptor* key_field = + entry1->GetDescriptor()->FindFieldByName("key"); + entry_reflection->SetInt32(entry1, key_field, 1); + entry_reflection->SetInt32(entry2, key_field, 1); + + EXPECT_EQ(2, reflection->FieldSize(message, field)); + EXPECT_EQ(1, MapSize(reflection, field, message)); + } +} + // Generated Message Test =========================================== TEST(GeneratedMapFieldTest, Accessors) { @@ -2689,6 +2738,69 @@ TEST_F(MapFieldInDynamicMessageTest, RecursiveMap) { ASSERT_TRUE(to->ParseFromString(data)); } +TEST_F(MapFieldInDynamicMessageTest, MapValueReferernceValidAfterSerialize) { + google::protobuf::scoped_ptr message(map_prototype_->New()); + MapReflectionTester reflection_tester(map_descriptor_); + reflection_tester.SetMapFieldsViaMapReflection(message.get()); + + // Get value reference before serialization, so that we know the value is from + // map. + MapKey map_key; + MapValueRef map_val; + map_key.SetInt32Value(0); + reflection_tester.GetMapValueViaMapReflection( + message.get(), "map_int32_foreign_message", map_key, &map_val); + Message* submsg = map_val.MutableMessageValue(); + + // In previous implementation, calling SerializeToString will cause syncing + // from map to repeated field, which will invalidate the submsg we previously + // got. + string data; + message->SerializeToString(&data); + + const Reflection* submsg_reflection = submsg->GetReflection(); + const Descriptor* submsg_desc = submsg->GetDescriptor(); + const FieldDescriptor* submsg_field = submsg_desc->FindFieldByName("c"); + submsg_reflection->SetInt32(submsg, submsg_field, 128); + + message->SerializeToString(&data); + TestMap to; + to.ParseFromString(data); + EXPECT_EQ(128, to.map_int32_foreign_message().at(0).c()); +} + +TEST_F(MapFieldInDynamicMessageTest, MapEntryReferernceValidAfterSerialize) { + google::protobuf::scoped_ptr message(map_prototype_->New()); + MapReflectionTester reflection_tester(map_descriptor_); + reflection_tester.SetMapFieldsViaReflection(message.get()); + + // Get map entry before serialization, so that we know the it is from + // repeated field. + Message* map_entry = reflection_tester.GetMapEntryViaReflection( + message.get(), "map_int32_foreign_message", 0); + const Reflection* map_entry_reflection = map_entry->GetReflection(); + const Descriptor* map_entry_desc = map_entry->GetDescriptor(); + const FieldDescriptor* value_field = map_entry_desc->FindFieldByName("value"); + Message* submsg = + map_entry_reflection->MutableMessage(map_entry, value_field); + + // In previous implementation, calling SerializeToString will cause syncing + // from repeated field to map, which will invalidate the map_entry we + // previously got. + string data; + message->SerializeToString(&data); + + const Reflection* submsg_reflection = submsg->GetReflection(); + const Descriptor* submsg_desc = submsg->GetDescriptor(); + const FieldDescriptor* submsg_field = submsg_desc->FindFieldByName("c"); + submsg_reflection->SetInt32(submsg, submsg_field, 128); + + message->SerializeToString(&data); + TestMap to; + to.ParseFromString(data); + EXPECT_EQ(128, to.map_int32_foreign_message().at(0).c()); +} + // ReflectionOps Test =============================================== TEST(ReflectionOpsForMapFieldTest, MapSanityCheck) { @@ -2751,6 +2863,20 @@ TEST(ReflectionOpsForMapFieldTest, MapDiscardUnknownFields) { GetUnknownFields(message).field_count()); } +TEST(ReflectionOpsForMapFieldTest, IsInitialized) { + unittest::TestRequiredMessageMap map_message; + + // Add an uninitialized message. + (*map_message.mutable_map_field())[0]; + EXPECT_FALSE(ReflectionOps::IsInitialized(map_message)); + + // Initialize uninitialized message + (*map_message.mutable_map_field())[0].set_a(0); + (*map_message.mutable_map_field())[0].set_b(0); + (*map_message.mutable_map_field())[0].set_c(0); + EXPECT_TRUE(ReflectionOps::IsInitialized(map_message)); +} + // Wire Format Test ================================================= TEST(WireFormatForMapFieldTest, ParseMap) { @@ -3089,7 +3215,7 @@ TEST(ArenaTest, ParsingAndSerializingNoHeapAllocation) { } // Use text format parsing and serializing to test reflection api. -TEST(ArenaTest, RelfectionInTextFormat) { +TEST(ArenaTest, ReflectionInTextFormat) { Arena arena; string data; diff --git a/src/google/protobuf/map_test_util.cc b/src/google/protobuf/map_test_util.cc index 3dd6aae55..4d3ad6092 100644 --- a/src/google/protobuf/map_test_util.cc +++ b/src/google/protobuf/map_test_util.cc @@ -744,6 +744,22 @@ void MapReflectionTester::SetMapFieldsViaMapReflection( sub_foreign_message, foreign_c_, 1); } +void MapReflectionTester::GetMapValueViaMapReflection(Message* message, + const string& field_name, + const MapKey& map_key, + MapValueRef* map_val) { + const Reflection* reflection = message->GetReflection(); + EXPECT_FALSE(reflection->InsertOrLookupMapValue(message, F(field_name), + map_key, map_val)); +} + +Message* MapReflectionTester::GetMapEntryViaReflection(Message* message, + const string& field_name, + int index) { + const Reflection* reflection = message->GetReflection(); + return reflection->MutableRepeatedMessage(message, F(field_name), index); +} + void MapReflectionTester::ClearMapFieldsViaReflection( Message* message) { const Reflection* reflection = message->GetReflection(); diff --git a/src/google/protobuf/map_test_util.h b/src/google/protobuf/map_test_util.h index deaf0f4f4..15c6c2894 100644 --- a/src/google/protobuf/map_test_util.h +++ b/src/google/protobuf/map_test_util.h @@ -106,6 +106,11 @@ class MapReflectionTester { void ExpectClearViaReflection(const Message& message); void ExpectClearViaReflectionIterator(Message* message); void ExpectMapEntryClearViaReflection(Message* message); + void GetMapValueViaMapReflection(Message* message, + const string& field_name, + const MapKey& map_key, MapValueRef* map_val); + Message* GetMapEntryViaReflection(Message* message, const string& field_name, + int index); private: const FieldDescriptor* F(const string& name); diff --git a/src/google/protobuf/message.h b/src/google/protobuf/message.h index 7d9bb8a9e..68acb5b13 100644 --- a/src/google/protobuf/message.h +++ b/src/google/protobuf/message.h @@ -154,6 +154,13 @@ class MapReflectionFriend; // scalar_map_container.h } +namespace internal { +class ReflectionOps; // reflection_ops.h +class MapKeySorter; // wire_format.cc +class WireFormat; // wire_format.h +class MapFieldReflectionTest; // map_test.cc +} + template class RepeatedField; // repeated_field.h @@ -936,6 +943,10 @@ class LIBPROTOBUF_EXPORT Reflection { template friend class MutableRepeatedFieldRef; friend class ::google::protobuf::python::MapReflectionFriend; + friend class internal::MapFieldReflectionTest; + friend class internal::MapKeySorter; + friend class internal::WireFormat; + friend class internal::ReflectionOps; // Special version for specialized implementations of string. We can't call // MutableRawRepeatedField directly here because we don't have access to diff --git a/src/google/protobuf/reflection_ops.cc b/src/google/protobuf/reflection_ops.cc index bb9c7f8bf..d18673118 100644 --- a/src/google/protobuf/reflection_ops.cc +++ b/src/google/protobuf/reflection_ops.cc @@ -38,6 +38,7 @@ #include #include #include +#include #include #include @@ -158,6 +159,27 @@ bool ReflectionOps::IsInitialized(const Message& message) { const FieldDescriptor* field = fields[i]; if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { + if (field->is_map()) { + const FieldDescriptor* value_field = field->message_type()->field(1); + if (value_field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { + MapFieldBase* map_field = + reflection->MapData(const_cast(&message), field); + if (map_field->IsMapValid()) { + MapIterator iter(const_cast(&message), field); + MapIterator end(const_cast(&message), field); + for (map_field->MapBegin(&iter), map_field->MapEnd(&end); + iter != end; ++iter) { + if (!iter.GetValueRef().GetMessageValue().IsInitialized()) { + return false; + } + } + continue; + } + } else { + continue; + } + } + if (field->is_repeated()) { int size = reflection->FieldSize(message, field); diff --git a/src/google/protobuf/wire_format.cc b/src/google/protobuf/wire_format.cc index 7778ecd19..01704c947 100644 --- a/src/google/protobuf/wire_format.cc +++ b/src/google/protobuf/wire_format.cc @@ -54,9 +54,17 @@ namespace google { +const size_t kMapEntryTagByteSize = 2; + namespace protobuf { namespace internal { +// Forward declare static functions +static size_t MapKeyDataOnlyByteSize(const FieldDescriptor* field, + const MapKey& value); +static size_t MapValueRefDataOnlyByteSize(const FieldDescriptor* field, + const MapValueRef& value); + // =================================================================== bool UnknownFieldSetFieldSkipper::SkipField( @@ -825,6 +833,129 @@ void WireFormat::SerializeWithCachedSizes( "during serialization?"; } +static void SerializeMapKeyWithCachedSizes(const FieldDescriptor* field, + const MapKey& value, + io::CodedOutputStream* output) { + switch (field->type()) { + case FieldDescriptor::TYPE_DOUBLE: + case FieldDescriptor::TYPE_FLOAT: + case FieldDescriptor::TYPE_GROUP: + case FieldDescriptor::TYPE_MESSAGE: + case FieldDescriptor::TYPE_BYTES: + case FieldDescriptor::TYPE_ENUM: + GOOGLE_LOG(FATAL) << "Unsupported"; + break; +#define CASE_TYPE(FieldType, CamelFieldType, CamelCppType) \ + case FieldDescriptor::TYPE_##FieldType: \ + WireFormatLite::Write##CamelFieldType(1, value.Get##CamelCppType##Value(), \ + output); \ + break; + CASE_TYPE(INT64, Int64, Int64) + CASE_TYPE(UINT64, UInt64, UInt64) + CASE_TYPE(INT32, Int32, Int32) + CASE_TYPE(FIXED64, Fixed64, UInt64) + CASE_TYPE(FIXED32, Fixed32, UInt32) + CASE_TYPE(BOOL, Bool, Bool) + CASE_TYPE(UINT32, UInt32, UInt32) + CASE_TYPE(SFIXED32, SFixed32, Int32) + CASE_TYPE(SFIXED64, SFixed64, Int64) + CASE_TYPE(SINT32, SInt32, Int32) + CASE_TYPE(SINT64, SInt64, Int64) + CASE_TYPE(STRING, String, String) +#undef CASE_TYPE + } +} + +static void SerializeMapValueRefWithCachedSizes(const FieldDescriptor* field, + const MapValueRef& value, + io::CodedOutputStream* output) { + switch (field->type()) { +#define CASE_TYPE(FieldType, CamelFieldType, CamelCppType) \ + case FieldDescriptor::TYPE_##FieldType: \ + WireFormatLite::Write##CamelFieldType(2, value.Get##CamelCppType##Value(), \ + output); \ + break; + CASE_TYPE(INT64, Int64, Int64) + CASE_TYPE(UINT64, UInt64, UInt64) + CASE_TYPE(INT32, Int32, Int32) + CASE_TYPE(FIXED64, Fixed64, UInt64) + CASE_TYPE(FIXED32, Fixed32, UInt32) + CASE_TYPE(BOOL, Bool, Bool) + CASE_TYPE(UINT32, UInt32, UInt32) + CASE_TYPE(SFIXED32, SFixed32, Int32) + CASE_TYPE(SFIXED64, SFixed64, Int64) + CASE_TYPE(SINT32, SInt32, Int32) + CASE_TYPE(SINT64, SInt64, Int64) + CASE_TYPE(ENUM, Enum, Enum) + CASE_TYPE(DOUBLE, Double, Double) + CASE_TYPE(FLOAT, Float, Float) + CASE_TYPE(STRING, String, String) + CASE_TYPE(BYTES, Bytes, String) + CASE_TYPE(MESSAGE, Message, Message) + CASE_TYPE(GROUP, Group, Message) +#undef CASE_TYPE + } +} + +class MapKeySorter { + public: + static std::vector SortKey(const Message& message, + const Reflection* reflection, + const FieldDescriptor* field) { + std::vector sorted_key_list; + for (MapIterator it = + reflection->MapBegin(const_cast(&message), field); + it != reflection->MapEnd(const_cast(&message), field); + ++it) { + sorted_key_list.push_back(it.GetKey()); + } + MapKeyComparator comparator; + std::sort(sorted_key_list.begin(), sorted_key_list.end(), comparator); + return sorted_key_list; + } + + private: + class MapKeyComparator { + public: + bool operator()(const MapKey& a, const MapKey& b) const { + GOOGLE_DCHECK(a.type() == b.type()); + switch (a.type()) { +#define CASE_TYPE(CppType, CamelCppType) \ + case FieldDescriptor::CPPTYPE_##CppType: { \ + return a.Get##CamelCppType##Value() < b.Get##CamelCppType##Value(); \ + } + CASE_TYPE(STRING, String) + CASE_TYPE(INT64, Int64) + CASE_TYPE(INT32, Int32) + CASE_TYPE(UINT64, UInt64) + CASE_TYPE(UINT32, UInt32) + CASE_TYPE(BOOL, Bool) +#undef CASE_TYPE + + default: + GOOGLE_LOG(DFATAL) << "Invalid key for map field."; + return true; + } + } + }; +}; + +static void SerializeMapEntry(const FieldDescriptor* field, const MapKey& key, + const MapValueRef& value, + io::CodedOutputStream* output) { + const FieldDescriptor* key_field = field->message_type()->field(0); + const FieldDescriptor* value_field = field->message_type()->field(1); + + WireFormatLite::WriteTag(field->number(), + WireFormatLite::WIRETYPE_LENGTH_DELIMITED, output); + size_t size = kMapEntryTagByteSize; + size += MapKeyDataOnlyByteSize(key_field, key); + size += MapValueRefDataOnlyByteSize(value_field, value); + output->WriteVarint32(size); + SerializeMapKeyWithCachedSizes(key_field, key, output); + SerializeMapValueRefWithCachedSizes(value_field, value, output); +} + void WireFormat::SerializeFieldWithCachedSizes( const FieldDescriptor* field, const Message& message, @@ -839,6 +970,48 @@ void WireFormat::SerializeFieldWithCachedSizes( return; } + // For map fields, we can use either repeated field reflection or map + // reflection. Our choice has some subtle effects. If we use repeated field + // reflection here, then the repeated field representation becomes + // authoritative for this field: any existing references that came from map + // reflection remain valid for reading, but mutations to them are lost and + // will be overwritten next time we call map reflection! + // + // So far this mainly affects Python, which keeps long-term references to map + // values around, and always uses map reflection. See: b/35918691 + // + // Here we choose to use map reflection API as long as the internal + // map is valid. In this way, the serialization doesn't change map field's + // internal state and existing references that came from map reflection remain + // valid for both reading and writing. + if (field->is_map()) { + MapFieldBase* map_field = + message_reflection->MapData(const_cast(&message), field); + if (map_field->IsMapValid()) { + if (output->IsSerializationDeterministic()) { + std::vector sorted_key_list = + MapKeySorter::SortKey(message, message_reflection, field); + for (std::vector::iterator it = sorted_key_list.begin(); + it != sorted_key_list.end(); ++it) { + MapValueRef map_value; + message_reflection->InsertOrLookupMapValue( + const_cast(&message), field, *it, &map_value); + SerializeMapEntry(field, *it, map_value, output); + } + } else { + for (MapIterator it = message_reflection->MapBegin( + const_cast(&message), field); + it != + message_reflection->MapEnd(const_cast(&message), field); + ++it) { + SerializeMapEntry(field, it.GetKey(), it.GetValueRef(), output); + } + } + + return; + } + } + int count = 0; if (field->is_repeated()) { @@ -1059,11 +1232,113 @@ size_t WireFormat::FieldByteSize( return our_size; } +static size_t MapKeyDataOnlyByteSize(const FieldDescriptor* field, + const MapKey& value) { + GOOGLE_DCHECK_EQ(FieldDescriptor::TypeToCppType(field->type()), value.type()); + switch (field->type()) { + case FieldDescriptor::TYPE_DOUBLE: + case FieldDescriptor::TYPE_FLOAT: + case FieldDescriptor::TYPE_GROUP: + case FieldDescriptor::TYPE_MESSAGE: + case FieldDescriptor::TYPE_BYTES: + case FieldDescriptor::TYPE_ENUM: + GOOGLE_LOG(FATAL) << "Unsupported"; + return 0; +#define CASE_TYPE(FieldType, CamelFieldType, CamelCppType) \ + case FieldDescriptor::TYPE_##FieldType: \ + return WireFormatLite::CamelFieldType##Size( \ + value.Get##CamelCppType##Value()); + +#define FIXED_CASE_TYPE(FieldType, CamelFieldType) \ + case FieldDescriptor::TYPE_##FieldType: \ + return WireFormatLite::k##CamelFieldType##Size; + + CASE_TYPE(INT32, Int32, Int32); + CASE_TYPE(INT64, Int64, Int64); + CASE_TYPE(UINT32, UInt32, UInt32); + CASE_TYPE(UINT64, UInt64, UInt64); + CASE_TYPE(SINT32, SInt32, Int32); + CASE_TYPE(SINT64, SInt64, Int64); + CASE_TYPE(STRING, String, String); + FIXED_CASE_TYPE(FIXED32, Fixed32); + FIXED_CASE_TYPE(FIXED64, Fixed64); + FIXED_CASE_TYPE(SFIXED32, SFixed32); + FIXED_CASE_TYPE(SFIXED64, SFixed64); + FIXED_CASE_TYPE(BOOL, Bool); + +#undef CASE_TYPE +#undef FIXED_CASE_TYPE + } + GOOGLE_LOG(FATAL) << "Cannot get here"; + return 0; +} + +static size_t MapValueRefDataOnlyByteSize(const FieldDescriptor* field, + const MapValueRef& value) { + switch (field->type()) { + case FieldDescriptor::TYPE_GROUP: + GOOGLE_LOG(FATAL) << "Unsupported"; + return 0; +#define CASE_TYPE(FieldType, CamelFieldType, CamelCppType) \ + case FieldDescriptor::TYPE_##FieldType: \ + return WireFormatLite::CamelFieldType##Size( \ + value.Get##CamelCppType##Value()); + +#define FIXED_CASE_TYPE(FieldType, CamelFieldType) \ + case FieldDescriptor::TYPE_##FieldType: \ + return WireFormatLite::k##CamelFieldType##Size; + + CASE_TYPE(INT32, Int32, Int32); + CASE_TYPE(INT64, Int64, Int64); + CASE_TYPE(UINT32, UInt32, UInt32); + CASE_TYPE(UINT64, UInt64, UInt64); + CASE_TYPE(SINT32, SInt32, Int32); + CASE_TYPE(SINT64, SInt64, Int64); + CASE_TYPE(STRING, String, String); + CASE_TYPE(BYTES, Bytes, String); + CASE_TYPE(ENUM, Enum, Enum); + CASE_TYPE(MESSAGE, Message, Message); + FIXED_CASE_TYPE(FIXED32, Fixed32); + FIXED_CASE_TYPE(FIXED64, Fixed64); + FIXED_CASE_TYPE(SFIXED32, SFixed32); + FIXED_CASE_TYPE(SFIXED64, SFixed64); + FIXED_CASE_TYPE(DOUBLE, Double); + FIXED_CASE_TYPE(FLOAT, Float); + FIXED_CASE_TYPE(BOOL, Bool); + +#undef CASE_TYPE +#undef FIXED_CASE_TYPE + } + GOOGLE_LOG(FATAL) << "Cannot get here"; + return 0; +} + size_t WireFormat::FieldDataOnlyByteSize( const FieldDescriptor* field, const Message& message) { const Reflection* message_reflection = message.GetReflection(); + size_t data_size = 0; + + if (field->is_map()) { + MapFieldBase* map_field = + message_reflection->MapData(const_cast(&message), field); + if (map_field->IsMapValid()) { + MapIterator iter(const_cast(&message), field); + MapIterator end(const_cast(&message), field); + const FieldDescriptor* key_field = field->message_type()->field(0); + const FieldDescriptor* value_field = field->message_type()->field(1); + for (map_field->MapBegin(&iter), map_field->MapEnd(&end); iter != end; + ++iter) { + size_t size = kMapEntryTagByteSize; + size += MapKeyDataOnlyByteSize(key_field, iter.GetKey()); + size += MapValueRefDataOnlyByteSize(value_field, iter.GetValueRef()); + data_size += WireFormatLite::LengthDelimitedSize(size); + } + return data_size; + } + } + size_t count = 0; if (field->is_repeated()) { count = @@ -1075,7 +1350,6 @@ size_t WireFormat::FieldDataOnlyByteSize( count = 1; } - size_t data_size = 0; switch (field->type()) { #define HANDLE_TYPE(TYPE, TYPE_METHOD, CPPTYPE_METHOD) \ case FieldDescriptor::TYPE_##TYPE: \