spirv-opt : Add FixFuncCallArgumentsPass (#4775)

spirv validation require OpFunctionCall with memory object, usually this
is non issue as all the functions are inlined.
This pass deal with some case for
DontInline function. accesschain input operand would be replaced new
created variable
This commit is contained in:
JiaoluAMD 2022-05-06 22:39:26 +08:00 committed by GitHub
parent 9e377b0f97
commit c11ea09652
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 323 additions and 2 deletions

View File

@ -109,6 +109,7 @@ SPVTOOLS_OPT_SRC_FILES := \
source/opt/eliminate_dead_input_components_pass.cpp \
source/opt/eliminate_dead_members_pass.cpp \
source/opt/feature_manager.cpp \
source/opt/fix_func_call_arguments.cpp \
source/opt/fix_storage_class.cpp \
source/opt/flatten_decoration_pass.cpp \
source/opt/fold.cpp \

View File

@ -629,6 +629,8 @@ static_library("spvtools_opt") {
"source/opt/empty_pass.h",
"source/opt/feature_manager.cpp",
"source/opt/feature_manager.h",
"source/opt/fix_func_call_arguments.cpp",
"source/opt/fix_func_call_arguments.h",
"source/opt/fix_storage_class.cpp",
"source/opt/fix_storage_class.h",
"source/opt/flatten_decoration_pass.cpp",

View File

@ -907,6 +907,11 @@ Optimizer::PassToken CreateConvertToSampledImagePass(
// from every function in the module. This is useful if you want the inliner to
// inline these functions some reason.
Optimizer::PassToken CreateRemoveDontInlinePass();
// Create a fix-func-call-param pass to fix non memory argument for the function
// call, as spirv-validation requires function parameters to be an memory
// object, currently the pass would remove accesschain pointer argument passed
// to the function
Optimizer::PassToken CreateFixFuncCallArgumentsPass();
} // namespace spvtools
#endif // INCLUDE_SPIRV_TOOLS_OPTIMIZER_HPP_

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
set(SPIRV_TOOLS_OPT_SOURCES
fix_func_call_arguments.h
aggressive_dead_code_elim_pass.h
amd_ext_to_khr.h
basic_block.h
@ -126,6 +127,7 @@ set(SPIRV_TOOLS_OPT_SOURCES
workaround1209.h
wrap_opkill.h
fix_func_call_arguments.cpp
aggressive_dead_code_elim_pass.cpp
amd_ext_to_khr.cpp
basic_block.cpp

View File

@ -0,0 +1,90 @@
// Copyright (c) 2022 Advanced Micro Devices, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fix_func_call_arguments.h"
#include "ir_builder.h"
using namespace spvtools;
using namespace opt;
bool FixFuncCallArgumentsPass::ModuleHasASingleFunction() {
auto funcsNum = get_module()->end() - get_module()->begin();
return funcsNum == 1;
}
Pass::Status FixFuncCallArgumentsPass::Process() {
bool modified = false;
if (ModuleHasASingleFunction()) return Status::SuccessWithoutChange;
for (auto& func : *get_module()) {
func.ForEachInst([this, &modified](Instruction* inst) {
if (inst->opcode() == SpvOpFunctionCall) {
modified |= FixFuncCallArguments(inst);
}
});
}
return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
}
bool FixFuncCallArgumentsPass::FixFuncCallArguments(
Instruction* func_call_inst) {
bool modified = false;
for (uint32_t i = 0; i < func_call_inst->NumInOperands(); ++i) {
Operand& op = func_call_inst->GetInOperand(i);
if (op.type != SPV_OPERAND_TYPE_ID) continue;
Instruction* operand_inst = get_def_use_mgr()->GetDef(op.AsId());
if (operand_inst->opcode() == SpvOpAccessChain) {
uint32_t var_id =
ReplaceAccessChainFuncCallArguments(func_call_inst, operand_inst);
func_call_inst->SetInOperand(i, {var_id});
modified = true;
}
}
if (modified) {
context()->UpdateDefUse(func_call_inst);
}
return modified;
}
uint32_t FixFuncCallArgumentsPass::ReplaceAccessChainFuncCallArguments(
Instruction* func_call_inst, Instruction* operand_inst) {
InstructionBuilder builder(
context(), func_call_inst,
IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
Instruction* next_insert_point = func_call_inst->NextNode();
// Get Variable insertion point
Function* func = context()->get_instr_block(func_call_inst)->GetParent();
Instruction* variable_insertion_point = &*(func->begin()->begin());
Instruction* op_ptr_type = get_def_use_mgr()->GetDef(operand_inst->type_id());
Instruction* op_type =
get_def_use_mgr()->GetDef(op_ptr_type->GetSingleWordInOperand(1));
uint32_t varType = context()->get_type_mgr()->FindPointerToType(
op_type->result_id(), SpvStorageClassFunction);
// Create new variable
builder.SetInsertPoint(variable_insertion_point);
Instruction* var = builder.AddVariable(varType, SpvStorageClassFunction);
// Load access chain to the new variable before function call
builder.SetInsertPoint(func_call_inst);
uint32_t operand_id = operand_inst->result_id();
Instruction* load = builder.AddLoad(op_type->result_id(), operand_id);
builder.AddStore(var->result_id(), load->result_id());
// Load return value to the acesschain after function call
builder.SetInsertPoint(next_insert_point);
load = builder.AddLoad(op_type->result_id(), var->result_id());
builder.AddStore(operand_id, load->result_id());
return var->result_id();
}

View File

@ -0,0 +1,47 @@
// Copyright (c) 2022 Advanced Micro Devices, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef _VAR_FUNC_CALL_PASS_H
#define _VAR_FUNC_CALL_PASS_H
#include "source/opt/pass.h"
namespace spvtools {
namespace opt {
class FixFuncCallArgumentsPass : public Pass {
public:
FixFuncCallArgumentsPass() {}
const char* name() const override { return "fix-for-funcall-param"; }
Status Process() override;
// Returns true if the module has one one function.
bool ModuleHasASingleFunction();
// Copies from the memory pointed to by |operand_inst| to a new function scope
// variable created before |func_call_inst|, and
// copies the value of the new variable back to the memory pointed to by
// |operand_inst| after |funct_call_inst| Returns the id of
// the new variable.
uint32_t ReplaceAccessChainFuncCallArguments(Instruction* func_call_inst,
Instruction* operand_inst);
// Fix function call |func_call_inst| non memory object arguments
bool FixFuncCallArguments(Instruction* func_call_inst);
IRContext::Analysis GetPreservedAnalyses() override {
return IRContext::kAnalysisTypes;
}
};
} // namespace opt
} // namespace spvtools
#endif // _VAR_FUNC_CALL_PASS_H

View File

@ -487,6 +487,15 @@ class InstructionBuilder {
return AddInstruction(std::move(new_inst));
}
Instruction* AddVariable(uint32_t type_id, uint32_t storage_class) {
std::vector<Operand> operands;
operands.push_back({SPV_OPERAND_TYPE_ID, {storage_class}});
std::unique_ptr<Instruction> new_inst(
new Instruction(GetContext(), SpvOpVariable, type_id,
GetContext()->TakeNextId(), operands));
return AddInstruction(std::move(new_inst));
}
Instruction* AddStore(uint32_t ptr_id, uint32_t obj_id) {
std::vector<Operand> operands;
operands.push_back({SPV_OPERAND_TYPE_ID, {ptr_id}});

View File

@ -525,6 +525,8 @@ bool Optimizer::RegisterPassFromFlag(const std::string& flag) {
RegisterPass(CreateRemoveDontInlinePass());
} else if (pass_name == "eliminate-dead-input-components") {
RegisterPass(CreateEliminateDeadInputComponentsPass());
} else if (pass_name == "fix-func-call-param") {
RegisterPass(CreateFixFuncCallArgumentsPass());
} else if (pass_name == "convert-to-sampled-image") {
if (pass_args.size() > 0) {
auto descriptor_set_binding_pairs =
@ -1022,4 +1024,9 @@ Optimizer::PassToken CreateRemoveDontInlinePass() {
return MakeUnique<Optimizer::PassToken::Impl>(
MakeUnique<opt::RemoveDontInline>());
}
Optimizer::PassToken CreateFixFuncCallArgumentsPass() {
return MakeUnique<Optimizer::PassToken::Impl>(
MakeUnique<opt::FixFuncCallArgumentsPass>());
}
} // namespace spvtools

View File

@ -37,6 +37,7 @@
#include "source/opt/eliminate_dead_input_components_pass.h"
#include "source/opt/eliminate_dead_members_pass.h"
#include "source/opt/empty_pass.h"
#include "source/opt/fix_func_call_arguments.h"
#include "source/opt/fix_storage_class.h"
#include "source/opt/flatten_decoration_pass.h"
#include "source/opt/fold_spec_constant_op_and_composite_pass.h"

View File

@ -45,6 +45,7 @@ add_spvtools_unittest(TARGET opt
eliminate_dead_input_components_test.cpp
eliminate_dead_member_test.cpp
feature_manager_test.cpp
fix_func_call_arguments_test.cpp
fix_storage_class_test.cpp
flatten_decoration_test.cpp
fold_spec_const_op_composite_test.cpp
@ -84,7 +85,7 @@ add_spvtools_unittest(TARGET opt
reduce_load_size_test.cpp
redundancy_elimination_test.cpp
remove_dontinline_test.cpp
remove_unused_interface_variables_test.cpp
remove_unused_interface_variables_test.cpp
register_liveness.cpp
relax_float_ops_test.cpp
replace_desc_array_access_using_var_index_test.cpp
@ -96,7 +97,7 @@ add_spvtools_unittest(TARGET opt
spread_volatile_semantics_test.cpp
strength_reduction_test.cpp
strip_debug_info_test.cpp
strip_nonsemantic_info_test.cpp
strip_nonsemantic_info_test.cpp
struct_cfg_analysis_test.cpp
type_manager_test.cpp
types_test.cpp

View File

@ -0,0 +1,152 @@
// Copyright (c) 2022 Advanced Micro Devices, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gmock/gmock.h"
#include "test/opt/pass_fixture.h"
#include "test/opt/pass_utils.h"
namespace spvtools {
namespace opt {
namespace {
using FixFuncCallArgumentsTest = PassTest<::testing::Test>;
TEST_F(FixFuncCallArgumentsTest, Simple) {
const std::string text = R"(
;
; CHECK: [[v0:%\w+]] = OpVariable %_ptr_Function_float Function
; CHECK: [[v1:%\w+]] = OpVariable %_ptr_Function_float Function
; CHECK: [[v2:%\w+]] = OpVariable %_ptr_Function_T Function
; CHECK: [[ac0:%\w+]] = OpAccessChain %_ptr_Function_float %t %int_0
; CHECK: [[ac1:%\w+]] = OpAccessChain %_ptr_Uniform_float %r1 %int_0 %uint_0
; CHECK: [[ld0:%\w+]] = OpLoad %float [[ac0]]
; CHECK: OpStore [[v1]] [[ld0]]
; CHECK: [[ld1:%\w+]] = OpLoad %float [[ac1]]
; CHECK: OpStore [[v0]] [[ld1]]
; CHECK: [[func:%\w+]] = OpFunctionCall %void %fn [[v1]] [[v0]]
; CHECK: [[ld2:%\w+]] = OpLoad %float [[v0]]
; CHECK: OpStore [[ac1]] [[ld2]]
; CHECK: [[ld3:%\w+]] = OpLoad %float [[v1]]
; CHECK: OpStore [[ac0]] [[ld3]]
;
OpCapability Shader
OpCapability Linkage
OpMemoryModel Logical GLSL450
OpSource HLSL 630
OpName %type_RWStructuredBuffer_float "type.RWStructuredBuffer.float"
OpName %r1 "r1"
OpName %type_ACSBuffer_counter "type.ACSBuffer.counter"
OpMemberName %type_ACSBuffer_counter 0 "counter"
OpName %counter_var_r1 "counter.var.r1"
OpName %main "main"
OpName %bb_entry "bb.entry"
OpName %T "T"
OpMemberName %T 0 "t0"
OpName %t "t"
OpName %fn "fn"
OpName %p0 "p0"
OpName %p2 "p2"
OpName %bb_entry_0 "bb.entry"
OpDecorate %main LinkageAttributes "main" Export
OpDecorate %r1 DescriptorSet 0
OpDecorate %r1 Binding 0
OpDecorate %counter_var_r1 DescriptorSet 0
OpDecorate %counter_var_r1 Binding 1
OpDecorate %_runtimearr_float ArrayStride 4
OpMemberDecorate %type_RWStructuredBuffer_float 0 Offset 0
OpDecorate %type_RWStructuredBuffer_float BufferBlock
OpMemberDecorate %type_ACSBuffer_counter 0 Offset 0
OpDecorate %type_ACSBuffer_counter BufferBlock
%int = OpTypeInt 32 1
%int_0 = OpConstant %int 0
%uint = OpTypeInt 32 0
%uint_0 = OpConstant %uint 0
%int_1 = OpConstant %int 1
%float = OpTypeFloat 32
%_runtimearr_float = OpTypeRuntimeArray %float
%type_RWStructuredBuffer_float = OpTypeStruct %_runtimearr_float
%_ptr_Uniform_type_RWStructuredBuffer_float = OpTypePointer Uniform %type_RWStructuredBuffer_float
%type_ACSBuffer_counter = OpTypeStruct %int
%_ptr_Uniform_type_ACSBuffer_counter = OpTypePointer Uniform %type_ACSBuffer_counter
%15 = OpTypeFunction %int
%T = OpTypeStruct %float
%_ptr_Function_T = OpTypePointer Function %T
%_ptr_Function_float = OpTypePointer Function %float
%_ptr_Uniform_float = OpTypePointer Uniform %float
%void = OpTypeVoid
%27 = OpTypeFunction %void %_ptr_Function_float %_ptr_Function_float
%r1 = OpVariable %_ptr_Uniform_type_RWStructuredBuffer_float Uniform
%counter_var_r1 = OpVariable %_ptr_Uniform_type_ACSBuffer_counter Uniform
%main = OpFunction %int None %15
%bb_entry = OpLabel
%t = OpVariable %_ptr_Function_T Function
%21 = OpAccessChain %_ptr_Function_float %t %int_0
%23 = OpAccessChain %_ptr_Uniform_float %r1 %int_0 %uint_0
%25 = OpFunctionCall %void %fn %21 %23
OpReturnValue %int_1
OpFunctionEnd
%fn = OpFunction %void DontInline %27
%p0 = OpFunctionParameter %_ptr_Function_float
%p2 = OpFunctionParameter %_ptr_Function_float
%bb_entry_0 = OpLabel
OpReturn
OpFunctionEnd
)";
SinglePassRunAndMatch<FixFuncCallArgumentsPass>(text, true);
}
TEST_F(FixFuncCallArgumentsTest, NotAccessChainInput) {
const std::string text = R"(
;
; CHECK: [[o:%\w+]] = OpCopyObject %_ptr_Function_float %t
; CHECK: [[func:%\w+]] = OpFunctionCall %void %fn [[o]]
;
OpCapability Shader
OpCapability Linkage
OpMemoryModel Logical GLSL450
OpSource HLSL 630
OpName %main "main"
OpName %bb_entry "bb.entry"
OpName %t "t"
OpName %fn "fn"
OpName %p0 "p0"
OpName %bb_entry_0 "bb.entry"
OpDecorate %main LinkageAttributes "main" Export
%int = OpTypeInt 32 1
%int_1 = OpConstant %int 1
%4 = OpTypeFunction %int
%float = OpTypeFloat 32
%_ptr_Function_float = OpTypePointer Function %float
%void = OpTypeVoid
%12 = OpTypeFunction %void %_ptr_Function_float
%main = OpFunction %int None %4
%bb_entry = OpLabel
%t = OpVariable %_ptr_Function_float Function
%t1 = OpCopyObject %_ptr_Function_float %t
%10 = OpFunctionCall %void %fn %t1
OpReturnValue %int_1
OpFunctionEnd
%fn = OpFunction %void DontInline %12
%p0 = OpFunctionParameter %_ptr_Function_float
%bb_entry_0 = OpLabel
OpReturn
OpFunctionEnd
)";
SinglePassRunAndMatch<FixFuncCallArgumentsPass>(text, false);
}
} // namespace
} // namespace opt
} // namespace spvtools

View File

@ -237,6 +237,10 @@ Options (in lexicographical order):)",
loads and stores. Performed only on entry point call tree
functions.)");
printf(R"(
--fix-func-call-param
fix non memory argument for the function call, replace
accesschain pointer argument with a variable.)");
printf(R"(
--flatten-decorations
Replace decoration groups with repeated OpDecorate and
OpMemberDecorate instructions.)");