opt: Add support for OpExtInst to capability trim pass (#5836)

The grammar does track required capabilities for extended instruction
set operations, so we just need to look them up.
This commit is contained in:
Cassandra Beckley 2024-10-04 00:42:48 -07:00 committed by GitHub
parent c173df736c
commit 522dfead39
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 106 additions and 2 deletions

View File

@ -27,6 +27,7 @@
#include "source/enum_set.h" #include "source/enum_set.h"
#include "source/enum_string_mapping.h" #include "source/enum_string_mapping.h"
#include "source/ext_inst.h"
#include "source/opt/ir_context.h" #include "source/opt/ir_context.h"
#include "source/opt/reflect.h" #include "source/opt/reflect.h"
#include "source/spirv_target_env.h" #include "source/spirv_target_env.h"
@ -49,6 +50,9 @@ constexpr uint32_t kOpTypeImageSampledIndex = kOpTypeImageMSIndex + 1;
constexpr uint32_t kOpTypeImageFormatIndex = kOpTypeImageSampledIndex + 1; constexpr uint32_t kOpTypeImageFormatIndex = kOpTypeImageSampledIndex + 1;
constexpr uint32_t kOpImageReadImageIndex = 0; constexpr uint32_t kOpImageReadImageIndex = 0;
constexpr uint32_t kOpImageSparseReadImageIndex = 0; constexpr uint32_t kOpImageSparseReadImageIndex = 0;
constexpr uint32_t kOpExtInstSetInIndex = 0;
constexpr uint32_t kOpExtInstInstructionInIndex = 1;
constexpr uint32_t kOpExtInstImportNameInIndex = 0;
// DFS visit of the type defined by `instruction`. // DFS visit of the type defined by `instruction`.
// If `condition` is true, children of the current node are visited. // If `condition` is true, children of the current node are visited.
@ -514,6 +518,35 @@ void TrimCapabilitiesPass::addInstructionRequirementsForOperand(
} }
} }
void TrimCapabilitiesPass::addInstructionRequirementsForExtInst(
Instruction* instruction, CapabilitySet* capabilities) const {
assert(instruction->opcode() == spv::Op::OpExtInst &&
"addInstructionRequirementsForExtInst must be passed an OpExtInst "
"instruction");
const auto* def_use_mgr = context()->get_def_use_mgr();
const Instruction* extInstImport = def_use_mgr->GetDef(
instruction->GetSingleWordInOperand(kOpExtInstSetInIndex));
uint32_t extInstruction =
instruction->GetSingleWordInOperand(kOpExtInstInstructionInIndex);
const Operand& extInstSet =
extInstImport->GetInOperand(kOpExtInstImportNameInIndex);
spv_ext_inst_type_t instructionSet =
spvExtInstImportTypeGet(extInstSet.AsString().c_str());
spv_ext_inst_desc desc = {};
auto result =
context()->grammar().lookupExtInst(instructionSet, extInstruction, &desc);
if (result != SPV_SUCCESS) {
return;
}
addSupportedCapabilitiesToSet(desc, capabilities);
}
void TrimCapabilitiesPass::addInstructionRequirements( void TrimCapabilitiesPass::addInstructionRequirements(
Instruction* instruction, CapabilitySet* capabilities, Instruction* instruction, CapabilitySet* capabilities,
ExtensionSet* extensions) const { ExtensionSet* extensions) const {
@ -523,8 +556,12 @@ void TrimCapabilitiesPass::addInstructionRequirements(
return; return;
} }
addInstructionRequirementsForOpcode(instruction->opcode(), capabilities, if (instruction->opcode() == spv::Op::OpExtInst) {
extensions); addInstructionRequirementsForExtInst(instruction, capabilities);
} else {
addInstructionRequirementsForOpcode(instruction->opcode(), capabilities,
extensions);
}
// Second case: one of the opcode operand is gated by a capability. // Second case: one of the opcode operand is gated by a capability.
const uint32_t operandCount = instruction->NumOperands(); const uint32_t operandCount = instruction->NumOperands();

View File

@ -90,6 +90,7 @@ class TrimCapabilitiesPass : public Pass {
spv::Capability::ImageMSArray, spv::Capability::ImageMSArray,
spv::Capability::Int16, spv::Capability::Int16,
spv::Capability::Int64, spv::Capability::Int64,
spv::Capability::InterpolationFunction,
spv::Capability::Linkage, spv::Capability::Linkage,
spv::Capability::MinLod, spv::Capability::MinLod,
spv::Capability::PhysicalStorageBufferAddresses, spv::Capability::PhysicalStorageBufferAddresses,
@ -160,6 +161,9 @@ class TrimCapabilitiesPass : public Pass {
CapabilitySet* capabilities, CapabilitySet* capabilities,
ExtensionSet* extensions) const; ExtensionSet* extensions) const;
void addInstructionRequirementsForExtInst(Instruction* instruction,
CapabilitySet* capabilities) const;
// Given an `instruction`, determines the capabilities it requires, and output // Given an `instruction`, determines the capabilities it requires, and output
// them in `capabilities`. The returned capabilities form a subset of // them in `capabilities`. The returned capabilities form a subset of
// kSupportedCapabilities. // kSupportedCapabilities.

View File

@ -3149,6 +3149,69 @@ TEST_P(TrimCapabilitiesPassTestSubgroupClustered_Unsigned,
EXPECT_EQ(std::get<1>(result), Pass::Status::SuccessWithChange); EXPECT_EQ(std::get<1>(result), Pass::Status::SuccessWithChange);
} }
TEST_F(TrimCapabilitiesPassTest, InterpolationFunction_RemovedIfNotUsed) {
const std::string kTest = R"(
OpCapability Shader
OpCapability InterpolationFunction
; CHECK-NOT: OpCapability InterpolationFunction
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %main "main" %out_var_SV_Target
OpExecutionMode %main OriginUpperLeft
OpSource HLSL 660
OpName %out_var_SV_Target "out.var.SV_Target"
OpName %main "main"
OpDecorate %out_var_SV_Target Location 0
%float = OpTypeFloat 32
%v4float = OpTypeVector %float 4
%_ptr_Output_v4float = OpTypePointer Output %v4float
%void = OpTypeVoid
%7 = OpTypeFunction %void
%out_var_SV_Target = OpVariable %_ptr_Output_v4float Output
%main = OpFunction %void None %7
%8 = OpLabel
OpReturn
OpFunctionEnd
)";
const auto result =
SinglePassRunAndMatch<TrimCapabilitiesPass>(kTest, /* skip_nop= */ false);
EXPECT_EQ(std::get<1>(result), Pass::Status::SuccessWithChange);
}
TEST_F(TrimCapabilitiesPassTest,
InterpolationFunction_RemainsWithInterpolateAtCentroid) {
const std::string kTest = R"(
OpCapability Shader
OpCapability InterpolationFunction
; CHECK: OpCapability InterpolationFunction
%std450 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %main "main" %out_var_SV_Target %gl_PointCoord
OpExecutionMode %main OriginUpperLeft
OpSource HLSL 660
OpName %out_var_SV_Target "out.var.SV_Target"
OpName %main "main"
OpDecorate %out_var_SV_Target Location 0
OpDecorate %gl_PointCoord BuiltIn PointCoord
%float = OpTypeFloat 32
%v2float = OpTypeVector %float 2
%v4float = OpTypeVector %float 4
%_ptr_Output_v4float = OpTypePointer Output %v4float
%_ptr_Input_v2float = OpTypePointer Input %v2float
%void = OpTypeVoid
%7 = OpTypeFunction %void
%out_var_SV_Target = OpVariable %_ptr_Output_v4float Output
%gl_PointCoord = OpVariable %_ptr_Input_v2float Input
%main = OpFunction %void None %7
%8 = OpLabel
%9 = OpExtInst %v4float %std450 InterpolateAtCentroid %gl_PointCoord
OpReturn
OpFunctionEnd
)";
const auto result =
SinglePassRunAndMatch<TrimCapabilitiesPass>(kTest, /* skip_nop= */ false);
EXPECT_EQ(std::get<1>(result), Pass::Status::SuccessWithoutChange);
}
INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P(
TrimCapabilitiesPassTestSubgroupClustered_Unsigned_I, TrimCapabilitiesPassTestSubgroupClustered_Unsigned_I,
TrimCapabilitiesPassTestSubgroupClustered_Unsigned, TrimCapabilitiesPassTestSubgroupClustered_Unsigned,