opt: prevent meld to merge block with MaximalReconvergence (#5557)

The extension SPV_KHR_maximal_reconvergence adds more constraints
around the merge blocks, and how the control flow can be altered.

The one we address here is explained in the following part of the spec:

  Note: This means that the instructions in a break block will execute as if
  they were still diverged according to the loop iteration. This restricts
  potential transformations an implementation may perform on the IR to match
  shader author expectations. Similarly, instructions in the loop construct
  cannot be moved into the continue construct unless it can be proven that
  invocations are always converged.

Until the optimizer is clever enough to determine if the invocation
have already converged, we shall not meld a block which branches to a
merge block into it, as it might move some instructions outside of the
convergence region.

This behavior being only required with the extension, this commit
behavior change is gated by the extension.
This means using wave operations without the maximal reconvergence
extension might lead to undefined behaviors.

Co-authored-by: Natalie Chouinard <chouinard.nm@gmail.com>
This commit is contained in:
Nathan Gauër 2024-02-06 12:12:00 +01:00 committed by GitHub
parent 6c11c2bd46
commit ab59dc6087
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 153 additions and 0 deletions

View File

@ -98,6 +98,17 @@ bool CanMergeWithSuccessor(IRContext* context, BasicBlock* block) {
return false;
}
// Note: This means that the instructions in a break block will execute as if
// they were still diverged according to the loop iteration. This restricts
// potential transformations an implementation may perform on the IR to match
// shader author expectations. Similarly, instructions in the loop construct
// cannot be moved into the continue construct unless it can be proven that
// invocations are always converged.
if (succ_is_merge && context->get_feature_mgr()->HasExtension(
kSPV_KHR_maximal_reconvergence)) {
return false;
}
if (pred_is_merge && IsContinue(context, lab_id)) {
// Cannot merge a continue target with a merge block.
return false;

View File

@ -1320,6 +1320,148 @@ OpFunctionEnd
SinglePassRunAndMatch<BlockMergePass>(text, true);
}
TEST_F(BlockMergeTest, MaximalReconvergenceNoMeldToMerge) {
const std::string text = R"(
OpCapability Shader
OpCapability GroupNonUniformBallot
OpCapability GroupNonUniformArithmetic
OpExtension "SPV_KHR_maximal_reconvergence"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main" %gl_GlobalInvocationID %output
OpExecutionMode %main LocalSize 1 1 1
OpExecutionMode %main MaximallyReconvergesKHR
OpSource HLSL 660
OpName %type_RWStructuredBuffer_uint "type.RWStructuredBuffer.uint"
OpName %output "output"
OpName %main "main"
OpDecorate %gl_GlobalInvocationID BuiltIn GlobalInvocationId
OpDecorate %output DescriptorSet 0
OpDecorate %output Binding 0
OpDecorate %_runtimearr_uint ArrayStride 4
OpMemberDecorate %type_RWStructuredBuffer_uint 0 Offset 0
OpDecorate %type_RWStructuredBuffer_uint Block
%uint = OpTypeInt 32 0
%bool = OpTypeBool
%int = OpTypeInt 32 1
%int_0 = OpConstant %int 0
%int_1 = OpConstant %int 1
%_runtimearr_uint = OpTypeRuntimeArray %uint
%type_RWStructuredBuffer_uint = OpTypeStruct %_runtimearr_uint
%_ptr_StorageBuffer_type_RWStructuredBuffer_uint = OpTypePointer StorageBuffer %type_RWStructuredBuffer_uint
%v3uint = OpTypeVector %uint 3
%_ptr_Input_v3uint = OpTypePointer Input %v3uint
%void = OpTypeVoid
%15 = OpTypeFunction %void
%uint_3 = OpConstant %uint 3
%_ptr_StorageBuffer_uint = OpTypePointer StorageBuffer %uint
%output = OpVariable %_ptr_StorageBuffer_type_RWStructuredBuffer_uint StorageBuffer
%gl_GlobalInvocationID = OpVariable %_ptr_Input_v3uint Input
%main = OpFunction %void None %15
%18 = OpLabel
%19 = OpLoad %v3uint %gl_GlobalInvocationID
OpBranch %20
%20 = OpLabel
OpLoopMerge %21 %22 None
; CHECK: OpLoopMerge [[merge:%\w+]] [[continue:%\w+]]
OpBranch %23
%23 = OpLabel
%24 = OpCompositeExtract %uint %19 0
%25 = OpGroupNonUniformBroadcastFirst %uint %uint_3 %24
%26 = OpIEqual %bool %24 %25
OpSelectionMerge %27 None
OpBranchConditional %26 %28 %27
%28 = OpLabel
%29 = OpGroupNonUniformIAdd %int %uint_3 Reduce %int_1
%30 = OpBitcast %uint %29
OpBranch %21
; CHECK: [[t1:%\w+]] = OpGroupNonUniformIAdd %int %uint_3 Reduce %int_1
; CHECK-NEXT: [[t2:%\w+]] = OpBitcast %uint [[t1]]
; CHECK-NEXT: OpBranch [[merge]]
%27 = OpLabel
OpBranch %22
%22 = OpLabel
OpBranch %20
%21 = OpLabel
%31 = OpAccessChain %_ptr_StorageBuffer_uint %output %int_0 %24
OpStore %31 %30
OpReturn
OpFunctionEnd
)";
SetTargetEnv(SPV_ENV_VULKAN_1_3);
SinglePassRunAndMatch<BlockMergePass>(text, true);
}
TEST_F(BlockMergeTest, NoMaximalReconvergenceMeldToMerge) {
const std::string text = R"(
OpCapability Shader
OpCapability GroupNonUniformBallot
OpCapability GroupNonUniformArithmetic
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main" %gl_GlobalInvocationID %output
OpExecutionMode %main LocalSize 1 1 1
OpSource HLSL 660
OpName %type_RWStructuredBuffer_uint "type.RWStructuredBuffer.uint"
OpName %output "output"
OpName %main "main"
OpDecorate %gl_GlobalInvocationID BuiltIn GlobalInvocationId
OpDecorate %output DescriptorSet 0
OpDecorate %output Binding 0
OpDecorate %_runtimearr_uint ArrayStride 4
OpMemberDecorate %type_RWStructuredBuffer_uint 0 Offset 0
OpDecorate %type_RWStructuredBuffer_uint Block
%uint = OpTypeInt 32 0
%bool = OpTypeBool
%int = OpTypeInt 32 1
%int_0 = OpConstant %int 0
%int_1 = OpConstant %int 1
%_runtimearr_uint = OpTypeRuntimeArray %uint
%type_RWStructuredBuffer_uint = OpTypeStruct %_runtimearr_uint
%_ptr_StorageBuffer_type_RWStructuredBuffer_uint = OpTypePointer StorageBuffer %type_RWStructuredBuffer_uint
%v3uint = OpTypeVector %uint 3
%_ptr_Input_v3uint = OpTypePointer Input %v3uint
%void = OpTypeVoid
%15 = OpTypeFunction %void
%uint_3 = OpConstant %uint 3
%_ptr_StorageBuffer_uint = OpTypePointer StorageBuffer %uint
%output = OpVariable %_ptr_StorageBuffer_type_RWStructuredBuffer_uint StorageBuffer
%gl_GlobalInvocationID = OpVariable %_ptr_Input_v3uint Input
%main = OpFunction %void None %15
%18 = OpLabel
%19 = OpLoad %v3uint %gl_GlobalInvocationID
OpBranch %20
%20 = OpLabel
OpLoopMerge %21 %22 None
; CHECK: OpLoopMerge [[merge:%\w+]] [[continue:%\w+]]
OpBranch %23
%23 = OpLabel
%24 = OpCompositeExtract %uint %19 0
%25 = OpGroupNonUniformBroadcastFirst %uint %uint_3 %24
%26 = OpIEqual %bool %24 %25
OpSelectionMerge %27 None
OpBranchConditional %26 %28 %27
%28 = OpLabel
%29 = OpGroupNonUniformIAdd %int %uint_3 Reduce %int_1
%30 = OpBitcast %uint %29
OpBranch %21
; CHECK: [[merge]] = OpLabel
; CHECK-NEXT: [[t1:%\w+]] = OpGroupNonUniformIAdd %int %uint_3 Reduce %int_1
; CHECK-NEXT: [[t2:%\w+]] = OpBitcast %uint [[t1]]
%27 = OpLabel
OpBranch %22
%22 = OpLabel
OpBranch %20
%21 = OpLabel
%31 = OpAccessChain %_ptr_StorageBuffer_uint %output %int_0 %24
OpStore %31 %30
OpReturn
OpFunctionEnd
)";
SetTargetEnv(SPV_ENV_VULKAN_1_3);
SinglePassRunAndMatch<BlockMergePass>(text, true);
}
// TODO(greg-lunarg): Add tests to verify handling of these cases:
//
// More complex control flow