Check for invalid branches into construct body.

Fixes #1281

* New structured cfg check: all non-construct header blocks'
predecessors must come from within the construct
* New function to calculate blocks in a construct

* Fixed a bug in BasicBlock type bitset

Relaxing check to not consider unreachable predecessors

* Fixing broken common uniform elim test
This commit is contained in:
Alan Baker 2018-04-25 09:26:41 -04:00 committed by David Neto
parent 035afb899c
commit 06de86863b
6 changed files with 174 additions and 38 deletions

View File

@ -205,7 +205,7 @@ class BasicBlock {
std::vector<BasicBlock*> successors_; std::vector<BasicBlock*> successors_;
/// The type of the block /// The type of the block
std::bitset<kBlockTypeCOUNT - 1> type_; std::bitset<kBlockTypeCOUNT> type_;
/// True if the block is reachable in the CFG /// True if the block is reachable in the CFG
bool reachable_; bool reachable_;

View File

@ -13,9 +13,11 @@
// limitations under the License. // limitations under the License.
#include "val/construct.h" #include "val/construct.h"
#include "val/function.h"
#include <cassert> #include <cassert>
#include <cstddef> #include <cstddef>
#include <unordered_set>
namespace libspirv { namespace libspirv {
@ -64,4 +66,59 @@ const BasicBlock* Construct::exit_block() const { return exit_block_; }
BasicBlock* Construct::exit_block() { return exit_block_; } BasicBlock* Construct::exit_block() { return exit_block_; }
void Construct::set_exit(BasicBlock* block) { exit_block_ = block; } void Construct::set_exit(BasicBlock* block) { exit_block_ = block; }
Construct::ConstructBlockSet Construct::blocks(Function* function) const {
auto header = entry_block();
auto merge = exit_block();
assert(header);
assert(merge);
int header_depth = function->GetBlockDepth(const_cast<BasicBlock*>(header));
ConstructBlockSet construct_blocks;
std::unordered_set<BasicBlock*> corresponding_headers;
for (auto& other : corresponding_constructs()) {
corresponding_headers.insert(other->entry_block());
}
std::vector<BasicBlock*> stack;
stack.push_back(const_cast<BasicBlock*>(header));
while (!stack.empty()) {
BasicBlock* block = stack.back();
stack.pop_back();
if (merge == block && ExitBlockIsMergeBlock()) {
// Merge block is not part of the construct.
continue;
}
if (corresponding_headers.count(block)) {
// Entered a corresponding construct.
continue;
}
int block_depth = function->GetBlockDepth(block);
if (block_depth < header_depth) {
// Broke to outer construct.
continue;
}
// In a loop, the continue target is at a depth of the loop construct + 1.
// A selection construct nested directly within the loop construct is also
// at the same depth. It is valid, however, to branch directly to the
// continue target from within the selection construct.
if (block_depth == header_depth && type() == ConstructType::kSelection &&
block->is_type(kBlockTypeContinue)) {
// Continued to outer construct.
continue;
}
if (!construct_blocks.insert(block).second) continue;
if (merge != block) {
for (auto succ : *block->successors()) {
stack.push_back(succ);
}
}
}
return construct_blocks;
}
} // namespace libspirv } // namespace libspirv

View File

@ -15,11 +15,21 @@
#ifndef LIBSPIRV_VAL_CONSTRUCT_H_ #ifndef LIBSPIRV_VAL_CONSTRUCT_H_
#define LIBSPIRV_VAL_CONSTRUCT_H_ #define LIBSPIRV_VAL_CONSTRUCT_H_
#include "val/basic_block.h"
#include <cstdint> #include <cstdint>
#include <set>
#include <vector> #include <vector>
namespace libspirv { namespace libspirv {
/// Functor for ordering BasicBlocks. BasicBlock pointers must not be null.
struct less_than_id {
bool operator()(const BasicBlock* lhs, const BasicBlock* rhs) const {
return lhs->id() < rhs->id();
}
};
enum class ConstructType : int { enum class ConstructType : int {
kNone = 0, kNone = 0,
/// The set of blocks dominated by a selection header, minus the set of blocks /// The set of blocks dominated by a selection header, minus the set of blocks
@ -39,7 +49,7 @@ enum class ConstructType : int {
kCase kCase
}; };
class BasicBlock; class Function;
/// @brief This class tracks the CFG constructs as defined in the SPIR-V spec /// @brief This class tracks the CFG constructs as defined in the SPIR-V spec
class Construct { class Construct {
@ -91,6 +101,13 @@ class Construct {
return type_ == ConstructType::kLoop || type_ == ConstructType::kSelection; return type_ == ConstructType::kLoop || type_ == ConstructType::kSelection;
} }
using ConstructBlockSet = std::set<BasicBlock*, less_than_id>;
// Returns the basic blocks in this construct. This function should not
// be called before the exit block is set and dominators have been
// calculated.
ConstructBlockSet blocks(Function* function) const;
private: private:
/// The type of the construct /// The type of the construct
ConstructType type_; ConstructType type_;

View File

@ -168,7 +168,7 @@ string ConstructErrorString(const Construct& construct,
} }
spv_result_t StructuredControlFlowChecks( spv_result_t StructuredControlFlowChecks(
const ValidationState_t& _, const Function& function, const ValidationState_t& _, Function* function,
const vector<pair<uint32_t, uint32_t>>& back_edges) { const vector<pair<uint32_t, uint32_t>>& back_edges) {
/// Check all backedges target only loop headers and have exactly one /// Check all backedges target only loop headers and have exactly one
/// back-edge branching to it /// back-edge branching to it
@ -179,7 +179,7 @@ spv_result_t StructuredControlFlowChecks(
uint32_t back_edge_block; uint32_t back_edge_block;
uint32_t header_block; uint32_t header_block;
tie(back_edge_block, header_block) = back_edge; tie(back_edge_block, header_block) = back_edge;
if (!function.IsBlockType(header_block, kBlockTypeLoop)) { if (!function->IsBlockType(header_block, kBlockTypeLoop)) {
return _.diag(SPV_ERROR_INVALID_CFG) return _.diag(SPV_ERROR_INVALID_CFG)
<< "Back-edges (" << _.getIdName(back_edge_block) << " -> " << "Back-edges (" << _.getIdName(back_edge_block) << " -> "
<< _.getIdName(header_block) << _.getIdName(header_block)
@ -189,7 +189,7 @@ spv_result_t StructuredControlFlowChecks(
} }
// Check the loop headers have exactly one back-edge branching to it // Check the loop headers have exactly one back-edge branching to it
for (BasicBlock* loop_header : function.ordered_blocks()) { for (BasicBlock* loop_header : function->ordered_blocks()) {
if (!loop_header->reachable()) continue; if (!loop_header->reachable()) continue;
if (!loop_header->is_type(kBlockTypeLoop)) continue; if (!loop_header->is_type(kBlockTypeLoop)) continue;
auto loop_header_id = loop_header->id(); auto loop_header_id = loop_header->id();
@ -203,7 +203,7 @@ spv_result_t StructuredControlFlowChecks(
} }
// Check construct rules // Check construct rules
for (const Construct& construct : function.constructs()) { for (const Construct& construct : function->constructs()) {
auto header = construct.entry_block(); auto header = construct.entry_block();
auto merge = construct.exit_block(); auto merge = construct.exit_block();
@ -242,6 +242,25 @@ spv_result_t StructuredControlFlowChecks(
_.getIdName(merge->id()), "is not post dominated by"); _.getIdName(merge->id()), "is not post dominated by");
} }
} }
// Check that for all non-header blocks, all predecessors are within this
// construct.
Construct::ConstructBlockSet construct_blocks = construct.blocks(function);
for (auto block : construct_blocks) {
if (block == header) continue;
for (auto pred : *block->predecessors()) {
if (pred->reachable() && !construct_blocks.count(pred)) {
string construct_name, header_name, exit_name;
tie(construct_name, header_name, exit_name) =
ConstructNames(construct.type());
return _.diag(SPV_ERROR_INVALID_CFG)
<< "block <ID> " << pred->id() << " branches to the "
<< construct_name << " construct, but not to the "
<< header_name << " <ID> " << header->id();
}
}
}
// TODO(umar): an OpSwitch block dominates all its defined case // TODO(umar): an OpSwitch block dominates all its defined case
// constructs // constructs
// TODO(umar): each case construct has at most one branch to another // TODO(umar): each case construct has at most one branch to another
@ -352,7 +371,7 @@ spv_result_t PerformCfgChecks(ValidationState_t& _) {
/// Structured control flow checks are only required for shader capabilities /// Structured control flow checks are only required for shader capabilities
if (_.HasCapability(SpvCapabilityShader)) { if (_.HasCapability(SpvCapabilityShader)) {
if (auto error = StructuredControlFlowChecks(_, function, back_edges)) if (auto error = StructuredControlFlowChecks(_, &function, back_edges))
return error; return error;
} }
} }

View File

@ -1131,6 +1131,10 @@ OpStore %v %27
%30 = OpINotEqual %bool %29 %uint_0 %30 = OpINotEqual %bool %29 %uint_0
OpSelectionMerge %31 None OpSelectionMerge %31 None
OpBranchConditional %30 %31 %32 OpBranchConditional %30 %31 %32
%32 = OpLabel
%47 = OpLoad %v4float %v
OpStore %gl_FragColor %47
OpReturn
%31 = OpLabel %31 = OpLabel
%33 = OpAccessChain %_ptr_Uniform_float %_ %int_1 %33 = OpAccessChain %_ptr_Uniform_float %_ %int_1
%34 = OpLoad %float %33 %34 = OpLoad %float %33
@ -1146,16 +1150,20 @@ OpBranchConditional %38 %43 %39
%41 = OpLoad %v4float %v %41 = OpLoad %v4float %v
%42 = OpVectorTimesScalar %v4float %41 %40 %42 = OpVectorTimesScalar %v4float %41 %40
OpStore %v %42 OpStore %v %42
OpBranch %32 OpBranch %50
%50 = OpLabel
%51 = OpLoad %v4float %v
OpStore %gl_FragColor %51
OpReturn
%43 = OpLabel %43 = OpLabel
%44 = OpLoad %float %fi %44 = OpLoad %float %fi
%45 = OpLoad %v4float %v %45 = OpLoad %v4float %v
%46 = OpVectorTimesScalar %v4float %45 %44 %46 = OpVectorTimesScalar %v4float %45 %44
OpStore %v %46 OpStore %v %46
OpBranch %32 OpBranch %60
%32 = OpLabel %60 = OpLabel
%47 = OpLoad %v4float %v %61 = OpLoad %v4float %v
OpStore %gl_FragColor %47 OpStore %gl_FragColor %61
OpReturn OpReturn
OpFunctionEnd OpFunctionEnd
)"; )";
@ -1166,35 +1174,43 @@ OpFunctionEnd
%v = OpVariable %_ptr_Function_v4float Function %v = OpVariable %_ptr_Function_v4float Function
%29 = OpLoad %v4float %BaseColor %29 = OpLoad %v4float %BaseColor
OpStore %v %29 OpStore %v %29
%50 = OpLoad %U_t %_ %54 = OpLoad %U_t %_
%51 = OpCompositeExtract %uint %50 0 %55 = OpCompositeExtract %uint %54 0
%32 = OpINotEqual %bool %51 %uint_0 %32 = OpINotEqual %bool %55 %uint_0
OpSelectionMerge %33 None OpSelectionMerge %33 None
OpBranchConditional %32 %33 %34 OpBranchConditional %32 %33 %34
%33 = OpLabel
%54 = OpLoad %float %alpha
%53 = OpCompositeExtract %float %50 1
%37 = OpLoad %v4float %v
%38 = OpVectorTimesScalar %v4float %37 %53
OpStore %v %38
%39 = OpLoad %uint %alpha_B
%40 = OpIEqual %bool %39 %uint_0
OpSelectionMerge %41 None
OpBranchConditional %40 %41 %42
%42 = OpLabel
%44 = OpLoad %v4float %v
%45 = OpVectorTimesScalar %v4float %44 %54
OpStore %v %45
OpBranch %34
%41 = OpLabel
%46 = OpLoad %float %fi
%47 = OpLoad %v4float %v
%48 = OpVectorTimesScalar %v4float %47 %46
OpStore %v %48
OpBranch %34
%34 = OpLabel %34 = OpLabel
%49 = OpLoad %v4float %v %35 = OpLoad %v4float %v
OpStore %gl_FragColor %49 OpStore %gl_FragColor %35
OpReturn
%33 = OpLabel
%58 = OpLoad %float %alpha
%57 = OpCompositeExtract %float %54 1
%38 = OpLoad %v4float %v
%39 = OpVectorTimesScalar %v4float %38 %57
OpStore %v %39
%40 = OpLoad %uint %alpha_B
%41 = OpIEqual %bool %40 %uint_0
OpSelectionMerge %42 None
OpBranchConditional %41 %42 %43
%43 = OpLabel
%45 = OpLoad %v4float %v
%46 = OpVectorTimesScalar %v4float %45 %58
OpStore %v %46
OpBranch %47
%47 = OpLabel
%48 = OpLoad %v4float %v
OpStore %gl_FragColor %48
OpReturn
%42 = OpLabel
%49 = OpLoad %float %fi
%50 = OpLoad %v4float %v
%51 = OpVectorTimesScalar %v4float %50 %49
OpStore %v %51
OpBranch %52
%52 = OpLabel
%53 = OpLoad %v4float %v
OpStore %gl_FragColor %53
OpReturn OpReturn
OpFunctionEnd OpFunctionEnd
)"; )";

View File

@ -1423,6 +1423,33 @@ TEST_F(ValidateCFG, OpReturnInNonVoidFunc) {
"OpReturn can only be called from a function with void return type")); "OpReturn can only be called from a function with void return type"));
} }
TEST_F(ValidateCFG, StructuredCFGBranchIntoSelectionBody) {
std::string spirv = R"(
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %func "func"
%void = OpTypeVoid
%bool = OpTypeBool
%true = OpConstantTrue %bool
%functy = OpTypeFunction %void
%func = OpFunction %void None %functy
%entry = OpLabel
OpSelectionMerge %merge None
OpBranchConditional %true %then %merge
%merge = OpLabel
OpBranch %then
%then = OpLabel
OpReturn
OpFunctionEnd
)";
CompileSuccessfully(spirv);
EXPECT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
HasSubstr("branches to the selection construct, but not to the "
"selection header <ID>"));
}
/// TODO(umar): Switch instructions /// TODO(umar): Switch instructions
/// TODO(umar): Nested CFG constructs /// TODO(umar): Nested CFG constructs
} // namespace } // namespace