From 9bdc848832b6f6e27ea4389a72c566ec43329114 Mon Sep 17 00:00:00 2001 From: Jon Skeet Date: Mon, 15 Feb 2016 11:58:01 +0000 Subject: [PATCH] Validate that end-group tags match their corresponding start-group tags This detects: - An end-group tag with the wrong field number (doesn't match the start-group field) - An end-group tag with no preceding start-group tag Fixes issue #688. --- .../CodedInputStreamTest.cs | 50 ++++++++++++++++++- .../GeneratedMessageTest.cs | 7 ++- .../src/Google.Protobuf/CodedInputStream.cs | 32 +++++++++--- 3 files changed, 77 insertions(+), 12 deletions(-) diff --git a/csharp/src/Google.Protobuf.Test/CodedInputStreamTest.cs b/csharp/src/Google.Protobuf.Test/CodedInputStreamTest.cs index 6ae021124..0e7cf04ef 100644 --- a/csharp/src/Google.Protobuf.Test/CodedInputStreamTest.cs +++ b/csharp/src/Google.Protobuf.Test/CodedInputStreamTest.cs @@ -469,6 +469,52 @@ namespace Google.Protobuf Assert.AreEqual("field 3", input.ReadString()); } + [Test] + public void SkipGroup_WrongEndGroupTag() + { + // Create an output stream with: + // Field 1: string "field 1" + // Start group 2 + // Field 3: fixed int32 + // End group 4 (should give an error) + var stream = new MemoryStream(); + var output = new CodedOutputStream(stream); + output.WriteTag(1, WireFormat.WireType.LengthDelimited); + output.WriteString("field 1"); + + // The outer group... + output.WriteTag(2, WireFormat.WireType.StartGroup); + output.WriteTag(3, WireFormat.WireType.Fixed32); + output.WriteFixed32(100); + output.WriteTag(4, WireFormat.WireType.EndGroup); + output.Flush(); + stream.Position = 0; + + // Now act like a generated client + var input = new CodedInputStream(stream); + Assert.AreEqual(WireFormat.MakeTag(1, WireFormat.WireType.LengthDelimited), input.ReadTag()); + Assert.AreEqual("field 1", input.ReadString()); + Assert.AreEqual(WireFormat.MakeTag(2, WireFormat.WireType.StartGroup), input.ReadTag()); + Assert.Throws(input.SkipLastField); + } + + [Test] + public void RogueEndGroupTag() + { + // If we have an end-group tag without a leading start-group tag, generated + // code will just call SkipLastField... so that should fail. + + var stream = new MemoryStream(); + var output = new CodedOutputStream(stream); + output.WriteTag(1, WireFormat.WireType.EndGroup); + output.Flush(); + stream.Position = 0; + + var input = new CodedInputStream(stream); + Assert.AreEqual(WireFormat.MakeTag(1, WireFormat.WireType.EndGroup), input.ReadTag()); + Assert.Throws(input.SkipLastField); + } + [Test] public void EndOfStreamReachedWhileSkippingGroup() { @@ -484,7 +530,7 @@ namespace Google.Protobuf // Now act like a generated client var input = new CodedInputStream(stream); input.ReadTag(); - Assert.Throws(() => input.SkipLastField()); + Assert.Throws(input.SkipLastField); } [Test] @@ -506,7 +552,7 @@ namespace Google.Protobuf // Now act like a generated client var input = new CodedInputStream(stream); Assert.AreEqual(WireFormat.MakeTag(1, WireFormat.WireType.StartGroup), input.ReadTag()); - Assert.Throws(() => input.SkipLastField()); + Assert.Throws(input.SkipLastField); } [Test] diff --git a/csharp/src/Google.Protobuf.Test/GeneratedMessageTest.cs b/csharp/src/Google.Protobuf.Test/GeneratedMessageTest.cs index 14cc6d194..67069954a 100644 --- a/csharp/src/Google.Protobuf.Test/GeneratedMessageTest.cs +++ b/csharp/src/Google.Protobuf.Test/GeneratedMessageTest.cs @@ -679,21 +679,20 @@ namespace Google.Protobuf /// for details; we may want to change this. /// [Test] - public void ExtraEndGroupSkipped() + public void ExtraEndGroupThrows() { var message = SampleMessages.CreateFullTestAllTypes(); var stream = new MemoryStream(); var output = new CodedOutputStream(stream); - output.WriteTag(100, WireFormat.WireType.EndGroup); output.WriteTag(TestAllTypes.SingleFixed32FieldNumber, WireFormat.WireType.Fixed32); output.WriteFixed32(123); + output.WriteTag(100, WireFormat.WireType.EndGroup); output.Flush(); stream.Position = 0; - var parsed = TestAllTypes.Parser.ParseFrom(stream); - Assert.AreEqual(new TestAllTypes { SingleFixed32 = 123 }, parsed); + Assert.Throws(() => TestAllTypes.Parser.ParseFrom(stream)); } [Test] diff --git a/csharp/src/Google.Protobuf/CodedInputStream.cs b/csharp/src/Google.Protobuf/CodedInputStream.cs index 91bed8e3c..1c02d9519 100644 --- a/csharp/src/Google.Protobuf/CodedInputStream.cs +++ b/csharp/src/Google.Protobuf/CodedInputStream.cs @@ -349,6 +349,14 @@ namespace Google.Protobuf /// This should be called directly after , when /// the caller wishes to skip an unknown field. /// + /// + /// This method throws if the last-read tag was an end-group tag. + /// If a caller wishes to skip a group, they should skip the whole group, by calling this method after reading the + /// start-group tag. This behavior allows callers to call this method on any field they don't understand, correctly + /// resulting in an error if an end-group tag has not been paired with an earlier start-group tag. + /// + /// The last tag was an end-group tag + /// The last read operation read to the end of the logical stream public void SkipLastField() { if (lastTag == 0) @@ -358,11 +366,11 @@ namespace Google.Protobuf switch (WireFormat.GetTagWireType(lastTag)) { case WireFormat.WireType.StartGroup: - SkipGroup(); + SkipGroup(lastTag); break; case WireFormat.WireType.EndGroup: - // Just ignore; there's no data following the tag. - break; + throw new InvalidProtocolBufferException( + "SkipLastField called on an end-group tag, indicating that the corresponding start-group was missing"); case WireFormat.WireType.Fixed32: ReadFixed32(); break; @@ -379,7 +387,7 @@ namespace Google.Protobuf } } - private void SkipGroup() + private void SkipGroup(uint startGroupTag) { // Note: Currently we expect this to be the way that groups are read. We could put the recursion // depth changes into the ReadTag method instead, potentially... @@ -389,16 +397,28 @@ namespace Google.Protobuf throw InvalidProtocolBufferException.RecursionLimitExceeded(); } uint tag; - do + while (true) { tag = ReadTag(); if (tag == 0) { throw InvalidProtocolBufferException.TruncatedMessage(); } + // Can't call SkipLastField for this case- that would throw. + if (WireFormat.GetTagWireType(tag) == WireFormat.WireType.EndGroup) + { + break; + } // This recursion will allow us to handle nested groups. SkipLastField(); - } while (WireFormat.GetTagWireType(tag) != WireFormat.WireType.EndGroup); + } + int startField = WireFormat.GetTagFieldNumber(startGroupTag); + int endField = WireFormat.GetTagFieldNumber(tag); + if (startField != endField) + { + throw new InvalidProtocolBufferException( + $"Mismatched end-group tag. Started with field {startField}; ended with field {endField}"); + } recursionDepth--; }