mirror of
https://github.com/KhronosGroup/SPIRV-Tools
synced 2024-11-21 19:20:07 +00:00
Add knowledge of cooperative matrices (#5720)
* Add knowledge of cooperative matrices Some optimizations are not aware of cooperative matrices, and either do nothing or assert. This commits fixes that up. * Add int tests, and a handle a couple more cases. * Add float tests, and a handle a couple more cases. * Add NV coop matrix as well.
This commit is contained in:
parent
64d37e2811
commit
ca004da9f9
@ -1004,7 +1004,9 @@ void AggressiveDCEPass::InitExtensions() {
|
||||
"SPV_NV_bindless_texture",
|
||||
"SPV_EXT_shader_atomic_float_add",
|
||||
"SPV_EXT_fragment_shader_interlock",
|
||||
"SPV_NV_compute_shader_derivatives"
|
||||
"SPV_NV_compute_shader_derivatives",
|
||||
"SPV_NV_cooperative_matrix",
|
||||
"SPV_KHR_cooperative_matrix"
|
||||
});
|
||||
// clang-format on
|
||||
}
|
||||
|
@ -112,6 +112,12 @@ bool IsValidResult(T val) {
|
||||
}
|
||||
}
|
||||
|
||||
// Returns true if `type` is a cooperative matrix.
|
||||
bool IsCooperativeMatrix(const analysis::Type* type) {
|
||||
return type->kind() == analysis::Type::kCooperativeMatrixKHR ||
|
||||
type->kind() == analysis::Type::kCooperativeMatrixNV;
|
||||
}
|
||||
|
||||
const analysis::Constant* ConstInput(
|
||||
const std::vector<const analysis::Constant*>& constants) {
|
||||
return constants[0] ? constants[0] : constants[1];
|
||||
@ -313,6 +319,11 @@ FoldingRule ReciprocalFDiv() {
|
||||
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
|
||||
const analysis::Type* type =
|
||||
context->get_type_mgr()->GetType(inst->type_id());
|
||||
|
||||
if (IsCooperativeMatrix(type)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!inst->IsFloatingPointFoldingAllowed()) return false;
|
||||
|
||||
uint32_t width = ElementWidth(type);
|
||||
@ -394,6 +405,11 @@ FoldingRule MergeNegateMulDivArithmetic() {
|
||||
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
|
||||
const analysis::Type* type =
|
||||
context->get_type_mgr()->GetType(inst->type_id());
|
||||
|
||||
if (IsCooperativeMatrix(type)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
|
||||
return false;
|
||||
|
||||
@ -455,6 +471,11 @@ FoldingRule MergeNegateAddSubArithmetic() {
|
||||
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
|
||||
const analysis::Type* type =
|
||||
context->get_type_mgr()->GetType(inst->type_id());
|
||||
|
||||
if (IsCooperativeMatrix(type)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
|
||||
return false;
|
||||
|
||||
@ -686,6 +707,11 @@ FoldingRule MergeMulMulArithmetic() {
|
||||
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
|
||||
const analysis::Type* type =
|
||||
context->get_type_mgr()->GetType(inst->type_id());
|
||||
|
||||
if (IsCooperativeMatrix(type)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
|
||||
return false;
|
||||
|
||||
@ -740,6 +766,11 @@ FoldingRule MergeMulDivArithmetic() {
|
||||
|
||||
const analysis::Type* type =
|
||||
context->get_type_mgr()->GetType(inst->type_id());
|
||||
|
||||
if (IsCooperativeMatrix(type)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!inst->IsFloatingPointFoldingAllowed()) return false;
|
||||
|
||||
uint32_t width = ElementWidth(type);
|
||||
@ -813,6 +844,11 @@ FoldingRule MergeMulNegateArithmetic() {
|
||||
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
|
||||
const analysis::Type* type =
|
||||
context->get_type_mgr()->GetType(inst->type_id());
|
||||
|
||||
if (IsCooperativeMatrix(type)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool uses_float = HasFloatingPoint(type);
|
||||
if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
|
||||
|
||||
@ -853,6 +889,11 @@ FoldingRule MergeDivDivArithmetic() {
|
||||
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
|
||||
const analysis::Type* type =
|
||||
context->get_type_mgr()->GetType(inst->type_id());
|
||||
|
||||
if (IsCooperativeMatrix(type)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!inst->IsFloatingPointFoldingAllowed()) return false;
|
||||
|
||||
uint32_t width = ElementWidth(type);
|
||||
@ -926,6 +967,11 @@ FoldingRule MergeDivMulArithmetic() {
|
||||
|
||||
const analysis::Type* type =
|
||||
context->get_type_mgr()->GetType(inst->type_id());
|
||||
|
||||
if (IsCooperativeMatrix(type)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!inst->IsFloatingPointFoldingAllowed()) return false;
|
||||
|
||||
uint32_t width = ElementWidth(type);
|
||||
@ -1068,6 +1114,11 @@ FoldingRule MergeSubNegateArithmetic() {
|
||||
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
|
||||
const analysis::Type* type =
|
||||
context->get_type_mgr()->GetType(inst->type_id());
|
||||
|
||||
if (IsCooperativeMatrix(type)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool uses_float = HasFloatingPoint(type);
|
||||
if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
|
||||
|
||||
@ -1116,6 +1167,11 @@ FoldingRule MergeAddAddArithmetic() {
|
||||
inst->opcode() == spv::Op::OpIAdd);
|
||||
const analysis::Type* type =
|
||||
context->get_type_mgr()->GetType(inst->type_id());
|
||||
|
||||
if (IsCooperativeMatrix(type)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
|
||||
bool uses_float = HasFloatingPoint(type);
|
||||
if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
|
||||
@ -1164,6 +1220,11 @@ FoldingRule MergeAddSubArithmetic() {
|
||||
inst->opcode() == spv::Op::OpIAdd);
|
||||
const analysis::Type* type =
|
||||
context->get_type_mgr()->GetType(inst->type_id());
|
||||
|
||||
if (IsCooperativeMatrix(type)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
|
||||
bool uses_float = HasFloatingPoint(type);
|
||||
if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
|
||||
@ -1224,6 +1285,11 @@ FoldingRule MergeSubAddArithmetic() {
|
||||
inst->opcode() == spv::Op::OpISub);
|
||||
const analysis::Type* type =
|
||||
context->get_type_mgr()->GetType(inst->type_id());
|
||||
|
||||
if (IsCooperativeMatrix(type)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
|
||||
bool uses_float = HasFloatingPoint(type);
|
||||
if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
|
||||
@ -1290,6 +1356,11 @@ FoldingRule MergeSubSubArithmetic() {
|
||||
inst->opcode() == spv::Op::OpISub);
|
||||
const analysis::Type* type =
|
||||
context->get_type_mgr()->GetType(inst->type_id());
|
||||
|
||||
if (IsCooperativeMatrix(type)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
|
||||
bool uses_float = HasFloatingPoint(type);
|
||||
if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
|
||||
@ -1383,6 +1454,11 @@ FoldingRule MergeGenericAddSubArithmetic() {
|
||||
inst->opcode() == spv::Op::OpIAdd);
|
||||
const analysis::Type* type =
|
||||
context->get_type_mgr()->GetType(inst->type_id());
|
||||
|
||||
if (IsCooperativeMatrix(type)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool uses_float = HasFloatingPoint(type);
|
||||
if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
|
||||
|
||||
|
@ -428,8 +428,8 @@ void LocalAccessChainConvertPass::InitExtensions() {
|
||||
"SPV_KHR_uniform_group_instructions",
|
||||
"SPV_KHR_fragment_shader_barycentric", "SPV_KHR_vulkan_memory_model",
|
||||
"SPV_NV_bindless_texture", "SPV_EXT_shader_atomic_float_add",
|
||||
"SPV_EXT_fragment_shader_interlock",
|
||||
"SPV_NV_compute_shader_derivatives"});
|
||||
"SPV_EXT_fragment_shader_interlock", "SPV_NV_compute_shader_derivatives",
|
||||
"SPV_NV_cooperative_matrix", "SPV_KHR_cooperative_matrix"});
|
||||
}
|
||||
|
||||
bool LocalAccessChainConvertPass::AnyIndexIsOutOfBounds(
|
||||
|
@ -291,7 +291,9 @@ void LocalSingleBlockLoadStoreElimPass::InitExtensions() {
|
||||
"SPV_NV_bindless_texture",
|
||||
"SPV_EXT_shader_atomic_float_add",
|
||||
"SPV_EXT_fragment_shader_interlock",
|
||||
"SPV_NV_compute_shader_derivatives"});
|
||||
"SPV_NV_compute_shader_derivatives",
|
||||
"SPV_NV_cooperative_matrix",
|
||||
"SPV_KHR_cooperative_matrix"});
|
||||
}
|
||||
|
||||
} // namespace opt
|
||||
|
@ -141,7 +141,9 @@ void LocalSingleStoreElimPass::InitExtensionAllowList() {
|
||||
"SPV_NV_bindless_texture",
|
||||
"SPV_EXT_shader_atomic_float_add",
|
||||
"SPV_EXT_fragment_shader_interlock",
|
||||
"SPV_NV_compute_shader_derivatives"});
|
||||
"SPV_NV_compute_shader_derivatives",
|
||||
"SPV_NV_cooperative_matrix",
|
||||
"SPV_KHR_cooperative_matrix"});
|
||||
}
|
||||
bool LocalSingleStoreElimPass::ProcessVariable(Instruction* var_inst) {
|
||||
std::vector<Instruction*> users;
|
||||
|
@ -43,6 +43,8 @@ bool MemPass::IsBaseTargetType(const Instruction* typeInst) const {
|
||||
case spv::Op::OpTypeSampler:
|
||||
case spv::Op::OpTypeSampledImage:
|
||||
case spv::Op::OpTypePointer:
|
||||
case spv::Op::OpTypeCooperativeMatrixNV:
|
||||
case spv::Op::OpTypeCooperativeMatrixKHR:
|
||||
return true;
|
||||
default:
|
||||
break;
|
||||
|
@ -215,6 +215,8 @@ OpCapability Float64
|
||||
OpCapability Int8
|
||||
OpCapability Int16
|
||||
OpCapability Int64
|
||||
OpCapability CooperativeMatrixKHR
|
||||
OpExtension "SPV_KHR_cooperative_matrix"
|
||||
%1 = OpExtInstImport "GLSL.std.450"
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint Fragment %main "main"
|
||||
@ -434,6 +436,12 @@ OpName %main "main"
|
||||
%ushort_0xBC00 = OpConstant %ushort 0xBC00
|
||||
%short_0xBC00 = OpConstant %short 0xBC00
|
||||
%int_arr_2_undef = OpUndef %int_arr_2
|
||||
%int_coop_matrix = OpTypeCooperativeMatrixKHR %int %uint_3 %uint_3 %uint_32 %uint_0
|
||||
%undef_int_coop_matrix = OpUndef %int_coop_matrix
|
||||
%uint_coop_matrix = OpTypeCooperativeMatrixKHR %uint %uint_3 %uint_3 %uint_32 %uint_0
|
||||
%undef_uint_coop_matrix = OpUndef %uint_coop_matrix
|
||||
%float_coop_matrix = OpTypeCooperativeMatrixKHR %float %uint_3 %uint_3 %uint_32 %uint_0
|
||||
%undef_float_coop_matrix = OpUndef %float_coop_matrix
|
||||
)";
|
||||
|
||||
return header;
|
||||
@ -4148,6 +4156,62 @@ INSTANTIATE_TEST_SUITE_P(IntegerArithmeticTestCases, GeneralInstructionFoldingTe
|
||||
"%2 = OpSLessThan %bool %long_0 %long_2\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
2, 0),
|
||||
// Test case 41: Don't fold OpSNegate for cooperative matrices.
|
||||
InstructionFoldingCase<uint32_t>(
|
||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%2 = OpSNegate %int_coop_matrix %undef_int_coop_matrix\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
2, 0),
|
||||
// Test case 42: Don't fold OpIAdd for cooperative matrices.
|
||||
InstructionFoldingCase<uint32_t>(
|
||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%2 = OpIAdd %int_coop_matrix %undef_int_coop_matrix %undef_int_coop_matrix\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
2, 0),
|
||||
// Test case 43: Don't fold OpISub for cooperative matrices.
|
||||
InstructionFoldingCase<uint32_t>(
|
||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%2 = OpISub %int_coop_matrix %undef_int_coop_matrix %undef_int_coop_matrix\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
2, 0),
|
||||
// Test case 44: Don't fold OpIMul for cooperative matrices.
|
||||
InstructionFoldingCase<uint32_t>(
|
||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%2 = OpIMul %int_coop_matrix %undef_int_coop_matrix %undef_int_coop_matrix\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
2, 0),
|
||||
// Test case 45: Don't fold OpSDiv for cooperative matrices.
|
||||
InstructionFoldingCase<uint32_t>(
|
||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%2 = OpSDiv %int_coop_matrix %undef_int_coop_matrix %undef_int_coop_matrix\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
2, 0),
|
||||
// Test case 46: Don't fold OpUDiv for cooperative matrices.
|
||||
InstructionFoldingCase<uint32_t>(
|
||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%2 = OpUDiv %uint_coop_matrix %undef_uint_coop_matrix %undef_uint_coop_matrix\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
2, 0),
|
||||
// Test case 47: Don't fold OpMatrixTimesScalar for cooperative matrices.
|
||||
InstructionFoldingCase<uint32_t>(
|
||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%2 = OpMatrixTimesScalar %uint_coop_matrix %undef_uint_coop_matrix %uint_3\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
2, 0)
|
||||
));
|
||||
|
||||
@ -4689,6 +4753,54 @@ INSTANTIATE_TEST_SUITE_P(FloatRedundantFoldingTest, GeneralInstructionFoldingTes
|
||||
"%2 = OpFDiv %half %half_1 %half_2\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
2, 0),
|
||||
// Test case 24: Don't fold OpFNegate for cooperative matrices.
|
||||
InstructionFoldingCase<uint32_t>(
|
||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%2 = OpFNegate %float_coop_matrix %undef_float_coop_matrix\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
2, 0),
|
||||
// Test case 25: Don't fold OpIAdd for cooperative matrices.
|
||||
InstructionFoldingCase<uint32_t>(
|
||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%2 = OpFAdd %float_coop_matrix %undef_float_coop_matrix %undef_float_coop_matrix\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
2, 0),
|
||||
// Test case 26: Don't fold OpISub for cooperative matrices.
|
||||
InstructionFoldingCase<uint32_t>(
|
||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%2 = OpFSub %float_coop_matrix %undef_float_coop_matrix %undef_float_coop_matrix\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
2, 0),
|
||||
// Test case 27: Don't fold OpIMul for cooperative matrices.
|
||||
InstructionFoldingCase<uint32_t>(
|
||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%2 = OpFMul %float_coop_matrix %undef_float_coop_matrix %undef_float_coop_matrix\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
2, 0),
|
||||
// Test case 28: Don't fold OpSDiv for cooperative matrices.
|
||||
InstructionFoldingCase<uint32_t>(
|
||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%2 = OpFDiv %float_coop_matrix %undef_float_coop_matrix %undef_float_coop_matrix\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
2, 0),
|
||||
// Test case 29: Don't fold OpMatrixTimesScalar for cooperative matrices.
|
||||
InstructionFoldingCase<uint32_t>(
|
||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%2 = OpMatrixTimesScalar %float_coop_matrix %undef_float_coop_matrix %float_3\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
2, 0)
|
||||
));
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user