Handle types with self references.

By using forward pointers, we are able to define a struct that has a
pointer to itself.  This could be directly or indirectly.  The current
implementation of the type manager did not handle this case.  There are
three changes that are made in this commit inorder to handle this case:

1) Change the handling of OpTypeForwardPointer

The current handling of OpTypeForwardsPointer is broken if there is a
reference to the pointer before the real definition.  When build the
type that contain the forward delared pointer, the type manager will ask
for the type for that ID, and will get a nullptr because it does not
exists.  This nullptr is not handleded very well.

The change is to keep track of the incomplete types the first time
through all of the types.  An incomplete type is a ForwardPointer or any
type that references an incomplete type.

Then we implement a second pass through the incomplete types that will
complete them.

2) Hashing types.

When hashing a type, we want to uses all of the subtypes as part of the
hash.  However, with types that reference them selves, this creates an
infinite recursion.  To get around this, we keep track of which types
have been seen on the path from the root type.  If we have see the
current type already then we can stop the recursion.

3) Comparing types.

In order to check if two types are the same, we must check that all of
their subtypes are the same as well.  This also causes an infinit
recursion.  The solution is to stop comparing the subtypes if we are
trying to compare two pointer types that we are already in the middle of
comparing.  The ideas is that if the two pointer are different, then in
progress compare will return false itself.

Fixes https://github.com/KhronosGroup/SPIRV-Tools/issues/1578.
This commit is contained in:
Steven Perron 2018-05-25 14:00:03 -04:00
parent 6b83643cfe
commit 93c4c184d5
7 changed files with 774 additions and 189 deletions

View File

@ -747,9 +747,9 @@ uint32_t CopyPropagateArrays::GenerateCopy(
original_type->AsStruct()) {
analysis::Struct* new_struct_type = new_type->AsStruct();
const std::vector<analysis::Type*>& original_types =
const std::vector<const analysis::Type*>& original_types =
original_struct_type->element_types();
const std::vector<analysis::Type*>& new_types =
const std::vector<const analysis::Type*>& new_types =
new_struct_type->element_types();
std::vector<uint32_t> element_ids;
for (uint32_t i = 0; i < original_types.size(); i++) {

View File

@ -41,6 +41,8 @@ TypeManager::TypeManager(const MessageConsumer& consumer,
Type* TypeManager::GetType(uint32_t id) const {
auto iter = id_to_type_.find(id);
if (iter != id_to_type_.end()) return (*iter).second;
iter = id_to_incomplete_type_.find(id);
if (iter != id_to_incomplete_type_.end()) return (*iter).second;
return nullptr;
}
@ -60,13 +62,107 @@ uint32_t TypeManager::GetId(const Type* type) const {
return 0;
}
ForwardPointer* TypeManager::GetForwardPointer(uint32_t index) const {
if (index >= forward_pointers_.size()) return nullptr;
return forward_pointers_.at(index).get();
}
void TypeManager::AnalyzeTypes(const spvtools::ir::Module& module) {
for (const auto* inst : module.GetTypes()) RecordIfTypeDefinition(*inst);
// First pass through the types. Any types that reference a forward pointer
// (directly or indirectly) are incomplete, and are added to incomplete types.
for (const auto* inst : module.GetTypes()) {
RecordIfTypeDefinition(*inst);
}
if (incomplete_types_.empty()) {
return;
}
// Get the real pointer definition for all of the forward pointers.
for (auto& type : incomplete_types_) {
if (type.type()->kind() == Type::kForwardPointer) {
auto* t = GetType(type.id());
assert(t);
auto* p = t->AsPointer();
assert(p);
type.type()->AsForwardPointer()->SetTargetPointer(p);
}
}
// Replaces the references to the forward pointers in the incomplete types.
for (auto& type : incomplete_types_) {
ReplaceForwardPointers(type.type());
}
// Delete the forward pointers now that they are not referenced anymore.
for (auto& type : incomplete_types_) {
if (type.type()->kind() == Type::kForwardPointer) {
type.ResetType(nullptr);
}
}
// Compare the complete types looking for types that are the same. If there
// are two types that are the same, then replace one with the other.
// Continue until we reach a fixed point.
bool restart = true;
while (restart) {
restart = false;
for (auto it1 = incomplete_types_.begin(); it1 != incomplete_types_.end();
++it1) {
uint32_t id1 = it1->id();
Type* type1 = it1->type();
if (!type1) {
continue;
}
for (auto it2 = it1 + 1; it2 != incomplete_types_.end(); ++it2) {
uint32_t id2 = it2->id();
(void)(id2 + id1);
Type* type2 = it2->type();
if (!type2) {
continue;
}
if (type1->IsSame(type2)) {
ReplaceType(type1, type2);
it2->ResetType(nullptr);
id_to_incomplete_type_[it2->id()] = type1;
restart = true;
}
}
}
}
// Add the remaining incomplete types to the type pool.
for (auto& type : incomplete_types_) {
if (type.type() && !type.type()->AsForwardPointer()) {
std::vector<ir::Instruction*> decorations =
context()->get_decoration_mgr()->GetDecorationsFor(type.id(), true);
for (auto dec : decorations) {
AttachDecoration(*dec, type.type());
}
auto pair = type_pool_.insert(type.ReleaseType());
id_to_type_[type.id()] = pair.first->get();
type_to_id_[pair.first->get()] = type.id();
id_to_incomplete_type_.erase(type.id());
}
}
// Add a mapping for any ids that whose original type was replaced by an
// equivalent type.
for (auto& type : id_to_incomplete_type_) {
id_to_type_[type.first] = type.second;
}
#ifndef NDEBUG
// Check if the type pool contains two types that are the same. This
// is an indication that the hashing and comparision are wrong. It
// will cause a problem if the type pool gets resized and everything
// is rehashed.
for (auto& i : type_pool_) {
for (auto& j : type_pool_) {
Type* ti = i.get();
Type* tj = j.get();
assert((ti == tj || !ti->IsSame(tj)) &&
"Type pool contains two types that are the same.");
}
}
#endif
}
void TypeManager::RemoveId(uint32_t id) {
@ -421,7 +517,7 @@ Type* TypeManager::RebuildType(const Type& type) {
}
case Type::kStruct: {
const Struct* struct_ty = type.AsStruct();
std::vector<Type*> subtypes;
std::vector<const Type*> subtypes;
subtypes.reserve(struct_ty->element_types().size());
for (const auto* ele_ty : struct_ty->element_types()) {
subtypes.push_back(RebuildType(*ele_ty));
@ -448,7 +544,7 @@ Type* TypeManager::RebuildType(const Type& type) {
case Type::kFunction: {
const Function* function_ty = type.AsFunction();
const Type* ret_ty = function_ty->return_type();
std::vector<Type*> param_types;
std::vector<const Type*> param_types;
param_types.reserve(function_ty->param_types().size());
for (const auto* param_ty : function_ty->param_types()) {
param_types.push_back(RebuildType(*param_ty));
@ -544,42 +640,79 @@ Type* TypeManager::RecordIfTypeDefinition(
case SpvOpTypeArray:
type = new Array(GetType(inst.GetSingleWordInOperand(0)),
inst.GetSingleWordInOperand(1));
if (id_to_incomplete_type_.count(inst.GetSingleWordInOperand(0))) {
incomplete_types_.emplace_back(inst.result_id(), type);
id_to_incomplete_type_[inst.result_id()] = type;
return type;
}
break;
case SpvOpTypeRuntimeArray:
type = new RuntimeArray(GetType(inst.GetSingleWordInOperand(0)));
if (id_to_incomplete_type_.count(inst.GetSingleWordInOperand(0))) {
incomplete_types_.emplace_back(inst.result_id(), type);
id_to_incomplete_type_[inst.result_id()] = type;
return type;
}
break;
case SpvOpTypeStruct: {
std::vector<Type*> element_types;
std::vector<const Type*> element_types;
bool incomplete_type = false;
for (uint32_t i = 0; i < inst.NumInOperands(); ++i) {
element_types.push_back(GetType(inst.GetSingleWordInOperand(i)));
uint32_t type_id = inst.GetSingleWordInOperand(i);
element_types.push_back(GetType(type_id));
if (id_to_incomplete_type_.count(type_id)) {
incomplete_type = true;
}
}
type = new Struct(element_types);
if (incomplete_type) {
incomplete_types_.emplace_back(inst.result_id(), type);
id_to_incomplete_type_[inst.result_id()] = type;
return type;
}
} break;
case SpvOpTypeOpaque: {
const uint32_t* data = inst.GetInOperand(0).words.data();
type = new Opaque(reinterpret_cast<const char*>(data));
} break;
case SpvOpTypePointer: {
auto* ptr = new Pointer(
GetType(inst.GetSingleWordInOperand(1)),
uint32_t pointee_type_id = inst.GetSingleWordInOperand(1);
type = new Pointer(
GetType(pointee_type_id),
static_cast<SpvStorageClass>(inst.GetSingleWordInOperand(0)));
// Let's see if somebody forward references this pointer.
for (auto* fp : unresolved_forward_pointers_) {
if (fp->target_id() == inst.result_id()) {
fp->SetTargetPointer(ptr);
unresolved_forward_pointers_.erase(fp);
break;
}
if (id_to_incomplete_type_.count(pointee_type_id)) {
incomplete_types_.emplace_back(inst.result_id(), type);
id_to_incomplete_type_[inst.result_id()] = type;
return type;
}
type = ptr;
id_to_incomplete_type_.erase(inst.result_id());
} break;
case SpvOpTypeFunction: {
Type* return_type = GetType(inst.GetSingleWordInOperand(0));
std::vector<Type*> param_types;
for (uint32_t i = 1; i < inst.NumInOperands(); ++i) {
param_types.push_back(GetType(inst.GetSingleWordInOperand(i)));
bool incomplete_type = false;
uint32_t return_type_id = inst.GetSingleWordInOperand(0);
if (id_to_incomplete_type_.count(return_type_id)) {
incomplete_type = true;
}
Type* return_type = GetType(return_type_id);
std::vector<const Type*> param_types;
for (uint32_t i = 1; i < inst.NumInOperands(); ++i) {
uint32_t param_type_id = inst.GetSingleWordInOperand(i);
param_types.push_back(GetType(param_type_id));
if (id_to_incomplete_type_.count(param_type_id)) {
incomplete_type = true;
}
}
type = new Function(return_type, param_types);
if (incomplete_type) {
incomplete_types_.emplace_back(inst.result_id(), type);
id_to_incomplete_type_[inst.result_id()] = type;
return type;
}
} break;
case SpvOpTypeEvent:
type = new Event();
@ -599,12 +732,12 @@ Type* TypeManager::RecordIfTypeDefinition(
break;
case SpvOpTypeForwardPointer: {
// Handling of forward pointers is different from the other types.
auto* fp = new ForwardPointer(
inst.GetSingleWordInOperand(0),
static_cast<SpvStorageClass>(inst.GetSingleWordInOperand(1)));
forward_pointers_.emplace_back(fp);
unresolved_forward_pointers_.insert(fp);
return fp;
uint32_t target_id = inst.GetSingleWordInOperand(0);
type = new ForwardPointer(target_id, static_cast<SpvStorageClass>(
inst.GetSingleWordInOperand(1)));
incomplete_types_.emplace_back(target_id, type);
id_to_incomplete_type_[target_id] = type;
return type;
}
case SpvOpTypePipeStorage:
type = new PipeStorage();
@ -618,22 +751,18 @@ Type* TypeManager::RecordIfTypeDefinition(
}
uint32_t id = inst.result_id();
if (id == 0) {
SPIRV_ASSERT(consumer_, inst.opcode() == SpvOpTypeForwardPointer,
"instruction without result id found");
} else {
SPIRV_ASSERT(consumer_, type != nullptr,
"type should not be nullptr at this point");
std::vector<ir::Instruction*> decorations =
context()->get_decoration_mgr()->GetDecorationsFor(id, true);
for (auto dec : decorations) {
AttachDecoration(*dec, type);
}
std::unique_ptr<Type> unique(type);
auto pair = type_pool_.insert(std::move(unique));
id_to_type_[id] = pair.first->get();
type_to_id_[pair.first->get()] = id;
SPIRV_ASSERT(consumer_, id != 0, "instruction without result id found");
SPIRV_ASSERT(consumer_, type != nullptr,
"type should not be nullptr at this point");
std::vector<ir::Instruction*> decorations =
context()->get_decoration_mgr()->GetDecorationsFor(id, true);
for (auto dec : decorations) {
AttachDecoration(*dec, type);
}
std::unique_ptr<Type> unique(type);
auto pair = type_pool_.insert(std::move(unique));
id_to_type_[id] = pair.first->get();
type_to_id_[pair.first->get()] = id;
return type;
}
@ -690,6 +819,115 @@ const Type* TypeManager::GetMemberType(
return parent_type;
}
void TypeManager::ReplaceForwardPointers(Type* type) {
switch (type->kind()) {
case Type::kArray: {
const analysis::ForwardPointer* element_type =
type->AsArray()->element_type()->AsForwardPointer();
if (element_type) {
type->AsArray()->ReplaceElementType(element_type->target_pointer());
}
} break;
case Type::kRuntimeArray: {
const analysis::ForwardPointer* element_type =
type->AsRuntimeArray()->element_type()->AsForwardPointer();
if (element_type) {
type->AsRuntimeArray()->ReplaceElementType(
element_type->target_pointer());
}
} break;
case Type::kStruct: {
auto& member_types = type->AsStruct()->element_types();
for (auto& member_type : member_types) {
if (member_type->AsForwardPointer()) {
member_type = member_type->AsForwardPointer()->target_pointer();
assert(member_type);
}
}
} break;
case Type::kPointer: {
const analysis::ForwardPointer* pointee_type =
type->AsPointer()->pointee_type()->AsForwardPointer();
if (pointee_type) {
type->AsPointer()->SetPointeeType(pointee_type->target_pointer());
}
} break;
case Type::kFunction: {
Function* func_type = type->AsFunction();
const analysis::ForwardPointer* return_type =
func_type->return_type()->AsForwardPointer();
if (return_type) {
func_type->SetReturnType(return_type->target_pointer());
}
auto& param_types = func_type->param_types();
for (auto& param_type : param_types) {
if (param_type->AsForwardPointer()) {
param_type = param_type->AsForwardPointer()->target_pointer();
}
}
} break;
default:
break;
}
}
void TypeManager::ReplaceType(Type* new_type, Type* original_type) {
assert(original_type->kind() == new_type->kind() &&
"Types must be the same for replacement.\n");
for (auto& p : incomplete_types_) {
Type* type = p.type();
if (!type) {
continue;
}
switch (type->kind()) {
case Type::kArray: {
const Type* element_type = type->AsArray()->element_type();
if (element_type == original_type) {
type->AsArray()->ReplaceElementType(new_type);
}
} break;
case Type::kRuntimeArray: {
const Type* element_type = type->AsRuntimeArray()->element_type();
if (element_type == original_type) {
type->AsRuntimeArray()->ReplaceElementType(new_type);
}
} break;
case Type::kStruct: {
auto& member_types = type->AsStruct()->element_types();
for (auto& member_type : member_types) {
if (member_type == original_type) {
member_type = new_type;
}
}
} break;
case Type::kPointer: {
const Type* pointee_type = type->AsPointer()->pointee_type();
if (pointee_type == original_type) {
type->AsPointer()->SetPointeeType(new_type);
}
} break;
case Type::kFunction: {
Function* func_type = type->AsFunction();
const Type* return_type = func_type->return_type();
if (return_type == original_type) {
func_type->SetReturnType(new_type);
}
auto& param_types = func_type->param_types();
for (auto& param_type : param_types) {
if (param_type == original_type) {
param_type = new_type;
}
}
} break;
default:
break;
}
}
}
} // namespace analysis
} // namespace opt
} // namespace spvtools

View File

@ -94,11 +94,6 @@ class TypeManager {
IdToTypeMap::const_iterator begin() const { return id_to_type_.cbegin(); }
IdToTypeMap::const_iterator end() const { return id_to_type_.cend(); }
// Returns the forward pointer type at the given |index|.
ForwardPointer* GetForwardPointer(uint32_t index) const;
// Returns the number of forward pointer types hold in this manager.
size_t NumForwardPointers() const { return forward_pointers_.size(); }
// Returns a pair of the type and pointer to the type in |sc|.
//
// |id| must be a registered type.
@ -146,14 +141,31 @@ class TypeManager {
using TypePool =
std::unordered_set<std::unique_ptr<Type>, HashTypeUniquePointer,
CompareTypeUniquePointers>;
using ForwardPointerVector = std::vector<std::unique_ptr<ForwardPointer>>;
class UnresolvedType {
public:
UnresolvedType(uint32_t i, Type* t) : id_(i), type_(t) {}
UnresolvedType(const UnresolvedType&) = delete;
UnresolvedType(UnresolvedType&& that)
: id_(that.id_), type_(std::move(that.type_)) {}
uint32_t id() { return id_; }
Type* type() { return type_.get(); }
std::unique_ptr<Type>&& ReleaseType() { return std::move(type_); }
void ResetType(Type* t) { type_.reset(t); }
private:
uint32_t id_;
std::unique_ptr<Type> type_;
};
using IdToUnresolvedType = std::vector<UnresolvedType>;
// Analyzes the types and decorations on types in the given |module|.
void AnalyzeTypes(const spvtools::ir::Module& module);
spvtools::ir::IRContext* context() { return context_; }
// Attachs the decorations on |type| to |id|.
// Attaches the decorations on |type| to |id|.
void AttachDecorations(uint32_t id, const Type* type);
// Create the annotation instruction.
@ -177,15 +189,25 @@ class TypeManager {
// replacing the bool subtype with one owned by |type_pool_|.
Type* RebuildType(const Type& type);
// Completes the incomplete type |type|, by replaces all references to
// ForwardPointer by the defining Pointer.
void ReplaceForwardPointers(Type* type);
// Replaces all references to |original_type| in |incomplete_types_| by
// |new_type|.
void ReplaceType(Type* new_type, Type* original_type);
const MessageConsumer& consumer_; // Message consumer.
spvtools::ir::IRContext* context_;
IdToTypeMap id_to_type_; // Mapping from ids to their type representations.
TypeToIdMap type_to_id_; // Mapping from types to their defining ids.
TypePool type_pool_; // Memory owner of type pointers.
ForwardPointerVector forward_pointers_; // All forward pointer declarations.
// All unresolved forward pointer declarations.
// Refers the contents in the above vector.
std::unordered_set<ForwardPointer*> unresolved_forward_pointers_;
IdToUnresolvedType incomplete_types_; // All incomplete types. Stored in an
// std::vector to make traversals
// deterministic.
IdToTypeMap id_to_incomplete_type_; // Maps ids to their type representations
// for incomplete types.
};
} // namespace analysis

View File

@ -16,6 +16,7 @@
#include <cassert>
#include <cstdint>
#include <sstream>
#include <unordered_set>
#include "types.h"
@ -171,7 +172,12 @@ bool Type::operator==(const Type& other) const {
}
}
void Type::GetHashWords(std::vector<uint32_t>* words) const {
void Type::GetHashWords(std::vector<uint32_t>* words,
std::unordered_set<const Type*>* seen) const {
if (!seen->insert(this).second) {
return;
}
words->push_back(kind_);
for (const auto& d : decorations_) {
for (auto w : d) {
@ -180,9 +186,9 @@ void Type::GetHashWords(std::vector<uint32_t>* words) const {
}
switch (kind_) {
#define DeclareKindCase(type) \
case k##type: \
As##type()->GetExtraHashWords(words); \
#define DeclareKindCase(type) \
case k##type: \
As##type()->GetExtraHashWords(words, seen); \
break;
DeclareKindCase(Void);
DeclareKindCase(Bool);
@ -212,6 +218,8 @@ void Type::GetHashWords(std::vector<uint32_t>* words) const {
assert(false && "Unhandled type");
break;
}
seen->erase(this);
}
size_t Type::HashValue() const {
@ -225,7 +233,7 @@ size_t Type::HashValue() const {
return std::hash<std::u32string>()(h);
}
bool Integer::IsSame(const Type* that) const {
bool Integer::IsSameImpl(const Type* that, IsSameCache*) const {
const Integer* it = that->AsInteger();
return it && width_ == it->width_ && signed_ == it->signed_ &&
HasSameDecorations(that);
@ -237,12 +245,13 @@ std::string Integer::str() const {
return oss.str();
}
void Integer::GetExtraHashWords(std::vector<uint32_t>* words) const {
void Integer::GetExtraHashWords(std::vector<uint32_t>* words,
std::unordered_set<const Type*>*) const {
words->push_back(width_);
words->push_back(signed_);
}
bool Float::IsSame(const Type* that) const {
bool Float::IsSameImpl(const Type* that, IsSameCache*) const {
const Float* ft = that->AsFloat();
return ft && width_ == ft->width_ && HasSameDecorations(that);
}
@ -253,7 +262,8 @@ std::string Float::str() const {
return oss.str();
}
void Float::GetExtraHashWords(std::vector<uint32_t>* words) const {
void Float::GetExtraHashWords(std::vector<uint32_t>* words,
std::unordered_set<const Type*>*) const {
words->push_back(width_);
}
@ -262,10 +272,11 @@ Vector::Vector(Type* type, uint32_t count)
assert(type->AsBool() || type->AsInteger() || type->AsFloat());
}
bool Vector::IsSame(const Type* that) const {
bool Vector::IsSameImpl(const Type* that, IsSameCache* seen) const {
const Vector* vt = that->AsVector();
if (!vt) return false;
return count_ == vt->count_ && element_type_->IsSame(vt->element_type_) &&
return count_ == vt->count_ &&
element_type_->IsSameImpl(vt->element_type_, seen) &&
HasSameDecorations(that);
}
@ -275,8 +286,9 @@ std::string Vector::str() const {
return oss.str();
}
void Vector::GetExtraHashWords(std::vector<uint32_t>* words) const {
element_type_->GetHashWords(words);
void Vector::GetExtraHashWords(std::vector<uint32_t>* words,
std::unordered_set<const Type*>* seen) const {
element_type_->GetHashWords(words, seen);
words->push_back(count_);
}
@ -285,10 +297,11 @@ Matrix::Matrix(Type* type, uint32_t count)
assert(type->AsVector());
}
bool Matrix::IsSame(const Type* that) const {
bool Matrix::IsSameImpl(const Type* that, IsSameCache* seen) const {
const Matrix* mt = that->AsMatrix();
if (!mt) return false;
return count_ == mt->count_ && element_type_->IsSame(mt->element_type_) &&
return count_ == mt->count_ &&
element_type_->IsSameImpl(mt->element_type_, seen) &&
HasSameDecorations(that);
}
@ -298,8 +311,9 @@ std::string Matrix::str() const {
return oss.str();
}
void Matrix::GetExtraHashWords(std::vector<uint32_t>* words) const {
element_type_->GetHashWords(words);
void Matrix::GetExtraHashWords(std::vector<uint32_t>* words,
std::unordered_set<const Type*>* seen) const {
element_type_->GetHashWords(words, seen);
words->push_back(count_);
}
@ -317,13 +331,14 @@ Image::Image(Type* type, SpvDim dimen, uint32_t d, bool array, bool multisample,
// TODO(antiagainst): check sampled_type
}
bool Image::IsSame(const Type* that) const {
bool Image::IsSameImpl(const Type* that, IsSameCache* seen) const {
const Image* it = that->AsImage();
if (!it) return false;
return dim_ == it->dim_ && depth_ == it->depth_ && arrayed_ == it->arrayed_ &&
ms_ == it->ms_ && sampled_ == it->sampled_ && format_ == it->format_ &&
access_qualifier_ == it->access_qualifier_ &&
sampled_type_->IsSame(it->sampled_type_) && HasSameDecorations(that);
sampled_type_->IsSameImpl(it->sampled_type_, seen) &&
HasSameDecorations(that);
}
std::string Image::str() const {
@ -334,8 +349,9 @@ std::string Image::str() const {
return oss.str();
}
void Image::GetExtraHashWords(std::vector<uint32_t>* words) const {
sampled_type_->GetHashWords(words);
void Image::GetExtraHashWords(std::vector<uint32_t>* words,
std::unordered_set<const Type*>* seen) const {
sampled_type_->GetHashWords(words, seen);
words->push_back(dim_);
words->push_back(depth_);
words->push_back(arrayed_);
@ -345,10 +361,11 @@ void Image::GetExtraHashWords(std::vector<uint32_t>* words) const {
words->push_back(access_qualifier_);
}
bool SampledImage::IsSame(const Type* that) const {
bool SampledImage::IsSameImpl(const Type* that, IsSameCache* seen) const {
const SampledImage* sit = that->AsSampledImage();
if (!sit) return false;
return image_type_->IsSame(sit->image_type_) && HasSameDecorations(that);
return image_type_->IsSameImpl(sit->image_type_, seen) &&
HasSameDecorations(that);
}
std::string SampledImage::str() const {
@ -357,8 +374,9 @@ std::string SampledImage::str() const {
return oss.str();
}
void SampledImage::GetExtraHashWords(std::vector<uint32_t>* words) const {
image_type_->GetHashWords(words);
void SampledImage::GetExtraHashWords(
std::vector<uint32_t>* words, std::unordered_set<const Type*>* seen) const {
image_type_->GetHashWords(words, seen);
}
Array::Array(Type* type, uint32_t length_id)
@ -366,11 +384,12 @@ Array::Array(Type* type, uint32_t length_id)
assert(!type->AsVoid());
}
bool Array::IsSame(const Type* that) const {
bool Array::IsSameImpl(const Type* that, IsSameCache* seen) const {
const Array* at = that->AsArray();
if (!at) return false;
return length_id_ == at->length_id_ &&
element_type_->IsSame(at->element_type_) && HasSameDecorations(that);
element_type_->IsSameImpl(at->element_type_, seen) &&
HasSameDecorations(that);
}
std::string Array::str() const {
@ -379,20 +398,24 @@ std::string Array::str() const {
return oss.str();
}
void Array::GetExtraHashWords(std::vector<uint32_t>* words) const {
element_type_->GetHashWords(words);
void Array::GetExtraHashWords(std::vector<uint32_t>* words,
std::unordered_set<const Type*>* seen) const {
element_type_->GetHashWords(words, seen);
words->push_back(length_id_);
}
void Array::ReplaceElementType(const Type* type) { element_type_ = type; }
RuntimeArray::RuntimeArray(Type* type)
: Type(kRuntimeArray), element_type_(type) {
assert(!type->AsVoid());
}
bool RuntimeArray::IsSame(const Type* that) const {
bool RuntimeArray::IsSameImpl(const Type* that, IsSameCache* seen) const {
const RuntimeArray* rat = that->AsRuntimeArray();
if (!rat) return false;
return element_type_->IsSame(rat->element_type_) && HasSameDecorations(that);
return element_type_->IsSameImpl(rat->element_type_, seen) &&
HasSameDecorations(that);
}
std::string RuntimeArray::str() const {
@ -401,11 +424,16 @@ std::string RuntimeArray::str() const {
return oss.str();
}
void RuntimeArray::GetExtraHashWords(std::vector<uint32_t>* words) const {
element_type_->GetHashWords(words);
void RuntimeArray::GetExtraHashWords(
std::vector<uint32_t>* words, std::unordered_set<const Type*>* seen) const {
element_type_->GetHashWords(words, seen);
}
Struct::Struct(const std::vector<Type*>& types)
void RuntimeArray::ReplaceElementType(const Type* type) {
element_type_ = type;
}
Struct::Struct(const std::vector<const Type*>& types)
: Type(kStruct), element_types_(types) {
for (const auto* t : types) {
(void)t;
@ -423,7 +451,7 @@ void Struct::AddMemberDecoration(uint32_t index,
element_decorations_[index].push_back(std::move(decoration));
}
bool Struct::IsSame(const Type* that) const {
bool Struct::IsSameImpl(const Type* that, IsSameCache* seen) const {
const Struct* st = that->AsStruct();
if (!st) return false;
if (element_types_.size() != st->element_types_.size()) return false;
@ -432,7 +460,8 @@ bool Struct::IsSame(const Type* that) const {
if (!HasSameDecorations(that)) return false;
for (size_t i = 0; i < element_types_.size(); ++i) {
if (!element_types_[i]->IsSame(st->element_types_[i])) return false;
if (!element_types_[i]->IsSameImpl(st->element_types_[i], seen))
return false;
}
for (const auto& p : element_decorations_) {
if (st->element_decorations_.count(p.first) == 0) return false;
@ -454,9 +483,10 @@ std::string Struct::str() const {
return oss.str();
}
void Struct::GetExtraHashWords(std::vector<uint32_t>* words) const {
void Struct::GetExtraHashWords(std::vector<uint32_t>* words,
std::unordered_set<const Type*>* seen) const {
for (auto* t : element_types_) {
t->GetHashWords(words);
t->GetHashWords(words, seen);
}
for (const auto& pair : element_decorations_) {
words->push_back(pair.first);
@ -468,7 +498,7 @@ void Struct::GetExtraHashWords(std::vector<uint32_t>* words) const {
}
}
bool Opaque::IsSame(const Type* that) const {
bool Opaque::IsSameImpl(const Type* that, IsSameCache*) const {
const Opaque* ot = that->AsOpaque();
if (!ot) return false;
return name_ == ot->name_ && HasSameDecorations(that);
@ -480,7 +510,8 @@ std::string Opaque::str() const {
return oss.str();
}
void Opaque::GetExtraHashWords(std::vector<uint32_t>* words) const {
void Opaque::GetExtraHashWords(std::vector<uint32_t>* words,
std::unordered_set<const Type*>*) const {
for (auto c : name_) {
words->push_back(static_cast<char32_t>(c));
}
@ -489,22 +520,33 @@ void Opaque::GetExtraHashWords(std::vector<uint32_t>* words) const {
Pointer::Pointer(const Type* type, SpvStorageClass sc)
: Type(kPointer), pointee_type_(type), storage_class_(sc) {}
bool Pointer::IsSame(const Type* that) const {
bool Pointer::IsSameImpl(const Type* that, IsSameCache* seen) const {
const Pointer* pt = that->AsPointer();
if (!pt) return false;
if (storage_class_ != pt->storage_class_) return false;
if (!pointee_type_->IsSame(pt->pointee_type_)) return false;
auto p = seen->insert(std::make_pair(this, that->AsPointer()));
if (!p.second) {
return true;
}
bool same_pointee = pointee_type_->IsSameImpl(pt->pointee_type_, seen);
seen->erase(p.first);
if (!same_pointee) {
return false;
}
return HasSameDecorations(that);
}
std::string Pointer::str() const { return pointee_type_->str() + "*"; }
void Pointer::GetExtraHashWords(std::vector<uint32_t>* words) const {
pointee_type_->GetHashWords(words);
void Pointer::GetExtraHashWords(std::vector<uint32_t>* words,
std::unordered_set<const Type*>* seen) const {
pointee_type_->GetHashWords(words, seen);
words->push_back(storage_class_);
}
Function::Function(Type* ret_type, const std::vector<Type*>& params)
void Pointer::SetPointeeType(const Type* type) { pointee_type_ = type; }
Function::Function(Type* ret_type, const std::vector<const Type*>& params)
: Type(kFunction), return_type_(ret_type), param_types_(params) {
for (auto* t : params) {
(void)t;
@ -512,13 +554,13 @@ Function::Function(Type* ret_type, const std::vector<Type*>& params)
}
}
bool Function::IsSame(const Type* that) const {
bool Function::IsSameImpl(const Type* that, IsSameCache* seen) const {
const Function* ft = that->AsFunction();
if (!ft) return false;
if (!return_type_->IsSame(ft->return_type_)) return false;
if (!return_type_->IsSameImpl(ft->return_type_, seen)) return false;
if (param_types_.size() != ft->param_types_.size()) return false;
for (size_t i = 0; i < param_types_.size(); ++i) {
if (!param_types_[i]->IsSame(ft->param_types_[i])) return false;
if (!param_types_[i]->IsSameImpl(ft->param_types_[i], seen)) return false;
}
return HasSameDecorations(that);
}
@ -535,14 +577,17 @@ std::string Function::str() const {
return oss.str();
}
void Function::GetExtraHashWords(std::vector<uint32_t>* words) const {
return_type_->GetHashWords(words);
void Function::GetExtraHashWords(std::vector<uint32_t>* words,
std::unordered_set<const Type*>* seen) const {
return_type_->GetHashWords(words, seen);
for (const auto* t : param_types_) {
t->GetHashWords(words);
t->GetHashWords(words, seen);
}
}
bool Pipe::IsSame(const Type* that) const {
void Function::SetReturnType(const Type* type) { return_type_ = type; }
bool Pipe::IsSameImpl(const Type* that, IsSameCache*) const {
const Pipe* pt = that->AsPipe();
if (!pt) return false;
return access_qualifier_ == pt->access_qualifier_ && HasSameDecorations(that);
@ -554,11 +599,12 @@ std::string Pipe::str() const {
return oss.str();
}
void Pipe::GetExtraHashWords(std::vector<uint32_t>* words) const {
void Pipe::GetExtraHashWords(std::vector<uint32_t>* words,
std::unordered_set<const Type*>*) const {
words->push_back(access_qualifier_);
}
bool ForwardPointer::IsSame(const Type* that) const {
bool ForwardPointer::IsSameImpl(const Type* that, IsSameCache*) const {
const ForwardPointer* fpt = that->AsForwardPointer();
if (!fpt) return false;
return target_id_ == fpt->target_id_ &&
@ -577,10 +623,11 @@ std::string ForwardPointer::str() const {
return oss.str();
}
void ForwardPointer::GetExtraHashWords(std::vector<uint32_t>* words) const {
void ForwardPointer::GetExtraHashWords(
std::vector<uint32_t>* words, std::unordered_set<const Type*>* seen) const {
words->push_back(target_id_);
words->push_back(storage_class_);
if (pointer_) pointer_->GetHashWords(words);
if (pointer_) pointer_->GetHashWords(words, seen);
}
} // namespace analysis

View File

@ -19,8 +19,10 @@
#include <map>
#include <memory>
#include <set>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "latest_version_spirv_header.h"
@ -58,6 +60,8 @@ class NamedBarrier;
// which is used as a way to probe the actual <subclass>.
class Type {
public:
typedef std::set<std::pair<const Pointer*, const Pointer*>> IsSameCache;
// Available subtypes.
//
// When adding a new derived class of Type, please add an entry to the enum.
@ -101,7 +105,16 @@ class Type {
bool HasSameDecorations(const Type* that) const;
// Returns true if this type is exactly the same as |that| type, including
// decorations.
virtual bool IsSame(const Type* that) const = 0;
bool IsSame(const Type* that) const {
IsSameCache seen;
return IsSameImpl(that, &seen);
}
// Returns true if this type is exactly the same as |that| type, including
// decorations. |seen| is the set of |Pointer*| pair that are currently being
// compared in a parent call to |IsSameImpl|.
virtual bool IsSameImpl(const Type* that, IsSameCache* seen) const = 0;
// Returns a human-readable string to represent this type.
virtual std::string str() const = 0;
@ -164,11 +177,20 @@ class Type {
size_t HashValue() const;
// Adds the necessary words to compute a hash value of this type to |words|.
void GetHashWords(std::vector<uint32_t>* words) const;
void GetHashWords(std::vector<uint32_t>* words) const {
std::unordered_set<const Type*> seen;
GetHashWords(words, &seen);
}
// Adds the necessary words to compute a hash value of this type to |words|.
void GetHashWords(std::vector<uint32_t>* words,
std::unordered_set<const Type*>* seen) const;
// Adds necessary extra words for a subtype to calculate a hash value into
// |words|.
virtual void GetExtraHashWords(std::vector<uint32_t>* words) const = 0;
virtual void GetExtraHashWords(
std::vector<uint32_t>* words,
std::unordered_set<const Type*>* pSet) const = 0;
protected:
// Decorations attached to this type. Each decoration is encoded as a vector
@ -190,7 +212,6 @@ class Integer : public Type {
: Type(kInteger), width_(w), signed_(is_signed) {}
Integer(const Integer&) = default;
bool IsSame(const Type* that) const override;
std::string str() const override;
Integer* AsInteger() override { return this; }
@ -198,9 +219,12 @@ class Integer : public Type {
uint32_t width() const { return width_; }
bool IsSigned() const { return signed_; }
void GetExtraHashWords(std::vector<uint32_t>* words) const override;
void GetExtraHashWords(std::vector<uint32_t>* words,
std::unordered_set<const Type*>* pSet) const override;
private:
bool IsSameImpl(const Type* that, IsSameCache*) const override;
uint32_t width_; // bit width
bool signed_; // true if this integer is signed
};
@ -210,16 +234,18 @@ class Float : public Type {
Float(uint32_t w) : Type(kFloat), width_(w) {}
Float(const Float&) = default;
bool IsSame(const Type* that) const override;
std::string str() const override;
Float* AsFloat() override { return this; }
const Float* AsFloat() const override { return this; }
uint32_t width() const { return width_; }
void GetExtraHashWords(std::vector<uint32_t>* words) const override;
void GetExtraHashWords(std::vector<uint32_t>* words,
std::unordered_set<const Type*>* pSet) const override;
private:
bool IsSameImpl(const Type* that, IsSameCache*) const override;
uint32_t width_; // bit width
};
@ -228,7 +254,6 @@ class Vector : public Type {
Vector(Type* element_type, uint32_t count);
Vector(const Vector&) = default;
bool IsSame(const Type* that) const override;
std::string str() const override;
const Type* element_type() const { return element_type_; }
uint32_t element_count() const { return count_; }
@ -236,10 +261,13 @@ class Vector : public Type {
Vector* AsVector() override { return this; }
const Vector* AsVector() const override { return this; }
void GetExtraHashWords(std::vector<uint32_t>* words) const override;
void GetExtraHashWords(std::vector<uint32_t>* words,
std::unordered_set<const Type*>* pSet) const override;
private:
Type* element_type_;
bool IsSameImpl(const Type* that, IsSameCache*) const override;
const Type* element_type_;
uint32_t count_;
};
@ -248,7 +276,6 @@ class Matrix : public Type {
Matrix(Type* element_type, uint32_t count);
Matrix(const Matrix&) = default;
bool IsSame(const Type* that) const override;
std::string str() const override;
const Type* element_type() const { return element_type_; }
uint32_t element_count() const { return count_; }
@ -256,10 +283,13 @@ class Matrix : public Type {
Matrix* AsMatrix() override { return this; }
const Matrix* AsMatrix() const override { return this; }
void GetExtraHashWords(std::vector<uint32_t>* words) const override;
void GetExtraHashWords(std::vector<uint32_t>* words,
std::unordered_set<const Type*>* pSet) const override;
private:
Type* element_type_;
bool IsSameImpl(const Type* that, IsSameCache*) const override;
const Type* element_type_;
uint32_t count_;
};
@ -270,7 +300,6 @@ class Image : public Type {
SpvAccessQualifier qualifier = SpvAccessQualifierReadOnly);
Image(const Image&) = default;
bool IsSame(const Type* that) const override;
std::string str() const override;
Image* AsImage() override { return this; }
@ -285,9 +314,12 @@ class Image : public Type {
SpvImageFormat format() const { return format_; }
SpvAccessQualifier access_qualifier() const { return access_qualifier_; }
void GetExtraHashWords(std::vector<uint32_t>* words) const override;
void GetExtraHashWords(std::vector<uint32_t>* words,
std::unordered_set<const Type*>* pSet) const override;
private:
bool IsSameImpl(const Type* that, IsSameCache*) const override;
Type* sampled_type_;
SpvDim dim_;
uint32_t depth_;
@ -303,7 +335,6 @@ class SampledImage : public Type {
SampledImage(Type* image) : Type(kSampledImage), image_type_(image) {}
SampledImage(const SampledImage&) = default;
bool IsSame(const Type* that) const override;
std::string str() const override;
SampledImage* AsSampledImage() override { return this; }
@ -311,9 +342,11 @@ class SampledImage : public Type {
const Type* image_type() const { return image_type_; }
void GetExtraHashWords(std::vector<uint32_t>* words) const override;
void GetExtraHashWords(std::vector<uint32_t>* words,
std::unordered_set<const Type*>* pSet) const override;
private:
bool IsSameImpl(const Type* that, IsSameCache*) const override;
Type* image_type_;
};
@ -322,7 +355,6 @@ class Array : public Type {
Array(Type* element_type, uint32_t length_id);
Array(const Array&) = default;
bool IsSame(const Type* that) const override;
std::string str() const override;
const Type* element_type() const { return element_type_; }
uint32_t LengthId() const { return length_id_; }
@ -330,10 +362,15 @@ class Array : public Type {
Array* AsArray() override { return this; }
const Array* AsArray() const override { return this; }
void GetExtraHashWords(std::vector<uint32_t>* words) const override;
void GetExtraHashWords(std::vector<uint32_t>* words,
std::unordered_set<const Type*>* pSet) const override;
void ReplaceElementType(const Type* element_type);
private:
Type* element_type_;
bool IsSameImpl(const Type* that, IsSameCache*) const override;
const Type* element_type_;
uint32_t length_id_;
};
@ -342,31 +379,37 @@ class RuntimeArray : public Type {
RuntimeArray(Type* element_type);
RuntimeArray(const RuntimeArray&) = default;
bool IsSame(const Type* that) const override;
std::string str() const override;
const Type* element_type() const { return element_type_; }
RuntimeArray* AsRuntimeArray() override { return this; }
const RuntimeArray* AsRuntimeArray() const override { return this; }
void GetExtraHashWords(std::vector<uint32_t>* words) const override;
void GetExtraHashWords(std::vector<uint32_t>* words,
std::unordered_set<const Type*>* pSet) const override;
void ReplaceElementType(const Type* element_type);
private:
Type* element_type_;
bool IsSameImpl(const Type* that, IsSameCache*) const override;
const Type* element_type_;
};
class Struct : public Type {
public:
Struct(const std::vector<Type*>& element_types);
Struct(const std::vector<const Type*>& element_types);
Struct(const Struct&) = default;
// Adds a decoration to the member at the given index. The first word is the
// decoration enum, and the remaining words, if any, are its operands.
void AddMemberDecoration(uint32_t index, std::vector<uint32_t>&& decoration);
bool IsSame(const Type* that) const override;
std::string str() const override;
const std::vector<Type*>& element_types() const { return element_types_; }
const std::vector<const Type*>& element_types() const {
return element_types_;
}
std::vector<const Type*>& element_types() { return element_types_; }
bool decoration_empty() const override {
return decorations_.empty() && element_decorations_.empty();
}
@ -379,15 +422,18 @@ class Struct : public Type {
Struct* AsStruct() override { return this; }
const Struct* AsStruct() const override { return this; }
void GetExtraHashWords(std::vector<uint32_t>* words) const override;
void GetExtraHashWords(std::vector<uint32_t>* words,
std::unordered_set<const Type*>* pSet) const override;
private:
bool IsSameImpl(const Type* that, IsSameCache*) const override;
void ClearDecorations() override {
decorations_.clear();
element_decorations_.clear();
}
std::vector<Type*> element_types_;
std::vector<const Type*> element_types_;
// We can attach decorations to struct members and that should not affect the
// underlying element type. So we need an extra data structure here to keep
// track of element type decorations. They must be stored in an ordered map
@ -401,7 +447,6 @@ class Opaque : public Type {
Opaque(std::string n) : Type(kOpaque), name_(std::move(n)) {}
Opaque(const Opaque&) = default;
bool IsSame(const Type* that) const override;
std::string str() const override;
Opaque* AsOpaque() override { return this; }
@ -409,9 +454,12 @@ class Opaque : public Type {
const std::string& name() const { return name_; }
void GetExtraHashWords(std::vector<uint32_t>* words) const override;
void GetExtraHashWords(std::vector<uint32_t>* words,
std::unordered_set<const Type*>* pSet) const override;
private:
bool IsSameImpl(const Type* that, IsSameCache*) const override;
std::string name_;
};
@ -420,7 +468,6 @@ class Pointer : public Type {
Pointer(const Type* pointee, SpvStorageClass sc);
Pointer(const Pointer&) = default;
bool IsSame(const Type* that) const override;
std::string str() const override;
const Type* pointee_type() const { return pointee_type_; }
SpvStorageClass storage_class() const { return storage_class_; }
@ -428,32 +475,42 @@ class Pointer : public Type {
Pointer* AsPointer() override { return this; }
const Pointer* AsPointer() const override { return this; }
void GetExtraHashWords(std::vector<uint32_t>* words) const override;
void GetExtraHashWords(std::vector<uint32_t>* words,
std::unordered_set<const Type*>* pSet) const override;
void SetPointeeType(const Type* type);
private:
bool IsSameImpl(const Type* that, IsSameCache*) const override;
const Type* pointee_type_;
SpvStorageClass storage_class_;
};
class Function : public Type {
public:
Function(Type* ret_type, const std::vector<Type*>& params);
Function(Type* ret_type, const std::vector<const Type*>& params);
Function(const Function&) = default;
bool IsSame(const Type* that) const override;
std::string str() const override;
Function* AsFunction() override { return this; }
const Function* AsFunction() const override { return this; }
const Type* return_type() const { return return_type_; }
const std::vector<Type*>& param_types() const { return param_types_; }
const std::vector<const Type*>& param_types() const { return param_types_; }
std::vector<const Type*>& param_types() { return param_types_; }
void GetExtraHashWords(std::vector<uint32_t>* words) const override;
void GetExtraHashWords(std::vector<uint32_t>* words,
std::unordered_set<const Type*>*) const override;
void SetReturnType(const Type* type);
private:
Type* return_type_;
std::vector<Type*> param_types_;
bool IsSameImpl(const Type* that, IsSameCache*) const override;
const Type* return_type_;
std::vector<const Type*> param_types_;
};
class Pipe : public Type {
@ -462,7 +519,6 @@ class Pipe : public Type {
: Type(kPipe), access_qualifier_(qualifier) {}
Pipe(const Pipe&) = default;
bool IsSame(const Type* that) const override;
std::string str() const override;
Pipe* AsPipe() override { return this; }
@ -470,9 +526,12 @@ class Pipe : public Type {
SpvAccessQualifier access_qualifier() const { return access_qualifier_; }
void GetExtraHashWords(std::vector<uint32_t>* words) const override;
void GetExtraHashWords(std::vector<uint32_t>* words,
std::unordered_set<const Type*>* pSet) const override;
private:
bool IsSameImpl(const Type* that, IsSameCache*) const override;
SpvAccessQualifier access_qualifier_;
};
@ -486,39 +545,44 @@ class ForwardPointer : public Type {
ForwardPointer(const ForwardPointer&) = default;
uint32_t target_id() const { return target_id_; }
void SetTargetPointer(Pointer* pointer) { pointer_ = pointer; }
void SetTargetPointer(const Pointer* pointer) { pointer_ = pointer; }
SpvStorageClass storage_class() const { return storage_class_; }
const Pointer* target_pointer() const { return pointer_; }
bool IsSame(const Type* that) const override;
std::string str() const override;
ForwardPointer* AsForwardPointer() override { return this; }
const ForwardPointer* AsForwardPointer() const override { return this; }
void GetExtraHashWords(std::vector<uint32_t>* words) const override;
void GetExtraHashWords(std::vector<uint32_t>* words,
std::unordered_set<const Type*>* pSet) const override;
private:
bool IsSameImpl(const Type* that, IsSameCache*) const override;
uint32_t target_id_;
SpvStorageClass storage_class_;
Pointer* pointer_;
const Pointer* pointer_;
};
#define DefineParameterlessType(type, name) \
class type : public Type { \
public: \
type() : Type(k##type) {} \
type(const type&) = default; \
\
bool IsSame(const Type* that) const override { \
return that->As##type() && HasSameDecorations(that); \
} \
std::string str() const override { return #name; } \
\
type* As##type() override { return this; } \
const type* As##type() const override { return this; } \
\
void GetExtraHashWords(std::vector<uint32_t>*) const override {} \
#define DefineParameterlessType(type, name) \
class type : public Type { \
public: \
type() : Type(k##type) {} \
type(const type&) = default; \
\
std::string str() const override { return #name; } \
\
type* As##type() override { return this; } \
const type* As##type() const override { return this; } \
\
void GetExtraHashWords(std::vector<uint32_t>*, \
std::unordered_set<const Type*>*) const override {} \
\
private: \
bool IsSameImpl(const Type* that, IsSameCache*) const override { \
return that->As##type() && HasSameDecorations(that); \
} \
}
DefineParameterlessType(Void, void);
DefineParameterlessType(Bool, bool);

View File

@ -131,10 +131,11 @@ std::vector<std::unique_ptr<Type>> GenerateAllTypes() {
auto* rav3s32 = types.back().get();
// Struct
types.emplace_back(new Struct(std::vector<Type*>{s32}));
types.emplace_back(new Struct(std::vector<Type*>{s32, f32}));
types.emplace_back(new Struct(std::vector<const Type*>{s32}));
types.emplace_back(new Struct(std::vector<const Type*>{s32, f32}));
auto* sts32f32 = types.back().get();
types.emplace_back(new Struct(std::vector<Type*>{u64, a42f32, rav3s32}));
types.emplace_back(
new Struct(std::vector<const Type*>{u64, a42f32, rav3s32}));
// Opaque
types.emplace_back(new Opaque(""));
@ -173,7 +174,6 @@ std::vector<std::unique_ptr<Type>> GenerateAllTypes() {
TEST(TypeManager, TypeStrings) {
const std::string text = R"(
OpTypeForwardPointer !20 !2 ; id for %p is 20, Uniform is 2
OpTypeForwardPointer !10000 !1
%void = OpTypeVoid
%bool = OpTypeBool
%u32 = OpTypeInt 32 0
@ -240,14 +240,232 @@ TEST(TypeManager, TypeStrings) {
opt::analysis::TypeManager manager(nullptr, context.get());
EXPECT_EQ(type_id_strs.size(), manager.NumTypes());
EXPECT_EQ(2u, manager.NumForwardPointers());
for (const auto& p : type_id_strs) {
EXPECT_EQ(p.second, manager.GetType(p.first)->str());
EXPECT_EQ(p.first, manager.GetId(manager.GetType(p.first)));
}
EXPECT_EQ("forward_pointer({uint32}*)", manager.GetForwardPointer(0)->str());
EXPECT_EQ("forward_pointer(10000)", manager.GetForwardPointer(1)->str());
}
TEST(TypeManager, StructWithFwdPtr) {
const std::string text = R"(
OpCapability Addresses
OpCapability Kernel
%1 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %7 "test"
OpSource OpenCL_C 102000
OpDecorate %11 FuncParamAttr NoCapture
%11 = OpDecorationGroup
OpGroupDecorate %11 %8 %9
OpTypeForwardPointer %100 CrossWorkgroup
%void = OpTypeVoid
%150 = OpTypeStruct %100
%100 = OpTypePointer CrossWorkgroup %150
%6 = OpTypeFunction %void %100 %100
%7 = OpFunction %void Pure %6
%8 = OpFunctionParameter %100
%9 = OpFunctionParameter %100
%10 = OpLabel
OpReturn
OpFunctionEnd
)";
std::unique_ptr<ir::IRContext> context =
BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
opt::analysis::TypeManager manager(nullptr, context.get());
Type* p100 = manager.GetType(100);
Type* s150 = manager.GetType(150);
EXPECT_TRUE(p100->AsPointer());
EXPECT_EQ(p100->AsPointer()->pointee_type(), s150);
EXPECT_TRUE(s150->AsStruct());
EXPECT_EQ(s150->AsStruct()->element_types()[0], p100);
}
TEST(TypeManager, CircularFwdPtr) {
const std::string text = R"(
OpCapability Addresses
OpCapability Kernel
%1 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %7 "test"
OpSource OpenCL_C 102000
OpDecorate %11 FuncParamAttr NoCapture
%11 = OpDecorationGroup
OpGroupDecorate %11 %8 %9
OpTypeForwardPointer %100 CrossWorkgroup
OpTypeForwardPointer %200 CrossWorkgroup
%void = OpTypeVoid
%int = OpTypeInt 32 0
%float = OpTypeFloat 32
%150 = OpTypeStruct %200 %int
%250 = OpTypeStruct %100 %float
%100 = OpTypePointer CrossWorkgroup %150
%200 = OpTypePointer CrossWorkgroup %250
%6 = OpTypeFunction %void %100 %200
%7 = OpFunction %void Pure %6
%8 = OpFunctionParameter %100
%9 = OpFunctionParameter %200
%10 = OpLabel
OpReturn
OpFunctionEnd
)";
std::unique_ptr<ir::IRContext> context =
BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
opt::analysis::TypeManager manager(nullptr, context.get());
Type* p100 = manager.GetType(100);
Type* s150 = manager.GetType(150);
Type* p200 = manager.GetType(200);
Type* s250 = manager.GetType(250);
EXPECT_TRUE(p100->AsPointer());
EXPECT_EQ(p100->AsPointer()->pointee_type(), s150);
EXPECT_TRUE(p200->AsPointer());
EXPECT_EQ(p200->AsPointer()->pointee_type(), s250);
EXPECT_TRUE(s150->AsStruct());
EXPECT_EQ(s150->AsStruct()->element_types()[0], p200);
EXPECT_TRUE(s250->AsStruct());
EXPECT_EQ(s250->AsStruct()->element_types()[0], p100);
}
TEST(TypeManager, IsomorphicStructWithFwdPtr) {
const std::string text = R"(
OpCapability Addresses
OpCapability Kernel
%1 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %7 "test"
OpSource OpenCL_C 102000
OpDecorate %11 FuncParamAttr NoCapture
%11 = OpDecorationGroup
OpGroupDecorate %11 %8 %9
OpTypeForwardPointer %100 CrossWorkgroup
OpTypeForwardPointer %200 CrossWorkgroup
%void = OpTypeVoid
%_struct_1 = OpTypeStruct %100
%_struct_2 = OpTypeStruct %200
%100 = OpTypePointer CrossWorkgroup %_struct_1
%200 = OpTypePointer CrossWorkgroup %_struct_2
%6 = OpTypeFunction %void %100 %200
%7 = OpFunction %void Pure %6
%8 = OpFunctionParameter %100
%9 = OpFunctionParameter %200
%10 = OpLabel
OpReturn
OpFunctionEnd
)";
std::unique_ptr<ir::IRContext> context =
BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
opt::analysis::TypeManager manager(nullptr, context.get());
EXPECT_EQ(manager.GetType(100), manager.GetType(200));
}
TEST(TypeManager, IsomorphicCircularFwdPtr) {
const std::string text = R"(
OpCapability Addresses
OpCapability Kernel
%1 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %7 "test"
OpSource OpenCL_C 102000
OpDecorate %11 FuncParamAttr NoCapture
%11 = OpDecorationGroup
OpGroupDecorate %11 %8 %9
OpTypeForwardPointer %100 CrossWorkgroup
OpTypeForwardPointer %200 CrossWorkgroup
OpTypeForwardPointer %300 CrossWorkgroup
OpTypeForwardPointer %400 CrossWorkgroup
%void = OpTypeVoid
%int = OpTypeInt 32 0
%float = OpTypeFloat 32
%150 = OpTypeStruct %200 %int
%250 = OpTypeStruct %100 %float
%350 = OpTypeStruct %400 %int
%450 = OpTypeStruct %300 %float
%100 = OpTypePointer CrossWorkgroup %150
%200 = OpTypePointer CrossWorkgroup %250
%300 = OpTypePointer CrossWorkgroup %350
%400 = OpTypePointer CrossWorkgroup %450
%6 = OpTypeFunction %void %100 %200
%7 = OpFunction %void Pure %6
%8 = OpFunctionParameter %100
%9 = OpFunctionParameter %200
%10 = OpLabel
OpReturn
OpFunctionEnd
)";
std::unique_ptr<ir::IRContext> context =
BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
opt::analysis::TypeManager manager(nullptr, context.get());
Type* p100 = manager.GetType(100);
Type* p300 = manager.GetType(300);
EXPECT_EQ(p100, p300);
Type* p200 = manager.GetType(200);
Type* p400 = manager.GetType(400);
EXPECT_EQ(p200, p400);
Type* p150 = manager.GetType(150);
Type* p350 = manager.GetType(350);
EXPECT_EQ(p150, p350);
Type* p250 = manager.GetType(250);
Type* p450 = manager.GetType(450);
EXPECT_EQ(p250, p450);
}
TEST(TypeManager, PartialIsomorphicFwdPtr) {
const std::string text = R"(
OpCapability Addresses
OpCapability Kernel
%1 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %7 "test"
OpSource OpenCL_C 102000
OpDecorate %11 FuncParamAttr NoCapture
%11 = OpDecorationGroup
OpGroupDecorate %11 %8 %9
OpTypeForwardPointer %100 CrossWorkgroup
OpTypeForwardPointer %200 CrossWorkgroup
%void = OpTypeVoid
%int = OpTypeInt 32 0
%float = OpTypeFloat 32
%150 = OpTypeStruct %200 %int
%250 = OpTypeStruct %200 %int
%100 = OpTypePointer CrossWorkgroup %150
%200 = OpTypePointer CrossWorkgroup %250
%6 = OpTypeFunction %void %100 %200
%7 = OpFunction %void Pure %6
%8 = OpFunctionParameter %100
%9 = OpFunctionParameter %200
%10 = OpLabel
OpReturn
OpFunctionEnd
)";
std::unique_ptr<ir::IRContext> context =
BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
opt::analysis::TypeManager manager(nullptr, context.get());
Type* p100 = manager.GetType(100);
Type* p200 = manager.GetType(200);
EXPECT_EQ(p100->AsPointer()->pointee_type(),
p200->AsPointer()->pointee_type());
}
TEST(TypeManager, DecorationOnStruct) {
@ -270,7 +488,6 @@ TEST(TypeManager, DecorationOnStruct) {
opt::analysis::TypeManager manager(nullptr, context.get());
ASSERT_EQ(7u, manager.NumTypes());
ASSERT_EQ(0u, manager.NumForwardPointers());
// Make sure we get ids correct.
ASSERT_EQ("uint32", manager.GetType(5)->str());
ASSERT_EQ("float32", manager.GetType(6)->str());
@ -320,7 +537,6 @@ TEST(TypeManager, DecorationOnMember) {
opt::analysis::TypeManager manager(nullptr, context.get());
ASSERT_EQ(10u, manager.NumTypes());
ASSERT_EQ(0u, manager.NumForwardPointers());
// Make sure we get ids correct.
ASSERT_EQ("uint32", manager.GetType(8)->str());
ASSERT_EQ("float32", manager.GetType(9)->str());
@ -358,7 +574,6 @@ TEST(TypeManager, DecorationEmpty) {
opt::analysis::TypeManager manager(nullptr, context.get());
ASSERT_EQ(5u, manager.NumTypes());
ASSERT_EQ(0u, manager.NumForwardPointers());
// Make sure we get ids correct.
ASSERT_EQ("uint32", manager.GetType(3)->str());
ASSERT_EQ("float32", manager.GetType(4)->str());
@ -379,7 +594,6 @@ TEST(TypeManager, BeginEndForEmptyModule) {
BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text);
opt::analysis::TypeManager manager(nullptr, context.get());
ASSERT_EQ(0u, manager.NumTypes());
ASSERT_EQ(0u, manager.NumForwardPointers());
EXPECT_EQ(manager.begin(), manager.end());
}
@ -396,7 +610,6 @@ TEST(TypeManager, BeginEnd) {
BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text);
opt::analysis::TypeManager manager(nullptr, context.get());
ASSERT_EQ(5u, manager.NumTypes());
ASSERT_EQ(0u, manager.NumForwardPointers());
EXPECT_NE(manager.begin(), manager.end());
for (const auto& t : manager) {

View File

@ -73,8 +73,8 @@ TestMultipleInstancesOfTheSameType(Sampler);
TestMultipleInstancesOfTheSameType(SampledImage, image_t_.get());
TestMultipleInstancesOfTheSameType(Array, u32_t_.get(), 10);
TestMultipleInstancesOfTheSameType(RuntimeArray, u32_t_.get());
TestMultipleInstancesOfTheSameType(Struct, std::vector<Type*>{u32_t_.get(),
f64_t_.get()});
TestMultipleInstancesOfTheSameType(Struct, std::vector<const Type*>{
u32_t_.get(), f64_t_.get()});
TestMultipleInstancesOfTheSameType(Opaque, "testing rocks");
TestMultipleInstancesOfTheSameType(Pointer, u32_t_.get(), SpvStorageClassInput);
TestMultipleInstancesOfTheSameType(Function, u32_t_.get(),
@ -160,10 +160,11 @@ std::vector<std::unique_ptr<Type>> GenerateAllTypes() {
auto* rav3s32 = types.back().get();
// Struct
types.emplace_back(new Struct(std::vector<Type*>{s32}));
types.emplace_back(new Struct(std::vector<Type*>{s32, f32}));
types.emplace_back(new Struct(std::vector<const Type*>{s32}));
types.emplace_back(new Struct(std::vector<const Type*>{s32, f32}));
auto* sts32f32 = types.back().get();
types.emplace_back(new Struct(std::vector<Type*>{u64, a42f32, rav3s32}));
types.emplace_back(
new Struct(std::vector<const Type*>{u64, a42f32, rav3s32}));
// Opaque
types.emplace_back(new Opaque(""));