mirror of
https://github.com/KhronosGroup/SPIRV-Tools
synced 2024-11-26 21:30:07 +00:00
Implement WebGPU specific CFG validation (#2386)
In WebGPU all blocks are required to be reachable, unless they are one of two specific degenerate cases for merge-block or continue-target. This PR adds in checking for these conditions. Fixes #2068
This commit is contained in:
parent
a2ef7be242
commit
b12e7338ee
@ -86,6 +86,12 @@ spv_result_t Function::RegisterLoopMerge(uint32_t merge_id,
|
|||||||
continue_construct.set_corresponding_constructs({&loop_construct});
|
continue_construct.set_corresponding_constructs({&loop_construct});
|
||||||
loop_construct.set_corresponding_constructs({&continue_construct});
|
loop_construct.set_corresponding_constructs({&continue_construct});
|
||||||
merge_block_header_[&merge_block] = current_block_;
|
merge_block_header_[&merge_block] = current_block_;
|
||||||
|
if (continue_target_headers_.find(&continue_target_block) ==
|
||||||
|
continue_target_headers_.end()) {
|
||||||
|
continue_target_headers_[&continue_target_block] = {current_block_};
|
||||||
|
} else {
|
||||||
|
continue_target_headers_[&continue_target_block].push_back(current_block_);
|
||||||
|
}
|
||||||
|
|
||||||
return SPV_SUCCESS;
|
return SPV_SUCCESS;
|
||||||
}
|
}
|
||||||
|
@ -232,6 +232,28 @@ class Function {
|
|||||||
return function_call_targets_;
|
return function_call_targets_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns the block containing the OpSelectionMerge or OpLoopMerge that
|
||||||
|
// references |merge_block|.
|
||||||
|
// Values of |merge_block_header_| inserted by CFGPass, so do not call before
|
||||||
|
// the first iteration of ordered instructions in
|
||||||
|
// ValidateBinaryUsingContextAndValidationState has completed.
|
||||||
|
BasicBlock* GetMergeHeader(BasicBlock* merge_block) {
|
||||||
|
return merge_block_header_[merge_block];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns vector of the blocks containing a OpLoopMerge that references
|
||||||
|
// |continue_target|.
|
||||||
|
// Values of |continue_target_headers_| inserted by CFGPass, so do not call
|
||||||
|
// before the first iteration of ordered instructions in
|
||||||
|
// ValidateBinaryUsingContextAndValidationState has completed.
|
||||||
|
std::vector<BasicBlock*> GetContinueHeaders(BasicBlock* continue_target) {
|
||||||
|
if (continue_target_headers_.find(continue_target) ==
|
||||||
|
continue_target_headers_.end()) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
return continue_target_headers_[continue_target];
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Computes the representation of the augmented CFG.
|
// Computes the representation of the augmented CFG.
|
||||||
// Populates augmented_successors_map_ and augmented_predecessors_map_.
|
// Populates augmented_successors_map_ and augmented_predecessors_map_.
|
||||||
@ -340,6 +362,10 @@ class Function {
|
|||||||
/// This map provides the header block for a given merge block.
|
/// This map provides the header block for a given merge block.
|
||||||
std::unordered_map<BasicBlock*, BasicBlock*> merge_block_header_;
|
std::unordered_map<BasicBlock*, BasicBlock*> merge_block_header_;
|
||||||
|
|
||||||
|
/// This map provides the header blocks for a given continue target.
|
||||||
|
std::unordered_map<BasicBlock*, std::vector<BasicBlock*>>
|
||||||
|
continue_target_headers_;
|
||||||
|
|
||||||
/// Stores the control flow nesting depth of a given basic block
|
/// Stores the control flow nesting depth of a given basic block
|
||||||
std::unordered_map<BasicBlock*, int> block_depth_;
|
std::unordered_map<BasicBlock*, int> block_depth_;
|
||||||
|
|
||||||
|
@ -29,6 +29,7 @@
|
|||||||
|
|
||||||
#include "source/cfa.h"
|
#include "source/cfa.h"
|
||||||
#include "source/opcode.h"
|
#include "source/opcode.h"
|
||||||
|
#include "source/spirv_target_env.h"
|
||||||
#include "source/spirv_validator_options.h"
|
#include "source/spirv_validator_options.h"
|
||||||
#include "source/val/basic_block.h"
|
#include "source/val/basic_block.h"
|
||||||
#include "source/val/construct.h"
|
#include "source/val/construct.h"
|
||||||
@ -610,6 +611,120 @@ spv_result_t StructuredControlFlowChecks(
|
|||||||
return SPV_SUCCESS;
|
return SPV_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
spv_result_t PerformWebGPUCfgChecks(ValidationState_t& _, Function* function) {
|
||||||
|
for (auto& block : function->ordered_blocks()) {
|
||||||
|
if (block->reachable()) continue;
|
||||||
|
if (block->is_type(kBlockTypeMerge)) {
|
||||||
|
// 1. Find the referencing merge and confirm that it is reachable.
|
||||||
|
BasicBlock* merge_header = function->GetMergeHeader(block);
|
||||||
|
assert(merge_header != nullptr);
|
||||||
|
if (!merge_header->reachable()) {
|
||||||
|
return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(block->id()))
|
||||||
|
<< "For WebGPU, unreachable merge-blocks must be referenced by "
|
||||||
|
"a reachable merge instruction.";
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Check that the only instructions are OpLabel and OpUnreachable.
|
||||||
|
auto* label_inst = block->label();
|
||||||
|
auto* terminator_inst = block->terminator();
|
||||||
|
assert(label_inst != nullptr);
|
||||||
|
assert(terminator_inst != nullptr);
|
||||||
|
|
||||||
|
if (terminator_inst->opcode() != SpvOpUnreachable) {
|
||||||
|
return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(block->id()))
|
||||||
|
<< "For WebGPU, unreachable merge-blocks must terminate with "
|
||||||
|
"OpUnreachable.";
|
||||||
|
}
|
||||||
|
|
||||||
|
auto label_idx = label_inst - &_.ordered_instructions()[0];
|
||||||
|
auto terminator_idx = terminator_inst - &_.ordered_instructions()[0];
|
||||||
|
if (label_idx + 1 != terminator_idx) {
|
||||||
|
return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(block->id()))
|
||||||
|
<< "For WebGPU, unreachable merge-blocks must only contain an "
|
||||||
|
"OpLabel and OpUnreachable instruction.";
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Use label instruction to confirm there is no uses by branches.
|
||||||
|
for (auto use : label_inst->uses()) {
|
||||||
|
const auto* use_inst = use.first;
|
||||||
|
if (spvOpcodeIsBranch(use_inst->opcode())) {
|
||||||
|
return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(block->id()))
|
||||||
|
<< "For WebGPU, unreachable merge-blocks cannot be the target "
|
||||||
|
"of a branch.";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (block->is_type(kBlockTypeContinue)) {
|
||||||
|
// 1. Find referencing loop and confirm that it is reachable.
|
||||||
|
std::vector<BasicBlock*> continue_headers =
|
||||||
|
function->GetContinueHeaders(block);
|
||||||
|
if (continue_headers.empty()) {
|
||||||
|
return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(block->id()))
|
||||||
|
<< "For WebGPU, unreachable continue-target must be referenced "
|
||||||
|
"by a loop instruction.";
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<BasicBlock*> reachable_headers(continue_headers.size());
|
||||||
|
auto iter =
|
||||||
|
std::copy_if(continue_headers.begin(), continue_headers.end(),
|
||||||
|
reachable_headers.begin(),
|
||||||
|
[](BasicBlock* header) { return header->reachable(); });
|
||||||
|
reachable_headers.resize(std::distance(reachable_headers.begin(), iter));
|
||||||
|
|
||||||
|
if (reachable_headers.empty()) {
|
||||||
|
return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(block->id()))
|
||||||
|
<< "For WebGPU, unreachable continue-target must be referenced "
|
||||||
|
"by a reachable loop instruction.";
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Check that the only instructions are OpLabel and OpBranch.
|
||||||
|
auto* label_inst = block->label();
|
||||||
|
auto* terminator_inst = block->terminator();
|
||||||
|
assert(label_inst != nullptr);
|
||||||
|
assert(terminator_inst != nullptr);
|
||||||
|
|
||||||
|
if (terminator_inst->opcode() != SpvOpBranch) {
|
||||||
|
return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(block->id()))
|
||||||
|
<< "For WebGPU, unreachable continue-target must terminate with "
|
||||||
|
"OpBranch.";
|
||||||
|
}
|
||||||
|
|
||||||
|
auto label_idx = label_inst - &_.ordered_instructions()[0];
|
||||||
|
auto terminator_idx = terminator_inst - &_.ordered_instructions()[0];
|
||||||
|
if (label_idx + 1 != terminator_idx) {
|
||||||
|
return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(block->id()))
|
||||||
|
<< "For WebGPU, unreachable continue-target must only contain "
|
||||||
|
"an OpLabel and an OpBranch instruction.";
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Use label instruction to confirm there is no uses by branches.
|
||||||
|
for (auto use : label_inst->uses()) {
|
||||||
|
const auto* use_inst = use.first;
|
||||||
|
if (spvOpcodeIsBranch(use_inst->opcode())) {
|
||||||
|
return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(block->id()))
|
||||||
|
<< "For WebGPU, unreachable continue-target cannot be the "
|
||||||
|
"target of a branch.";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. Confirm that continue-target has a back edge to a reachable loop
|
||||||
|
// header block.
|
||||||
|
auto branch_target = terminator_inst->GetOperandAs<uint32_t>(0);
|
||||||
|
for (auto* continue_header : reachable_headers) {
|
||||||
|
if (branch_target != continue_header->id()) {
|
||||||
|
return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(block->id()))
|
||||||
|
<< "For WebGPU, unreachable continue-target must only have a "
|
||||||
|
"back edge to a single reachable loop instruction.";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(block->id()))
|
||||||
|
<< "For WebGPU, all blocks must be reachable, unless they are "
|
||||||
|
<< "degenerate cases of merge-block or continue-target.";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return SPV_SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
spv_result_t PerformCfgChecks(ValidationState_t& _) {
|
spv_result_t PerformCfgChecks(ValidationState_t& _) {
|
||||||
for (auto& function : _.functions()) {
|
for (auto& function : _.functions()) {
|
||||||
// Check all referenced blocks are defined within a function
|
// Check all referenced blocks are defined within a function
|
||||||
@ -689,6 +804,13 @@ spv_result_t PerformCfgChecks(ValidationState_t& _) {
|
|||||||
<< _.getIdName(idom->id());
|
<< _.getIdName(idom->id());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// For WebGPU check that all unreachable blocks are degenerate cases for
|
||||||
|
// merge-block or continue-target.
|
||||||
|
if (spvIsWebGPUEnv(_.context()->target_env)) {
|
||||||
|
spv_result_t result = PerformWebGPUCfgChecks(_, &function);
|
||||||
|
if (result != SPV_SUCCESS) return result;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// If we have structed control flow, check that no block has a control
|
// If we have structed control flow, check that no block has a control
|
||||||
// flow nesting depth larger than the limit.
|
// flow nesting depth larger than the limit.
|
||||||
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user