From 199038f10cbe56bf7cbfeb5472eb0a25af2f09f5 Mon Sep 17 00:00:00 2001 From: Sven van Haastregt Date: Fri, 10 May 2024 21:49:10 +0200 Subject: [PATCH] spirv-val: Validate MemoryAccessMask of OpCooperativeMatrixStoreKHR (#5668) Reject `OpCooperativeMatrixStoreKHR` with a `MakePointerVisibleKHR` MemoryAccess operand, as `MakePointerVisibleKHR` is not supposed to be used with store operations. The `CoopMatKHRStoreMemoryAccessFail` test failed to catch this because it used the helper function `GenCoopMatLoadStoreShader` which generates `...NV` instead of `...KHR` instructions. Add a new helper function to generate similar shaders for the KHR extension, as the NV and KHR extensions have various subtle differences that makes parameterizing the original helper function non-trivial. Signed-off-by: Sven van Haastregt --- source/val/validate_memory.cpp | 3 +- test/val/val_memory_test.cpp | 208 +++++++++++++++++++++++++++++++-- 2 files changed, 201 insertions(+), 10 deletions(-) diff --git a/source/val/validate_memory.cpp b/source/val/validate_memory.cpp index 2d6715f42..ef6676fb7 100644 --- a/source/val/validate_memory.cpp +++ b/source/val/validate_memory.cpp @@ -349,7 +349,8 @@ spv_result_t CheckMemoryAccess(ValidationState_t& _, const Instruction* inst, if (mask & uint32_t(spv::MemoryAccessMask::MakePointerVisibleKHR)) { if (inst->opcode() == spv::Op::OpStore || - inst->opcode() == spv::Op::OpCooperativeMatrixStoreNV) { + inst->opcode() == spv::Op::OpCooperativeMatrixStoreNV || + inst->opcode() == spv::Op::OpCooperativeMatrixStoreKHR) { return _.diag(SPV_ERROR_INVALID_ID, inst) << "MakePointerVisibleKHR cannot be used with OpStore."; } diff --git a/test/val/val_memory_test.cpp b/test/val/val_memory_test.cpp index 74a17e984..dfddc9872 100644 --- a/test/val/val_memory_test.cpp +++ b/test/val/val_memory_test.cpp @@ -2348,19 +2348,209 @@ OpFunctionEnd)"; EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); } +std::string GenCoopMatLoadStoreShaderKHR(const std::string& storeMemoryAccess, + const std::string& loadMemoryAccess) { + std::string s = R"( +OpCapability Shader +OpCapability GroupNonUniform +OpCapability VulkanMemoryModelKHR +OpCapability CooperativeMatrixKHR +OpExtension "SPV_KHR_vulkan_memory_model" +OpExtension "SPV_KHR_cooperative_matrix" +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical VulkanKHR +OpEntryPoint GLCompute %4 "main" %11 %21 +OpExecutionMode %4 LocalSize 1 1 1 +OpDecorate %11 BuiltIn SubgroupId +OpDecorate %21 BuiltIn WorkgroupId +OpDecorate %74 ArrayStride 4 +OpMemberDecorate %75 0 Offset 0 +OpDecorate %75 Block +OpDecorate %77 DescriptorSet 0 +OpDecorate %77 Binding 0 +OpDecorate %92 ArrayStride 4 +OpMemberDecorate %93 0 Offset 0 +OpDecorate %93 Block +OpDecorate %95 DescriptorSet 0 +OpDecorate %95 Binding 1 +OpDecorate %102 ArrayStride 4 +OpMemberDecorate %103 0 Offset 0 +OpDecorate %103 Block +OpDecorate %105 DescriptorSet 0 +OpDecorate %105 Binding 2 +OpDecorate %117 ArrayStride 4 +OpMemberDecorate %118 0 Offset 0 +OpDecorate %118 Block +OpDecorate %120 DescriptorSet 0 +OpDecorate %120 Binding 3 +OpDecorate %123 SpecId 2 +OpDecorate %124 SpecId 3 +OpDecorate %125 SpecId 4 +OpDecorate %126 SpecId 5 +OpDecorate %127 SpecId 0 +OpDecorate %128 SpecId 1 +OpDecorate %129 BuiltIn WorkgroupSize +%2 = OpTypeVoid +%3 = OpTypeFunction %2 +%6 = OpTypeInt 32 0 +%7 = OpTypeVector %6 2 +%8 = OpTypePointer Function %7 +%10 = OpTypePointer Input %6 +%11 = OpVariable %10 Input +%13 = OpConstant %6 2 +%19 = OpTypeVector %6 3 +%20 = OpTypePointer Input %19 +%21 = OpVariable %20 Input +%27 = OpConstantComposite %7 %13 %13 +%31 = OpTypePointer Function %6 +%33 = OpConstant %6 1024 +%34 = OpConstant %6 1 +%38 = OpConstant %6 8 +%39 = OpConstant %6 0 +%68 = OpTypeFloat 32 +%69 = OpConstant %6 16 +%70 = OpConstant %6 3 +%71 = OpTypeCooperativeMatrixKHR %68 %70 %69 %38 %39 +%72 = OpTypePointer Function %71 +%74 = OpTypeRuntimeArray %68 +%75 = OpTypeStruct %74 +%76 = OpTypePointer StorageBuffer %75 +%77 = OpVariable %76 StorageBuffer +%78 = OpTypeInt 32 1 +%79 = OpConstant %78 0 +%81 = OpConstant %6 5 +%82 = OpTypePointer StorageBuffer %68 +%84 = OpConstant %6 64 +%88 = OpTypePointer Private %71 +%89 = OpVariable %88 Private +%92 = OpTypeRuntimeArray %68 +%93 = OpTypeStruct %92 +%94 = OpTypePointer StorageBuffer %93 +%95 = OpVariable %94 StorageBuffer +%99 = OpVariable %88 Private +%102 = OpTypeRuntimeArray %68 +%103 = OpTypeStruct %102 +%104 = OpTypePointer StorageBuffer %103 +%105 = OpVariable %104 StorageBuffer +%109 = OpVariable %88 Private +%111 = OpVariable %88 Private +%112 = OpSpecConstantOp %6 CooperativeMatrixLengthKHR %71 +%113 = OpSpecConstantOp %78 IAdd %112 %79 +%117 = OpTypeRuntimeArray %68 +%118 = OpTypeStruct %117 +%119 = OpTypePointer StorageBuffer %118 +%120 = OpVariable %119 StorageBuffer +%123 = OpSpecConstant %78 1 +%124 = OpSpecConstant %78 1 +%125 = OpSpecConstant %78 1 +%126 = OpSpecConstant %78 1 +%127 = OpSpecConstant %6 1 +%128 = OpSpecConstant %6 1 +%129 = OpSpecConstantComposite %19 %127 %128 %34 +%4 = OpFunction %2 None %3 +%5 = OpLabel +%9 = OpVariable %8 Function +%18 = OpVariable %8 Function +%32 = OpVariable %31 Function +%44 = OpVariable %31 Function +%52 = OpVariable %31 Function +%60 = OpVariable %31 Function +%73 = OpVariable %72 Function +%91 = OpVariable %72 Function +%101 = OpVariable %72 Function +%12 = OpLoad %6 %11 +%14 = OpUMod %6 %12 %13 +%15 = OpLoad %6 %11 +%16 = OpUDiv %6 %15 %13 +%17 = OpCompositeConstruct %7 %14 %16 +OpStore %9 %17 +%22 = OpLoad %19 %21 +%23 = OpVectorShuffle %7 %22 %22 0 1 +%24 = OpCompositeExtract %6 %23 0 +%25 = OpCompositeExtract %6 %23 1 +%26 = OpCompositeConstruct %7 %24 %25 +%28 = OpIMul %7 %26 %27 +%29 = OpLoad %7 %9 +%30 = OpIAdd %7 %28 %29 +OpStore %18 %30 +%35 = OpAccessChain %31 %18 %34 +%36 = OpLoad %6 %35 +%37 = OpIMul %6 %33 %36 +%40 = OpAccessChain %31 %18 %39 +%41 = OpLoad %6 %40 +%42 = OpIMul %6 %38 %41 +%43 = OpIAdd %6 %37 %42 +OpStore %32 %43 +%45 = OpAccessChain %31 %18 %34 +%46 = OpLoad %6 %45 +%47 = OpIMul %6 %33 %46 +%48 = OpAccessChain %31 %18 %39 +%49 = OpLoad %6 %48 +%50 = OpIMul %6 %38 %49 +%51 = OpIAdd %6 %47 %50 +OpStore %44 %51 +%53 = OpAccessChain %31 %18 %34 +%54 = OpLoad %6 %53 +%55 = OpIMul %6 %33 %54 +%56 = OpAccessChain %31 %18 %39 +%57 = OpLoad %6 %56 +%58 = OpIMul %6 %38 %57 +%59 = OpIAdd %6 %55 %58 +OpStore %52 %59 +%61 = OpAccessChain %31 %18 %34 +%62 = OpLoad %6 %61 +%63 = OpIMul %6 %33 %62 +%64 = OpAccessChain %31 %18 %39 +%65 = OpLoad %6 %64 +%66 = OpIMul %6 %38 %65 +%67 = OpIAdd %6 %63 %66 +OpStore %60 %67 +%80 = OpLoad %6 %32 +%83 = OpAccessChain %82 %77 %79 %80 +%87 = OpCooperativeMatrixLoadKHR %71 %83 %39 %84 )" + + loadMemoryAccess + R"( %81 +OpStore %73 %87 +%90 = OpLoad %71 %73 +OpStore %89 %90 +%96 = OpLoad %6 %44 +%97 = OpAccessChain %82 %95 %79 %96 +%98 = OpCooperativeMatrixLoadKHR %71 %97 %39 %84 MakePointerVisibleKHR|NonPrivatePointerKHR %81 +OpStore %91 %98 +%100 = OpLoad %71 %91 +OpStore %99 %100 +%106 = OpLoad %6 %52 +%107 = OpAccessChain %82 %105 %79 %106 +%108 = OpCooperativeMatrixLoadKHR %71 %107 %39 %84 MakePointerVisibleKHR|NonPrivatePointerKHR %81 +OpStore %101 %108 +%110 = OpLoad %71 %101 +OpStore %109 %110 +%114 = OpConvertSToF %68 %113 +%115 = OpCompositeConstruct %71 %114 +OpStore %111 %115 +%116 = OpLoad %71 %111 +%121 = OpLoad %6 %60 +%122 = OpAccessChain %82 %120 %79 %121 +OpCooperativeMatrixStoreKHR %122 %116 %39 %84 )" + storeMemoryAccess + R"( %81 +OpReturn +OpFunctionEnd +)"; + + return s; +} + TEST_F(ValidateMemory, CoopMatKHRLoadStoreSuccess) { - std::string spirv = - GenCoopMatLoadStoreShader("MakePointerAvailableKHR|NonPrivatePointerKHR", - "MakePointerVisibleKHR|NonPrivatePointerKHR"); + std::string spirv = GenCoopMatLoadStoreShaderKHR( + "MakePointerAvailableKHR|NonPrivatePointerKHR", + "MakePointerVisibleKHR|NonPrivatePointerKHR"); CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1); EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_1)); } TEST_F(ValidateMemory, CoopMatKHRStoreMemoryAccessFail) { - std::string spirv = - GenCoopMatLoadStoreShader("MakePointerVisibleKHR|NonPrivatePointerKHR", - "MakePointerVisibleKHR|NonPrivatePointerKHR"); + std::string spirv = GenCoopMatLoadStoreShaderKHR( + "MakePointerVisibleKHR|NonPrivatePointerKHR", + "MakePointerVisibleKHR|NonPrivatePointerKHR"); CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1); ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1)); @@ -2369,9 +2559,9 @@ TEST_F(ValidateMemory, CoopMatKHRStoreMemoryAccessFail) { } TEST_F(ValidateMemory, CoopMatKHRLoadMemoryAccessFail) { - std::string spirv = - GenCoopMatLoadStoreShader("MakePointerAvailableKHR|NonPrivatePointerKHR", - "MakePointerAvailableKHR|NonPrivatePointerKHR"); + std::string spirv = GenCoopMatLoadStoreShaderKHR( + "MakePointerAvailableKHR|NonPrivatePointerKHR", + "MakePointerAvailableKHR|NonPrivatePointerKHR"); CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1); ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1));