spirv-val: Validate MemoryAccessMask of OpCooperativeMatrixStoreKHR (#5668)

Reject `OpCooperativeMatrixStoreKHR` with a `MakePointerVisibleKHR`
MemoryAccess operand, as `MakePointerVisibleKHR` is not supposed to be
used with store operations.

The `CoopMatKHRStoreMemoryAccessFail` test failed to catch this
because it used the helper function `GenCoopMatLoadStoreShader` which
generates `...NV` instead of `...KHR` instructions.  Add a new helper
function to generate similar shaders for the KHR extension, as the NV
and KHR extensions have various subtle differences that makes
parameterizing the original helper function non-trivial.

Signed-off-by: Sven van Haastregt <sven.vanhaastregt@arm.com>
This commit is contained in:
Sven van Haastregt 2024-05-10 21:49:10 +02:00 committed by GitHub
parent 9241a58a80
commit 199038f10c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 201 additions and 10 deletions

View File

@ -349,7 +349,8 @@ spv_result_t CheckMemoryAccess(ValidationState_t& _, const Instruction* inst,
if (mask & uint32_t(spv::MemoryAccessMask::MakePointerVisibleKHR)) {
if (inst->opcode() == spv::Op::OpStore ||
inst->opcode() == spv::Op::OpCooperativeMatrixStoreNV) {
inst->opcode() == spv::Op::OpCooperativeMatrixStoreNV ||
inst->opcode() == spv::Op::OpCooperativeMatrixStoreKHR) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "MakePointerVisibleKHR cannot be used with OpStore.";
}

View File

@ -2348,19 +2348,209 @@ OpFunctionEnd)";
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
}
std::string GenCoopMatLoadStoreShaderKHR(const std::string& storeMemoryAccess,
const std::string& loadMemoryAccess) {
std::string s = R"(
OpCapability Shader
OpCapability GroupNonUniform
OpCapability VulkanMemoryModelKHR
OpCapability CooperativeMatrixKHR
OpExtension "SPV_KHR_vulkan_memory_model"
OpExtension "SPV_KHR_cooperative_matrix"
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical VulkanKHR
OpEntryPoint GLCompute %4 "main" %11 %21
OpExecutionMode %4 LocalSize 1 1 1
OpDecorate %11 BuiltIn SubgroupId
OpDecorate %21 BuiltIn WorkgroupId
OpDecorate %74 ArrayStride 4
OpMemberDecorate %75 0 Offset 0
OpDecorate %75 Block
OpDecorate %77 DescriptorSet 0
OpDecorate %77 Binding 0
OpDecorate %92 ArrayStride 4
OpMemberDecorate %93 0 Offset 0
OpDecorate %93 Block
OpDecorate %95 DescriptorSet 0
OpDecorate %95 Binding 1
OpDecorate %102 ArrayStride 4
OpMemberDecorate %103 0 Offset 0
OpDecorate %103 Block
OpDecorate %105 DescriptorSet 0
OpDecorate %105 Binding 2
OpDecorate %117 ArrayStride 4
OpMemberDecorate %118 0 Offset 0
OpDecorate %118 Block
OpDecorate %120 DescriptorSet 0
OpDecorate %120 Binding 3
OpDecorate %123 SpecId 2
OpDecorate %124 SpecId 3
OpDecorate %125 SpecId 4
OpDecorate %126 SpecId 5
OpDecorate %127 SpecId 0
OpDecorate %128 SpecId 1
OpDecorate %129 BuiltIn WorkgroupSize
%2 = OpTypeVoid
%3 = OpTypeFunction %2
%6 = OpTypeInt 32 0
%7 = OpTypeVector %6 2
%8 = OpTypePointer Function %7
%10 = OpTypePointer Input %6
%11 = OpVariable %10 Input
%13 = OpConstant %6 2
%19 = OpTypeVector %6 3
%20 = OpTypePointer Input %19
%21 = OpVariable %20 Input
%27 = OpConstantComposite %7 %13 %13
%31 = OpTypePointer Function %6
%33 = OpConstant %6 1024
%34 = OpConstant %6 1
%38 = OpConstant %6 8
%39 = OpConstant %6 0
%68 = OpTypeFloat 32
%69 = OpConstant %6 16
%70 = OpConstant %6 3
%71 = OpTypeCooperativeMatrixKHR %68 %70 %69 %38 %39
%72 = OpTypePointer Function %71
%74 = OpTypeRuntimeArray %68
%75 = OpTypeStruct %74
%76 = OpTypePointer StorageBuffer %75
%77 = OpVariable %76 StorageBuffer
%78 = OpTypeInt 32 1
%79 = OpConstant %78 0
%81 = OpConstant %6 5
%82 = OpTypePointer StorageBuffer %68
%84 = OpConstant %6 64
%88 = OpTypePointer Private %71
%89 = OpVariable %88 Private
%92 = OpTypeRuntimeArray %68
%93 = OpTypeStruct %92
%94 = OpTypePointer StorageBuffer %93
%95 = OpVariable %94 StorageBuffer
%99 = OpVariable %88 Private
%102 = OpTypeRuntimeArray %68
%103 = OpTypeStruct %102
%104 = OpTypePointer StorageBuffer %103
%105 = OpVariable %104 StorageBuffer
%109 = OpVariable %88 Private
%111 = OpVariable %88 Private
%112 = OpSpecConstantOp %6 CooperativeMatrixLengthKHR %71
%113 = OpSpecConstantOp %78 IAdd %112 %79
%117 = OpTypeRuntimeArray %68
%118 = OpTypeStruct %117
%119 = OpTypePointer StorageBuffer %118
%120 = OpVariable %119 StorageBuffer
%123 = OpSpecConstant %78 1
%124 = OpSpecConstant %78 1
%125 = OpSpecConstant %78 1
%126 = OpSpecConstant %78 1
%127 = OpSpecConstant %6 1
%128 = OpSpecConstant %6 1
%129 = OpSpecConstantComposite %19 %127 %128 %34
%4 = OpFunction %2 None %3
%5 = OpLabel
%9 = OpVariable %8 Function
%18 = OpVariable %8 Function
%32 = OpVariable %31 Function
%44 = OpVariable %31 Function
%52 = OpVariable %31 Function
%60 = OpVariable %31 Function
%73 = OpVariable %72 Function
%91 = OpVariable %72 Function
%101 = OpVariable %72 Function
%12 = OpLoad %6 %11
%14 = OpUMod %6 %12 %13
%15 = OpLoad %6 %11
%16 = OpUDiv %6 %15 %13
%17 = OpCompositeConstruct %7 %14 %16
OpStore %9 %17
%22 = OpLoad %19 %21
%23 = OpVectorShuffle %7 %22 %22 0 1
%24 = OpCompositeExtract %6 %23 0
%25 = OpCompositeExtract %6 %23 1
%26 = OpCompositeConstruct %7 %24 %25
%28 = OpIMul %7 %26 %27
%29 = OpLoad %7 %9
%30 = OpIAdd %7 %28 %29
OpStore %18 %30
%35 = OpAccessChain %31 %18 %34
%36 = OpLoad %6 %35
%37 = OpIMul %6 %33 %36
%40 = OpAccessChain %31 %18 %39
%41 = OpLoad %6 %40
%42 = OpIMul %6 %38 %41
%43 = OpIAdd %6 %37 %42
OpStore %32 %43
%45 = OpAccessChain %31 %18 %34
%46 = OpLoad %6 %45
%47 = OpIMul %6 %33 %46
%48 = OpAccessChain %31 %18 %39
%49 = OpLoad %6 %48
%50 = OpIMul %6 %38 %49
%51 = OpIAdd %6 %47 %50
OpStore %44 %51
%53 = OpAccessChain %31 %18 %34
%54 = OpLoad %6 %53
%55 = OpIMul %6 %33 %54
%56 = OpAccessChain %31 %18 %39
%57 = OpLoad %6 %56
%58 = OpIMul %6 %38 %57
%59 = OpIAdd %6 %55 %58
OpStore %52 %59
%61 = OpAccessChain %31 %18 %34
%62 = OpLoad %6 %61
%63 = OpIMul %6 %33 %62
%64 = OpAccessChain %31 %18 %39
%65 = OpLoad %6 %64
%66 = OpIMul %6 %38 %65
%67 = OpIAdd %6 %63 %66
OpStore %60 %67
%80 = OpLoad %6 %32
%83 = OpAccessChain %82 %77 %79 %80
%87 = OpCooperativeMatrixLoadKHR %71 %83 %39 %84 )" +
loadMemoryAccess + R"( %81
OpStore %73 %87
%90 = OpLoad %71 %73
OpStore %89 %90
%96 = OpLoad %6 %44
%97 = OpAccessChain %82 %95 %79 %96
%98 = OpCooperativeMatrixLoadKHR %71 %97 %39 %84 MakePointerVisibleKHR|NonPrivatePointerKHR %81
OpStore %91 %98
%100 = OpLoad %71 %91
OpStore %99 %100
%106 = OpLoad %6 %52
%107 = OpAccessChain %82 %105 %79 %106
%108 = OpCooperativeMatrixLoadKHR %71 %107 %39 %84 MakePointerVisibleKHR|NonPrivatePointerKHR %81
OpStore %101 %108
%110 = OpLoad %71 %101
OpStore %109 %110
%114 = OpConvertSToF %68 %113
%115 = OpCompositeConstruct %71 %114
OpStore %111 %115
%116 = OpLoad %71 %111
%121 = OpLoad %6 %60
%122 = OpAccessChain %82 %120 %79 %121
OpCooperativeMatrixStoreKHR %122 %116 %39 %84 )" + storeMemoryAccess + R"( %81
OpReturn
OpFunctionEnd
)";
return s;
}
TEST_F(ValidateMemory, CoopMatKHRLoadStoreSuccess) {
std::string spirv =
GenCoopMatLoadStoreShader("MakePointerAvailableKHR|NonPrivatePointerKHR",
"MakePointerVisibleKHR|NonPrivatePointerKHR");
std::string spirv = GenCoopMatLoadStoreShaderKHR(
"MakePointerAvailableKHR|NonPrivatePointerKHR",
"MakePointerVisibleKHR|NonPrivatePointerKHR");
CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1);
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_1));
}
TEST_F(ValidateMemory, CoopMatKHRStoreMemoryAccessFail) {
std::string spirv =
GenCoopMatLoadStoreShader("MakePointerVisibleKHR|NonPrivatePointerKHR",
"MakePointerVisibleKHR|NonPrivatePointerKHR");
std::string spirv = GenCoopMatLoadStoreShaderKHR(
"MakePointerVisibleKHR|NonPrivatePointerKHR",
"MakePointerVisibleKHR|NonPrivatePointerKHR");
CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1);
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1));
@ -2369,9 +2559,9 @@ TEST_F(ValidateMemory, CoopMatKHRStoreMemoryAccessFail) {
}
TEST_F(ValidateMemory, CoopMatKHRLoadMemoryAccessFail) {
std::string spirv =
GenCoopMatLoadStoreShader("MakePointerAvailableKHR|NonPrivatePointerKHR",
"MakePointerAvailableKHR|NonPrivatePointerKHR");
std::string spirv = GenCoopMatLoadStoreShaderKHR(
"MakePointerAvailableKHR|NonPrivatePointerKHR",
"MakePointerAvailableKHR|NonPrivatePointerKHR");
CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1);
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1));