SPIRV-Tools/source/opt/local_access_chain_convert_pass.cpp
Nathan Gauër ad11927e6c
opt: add SPV_EXT_mesh_shader to opt allowlist (#5551)
Add this extension to the allowlist, allowing DCE and other
optimizations on modules exposing this.
Note: NV equivalent is already allowed.
2024-01-30 12:13:46 -05:00

474 lines
18 KiB
C++

// Copyright (c) 2017 The Khronos Group Inc.
// Copyright (c) 2017 Valve Corporation
// Copyright (c) 2017 LunarG Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "source/opt/local_access_chain_convert_pass.h"
#include "ir_context.h"
#include "iterator.h"
#include "source/util/string_utils.h"
namespace spvtools {
namespace opt {
namespace {
constexpr uint32_t kStoreValIdInIdx = 1;
constexpr uint32_t kAccessChainPtrIdInIdx = 0;
} // namespace
void LocalAccessChainConvertPass::BuildAndAppendInst(
spv::Op opcode, uint32_t typeId, uint32_t resultId,
const std::vector<Operand>& in_opnds,
std::vector<std::unique_ptr<Instruction>>* newInsts) {
std::unique_ptr<Instruction> newInst(
new Instruction(context(), opcode, typeId, resultId, in_opnds));
get_def_use_mgr()->AnalyzeInstDefUse(&*newInst);
newInsts->emplace_back(std::move(newInst));
}
uint32_t LocalAccessChainConvertPass::BuildAndAppendVarLoad(
const Instruction* ptrInst, uint32_t* varId, uint32_t* varPteTypeId,
std::vector<std::unique_ptr<Instruction>>* newInsts) {
const uint32_t ldResultId = TakeNextId();
if (ldResultId == 0) {
return 0;
}
*varId = ptrInst->GetSingleWordInOperand(kAccessChainPtrIdInIdx);
const Instruction* varInst = get_def_use_mgr()->GetDef(*varId);
assert(varInst->opcode() == spv::Op::OpVariable);
*varPteTypeId = GetPointeeTypeId(varInst);
BuildAndAppendInst(spv::Op::OpLoad, *varPteTypeId, ldResultId,
{{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {*varId}}},
newInsts);
return ldResultId;
}
void LocalAccessChainConvertPass::AppendConstantOperands(
const Instruction* ptrInst, std::vector<Operand>* in_opnds) {
uint32_t iidIdx = 0;
ptrInst->ForEachInId([&iidIdx, &in_opnds, this](const uint32_t* iid) {
if (iidIdx > 0) {
const Instruction* cInst = get_def_use_mgr()->GetDef(*iid);
const auto* constant_value =
context()->get_constant_mgr()->GetConstantFromInst(cInst);
assert(constant_value != nullptr &&
"Expecting the index to be a constant.");
// We take the sign extended value because OpAccessChain interprets the
// index as signed.
int64_t long_value = constant_value->GetSignExtendedValue();
assert(long_value <= UINT32_MAX && long_value >= 0 &&
"The index value is too large for a composite insert or extract "
"instruction.");
uint32_t val = static_cast<uint32_t>(long_value);
in_opnds->push_back(
{spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, {val}});
}
++iidIdx;
});
}
bool LocalAccessChainConvertPass::ReplaceAccessChainLoad(
const Instruction* address_inst, Instruction* original_load) {
// Build and append load of variable in ptrInst
if (address_inst->NumInOperands() == 1) {
// An access chain with no indices is essentially a copy. All that is
// needed is to propagate the address.
context()->ReplaceAllUsesWith(
address_inst->result_id(),
address_inst->GetSingleWordInOperand(kAccessChainPtrIdInIdx));
return true;
}
std::vector<std::unique_ptr<Instruction>> new_inst;
uint32_t varId;
uint32_t varPteTypeId;
const uint32_t ldResultId =
BuildAndAppendVarLoad(address_inst, &varId, &varPteTypeId, &new_inst);
if (ldResultId == 0) {
return false;
}
new_inst[0]->UpdateDebugInfoFrom(original_load);
context()->get_decoration_mgr()->CloneDecorations(
original_load->result_id(), ldResultId,
{spv::Decoration::RelaxedPrecision});
original_load->InsertBefore(std::move(new_inst));
context()->get_debug_info_mgr()->AnalyzeDebugInst(
original_load->PreviousNode());
// Rewrite |original_load| into an extract.
Instruction::OperandList new_operands;
// copy the result id and the type id to the new operand list.
new_operands.emplace_back(original_load->GetOperand(0));
new_operands.emplace_back(original_load->GetOperand(1));
new_operands.emplace_back(
Operand({spv_operand_type_t::SPV_OPERAND_TYPE_ID, {ldResultId}}));
AppendConstantOperands(address_inst, &new_operands);
original_load->SetOpcode(spv::Op::OpCompositeExtract);
original_load->ReplaceOperands(new_operands);
context()->UpdateDefUse(original_load);
return true;
}
bool LocalAccessChainConvertPass::GenAccessChainStoreReplacement(
const Instruction* ptrInst, uint32_t valId,
std::vector<std::unique_ptr<Instruction>>* newInsts) {
if (ptrInst->NumInOperands() == 1) {
// An access chain with no indices is essentially a copy. However, we still
// have to create a new store because the old ones will be deleted.
BuildAndAppendInst(
spv::Op::OpStore, 0, 0,
{{spv_operand_type_t::SPV_OPERAND_TYPE_ID,
{ptrInst->GetSingleWordInOperand(kAccessChainPtrIdInIdx)}},
{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {valId}}},
newInsts);
return true;
}
// Build and append load of variable in ptrInst
uint32_t varId;
uint32_t varPteTypeId;
const uint32_t ldResultId =
BuildAndAppendVarLoad(ptrInst, &varId, &varPteTypeId, newInsts);
if (ldResultId == 0) {
return false;
}
context()->get_decoration_mgr()->CloneDecorations(
varId, ldResultId, {spv::Decoration::RelaxedPrecision});
// Build and append Insert
const uint32_t insResultId = TakeNextId();
if (insResultId == 0) {
return false;
}
std::vector<Operand> ins_in_opnds = {
{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {valId}},
{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {ldResultId}}};
AppendConstantOperands(ptrInst, &ins_in_opnds);
BuildAndAppendInst(spv::Op::OpCompositeInsert, varPteTypeId, insResultId,
ins_in_opnds, newInsts);
context()->get_decoration_mgr()->CloneDecorations(
varId, insResultId, {spv::Decoration::RelaxedPrecision});
// Build and append Store
BuildAndAppendInst(spv::Op::OpStore, 0, 0,
{{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {varId}},
{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {insResultId}}},
newInsts);
return true;
}
bool LocalAccessChainConvertPass::Is32BitConstantIndexAccessChain(
const Instruction* acp) const {
uint32_t inIdx = 0;
return acp->WhileEachInId([&inIdx, this](const uint32_t* tid) {
if (inIdx > 0) {
Instruction* opInst = get_def_use_mgr()->GetDef(*tid);
if (opInst->opcode() != spv::Op::OpConstant) return false;
const auto* index =
context()->get_constant_mgr()->GetConstantFromInst(opInst);
int64_t index_value = index->GetSignExtendedValue();
if (index_value > UINT32_MAX) return false;
if (index_value < 0) return false;
}
++inIdx;
return true;
});
}
bool LocalAccessChainConvertPass::HasOnlySupportedRefs(uint32_t ptrId) {
if (supported_ref_ptrs_.find(ptrId) != supported_ref_ptrs_.end()) return true;
if (get_def_use_mgr()->WhileEachUser(ptrId, [this](Instruction* user) {
if (user->GetCommonDebugOpcode() == CommonDebugInfoDebugValue ||
user->GetCommonDebugOpcode() == CommonDebugInfoDebugDeclare) {
return true;
}
spv::Op op = user->opcode();
if (IsNonPtrAccessChain(op) || op == spv::Op::OpCopyObject) {
if (!HasOnlySupportedRefs(user->result_id())) {
return false;
}
} else if (op != spv::Op::OpStore && op != spv::Op::OpLoad &&
op != spv::Op::OpName && !IsNonTypeDecorate(op)) {
return false;
}
return true;
})) {
supported_ref_ptrs_.insert(ptrId);
return true;
}
return false;
}
void LocalAccessChainConvertPass::FindTargetVars(Function* func) {
for (auto bi = func->begin(); bi != func->end(); ++bi) {
for (auto ii = bi->begin(); ii != bi->end(); ++ii) {
switch (ii->opcode()) {
case spv::Op::OpStore:
case spv::Op::OpLoad: {
uint32_t varId;
Instruction* ptrInst = GetPtr(&*ii, &varId);
if (!IsTargetVar(varId)) break;
const spv::Op op = ptrInst->opcode();
// Rule out variables with non-supported refs eg function calls
if (!HasOnlySupportedRefs(varId)) {
seen_non_target_vars_.insert(varId);
seen_target_vars_.erase(varId);
break;
}
// Rule out variables with nested access chains
// TODO(): Convert nested access chains
bool is_non_ptr_access_chain = IsNonPtrAccessChain(op);
if (is_non_ptr_access_chain && ptrInst->GetSingleWordInOperand(
kAccessChainPtrIdInIdx) != varId) {
seen_non_target_vars_.insert(varId);
seen_target_vars_.erase(varId);
break;
}
// Rule out variables accessed with non-constant indices
if (!Is32BitConstantIndexAccessChain(ptrInst)) {
seen_non_target_vars_.insert(varId);
seen_target_vars_.erase(varId);
break;
}
if (is_non_ptr_access_chain && AnyIndexIsOutOfBounds(ptrInst)) {
seen_non_target_vars_.insert(varId);
seen_target_vars_.erase(varId);
break;
}
} break;
default:
break;
}
}
}
}
Pass::Status LocalAccessChainConvertPass::ConvertLocalAccessChains(
Function* func) {
FindTargetVars(func);
// Replace access chains of all targeted variables with equivalent
// extract and insert sequences
bool modified = false;
for (auto bi = func->begin(); bi != func->end(); ++bi) {
std::vector<Instruction*> dead_instructions;
for (auto ii = bi->begin(); ii != bi->end(); ++ii) {
switch (ii->opcode()) {
case spv::Op::OpLoad: {
uint32_t varId;
Instruction* ptrInst = GetPtr(&*ii, &varId);
if (!IsNonPtrAccessChain(ptrInst->opcode())) break;
if (!IsTargetVar(varId)) break;
if (!ReplaceAccessChainLoad(ptrInst, &*ii)) {
return Status::Failure;
}
modified = true;
} break;
case spv::Op::OpStore: {
uint32_t varId;
Instruction* store = &*ii;
Instruction* ptrInst = GetPtr(store, &varId);
if (!IsNonPtrAccessChain(ptrInst->opcode())) break;
if (!IsTargetVar(varId)) break;
std::vector<std::unique_ptr<Instruction>> newInsts;
uint32_t valId = store->GetSingleWordInOperand(kStoreValIdInIdx);
if (!GenAccessChainStoreReplacement(ptrInst, valId, &newInsts)) {
return Status::Failure;
}
size_t num_of_instructions_to_skip = newInsts.size() - 1;
dead_instructions.push_back(store);
++ii;
ii = ii.InsertBefore(std::move(newInsts));
for (size_t i = 0; i < num_of_instructions_to_skip; ++i) {
ii->UpdateDebugInfoFrom(store);
context()->get_debug_info_mgr()->AnalyzeDebugInst(&*ii);
++ii;
}
ii->UpdateDebugInfoFrom(store);
context()->get_debug_info_mgr()->AnalyzeDebugInst(&*ii);
modified = true;
} break;
default:
break;
}
}
while (!dead_instructions.empty()) {
Instruction* inst = dead_instructions.back();
dead_instructions.pop_back();
DCEInst(inst, [&dead_instructions](Instruction* other_inst) {
auto i = std::find(dead_instructions.begin(), dead_instructions.end(),
other_inst);
if (i != dead_instructions.end()) {
dead_instructions.erase(i);
}
});
}
}
return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
}
void LocalAccessChainConvertPass::Initialize() {
// Initialize Target Variable Caches
seen_target_vars_.clear();
seen_non_target_vars_.clear();
// Initialize collections
supported_ref_ptrs_.clear();
// Initialize extension allowlist
InitExtensions();
}
bool LocalAccessChainConvertPass::AllExtensionsSupported() const {
// This capability can now exist without the extension, so we have to check
// for the capability. This pass is only looking at function scope symbols,
// so we do not care if there are variable pointers on storage buffers.
if (context()->get_feature_mgr()->HasCapability(
spv::Capability::VariablePointers))
return false;
// If any extension not in allowlist, return false
for (auto& ei : get_module()->extensions()) {
const std::string extName = ei.GetInOperand(0).AsString();
if (extensions_allowlist_.find(extName) == extensions_allowlist_.end())
return false;
}
// only allow NonSemantic.Shader.DebugInfo.100, we cannot safely optimise
// around unknown extended
// instruction sets even if they are non-semantic
for (auto& inst : context()->module()->ext_inst_imports()) {
assert(inst.opcode() == spv::Op::OpExtInstImport &&
"Expecting an import of an extension's instruction set.");
const std::string extension_name = inst.GetInOperand(0).AsString();
if (spvtools::utils::starts_with(extension_name, "NonSemantic.") &&
extension_name != "NonSemantic.Shader.DebugInfo.100") {
return false;
}
}
return true;
}
Pass::Status LocalAccessChainConvertPass::ProcessImpl() {
// Do not process if module contains OpGroupDecorate. Additional
// support required in KillNamesAndDecorates().
// TODO(greg-lunarg): Add support for OpGroupDecorate
for (auto& ai : get_module()->annotations())
if (ai.opcode() == spv::Op::OpGroupDecorate)
return Status::SuccessWithoutChange;
// Do not process if any disallowed extensions are enabled
if (!AllExtensionsSupported()) return Status::SuccessWithoutChange;
// Process all functions in the module.
Status status = Status::SuccessWithoutChange;
for (Function& func : *get_module()) {
status = CombineStatus(status, ConvertLocalAccessChains(&func));
if (status == Status::Failure) {
break;
}
}
return status;
}
LocalAccessChainConvertPass::LocalAccessChainConvertPass() {}
Pass::Status LocalAccessChainConvertPass::Process() {
Initialize();
return ProcessImpl();
}
void LocalAccessChainConvertPass::InitExtensions() {
extensions_allowlist_.clear();
extensions_allowlist_.insert(
{"SPV_AMD_shader_explicit_vertex_parameter",
"SPV_AMD_shader_trinary_minmax", "SPV_AMD_gcn_shader",
"SPV_KHR_shader_ballot", "SPV_AMD_shader_ballot",
"SPV_AMD_gpu_shader_half_float", "SPV_KHR_shader_draw_parameters",
"SPV_KHR_subgroup_vote", "SPV_KHR_8bit_storage", "SPV_KHR_16bit_storage",
"SPV_KHR_device_group", "SPV_KHR_multiview",
"SPV_NVX_multiview_per_view_attributes", "SPV_NV_viewport_array2",
"SPV_NV_stereo_view_rendering", "SPV_NV_sample_mask_override_coverage",
"SPV_NV_geometry_shader_passthrough", "SPV_AMD_texture_gather_bias_lod",
"SPV_KHR_storage_buffer_storage_class",
// SPV_KHR_variable_pointers
// Currently do not support extended pointer expressions
"SPV_AMD_gpu_shader_int16", "SPV_KHR_post_depth_coverage",
"SPV_KHR_shader_atomic_counter_ops", "SPV_EXT_shader_stencil_export",
"SPV_EXT_shader_viewport_index_layer",
"SPV_AMD_shader_image_load_store_lod", "SPV_AMD_shader_fragment_mask",
"SPV_EXT_fragment_fully_covered", "SPV_AMD_gpu_shader_half_float_fetch",
"SPV_GOOGLE_decorate_string", "SPV_GOOGLE_hlsl_functionality1",
"SPV_GOOGLE_user_type", "SPV_NV_shader_subgroup_partitioned",
"SPV_EXT_demote_to_helper_invocation", "SPV_EXT_descriptor_indexing",
"SPV_NV_fragment_shader_barycentric",
"SPV_NV_compute_shader_derivatives", "SPV_NV_shader_image_footprint",
"SPV_NV_shading_rate", "SPV_NV_mesh_shader", "SPV_EXT_mesh_shader",
"SPV_NV_ray_tracing", "SPV_KHR_ray_tracing", "SPV_KHR_ray_query",
"SPV_EXT_fragment_invocation_density", "SPV_KHR_terminate_invocation",
"SPV_KHR_subgroup_uniform_control_flow", "SPV_KHR_integer_dot_product",
"SPV_EXT_shader_image_int64", "SPV_KHR_non_semantic_info",
"SPV_KHR_uniform_group_instructions",
"SPV_KHR_fragment_shader_barycentric", "SPV_KHR_vulkan_memory_model",
"SPV_NV_bindless_texture", "SPV_EXT_shader_atomic_float_add",
"SPV_EXT_fragment_shader_interlock",
"SPV_NV_compute_shader_derivatives"});
}
bool LocalAccessChainConvertPass::AnyIndexIsOutOfBounds(
const Instruction* access_chain_inst) {
assert(IsNonPtrAccessChain(access_chain_inst->opcode()));
analysis::TypeManager* type_mgr = context()->get_type_mgr();
analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
auto constants = const_mgr->GetOperandConstants(access_chain_inst);
uint32_t base_pointer_id = access_chain_inst->GetSingleWordInOperand(0);
Instruction* base_pointer = get_def_use_mgr()->GetDef(base_pointer_id);
const analysis::Pointer* base_pointer_type =
type_mgr->GetType(base_pointer->type_id())->AsPointer();
assert(base_pointer_type != nullptr &&
"The base of the access chain is not a pointer.");
const analysis::Type* current_type = base_pointer_type->pointee_type();
for (uint32_t i = 1; i < access_chain_inst->NumInOperands(); ++i) {
if (IsIndexOutOfBounds(constants[i], current_type)) {
return true;
}
uint32_t index =
(constants[i]
? static_cast<uint32_t>(constants[i]->GetZeroExtendedValue())
: 0);
current_type = type_mgr->GetMemberType(current_type, {index});
}
return false;
}
bool LocalAccessChainConvertPass::IsIndexOutOfBounds(
const analysis::Constant* index, const analysis::Type* type) const {
if (index == nullptr) {
return false;
}
return index->GetZeroExtendedValue() >= type->NumberOfComponents();
}
} // namespace opt
} // namespace spvtools