Validate that end-group tags match their corresponding start-group tags
This detects: - An end-group tag with the wrong field number (doesn't match the start-group field) - An end-group tag with no preceding start-group tag Fixes issue #688.
This commit is contained in:
parent
e35e24800f
commit
9bdc848832
@ -469,6 +469,52 @@ namespace Google.Protobuf
|
||||
Assert.AreEqual("field 3", input.ReadString());
|
||||
}
|
||||
|
||||
[Test]
|
||||
public void SkipGroup_WrongEndGroupTag()
|
||||
{
|
||||
// Create an output stream with:
|
||||
// Field 1: string "field 1"
|
||||
// Start group 2
|
||||
// Field 3: fixed int32
|
||||
// End group 4 (should give an error)
|
||||
var stream = new MemoryStream();
|
||||
var output = new CodedOutputStream(stream);
|
||||
output.WriteTag(1, WireFormat.WireType.LengthDelimited);
|
||||
output.WriteString("field 1");
|
||||
|
||||
// The outer group...
|
||||
output.WriteTag(2, WireFormat.WireType.StartGroup);
|
||||
output.WriteTag(3, WireFormat.WireType.Fixed32);
|
||||
output.WriteFixed32(100);
|
||||
output.WriteTag(4, WireFormat.WireType.EndGroup);
|
||||
output.Flush();
|
||||
stream.Position = 0;
|
||||
|
||||
// Now act like a generated client
|
||||
var input = new CodedInputStream(stream);
|
||||
Assert.AreEqual(WireFormat.MakeTag(1, WireFormat.WireType.LengthDelimited), input.ReadTag());
|
||||
Assert.AreEqual("field 1", input.ReadString());
|
||||
Assert.AreEqual(WireFormat.MakeTag(2, WireFormat.WireType.StartGroup), input.ReadTag());
|
||||
Assert.Throws<InvalidProtocolBufferException>(input.SkipLastField);
|
||||
}
|
||||
|
||||
[Test]
|
||||
public void RogueEndGroupTag()
|
||||
{
|
||||
// If we have an end-group tag without a leading start-group tag, generated
|
||||
// code will just call SkipLastField... so that should fail.
|
||||
|
||||
var stream = new MemoryStream();
|
||||
var output = new CodedOutputStream(stream);
|
||||
output.WriteTag(1, WireFormat.WireType.EndGroup);
|
||||
output.Flush();
|
||||
stream.Position = 0;
|
||||
|
||||
var input = new CodedInputStream(stream);
|
||||
Assert.AreEqual(WireFormat.MakeTag(1, WireFormat.WireType.EndGroup), input.ReadTag());
|
||||
Assert.Throws<InvalidProtocolBufferException>(input.SkipLastField);
|
||||
}
|
||||
|
||||
[Test]
|
||||
public void EndOfStreamReachedWhileSkippingGroup()
|
||||
{
|
||||
@ -484,7 +530,7 @@ namespace Google.Protobuf
|
||||
// Now act like a generated client
|
||||
var input = new CodedInputStream(stream);
|
||||
input.ReadTag();
|
||||
Assert.Throws<InvalidProtocolBufferException>(() => input.SkipLastField());
|
||||
Assert.Throws<InvalidProtocolBufferException>(input.SkipLastField);
|
||||
}
|
||||
|
||||
[Test]
|
||||
@ -506,7 +552,7 @@ namespace Google.Protobuf
|
||||
// Now act like a generated client
|
||||
var input = new CodedInputStream(stream);
|
||||
Assert.AreEqual(WireFormat.MakeTag(1, WireFormat.WireType.StartGroup), input.ReadTag());
|
||||
Assert.Throws<InvalidProtocolBufferException>(() => input.SkipLastField());
|
||||
Assert.Throws<InvalidProtocolBufferException>(input.SkipLastField);
|
||||
}
|
||||
|
||||
[Test]
|
||||
|
@ -679,21 +679,20 @@ namespace Google.Protobuf
|
||||
/// for details; we may want to change this.
|
||||
/// </summary>
|
||||
[Test]
|
||||
public void ExtraEndGroupSkipped()
|
||||
public void ExtraEndGroupThrows()
|
||||
{
|
||||
var message = SampleMessages.CreateFullTestAllTypes();
|
||||
var stream = new MemoryStream();
|
||||
var output = new CodedOutputStream(stream);
|
||||
|
||||
output.WriteTag(100, WireFormat.WireType.EndGroup);
|
||||
output.WriteTag(TestAllTypes.SingleFixed32FieldNumber, WireFormat.WireType.Fixed32);
|
||||
output.WriteFixed32(123);
|
||||
output.WriteTag(100, WireFormat.WireType.EndGroup);
|
||||
|
||||
output.Flush();
|
||||
|
||||
stream.Position = 0;
|
||||
var parsed = TestAllTypes.Parser.ParseFrom(stream);
|
||||
Assert.AreEqual(new TestAllTypes { SingleFixed32 = 123 }, parsed);
|
||||
Assert.Throws<InvalidProtocolBufferException>(() => TestAllTypes.Parser.ParseFrom(stream));
|
||||
}
|
||||
|
||||
[Test]
|
||||
|
@ -349,6 +349,14 @@ namespace Google.Protobuf
|
||||
/// This should be called directly after <see cref="ReadTag"/>, when
|
||||
/// the caller wishes to skip an unknown field.
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// This method throws <see cref="InvalidProtocolBufferException"/> if the last-read tag was an end-group tag.
|
||||
/// If a caller wishes to skip a group, they should skip the whole group, by calling this method after reading the
|
||||
/// start-group tag. This behavior allows callers to call this method on any field they don't understand, correctly
|
||||
/// resulting in an error if an end-group tag has not been paired with an earlier start-group tag.
|
||||
/// </remarks>
|
||||
/// <exception cref="InvalidProtocolBufferException">The last tag was an end-group tag</exception>
|
||||
/// <exception cref="InvalidOperationException">The last read operation read to the end of the logical stream</exception>
|
||||
public void SkipLastField()
|
||||
{
|
||||
if (lastTag == 0)
|
||||
@ -358,11 +366,11 @@ namespace Google.Protobuf
|
||||
switch (WireFormat.GetTagWireType(lastTag))
|
||||
{
|
||||
case WireFormat.WireType.StartGroup:
|
||||
SkipGroup();
|
||||
SkipGroup(lastTag);
|
||||
break;
|
||||
case WireFormat.WireType.EndGroup:
|
||||
// Just ignore; there's no data following the tag.
|
||||
break;
|
||||
throw new InvalidProtocolBufferException(
|
||||
"SkipLastField called on an end-group tag, indicating that the corresponding start-group was missing");
|
||||
case WireFormat.WireType.Fixed32:
|
||||
ReadFixed32();
|
||||
break;
|
||||
@ -379,7 +387,7 @@ namespace Google.Protobuf
|
||||
}
|
||||
}
|
||||
|
||||
private void SkipGroup()
|
||||
private void SkipGroup(uint startGroupTag)
|
||||
{
|
||||
// Note: Currently we expect this to be the way that groups are read. We could put the recursion
|
||||
// depth changes into the ReadTag method instead, potentially...
|
||||
@ -389,16 +397,28 @@ namespace Google.Protobuf
|
||||
throw InvalidProtocolBufferException.RecursionLimitExceeded();
|
||||
}
|
||||
uint tag;
|
||||
do
|
||||
while (true)
|
||||
{
|
||||
tag = ReadTag();
|
||||
if (tag == 0)
|
||||
{
|
||||
throw InvalidProtocolBufferException.TruncatedMessage();
|
||||
}
|
||||
// Can't call SkipLastField for this case- that would throw.
|
||||
if (WireFormat.GetTagWireType(tag) == WireFormat.WireType.EndGroup)
|
||||
{
|
||||
break;
|
||||
}
|
||||
// This recursion will allow us to handle nested groups.
|
||||
SkipLastField();
|
||||
} while (WireFormat.GetTagWireType(tag) != WireFormat.WireType.EndGroup);
|
||||
}
|
||||
int startField = WireFormat.GetTagFieldNumber(startGroupTag);
|
||||
int endField = WireFormat.GetTagFieldNumber(tag);
|
||||
if (startField != endField)
|
||||
{
|
||||
throw new InvalidProtocolBufferException(
|
||||
$"Mismatched end-group tag. Started with field {startField}; ended with field {endField}");
|
||||
}
|
||||
recursionDepth--;
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user