mirror of
https://github.com/KhronosGroup/SPIRV-Tools
synced 2024-11-27 05:40:06 +00:00
Adding validation code for OpSwitch limits
The number of (literal, label) pairs passed to OpSwitch may not exceed 16,383. Added code to validate this and added unit tests for it. Also fixed a typo in another validor error message.
This commit is contained in:
parent
bef80716d7
commit
3c8bc80e3a
@ -198,7 +198,7 @@ void printDominatorList(const BasicBlock& b) {
|
||||
spv_result_t FirstBlockAssert(ValidationState_t& _, uint32_t target) {
|
||||
if (_.current_function().IsFirstBlock(target)) {
|
||||
return _.diag(SPV_ERROR_INVALID_CFG)
|
||||
<< "First block " << _.getIdName(target) << " of funciton "
|
||||
<< "First block " << _.getIdName(target) << " of function "
|
||||
<< _.getIdName(_.current_function().id()) << " is targeted by block "
|
||||
<< _.getIdName(_.current_function().current_block()->id());
|
||||
}
|
||||
|
@ -156,6 +156,26 @@ spv_result_t LimitCheckStruct(ValidationState_t& _,
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
// Checks that the number of (literal, label) pairs in OpSwitch is within the
|
||||
// limit.
|
||||
spv_result_t LimitCheckSwitch(ValidationState_t& _,
|
||||
const spv_parsed_instruction_t* inst) {
|
||||
if (SpvOpSwitch == inst->opcode) {
|
||||
// The instruction syntax is as follows:
|
||||
// OpSwitch <selector ID> <Default ID> literal label literal label ...
|
||||
// literal,label pairs come after the first 2 operands.
|
||||
// It is guaranteed at this point that num_operands is an even numner.
|
||||
unsigned int num_pairs = (inst->num_operands - 2) / 2;
|
||||
const unsigned int num_pairs_limit = 16383;
|
||||
if (num_pairs > num_pairs_limit) {
|
||||
return _.diag(SPV_ERROR_INVALID_BINARY)
|
||||
<< "Number of (literal, label) pairs in OpSwitch (" << num_pairs
|
||||
<< ") exceeds the limit (" << num_pairs_limit << ").";
|
||||
}
|
||||
}
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
spv_result_t InstructionPass(ValidationState_t& _,
|
||||
const spv_parsed_instruction_t* inst) {
|
||||
const SpvOp opcode = static_cast<SpvOp>(inst->opcode);
|
||||
@ -198,6 +218,7 @@ spv_result_t InstructionPass(ValidationState_t& _,
|
||||
if (auto error = CapCheck(_, inst)) return error;
|
||||
if (auto error = LimitCheckIdBound(_, inst)) return error;
|
||||
if (auto error = LimitCheckStruct(_, inst)) return error;
|
||||
if (auto error = LimitCheckSwitch(_, inst)) return error;
|
||||
|
||||
// All instruction checks have passed.
|
||||
return SPV_SUCCESS;
|
||||
|
@ -25,11 +25,11 @@
|
||||
|
||||
#include "gmock/gmock.h"
|
||||
|
||||
#include "source/diagnostic.h"
|
||||
#include "source/validate.h"
|
||||
#include "test_fixture.h"
|
||||
#include "unit_spirv.h"
|
||||
#include "val_fixtures.h"
|
||||
#include "source/diagnostic.h"
|
||||
#include "source/validate.h"
|
||||
|
||||
using std::array;
|
||||
using std::make_pair;
|
||||
@ -80,12 +80,12 @@ class Block {
|
||||
|
||||
/// Sets the instructions which will appear in the body of the block
|
||||
Block& SetBody(std::string body) {
|
||||
body_ = body;
|
||||
body_ = body;
|
||||
return *this;
|
||||
}
|
||||
|
||||
Block& AppendBody(std::string body) {
|
||||
body_ += body;
|
||||
body_ += body;
|
||||
return *this;
|
||||
}
|
||||
|
||||
@ -465,7 +465,7 @@ TEST_P(ValidateCFG, BranchTargetFirstBlockBad) {
|
||||
CompileSuccessfully(str);
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
|
||||
EXPECT_THAT(getDiagnosticString(),
|
||||
MatchesRegex("First block .\\[entry\\] of funciton .\\[Main\\] "
|
||||
MatchesRegex("First block .\\[entry\\] of function .\\[Main\\] "
|
||||
"is targeted by block .\\[bad\\]"));
|
||||
}
|
||||
|
||||
@ -489,7 +489,7 @@ TEST_P(ValidateCFG, BranchConditionalTrueTargetFirstBlockBad) {
|
||||
CompileSuccessfully(str);
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
|
||||
EXPECT_THAT(getDiagnosticString(),
|
||||
MatchesRegex("First block .\\[entry\\] of funciton .\\[Main\\] "
|
||||
MatchesRegex("First block .\\[entry\\] of function .\\[Main\\] "
|
||||
"is targeted by block .\\[bad\\]"));
|
||||
}
|
||||
|
||||
@ -516,7 +516,7 @@ TEST_P(ValidateCFG, BranchConditionalFalseTargetFirstBlockBad) {
|
||||
CompileSuccessfully(str);
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
|
||||
EXPECT_THAT(getDiagnosticString(),
|
||||
MatchesRegex("First block .\\[entry\\] of funciton .\\[Main\\] "
|
||||
MatchesRegex("First block .\\[entry\\] of function .\\[Main\\] "
|
||||
"is targeted by block .\\[bad\\]"));
|
||||
}
|
||||
|
||||
@ -550,7 +550,7 @@ TEST_P(ValidateCFG, SwitchTargetFirstBlockBad) {
|
||||
CompileSuccessfully(str);
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
|
||||
EXPECT_THAT(getDiagnosticString(),
|
||||
MatchesRegex("First block .\\[entry\\] of funciton .\\[Main\\] "
|
||||
MatchesRegex("First block .\\[entry\\] of function .\\[Main\\] "
|
||||
"is targeted by block .\\[bad\\]"));
|
||||
}
|
||||
|
||||
@ -1019,7 +1019,8 @@ TEST_P(ValidateCFG, BranchOutOfConstructToMergeBad) {
|
||||
EXPECT_THAT(getDiagnosticString(),
|
||||
MatchesRegex("The continue construct with the continue target "
|
||||
".\\[loop\\] is not post dominated by the "
|
||||
"back-edge block .\\[cont\\]")) << str;
|
||||
"back-edge block .\\[cont\\]"))
|
||||
<< str;
|
||||
} else {
|
||||
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
|
||||
}
|
||||
@ -1254,7 +1255,7 @@ TEST_P(ValidateCFG, SingleLatchBlockMultipleBranchesToLoopHeader) {
|
||||
|
||||
str += entry >> loop;
|
||||
str += loop >> vector<Block>({latch, merge});
|
||||
str += latch >> vector<Block>({loop, loop}); // This is the key
|
||||
str += latch >> vector<Block>({loop, loop}); // This is the key
|
||||
str += merge;
|
||||
str += "OpFunctionEnd";
|
||||
|
||||
|
@ -103,3 +103,62 @@ TEST_F(ValidateLimits, structNumMembersExceededBad) {
|
||||
"the limit (16383)."));
|
||||
}
|
||||
|
||||
// Valid: Switch statement has 16,383 branches.
|
||||
TEST_F(ValidateLimits, switchNumBranchesGood) {
|
||||
std::ostringstream spirv;
|
||||
spirv << header << R"(
|
||||
%1 = OpTypeVoid
|
||||
%2 = OpTypeFunction %1
|
||||
%3 = OpTypeInt 32 0
|
||||
%4 = OpConstant %3 1234
|
||||
%5 = OpFunction %1 None %2
|
||||
%7 = OpLabel
|
||||
%8 = OpIAdd %3 %4 %4
|
||||
%9 = OpSwitch %4 %10)";
|
||||
|
||||
// Now add the (literal, label) pairs
|
||||
for (int i = 0; i < 16383; ++i) {
|
||||
spirv << " 1 %10";
|
||||
}
|
||||
|
||||
spirv << R"(
|
||||
%10 = OpLabel
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
)";
|
||||
|
||||
CompileSuccessfully(spirv.str());
|
||||
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
|
||||
}
|
||||
|
||||
// Invalid: Switch statement has 16,384 branches.
|
||||
TEST_F(ValidateLimits, switchNumBranchesBad) {
|
||||
std::ostringstream spirv;
|
||||
spirv << header << R"(
|
||||
%1 = OpTypeVoid
|
||||
%2 = OpTypeFunction %1
|
||||
%3 = OpTypeInt 32 0
|
||||
%4 = OpConstant %3 1234
|
||||
%5 = OpFunction %1 None %2
|
||||
%7 = OpLabel
|
||||
%8 = OpIAdd %3 %4 %4
|
||||
%9 = OpSwitch %4 %10)";
|
||||
|
||||
// Now add the (literal, label) pairs
|
||||
for (int i = 0; i < 16384; ++i) {
|
||||
spirv << " 1 %10";
|
||||
}
|
||||
|
||||
spirv << R"(
|
||||
%10 = OpLabel
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
)";
|
||||
|
||||
CompileSuccessfully(spirv.str());
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_BINARY, ValidateInstructions());
|
||||
EXPECT_THAT(getDiagnosticString(),
|
||||
HasSubstr("Number of (literal, label) pairs in OpSwitch (16384) "
|
||||
"exceeds the limit (16383)."));
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user