Validator structured flow checks: back-edge, constructs

Skip structured control flow chekcs for non-shader capability.

Fix infinite loop in dominator algorithm when there's an
unreachable block.
This commit is contained in:
Umar Arshad 2016-06-03 21:24:24 -04:00 committed by David Neto
parent 7cdf39c8f1
commit f61db0bcc6
11 changed files with 1010 additions and 278 deletions

View File

@ -35,25 +35,38 @@ namespace libspirv {
BasicBlock::BasicBlock(uint32_t id) BasicBlock::BasicBlock(uint32_t id)
: id_(id), : id_(id),
immediate_dominator_(nullptr), immediate_dominator_(nullptr),
immediate_post_dominator_(nullptr),
predecessors_(), predecessors_(),
successors_(), successors_(),
type_(0),
reachable_(false) {} reachable_(false) {}
void BasicBlock::SetImmediateDominator(BasicBlock* dom_block) { void BasicBlock::SetImmediateDominator(BasicBlock* dom_block) {
immediate_dominator_ = dom_block; immediate_dominator_ = dom_block;
} }
void BasicBlock::SetImmediatePostDominator(BasicBlock* pdom_block) {
immediate_post_dominator_ = pdom_block;
}
const BasicBlock* BasicBlock::GetImmediateDominator() const { const BasicBlock* BasicBlock::GetImmediateDominator() const {
return immediate_dominator_; return immediate_dominator_;
} }
BasicBlock* BasicBlock::GetImmediateDominator() { return immediate_dominator_; } const BasicBlock* BasicBlock::GetImmediatePostDominator() const {
return immediate_post_dominator_;
}
void BasicBlock::RegisterSuccessors(vector<BasicBlock*> next_blocks) { BasicBlock* BasicBlock::GetImmediateDominator() { return immediate_dominator_; }
BasicBlock* BasicBlock::GetImmediatePostDominator() {
return immediate_post_dominator_;
}
void BasicBlock::RegisterSuccessors(const vector<BasicBlock*>& next_blocks) {
for (auto& block : next_blocks) { for (auto& block : next_blocks) {
block->predecessors_.push_back(this); block->predecessors_.push_back(this);
successors_.push_back(block); successors_.push_back(block);
if (block->reachable_ == false) block->set_reachability(reachable_); if (block->reachable_ == false) block->set_reachable(reachable_);
} }
} }
@ -63,24 +76,29 @@ void BasicBlock::RegisterBranchInstruction(SpvOp branch_instruction) {
} }
BasicBlock::DominatorIterator::DominatorIterator() : current_(nullptr) {} BasicBlock::DominatorIterator::DominatorIterator() : current_(nullptr) {}
BasicBlock::DominatorIterator::DominatorIterator(const BasicBlock* block)
: current_(block) {} BasicBlock::DominatorIterator::DominatorIterator(
const BasicBlock* block,
std::function<const BasicBlock*(const BasicBlock*)> dominator_func)
: current_(block), dom_func_(dominator_func) {}
BasicBlock::DominatorIterator& BasicBlock::DominatorIterator::operator++() { BasicBlock::DominatorIterator& BasicBlock::DominatorIterator::operator++() {
if (current_ == current_->GetImmediateDominator()) { if (current_ == dom_func_(current_)) {
current_ = nullptr; current_ = nullptr;
} else { } else {
current_ = current_->GetImmediateDominator(); current_ = dom_func_(current_);
} }
return *this; return *this;
} }
const BasicBlock::DominatorIterator BasicBlock::dom_begin() const { const BasicBlock::DominatorIterator BasicBlock::dom_begin() const {
return DominatorIterator(this); return DominatorIterator(
this, [](const BasicBlock* b) { return b->GetImmediateDominator(); });
} }
BasicBlock::DominatorIterator BasicBlock::dom_begin() { BasicBlock::DominatorIterator BasicBlock::dom_begin() {
return DominatorIterator(this); return DominatorIterator(
this, [](const BasicBlock* b) { return b->GetImmediateDominator(); });
} }
const BasicBlock::DominatorIterator BasicBlock::dom_end() const { const BasicBlock::DominatorIterator BasicBlock::dom_end() const {
@ -91,6 +109,24 @@ BasicBlock::DominatorIterator BasicBlock::dom_end() {
return DominatorIterator(); return DominatorIterator();
} }
const BasicBlock::DominatorIterator BasicBlock::pdom_begin() const {
return DominatorIterator(
this, [](const BasicBlock* b) { return b->GetImmediatePostDominator(); });
}
BasicBlock::DominatorIterator BasicBlock::pdom_begin() {
return DominatorIterator(
this, [](const BasicBlock* b) { return b->GetImmediatePostDominator(); });
}
const BasicBlock::DominatorIterator BasicBlock::pdom_end() const {
return DominatorIterator();
}
BasicBlock::DominatorIterator BasicBlock::pdom_end() {
return DominatorIterator();
}
bool operator==(const BasicBlock::DominatorIterator& lhs, bool operator==(const BasicBlock::DominatorIterator& lhs,
const BasicBlock::DominatorIterator& rhs) { const BasicBlock::DominatorIterator& rhs) {
return lhs.current_ == rhs.current_; return lhs.current_ == rhs.current_;

View File

@ -30,10 +30,24 @@
#include "spirv/1.1/spirv.h" #include "spirv/1.1/spirv.h"
#include <cstdint> #include <cstdint>
#include <bitset>
#include <functional>
#include <vector> #include <vector>
namespace libspirv { namespace libspirv {
enum BlockType : uint32_t {
kBlockTypeUndefined,
kBlockTypeHeader,
kBlockTypeLoop,
kBlockTypeMerge,
kBlockTypeBreak,
kBlockTypeContinue,
kBlockTypeReturn,
kBlockTypeCOUNT ///< Total number of block types. (must be the last element)
};
// This class represents a basic block in a SPIR-V module // This class represents a basic block in a SPIR-V module
class BasicBlock { class BasicBlock {
public: public:
@ -61,27 +75,53 @@ class BasicBlock {
/// Returns the successors of the BasicBlock /// Returns the successors of the BasicBlock
std::vector<BasicBlock*>* get_successors() { return &successors_; } std::vector<BasicBlock*>* get_successors() { return &successors_; }
/// Returns true if the block should be reachable in the CFG /// Returns true if the block is reachable in the CFG
bool is_reachable() const { return reachable_; } bool is_reachable() const { return reachable_; }
void set_reachability(bool reachability) { reachable_ = reachability; } /// Returns true if BasicBlock is of the given type
bool is_type(BlockType type) const {
if (type == kBlockTypeUndefined) return type_.none();
return type_.test(type);
}
/// Sets the reachability of the basic block in the CFG
void set_reachable(bool reachability) { reachable_ = reachability; }
/// Sets the type of the BasicBlock
void set_type(BlockType type) {
if (type == kBlockTypeUndefined)
type_.reset();
else
type_.set(type);
}
/// Sets the immedate dominator of this basic block /// Sets the immedate dominator of this basic block
/// ///
/// @param[in] dom_block The dominator block /// @param[in] dom_block The dominator block
void SetImmediateDominator(BasicBlock* dom_block); void SetImmediateDominator(BasicBlock* dom_block);
/// Sets the immedate post dominator of this basic block
///
/// @param[in] pdom_block The post dominator block
void SetImmediatePostDominator(BasicBlock* pdom_block);
/// Returns the immedate dominator of this basic block /// Returns the immedate dominator of this basic block
BasicBlock* GetImmediateDominator(); BasicBlock* GetImmediateDominator();
/// Returns the immedate dominator of this basic block /// Returns the immedate dominator of this basic block
const BasicBlock* GetImmediateDominator() const; const BasicBlock* GetImmediateDominator() const;
/// Returns the immedate post dominator of this basic block
BasicBlock* GetImmediatePostDominator();
/// Returns the immedate post dominator of this basic block
const BasicBlock* GetImmediatePostDominator() const;
/// Ends the block without a successor /// Ends the block without a successor
void RegisterBranchInstruction(SpvOp branch_instruction); void RegisterBranchInstruction(SpvOp branch_instruction);
/// Adds @p next BasicBlocks as successors of this BasicBlock /// Adds @p next BasicBlocks as successors of this BasicBlock
void RegisterSuccessors(std::vector<BasicBlock*> next = {}); void RegisterSuccessors(const std::vector<BasicBlock*>& next = {});
/// Returns true if the id of the BasicBlock matches /// Returns true if the id of the BasicBlock matches
bool operator==(const BasicBlock& other) const { return other.id_ == id_; } bool operator==(const BasicBlock& other) const { return other.id_ == id_; }
@ -91,7 +131,7 @@ class BasicBlock {
/// @brief A BasicBlock dominator iterator class /// @brief A BasicBlock dominator iterator class
/// ///
/// This iterator will iterate over the dominators of the block /// This iterator will iterate over the (post)dominators of the block
class DominatorIterator class DominatorIterator
: public std::iterator<std::forward_iterator_tag, BasicBlock*> { : public std::iterator<std::forward_iterator_tag, BasicBlock*> {
public: public:
@ -104,8 +144,12 @@ class BasicBlock {
/// @brief Constructs an iterator for the given block which points to /// @brief Constructs an iterator for the given block which points to
/// @p block /// @p block
/// ///
/// @param block The block which is referenced by the iterator /// @param block The block which is referenced by the iterator
explicit DominatorIterator(const BasicBlock* block); /// @param dominator_func This function will be called to get the immediate
/// (post)dominator of the current block
DominatorIterator(
const BasicBlock* block,
std::function<const BasicBlock*(const BasicBlock*)> dominator_func);
/// @brief Advances the iterator /// @brief Advances the iterator
DominatorIterator& operator++(); DominatorIterator& operator++();
@ -118,16 +162,36 @@ class BasicBlock {
private: private:
const BasicBlock* current_; const BasicBlock* current_;
std::function<const BasicBlock*(const BasicBlock*)> dom_func_;
}; };
/// Returns an iterator which points to the current block /// Returns a dominator iterator which points to the current block
const DominatorIterator dom_begin() const; const DominatorIterator dom_begin() const;
/// Returns a dominator iterator which points to the current block
DominatorIterator dom_begin(); DominatorIterator dom_begin();
/// Returns an iterator which points to one element past the first block /// Returns a dominator iterator which points to one element past the first
/// block
const DominatorIterator dom_end() const; const DominatorIterator dom_end() const;
/// Returns a dominator iterator which points to one element past the first
/// block
DominatorIterator dom_end(); DominatorIterator dom_end();
/// Returns a post dominator iterator which points to the current block
const DominatorIterator pdom_begin() const;
/// Returns a post dominator iterator which points to the current block
DominatorIterator pdom_begin();
/// Returns a post dominator iterator which points to one element past the
/// last block
const DominatorIterator pdom_end() const;
/// Returns a post dominator iterator which points to one element past the
/// last block
DominatorIterator pdom_end();
private: private:
/// Id of the BasicBlock /// Id of the BasicBlock
const uint32_t id_; const uint32_t id_;
@ -135,12 +199,19 @@ class BasicBlock {
/// Pointer to the immediate dominator of the BasicBlock /// Pointer to the immediate dominator of the BasicBlock
BasicBlock* immediate_dominator_; BasicBlock* immediate_dominator_;
/// Pointer to the immediate dominator of the BasicBlock
BasicBlock* immediate_post_dominator_;
/// The set of predecessors of the BasicBlock /// The set of predecessors of the BasicBlock
std::vector<BasicBlock*> predecessors_; std::vector<BasicBlock*> predecessors_;
/// The set of successors of the BasicBlock /// The set of successors of the BasicBlock
std::vector<BasicBlock*> successors_; std::vector<BasicBlock*> successors_;
/// The type of the block
std::bitset<kBlockTypeCOUNT - 1> type_;
/// True if the block is reachable in the CFG
bool reachable_; bool reachable_;
}; };

View File

@ -26,19 +26,51 @@
#include "val/Construct.h" #include "val/Construct.h"
#include <cassert>
#include <cstddef>
namespace libspirv { namespace libspirv {
Construct::Construct(BasicBlock* header_block, BasicBlock* merge_block, Construct::Construct(ConstructType type, BasicBlock* entry,
BasicBlock* continue_block) BasicBlock* exit, std::vector<Construct*> constructs)
: header_block_(header_block), : type_(type),
merge_block_(merge_block), corresponding_constructs_(constructs),
continue_block_(continue_block) {} entry_block_(entry),
exit_block_(exit) {}
const BasicBlock* Construct::get_header() const { return header_block_; } ConstructType Construct::get_type() const { return type_; }
const BasicBlock* Construct::get_merge() const { return merge_block_; }
const BasicBlock* Construct::get_continue() const { return continue_block_; }
BasicBlock* Construct::get_header() { return header_block_; } const std::vector<Construct*>& Construct::get_corresponding_constructs() const {
BasicBlock* Construct::get_merge() { return merge_block_; } return corresponding_constructs_;
BasicBlock* Construct::get_continue() { return continue_block_; }
} }
std::vector<Construct*>& Construct::get_corresponding_constructs() {
return corresponding_constructs_;
}
bool ValidateConstructSize(ConstructType type, size_t size) {
switch (type) {
case ConstructType::kSelection: return size == 0;
case ConstructType::kContinue: return size == 1;
case ConstructType::kLoop: return size == 1;
case ConstructType::kCase: return size >= 1;
default: assert(1 == 0 && "Type not defined");
}
return false;
}
void Construct::set_corresponding_constructs(
std::vector<Construct*> constructs) {
assert(ValidateConstructSize(type_, constructs.size()));
corresponding_constructs_ = constructs;
}
const BasicBlock* Construct::get_entry() const { return entry_block_; }
BasicBlock* Construct::get_entry() { return entry_block_; }
const BasicBlock* Construct::get_exit() const { return exit_block_; }
BasicBlock* Construct::get_exit() { return exit_block_; }
void Construct::set_exit(BasicBlock* exit_block) {
exit_block_ = exit_block;
}
} /// namespace libspirv

View File

@ -28,29 +28,109 @@
#define LIBSPIRV_VAL_CONSTRUCT_H_ #define LIBSPIRV_VAL_CONSTRUCT_H_
#include <cstdint> #include <cstdint>
#include <vector>
namespace libspirv { namespace libspirv {
enum class ConstructType {
kNone,
/// The set of blocks dominated by a selection header, minus the set of blocks
/// dominated by the header's merge block
kSelection,
/// The set of blocks dominated by an OpLoopMerge's Continue Target and post
/// dominated by the corresponding back
kContinue,
/// The set of blocks dominated by a loop header, minus the set of blocks
/// dominated by the loop's merge block, minus the loop's corresponding
/// continue construct
kLoop,
/// The set of blocks dominated by an OpSwitch's Target or Default, minus the
/// set of blocks dominated by the OpSwitch's merge block (this construct is
/// only defined for those OpSwitch Target or Default that are not equal to
/// the OpSwitch's corresponding merge block)
kCase
};
class BasicBlock; class BasicBlock;
/// @brief This class tracks the CFG constructs as defined in the SPIR-V spec /// @brief This class tracks the CFG constructs as defined in the SPIR-V spec
class Construct { class Construct {
public: public:
Construct(BasicBlock* header_block, BasicBlock* merge_block, Construct(ConstructType type, BasicBlock* dominator,
BasicBlock* continue_block = nullptr); BasicBlock* exit = nullptr,
std::vector<Construct*> constructs = {});
const BasicBlock* get_header() const; /// Returns the type of the construct
const BasicBlock* get_merge() const; ConstructType get_type() const;
const BasicBlock* get_continue() const;
BasicBlock* get_header(); const std::vector<Construct*>& get_corresponding_constructs() const;
BasicBlock* get_merge(); std::vector<Construct*>& get_corresponding_constructs();
BasicBlock* get_continue(); void set_corresponding_constructs(std::vector<Construct*> constructs);
/// Returns the dominator block of the construct.
///
/// This is usually the header block or the first block of the construct.
const BasicBlock* get_entry() const;
/// Returns the dominator block of the construct.
///
/// This is usually the header block or the first block of the construct.
BasicBlock* get_entry();
/// Returns the exit block of the construct.
///
/// For a continue construct it is the backedge block of the corresponding
/// loop construct. For the case construct it is the block that branches to
/// the OpSwitch merge block or other case blocks. Otherwise it is the merge
/// block of the corresponding header block
const BasicBlock* get_exit() const;
/// Returns the exit block of the construct.
///
/// For a continue construct it is the backedge block of the corresponding
/// loop construct. For the case construct it is the block that branches to
/// the OpSwitch merge block or other case blocks. Otherwise it is the merge
/// block of the corresponding header block
BasicBlock* get_exit();
/// Sets the exit block for this construct. This is useful for continue
/// constructs which do not know the back-edge block during construction
void set_exit(BasicBlock* exit_block);
private: private:
BasicBlock* header_block_; ///< The header block of a loop or selection /// The type of the construct
BasicBlock* merge_block_; ///< The merge block of a loop or selection ConstructType type_;
BasicBlock* continue_block_; ///< The continue block of a loop block
/// These are the constructs that are related to this construct. These
/// constructs can be the continue construct, for the corresponding loop
/// construct, the case construct that are part of the same OpSwitch
/// instruction
///
/// Here is a table that describes what constructs are included in
/// @p corresponding_constructs_
/// | this construct | corresponding construct |
/// |----------------|----------------------------------|
/// | loop | continue |
/// | continue | loop |
/// | case | other cases in the same OpSwitch |
///
/// kContinue and kLoop constructs will always have corresponding
/// constructs even if they are represented by the same block
std::vector<Construct*> corresponding_constructs_;
/// @brief Dominator block for the construct
///
/// The dominator block for the construct. Depending on the construct this may
/// be a selection header, a continue target of a loop, a loop header or a
/// Target or Default block of a switch
BasicBlock* entry_block_;
/// @brief Exiting block for the construct
///
/// The exit block for the construct. This can be a merge block for the loop
/// and selection constructs, a back-edge block for a continue construct, or
/// the branching block for the case construct
BasicBlock* exit_block_;
}; };
} /// namespace libspirv } /// namespace libspirv

View File

@ -29,13 +29,18 @@
#include <cassert> #include <cassert>
#include <algorithm> #include <algorithm>
#include <utility>
#include "val/BasicBlock.h" #include "val/BasicBlock.h"
#include "val/Construct.h" #include "val/Construct.h"
#include "val/ValidationState.h" #include "val/ValidationState.h"
using std::ignore;
using std::list; using std::list;
using std::make_pair;
using std::pair;
using std::string; using std::string;
using std::tie;
using std::vector; using std::vector;
namespace libspirv { namespace libspirv {
@ -66,6 +71,7 @@ Function::Function(uint32_t id, uint32_t result_type_id,
declaration_type_(FunctionDecl::kFunctionDeclUnknown), declaration_type_(FunctionDecl::kFunctionDeclUnknown),
blocks_(), blocks_(),
current_block_(nullptr), current_block_(nullptr),
pseudo_exit_block_(kInvalidId),
cfg_constructs_(), cfg_constructs_(),
variable_ids_(), variable_ids_(),
parameter_ids_() {} parameter_ids_() {}
@ -93,15 +99,33 @@ spv_result_t Function::RegisterLoopMerge(uint32_t merge_id,
uint32_t continue_id) { uint32_t continue_id) {
RegisterBlock(merge_id, false); RegisterBlock(merge_id, false);
RegisterBlock(continue_id, false); RegisterBlock(continue_id, false);
cfg_constructs_.emplace_back(get_current_block(), &blocks_.at(merge_id), BasicBlock& merge_block = blocks_.at(merge_id);
&blocks_.at(continue_id)); BasicBlock& continue_block = blocks_.at(continue_id);
assert(current_block_ &&
"RegisterLoopMerge must be called when called within a block");
current_block_->set_type(kBlockTypeLoop);
merge_block.set_type(kBlockTypeMerge);
continue_block.set_type(kBlockTypeContinue);
cfg_constructs_.emplace_back(ConstructType::kLoop, current_block_,
&merge_block);
Construct& loop_construct = cfg_constructs_.back();
cfg_constructs_.emplace_back(ConstructType::kContinue, &continue_block);
Construct& continue_construct = cfg_constructs_.back();
continue_construct.set_corresponding_constructs({&loop_construct});
loop_construct.set_corresponding_constructs({&continue_construct});
return SPV_SUCCESS; return SPV_SUCCESS;
} }
spv_result_t Function::RegisterSelectionMerge(uint32_t merge_id) { spv_result_t Function::RegisterSelectionMerge(uint32_t merge_id) {
RegisterBlock(merge_id, false); RegisterBlock(merge_id, false);
cfg_constructs_.emplace_back(get_current_block(), &blocks_.at(merge_id)); BasicBlock& merge_block = blocks_.at(merge_id);
current_block_->set_type(kBlockTypeHeader);
merge_block.set_type(kBlockTypeMerge);
cfg_constructs_.emplace_back(ConstructType::kSelection, get_current_block(),
&merge_block);
return SPV_SUCCESS; return SPV_SUCCESS;
} }
@ -152,7 +176,7 @@ spv_result_t Function::RegisterBlock(uint32_t id, bool is_definition) {
undefined_blocks_.erase(id); undefined_blocks_.erase(id);
current_block_ = &inserted_block->second; current_block_ = &inserted_block->second;
ordered_blocks_.push_back(current_block_); ordered_blocks_.push_back(current_block_);
if (IsFirstBlock(id)) current_block_->set_reachability(true); if (IsFirstBlock(id)) current_block_->set_reachable(true);
} else if (success) { // Block doesn't exsist but this is not a definition } else if (success) { // Block doesn't exsist but this is not a definition
undefined_blocks_.insert(id); undefined_blocks_.insert(id);
} }
@ -182,6 +206,11 @@ void Function::RegisterBlockEnd(vector<uint32_t> next_list,
next_blocks.push_back(&inserted_block->second); next_blocks.push_back(&inserted_block->second);
} }
if (branch_instruction == SpvOpReturn ||
branch_instruction == SpvOpReturnValue) {
assert(next_blocks.empty());
next_blocks.push_back(&pseudo_exit_block_);
}
current_block_->RegisterBranchInstruction(branch_instruction); current_block_->RegisterBranchInstruction(branch_instruction);
current_block_->RegisterSuccessors(next_blocks); current_block_->RegisterSuccessors(next_blocks);
current_block_ = nullptr; current_block_ = nullptr;
@ -202,6 +231,11 @@ vector<BasicBlock*>& Function::get_blocks() { return ordered_blocks_; }
const BasicBlock* Function::get_current_block() const { return current_block_; } const BasicBlock* Function::get_current_block() const { return current_block_; }
BasicBlock* Function::get_current_block() { return current_block_; } BasicBlock* Function::get_current_block() { return current_block_; }
BasicBlock* Function::get_pseudo_exit_block() { return &pseudo_exit_block_; }
const BasicBlock* Function::get_pseudo_exit_block() const {
return &pseudo_exit_block_;
}
const list<Construct>& Function::get_constructs() const { const list<Construct>& Function::get_constructs() const {
return cfg_constructs_; return cfg_constructs_;
} }
@ -216,17 +250,32 @@ BasicBlock* Function::get_first_block() {
return ordered_blocks_[0]; return ordered_blocks_[0];
} }
bool Function::IsMergeBlock(uint32_t merge_block_id) const { bool Function::IsBlockType(uint32_t merge_block_id, BlockType type) const {
const auto b = blocks_.find(merge_block_id); bool ret = false;
const BasicBlock* block;
tie(block, ignore) = GetBlock(merge_block_id);
if (block) {
ret = block->is_type(type);
}
return ret;
}
pair<const BasicBlock*, bool> Function::GetBlock(uint32_t id) const {
const auto b = blocks_.find(id);
if (b != end(blocks_)) { if (b != end(blocks_)) {
return cfg_constructs_.end() != const BasicBlock* block = &(b->second);
find_if(begin(cfg_constructs_), end(cfg_constructs_), bool defined =
[&](const Construct& construct) { undefined_blocks_.find(block->get_id()) == end(undefined_blocks_);
return construct.get_merge() == &b->second; return make_pair(block, defined);
});
} else { } else {
return false; return make_pair(nullptr, false);
} }
} }
pair<BasicBlock*, bool> Function::GetBlock(uint32_t id) {
const BasicBlock* out;
bool defined;
tie(out, defined) = const_cast<const Function*>(this)->GetBlock(id);
return make_pair(const_cast<BasicBlock*>(out), defined);
}
} /// namespace libspirv } /// namespace libspirv

View File

@ -28,9 +28,9 @@
#define LIBSPIRV_VAL_FUNCTION_H_ #define LIBSPIRV_VAL_FUNCTION_H_
#include <list> #include <list>
#include <vector>
#include <unordered_set>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
#include <vector>
#include "spirv/1.1/spirv.h" #include "spirv/1.1/spirv.h"
#include "spirv-tools/libspirv.h" #include "spirv-tools/libspirv.h"
@ -100,12 +100,19 @@ class Function {
void RegisterBlockEnd(std::vector<uint32_t> successors_list, void RegisterBlockEnd(std::vector<uint32_t> successors_list,
SpvOp branch_instruction); SpvOp branch_instruction);
/// Returns true if the \p merge_block_id is a merge block /// Returns true if the \p id block is the first block of this function
bool IsMergeBlock(uint32_t merge_block_id) const;
/// Returns true if the \p id is the first block of this function
bool IsFirstBlock(uint32_t id) const; bool IsFirstBlock(uint32_t id) const;
/// Returns true if the \p merge_block_id is a BlockType of \p type
bool IsBlockType(uint32_t merge_block_id, BlockType type) const;
/// Returns a pair consisting of the BasicBlock with \p id and a bool
/// which is true if the block has been defined, and false if it is
/// declared but not defined. This function will return nullptr if the
/// \p id was not declared and not defined at the current point in the binary
std::pair<const BasicBlock*, bool> GetBlock(uint32_t id) const;
std::pair<BasicBlock*, bool> GetBlock(uint32_t id);
/// Returns the first block of the current function /// Returns the first block of the current function
const BasicBlock* get_first_block() const; const BasicBlock* get_first_block() const;
@ -142,6 +149,12 @@ class Function {
/// Returns the block that is currently being parsed in the binary /// Returns the block that is currently being parsed in the binary
const BasicBlock* get_current_block() const; const BasicBlock* get_current_block() const;
/// Returns the psudo exit block
BasicBlock* get_pseudo_exit_block();
/// Returns the psudo exit block
const BasicBlock* get_pseudo_exit_block() const;
/// Prints a GraphViz digraph of the CFG of the current funciton /// Prints a GraphViz digraph of the CFG of the current funciton
void printDotGraph() const; void printDotGraph() const;
@ -179,6 +192,9 @@ class Function {
/// The block that is currently being parsed /// The block that is currently being parsed
BasicBlock* current_block_; BasicBlock* current_block_;
/// A pseudo exit block that is the successor to all return blocks
BasicBlock pseudo_exit_block_;
/// The constructs that are available in this function /// The constructs that are available in this function
std::list<Construct> cfg_constructs_; std::list<Construct> cfg_constructs_;
@ -191,5 +207,4 @@ class Function {
} /// namespace libspirv } /// namespace libspirv
#endif /// LIBSPIRV_VAL_FUNCTION_H_ #endif /// LIBSPIRV_VAL_FUNCTION_H_

View File

@ -42,6 +42,9 @@
namespace libspirv { namespace libspirv {
// Universal Limit of ResultID + 1
static const uint32_t kInvalidId = 0x400000;
// Info about a result ID. // Info about a result ID.
typedef struct spv_id_info_t { typedef struct spv_id_info_t {
/// Id value. /// Id value.

View File

@ -29,6 +29,7 @@
#include <algorithm> #include <algorithm>
#include <array> #include <array>
#include <functional>
#include <list> #include <list>
#include <map> #include <map>
#include <string> #include <string>
@ -52,16 +53,25 @@ namespace libspirv {
class ValidationState_t; class ValidationState_t;
/// @brief Calculates dominator edges of a root basic block /// A function that returns a vector of BasicBlocks given a BasicBlock. Used to
/// get the successor and predecessor nodes of a CFG block
using get_blocks_func =
std::function<const std::vector<BasicBlock*>*(const BasicBlock*)>;
/// @brief Calculates dominator edges for a set of blocks
/// ///
/// This function calculates the dominator edges form a root BasicBlock. Uses /// This function calculates the dominator edges for a set of blocks in the CFG.
/// the dominator algorithm by Cooper et al. /// Uses the dominator algorithm by Cooper et al.
/// ///
/// @param[in] first_block the root or entry BasicBlock of a function /// @param[in] postorder A vector of blocks in post order traversal order
/// in a CFG
/// @param[in] predecessor_func Function used to get the predecessor nodes of a
/// block
/// ///
/// @return a set of dominator edges represented as a pair of blocks /// @return a set of dominator edges represented as a pair of blocks
std::vector<std::pair<BasicBlock*, BasicBlock*>> CalculateDominators( std::vector<std::pair<BasicBlock*, BasicBlock*>> CalculateDominators(
const BasicBlock& first_block); const std::vector<const BasicBlock*>& postorder,
get_blocks_func predecessor_func);
/// @brief Performs the Control Flow Graph checks /// @brief Performs the Control Flow Graph checks
/// ///
@ -76,8 +86,11 @@ spv_result_t PerformCfgChecks(ValidationState_t& _);
/// provided by the @p dom_edges parameter /// provided by the @p dom_edges parameter
/// ///
/// @param[in,out] dom_edges The edges of the dominator tree /// @param[in,out] dom_edges The edges of the dominator tree
/// @param[in] set_func This function will be called to updated the Immediate
/// dominator
void UpdateImmediateDominators( void UpdateImmediateDominators(
std::vector<std::pair<BasicBlock*, BasicBlock*>>& dom_edges); const std::vector<std::pair<BasicBlock*, BasicBlock*>>& dom_edges,
std::function<void(BasicBlock*, BasicBlock*)> set_func);
/// @brief Prints all of the dominators of a BasicBlock /// @brief Prints all of the dominators of a BasicBlock
/// ///

View File

@ -30,6 +30,8 @@
#include <algorithm> #include <algorithm>
#include <functional> #include <functional>
#include <set>
#include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <utility> #include <utility>
@ -43,9 +45,13 @@
using std::find; using std::find;
using std::function; using std::function;
using std::get; using std::get;
using std::ignore;
using std::make_pair; using std::make_pair;
using std::numeric_limits; using std::numeric_limits;
using std::pair; using std::pair;
using std::set;
using std::string;
using std::tie;
using std::transform; using std::transform;
using std::unordered_map; using std::unordered_map;
using std::unordered_set; using std::unordered_set;
@ -61,8 +67,6 @@ using bb_ptr = BasicBlock*;
using cbb_ptr = const BasicBlock*; using cbb_ptr = const BasicBlock*;
using bb_iter = vector<BasicBlock*>::const_iterator; using bb_iter = vector<BasicBlock*>::const_iterator;
using get_blocks_func = function<const vector<BasicBlock*>*(const BasicBlock*)>;
struct block_info { struct block_info {
cbb_ptr block; ///< pointer to the block cbb_ptr block; ///< pointer to the block
bb_iter iter; ///< Iterator to the current child node being processed bb_iter iter; ///< Iterator to the current child node being processed
@ -92,8 +96,8 @@ bool FindInWorkList(const vector<block_info>& work_list, uint32_t id) {
/// @param[in] entry The root BasicBlock of a CFG tree /// @param[in] entry The root BasicBlock of a CFG tree
/// @param[in] successor_func A function which will return a pointer to the /// @param[in] successor_func A function which will return a pointer to the
/// successor nodes /// successor nodes
/// @param[in] preorder A function that will be called for every block in a CFG /// @param[in] preorder A function that will be called for every block in a
/// following preorder traversal semantics /// CFG following preorder traversal semantics
/// @param[in] postorder A function that will be called for every block in a /// @param[in] postorder A function that will be called for every block in a
/// CFG following postorder traversal semantics /// CFG following postorder traversal semantics
/// @param[in] backedge A function that will be called when a backedge is /// @param[in] backedge A function that will be called when a backedge is
@ -143,45 +147,44 @@ const vector<BasicBlock*>* successor(const BasicBlock* b) {
return b->get_successors(); return b->get_successors();
} }
const vector<BasicBlock*>* predecessor(const BasicBlock* b) {
return b->get_predecessors();
}
} // namespace } // namespace
vector<pair<BasicBlock*, BasicBlock*>> CalculateDominators( vector<pair<BasicBlock*, BasicBlock*>> CalculateDominators(
vector<cbb_ptr>& postorder) { const vector<cbb_ptr>& postorder, get_blocks_func predecessor_func) {
struct block_detail { struct block_detail {
size_t dominator; ///< The index of blocks's dominator in post order array size_t dominator; ///< The index of blocks's dominator in post order array
size_t postorder_index; ///< The index of the block in the post order array size_t postorder_index; ///< The index of the block in the post order array
}; };
const size_t undefined_dom = postorder.size();
const size_t undefined_dom = static_cast<size_t>(postorder.size());
unordered_map<cbb_ptr, block_detail> idoms; unordered_map<cbb_ptr, block_detail> idoms;
for (size_t i = 0; i < postorder.size(); i++) { for (size_t i = 0; i < postorder.size(); i++) {
idoms[postorder[i]] = {undefined_dom, i}; idoms[postorder[i]] = {undefined_dom, i};
} }
idoms[postorder.back()].dominator = idoms[postorder.back()].postorder_index; idoms[postorder.back()].dominator = idoms[postorder.back()].postorder_index;
bool changed = true; bool changed = true;
while (changed) { while (changed) {
changed = false; changed = false;
for (auto b = postorder.rbegin() + 1; b != postorder.rend(); b++) { for (auto b = postorder.rbegin() + 1; b != postorder.rend(); b++) {
size_t& b_dom = idoms[*b].dominator; const vector<BasicBlock*>* predecessors = predecessor_func(*b);
const vector<BasicBlock*>* predecessors = (*b)->get_predecessors(); // first processed/reachable predecessor
// first processed predecessor
auto res = find_if(begin(*predecessors), end(*predecessors), auto res = find_if(begin(*predecessors), end(*predecessors),
[&idoms, undefined_dom](BasicBlock* pred) { [&idoms, undefined_dom](BasicBlock* pred) {
return idoms[pred].dominator != undefined_dom; return idoms[pred].dominator != undefined_dom &&
pred->is_reachable();
}); });
assert(res != end(*predecessors)); if (res == end(*predecessors)) continue;
BasicBlock* idom = *res; BasicBlock* idom = *res;
size_t idom_idx = idoms[idom].postorder_index; size_t idom_idx = idoms[idom].postorder_index;
// all other predecessors // all other predecessors
for (auto p : *predecessors) { for (auto p : *predecessors) {
if (idom == p || p->is_reachable() == false) { if (idom == p || p->is_reachable() == false) continue;
continue;
}
if (idoms[p].dominator != undefined_dom) { if (idoms[p].dominator != undefined_dom) {
size_t finger1 = idoms[p].postorder_index; size_t finger1 = idoms[p].postorder_index;
size_t finger2 = idom_idx; size_t finger2 = idom_idx;
@ -196,8 +199,8 @@ vector<pair<BasicBlock*, BasicBlock*>> CalculateDominators(
idom_idx = finger1; idom_idx = finger1;
} }
} }
if (b_dom != idom_idx) { if (idoms[*b].dominator != idom_idx) {
b_dom = idom_idx; idoms[*b].dominator = idom_idx;
changed = true; changed = true;
} }
} }
@ -213,13 +216,15 @@ vector<pair<BasicBlock*, BasicBlock*>> CalculateDominators(
return out; return out;
} }
void UpdateImmediateDominators(vector<pair<bb_ptr, bb_ptr>>& dom_edges) { void UpdateImmediateDominators(
const vector<pair<bb_ptr, bb_ptr>>& dom_edges,
function<void(BasicBlock*, BasicBlock*)> set_func) {
for (auto& edge : dom_edges) { for (auto& edge : dom_edges) {
get<0>(edge)->SetImmediateDominator(get<1>(edge)); set_func(get<0>(edge), get<1>(edge));
} }
} }
void printDominatorList(BasicBlock& b) { void printDominatorList(const BasicBlock& b) {
std::cout << b.get_id() << " is dominated by: "; std::cout << b.get_id() << " is dominated by: ";
const BasicBlock* bb = &b; const BasicBlock* bb = &b;
while (bb->GetImmediateDominator() != bb) { while (bb->GetImmediateDominator() != bb) {
@ -244,7 +249,7 @@ spv_result_t FirstBlockAssert(ValidationState_t& _, uint32_t target) {
} }
spv_result_t MergeBlockAssert(ValidationState_t& _, uint32_t merge_block) { spv_result_t MergeBlockAssert(ValidationState_t& _, uint32_t merge_block) {
if (_.get_current_function().IsMergeBlock(merge_block)) { if (_.get_current_function().IsBlockType(merge_block, kBlockTypeMerge)) {
return _.diag(SPV_ERROR_INVALID_CFG) return _.diag(SPV_ERROR_INVALID_CFG)
<< "Block " << _.getIdName(merge_block) << "Block " << _.getIdName(merge_block)
<< " is already a merge block for another header"; << " is already a merge block for another header";
@ -252,21 +257,188 @@ spv_result_t MergeBlockAssert(ValidationState_t& _, uint32_t merge_block) {
return SPV_SUCCESS; return SPV_SUCCESS;
} }
/// Update the continue construct's exit blocks once the backedge blocks are
/// identified in the CFG.
void UpdateContinueConstructExitBlocks(
Function& function, const vector<pair<uint32_t, uint32_t>>& back_edges) {
auto& constructs = function.get_constructs();
// TODO(umar): Think of a faster way to do this
for (auto& edge : back_edges) {
uint32_t back_edge_block_id;
uint32_t loop_header_block_id;
tie(back_edge_block_id, loop_header_block_id) = edge;
auto is_this_header = [=](Construct& c) {
return c.get_type() == ConstructType::kLoop &&
c.get_entry()->get_id() == loop_header_block_id;
};
for (auto construct : constructs) {
if (is_this_header(construct)) {
Construct* continue_construct =
construct.get_corresponding_constructs().back();
assert(continue_construct->get_type() == ConstructType::kContinue);
BasicBlock* back_edge_block;
tie(back_edge_block, ignore) = function.GetBlock(back_edge_block_id);
continue_construct->set_exit(back_edge_block);
}
}
}
}
/// Constructs an error message for construct validation errors
string ConstructErrorString(const Construct& construct,
const string& header_string,
const string& exit_string,
bool post_dominate = false) {
string construct_name;
string header_name;
string exit_name;
string dominate_text;
if (post_dominate) {
dominate_text = "is not post dominated by";
} else {
dominate_text = "does not dominate";
}
switch (construct.get_type()) {
case ConstructType::kSelection:
construct_name = "selection";
header_name = "selection header";
exit_name = "merge block";
break;
case ConstructType::kLoop:
construct_name = "loop";
header_name = "loop header";
exit_name = "merge block";
break;
case ConstructType::kContinue:
construct_name = "continue";
header_name = "continue target";
exit_name = "back-edge block";
break;
case ConstructType::kCase:
construct_name = "case";
header_name = "case block";
exit_name = "exit block"; // TODO(umar): there has to be a better name
break;
default:
assert(1 == 0 && "Not defined type");
}
// TODO(umar): Add header block for continue constructs to error message
return "The " + construct_name + " construct with the " + header_name + " " +
header_string + " " + dominate_text + " the " + exit_name + " " +
exit_string;
}
spv_result_t StructuredControlFlowChecks(
const ValidationState_t& _, const Function& function,
const vector<pair<uint32_t, uint32_t>>& back_edges) {
/// Check all backedges target only loop headers and have exactly one
/// back-edge branching to it
set<uint32_t> loop_headers;
for (auto back_edge : back_edges) {
uint32_t back_edge_block;
uint32_t header_block;
tie(back_edge_block, header_block) = back_edge;
if (!function.IsBlockType(header_block, kBlockTypeLoop)) {
return _.diag(SPV_ERROR_INVALID_CFG)
<< "Back-edges (" << _.getIdName(back_edge_block) << " -> "
<< _.getIdName(header_block)
<< ") can only be formed between a block and a loop header.";
}
bool success;
tie(ignore, success) = loop_headers.insert(header_block);
if (!success) {
// TODO(umar): List the back-edge blocks that are branching to loop
// header
return _.diag(SPV_ERROR_INVALID_CFG)
<< "Loop header " << _.getIdName(header_block)
<< " targeted by multiple back-edges";
}
}
// Check construct rules
for (const Construct& construct : function.get_constructs()) {
auto header = construct.get_entry();
auto merge = construct.get_exit();
// if the merge block is reachable then it's dominated by the header
if (merge->is_reachable() &&
find(merge->dom_begin(), merge->dom_end(), header) ==
merge->dom_end()) {
return _.diag(SPV_ERROR_INVALID_CFG)
<< ConstructErrorString(construct, _.getIdName(header->get_id()),
_.getIdName(merge->get_id()));
}
if (construct.get_type() == ConstructType::kContinue) {
if (find(header->pdom_begin(), header->pdom_end(), merge) ==
merge->pdom_end()) {
return _.diag(SPV_ERROR_INVALID_CFG)
<< ConstructErrorString(construct, _.getIdName(header->get_id()),
_.getIdName(merge->get_id()), true);
}
}
// TODO(umar): an OpSwitch block dominates all its defined case
// constructs
// TODO(umar): each case construct has at most one branch to another
// case construct
// TODO(umar): each case construct is branched to by at most one other
// case construct
// TODO(umar): if Target T1 branches to Target T2, or if Target T1
// branches to the Default and the Default branches to Target T2, then
// T1 must immediately precede T2 in the list of the OpSwitch Target
// operands
}
return SPV_SUCCESS;
}
spv_result_t PerformCfgChecks(ValidationState_t& _) { spv_result_t PerformCfgChecks(ValidationState_t& _) {
for (auto& function : _.get_functions()) { for (auto& function : _.get_functions()) {
// Check all referenced blocks are defined within a function
if (function.get_undefined_block_count() != 0) {
string undef_blocks("{");
for (auto undefined_block : function.get_undefined_blocks()) {
undef_blocks += _.getIdName(undefined_block) + " ";
}
return _.diag(SPV_ERROR_INVALID_CFG)
<< "Block(s) " << undef_blocks << "\b}"
<< " are referenced but not defined in function "
<< _.getIdName(function.get_id());
}
// Updates each blocks immediate dominators // Updates each blocks immediate dominators
vector<const BasicBlock*> postorder; vector<const BasicBlock*> postorder;
vector<const BasicBlock*> postdom_postorder;
vector<pair<uint32_t, uint32_t>> back_edges; vector<pair<uint32_t, uint32_t>> back_edges;
if (auto* first_block = function.get_first_block()) { if (auto* first_block = function.get_first_block()) {
/// calculate dominators
DepthFirstTraversal(*first_block, successor, [](cbb_ptr) {}, DepthFirstTraversal(*first_block, successor, [](cbb_ptr) {},
[&](cbb_ptr b) { postorder.push_back(b); }, [&](cbb_ptr b) { postorder.push_back(b); },
[&](cbb_ptr from, cbb_ptr to) { [&](cbb_ptr from, cbb_ptr to) {
back_edges.emplace_back(from->get_id(), back_edges.emplace_back(from->get_id(),
to->get_id()); to->get_id());
}); });
auto edges = libspirv::CalculateDominators(postorder); auto edges = libspirv::CalculateDominators(postorder, predecessor);
libspirv::UpdateImmediateDominators(edges); libspirv::UpdateImmediateDominators(
edges, [](bb_ptr block, bb_ptr dominator) {
block->SetImmediateDominator(dominator);
});
/// calculate post dominators
auto exit_block = function.get_pseudo_exit_block();
DepthFirstTraversal(*exit_block, predecessor, [](cbb_ptr) {},
[&](cbb_ptr b) { postdom_postorder.push_back(b); },
[&](cbb_ptr, cbb_ptr) {});
auto postdom_edges =
libspirv::CalculateDominators(postdom_postorder, successor);
libspirv::UpdateImmediateDominators(
postdom_edges, [](bb_ptr block, bb_ptr dominator) {
block->SetImmediatePostDominator(dominator);
});
} }
UpdateContinueConstructExitBlocks(function, back_edges);
// Check if the order of blocks in the binary appear before the blocks they // Check if the order of blocks in the binary appear before the blocks they
// dominate // dominate
@ -284,41 +456,10 @@ spv_result_t PerformCfgChecks(ValidationState_t& _) {
} }
} }
// Check all referenced blocks are defined within a function /// Structured control flow checks are only required for shader capabilities
if (function.get_undefined_block_count() != 0) { if (_.hasCapability(SpvCapabilityShader)) {
std::stringstream ss; spvCheckReturn(StructuredControlFlowChecks(_, function, back_edges));
ss << "{";
for (auto undefined_block : function.get_undefined_blocks()) {
ss << _.getIdName(undefined_block) << " ";
}
return _.diag(SPV_ERROR_INVALID_CFG)
<< "Block(s) " << ss.str() << "\b}"
<< " are referenced but not defined in function "
<< _.getIdName(function.get_id());
} }
// Check all headers dominate their merge blocks
for (Construct& construct : function.get_constructs()) {
auto header = construct.get_header();
auto merge = construct.get_merge();
// auto cont = construct.get_continue();
if (merge->is_reachable() &&
find(merge->dom_begin(), merge->dom_end(), header) ==
merge->dom_end()) {
return _.diag(SPV_ERROR_INVALID_CFG)
<< "Header block " << _.getIdName(header->get_id())
<< " doesn't dominate its merge block "
<< _.getIdName(merge->get_id());
}
}
// TODO(umar): All CFG back edges must branch to a loop header, with each
// loop header having exactly one back edge branching to it
// TODO(umar): For a given loop, its back-edge block must post dominate the
// OpLoopMerge's Continue Target, and that Continue Target must dominate the
// back-edge block
} }
return SPV_SUCCESS; return SPV_SUCCESS;
} }
@ -331,7 +472,6 @@ spv_result_t CfgPass(ValidationState_t& _,
spvCheckReturn(_.get_current_function().RegisterBlock(inst->result_id)); spvCheckReturn(_.get_current_function().RegisterBlock(inst->result_id));
break; break;
case SpvOpLoopMerge: { case SpvOpLoopMerge: {
// TODO(umar): mark current block as a loop header
uint32_t merge_block = inst->words[inst->operands[0].offset]; uint32_t merge_block = inst->words[inst->operands[0].offset];
uint32_t continue_block = inst->words[inst->operands[1].offset]; uint32_t continue_block = inst->words[inst->operands[1].offset];
CFG_ASSERT(MergeBlockAssert, merge_block); CFG_ASSERT(MergeBlockAssert, merge_block);

View File

@ -56,7 +56,7 @@ using ::testing::MatchesRegex;
using libspirv::BasicBlock; using libspirv::BasicBlock;
using libspirv::ValidationState_t; using libspirv::ValidationState_t;
using ValidateCFG = spvtest::ValidateBase<bool>; using ValidateCFG = spvtest::ValidateBase<SpvCapability>;
using spvtest::ScopedContext; using spvtest::ScopedContext;
namespace { namespace {
@ -160,34 +160,52 @@ Block& operator>>(Block& lhs, Block& successor) {
return lhs; return lhs;
} }
string header = const char* header(SpvCapability cap) {
"OpCapability Shader\n" static const char* shader_header =
"OpMemoryModel Logical GLSL450\n"; "OpCapability Shader\n"
"OpMemoryModel Logical GLSL450\n";
string types_consts = static const char* kernel_header =
"%voidt = OpTypeVoid\n" "OpCapability Kernel\n"
"%boolt = OpTypeBool\n" "OpMemoryModel Logical OpenCL\n";
"%intt = OpTypeInt 32 1\n"
"%one = OpConstant %intt 1\n"
"%two = OpConstant %intt 2\n"
"%ptrt = OpTypePointer Function %intt\n"
"%funct = OpTypeFunction %voidt\n";
TEST_F(ValidateCFG, Simple) { return (cap == SpvCapabilityShader) ? shader_header : kernel_header;
Block first("first"); }
const char* types_consts() {
static const char* types =
"%voidt = OpTypeVoid\n"
"%boolt = OpTypeBool\n"
"%intt = OpTypeInt 32 1\n"
"%one = OpConstant %intt 1\n"
"%two = OpConstant %intt 2\n"
"%ptrt = OpTypePointer Function %intt\n"
"%funct = OpTypeFunction %voidt\n";
return types;
}
INSTANTIATE_TEST_CASE_P(StructuredControlFlow, ValidateCFG,
::testing::Values(SpvCapabilityShader,
SpvCapabilityKernel));
TEST_P(ValidateCFG, Simple) {
bool is_shader = GetParam() == SpvCapabilityShader;
Block entry("entry");
Block loop("loop", SpvOpBranchConditional); Block loop("loop", SpvOpBranchConditional);
Block cont("cont"); Block cont("cont");
Block merge("merge", SpvOpReturn); Block merge("merge", SpvOpReturn);
loop.setBody( entry.setBody("%cond = OpSLessThan %intt %one %two\n");
"%cond = OpSLessThan %intt %one %two\n" if (is_shader) {
"OpLoopMerge %merge %cont None\n"); loop.setBody("OpLoopMerge %merge %cont None\n");
}
string str = header + nameOps("loop", "first", "cont", "merge", string str = header(GetParam()) + nameOps("loop", "entry", "cont", "merge",
make_pair("func", "Main")) + make_pair("func", "Main")) +
types_consts + "%func = OpFunction %voidt None %funct\n"; types_consts() + "%func = OpFunction %voidt None %funct\n";
str += first >> loop; str += entry >> loop;
str += loop >> vector<Block>({cont, merge}); str += loop >> vector<Block>({cont, merge});
str += cont >> loop; str += cont >> loop;
str += merge; str += merge;
@ -197,15 +215,15 @@ TEST_F(ValidateCFG, Simple) {
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
} }
TEST_F(ValidateCFG, Variable) { TEST_P(ValidateCFG, Variable) {
Block entry("entry"); Block entry("entry");
Block cont("cont"); Block cont("cont");
Block exit("exit", SpvOpReturn); Block exit("exit", SpvOpReturn);
entry.setBody("%var = OpVariable %ptrt Function\n"); entry.setBody("%var = OpVariable %ptrt Function\n");
string str = header + nameOps(make_pair("func", "Main")) + types_consts + string str = header(GetParam()) + nameOps(make_pair("func", "Main")) +
" %func = OpFunction %voidt None %funct\n"; types_consts() + " %func = OpFunction %voidt None %funct\n";
str += entry >> cont; str += entry >> cont;
str += cont >> exit; str += cont >> exit;
str += exit; str += exit;
@ -215,7 +233,7 @@ TEST_F(ValidateCFG, Variable) {
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
} }
TEST_F(ValidateCFG, VariableNotInFirstBlockBad) { TEST_P(ValidateCFG, VariableNotInFirstBlockBad) {
Block entry("entry"); Block entry("entry");
Block cont("cont"); Block cont("cont");
Block exit("exit", SpvOpReturn); Block exit("exit", SpvOpReturn);
@ -223,8 +241,8 @@ TEST_F(ValidateCFG, VariableNotInFirstBlockBad) {
// This operation should only be performed in the entry block // This operation should only be performed in the entry block
cont.setBody("%var = OpVariable %ptrt Function\n"); cont.setBody("%var = OpVariable %ptrt Function\n");
string str = header + nameOps(make_pair("func", "Main")) + types_consts + string str = header(GetParam()) + nameOps(make_pair("func", "Main")) +
" %func = OpFunction %voidt None %funct\n"; types_consts() + " %func = OpFunction %voidt None %funct\n";
str += entry >> cont; str += entry >> cont;
str += cont >> exit; str += cont >> exit;
@ -239,18 +257,19 @@ TEST_F(ValidateCFG, VariableNotInFirstBlockBad) {
"Variables can only be defined in the first block of a function")); "Variables can only be defined in the first block of a function"));
} }
TEST_F(ValidateCFG, BlockAppearsBeforeDominatorBad) { TEST_P(ValidateCFG, BlockAppearsBeforeDominatorBad) {
bool is_shader = GetParam() == SpvCapabilityShader;
Block entry("entry"); Block entry("entry");
Block cont("cont"); Block cont("cont");
Block branch("branch", SpvOpBranchConditional); Block branch("branch", SpvOpBranchConditional);
Block merge("merge", SpvOpReturn); Block merge("merge", SpvOpReturn);
branch.setBody( entry.setBody("%cond = OpSLessThan %intt %one %two\n");
" %cond = OpSLessThan %intt %one %two\n" if (is_shader) branch.setBody("OpSelectionMerge %merge None\n");
"OpSelectionMerge %merge None\n");
string str = header + nameOps("cont", "branch", make_pair("func", "Main")) + string str = header(GetParam()) +
types_consts + "%func = OpFunction %voidt None %funct\n"; nameOps("cont", "branch", make_pair("func", "Main")) +
types_consts() + "%func = OpFunction %voidt None %funct\n";
str += entry >> branch; str += entry >> branch;
str += cont >> merge; // cont appears before its dominator str += cont >> merge; // cont appears before its dominator
@ -265,20 +284,22 @@ TEST_F(ValidateCFG, BlockAppearsBeforeDominatorBad) {
"before its dominator .\\[branch\\]")); "before its dominator .\\[branch\\]"));
} }
TEST_F(ValidateCFG, MergeBlockTargetedByMultipleHeaderBlocksBad) { TEST_P(ValidateCFG, MergeBlockTargetedByMultipleHeaderBlocksBad) {
bool is_shader = GetParam() == SpvCapabilityShader;
Block entry("entry"); Block entry("entry");
Block loop("loop"); Block loop("loop");
Block selection("selection", SpvOpBranchConditional); Block selection("selection", SpvOpBranchConditional);
Block merge("merge", SpvOpReturn); Block merge("merge", SpvOpReturn);
loop.setBody( entry.setBody("%cond = OpSLessThan %intt %one %two\n");
" %cond = OpSLessThan %intt %one %two\n" if (is_shader) loop.setBody(" OpLoopMerge %merge %loop None\n");
" OpLoopMerge %merge %loop None\n");
// cannot share the same merge
selection.setBody("OpSelectionMerge %merge None\n");
string str = header + nameOps("merge", make_pair("func", "Main")) + // cannot share the same merge
types_consts + "%func = OpFunction %voidt None %funct\n"; if (is_shader) selection.setBody("OpSelectionMerge %merge None\n");
string str = header(GetParam()) +
nameOps("merge", make_pair("func", "Main")) + types_consts() +
"%func = OpFunction %voidt None %funct\n";
str += entry >> loop; str += entry >> loop;
str += loop >> selection; str += loop >> selection;
@ -287,26 +308,32 @@ TEST_F(ValidateCFG, MergeBlockTargetedByMultipleHeaderBlocksBad) {
str += "OpFunctionEnd\n"; str += "OpFunctionEnd\n";
CompileSuccessfully(str); CompileSuccessfully(str);
ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); if (is_shader) {
EXPECT_THAT(getDiagnosticString(), ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
MatchesRegex("Block .\\[merge\\] is already a merge block " EXPECT_THAT(getDiagnosticString(),
"for another header")); MatchesRegex("Block .\\[merge\\] is already a merge block "
"for another header"));
} else {
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
} }
TEST_F(ValidateCFG, MergeBlockTargetedByMultipleHeaderBlocksSelectionBad) { TEST_P(ValidateCFG, MergeBlockTargetedByMultipleHeaderBlocksSelectionBad) {
bool is_shader = GetParam() == SpvCapabilityShader;
Block entry("entry"); Block entry("entry");
Block loop("loop", SpvOpBranchConditional); Block loop("loop", SpvOpBranchConditional);
Block selection("selection", SpvOpBranchConditional); Block selection("selection", SpvOpBranchConditional);
Block merge("merge", SpvOpReturn); Block merge("merge", SpvOpReturn);
selection.setBody( entry.setBody("%cond = OpSLessThan %intt %one %two\n");
" %cond = OpSLessThan %intt %one %two\n" if (is_shader) selection.setBody(" OpSelectionMerge %merge None\n");
" OpSelectionMerge %merge None\n");
// cannot share the same merge
loop.setBody(" OpLoopMerge %merge %loop None\n");
string str = header + nameOps("merge", make_pair("func", "Main")) + // cannot share the same merge
types_consts + "%func = OpFunction %voidt None %funct\n"; if (is_shader) loop.setBody(" OpLoopMerge %merge %loop None\n");
string str = header(GetParam()) +
nameOps("merge", make_pair("func", "Main")) + types_consts() +
"%func = OpFunction %voidt None %funct\n";
str += entry >> selection; str += entry >> selection;
str += selection >> vector<Block>({merge, loop}); str += selection >> vector<Block>({merge, loop});
@ -315,18 +342,23 @@ TEST_F(ValidateCFG, MergeBlockTargetedByMultipleHeaderBlocksSelectionBad) {
str += "OpFunctionEnd\n"; str += "OpFunctionEnd\n";
CompileSuccessfully(str); CompileSuccessfully(str);
ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); if (is_shader) {
EXPECT_THAT(getDiagnosticString(), ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
MatchesRegex("Block .\\[merge\\] is already a merge block " EXPECT_THAT(getDiagnosticString(),
"for another header")); MatchesRegex("Block .\\[merge\\] is already a merge block "
"for another header"));
} else {
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
} }
TEST_F(ValidateCFG, BranchTargetFirstBlockBad) { TEST_P(ValidateCFG, BranchTargetFirstBlockBad) {
Block entry("entry"); Block entry("entry");
Block bad("bad"); Block bad("bad");
Block end("end", SpvOpReturn); Block end("end", SpvOpReturn);
string str = header + nameOps("entry", "bad", make_pair("func", "Main")) + string str = header(GetParam()) +
types_consts + "%func = OpFunction %voidt None %funct\n"; nameOps("entry", "bad", make_pair("func", "Main")) +
types_consts() + "%func = OpFunction %voidt None %funct\n";
str += entry >> bad; str += entry >> bad;
str += bad >> entry; // Cannot target entry block str += bad >> entry; // Cannot target entry block
@ -340,17 +372,17 @@ TEST_F(ValidateCFG, BranchTargetFirstBlockBad) {
"is targeted by block .\\[bad\\]")); "is targeted by block .\\[bad\\]"));
} }
TEST_F(ValidateCFG, BranchConditionalTrueTargetFirstBlockBad) { TEST_P(ValidateCFG, BranchConditionalTrueTargetFirstBlockBad) {
Block entry("entry"); Block entry("entry");
Block bad("bad", SpvOpBranchConditional); Block bad("bad", SpvOpBranchConditional);
Block exit("exit", SpvOpReturn); Block exit("exit", SpvOpReturn);
bad.setBody( entry.setBody("%cond = OpSLessThan %intt %one %two\n");
" %cond = OpSLessThan %intt %one %two\n" bad.setBody(" OpLoopMerge %entry %exit None\n");
" OpLoopMerge %entry %exit None\n");
string str = header + nameOps("entry", "bad", make_pair("func", "Main")) + string str = header(GetParam()) +
types_consts + "%func = OpFunction %voidt None %funct\n"; nameOps("entry", "bad", make_pair("func", "Main")) +
types_consts() + "%func = OpFunction %voidt None %funct\n";
str += entry >> bad; str += entry >> bad;
str += bad >> vector<Block>({entry, exit}); // cannot target entry block str += bad >> vector<Block>({entry, exit}); // cannot target entry block
@ -364,19 +396,19 @@ TEST_F(ValidateCFG, BranchConditionalTrueTargetFirstBlockBad) {
"is targeted by block .\\[bad\\]")); "is targeted by block .\\[bad\\]"));
} }
TEST_F(ValidateCFG, BranchConditionalFalseTargetFirstBlockBad) { TEST_P(ValidateCFG, BranchConditionalFalseTargetFirstBlockBad) {
Block entry("entry"); Block entry("entry");
Block bad("bad", SpvOpBranchConditional); Block bad("bad", SpvOpBranchConditional);
Block t("t"); Block t("t");
Block merge("merge"); Block merge("merge");
Block end("end", SpvOpReturn); Block end("end", SpvOpReturn);
bad.setBody( entry.setBody("%cond = OpSLessThan %intt %one %two\n");
"%cond = OpSLessThan %intt %one %two\n" bad.setBody("OpLoopMerge %merge %cont None\n");
"OpLoopMerge %merge %cont None\n");
string str = header + nameOps("entry", "bad", make_pair("func", "Main")) + string str = header(GetParam()) +
types_consts + "%func = OpFunction %voidt None %funct\n"; nameOps("entry", "bad", make_pair("func", "Main")) +
types_consts() + "%func = OpFunction %voidt None %funct\n";
str += entry >> bad; str += entry >> bad;
str += bad >> vector<Block>({t, entry}); str += bad >> vector<Block>({t, entry});
@ -391,7 +423,7 @@ TEST_F(ValidateCFG, BranchConditionalFalseTargetFirstBlockBad) {
"is targeted by block .\\[bad\\]")); "is targeted by block .\\[bad\\]"));
} }
TEST_F(ValidateCFG, SwitchTargetFirstBlockBad) { TEST_P(ValidateCFG, SwitchTargetFirstBlockBad) {
Block entry("entry"); Block entry("entry");
Block bad("bad", SpvOpSwitch); Block bad("bad", SpvOpSwitch);
Block block1("block1"); Block block1("block1");
@ -401,12 +433,12 @@ TEST_F(ValidateCFG, SwitchTargetFirstBlockBad) {
Block merge("merge"); Block merge("merge");
Block end("end", SpvOpReturn); Block end("end", SpvOpReturn);
bad.setBody( entry.setBody("%cond = OpSLessThan %intt %one %two\n");
"%cond = OpSLessThan %intt %one %two\n" bad.setBody("OpSelectionMerge %merge None\n");
"OpSelectionMerge %merge None\n");
string str = header + nameOps("entry", "bad", make_pair("func", "Main")) + string str = header(GetParam()) +
types_consts + "%func = OpFunction %voidt None %funct\n"; nameOps("entry", "bad", make_pair("func", "Main")) +
types_consts() + "%func = OpFunction %voidt None %funct\n";
str += entry >> bad; str += entry >> bad;
str += bad >> vector<Block>({def, block1, block2, block3, entry}); str += bad >> vector<Block>({def, block1, block2, block3, entry});
@ -425,21 +457,21 @@ TEST_F(ValidateCFG, SwitchTargetFirstBlockBad) {
"is targeted by block .\\[bad\\]")); "is targeted by block .\\[bad\\]"));
} }
TEST_F(ValidateCFG, BranchToBlockInOtherFunctionBad) { TEST_P(ValidateCFG, BranchToBlockInOtherFunctionBad) {
Block entry("entry"); Block entry("entry");
Block middle("middle", SpvOpBranchConditional); Block middle("middle", SpvOpBranchConditional);
Block end("end", SpvOpReturn); Block end("end", SpvOpReturn);
middle.setBody( entry.setBody("%cond = OpSLessThan %intt %one %two\n");
"%cond = OpSLessThan %intt %one %two\n" middle.setBody("OpSelectionMerge %end None\n");
"OpSelectionMerge %end None\n");
Block entry2("entry2"); Block entry2("entry2");
Block middle2("middle2"); Block middle2("middle2");
Block end2("end2", SpvOpReturn); Block end2("end2", SpvOpReturn);
string str = header + nameOps("middle2", make_pair("func", "Main")) + string str = header(GetParam()) +
types_consts + "%func = OpFunction %voidt None %funct\n"; nameOps("middle2", make_pair("func", "Main")) + types_consts() +
"%func = OpFunction %voidt None %funct\n";
str += entry >> middle; str += entry >> middle;
str += middle >> vector<Block>({end, middle2}); str += middle >> vector<Block>({end, middle2});
@ -460,7 +492,8 @@ TEST_F(ValidateCFG, BranchToBlockInOtherFunctionBad) {
"defined in function .\\[Main\\]")); "defined in function .\\[Main\\]"));
} }
TEST_F(ValidateCFG, HeaderDoesntDominatesMergeBad) { TEST_P(ValidateCFG, HeaderDoesntDominatesMergeBad) {
bool is_shader = GetParam() == SpvCapabilityShader;
Block entry("entry"); Block entry("entry");
Block head("head", SpvOpBranchConditional); Block head("head", SpvOpBranchConditional);
Block f("f"); Block f("f");
@ -468,10 +501,11 @@ TEST_F(ValidateCFG, HeaderDoesntDominatesMergeBad) {
entry.setBody("%cond = OpSLessThan %intt %one %two\n"); entry.setBody("%cond = OpSLessThan %intt %one %two\n");
head.setBody("OpSelectionMerge %merge None\n"); if (is_shader) head.setBody("OpSelectionMerge %merge None\n");
string str = header + nameOps("head", "merge", make_pair("func", "Main")) + string str = header(GetParam()) +
types_consts + "%func = OpFunction %voidt None %funct\n"; nameOps("head", "merge", make_pair("func", "Main")) +
types_consts() + "%func = OpFunction %voidt None %funct\n";
str += entry >> merge; str += entry >> merge;
str += head >> vector<Block>({merge, f}); str += head >> vector<Block>({merge, f});
@ -479,26 +513,33 @@ TEST_F(ValidateCFG, HeaderDoesntDominatesMergeBad) {
str += merge; str += merge;
CompileSuccessfully(str); CompileSuccessfully(str);
ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
EXPECT_THAT( if (is_shader) {
getDiagnosticString(), ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
MatchesRegex("Header block .\\[head\\] doesn't dominate its merge block " EXPECT_THAT(
".\\[merge\\]")); getDiagnosticString(),
MatchesRegex("The selection construct with the selection header "
".\\[head\\] does not dominate the merge block "
".\\[merge\\]"));
} else {
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
} }
TEST_F(ValidateCFG, UnreachableMerge) { TEST_P(ValidateCFG, UnreachableMerge) {
bool is_shader = GetParam() == SpvCapabilityShader;
Block entry("entry"); Block entry("entry");
Block branch("branch", SpvOpBranchConditional); Block branch("branch", SpvOpBranchConditional);
Block t("t", SpvOpReturn); Block t("t", SpvOpReturn);
Block f("f", SpvOpReturn); Block f("f", SpvOpReturn);
Block merge("merge", SpvOpReturn); Block merge("merge", SpvOpReturn);
branch.setBody( entry.setBody("%cond = OpSLessThan %intt %one %two\n");
" %cond = OpSLessThan %intt %one %two\n" if (is_shader) branch.setBody("OpSelectionMerge %merge None\n");
"OpSelectionMerge %merge None\n");
string str = header + nameOps("branch", "merge", make_pair("func", "Main")) + string str = header(GetParam()) +
types_consts + "%func = OpFunction %voidt None %funct\n"; nameOps("branch", "merge", make_pair("func", "Main")) +
types_consts() + "%func = OpFunction %voidt None %funct\n";
str += entry >> branch; str += entry >> branch;
str += branch >> vector<Block>({t, f}); str += branch >> vector<Block>({t, f});
@ -511,19 +552,20 @@ TEST_F(ValidateCFG, UnreachableMerge) {
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
} }
TEST_F(ValidateCFG, UnreachableMergeDefinedByOpUnreachable) { TEST_P(ValidateCFG, UnreachableMergeDefinedByOpUnreachable) {
bool is_shader = GetParam() == SpvCapabilityShader;
Block entry("entry"); Block entry("entry");
Block branch("branch", SpvOpBranchConditional); Block branch("branch", SpvOpBranchConditional);
Block t("t", SpvOpReturn); Block t("t", SpvOpReturn);
Block f("f", SpvOpReturn); Block f("f", SpvOpReturn);
Block merge("merge", SpvOpUnreachable); Block merge("merge", SpvOpUnreachable);
branch.setBody( entry.setBody("%cond = OpSLessThan %intt %one %two\n");
" %cond = OpSLessThan %intt %one %two\n" if (is_shader) branch.setBody("OpSelectionMerge %merge None\n");
"OpSelectionMerge %merge None\n");
string str = header + nameOps("branch", "merge", make_pair("func", "Main")) + string str = header(GetParam()) +
types_consts + "%func = OpFunction %voidt None %funct\n"; nameOps("branch", "merge", make_pair("func", "Main")) +
types_consts() + "%func = OpFunction %voidt None %funct\n";
str += entry >> branch; str += entry >> branch;
str += branch >> vector<Block>({t, f}); str += branch >> vector<Block>({t, f});
@ -536,14 +578,14 @@ TEST_F(ValidateCFG, UnreachableMergeDefinedByOpUnreachable) {
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
} }
TEST_F(ValidateCFG, UnreachableBlock) { TEST_P(ValidateCFG, UnreachableBlock) {
Block entry("entry"); Block entry("entry");
Block unreachable("unreachable"); Block unreachable("unreachable");
Block exit("exit", SpvOpReturn); Block exit("exit", SpvOpReturn);
string str = header + string str = header(GetParam()) +
nameOps("unreachable", "exit", make_pair("func", "Main")) + nameOps("unreachable", "exit", make_pair("func", "Main")) +
types_consts + "%func = OpFunction %voidt None %funct\n"; types_consts() + "%func = OpFunction %voidt None %funct\n";
str += entry >> exit; str += entry >> exit;
str += unreachable >> exit; str += unreachable >> exit;
@ -554,7 +596,8 @@ TEST_F(ValidateCFG, UnreachableBlock) {
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
} }
TEST_F(ValidateCFG, UnreachableBranch) { TEST_P(ValidateCFG, UnreachableBranch) {
bool is_shader = GetParam() == SpvCapabilityShader;
Block entry("entry"); Block entry("entry");
Block unreachable("unreachable", SpvOpBranchConditional); Block unreachable("unreachable", SpvOpBranchConditional);
Block unreachablechildt("unreachablechildt"); Block unreachablechildt("unreachablechildt");
@ -562,12 +605,11 @@ TEST_F(ValidateCFG, UnreachableBranch) {
Block merge("merge"); Block merge("merge");
Block exit("exit", SpvOpReturn); Block exit("exit", SpvOpReturn);
unreachable.setBody( entry.setBody("%cond = OpSLessThan %intt %one %two\n");
" %cond = OpSLessThan %intt %one %two\n" if (is_shader) unreachable.setBody("OpSelectionMerge %merge None\n");
"OpSelectionMerge %merge None\n"); string str = header(GetParam()) +
string str = header +
nameOps("unreachable", "exit", make_pair("func", "Main")) + nameOps("unreachable", "exit", make_pair("func", "Main")) +
types_consts + "%func = OpFunction %voidt None %funct\n"; types_consts() + "%func = OpFunction %voidt None %funct\n";
str += entry >> exit; str += entry >> exit;
str += unreachable >> vector<Block>({unreachablechildt, unreachablechildf}); str += unreachable >> vector<Block>({unreachablechildt, unreachablechildf});
@ -581,25 +623,25 @@ TEST_F(ValidateCFG, UnreachableBranch) {
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
} }
TEST_F(ValidateCFG, EmptyFunction) { TEST_P(ValidateCFG, EmptyFunction) {
string str = header + types_consts + string str = header(GetParam()) + string(types_consts()) +
"%func = OpFunction %voidt None %funct\n" + "OpFunctionEnd\n"; "%func = OpFunction %voidt None %funct\n" + "OpFunctionEnd\n";
CompileSuccessfully(str); CompileSuccessfully(str);
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
} }
TEST_F(ValidateCFG, SingleBlockLoop) { TEST_P(ValidateCFG, SingleBlockLoop) {
bool is_shader = GetParam() == SpvCapabilityShader;
Block entry("entry"); Block entry("entry");
Block loop("loop", SpvOpBranchConditional); Block loop("loop", SpvOpBranchConditional);
Block exit("exit", SpvOpReturn); Block exit("exit", SpvOpReturn);
loop.setBody( entry.setBody("%cond = OpSLessThan %intt %one %two\n");
"%cond = OpSLessThan %intt %one %two\n" if (is_shader) loop.setBody("OpLoopMerge %exit %loop None\n");
"OpLoopMerge %exit %loop None\n");
string str = string str = header(GetParam()) + string(types_consts()) +
header + types_consts + "%func = OpFunction %voidt None %funct\n"; "%func = OpFunction %voidt None %funct\n";
str += entry >> loop; str += entry >> loop;
str += loop >> vector<Block>({loop, exit}); str += loop >> vector<Block>({loop, exit});
@ -610,7 +652,8 @@ TEST_F(ValidateCFG, SingleBlockLoop) {
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
} }
TEST_F(ValidateCFG, NestedLoops) { TEST_P(ValidateCFG, NestedLoops) {
bool is_shader = GetParam() == SpvCapabilityShader;
Block entry("entry"); Block entry("entry");
Block loop1("loop1"); Block loop1("loop1");
Block loop1_cont_break_block("loop1_cont_break_block", Block loop1_cont_break_block("loop1_cont_break_block",
@ -620,14 +663,14 @@ TEST_F(ValidateCFG, NestedLoops) {
Block loop1_merge("loop1_merge"); Block loop1_merge("loop1_merge");
Block exit("exit", SpvOpReturn); Block exit("exit", SpvOpReturn);
loop1.setBody( entry.setBody("%cond = OpSLessThan %intt %one %two\n");
"%cond = OpSLessThan %intt %one %two\n" if (is_shader) {
"OpLoopMerge %loop1_merge %loop2 None\n"); loop1.setBody("OpLoopMerge %loop1_merge %loop2 None\n");
loop2.setBody("OpLoopMerge %loop2_merge %loop2 None\n");
}
loop2.setBody("OpLoopMerge %loop2_merge %loop2 None\n"); string str = header(GetParam()) + nameOps("loop2", "loop2_merge") +
types_consts() + "%func = OpFunction %voidt None %funct\n";
string str =
header + types_consts + "%func = OpFunction %voidt None %funct\n";
str += entry >> loop1; str += entry >> loop1;
str += loop1 >> loop1_cont_break_block; str += loop1 >> loop1_cont_break_block;
@ -641,29 +684,33 @@ TEST_F(ValidateCFG, NestedLoops) {
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
} }
TEST_F(ValidateCFG, NestedSelection) { TEST_P(ValidateCFG, NestedSelection) {
bool is_shader = GetParam() == SpvCapabilityShader;
Block entry("entry"); Block entry("entry");
const int N = 256; const int N = 256;
vector<Block> if_blocks; vector<Block> if_blocks;
vector<Block> merge_blocks; vector<Block> merge_blocks;
Block inner("inner"); Block inner("inner");
entry.setBody("%cond = OpSLessThan %intt %one %two\n");
if_blocks.emplace_back("if0", SpvOpBranchConditional); if_blocks.emplace_back("if0", SpvOpBranchConditional);
if_blocks[0].setBody(
"%cond = OpSLessThan %intt %one %two\n" if (is_shader) if_blocks[0].setBody("OpSelectionMerge %if_merge0 None\n");
"OpSelectionMerge %if_merge0 None\n");
merge_blocks.emplace_back("if_merge0", SpvOpReturn); merge_blocks.emplace_back("if_merge0", SpvOpReturn);
for (int i = 1; i < N; i++) { for (int i = 1; i < N; i++) {
stringstream ss; stringstream ss;
ss << i; ss << i;
if_blocks.emplace_back("if" + ss.str(), SpvOpBranchConditional); if_blocks.emplace_back("if" + ss.str(), SpvOpBranchConditional);
if_blocks[i].setBody("OpSelectionMerge %if_merge" + ss.str() + " None\n"); if (is_shader)
if_blocks[i].setBody("OpSelectionMerge %if_merge" + ss.str() + " None\n");
merge_blocks.emplace_back("if_merge" + ss.str(), SpvOpBranch); merge_blocks.emplace_back("if_merge" + ss.str(), SpvOpBranch);
} }
string str = string str = header(GetParam()) + string(types_consts()) +
header + types_consts + "%func = OpFunction %voidt None %funct\n"; "%func = OpFunction %voidt None %funct\n";
str += entry >> if_blocks[0];
for (int i = 0; i < N - 1; i++) { for (int i = 0; i < N - 1; i++) {
str += if_blocks[i] >> vector<Block>({if_blocks[i + 1], merge_blocks[i]}); str += if_blocks[i] >> vector<Block>({if_blocks[i + 1], merge_blocks[i]});
} }
@ -679,37 +726,282 @@ TEST_F(ValidateCFG, NestedSelection) {
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
} }
// TODO(umar): enable this test TEST_P(ValidateCFG, BackEdgeBlockDoesntPostDominateContinueTargetBad) {
TEST_F(ValidateCFG, DISABLED_BackEdgeBlockDoesntPostDominateContinueTargetBad) { bool is_shader = GetParam() == SpvCapabilityShader;
Block entry("entry"); Block entry("entry");
Block loop1("loop1", SpvOpBranchConditional); Block loop1("loop1", SpvOpBranchConditional);
Block loop2("loop2", SpvOpBranchConditional); Block loop2("loop2", SpvOpBranchConditional);
Block loop2_merge("loop2_merge"); Block loop2_merge("loop2_merge", SpvOpBranchConditional);
Block loop1_merge("loop1_merge", SpvOpBranchConditional); Block be_block("be_block");
Block exit("exit", SpvOpReturn); Block exit("exit", SpvOpReturn);
loop1.setBody( entry.setBody("%cond = OpSLessThan %intt %one %two\n");
"%cond = OpSLessThan %intt %one %two\n" if (is_shader) {
"OpLoopMerge %loop1_merge %loop2 None\n"); loop1.setBody("OpLoopMerge %exit %loop2_merge None\n");
loop2.setBody("OpLoopMerge %loop2_merge %loop2 None\n");
}
loop2.setBody("OpLoopMerge %loop2_merge %loop2 None\n"); string str = header(GetParam()) +
nameOps("loop1", "loop2", "be_block", "loop2_merge") +
string str = types_consts() + "%func = OpFunction %voidt None %funct\n";
header + types_consts + "%func = OpFunction %voidt None %funct\n";
str += entry >> loop1; str += entry >> loop1;
str += loop1 >> vector<Block>({loop2, loop1_merge}); str += loop1 >> vector<Block>({loop2, exit});
str += loop2 >> vector<Block>({loop2, loop2_merge}); str += loop2 >> vector<Block>({loop2, loop2_merge});
str += loop2_merge >> loop1_merge; str += loop2_merge >> vector<Block>({be_block, exit});
str += loop1_merge >> vector<Block>({loop1, exit}); str += be_block >> loop1;
str += exit; str += exit;
str += "OpFunctionEnd"; str += "OpFunctionEnd";
CompileSuccessfully(str); CompileSuccessfully(str);
ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); if (GetParam() == SpvCapabilityShader) {
ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
MatchesRegex("The continue construct with the continue target "
".\\[loop2_merge\\] is not post dominated by the "
"back-edge block .\\[be_block\\]"));
} else {
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
}
TEST_P(ValidateCFG, BranchingToNonLoopHeaderBlockBad) {
bool is_shader = GetParam() == SpvCapabilityShader;
Block entry("entry");
Block split("split", SpvOpBranchConditional);
Block t("t");
Block f("f");
Block exit("exit", SpvOpReturn);
entry.setBody("%cond = OpSLessThan %intt %one %two\n");
if (is_shader) split.setBody("OpSelectionMerge %exit None\n");
string str = header(GetParam()) + nameOps("split", "f") + types_consts() +
"%func = OpFunction %voidt None %funct\n";
str += entry >> split;
str += split >> vector<Block>({t, f});
str += t >> exit;
str += f >> split;
str += exit;
str += "OpFunctionEnd";
CompileSuccessfully(str);
if (is_shader) {
ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
EXPECT_THAT(
getDiagnosticString(),
MatchesRegex("Back-edges \\(.\\[f\\] -> .\\[split\\]\\) can only "
"be formed between a block and a loop header."));
} else {
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
}
TEST_P(ValidateCFG, BranchingToSameNonLoopHeaderBlockBad) {
bool is_shader = GetParam() == SpvCapabilityShader;
Block entry("entry");
Block split("split", SpvOpBranchConditional);
Block exit("exit", SpvOpReturn);
entry.setBody("%cond = OpSLessThan %intt %one %two\n");
if (is_shader) split.setBody("OpSelectionMerge %exit None\n");
string str = header(GetParam()) + nameOps("split") + types_consts() +
"%func = OpFunction %voidt None %funct\n";
str += entry >> split;
str += split >> vector<Block>({split, exit});
str += exit;
str += "OpFunctionEnd";
CompileSuccessfully(str);
if (is_shader) {
ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
MatchesRegex(
"Back-edges \\(.\\[split\\] -> .\\[split\\]\\) can only be "
"formed between a block and a loop header."));
} else {
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
}
TEST_P(ValidateCFG, MultipleBackEdgesToLoopHeaderBad) {
bool is_shader = GetParam() == SpvCapabilityShader;
Block entry("entry");
Block loop("loop", SpvOpBranchConditional);
Block cont("cont", SpvOpBranchConditional);
Block merge("merge", SpvOpReturn);
entry.setBody("%cond = OpSLessThan %intt %one %two\n");
if (is_shader) loop.setBody("OpLoopMerge %merge %loop None\n");
string str = header(GetParam()) + nameOps("cont", "loop") + types_consts() +
"%func = OpFunction %voidt None %funct\n";
str += entry >> loop;
str += loop >> vector<Block>({cont, merge});
str += cont >> vector<Block>({loop, loop});
str += merge;
str += "OpFunctionEnd";
CompileSuccessfully(str);
if (is_shader) {
ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
MatchesRegex(
"Loop header .\\[loop\\] targeted by multiple back-edges"));
} else {
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
}
TEST_P(ValidateCFG, ContinueTargetMustBePostDominatedByBackEdge) {
bool is_shader = GetParam() == SpvCapabilityShader;
Block entry("entry");
Block loop("loop", SpvOpBranchConditional);
Block cheader("cheader", SpvOpBranchConditional);
Block be_block("be_block");
Block merge("merge", SpvOpReturn);
Block exit("exit", SpvOpReturn);
entry.setBody("%cond = OpSLessThan %intt %one %two\n");
if (is_shader) loop.setBody("OpLoopMerge %merge %cheader None\n");
string str = header(GetParam()) + nameOps("cheader", "be_block") +
types_consts() + "%func = OpFunction %voidt None %funct\n";
str += entry >> loop;
str += loop >> vector<Block>({cheader, merge});
str += cheader >> vector<Block>({exit, be_block});
str += exit; // Branches out of a continue construct
str += be_block >> loop;
str += merge;
str += "OpFunctionEnd";
CompileSuccessfully(str);
if (is_shader) {
ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
MatchesRegex("The continue construct with the continue target "
".\\[cheader\\] is not post dominated by the "
"back-edge block .\\[be_block\\]"));
} else {
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
}
TEST_P(ValidateCFG, BranchOutOfConstructToMergeBad) {
bool is_shader = GetParam() == SpvCapabilityShader;
Block entry("entry");
Block loop("loop", SpvOpBranchConditional);
Block cont("cont", SpvOpBranchConditional);
Block merge("merge", SpvOpReturn);
entry.setBody("%cond = OpSLessThan %intt %one %two\n");
if (is_shader) loop.setBody("OpLoopMerge %merge %loop None\n");
string str = header(GetParam()) + nameOps("cont", "loop") + types_consts() +
"%func = OpFunction %voidt None %funct\n";
str += entry >> loop;
str += loop >> vector<Block>({cont, merge});
str += cont >> vector<Block>({loop, merge});
str += merge;
str += "OpFunctionEnd";
CompileSuccessfully(str);
if (is_shader) {
ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
MatchesRegex("The continue construct with the continue target "
".\\[loop\\] is not post dominated by the "
"back-edge block .\\[cont\\]"));
} else {
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
}
TEST_P(ValidateCFG, BranchOutOfConstructBad) {
bool is_shader = GetParam() == SpvCapabilityShader;
Block entry("entry");
Block loop("loop", SpvOpBranchConditional);
Block cont("cont", SpvOpBranchConditional);
Block merge("merge");
Block exit("exit", SpvOpReturn);
entry.setBody("%cond = OpSLessThan %intt %one %two\n");
if (is_shader) loop.setBody("OpLoopMerge %merge %loop None\n");
string str = header(GetParam()) + nameOps("cont", "loop") + types_consts() +
"%func = OpFunction %voidt None %funct\n";
str += entry >> loop;
str += loop >> vector<Block>({cont, merge});
str += cont >> vector<Block>({loop, exit});
str += merge >> exit;
str += exit;
str += "OpFunctionEnd";
CompileSuccessfully(str);
if (is_shader) {
ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
MatchesRegex("The continue construct with the continue target "
".\\[loop\\] is not post dominated by the "
"back-edge block .\\[cont\\]"));
} else {
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
}
TEST_F(ValidateCFG, OpSwitchToUnreachableBlock) {
Block entry("entry", SpvOpSwitch);
Block case0("case0");
Block case1("case1");
Block case2("case2");
Block def("default", SpvOpUnreachable);
Block phi("phi", SpvOpReturn);
string str = R"(
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main" %id
OpExecutionMode %main LocalSize 1 1 1
OpSource GLSL 430
OpName %main "main"
OpDecorate %id BuiltIn GlobalInvocationId
%void = OpTypeVoid
%voidf = OpTypeFunction %void
%u32 = OpTypeInt 32 0
%f32 = OpTypeFloat 32
%uvec3 = OpTypeVector %u32 3
%fvec3 = OpTypeVector %f32 3
%uvec3ptr = OpTypePointer Input %uvec3
%id = OpVariable %uvec3ptr Input
%one = OpConstant %u32 1
%three = OpConstant %u32 3
%main = OpFunction %void None %voidf
)";
entry.setBody(
"%idval = OpLoad %uvec3 %id\n"
"%x = OpCompositeExtract %u32 %idval 0\n"
"%selector = OpUMod %u32 %x %three\n"
"OpSelectionMerge %phi None\n");
str += entry >> vector<Block>({def, case0, case1, case2});
str += case1 >> phi;
str += def;
str += phi;
str += case0 >> phi;
str += case2 >> phi;
str += "OpFunctionEnd";
CompileSuccessfully(str);
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
} }
/// TODO(umar): Switch instructions /// TODO(umar): Switch instructions
/// TODO(umar): CFG branching outside of CFG construct
/// TODO(umar): Nested CFG constructs /// TODO(umar): Nested CFG constructs
} } /// namespace

View File

@ -90,4 +90,5 @@ template class spvtest::ValidateBase<
template class spvtest::ValidateBase< template class spvtest::ValidateBase<
std::tuple<int, std::tuple<std::string, std::function<spv_result_t(int)>, std::tuple<int, std::tuple<std::string, std::function<spv_result_t(int)>,
std::function<spv_result_t(int)>>>>; std::function<spv_result_t(int)>>>>;
template class spvtest::ValidateBase<SpvCapability>;
} }