Implement WebGPU specific CFG validation (#2386)

In WebGPU all blocks are required to be reachable, unless they are one of two
specific degenerate cases for merge-block or continue-target. This PR adds in
checking for these conditions.

Fixes #2068
This commit is contained in:
Ryan Harrison 2019-03-08 13:01:09 -05:00 committed by GitHub
parent a2ef7be242
commit b12e7338ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 979 additions and 81 deletions

View File

@ -86,6 +86,12 @@ spv_result_t Function::RegisterLoopMerge(uint32_t merge_id,
continue_construct.set_corresponding_constructs({&loop_construct}); continue_construct.set_corresponding_constructs({&loop_construct});
loop_construct.set_corresponding_constructs({&continue_construct}); loop_construct.set_corresponding_constructs({&continue_construct});
merge_block_header_[&merge_block] = current_block_; merge_block_header_[&merge_block] = current_block_;
if (continue_target_headers_.find(&continue_target_block) ==
continue_target_headers_.end()) {
continue_target_headers_[&continue_target_block] = {current_block_};
} else {
continue_target_headers_[&continue_target_block].push_back(current_block_);
}
return SPV_SUCCESS; return SPV_SUCCESS;
} }

View File

@ -232,6 +232,28 @@ class Function {
return function_call_targets_; return function_call_targets_;
} }
// Returns the block containing the OpSelectionMerge or OpLoopMerge that
// references |merge_block|.
// Values of |merge_block_header_| inserted by CFGPass, so do not call before
// the first iteration of ordered instructions in
// ValidateBinaryUsingContextAndValidationState has completed.
BasicBlock* GetMergeHeader(BasicBlock* merge_block) {
return merge_block_header_[merge_block];
}
// Returns vector of the blocks containing a OpLoopMerge that references
// |continue_target|.
// Values of |continue_target_headers_| inserted by CFGPass, so do not call
// before the first iteration of ordered instructions in
// ValidateBinaryUsingContextAndValidationState has completed.
std::vector<BasicBlock*> GetContinueHeaders(BasicBlock* continue_target) {
if (continue_target_headers_.find(continue_target) ==
continue_target_headers_.end()) {
return {};
}
return continue_target_headers_[continue_target];
}
private: private:
// Computes the representation of the augmented CFG. // Computes the representation of the augmented CFG.
// Populates augmented_successors_map_ and augmented_predecessors_map_. // Populates augmented_successors_map_ and augmented_predecessors_map_.
@ -340,6 +362,10 @@ class Function {
/// This map provides the header block for a given merge block. /// This map provides the header block for a given merge block.
std::unordered_map<BasicBlock*, BasicBlock*> merge_block_header_; std::unordered_map<BasicBlock*, BasicBlock*> merge_block_header_;
/// This map provides the header blocks for a given continue target.
std::unordered_map<BasicBlock*, std::vector<BasicBlock*>>
continue_target_headers_;
/// Stores the control flow nesting depth of a given basic block /// Stores the control flow nesting depth of a given basic block
std::unordered_map<BasicBlock*, int> block_depth_; std::unordered_map<BasicBlock*, int> block_depth_;

View File

@ -29,6 +29,7 @@
#include "source/cfa.h" #include "source/cfa.h"
#include "source/opcode.h" #include "source/opcode.h"
#include "source/spirv_target_env.h"
#include "source/spirv_validator_options.h" #include "source/spirv_validator_options.h"
#include "source/val/basic_block.h" #include "source/val/basic_block.h"
#include "source/val/construct.h" #include "source/val/construct.h"
@ -610,6 +611,120 @@ spv_result_t StructuredControlFlowChecks(
return SPV_SUCCESS; return SPV_SUCCESS;
} }
spv_result_t PerformWebGPUCfgChecks(ValidationState_t& _, Function* function) {
for (auto& block : function->ordered_blocks()) {
if (block->reachable()) continue;
if (block->is_type(kBlockTypeMerge)) {
// 1. Find the referencing merge and confirm that it is reachable.
BasicBlock* merge_header = function->GetMergeHeader(block);
assert(merge_header != nullptr);
if (!merge_header->reachable()) {
return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(block->id()))
<< "For WebGPU, unreachable merge-blocks must be referenced by "
"a reachable merge instruction.";
}
// 2. Check that the only instructions are OpLabel and OpUnreachable.
auto* label_inst = block->label();
auto* terminator_inst = block->terminator();
assert(label_inst != nullptr);
assert(terminator_inst != nullptr);
if (terminator_inst->opcode() != SpvOpUnreachable) {
return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(block->id()))
<< "For WebGPU, unreachable merge-blocks must terminate with "
"OpUnreachable.";
}
auto label_idx = label_inst - &_.ordered_instructions()[0];
auto terminator_idx = terminator_inst - &_.ordered_instructions()[0];
if (label_idx + 1 != terminator_idx) {
return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(block->id()))
<< "For WebGPU, unreachable merge-blocks must only contain an "
"OpLabel and OpUnreachable instruction.";
}
// 3. Use label instruction to confirm there is no uses by branches.
for (auto use : label_inst->uses()) {
const auto* use_inst = use.first;
if (spvOpcodeIsBranch(use_inst->opcode())) {
return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(block->id()))
<< "For WebGPU, unreachable merge-blocks cannot be the target "
"of a branch.";
}
}
} else if (block->is_type(kBlockTypeContinue)) {
// 1. Find referencing loop and confirm that it is reachable.
std::vector<BasicBlock*> continue_headers =
function->GetContinueHeaders(block);
if (continue_headers.empty()) {
return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(block->id()))
<< "For WebGPU, unreachable continue-target must be referenced "
"by a loop instruction.";
}
std::vector<BasicBlock*> reachable_headers(continue_headers.size());
auto iter =
std::copy_if(continue_headers.begin(), continue_headers.end(),
reachable_headers.begin(),
[](BasicBlock* header) { return header->reachable(); });
reachable_headers.resize(std::distance(reachable_headers.begin(), iter));
if (reachable_headers.empty()) {
return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(block->id()))
<< "For WebGPU, unreachable continue-target must be referenced "
"by a reachable loop instruction.";
}
// 2. Check that the only instructions are OpLabel and OpBranch.
auto* label_inst = block->label();
auto* terminator_inst = block->terminator();
assert(label_inst != nullptr);
assert(terminator_inst != nullptr);
if (terminator_inst->opcode() != SpvOpBranch) {
return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(block->id()))
<< "For WebGPU, unreachable continue-target must terminate with "
"OpBranch.";
}
auto label_idx = label_inst - &_.ordered_instructions()[0];
auto terminator_idx = terminator_inst - &_.ordered_instructions()[0];
if (label_idx + 1 != terminator_idx) {
return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(block->id()))
<< "For WebGPU, unreachable continue-target must only contain "
"an OpLabel and an OpBranch instruction.";
}
// 3. Use label instruction to confirm there is no uses by branches.
for (auto use : label_inst->uses()) {
const auto* use_inst = use.first;
if (spvOpcodeIsBranch(use_inst->opcode())) {
return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(block->id()))
<< "For WebGPU, unreachable continue-target cannot be the "
"target of a branch.";
}
}
// 4. Confirm that continue-target has a back edge to a reachable loop
// header block.
auto branch_target = terminator_inst->GetOperandAs<uint32_t>(0);
for (auto* continue_header : reachable_headers) {
if (branch_target != continue_header->id()) {
return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(block->id()))
<< "For WebGPU, unreachable continue-target must only have a "
"back edge to a single reachable loop instruction.";
}
}
} else {
return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(block->id()))
<< "For WebGPU, all blocks must be reachable, unless they are "
<< "degenerate cases of merge-block or continue-target.";
}
}
return SPV_SUCCESS;
}
spv_result_t PerformCfgChecks(ValidationState_t& _) { spv_result_t PerformCfgChecks(ValidationState_t& _) {
for (auto& function : _.functions()) { for (auto& function : _.functions()) {
// Check all referenced blocks are defined within a function // Check all referenced blocks are defined within a function
@ -689,6 +804,13 @@ spv_result_t PerformCfgChecks(ValidationState_t& _) {
<< _.getIdName(idom->id()); << _.getIdName(idom->id());
} }
} }
// For WebGPU check that all unreachable blocks are degenerate cases for
// merge-block or continue-target.
if (spvIsWebGPUEnv(_.context()->target_env)) {
spv_result_t result = PerformWebGPUCfgChecks(_, &function);
if (result != SPV_SUCCESS) return result;
}
} }
// If we have structed control flow, check that no block has a control // If we have structed control flow, check that no block has a control
// flow nesting depth larger than the limit. // flow nesting depth larger than the limit.

File diff suppressed because it is too large Load Diff