diff --git a/csharp/src/Google.Protobuf.Test/CodedInputStreamTest.cs b/csharp/src/Google.Protobuf.Test/CodedInputStreamTest.cs index c4c92efd3..42c740ac4 100644 --- a/csharp/src/Google.Protobuf.Test/CodedInputStreamTest.cs +++ b/csharp/src/Google.Protobuf.Test/CodedInputStreamTest.cs @@ -442,5 +442,92 @@ namespace Google.Protobuf var input = new CodedInputStream(new byte[] { 0 }); Assert.Throws(() => input.ReadTag()); } + + [Test] + public void SkipGroup() + { + // Create an output stream with a group in: + // Field 1: string "field 1" + // Field 2: group containing: + // Field 1: fixed int32 value 100 + // Field 2: string "ignore me" + // Field 3: nested group containing + // Field 1: fixed int64 value 1000 + // Field 3: string "field 3" + var stream = new MemoryStream(); + var output = new CodedOutputStream(stream); + output.WriteTag(1, WireFormat.WireType.LengthDelimited); + output.WriteString("field 1"); + + // The outer group... + output.WriteTag(2, WireFormat.WireType.StartGroup); + output.WriteTag(1, WireFormat.WireType.Fixed32); + output.WriteFixed32(100); + output.WriteTag(2, WireFormat.WireType.LengthDelimited); + output.WriteString("ignore me"); + // The nested group... + output.WriteTag(3, WireFormat.WireType.StartGroup); + output.WriteTag(1, WireFormat.WireType.Fixed64); + output.WriteFixed64(1000); + // Note: Not sure the field number is relevant for end group... + output.WriteTag(3, WireFormat.WireType.EndGroup); + + // End the outer group + output.WriteTag(2, WireFormat.WireType.EndGroup); + + output.WriteTag(3, WireFormat.WireType.LengthDelimited); + output.WriteString("field 3"); + output.Flush(); + stream.Position = 0; + + // Now act like a generated client + var input = new CodedInputStream(stream); + Assert.AreEqual(WireFormat.MakeTag(1, WireFormat.WireType.LengthDelimited), input.ReadTag()); + Assert.AreEqual("field 1", input.ReadString()); + Assert.AreEqual(WireFormat.MakeTag(2, WireFormat.WireType.StartGroup), input.ReadTag()); + input.SkipLastField(); // Should consume the whole group, including the nested one. + Assert.AreEqual(WireFormat.MakeTag(3, WireFormat.WireType.LengthDelimited), input.ReadTag()); + Assert.AreEqual("field 3", input.ReadString()); + } + + [Test] + public void EndOfStreamReachedWhileSkippingGroup() + { + var stream = new MemoryStream(); + var output = new CodedOutputStream(stream); + output.WriteTag(1, WireFormat.WireType.StartGroup); + output.WriteTag(2, WireFormat.WireType.StartGroup); + output.WriteTag(2, WireFormat.WireType.EndGroup); + + output.Flush(); + stream.Position = 0; + + // Now act like a generated client + var input = new CodedInputStream(stream); + input.ReadTag(); + Assert.Throws(() => input.SkipLastField()); + } + + [Test] + public void RecursionLimitAppliedWhileSkippingGroup() + { + var stream = new MemoryStream(); + var output = new CodedOutputStream(stream); + for (int i = 0; i < CodedInputStream.DefaultRecursionLimit + 1; i++) + { + output.WriteTag(1, WireFormat.WireType.StartGroup); + } + for (int i = 0; i < CodedInputStream.DefaultRecursionLimit + 1; i++) + { + output.WriteTag(1, WireFormat.WireType.EndGroup); + } + output.Flush(); + stream.Position = 0; + + // Now act like a generated client + var input = new CodedInputStream(stream); + Assert.AreEqual(WireFormat.MakeTag(1, WireFormat.WireType.StartGroup), input.ReadTag()); + Assert.Throws(() => input.SkipLastField()); + } } } \ No newline at end of file diff --git a/csharp/src/Google.Protobuf/CodedInputStream.cs b/csharp/src/Google.Protobuf/CodedInputStream.cs index 0e2495f1d..a37fefc10 100644 --- a/csharp/src/Google.Protobuf/CodedInputStream.cs +++ b/csharp/src/Google.Protobuf/CodedInputStream.cs @@ -236,17 +236,16 @@ namespace Google.Protobuf #region Validation /// - /// Verifies that the last call to ReadTag() returned the given tag value. - /// This is used to verify that a nested group ended with the correct - /// end tag. + /// Verifies that the last call to ReadTag() returned tag 0 - in other words, + /// we've reached the end of the stream when we expected to. /// - /// The last + /// The /// tag read was not the one specified - internal void CheckLastTagWas(uint value) + internal void CheckReadEndOfStreamTag() { - if (lastTag != value) + if (lastTag != 0) { - throw InvalidProtocolBufferException.InvalidEndTag(); + throw InvalidProtocolBufferException.MoreDataAvailable(); } } #endregion @@ -275,6 +274,11 @@ namespace Google.Protobuf /// /// Reads a field tag, returning the tag of 0 for "end of stream". /// + /// + /// If this method returns 0, it doesn't necessarily mean the end of all + /// the data in this CodedInputStream; it may be the end of the logical stream + /// for an embedded message, for example. + /// /// The next field tag, or 0 for end of stream. (0 is never a valid tag.) public uint ReadTag() { @@ -329,22 +333,24 @@ namespace Google.Protobuf } /// - /// Consumes the data for the field with the tag we've just read. + /// Skips the data for the field with the tag we've just read. /// This should be called directly after , when /// the caller wishes to skip an unknown field. /// - public void ConsumeLastField() + public void SkipLastField() { if (lastTag == 0) { - throw new InvalidOperationException("ConsumeLastField cannot be called at the end of a stream"); + throw new InvalidOperationException("SkipLastField cannot be called at the end of a stream"); } switch (WireFormat.GetTagWireType(lastTag)) { case WireFormat.WireType.StartGroup: + ConsumeGroup(); + break; case WireFormat.WireType.EndGroup: - // TODO: Work out how to skip them instead? See issue 688. - throw new InvalidProtocolBufferException("Group tags not supported by proto3 C# implementation"); + // Just ignore; there's no data following the tag. + break; case WireFormat.WireType.Fixed32: ReadFixed32(); break; @@ -361,6 +367,29 @@ namespace Google.Protobuf } } + private void ConsumeGroup() + { + // Note: Currently we expect this to be the way that groups are read. We could put the recursion + // depth changes into the ReadTag method instead, potentially... + recursionDepth++; + if (recursionDepth >= recursionLimit) + { + throw InvalidProtocolBufferException.RecursionLimitExceeded(); + } + uint tag; + do + { + tag = ReadTag(); + if (tag == 0) + { + throw InvalidProtocolBufferException.TruncatedMessage(); + } + // This recursion will allow us to handle nested groups. + SkipLastField(); + } while (WireFormat.GetTagWireType(tag) != WireFormat.WireType.EndGroup); + recursionDepth--; + } + /// /// Reads a double field from the stream. /// @@ -475,7 +504,7 @@ namespace Google.Protobuf int oldLimit = PushLimit(length); ++recursionDepth; builder.MergeFrom(this); - CheckLastTagWas(0); + CheckReadEndOfStreamTag(); // Check that we've read exactly as much data as expected. if (!ReachedLimit) { diff --git a/csharp/src/Google.Protobuf/Collections/MapField.cs b/csharp/src/Google.Protobuf/Collections/MapField.cs index 5eb2c2fcb..dc4b04cbb 100644 --- a/csharp/src/Google.Protobuf/Collections/MapField.cs +++ b/csharp/src/Google.Protobuf/Collections/MapField.cs @@ -637,10 +637,9 @@ namespace Google.Protobuf.Collections { Value = codec.valueCodec.Read(input); } - else if (WireFormat.IsEndGroupTag(tag)) + else { - // TODO(jonskeet): Do we need this? (Given that we don't support groups...) - return; + input.SkipLastField(); } } } diff --git a/csharp/src/Google.Protobuf/FieldCodec.cs b/csharp/src/Google.Protobuf/FieldCodec.cs index 15d52c7d5..20a1f438f 100644 --- a/csharp/src/Google.Protobuf/FieldCodec.cs +++ b/csharp/src/Google.Protobuf/FieldCodec.cs @@ -304,12 +304,13 @@ namespace Google.Protobuf { value = codec.Read(input); } - if (WireFormat.IsEndGroupTag(tag)) + else { - break; + input.SkipLastField(); } + } - input.CheckLastTagWas(0); + input.CheckReadEndOfStreamTag(); input.PopLimit(oldLimit); return value; diff --git a/csharp/src/Google.Protobuf/MessageExtensions.cs b/csharp/src/Google.Protobuf/MessageExtensions.cs index ee78dc8dd..d2d057c0c 100644 --- a/csharp/src/Google.Protobuf/MessageExtensions.cs +++ b/csharp/src/Google.Protobuf/MessageExtensions.cs @@ -50,7 +50,7 @@ namespace Google.Protobuf Preconditions.CheckNotNull(data, "data"); CodedInputStream input = new CodedInputStream(data); message.MergeFrom(input); - input.CheckLastTagWas(0); + input.CheckReadEndOfStreamTag(); } /// @@ -64,7 +64,7 @@ namespace Google.Protobuf Preconditions.CheckNotNull(data, "data"); CodedInputStream input = data.CreateCodedInput(); message.MergeFrom(input); - input.CheckLastTagWas(0); + input.CheckReadEndOfStreamTag(); } /// @@ -78,7 +78,7 @@ namespace Google.Protobuf Preconditions.CheckNotNull(input, "input"); CodedInputStream codedInput = new CodedInputStream(input); message.MergeFrom(codedInput); - codedInput.CheckLastTagWas(0); + codedInput.CheckReadEndOfStreamTag(); } /// diff --git a/csharp/src/Google.Protobuf/WireFormat.cs b/csharp/src/Google.Protobuf/WireFormat.cs index bbd7e4f96..b0e4a41f0 100644 --- a/csharp/src/Google.Protobuf/WireFormat.cs +++ b/csharp/src/Google.Protobuf/WireFormat.cs @@ -98,16 +98,6 @@ namespace Google.Protobuf return (WireType) (tag & TagTypeMask); } - /// - /// Determines whether the given tag is an end group tag. - /// - /// The tag to check. - /// true if the given tag is an end group tag; false otherwise. - public static bool IsEndGroupTag(uint tag) - { - return (WireType) (tag & TagTypeMask) == WireType.EndGroup; - } - /// /// Given a tag value, determines the field number (the upper 29 bits). /// diff --git a/src/google/protobuf/compiler/csharp/csharp_message.cc b/src/google/protobuf/compiler/csharp/csharp_message.cc index 40c13de5b..a71a7909f 100644 --- a/src/google/protobuf/compiler/csharp/csharp_message.cc +++ b/src/google/protobuf/compiler/csharp/csharp_message.cc @@ -423,10 +423,7 @@ void MessageGenerator::GenerateMergingMethods(io::Printer* printer) { printer->Indent(); printer->Print( "default:\n" - " if (pb::WireFormat.IsEndGroupTag(tag)) {\n" - " return;\n" - " }\n" - " input.ConsumeLastField();\n" // We're not storing the data, but we still need to consume it. + " input.SkipLastField();\n" // We're not storing the data, but we still need to consume it. " break;\n"); for (int i = 0; i < fields_by_number().size(); i++) { const FieldDescriptor* field = fields_by_number()[i];