Add spirv-opt pass to replace descriptor accesses based on variable indices (#4574)

This commit adds a spirv-opt pass to replace accesses to
descriptor array based on variable indices with constant
elements.

Before:
```
%descriptor = OpVariable %_ptr_array_Image Uniform
...
%ac = OpAccessChain %_ptr_Image %descriptor %variable_index
(some image instructions using %ac)
```
After:
```
%descriptor = OpVariable %_ptr_array_Image Uniform
...
OpSwitch %variable_index 0 %case0 1 %case1 ...
...
%case0 = OpLabel
%ac = OpAccessChain %_ptr_Image %descriptor %uint_0
...
%case1 = OpLabel
%ac = OpAccessChain %_ptr_Image %descriptor %uint_1
...
(use OpPhi for value with concrete type)
```
This commit is contained in:
Jaebaek Seo 2021-10-26 17:20:58 -04:00 committed by GitHub
parent d78c1c4cd3
commit d997c83b10
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 1265 additions and 115 deletions

View File

@ -100,6 +100,7 @@ SPVTOOLS_OPT_SRC_FILES := \
source/opt/debug_info_manager.cpp \
source/opt/def_use_manager.cpp \
source/opt/desc_sroa.cpp \
source/opt/desc_sroa_util.cpp \
source/opt/dominator_analysis.cpp \
source/opt/dominator_tree.cpp \
source/opt/eliminate_dead_constant_pass.cpp \
@ -157,6 +158,7 @@ SPVTOOLS_OPT_SRC_FILES := \
source/opt/relax_float_ops_pass.cpp \
source/opt/remove_duplicates_pass.cpp \
source/opt/remove_unused_interface_variables_pass.cpp \
source/opt/replace_desc_array_access_using_var_index.cpp \
source/opt/replace_invalid_opc.cpp \
source/opt/scalar_analysis.cpp \
source/opt/scalar_analysis_simplification.cpp \

View File

@ -597,6 +597,8 @@ static_library("spvtools_opt") {
"source/opt/def_use_manager.h",
"source/opt/desc_sroa.cpp",
"source/opt/desc_sroa.h",
"source/opt/desc_sroa_util.cpp",
"source/opt/desc_sroa_util.h",
"source/opt/dominator_analysis.cpp",
"source/opt/dominator_analysis.h",
"source/opt/dominator_tree.cpp",
@ -716,6 +718,8 @@ static_library("spvtools_opt") {
"source/opt/remove_duplicates_pass.h",
"source/opt/remove_unused_interface_variables_pass.cpp",
"source/opt/remove_unused_interface_variables_pass.h",
"source/opt/replace_desc_array_access_using_var_index.cpp",
"source/opt/replace_desc_array_access_using_var_index.h",
"source/opt/replace_invalid_opc.cpp",
"source/opt/replace_invalid_opc.h",
"source/opt/scalar_analysis.cpp",

View File

@ -833,6 +833,13 @@ Optimizer::PassToken CreateFixStorageClassPass();
// inclusive.
Optimizer::PassToken CreateGraphicsRobustAccessPass();
// Create a pass to replace a descriptor access using variable index.
// This pass replaces every access using a variable index to array variable
// |desc| that has a DescriptorSet and Binding decorations with a constant
// element of the array. In order to replace the access using a variable index
// with the constant element, it uses a switch statement.
Optimizer::PassToken CreateReplaceDescArrayAccessUsingVarIndexPass();
// Create descriptor scalar replacement pass.
// This pass replaces every array variable |desc| that has a DescriptorSet and
// Binding decorations with a new variable for each element of the array.

View File

@ -39,6 +39,7 @@ set(SPIRV_TOOLS_OPT_SOURCES
debug_info_manager.h
def_use_manager.h
desc_sroa.h
desc_sroa_util.h
dominator_analysis.h
dominator_tree.h
eliminate_dead_constant_pass.h
@ -100,6 +101,7 @@ set(SPIRV_TOOLS_OPT_SOURCES
relax_float_ops_pass.h
remove_duplicates_pass.h
remove_unused_interface_variables_pass.h
replace_desc_array_access_using_var_index.h
replace_invalid_opc.h
scalar_analysis.h
scalar_analysis_nodes.h
@ -148,6 +150,7 @@ set(SPIRV_TOOLS_OPT_SOURCES
debug_info_manager.cpp
def_use_manager.cpp
desc_sroa.cpp
desc_sroa_util.cpp
dominator_analysis.cpp
dominator_tree.cpp
eliminate_dead_constant_pass.cpp
@ -205,6 +208,7 @@ set(SPIRV_TOOLS_OPT_SOURCES
relax_float_ops_pass.cpp
remove_duplicates_pass.cpp
remove_unused_interface_variables_pass.cpp
replace_desc_array_access_using_var_index.cpp
replace_invalid_opc.cpp
scalar_analysis.cpp
scalar_analysis_simplification.cpp

View File

@ -490,6 +490,14 @@ void DecorationManager::ForEachDecoration(
});
}
bool DecorationManager::HasDecoration(uint32_t id, uint32_t decoration) {
bool has_decoration = false;
ForEachDecoration(id, decoration, [&has_decoration](const Instruction&) {
has_decoration = true;
});
return has_decoration;
}
bool DecorationManager::FindDecoration(
uint32_t id, uint32_t decoration,
std::function<bool(const Instruction&)> f) {

View File

@ -90,6 +90,10 @@ class DecorationManager {
bool AreDecorationsTheSame(const Instruction* inst1, const Instruction* inst2,
bool ignore_target) const;
// Returns whether a decoration instruction for |id| with decoration
// |decoration| exists or not.
bool HasDecoration(uint32_t id, uint32_t decoration);
// |f| is run on each decoration instruction for |id| with decoration
// |decoration|. Processed are all decorations which target |id| either
// directly or indirectly by Decoration Groups.

View File

@ -14,6 +14,7 @@
#include "source/opt/desc_sroa.h"
#include "source/opt/desc_sroa_util.h"
#include "source/util/string_utils.h"
namespace spvtools {
@ -25,7 +26,7 @@ Pass::Status DescriptorScalarReplacement::Process() {
std::vector<Instruction*> vars_to_kill;
for (Instruction& var : context()->types_values()) {
if (IsCandidate(&var)) {
if (descsroautil::IsDescriptorArray(context(), &var)) {
modified = true;
if (!ReplaceCandidate(&var)) {
return Status::Failure;
@ -41,72 +42,6 @@ Pass::Status DescriptorScalarReplacement::Process() {
return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
}
bool DescriptorScalarReplacement::IsCandidate(Instruction* var) {
if (var->opcode() != SpvOpVariable) {
return false;
}
uint32_t ptr_type_id = var->type_id();
Instruction* ptr_type_inst =
context()->get_def_use_mgr()->GetDef(ptr_type_id);
if (ptr_type_inst->opcode() != SpvOpTypePointer) {
return false;
}
uint32_t var_type_id = ptr_type_inst->GetSingleWordInOperand(1);
Instruction* var_type_inst =
context()->get_def_use_mgr()->GetDef(var_type_id);
if (var_type_inst->opcode() != SpvOpTypeArray &&
var_type_inst->opcode() != SpvOpTypeStruct) {
return false;
}
// All structures with descriptor assignments must be replaced by variables,
// one for each of their members - with the exceptions of buffers.
if (IsTypeOfStructuredBuffer(var_type_inst)) {
return false;
}
bool has_desc_set_decoration = false;
context()->get_decoration_mgr()->ForEachDecoration(
var->result_id(), SpvDecorationDescriptorSet,
[&has_desc_set_decoration](const Instruction&) {
has_desc_set_decoration = true;
});
if (!has_desc_set_decoration) {
return false;
}
bool has_binding_decoration = false;
context()->get_decoration_mgr()->ForEachDecoration(
var->result_id(), SpvDecorationBinding,
[&has_binding_decoration](const Instruction&) {
has_binding_decoration = true;
});
if (!has_binding_decoration) {
return false;
}
return true;
}
bool DescriptorScalarReplacement::IsTypeOfStructuredBuffer(
const Instruction* type) const {
if (type->opcode() != SpvOpTypeStruct) {
return false;
}
// All buffers have offset decorations for members of their structure types.
// This is how we distinguish it from a structure of descriptors.
bool has_offset_decoration = false;
context()->get_decoration_mgr()->ForEachDecoration(
type->result_id(), SpvDecorationOffset,
[&has_offset_decoration](const Instruction&) {
has_offset_decoration = true;
});
return has_offset_decoration;
}
bool DescriptorScalarReplacement::ReplaceCandidate(Instruction* var) {
std::vector<Instruction*> access_chain_work_list;
std::vector<Instruction*> load_work_list;
@ -162,16 +97,15 @@ bool DescriptorScalarReplacement::ReplaceAccessChain(Instruction* var,
return false;
}
uint32_t idx_id = use->GetSingleWordInOperand(1);
const analysis::Constant* idx_const =
context()->get_constant_mgr()->FindDeclaredConstant(idx_id);
if (idx_const == nullptr) {
const analysis::Constant* const_index =
descsroautil::GetAccessChainIndexAsConst(context(), use);
if (const_index == nullptr) {
context()->EmitErrorMessage("Variable cannot be replaced: invalid index",
use);
return false;
}
uint32_t idx = idx_const->GetU32();
uint32_t idx = const_index->GetU32();
uint32_t replacement_var = GetReplacementVariable(var, idx);
if (use->NumInOperands() == 2) {
@ -208,39 +142,12 @@ uint32_t DescriptorScalarReplacement::GetReplacementVariable(Instruction* var,
uint32_t idx) {
auto replacement_vars = replacement_variables_.find(var);
if (replacement_vars == replacement_variables_.end()) {
uint32_t ptr_type_id = var->type_id();
Instruction* ptr_type_inst = get_def_use_mgr()->GetDef(ptr_type_id);
assert(ptr_type_inst->opcode() == SpvOpTypePointer &&
"Variable should be a pointer to an array or structure.");
uint32_t pointee_type_id = ptr_type_inst->GetSingleWordInOperand(1);
Instruction* pointee_type_inst = get_def_use_mgr()->GetDef(pointee_type_id);
const bool is_array = pointee_type_inst->opcode() == SpvOpTypeArray;
const bool is_struct = pointee_type_inst->opcode() == SpvOpTypeStruct;
assert((is_array || is_struct) &&
"Variable should be a pointer to an array or structure.");
// For arrays, each array element should be replaced with a new replacement
// variable
if (is_array) {
uint32_t array_len_id = pointee_type_inst->GetSingleWordInOperand(1);
const analysis::Constant* array_len_const =
context()->get_constant_mgr()->FindDeclaredConstant(array_len_id);
assert(array_len_const != nullptr && "Array length must be a constant.");
uint32_t array_len = array_len_const->GetU32();
replacement_vars = replacement_variables_
.insert({var, std::vector<uint32_t>(array_len, 0)})
.first;
}
// For structures, each member should be replaced with a new replacement
// variable
if (is_struct) {
const uint32_t num_members = pointee_type_inst->NumInOperands();
replacement_vars =
replacement_variables_
.insert({var, std::vector<uint32_t>(num_members, 0)})
.first;
}
uint32_t number_of_elements =
descsroautil::GetNumberOfElementsForArrayOrStruct(context(), var);
replacement_vars =
replacement_variables_
.insert({var, std::vector<uint32_t>(number_of_elements, 0)})
.first;
}
if (replacement_vars->second[idx] == 0) {
@ -377,7 +284,7 @@ uint32_t DescriptorScalarReplacement::GetNumBindingsUsedByType(
// The number of bindings consumed by a structure is the sum of the bindings
// used by its members.
if (type_inst->opcode() == SpvOpTypeStruct &&
!IsTypeOfStructuredBuffer(type_inst)) {
!descsroautil::IsTypeOfStructuredBuffer(context(), type_inst)) {
uint32_t sum = 0;
for (uint32_t i = 0; i < type_inst->NumInOperands(); i++)
sum += GetNumBindingsUsedByType(type_inst->GetSingleWordInOperand(i));

View File

@ -46,10 +46,6 @@ class DescriptorScalarReplacement : public Pass {
}
private:
// Returns true if |var| is an OpVariable instruction that represents a
// descriptor array. These are the variables that we want to replace.
bool IsCandidate(Instruction* var);
// Replaces all references to |var| by new variables, one for each element of
// the array |var|. The binding for the new variables corresponding to
// element i will be the binding of |var| plus i. Returns true if successful.
@ -93,11 +89,6 @@ class DescriptorScalarReplacement : public Pass {
// bindings used by its members.
uint32_t GetNumBindingsUsedByType(uint32_t type_id);
// Returns true if |type| is a type that could be used for a structured buffer
// as opposed to a type that would be used for a structure of resource
// descriptors.
bool IsTypeOfStructuredBuffer(const Instruction* type) const;
// A map from an OpVariable instruction to the set of variables that will be
// used to replace it. The entry |replacement_variables_[var][i]| is the id of
// a variable that will be used in the place of the the ith element of the

View File

@ -0,0 +1,117 @@
// Copyright (c) 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "source/opt/desc_sroa_util.h"
namespace spvtools {
namespace opt {
namespace {
const uint32_t kOpAccessChainInOperandIndexes = 1;
// Returns the length of array type |type|.
uint32_t GetLengthOfArrayType(IRContext* context, Instruction* type) {
assert(type->opcode() == SpvOpTypeArray && "type must be array");
uint32_t length_id = type->GetSingleWordInOperand(1);
const analysis::Constant* length_const =
context->get_constant_mgr()->FindDeclaredConstant(length_id);
assert(length_const != nullptr);
return length_const->GetU32();
}
} // namespace
namespace descsroautil {
bool IsDescriptorArray(IRContext* context, Instruction* var) {
if (var->opcode() != SpvOpVariable) {
return false;
}
uint32_t ptr_type_id = var->type_id();
Instruction* ptr_type_inst = context->get_def_use_mgr()->GetDef(ptr_type_id);
if (ptr_type_inst->opcode() != SpvOpTypePointer) {
return false;
}
uint32_t var_type_id = ptr_type_inst->GetSingleWordInOperand(1);
Instruction* var_type_inst = context->get_def_use_mgr()->GetDef(var_type_id);
if (var_type_inst->opcode() != SpvOpTypeArray &&
var_type_inst->opcode() != SpvOpTypeStruct) {
return false;
}
// All structures with descriptor assignments must be replaced by variables,
// one for each of their members - with the exceptions of buffers.
if (IsTypeOfStructuredBuffer(context, var_type_inst)) {
return false;
}
if (!context->get_decoration_mgr()->HasDecoration(
var->result_id(), SpvDecorationDescriptorSet)) {
return false;
}
return context->get_decoration_mgr()->HasDecoration(var->result_id(),
SpvDecorationBinding);
}
bool IsTypeOfStructuredBuffer(IRContext* context, const Instruction* type) {
if (type->opcode() != SpvOpTypeStruct) {
return false;
}
// All buffers have offset decorations for members of their structure types.
// This is how we distinguish it from a structure of descriptors.
return context->get_decoration_mgr()->HasDecoration(type->result_id(),
SpvDecorationOffset);
}
const analysis::Constant* GetAccessChainIndexAsConst(
IRContext* context, Instruction* access_chain) {
if (access_chain->NumInOperands() <= 1) {
return nullptr;
}
uint32_t idx_id = GetFirstIndexOfAccessChain(access_chain);
const analysis::Constant* idx_const =
context->get_constant_mgr()->FindDeclaredConstant(idx_id);
return idx_const;
}
uint32_t GetFirstIndexOfAccessChain(Instruction* access_chain) {
assert(access_chain->NumInOperands() > 1 &&
"OpAccessChain does not have Indexes operand");
return access_chain->GetSingleWordInOperand(kOpAccessChainInOperandIndexes);
}
uint32_t GetNumberOfElementsForArrayOrStruct(IRContext* context,
Instruction* var) {
uint32_t ptr_type_id = var->type_id();
Instruction* ptr_type_inst = context->get_def_use_mgr()->GetDef(ptr_type_id);
assert(ptr_type_inst->opcode() == SpvOpTypePointer &&
"Variable should be a pointer to an array or structure.");
uint32_t pointee_type_id = ptr_type_inst->GetSingleWordInOperand(1);
Instruction* pointee_type_inst =
context->get_def_use_mgr()->GetDef(pointee_type_id);
if (pointee_type_inst->opcode() == SpvOpTypeArray) {
return GetLengthOfArrayType(context, pointee_type_inst);
}
assert(pointee_type_inst->opcode() == SpvOpTypeStruct &&
"Variable should be a pointer to an array or structure.");
return pointee_type_inst->NumInOperands();
}
} // namespace descsroautil
} // namespace opt
} // namespace spvtools

View File

@ -0,0 +1,54 @@
// Copyright (c) 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SOURCE_OPT_DESC_SROA_UTIL_H_
#define SOURCE_OPT_DESC_SROA_UTIL_H_
#include "source/opt/ir_context.h"
namespace spvtools {
namespace opt {
// Provides functions for the descriptor array SROA.
namespace descsroautil {
// Returns true if |var| is an OpVariable instruction that represents a
// descriptor array.
bool IsDescriptorArray(IRContext* context, Instruction* var);
// Returns true if |type| is a type that could be used for a structured buffer
// as opposed to a type that would be used for a structure of resource
// descriptors.
bool IsTypeOfStructuredBuffer(IRContext* context, const Instruction* type);
// Returns the first index of the OpAccessChain instruction |access_chain| as
// a constant. Returns nullptr if it is not a constant.
const analysis::Constant* GetAccessChainIndexAsConst(IRContext* context,
Instruction* access_chain);
// Returns the number of elements of an OpVariable instruction |var| whose type
// must be a pointer to an array or a struct.
uint32_t GetNumberOfElementsForArrayOrStruct(IRContext* context,
Instruction* var);
// Returns the first Indexes operand id of the OpAccessChain or
// OpInBoundsAccessChain instruction |access_chain|. The access chain must have
// at least 1 index.
uint32_t GetFirstIndexOfAccessChain(Instruction* access_chain);
} // namespace descsroautil
} // namespace opt
} // namespace spvtools
#endif // SOURCE_OPT_DESC_SROA_UTIL_H_

View File

@ -320,6 +320,8 @@ bool Optimizer::RegisterPassFromFlag(const std::string& flag) {
RegisterPass(CreateCombineAccessChainsPass());
} else if (pass_name == "convert-local-access-chains") {
RegisterPass(CreateLocalAccessChainConvertPass());
} else if (pass_name == "replace-desc-array-access-using-var-index") {
RegisterPass(CreateReplaceDescArrayAccessUsingVarIndexPass());
} else if (pass_name == "descriptor-scalar-replacement") {
RegisterPass(CreateDescriptorScalarReplacementPass());
} else if (pass_name == "eliminate-dead-code-aggressive") {
@ -958,6 +960,11 @@ Optimizer::PassToken CreateGraphicsRobustAccessPass() {
MakeUnique<opt::GraphicsRobustAccessPass>());
}
Optimizer::PassToken CreateReplaceDescArrayAccessUsingVarIndexPass() {
return MakeUnique<Optimizer::PassToken::Impl>(
MakeUnique<opt::ReplaceDescArrayAccessUsingVarIndex>());
}
Optimizer::PassToken CreateDescriptorScalarReplacementPass() {
return MakeUnique<Optimizer::PassToken::Impl>(
MakeUnique<opt::DescriptorScalarReplacement>());

View File

@ -66,6 +66,7 @@
#include "source/opt/relax_float_ops_pass.h"
#include "source/opt/remove_duplicates_pass.h"
#include "source/opt/remove_unused_interface_variables_pass.h"
#include "source/opt/replace_desc_array_access_using_var_index.h"
#include "source/opt/replace_invalid_opc.h"
#include "source/opt/scalar_replacement_pass.h"
#include "source/opt/set_spec_constant_default_value_pass.h"

View File

@ -0,0 +1,423 @@
// Copyright (c) 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "source/opt/replace_desc_array_access_using_var_index.h"
#include "source/opt/desc_sroa_util.h"
#include "source/opt/ir_builder.h"
#include "source/util/string_utils.h"
namespace spvtools {
namespace opt {
namespace {
const uint32_t kOpAccessChainInOperandIndexes = 1;
const uint32_t kOpTypePointerInOperandType = 1;
const uint32_t kOpTypeArrayInOperandType = 0;
const uint32_t kOpTypeStructInOperandMember = 0;
IRContext::Analysis kAnalysisDefUseAndInstrToBlockMapping =
IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping;
uint32_t GetValueWithKeyExistenceCheck(
uint32_t key, const std::unordered_map<uint32_t, uint32_t>& map) {
auto itr = map.find(key);
assert(itr != map.end() && "Key does not exist");
return itr->second;
}
} // namespace
Pass::Status ReplaceDescArrayAccessUsingVarIndex::Process() {
Status status = Status::SuccessWithoutChange;
for (Instruction& var : context()->types_values()) {
if (descsroautil::IsDescriptorArray(context(), &var)) {
if (ReplaceVariableAccessesWithConstantElements(&var))
status = Status::SuccessWithChange;
}
}
return status;
}
bool ReplaceDescArrayAccessUsingVarIndex::
ReplaceVariableAccessesWithConstantElements(Instruction* var) const {
std::vector<Instruction*> work_list;
get_def_use_mgr()->ForEachUser(var, [&work_list](Instruction* use) {
switch (use->opcode()) {
case SpvOpAccessChain:
case SpvOpInBoundsAccessChain:
work_list.push_back(use);
break;
default:
break;
}
});
bool updated = false;
for (Instruction* access_chain : work_list) {
if (descsroautil::GetAccessChainIndexAsConst(context(), access_chain) ==
nullptr) {
ReplaceAccessChain(var, access_chain);
updated = true;
}
}
// Note that we do not consider OpLoad and OpCompositeExtract because
// OpCompositeExtract always has constant literals for indices.
return updated;
}
void ReplaceDescArrayAccessUsingVarIndex::ReplaceAccessChain(
Instruction* var, Instruction* access_chain) const {
uint32_t number_of_elements =
descsroautil::GetNumberOfElementsForArrayOrStruct(context(), var);
assert(number_of_elements != 0 && "Number of element is 0");
if (number_of_elements == 1) {
UseConstIndexForAccessChain(access_chain, 0);
get_def_use_mgr()->AnalyzeInstUse(access_chain);
return;
}
ReplaceUsersOfAccessChain(access_chain, number_of_elements);
}
void ReplaceDescArrayAccessUsingVarIndex::ReplaceUsersOfAccessChain(
Instruction* access_chain, uint32_t number_of_elements) const {
std::vector<Instruction*> final_users;
CollectRecursiveUsersWithConcreteType(access_chain, &final_users);
for (auto* inst : final_users) {
std::deque<Instruction*> insts_to_be_cloned =
CollectRequiredImageInsts(inst);
ReplaceNonUniformAccessWithSwitchCase(
inst, access_chain, number_of_elements, insts_to_be_cloned);
}
}
void ReplaceDescArrayAccessUsingVarIndex::CollectRecursiveUsersWithConcreteType(
Instruction* access_chain, std::vector<Instruction*>* final_users) const {
std::queue<Instruction*> work_list;
work_list.push(access_chain);
while (!work_list.empty()) {
auto* inst_from_work_list = work_list.front();
work_list.pop();
get_def_use_mgr()->ForEachUser(
inst_from_work_list, [this, final_users, &work_list](Instruction* use) {
// TODO: Support Boolean type as well.
if (!use->HasResultId() || IsConcreteType(use->type_id())) {
final_users->push_back(use);
} else {
work_list.push(use);
}
});
}
}
std::deque<Instruction*>
ReplaceDescArrayAccessUsingVarIndex::CollectRequiredImageInsts(
Instruction* user_of_image_insts) const {
std::unordered_set<uint32_t> seen_inst_ids;
std::queue<Instruction*> work_list;
auto decision_to_include_operand = [this, &seen_inst_ids,
&work_list](uint32_t* idp) {
if (!seen_inst_ids.insert(*idp).second) return;
Instruction* operand = get_def_use_mgr()->GetDef(*idp);
if (context()->get_instr_block(operand) != nullptr &&
HasImageOrImagePtrType(operand)) {
work_list.push(operand);
}
};
std::deque<Instruction*> required_image_insts;
required_image_insts.push_front(user_of_image_insts);
user_of_image_insts->ForEachInId(decision_to_include_operand);
while (!work_list.empty()) {
auto* inst_from_work_list = work_list.front();
work_list.pop();
required_image_insts.push_front(inst_from_work_list);
inst_from_work_list->ForEachInId(decision_to_include_operand);
}
return required_image_insts;
}
bool ReplaceDescArrayAccessUsingVarIndex::HasImageOrImagePtrType(
const Instruction* inst) const {
assert(inst != nullptr && inst->type_id() != 0 && "Invalid instruction");
return IsImageOrImagePtrType(get_def_use_mgr()->GetDef(inst->type_id()));
}
bool ReplaceDescArrayAccessUsingVarIndex::IsImageOrImagePtrType(
const Instruction* type_inst) const {
if (type_inst->opcode() == SpvOpTypeImage ||
type_inst->opcode() == SpvOpTypeSampler ||
type_inst->opcode() == SpvOpTypeSampledImage) {
return true;
}
if (type_inst->opcode() == SpvOpTypePointer) {
Instruction* pointee_type_inst = get_def_use_mgr()->GetDef(
type_inst->GetSingleWordInOperand(kOpTypePointerInOperandType));
return IsImageOrImagePtrType(pointee_type_inst);
}
if (type_inst->opcode() == SpvOpTypeArray) {
Instruction* element_type_inst = get_def_use_mgr()->GetDef(
type_inst->GetSingleWordInOperand(kOpTypeArrayInOperandType));
return IsImageOrImagePtrType(element_type_inst);
}
if (type_inst->opcode() != SpvOpTypeStruct) return false;
for (uint32_t in_operand_idx = kOpTypeStructInOperandMember;
in_operand_idx < type_inst->NumInOperands(); ++in_operand_idx) {
Instruction* member_type_inst = get_def_use_mgr()->GetDef(
type_inst->GetSingleWordInOperand(kOpTypeStructInOperandMember));
if (IsImageOrImagePtrType(member_type_inst)) return true;
}
return false;
}
bool ReplaceDescArrayAccessUsingVarIndex::IsConcreteType(
uint32_t type_id) const {
Instruction* type_inst = get_def_use_mgr()->GetDef(type_id);
if (type_inst->opcode() == SpvOpTypeInt ||
type_inst->opcode() == SpvOpTypeFloat) {
return true;
}
if (type_inst->opcode() == SpvOpTypeVector ||
type_inst->opcode() == SpvOpTypeMatrix ||
type_inst->opcode() == SpvOpTypeArray) {
return IsConcreteType(type_inst->GetSingleWordInOperand(0));
}
if (type_inst->opcode() == SpvOpTypeStruct) {
for (uint32_t i = 0; i < type_inst->NumInOperands(); ++i) {
if (!IsConcreteType(type_inst->GetSingleWordInOperand(i))) return false;
}
return true;
}
return false;
}
BasicBlock* ReplaceDescArrayAccessUsingVarIndex::CreateCaseBlock(
Instruction* access_chain, uint32_t element_index,
const std::deque<Instruction*>& insts_to_be_cloned,
uint32_t branch_target_id,
std::unordered_map<uint32_t, uint32_t>* old_ids_to_new_ids) const {
auto* case_block = CreateNewBlock();
AddConstElementAccessToCaseBlock(case_block, access_chain, element_index,
old_ids_to_new_ids);
CloneInstsToBlock(case_block, access_chain, insts_to_be_cloned,
old_ids_to_new_ids);
AddBranchToBlock(case_block, branch_target_id);
UseNewIdsInBlock(case_block, *old_ids_to_new_ids);
return case_block;
}
void ReplaceDescArrayAccessUsingVarIndex::CloneInstsToBlock(
BasicBlock* block, Instruction* inst_to_skip_cloning,
const std::deque<Instruction*>& insts_to_be_cloned,
std::unordered_map<uint32_t, uint32_t>* old_ids_to_new_ids) const {
for (auto* inst_to_be_cloned : insts_to_be_cloned) {
if (inst_to_be_cloned == inst_to_skip_cloning) continue;
std::unique_ptr<Instruction> clone(inst_to_be_cloned->Clone(context()));
if (inst_to_be_cloned->HasResultId()) {
uint32_t new_id = context()->TakeNextId();
clone->SetResultId(new_id);
(*old_ids_to_new_ids)[inst_to_be_cloned->result_id()] = new_id;
}
get_def_use_mgr()->AnalyzeInstDefUse(clone.get());
context()->set_instr_block(clone.get(), block);
block->AddInstruction(std::move(clone));
}
}
void ReplaceDescArrayAccessUsingVarIndex::UseNewIdsInBlock(
BasicBlock* block,
const std::unordered_map<uint32_t, uint32_t>& old_ids_to_new_ids) const {
for (auto block_itr = block->begin(); block_itr != block->end();
++block_itr) {
(&*block_itr)->ForEachInId([&old_ids_to_new_ids](uint32_t* idp) {
auto old_ids_to_new_ids_itr = old_ids_to_new_ids.find(*idp);
if (old_ids_to_new_ids_itr == old_ids_to_new_ids.end()) return;
*idp = old_ids_to_new_ids_itr->second;
});
get_def_use_mgr()->AnalyzeInstUse(&*block_itr);
}
}
void ReplaceDescArrayAccessUsingVarIndex::ReplaceNonUniformAccessWithSwitchCase(
Instruction* access_chain_final_user, Instruction* access_chain,
uint32_t number_of_elements,
const std::deque<Instruction*>& insts_to_be_cloned) const {
// Create merge block and add terminator
auto* block = context()->get_instr_block(access_chain_final_user);
auto* merge_block = SeparateInstructionsIntoNewBlock(
block, access_chain_final_user->NextNode());
auto* function = block->GetParent();
// Add case blocks
std::vector<uint32_t> phi_operands;
std::vector<uint32_t> case_block_ids;
for (uint32_t idx = 0; idx < number_of_elements; ++idx) {
std::unordered_map<uint32_t, uint32_t> old_ids_to_new_ids_for_cloned_insts;
std::unique_ptr<BasicBlock> case_block(CreateCaseBlock(
access_chain, idx, insts_to_be_cloned, merge_block->id(),
&old_ids_to_new_ids_for_cloned_insts));
case_block_ids.push_back(case_block->id());
function->InsertBasicBlockBefore(std::move(case_block), merge_block);
// Keep the operand for OpPhi
if (!access_chain_final_user->HasResultId()) continue;
uint32_t phi_operand =
GetValueWithKeyExistenceCheck(access_chain_final_user->result_id(),
old_ids_to_new_ids_for_cloned_insts);
phi_operands.push_back(phi_operand);
}
// Create default block
std::unique_ptr<BasicBlock> default_block(
CreateDefaultBlock(access_chain_final_user->HasResultId(), &phi_operands,
merge_block->id()));
uint32_t default_block_id = default_block->id();
function->InsertBasicBlockBefore(std::move(default_block), merge_block);
// Create OpSwitch
uint32_t access_chain_index_var_id =
descsroautil::GetFirstIndexOfAccessChain(access_chain);
AddSwitchForAccessChain(block, access_chain_index_var_id, default_block_id,
merge_block->id(), case_block_ids);
// Create phi instructions
if (!phi_operands.empty()) {
uint32_t phi_id = CreatePhiInstruction(merge_block, phi_operands,
case_block_ids, default_block_id);
context()->ReplaceAllUsesWith(access_chain_final_user->result_id(), phi_id);
}
// Replace OpPhi incoming block operand that uses |block| with |merge_block|
ReplacePhiIncomingBlock(block->id(), merge_block->id());
}
BasicBlock*
ReplaceDescArrayAccessUsingVarIndex::SeparateInstructionsIntoNewBlock(
BasicBlock* block, Instruction* separation_begin_inst) const {
auto separation_begin = block->begin();
while (separation_begin != block->end() &&
&*separation_begin != separation_begin_inst) {
++separation_begin;
}
return block->SplitBasicBlock(context(), context()->TakeNextId(),
separation_begin);
}
BasicBlock* ReplaceDescArrayAccessUsingVarIndex::CreateNewBlock() const {
auto* new_block = new BasicBlock(std::unique_ptr<Instruction>(
new Instruction(context(), SpvOpLabel, 0, context()->TakeNextId(), {})));
get_def_use_mgr()->AnalyzeInstDefUse(new_block->GetLabelInst());
context()->set_instr_block(new_block->GetLabelInst(), new_block);
return new_block;
}
void ReplaceDescArrayAccessUsingVarIndex::UseConstIndexForAccessChain(
Instruction* access_chain, uint32_t const_element_idx) const {
uint32_t const_element_idx_id =
context()->get_constant_mgr()->GetUIntConst(const_element_idx);
access_chain->SetInOperand(kOpAccessChainInOperandIndexes,
{const_element_idx_id});
}
void ReplaceDescArrayAccessUsingVarIndex::AddConstElementAccessToCaseBlock(
BasicBlock* case_block, Instruction* access_chain,
uint32_t const_element_idx,
std::unordered_map<uint32_t, uint32_t>* old_ids_to_new_ids) const {
std::unique_ptr<Instruction> access_clone(access_chain->Clone(context()));
UseConstIndexForAccessChain(access_clone.get(), const_element_idx);
uint32_t new_access_id = context()->TakeNextId();
(*old_ids_to_new_ids)[access_clone->result_id()] = new_access_id;
access_clone->SetResultId(new_access_id);
get_def_use_mgr()->AnalyzeInstDefUse(access_clone.get());
context()->set_instr_block(access_clone.get(), case_block);
case_block->AddInstruction(std::move(access_clone));
}
void ReplaceDescArrayAccessUsingVarIndex::AddBranchToBlock(
BasicBlock* parent_block, uint32_t branch_destination) const {
InstructionBuilder builder{context(), parent_block,
kAnalysisDefUseAndInstrToBlockMapping};
builder.AddBranch(branch_destination);
}
BasicBlock* ReplaceDescArrayAccessUsingVarIndex::CreateDefaultBlock(
bool null_const_for_phi_is_needed, std::vector<uint32_t>* phi_operands,
uint32_t merge_block_id) const {
auto* default_block = CreateNewBlock();
AddBranchToBlock(default_block, merge_block_id);
if (!null_const_for_phi_is_needed) return default_block;
// Create null value for OpPhi
Instruction* inst = context()->get_def_use_mgr()->GetDef((*phi_operands)[0]);
auto* null_const_inst = GetConstNull(inst->type_id());
phi_operands->push_back(null_const_inst->result_id());
return default_block;
}
Instruction* ReplaceDescArrayAccessUsingVarIndex::GetConstNull(
uint32_t type_id) const {
assert(type_id != 0 && "Result type is expected");
auto* type = context()->get_type_mgr()->GetType(type_id);
auto* null_const = context()->get_constant_mgr()->GetConstant(type, {});
return context()->get_constant_mgr()->GetDefiningInstruction(null_const);
}
void ReplaceDescArrayAccessUsingVarIndex::AddSwitchForAccessChain(
BasicBlock* parent_block, uint32_t access_chain_index_var_id,
uint32_t default_id, uint32_t merge_id,
const std::vector<uint32_t>& case_block_ids) const {
InstructionBuilder builder{context(), parent_block,
kAnalysisDefUseAndInstrToBlockMapping};
std::vector<std::pair<Operand::OperandData, uint32_t>> cases;
for (uint32_t i = 0; i < static_cast<uint32_t>(case_block_ids.size()); ++i) {
cases.emplace_back(Operand::OperandData{i}, case_block_ids[i]);
}
builder.AddSwitch(access_chain_index_var_id, default_id, cases, merge_id);
}
uint32_t ReplaceDescArrayAccessUsingVarIndex::CreatePhiInstruction(
BasicBlock* parent_block, const std::vector<uint32_t>& phi_operands,
const std::vector<uint32_t>& case_block_ids,
uint32_t default_block_id) const {
std::vector<uint32_t> incomings;
assert(case_block_ids.size() + 1 == phi_operands.size() &&
"Number of Phi operands must be exactly 1 bigger than the one of case "
"blocks");
for (size_t i = 0; i < case_block_ids.size(); ++i) {
incomings.push_back(phi_operands[i]);
incomings.push_back(case_block_ids[i]);
}
incomings.push_back(phi_operands.back());
incomings.push_back(default_block_id);
InstructionBuilder builder{context(), &*parent_block->begin(),
kAnalysisDefUseAndInstrToBlockMapping};
uint32_t phi_result_type_id =
context()->get_def_use_mgr()->GetDef(phi_operands[0])->type_id();
auto* phi = builder.AddPhi(phi_result_type_id, incomings);
return phi->result_id();
}
void ReplaceDescArrayAccessUsingVarIndex::ReplacePhiIncomingBlock(
uint32_t old_incoming_block_id, uint32_t new_incoming_block_id) const {
context()->ReplaceAllUsesWithPredicate(
old_incoming_block_id, new_incoming_block_id,
[](Instruction* use) { return use->opcode() == SpvOpPhi; });
}
} // namespace opt
} // namespace spvtools

View File

@ -0,0 +1,204 @@
// Copyright (c) 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SOURCE_OPT_REPLACE_DESC_VAR_INDEX_ACCESS_H_
#define SOURCE_OPT_REPLACE_DESC_VAR_INDEX_ACCESS_H_
#include <cstdio>
#include <memory>
#include <queue>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "source/opt/function.h"
#include "source/opt/pass.h"
#include "source/opt/type_manager.h"
namespace spvtools {
namespace opt {
// See optimizer.hpp for documentation.
class ReplaceDescArrayAccessUsingVarIndex : public Pass {
public:
ReplaceDescArrayAccessUsingVarIndex() {}
const char* name() const override {
return "replace-desc-array-access-using-var-index";
}
Status Process() override;
IRContext::Analysis GetPreservedAnalyses() override {
return IRContext::kAnalysisDefUse |
IRContext::kAnalysisInstrToBlockMapping |
IRContext::kAnalysisConstants | IRContext::kAnalysisTypes;
}
private:
// Replaces all acceses to |var| using variable indices with constant
// elements of the array |var|. Creates switch-case statements to determine
// the value of the variable index for all the possible cases. Returns
// whether replacement is done or not.
bool ReplaceVariableAccessesWithConstantElements(Instruction* var) const;
// Replaces the OpAccessChain or OpInBoundsAccessChain instruction |use| that
// uses the descriptor variable |var| with the OpAccessChain or
// OpInBoundsAccessChain instruction with a constant Indexes operand.
void ReplaceAccessChain(Instruction* var, Instruction* use) const;
// Updates the first Indexes operand of the OpAccessChain or
// OpInBoundsAccessChain instruction |access_chain| to let it use a constant
// index |const_element_idx|.
void UseConstIndexForAccessChain(Instruction* access_chain,
uint32_t const_element_idx) const;
// Replaces users of the OpAccessChain or OpInBoundsAccessChain instruction
// |access_chain| that accesses an array descriptor variable using variable
// indices with constant elements. |number_of_elements| is the number
// of array elements.
void ReplaceUsersOfAccessChain(Instruction* access_chain,
uint32_t number_of_elements) const;
// Puts all the recursive users of |access_chain| with concrete result types
// or the ones without result it in |final_users|.
void CollectRecursiveUsersWithConcreteType(
Instruction* access_chain, std::vector<Instruction*>* final_users) const;
// Recursively collects the operands of |user_of_image_insts| (and operands
// of the operands) whose result types are images/samplers or pointers/array/
// struct of them and returns them.
std::deque<Instruction*> CollectRequiredImageInsts(
Instruction* user_of_image_insts) const;
// Returns whether result type of |inst| is an image/sampler/pointer of image
// or sampler or not.
bool HasImageOrImagePtrType(const Instruction* inst) const;
// Returns whether |type_inst| is an image/sampler or pointer/array/struct of
// image or sampler or not.
bool IsImageOrImagePtrType(const Instruction* type_inst) const;
// Returns whether the type with |type_id| is a concrete type or not.
bool IsConcreteType(uint32_t type_id) const;
// Replaces the non-uniform access to a descriptor variable
// |access_chain_final_user| with OpSwitch instruction and case blocks. Each
// case block will contain a clone of |access_chain| and clones of
// |non_uniform_accesses_to_clone| that are recursively used by
// |access_chain_final_user|. The clone of |access_chain| (or
// OpInBoundsAccessChain) will have a constant index for its first index. The
// OpSwitch instruction will have the cases for the variable index of
// |access_chain| from 0 to |number_of_elements| - 1.
void ReplaceNonUniformAccessWithSwitchCase(
Instruction* access_chain_final_user, Instruction* access_chain,
uint32_t number_of_elements,
const std::deque<Instruction*>& non_uniform_accesses_to_clone) const;
// Creates and returns a new basic block that contains all instructions of
// |block| after |separation_begin_inst|. The new basic block is added to the
// function in this method.
BasicBlock* SeparateInstructionsIntoNewBlock(
BasicBlock* block, Instruction* separation_begin_inst) const;
// Creates and returns a new block.
BasicBlock* CreateNewBlock() const;
// Returns the first operand id of the OpAccessChain or OpInBoundsAccessChain
// instruction |access_chain|.
uint32_t GetFirstIndexOfAccessChain(Instruction* access_chain) const;
// Adds a clone of the OpAccessChain or OpInBoundsAccessChain instruction
// |access_chain| to |case_block|. The clone of |access_chain| will use
// |const_element_idx| for its first index. |old_ids_to_new_ids| keeps the
// mapping from the result id of |access_chain| to the result of its clone.
void AddConstElementAccessToCaseBlock(
BasicBlock* case_block, Instruction* access_chain,
uint32_t const_element_idx,
std::unordered_map<uint32_t, uint32_t>* old_ids_to_new_ids) const;
// Clones all instructions in |insts_to_be_cloned| and put them to |block|.
// |old_ids_to_new_ids| keeps the mapping from the result id of each
// instruction of |insts_to_be_cloned| to the result of their clones.
void CloneInstsToBlock(
BasicBlock* block, Instruction* inst_to_skip_cloning,
const std::deque<Instruction*>& insts_to_be_cloned,
std::unordered_map<uint32_t, uint32_t>* old_ids_to_new_ids) const;
// Adds OpBranch to |branch_destination| at the end of |parent_block|.
void AddBranchToBlock(BasicBlock* parent_block,
uint32_t branch_destination) const;
// Replaces in-operands of all instructions in the basic block |block| using
// |old_ids_to_new_ids|. It conducts the replacement only if the in-operand
// id is a key of |old_ids_to_new_ids|.
void UseNewIdsInBlock(
BasicBlock* block,
const std::unordered_map<uint32_t, uint32_t>& old_ids_to_new_ids) const;
// Creates a case block for |element_index| case. It adds clones of
// |insts_to_be_cloned| and a clone of |access_chain| with |element_index| as
// its first index. The termination instruction of the created case block will
// be a branch to |branch_target_id|. Puts old ids to new ids map for the
// cloned instructions in |old_ids_to_new_ids|.
BasicBlock* CreateCaseBlock(
Instruction* access_chain, uint32_t element_index,
const std::deque<Instruction*>& insts_to_be_cloned,
uint32_t branch_target_id,
std::unordered_map<uint32_t, uint32_t>* old_ids_to_new_ids) const;
// Creates a default block for switch-case statement that has only a single
// instruction OpBranch whose target is a basic block with |merge_block_id|.
// If |null_const_for_phi_is_needed| is true, gets or creates a default null
// constant value for a phi instruction whose operands are |phi_operands| and
// puts it in |phi_operands|.
BasicBlock* CreateDefaultBlock(bool null_const_for_phi_is_needed,
std::vector<uint32_t>* phi_operands,
uint32_t merge_block_id) const;
// Creates and adds an OpSwitch used for the selection of OpAccessChain whose
// first Indexes operand is |access_chain_index_var_id|. The OpSwitch will be
// added at the end of |parent_block|. It will jump to |default_id| for the
// default case and jumps to one of case blocks whoes ids are |case_block_ids|
// if |access_chain_index_var_id| matches the case number. |merge_id| is the
// merge block id.
void AddSwitchForAccessChain(
BasicBlock* parent_block, uint32_t access_chain_index_var_id,
uint32_t default_id, uint32_t merge_id,
const std::vector<uint32_t>& case_block_ids) const;
// Creates a phi instruction with |phi_operands| as values and
// |case_block_ids| and |default_block_id| as incoming blocks. The size of
// |phi_operands| must be exactly 1 larger than the size of |case_block_ids|.
// The last element of |phi_operands| will be used for |default_block_id|. It
// adds the phi instruction to the beginning of |parent_block|.
uint32_t CreatePhiInstruction(BasicBlock* parent_block,
const std::vector<uint32_t>& phi_operands,
const std::vector<uint32_t>& case_block_ids,
uint32_t default_block_id) const;
// Replaces the incoming block operand of OpPhi instructions with
// |new_incoming_block_id| if the incoming block operand is
// |old_incoming_block_id|.
void ReplacePhiIncomingBlock(uint32_t old_incoming_block_id,
uint32_t new_incoming_block_id) const;
// Create an OpConstantNull instruction whose result type id is |type_id|.
Instruction* GetConstNull(uint32_t type_id) const;
};
} // namespace opt
} // namespace spvtools
#endif // SOURCE_OPT_REPLACE_DESC_VAR_INDEX_ACCESS_H_

View File

@ -85,6 +85,7 @@ add_spvtools_unittest(TARGET opt
remove_unused_interface_variables_test.cpp
register_liveness.cpp
relax_float_ops_test.cpp
replace_desc_array_access_using_var_index_test.cpp
replace_invalid_opc_test.cpp
scalar_analysis.cpp
scalar_replacement_test.cpp

View File

@ -0,0 +1,411 @@
// Copyright (c) 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "gmock/gmock.h"
#include "test/opt/assembly_builder.h"
#include "test/opt/pass_fixture.h"
#include "test/opt/pass_utils.h"
namespace spvtools {
namespace opt {
namespace {
using ReplaceDescArrayAccessUsingVarIndexTest = PassTest<::testing::Test>;
TEST_F(ReplaceDescArrayAccessUsingVarIndexTest,
ReplaceAccessChainToTextureArray) {
const std::string text = R"(
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %psmain "psmain" %gl_FragCoord %in_var_INSTANCEID %out_var_SV_TARGET
OpExecutionMode %psmain OriginUpperLeft
OpSource HLSL 600
OpName %type_sampler "type.sampler"
OpName %Sampler0 "Sampler0"
OpName %type_2d_image "type.2d.image"
OpName %Tex0 "Tex0"
OpName %in_var_INSTANCEID "in.var.INSTANCEID"
OpName %out_var_SV_TARGET "out.var.SV_TARGET"
OpName %psmain "psmain"
OpName %type_sampled_image "type.sampled.image"
OpDecorate %gl_FragCoord BuiltIn FragCoord
OpDecorate %in_var_INSTANCEID Flat
OpDecorate %in_var_INSTANCEID Location 0
OpDecorate %out_var_SV_TARGET Location 0
OpDecorate %Sampler0 DescriptorSet 0
OpDecorate %Sampler0 Binding 1
OpDecorate %Tex0 DescriptorSet 0
OpDecorate %Tex0 Binding 2
%bool = OpTypeBool
%type_sampler = OpTypeSampler
%_ptr_UniformConstant_type_sampler = OpTypePointer UniformConstant %type_sampler
%uint = OpTypeInt 32 0
%uint_3 = OpConstant %uint 3
%float = OpTypeFloat 32
%type_2d_image = OpTypeImage %float 2D 2 0 0 0 Unknown
%_arr_type_2d_image_uint_3 = OpTypeArray %type_2d_image %uint_3
%_ptr_UniformConstant__arr_type_2d_image_uint_3 = OpTypePointer UniformConstant %_arr_type_2d_image_uint_3
%v4float = OpTypeVector %float 4
%_ptr_Input_v4float = OpTypePointer Input %v4float
%_ptr_Input_uint = OpTypePointer Input %uint
%_ptr_Output_v4float = OpTypePointer Output %v4float
%void = OpTypeVoid
%21 = OpTypeFunction %void
%_ptr_UniformConstant_type_2d_image = OpTypePointer UniformConstant %type_2d_image
%v2float = OpTypeVector %float 2
%v2uint = OpTypeVector %uint 2
%uint_0 = OpConstant %uint 0
%uint_1 = OpConstant %uint 1
%27 = OpConstantComposite %v2uint %uint_0 %uint_1
%type_sampled_image = OpTypeSampledImage %type_2d_image
%Sampler0 = OpVariable %_ptr_UniformConstant_type_sampler UniformConstant
%Tex0 = OpVariable %_ptr_UniformConstant__arr_type_2d_image_uint_3 UniformConstant
%gl_FragCoord = OpVariable %_ptr_Input_v4float Input
%in_var_INSTANCEID = OpVariable %_ptr_Input_uint Input
%out_var_SV_TARGET = OpVariable %_ptr_Output_v4float Output
%uint_2 = OpConstant %uint 2
%66 = OpConstantNull %v4float
; CHECK: [[null_value:%\w+]] = OpConstantNull %v4float
%psmain = OpFunction %void None %21
%39 = OpLabel
%29 = OpLoad %v4float %gl_FragCoord
%30 = OpLoad %uint %in_var_INSTANCEID
%37 = OpIEqual %bool %30 %uint_2
OpSelectionMerge %38 None
OpBranchConditional %37 %28 %40
; CHECK: [[var_index:%\w+]] = OpLoad %uint %in_var_INSTANCEID
; CHECK: OpSelectionMerge [[cond_branch_merge:%\w+]] None
; CHECK: OpBranchConditional {{%\w+}} {{%\w+}} [[bb_cond_br:%\w+]]
%28 = OpLabel
%31 = OpAccessChain %_ptr_UniformConstant_type_2d_image %Tex0 %30
%32 = OpLoad %type_2d_image %31
OpImageWrite %32 %27 %29
; CHECK: OpSelectionMerge [[merge:%\w+]] None
; CHECK: OpSwitch [[var_index]] [[default:%\w+]] 0 [[case0:%\w+]] 1 [[case1:%\w+]] 2 [[case2:%\w+]]
; CHECK: [[case0]] = OpLabel
; CHECK: OpAccessChain
; CHECK: OpLoad
; CHECK: OpImageWrite
; CHECK: OpBranch [[merge]]
; CHECK: [[case1]] = OpLabel
; CHECK: OpAccessChain
; CHECK: OpLoad
; CHECK: OpImageWrite
; CHECK: OpBranch [[merge]]
; CHECK: [[case2]] = OpLabel
; CHECK: OpAccessChain
; CHECK: OpLoad
; CHECK: OpImageWrite
; CHECK: OpBranch [[merge]]
; CHECK: [[default]] = OpLabel
; CHECK: OpBranch [[merge]]
; CHECK: [[merge]] = OpLabel
%33 = OpLoad %type_sampler %Sampler0
%34 = OpVectorShuffle %v2float %29 %29 0 1
%35 = OpSampledImage %type_sampled_image %32 %33
%36 = OpImageSampleImplicitLod %v4float %35 %34 None
; CHECK: OpSelectionMerge [[merge:%\w+]] None
; CHECK: OpSwitch [[var_index]] [[default:%\w+]] 0 [[case0:%\w+]] 1 [[case1:%\w+]] 2 [[case2:%\w+]]
; CHECK: [[case0]] = OpLabel
; CHECK: [[ac:%\w+]] = OpAccessChain %_ptr_UniformConstant_type_2d_image %Tex0 %uint_0
; CHECK: [[sam:%\w+]] = OpLoad %type_sampler %Sampler0
; CHECK: [[img:%\w+]] = OpLoad %type_2d_image [[ac]]
; CHECK: [[sampledImg:%\w+]] = OpSampledImage %type_sampled_image [[img]] [[sam]]
; CHECK: [[value0:%\w+]] = OpImageSampleImplicitLod %v4float [[sampledImg]]
; CHECK: OpBranch [[merge]]
; CHECK: [[case1]] = OpLabel
; CHECK: [[ac:%\w+]] = OpAccessChain %_ptr_UniformConstant_type_2d_image %Tex0 %uint_1
; CHECK: [[sam:%\w+]] = OpLoad %type_sampler %Sampler0
; CHECK: [[img:%\w+]] = OpLoad %type_2d_image [[ac]]
; CHECK: [[sampledImg:%\w+]] = OpSampledImage %type_sampled_image [[img]] [[sam]]
; CHECK: [[value1:%\w+]] = OpImageSampleImplicitLod %v4float [[sampledImg]]
; CHECK: OpBranch [[merge]]
; CHECK: [[case2]] = OpLabel
; CHECK: [[ac:%\w+]] = OpAccessChain %_ptr_UniformConstant_type_2d_image %Tex0 %uint_2
; CHECK: [[sam:%\w+]] = OpLoad %type_sampler %Sampler0
; CHECK: [[img:%\w+]] = OpLoad %type_2d_image [[ac]]
; CHECK: [[sampledImg:%\w+]] = OpSampledImage %type_sampled_image [[img]] [[sam]]
; CHECK: [[value2:%\w+]] = OpImageSampleImplicitLod %v4float [[sampledImg]]
; CHECK: OpBranch [[merge]]
; CHECK: [[default]] = OpLabel
; CHECK: OpBranch [[merge]]
; CHECK: [[merge]] = OpLabel
; CHECK: [[phi0:%\w+]] = OpPhi %v4float [[value0]] [[case0]] [[value1]] [[case1]] [[value2]] [[case2]] [[null_value]] [[default]]
OpBranch %38
%40 = OpLabel
OpBranch %38
%38 = OpLabel
%41 = OpPhi %v4float %36 %28 %29 %40
; CHECK: OpBranch [[cond_branch_merge]]
; CHECK: [[bb_cond_br]] = OpLabel
; CHECK: OpBranch [[cond_branch_merge]]
; CHECK: [[cond_branch_merge]] = OpLabel
; CHECK: [[phi1:%\w+]] = OpPhi %v4float [[phi0]] [[merge]] {{%\w+}} [[bb_cond_br]]
; CHECK: OpStore {{%\w+}} [[phi1]]
OpStore %out_var_SV_TARGET %41
OpReturn
OpFunctionEnd
)";
SinglePassRunAndMatch<ReplaceDescArrayAccessUsingVarIndex>(text, true);
}
TEST_F(ReplaceDescArrayAccessUsingVarIndexTest,
ReplaceAccessChainToTextureArrayAndSamplerArray) {
const std::string text = R"(
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %psmain "psmain" %gl_FragCoord %in_var_INSTANCEID %out_var_SV_TARGET
OpExecutionMode %psmain OriginUpperLeft
OpSource HLSL 600
OpName %type_sampler "type.sampler"
OpName %Sampler0 "Sampler0"
OpName %type_2d_image "type.2d.image"
OpName %Tex0 "Tex0"
OpName %in_var_INSTANCEID "in.var.INSTANCEID"
OpName %out_var_SV_TARGET "out.var.SV_TARGET"
OpName %psmain "psmain"
OpName %type_sampled_image "type.sampled.image"
OpDecorate %gl_FragCoord BuiltIn FragCoord
OpDecorate %in_var_INSTANCEID Flat
OpDecorate %in_var_INSTANCEID Location 0
OpDecorate %out_var_SV_TARGET Location 0
OpDecorate %Sampler0 DescriptorSet 0
OpDecorate %Sampler0 Binding 1
OpDecorate %Tex0 DescriptorSet 0
OpDecorate %Tex0 Binding 2
%type_sampler = OpTypeSampler
%uint = OpTypeInt 32 0
%uint_2 = OpConstant %uint 2
%_ptr_UniformConstant_type_sampler = OpTypePointer UniformConstant %type_sampler
%_arr_type_sampler_uint_2 = OpTypeArray %type_sampler %uint_2
%_ptr_UniformConstant__arr_type_sampler_uint_2 = OpTypePointer UniformConstant %_arr_type_sampler_uint_2
%float = OpTypeFloat 32
%type_2d_image = OpTypeImage %float 2D 2 0 0 0 Unknown
%_arr_type_2d_image_uint_2 = OpTypeArray %type_2d_image %uint_2
%_ptr_UniformConstant__arr_type_2d_image_uint_2 = OpTypePointer UniformConstant %_arr_type_2d_image_uint_2
%v4float = OpTypeVector %float 4
%_ptr_Input_v4float = OpTypePointer Input %v4float
%_ptr_Input_uint = OpTypePointer Input %uint
%_ptr_Output_v4float = OpTypePointer Output %v4float
%void = OpTypeVoid
%21 = OpTypeFunction %void
%_ptr_UniformConstant_type_2d_image = OpTypePointer UniformConstant %type_2d_image
%v2float = OpTypeVector %float 2
%v2uint = OpTypeVector %uint 2
%uint_0 = OpConstant %uint 0
%uint_1 = OpConstant %uint 1
%27 = OpConstantComposite %v2uint %uint_0 %uint_1
%type_sampled_image = OpTypeSampledImage %type_2d_image
%Sampler0 = OpVariable %_ptr_UniformConstant__arr_type_sampler_uint_2 UniformConstant
%Tex0 = OpVariable %_ptr_UniformConstant__arr_type_2d_image_uint_2 UniformConstant
%gl_FragCoord = OpVariable %_ptr_Input_v4float Input
%in_var_INSTANCEID = OpVariable %_ptr_Input_uint Input
%out_var_SV_TARGET = OpVariable %_ptr_Output_v4float Output
%66 = OpConstantNull %v4float
%psmain = OpFunction %void None %21
%28 = OpLabel
%29 = OpLoad %v4float %gl_FragCoord
%30 = OpLoad %uint %in_var_INSTANCEID
%31 = OpAccessChain %_ptr_UniformConstant_type_2d_image %Tex0 %30
%32 = OpLoad %type_2d_image %31
OpImageWrite %32 %27 %29
; CHECK: [[null_value:%\w+]] = OpConstantNull %v4float
; CHECK: [[var_index:%\w+]] = OpLoad %uint %in_var_INSTANCEID
; CHECK: OpSelectionMerge [[merge:%\w+]] None
; CHECK: OpSwitch [[var_index]] [[default:%\w+]] 0 [[case0:%\w+]] 1 [[case1:%\w+]]
; CHECK: [[case0]] = OpLabel
; CHECK: OpAccessChain
; CHECK: OpLoad
; CHECK: OpImageWrite
; CHECK: OpBranch [[merge]]
; CHECK: [[case1]] = OpLabel
; CHECK: OpAccessChain
; CHECK: OpLoad
; CHECK: OpImageWrite
; CHECK: OpBranch [[merge]]
; CHECK: [[default]] = OpLabel
; CHECK: OpBranch [[merge]]
; CHECK: [[merge]] = OpLabel
%33 = OpAccessChain %_ptr_UniformConstant_type_sampler %Sampler0 %30
%37 = OpLoad %type_sampler %33
%34 = OpVectorShuffle %v2float %29 %29 0 1
%35 = OpSampledImage %type_sampled_image %32 %37
%36 = OpImageSampleImplicitLod %v4float %35 %34 None
; SPIR-V instructions to be replaced (will be killed by ADCE)
; CHECK: OpSelectionMerge
; CHECK: OpSwitch
; CHECK: OpSelectionMerge [[merge_sampler:%\w+]] None
; CHECK: OpSwitch [[var_index]] [[default_sampler:%\w+]] 0 [[case_sampler0:%\w+]] 1 [[case_sampler1:%\w+]]
; CHECK: [[case_sampler0]] = OpLabel
; CHECK: OpSelectionMerge [[merge_texture0:%\w+]] None
; CHECK: OpSwitch [[var_index]] [[default_texture:%\w+]] 0 [[case_texture0:%\w+]] 1 [[case_texture1:%\w+]]
; CHECK: [[case_texture0]] = OpLabel
; CHECK: [[pt0:%\w+]] = OpAccessChain %_ptr_UniformConstant_type_2d_image %Tex0 %uint_0
; CHECK: [[ps0:%\w+]] = OpAccessChain %_ptr_UniformConstant_type_sampler %Sampler0 %uint_0
; CHECK: [[s0:%\w+]] = OpLoad %type_sampler [[ps0]]
; CHECK: [[t0:%\w+]] = OpLoad %type_2d_image [[pt0]]
; CHECK: [[sampledImg0:%\w+]] = OpSampledImage %type_sampled_image [[t0]] [[s0]]
; CHECK: [[value0:%\w+]] = OpImageSampleImplicitLod %v4float [[sampledImg0]]
; CHECK: OpBranch [[merge_texture0]]
; CHECK: [[case_texture1]] = OpLabel
; CHECK: [[pt1:%\w+]] = OpAccessChain %_ptr_UniformConstant_type_2d_image %Tex0 %uint_1
; CHECK: [[ps0:%\w+]] = OpAccessChain %_ptr_UniformConstant_type_sampler %Sampler0 %uint_0
; CHECK: [[s0:%\w+]] = OpLoad %type_sampler [[ps0]]
; CHECK: [[t1:%\w+]] = OpLoad %type_2d_image [[pt1]]
; CHECK: [[sampledImg1:%\w+]] = OpSampledImage %type_sampled_image [[t1]] [[s0]]
; CHECK: [[value1:%\w+]] = OpImageSampleImplicitLod %v4float [[sampledImg1]]
; CHECK: OpBranch [[merge_texture0]]
; CHECK: [[default_texture]] = OpLabel
; CHECK: OpBranch [[merge_texture0]]
; CHECK: [[merge_texture0]] = OpLabel
; CHECK: [[phi0:%\w+]] = OpPhi %v4float [[value0]] [[case_texture0]] [[value1]] [[case_texture1]] [[null_value]] [[default_texture]]
; CHECK: OpBranch [[merge_sampler]]
; CHECK: [[case_sampler1]] = OpLabel
; CHECK: OpSelectionMerge [[merge_texture1:%\w+]] None
; CHECK: OpSwitch [[var_index]] [[default_texture:%\w+]] 0 [[case_texture0:%\w+]] 1 [[case_texture1:%\w+]]
; CHECK: [[case_texture0]] = OpLabel
; CHECK: [[pt0:%\w+]] = OpAccessChain %_ptr_UniformConstant_type_2d_image %Tex0 %uint_0
; CHECK: [[ps1:%\w+]] = OpAccessChain %_ptr_UniformConstant_type_sampler %Sampler0 %uint_1
; CHECK: [[s1:%\w+]] = OpLoad %type_sampler [[ps1]]
; CHECK: [[t0:%\w+]] = OpLoad %type_2d_image [[pt0]]
; CHECK: [[sampledImg0:%\w+]] = OpSampledImage %type_sampled_image [[t0]] [[s1]]
; CHECK: [[value0:%\w+]] = OpImageSampleImplicitLod %v4float [[sampledImg0]]
; CHECK: OpBranch [[merge_texture1]]
; CHECK: [[case_texture1]] = OpLabel
; CHECK: [[pt1:%\w+]] = OpAccessChain %_ptr_UniformConstant_type_2d_image %Tex0 %uint_1
; CHECK: [[ps1:%\w+]] = OpAccessChain %_ptr_UniformConstant_type_sampler %Sampler0 %uint_1
; CHECK: [[s1:%\w+]] = OpLoad %type_sampler [[ps1]]
; CHECK: [[t1:%\w+]] = OpLoad %type_2d_image [[pt1]]
; CHECK: [[sampledImg1:%\w+]] = OpSampledImage %type_sampled_image [[t1]] [[s1]]
; CHECK: [[value1:%\w+]] = OpImageSampleImplicitLod %v4float [[sampledImg1]]
; CHECK: OpBranch [[merge_texture1]]
; CHECK: [[default_texture]] = OpLabel
; CHECK: OpBranch [[merge_texture1]]
; CHECK: [[merge_texture1]] = OpLabel
; CHECK: [[phi1:%\w+]] = OpPhi %v4float [[value0]] [[case_texture0]] [[value1]] [[case_texture1]] [[null_value]] [[default_texture]]
; CHECK: [[default_sampler]] = OpLabel
; CHECK: OpBranch [[merge_sampler]]
; CHECK: [[merge_sampler]] = OpLabel
; CHECK: OpPhi %v4float [[phi0]] [[merge_texture0]] [[phi1]] [[merge_texture1]] [[null_value]] [[default_sampler]]
; CHECK: OpStore
OpStore %out_var_SV_TARGET %36
OpReturn
OpFunctionEnd
)";
SinglePassRunAndMatch<ReplaceDescArrayAccessUsingVarIndex>(text, true);
}
TEST_F(ReplaceDescArrayAccessUsingVarIndexTest,
ReplaceAccessChainToTextureArrayWithSingleElement) {
const std::string text = R"(
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %psmain "psmain" %gl_FragCoord %in_var_INSTANCEID %out_var_SV_TARGET
OpExecutionMode %psmain OriginUpperLeft
OpSource HLSL 600
OpName %type_sampler "type.sampler"
OpName %Sampler0 "Sampler0"
OpName %type_2d_image "type.2d.image"
OpName %Tex0 "Tex0"
OpName %in_var_INSTANCEID "in.var.INSTANCEID"
OpName %out_var_SV_TARGET "out.var.SV_TARGET"
OpName %psmain "psmain"
OpName %type_sampled_image "type.sampled.image"
OpDecorate %gl_FragCoord BuiltIn FragCoord
OpDecorate %in_var_INSTANCEID Flat
OpDecorate %in_var_INSTANCEID Location 0
OpDecorate %out_var_SV_TARGET Location 0
OpDecorate %Sampler0 DescriptorSet 0
OpDecorate %Sampler0 Binding 1
OpDecorate %Tex0 DescriptorSet 0
OpDecorate %Tex0 Binding 2
%type_sampler = OpTypeSampler
%_ptr_UniformConstant_type_sampler = OpTypePointer UniformConstant %type_sampler
%uint = OpTypeInt 32 0
%uint_1 = OpConstant %uint 1
%float = OpTypeFloat 32
%type_2d_image = OpTypeImage %float 2D 2 0 0 0 Unknown
%_arr_type_2d_image_uint_1 = OpTypeArray %type_2d_image %uint_1
%_ptr_UniformConstant__arr_type_2d_image_uint_1 = OpTypePointer UniformConstant %_arr_type_2d_image_uint_1
%v4float = OpTypeVector %float 4
%_ptr_Input_v4float = OpTypePointer Input %v4float
%_ptr_Input_uint = OpTypePointer Input %uint
%_ptr_Output_v4float = OpTypePointer Output %v4float
%void = OpTypeVoid
%21 = OpTypeFunction %void
%_ptr_UniformConstant_type_2d_image = OpTypePointer UniformConstant %type_2d_image
%v2float = OpTypeVector %float 2
%v2uint = OpTypeVector %uint 2
%uint_0 = OpConstant %uint 0
%27 = OpConstantComposite %v2uint %uint_0 %uint_1
%type_sampled_image = OpTypeSampledImage %type_2d_image
%Sampler0 = OpVariable %_ptr_UniformConstant_type_sampler UniformConstant
%Tex0 = OpVariable %_ptr_UniformConstant__arr_type_2d_image_uint_1 UniformConstant
%gl_FragCoord = OpVariable %_ptr_Input_v4float Input
%in_var_INSTANCEID = OpVariable %_ptr_Input_uint Input
%out_var_SV_TARGET = OpVariable %_ptr_Output_v4float Output
%uint_2 = OpConstant %uint 2
%66 = OpConstantNull %v4float
%psmain = OpFunction %void None %21
%28 = OpLabel
%29 = OpLoad %v4float %gl_FragCoord
%30 = OpLoad %uint %in_var_INSTANCEID
%31 = OpAccessChain %_ptr_UniformConstant_type_2d_image %Tex0 %30
%32 = OpLoad %type_2d_image %31
OpImageWrite %32 %27 %29
; CHECK: [[ac:%\w+]] = OpAccessChain %_ptr_UniformConstant_type_2d_image %Tex0 %uint_0
; CHECK-NOT: OpAccessChain
; CHECK-NOT: OpSwitch
; CHECK-NOT: OpPhi
%33 = OpLoad %type_sampler %Sampler0
%34 = OpVectorShuffle %v2float %29 %29 0 1
%35 = OpSampledImage %type_sampled_image %32 %33
%36 = OpImageSampleImplicitLod %v4float %35 %34 None
OpStore %out_var_SV_TARGET %36
OpReturn
OpFunctionEnd
)";
SinglePassRunAndMatch<ReplaceDescArrayAccessUsingVarIndex>(text, true);
}
} // namespace
} // namespace opt
} // namespace spvtools

View File

@ -163,6 +163,11 @@ Options (in lexicographical order):)",
around known issues with some Vulkan drivers for initialize
variables.)");
printf(R"(
--replace-desc-array-access-using-var-index
Replaces accesses to descriptor arrays based on a variable index
with a switch that has a case for every possible value of the
index.)");
printf(R"(
--descriptor-scalar-replacement
Replaces every array variable |desc| that has a DescriptorSet
and Binding decorations with a new variable for each element of