diff --git a/source/val/function.cpp b/source/val/function.cpp index e033bfe80..d9e5b264b 100644 --- a/source/val/function.cpp +++ b/source/val/function.cpp @@ -137,6 +137,7 @@ spv_result_t Function::RegisterLoopMerge(uint32_t merge_id, AddConstruct({ConstructType::kContinue, &continue_target_block}); continue_construct.set_corresponding_constructs({&loop_construct}); loop_construct.set_corresponding_constructs({&continue_construct}); + merge_block_header_[&merge_block] = current_block_; return SPV_SUCCESS; } @@ -146,6 +147,7 @@ spv_result_t Function::RegisterSelectionMerge(uint32_t merge_id) { BasicBlock& merge_block = blocks_.at(merge_id); current_block_->set_type(kBlockTypeHeader); merge_block.set_type(kBlockTypeMerge); + merge_block_header_[&merge_block] = current_block_; AddConstruct({ConstructType::kSelection, current_block(), &merge_block}); @@ -371,4 +373,54 @@ Construct& Function::FindConstructForEntryBlock(const BasicBlock* entry_block) { return *construct_ptr; } +int Function::GetBlockDepth(BasicBlock* bb) { + // Guard against nullptr. + if (!bb) { + return 0; + } + // Only calculate the depth if it's not already calculated. + // This function uses memoization to avoid duplicate CFG depth calculations. + if (block_depth_.find(bb) != block_depth_.end()) { + return block_depth_[bb]; + } + + BasicBlock* bb_dom = bb->immediate_dominator(); + if (!bb_dom || bb == bb_dom) { + // This block has no dominator, so it's at depth 0. + block_depth_[bb] = 0; + } else if (bb->is_type(kBlockTypeMerge)) { + // If this is a merge block, its depth is equal to the block before + // branching. + BasicBlock* header = merge_block_header_[bb]; + assert(header); + block_depth_[bb] = GetBlockDepth(header); + } else if (bb->is_type(kBlockTypeContinue)) { + // The depth of the continue block entry point is 1 + loop header depth. + Construct* continue_construct = entry_block_to_construct_[bb]; + assert(continue_construct); + // Continue construct has only 1 corresponding construct (loop header). + Construct* loop_construct = + continue_construct->corresponding_constructs()[0]; + assert(loop_construct); + BasicBlock* loop_header = loop_construct->entry_block(); + // The continue target may be the loop itself (while 1). + // In such cases, the depth of the continue block is: 1 + depth of the + // loop's dominator block. + if (loop_header == bb) { + block_depth_[bb] = 1 + GetBlockDepth(bb_dom); + } else { + block_depth_[bb] = 1 + GetBlockDepth(loop_header); + } + } else if (bb_dom->is_type(kBlockTypeHeader) || + bb_dom->is_type(kBlockTypeLoop)) { + // The dominator of the given block is a header block. So, the nesting + // depth of this block is: 1 + nesting depth of the header. + block_depth_[bb] = 1 + GetBlockDepth(bb_dom); + } else { + block_depth_[bb] = GetBlockDepth(bb_dom); + } + + return block_depth_[bb]; +} + } /// namespace libspirv diff --git a/source/val/function.h b/source/val/function.h index bd993c791..991b94fc2 100644 --- a/source/val/function.h +++ b/source/val/function.h @@ -174,6 +174,12 @@ class Function { /// Returns the block predecessors function for the augmented CFG. GetBlocksFunction AugmentedCFGPredecessorsFunction() const; + /// Returns the control flow nesting depth of the given basic block. + /// This function only works when you have structured control flow. + /// This function should only be called after the control flow constructs have + /// been identified and dominators have been computed. + int GetBlockDepth(BasicBlock* bb); + /// Prints a GraphViz digraph of the CFG of the current funciton void PrintDotGraph() const; @@ -278,6 +284,12 @@ class Function { /// Maps a construct's entry block to the construct. std::unordered_map entry_block_to_construct_; + + /// This map provides the header block for a given merge block. + std::unordered_map merge_block_header_; + + /// Stores the control flow nesting depth of a given basic block + std::unordered_map block_depth_; }; } /// namespace libspirv diff --git a/source/validate_cfg.cpp b/source/validate_cfg.cpp index ab866e3fd..e90857740 100644 --- a/source/validate_cfg.cpp +++ b/source/validate_cfg.cpp @@ -435,10 +435,10 @@ spv_result_t PerformCfgChecks(ValidationState_t& _) { } UpdateContinueConstructExitBlocks(function, back_edges); - // Check if the order of blocks in the binary appear before the blocks they - // dominate auto& blocks = function.ordered_blocks(); - if (blocks.empty() == false) { + if (!blocks.empty()) { + // Check if the order of blocks in the binary appear before the blocks + // they dominate for (auto block = begin(blocks) + 1; block != end(blocks); ++block) { if (auto idom = (*block)->immediate_dominator()) { if (idom != function.pseudo_entry_block() && @@ -450,6 +450,18 @@ spv_result_t PerformCfgChecks(ValidationState_t& _) { } } } + // If we have structed control flow, check that no block has a control + // flow nesting depth larger than the limit. + if (_.HasCapability(SpvCapabilityShader)) { + const int control_flow_nesting_depth_limit = 1023; + for (auto block = begin(blocks); block != end(blocks); ++block) { + if (function.GetBlockDepth(*block) > + control_flow_nesting_depth_limit) { + return _.diag(SPV_ERROR_INVALID_CFG) + << "Maximum Control Flow nesting depth exceeded."; + } + } + } } /// Structured control flow checks are only required for shader capabilities diff --git a/test/val/val_limits_test.cpp b/test/val/val_limits_test.cpp index f6f294542..c402e34af 100644 --- a/test/val/val_limits_test.cpp +++ b/test/val/val_limits_test.cpp @@ -22,9 +22,10 @@ #include "unit_spirv.h" #include "val_fixtures.h" +namespace { + using ::testing::HasSubstr; using ::testing::MatchesRegex; - using std::string; using ValidateLimits = spvtest::ValidateBase; @@ -34,7 +35,7 @@ string header = R"( OpMemoryModel Logical GLSL450 )"; -TEST_F(ValidateLimits, idLargerThanBoundBad) { +TEST_F(ValidateLimits, IdLargerThanBoundBad) { string str = header + R"( ; %i32 has ID 1 %i32 = OpTypeInt 32 1 @@ -52,7 +53,7 @@ TEST_F(ValidateLimits, idLargerThanBoundBad) { HasSubstr("Result '64' must be less than the ID bound '3'.")); } -TEST_F(ValidateLimits, idEqualToBoundBad) { +TEST_F(ValidateLimits, IdEqualToBoundBad) { string str = header + R"( ; %i32 has ID 1 %i32 = OpTypeInt 32 1 @@ -76,7 +77,7 @@ TEST_F(ValidateLimits, idEqualToBoundBad) { HasSubstr("Result '64' must be less than the ID bound '64'.")); } -TEST_F(ValidateLimits, structNumMembersGood) { +TEST_F(ValidateLimits, StructNumMembersGood) { std::ostringstream spirv; spirv << header << R"( %1 = OpTypeInt 32 0 @@ -88,7 +89,7 @@ TEST_F(ValidateLimits, structNumMembersGood) { ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); } -TEST_F(ValidateLimits, structNumMembersExceededBad) { +TEST_F(ValidateLimits, StructNumMembersExceededBad) { std::ostringstream spirv; spirv << header << R"( %1 = OpTypeInt 32 0 @@ -104,7 +105,7 @@ TEST_F(ValidateLimits, structNumMembersExceededBad) { } // Valid: Switch statement has 16,383 branches. -TEST_F(ValidateLimits, switchNumBranchesGood) { +TEST_F(ValidateLimits, SwitchNumBranchesGood) { std::ostringstream spirv; spirv << header << R"( %1 = OpTypeVoid @@ -132,7 +133,7 @@ OpFunctionEnd } // Invalid: Switch statement has 16,384 branches. -TEST_F(ValidateLimits, switchNumBranchesBad) { +TEST_F(ValidateLimits, SwitchNumBranchesBad) { std::ostringstream spirv; spirv << header << R"( %1 = OpTypeVoid @@ -320,3 +321,97 @@ TEST_F(ValidateLimits, StructNestingDepthBad) { HasSubstr( "Structure Nesting Depth may not be larger than 255. Found 256.")); } + +// clang-format off +// Generates an SPIRV program with the given control flow nesting depth +void GenerateSpirvProgramWithCfgNestingDepth(std::string& str, int depth) { + std::ostringstream spirv; + spirv << header << R"( + %void = OpTypeVoid + %3 = OpTypeFunction %void + %bool = OpTypeBool + %12 = OpConstantTrue %bool + %main = OpFunction %void None %3 + %5 = OpLabel + OpBranch %6 + %6 = OpLabel + OpLoopMerge %8 %9 None + OpBranch %10 + %10 = OpLabel + OpBranchConditional %12 %7 %8 + %7 = OpLabel + )"; + int first_id = 13; + int last_id = 14; + // We already have 1 level of nesting due to the Loop. + int num_if_conditions = depth-1; + int largest_index = first_id + 2*num_if_conditions - 2; + for (int i = first_id; i <= largest_index; i = i + 2) { + spirv << "OpSelectionMerge %" << i+1 << " None" << "\n"; + spirv << "OpBranchConditional %12 " << "%" << i << " %" << i+1 << "\n"; + spirv << "%" << i << " = OpLabel" << "\n"; + } + spirv << "OpBranch %9" << "\n"; + + for (int i = largest_index+1; i > last_id; i = i - 2) { + spirv << "%" << i << " = OpLabel" << "\n"; + spirv << "OpBranch %" << i-2 << "\n"; + } + spirv << "%" << last_id << " = OpLabel" << "\n"; + spirv << "OpBranch %9" << "\n"; + spirv << R"( + %9 = OpLabel + OpBranch %6 + %8 = OpLabel + OpReturn + OpFunctionEnd + )"; + str = spirv.str(); +} +// clang-format on + +// Valid: Control Flow Nesting depth is 1023. +TEST_F(ValidateLimits, ControlFlowDepthGood) { + std::string spirv; + GenerateSpirvProgramWithCfgNestingDepth(spirv, 1023); + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +// Invalid: Control Flow Nesting depth is 1024. (limit is 1023). +TEST_F(ValidateLimits, ControlFlowDepthBad) { + std::string spirv; + GenerateSpirvProgramWithCfgNestingDepth(spirv, 1024); + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Maximum Control Flow nesting depth exceeded.")); +} + +// Valid. The purpose here is to test the CFG depth calculation code when a loop +// continue target is the loop iteself. It also exercises the case where a loop +// is unreachable. +TEST_F(ValidateLimits, ControlFlowNoEntryToLoopGood) { + string str = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpName %entry "entry" + OpName %loop "loop" + OpName %exit "exit" +%voidt = OpTypeVoid +%funct = OpTypeFunction %voidt +%main = OpFunction %voidt None %funct +%entry = OpLabel + OpBranch %exit +%loop = OpLabel + OpLoopMerge %loop %loop None + OpBranch %loop +%exit = OpLabel + OpReturn + OpFunctionEnd + )"; + CompileSuccessfully(str); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +} // anonymous namespace