enforce recursion depth checking for unknown fields (#7210)

This commit is contained in:
Jan Tattermusch 2020-02-14 19:17:06 +01:00 committed by GitHub
parent d314101531
commit 0e8f69e532
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 108 additions and 8 deletions

View File

@ -33,6 +33,7 @@
using System;
using System.IO;
using Google.Protobuf.TestProtos;
using Proto2 = Google.Protobuf.TestProtos.Proto2;
using NUnit.Framework;
namespace Google.Protobuf
@ -337,6 +338,66 @@ namespace Google.Protobuf
CodedInputStream input = CodedInputStream.CreateWithLimits(new MemoryStream(atRecursiveLimit.ToByteArray()), 1000000, CodedInputStream.DefaultRecursionLimit - 1);
Assert.Throws<InvalidProtocolBufferException>(() => TestRecursiveMessage.Parser.ParseFrom(input));
}
private static byte[] MakeMaliciousRecursionUnknownFieldsPayload(int recursionDepth)
{
// generate recursively nested groups that will be parsed as unknown fields
int unknownFieldNumber = 14; // an unused field number
MemoryStream ms = new MemoryStream();
CodedOutputStream output = new CodedOutputStream(ms);
for (int i = 0; i < recursionDepth; i++)
{
output.WriteTag(WireFormat.MakeTag(unknownFieldNumber, WireFormat.WireType.StartGroup));
}
for (int i = 0; i < recursionDepth; i++)
{
output.WriteTag(WireFormat.MakeTag(unknownFieldNumber, WireFormat.WireType.EndGroup));
}
output.Flush();
return ms.ToArray();
}
[Test]
public void MaliciousRecursion_UnknownFields()
{
byte[] payloadAtRecursiveLimit = MakeMaliciousRecursionUnknownFieldsPayload(CodedInputStream.DefaultRecursionLimit);
byte[] payloadBeyondRecursiveLimit = MakeMaliciousRecursionUnknownFieldsPayload(CodedInputStream.DefaultRecursionLimit + 1);
Assert.DoesNotThrow(() => TestRecursiveMessage.Parser.ParseFrom(payloadAtRecursiveLimit));
Assert.Throws<InvalidProtocolBufferException>(() => TestRecursiveMessage.Parser.ParseFrom(payloadBeyondRecursiveLimit));
}
[Test]
public void ReadGroup_WrongEndGroupTag()
{
int groupFieldNumber = Proto2.TestAllTypes.OptionalGroupFieldNumber;
// write Proto2.TestAllTypes with "optional_group" set, but use wrong EndGroup closing tag
MemoryStream ms = new MemoryStream();
CodedOutputStream output = new CodedOutputStream(ms);
output.WriteTag(WireFormat.MakeTag(groupFieldNumber, WireFormat.WireType.StartGroup));
output.WriteGroup(new Proto2.TestAllTypes.Types.OptionalGroup { A = 12345 });
// end group with different field number
output.WriteTag(WireFormat.MakeTag(groupFieldNumber + 1, WireFormat.WireType.EndGroup));
output.Flush();
var payload = ms.ToArray();
Assert.Throws<InvalidProtocolBufferException>(() => Proto2.TestAllTypes.Parser.ParseFrom(payload));
}
[Test]
public void ReadGroup_UnknownFields_WrongEndGroupTag()
{
MemoryStream ms = new MemoryStream();
CodedOutputStream output = new CodedOutputStream(ms);
output.WriteTag(WireFormat.MakeTag(14, WireFormat.WireType.StartGroup));
// end group with different field number
output.WriteTag(WireFormat.MakeTag(15, WireFormat.WireType.EndGroup));
output.Flush();
var payload = ms.ToArray();
Assert.Throws<InvalidProtocolBufferException>(() => TestRecursiveMessage.Parser.ParseFrom(payload));
}
[Test]
public void SizeLimit()
@ -735,4 +796,4 @@ namespace Google.Protobuf
}
}
}
}
}

View File

@ -307,10 +307,17 @@ namespace Google.Protobuf
throw InvalidProtocolBufferException.MoreDataAvailable();
}
}
internal void CheckLastTagWas(uint expectedTag)
{
if (lastTag != expectedTag) {
throw InvalidProtocolBufferException.InvalidEndTag();
}
}
#endregion
#region Reading of tags etc
/// <summary>
/// Peeks at the next field tag. This is like calling <see cref="ReadTag"/>, but the
/// tag is not consumed. (So a subsequent call to <see cref="ReadTag"/> will return the
@ -636,7 +643,27 @@ namespace Google.Protobuf
throw InvalidProtocolBufferException.RecursionLimitExceeded();
}
++recursionDepth;
uint tag = lastTag;
int fieldNumber = WireFormat.GetTagFieldNumber(tag);
builder.MergeFrom(this);
CheckLastTagWas(WireFormat.MakeTag(fieldNumber, WireFormat.WireType.EndGroup));
--recursionDepth;
}
/// <summary>
/// Reads an embedded group unknown field from the stream.
/// </summary>
internal void ReadGroup(int fieldNumber, UnknownFieldSet set)
{
if (recursionDepth >= recursionLimit)
{
throw InvalidProtocolBufferException.RecursionLimitExceeded();
}
++recursionDepth;
set.MergeGroupFrom(this);
CheckLastTagWas(WireFormat.MakeTag(fieldNumber, WireFormat.WireType.EndGroup));
--recursionDepth;
}

View File

@ -215,12 +215,8 @@ namespace Google.Protobuf
}
case WireFormat.WireType.StartGroup:
{
uint endTag = WireFormat.MakeTag(number, WireFormat.WireType.EndGroup);
UnknownFieldSet set = new UnknownFieldSet();
while (input.ReadTag() != endTag)
{
set.MergeFieldFrom(input);
}
input.ReadGroup(number, set);
GetOrAddField(number).AddGroup(set);
return true;
}
@ -233,6 +229,22 @@ namespace Google.Protobuf
}
}
internal void MergeGroupFrom(CodedInputStream input)
{
while (true)
{
uint tag = input.ReadTag();
if (tag == 0)
{
break;
}
if (!MergeFieldFrom(input))
{
break;
}
}
}
/// <summary>
/// Create a new UnknownFieldSet if unknownFields is null.
/// Parse a single field from <paramref name="input"/> and merge it