Add knowledge of cooperative matrices (#5720)

* Add knowledge of cooperative matrices

Some optimizations are not aware of cooperative matrices, and either do
nothing or assert. This commits fixes that up.

* Add int tests, and a handle a couple more cases.

* Add float tests, and a handle a couple more cases.

* Add NV coop matrix as well.
This commit is contained in:
Steven Perron 2024-06-26 08:00:29 -04:00 committed by GitHub
parent 64d37e2811
commit ca004da9f9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 201 additions and 5 deletions

View File

@ -1004,7 +1004,9 @@ void AggressiveDCEPass::InitExtensions() {
"SPV_NV_bindless_texture",
"SPV_EXT_shader_atomic_float_add",
"SPV_EXT_fragment_shader_interlock",
"SPV_NV_compute_shader_derivatives"
"SPV_NV_compute_shader_derivatives",
"SPV_NV_cooperative_matrix",
"SPV_KHR_cooperative_matrix"
});
// clang-format on
}

View File

@ -112,6 +112,12 @@ bool IsValidResult(T val) {
}
}
// Returns true if `type` is a cooperative matrix.
bool IsCooperativeMatrix(const analysis::Type* type) {
return type->kind() == analysis::Type::kCooperativeMatrixKHR ||
type->kind() == analysis::Type::kCooperativeMatrixNV;
}
const analysis::Constant* ConstInput(
const std::vector<const analysis::Constant*>& constants) {
return constants[0] ? constants[0] : constants[1];
@ -313,6 +319,11 @@ FoldingRule ReciprocalFDiv() {
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());
if (IsCooperativeMatrix(type)) {
return false;
}
if (!inst->IsFloatingPointFoldingAllowed()) return false;
uint32_t width = ElementWidth(type);
@ -394,6 +405,11 @@ FoldingRule MergeNegateMulDivArithmetic() {
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());
if (IsCooperativeMatrix(type)) {
return false;
}
if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
return false;
@ -455,6 +471,11 @@ FoldingRule MergeNegateAddSubArithmetic() {
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());
if (IsCooperativeMatrix(type)) {
return false;
}
if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
return false;
@ -686,6 +707,11 @@ FoldingRule MergeMulMulArithmetic() {
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());
if (IsCooperativeMatrix(type)) {
return false;
}
if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
return false;
@ -740,6 +766,11 @@ FoldingRule MergeMulDivArithmetic() {
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());
if (IsCooperativeMatrix(type)) {
return false;
}
if (!inst->IsFloatingPointFoldingAllowed()) return false;
uint32_t width = ElementWidth(type);
@ -813,6 +844,11 @@ FoldingRule MergeMulNegateArithmetic() {
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());
if (IsCooperativeMatrix(type)) {
return false;
}
bool uses_float = HasFloatingPoint(type);
if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
@ -853,6 +889,11 @@ FoldingRule MergeDivDivArithmetic() {
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());
if (IsCooperativeMatrix(type)) {
return false;
}
if (!inst->IsFloatingPointFoldingAllowed()) return false;
uint32_t width = ElementWidth(type);
@ -926,6 +967,11 @@ FoldingRule MergeDivMulArithmetic() {
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());
if (IsCooperativeMatrix(type)) {
return false;
}
if (!inst->IsFloatingPointFoldingAllowed()) return false;
uint32_t width = ElementWidth(type);
@ -1068,6 +1114,11 @@ FoldingRule MergeSubNegateArithmetic() {
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());
if (IsCooperativeMatrix(type)) {
return false;
}
bool uses_float = HasFloatingPoint(type);
if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
@ -1116,6 +1167,11 @@ FoldingRule MergeAddAddArithmetic() {
inst->opcode() == spv::Op::OpIAdd);
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());
if (IsCooperativeMatrix(type)) {
return false;
}
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
bool uses_float = HasFloatingPoint(type);
if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
@ -1164,6 +1220,11 @@ FoldingRule MergeAddSubArithmetic() {
inst->opcode() == spv::Op::OpIAdd);
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());
if (IsCooperativeMatrix(type)) {
return false;
}
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
bool uses_float = HasFloatingPoint(type);
if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
@ -1224,6 +1285,11 @@ FoldingRule MergeSubAddArithmetic() {
inst->opcode() == spv::Op::OpISub);
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());
if (IsCooperativeMatrix(type)) {
return false;
}
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
bool uses_float = HasFloatingPoint(type);
if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
@ -1290,6 +1356,11 @@ FoldingRule MergeSubSubArithmetic() {
inst->opcode() == spv::Op::OpISub);
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());
if (IsCooperativeMatrix(type)) {
return false;
}
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
bool uses_float = HasFloatingPoint(type);
if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
@ -1383,6 +1454,11 @@ FoldingRule MergeGenericAddSubArithmetic() {
inst->opcode() == spv::Op::OpIAdd);
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());
if (IsCooperativeMatrix(type)) {
return false;
}
bool uses_float = HasFloatingPoint(type);
if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;

View File

@ -428,8 +428,8 @@ void LocalAccessChainConvertPass::InitExtensions() {
"SPV_KHR_uniform_group_instructions",
"SPV_KHR_fragment_shader_barycentric", "SPV_KHR_vulkan_memory_model",
"SPV_NV_bindless_texture", "SPV_EXT_shader_atomic_float_add",
"SPV_EXT_fragment_shader_interlock",
"SPV_NV_compute_shader_derivatives"});
"SPV_EXT_fragment_shader_interlock", "SPV_NV_compute_shader_derivatives",
"SPV_NV_cooperative_matrix", "SPV_KHR_cooperative_matrix"});
}
bool LocalAccessChainConvertPass::AnyIndexIsOutOfBounds(

View File

@ -291,7 +291,9 @@ void LocalSingleBlockLoadStoreElimPass::InitExtensions() {
"SPV_NV_bindless_texture",
"SPV_EXT_shader_atomic_float_add",
"SPV_EXT_fragment_shader_interlock",
"SPV_NV_compute_shader_derivatives"});
"SPV_NV_compute_shader_derivatives",
"SPV_NV_cooperative_matrix",
"SPV_KHR_cooperative_matrix"});
}
} // namespace opt

View File

@ -141,7 +141,9 @@ void LocalSingleStoreElimPass::InitExtensionAllowList() {
"SPV_NV_bindless_texture",
"SPV_EXT_shader_atomic_float_add",
"SPV_EXT_fragment_shader_interlock",
"SPV_NV_compute_shader_derivatives"});
"SPV_NV_compute_shader_derivatives",
"SPV_NV_cooperative_matrix",
"SPV_KHR_cooperative_matrix"});
}
bool LocalSingleStoreElimPass::ProcessVariable(Instruction* var_inst) {
std::vector<Instruction*> users;

View File

@ -43,6 +43,8 @@ bool MemPass::IsBaseTargetType(const Instruction* typeInst) const {
case spv::Op::OpTypeSampler:
case spv::Op::OpTypeSampledImage:
case spv::Op::OpTypePointer:
case spv::Op::OpTypeCooperativeMatrixNV:
case spv::Op::OpTypeCooperativeMatrixKHR:
return true;
default:
break;

View File

@ -215,6 +215,8 @@ OpCapability Float64
OpCapability Int8
OpCapability Int16
OpCapability Int64
OpCapability CooperativeMatrixKHR
OpExtension "SPV_KHR_cooperative_matrix"
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %main "main"
@ -434,6 +436,12 @@ OpName %main "main"
%ushort_0xBC00 = OpConstant %ushort 0xBC00
%short_0xBC00 = OpConstant %short 0xBC00
%int_arr_2_undef = OpUndef %int_arr_2
%int_coop_matrix = OpTypeCooperativeMatrixKHR %int %uint_3 %uint_3 %uint_32 %uint_0
%undef_int_coop_matrix = OpUndef %int_coop_matrix
%uint_coop_matrix = OpTypeCooperativeMatrixKHR %uint %uint_3 %uint_3 %uint_32 %uint_0
%undef_uint_coop_matrix = OpUndef %uint_coop_matrix
%float_coop_matrix = OpTypeCooperativeMatrixKHR %float %uint_3 %uint_3 %uint_32 %uint_0
%undef_float_coop_matrix = OpUndef %float_coop_matrix
)";
return header;
@ -4148,6 +4156,62 @@ INSTANTIATE_TEST_SUITE_P(IntegerArithmeticTestCases, GeneralInstructionFoldingTe
"%2 = OpSLessThan %bool %long_0 %long_2\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 41: Don't fold OpSNegate for cooperative matrices.
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpSNegate %int_coop_matrix %undef_int_coop_matrix\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 42: Don't fold OpIAdd for cooperative matrices.
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpIAdd %int_coop_matrix %undef_int_coop_matrix %undef_int_coop_matrix\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 43: Don't fold OpISub for cooperative matrices.
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpISub %int_coop_matrix %undef_int_coop_matrix %undef_int_coop_matrix\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 44: Don't fold OpIMul for cooperative matrices.
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpIMul %int_coop_matrix %undef_int_coop_matrix %undef_int_coop_matrix\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 45: Don't fold OpSDiv for cooperative matrices.
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpSDiv %int_coop_matrix %undef_int_coop_matrix %undef_int_coop_matrix\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 46: Don't fold OpUDiv for cooperative matrices.
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpUDiv %uint_coop_matrix %undef_uint_coop_matrix %undef_uint_coop_matrix\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 47: Don't fold OpMatrixTimesScalar for cooperative matrices.
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpMatrixTimesScalar %uint_coop_matrix %undef_uint_coop_matrix %uint_3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0)
));
@ -4689,6 +4753,54 @@ INSTANTIATE_TEST_SUITE_P(FloatRedundantFoldingTest, GeneralInstructionFoldingTes
"%2 = OpFDiv %half %half_1 %half_2\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 24: Don't fold OpFNegate for cooperative matrices.
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpFNegate %float_coop_matrix %undef_float_coop_matrix\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 25: Don't fold OpIAdd for cooperative matrices.
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpFAdd %float_coop_matrix %undef_float_coop_matrix %undef_float_coop_matrix\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 26: Don't fold OpISub for cooperative matrices.
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpFSub %float_coop_matrix %undef_float_coop_matrix %undef_float_coop_matrix\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 27: Don't fold OpIMul for cooperative matrices.
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpFMul %float_coop_matrix %undef_float_coop_matrix %undef_float_coop_matrix\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 28: Don't fold OpSDiv for cooperative matrices.
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpFDiv %float_coop_matrix %undef_float_coop_matrix %undef_float_coop_matrix\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 29: Don't fold OpMatrixTimesScalar for cooperative matrices.
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpMatrixTimesScalar %float_coop_matrix %undef_float_coop_matrix %float_3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0)
));