Adding validation code for OpSwitch limits

The number of (literal, label) pairs passed to OpSwitch may not exceed
16,383. Added code to validate this and added unit tests for it.

Also fixed a typo in another validor error message.
This commit is contained in:
Ehsan Nasiri 2016-11-25 09:26:26 -05:00 committed by David Neto
parent bef80716d7
commit 3c8bc80e3a
4 changed files with 92 additions and 11 deletions

View File

@ -198,7 +198,7 @@ void printDominatorList(const BasicBlock& b) {
spv_result_t FirstBlockAssert(ValidationState_t& _, uint32_t target) {
if (_.current_function().IsFirstBlock(target)) {
return _.diag(SPV_ERROR_INVALID_CFG)
<< "First block " << _.getIdName(target) << " of funciton "
<< "First block " << _.getIdName(target) << " of function "
<< _.getIdName(_.current_function().id()) << " is targeted by block "
<< _.getIdName(_.current_function().current_block()->id());
}

View File

@ -156,6 +156,26 @@ spv_result_t LimitCheckStruct(ValidationState_t& _,
return SPV_SUCCESS;
}
// Checks that the number of (literal, label) pairs in OpSwitch is within the
// limit.
spv_result_t LimitCheckSwitch(ValidationState_t& _,
const spv_parsed_instruction_t* inst) {
if (SpvOpSwitch == inst->opcode) {
// The instruction syntax is as follows:
// OpSwitch <selector ID> <Default ID> literal label literal label ...
// literal,label pairs come after the first 2 operands.
// It is guaranteed at this point that num_operands is an even numner.
unsigned int num_pairs = (inst->num_operands - 2) / 2;
const unsigned int num_pairs_limit = 16383;
if (num_pairs > num_pairs_limit) {
return _.diag(SPV_ERROR_INVALID_BINARY)
<< "Number of (literal, label) pairs in OpSwitch (" << num_pairs
<< ") exceeds the limit (" << num_pairs_limit << ").";
}
}
return SPV_SUCCESS;
}
spv_result_t InstructionPass(ValidationState_t& _,
const spv_parsed_instruction_t* inst) {
const SpvOp opcode = static_cast<SpvOp>(inst->opcode);
@ -198,6 +218,7 @@ spv_result_t InstructionPass(ValidationState_t& _,
if (auto error = CapCheck(_, inst)) return error;
if (auto error = LimitCheckIdBound(_, inst)) return error;
if (auto error = LimitCheckStruct(_, inst)) return error;
if (auto error = LimitCheckSwitch(_, inst)) return error;
// All instruction checks have passed.
return SPV_SUCCESS;

View File

@ -25,11 +25,11 @@
#include "gmock/gmock.h"
#include "source/diagnostic.h"
#include "source/validate.h"
#include "test_fixture.h"
#include "unit_spirv.h"
#include "val_fixtures.h"
#include "source/diagnostic.h"
#include "source/validate.h"
using std::array;
using std::make_pair;
@ -80,12 +80,12 @@ class Block {
/// Sets the instructions which will appear in the body of the block
Block& SetBody(std::string body) {
body_ = body;
body_ = body;
return *this;
}
Block& AppendBody(std::string body) {
body_ += body;
body_ += body;
return *this;
}
@ -465,7 +465,7 @@ TEST_P(ValidateCFG, BranchTargetFirstBlockBad) {
CompileSuccessfully(str);
ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
MatchesRegex("First block .\\[entry\\] of funciton .\\[Main\\] "
MatchesRegex("First block .\\[entry\\] of function .\\[Main\\] "
"is targeted by block .\\[bad\\]"));
}
@ -489,7 +489,7 @@ TEST_P(ValidateCFG, BranchConditionalTrueTargetFirstBlockBad) {
CompileSuccessfully(str);
ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
MatchesRegex("First block .\\[entry\\] of funciton .\\[Main\\] "
MatchesRegex("First block .\\[entry\\] of function .\\[Main\\] "
"is targeted by block .\\[bad\\]"));
}
@ -516,7 +516,7 @@ TEST_P(ValidateCFG, BranchConditionalFalseTargetFirstBlockBad) {
CompileSuccessfully(str);
ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
MatchesRegex("First block .\\[entry\\] of funciton .\\[Main\\] "
MatchesRegex("First block .\\[entry\\] of function .\\[Main\\] "
"is targeted by block .\\[bad\\]"));
}
@ -550,7 +550,7 @@ TEST_P(ValidateCFG, SwitchTargetFirstBlockBad) {
CompileSuccessfully(str);
ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
MatchesRegex("First block .\\[entry\\] of funciton .\\[Main\\] "
MatchesRegex("First block .\\[entry\\] of function .\\[Main\\] "
"is targeted by block .\\[bad\\]"));
}
@ -1019,7 +1019,8 @@ TEST_P(ValidateCFG, BranchOutOfConstructToMergeBad) {
EXPECT_THAT(getDiagnosticString(),
MatchesRegex("The continue construct with the continue target "
".\\[loop\\] is not post dominated by the "
"back-edge block .\\[cont\\]")) << str;
"back-edge block .\\[cont\\]"))
<< str;
} else {
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
@ -1254,7 +1255,7 @@ TEST_P(ValidateCFG, SingleLatchBlockMultipleBranchesToLoopHeader) {
str += entry >> loop;
str += loop >> vector<Block>({latch, merge});
str += latch >> vector<Block>({loop, loop}); // This is the key
str += latch >> vector<Block>({loop, loop}); // This is the key
str += merge;
str += "OpFunctionEnd";

View File

@ -103,3 +103,62 @@ TEST_F(ValidateLimits, structNumMembersExceededBad) {
"the limit (16383)."));
}
// Valid: Switch statement has 16,383 branches.
TEST_F(ValidateLimits, switchNumBranchesGood) {
std::ostringstream spirv;
spirv << header << R"(
%1 = OpTypeVoid
%2 = OpTypeFunction %1
%3 = OpTypeInt 32 0
%4 = OpConstant %3 1234
%5 = OpFunction %1 None %2
%7 = OpLabel
%8 = OpIAdd %3 %4 %4
%9 = OpSwitch %4 %10)";
// Now add the (literal, label) pairs
for (int i = 0; i < 16383; ++i) {
spirv << " 1 %10";
}
spirv << R"(
%10 = OpLabel
OpReturn
OpFunctionEnd
)";
CompileSuccessfully(spirv.str());
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
// Invalid: Switch statement has 16,384 branches.
TEST_F(ValidateLimits, switchNumBranchesBad) {
std::ostringstream spirv;
spirv << header << R"(
%1 = OpTypeVoid
%2 = OpTypeFunction %1
%3 = OpTypeInt 32 0
%4 = OpConstant %3 1234
%5 = OpFunction %1 None %2
%7 = OpLabel
%8 = OpIAdd %3 %4 %4
%9 = OpSwitch %4 %10)";
// Now add the (literal, label) pairs
for (int i = 0; i < 16384; ++i) {
spirv << " 1 %10";
}
spirv << R"(
%10 = OpLabel
OpReturn
OpFunctionEnd
)";
CompileSuccessfully(spirv.str());
ASSERT_EQ(SPV_ERROR_INVALID_BINARY, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
HasSubstr("Number of (literal, label) pairs in OpSwitch (16384) "
"exceeds the limit (16383)."));
}