diff --git a/source/opt/aggressive_dead_code_elim_pass.cpp b/source/opt/aggressive_dead_code_elim_pass.cpp index 290d70ae5..f7d99cb6b 100644 --- a/source/opt/aggressive_dead_code_elim_pass.cpp +++ b/source/opt/aggressive_dead_code_elim_pass.cpp @@ -433,6 +433,15 @@ bool AggressiveDCEPass::AggressiveDCE(ir::Function* func) { else if (liveInst->opcode() == SpvOpFunctionParameter) { ProcessLoad(liveInst->result_id()); } + // We treat an OpImageTexelPointer as a load of the pointer, and + // that value is manipulated to get the result. + else if (liveInst->opcode() == SpvOpImageTexelPointer) { + uint32_t varId; + (void)GetPtr(liveInst, &varId); + if (varId != 0) { + ProcessLoad(varId); + } + } worklist_.pop(); } diff --git a/source/opt/instruction.cpp b/source/opt/instruction.cpp index f1e483b9e..ba8413e35 100644 --- a/source/opt/instruction.cpp +++ b/source/opt/instruction.cpp @@ -158,8 +158,8 @@ bool Instruction::IsReadOnlyLoad() const { Instruction* Instruction::GetBaseAddress() const { assert((IsLoad() || opcode() == SpvOpStore || opcode() == SpvOpAccessChain || - opcode() == SpvOpInBoundsAccessChain || - opcode() == SpvOpCopyObject) && + opcode() == SpvOpInBoundsAccessChain || opcode() == SpvOpCopyObject || + opcode() == SpvOpImageTexelPointer) && "GetBaseAddress should only be called on instructions that take a " "pointer or image."); uint32_t base = GetSingleWordInOperand(kLoadBaseIndex); diff --git a/source/opt/mem_pass.cpp b/source/opt/mem_pass.cpp index 178bf95f4..13364da0c 100644 --- a/source/opt/mem_pass.cpp +++ b/source/opt/mem_pass.cpp @@ -28,9 +28,7 @@ namespace opt { namespace { const uint32_t kCopyObjectOperandInIdx = 0; -const uint32_t kLoadPtrIdInIdx = 0; const uint32_t kLoopMergeMergeBlockIdInIdx = 0; -const uint32_t kStorePtrIdInIdx = 0; const uint32_t kStoreValIdInIdx = 1; const uint32_t kTypePointerStorageClassInIdx = 0; const uint32_t kTypePointerTypeIdInIdx = 1; @@ -119,10 +117,11 @@ ir::Instruction* MemPass::GetPtr(uint32_t ptrId, uint32_t* varId) { } ir::Instruction* MemPass::GetPtr(ir::Instruction* ip, uint32_t* varId) { - const SpvOp op = ip->opcode(); - assert(op == SpvOpStore || op == SpvOpLoad); - const uint32_t ptrId = ip->GetSingleWordInOperand( - op == SpvOpStore ? kStorePtrIdInIdx : kLoadPtrIdInIdx); + assert(ip->opcode() == SpvOpStore || ip->opcode() == SpvOpLoad || + ip->opcode() == SpvOpImageTexelPointer); + + // All of these opcode place the pointer in position 0. + const uint32_t ptrId = ip->GetSingleWordInOperand(0); return GetPtr(ptrId, varId); } diff --git a/test/opt/aggressive_dead_code_elim_test.cpp b/test/opt/aggressive_dead_code_elim_test.cpp index 1f2fee88d..b3346dbff 100644 --- a/test/opt/aggressive_dead_code_elim_test.cpp +++ b/test/opt/aggressive_dead_code_elim_test.cpp @@ -5369,6 +5369,47 @@ OpFunctionEnd SinglePassRunAndCheck(text, text, true, true); } +TEST_F(AggressiveDCETest, AtomicAdd) { + const std::string text = R"(OpCapability SampledBuffer +OpCapability StorageImageExtendedFormats +OpCapability ImageBuffer +OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %2 "min" %gl_GlobalInvocationID +OpExecutionMode %2 LocalSize 64 1 1 +OpSource HLSL 600 +OpDecorate %gl_GlobalInvocationID BuiltIn GlobalInvocationId +OpDecorate %4 DescriptorSet 4 +OpDecorate %4 Binding 70 +%uint = OpTypeInt 32 0 +%6 = OpTypeImage %uint Buffer 0 0 0 2 R32ui +%_ptr_UniformConstant_6 = OpTypePointer UniformConstant %6 +%_ptr_Private_6 = OpTypePointer Private %6 +%void = OpTypeVoid +%10 = OpTypeFunction %void +%uint_0 = OpConstant %uint 0 +%uint_1 = OpConstant %uint 1 +%v3uint = OpTypeVector %uint 3 +%_ptr_Input_v3uint = OpTypePointer Input %v3uint +%_ptr_Image_uint = OpTypePointer Image %uint +%4 = OpVariable %_ptr_UniformConstant_6 UniformConstant +%16 = OpVariable %_ptr_Private_6 Private +%gl_GlobalInvocationID = OpVariable %_ptr_Input_v3uint Input +%2 = OpFunction %void None %10 +%17 = OpLabel +%18 = OpLoad %6 %4 +OpStore %16 %18 +%19 = OpImageTexelPointer %_ptr_Image_uint %16 %uint_0 %uint_0 +%20 = OpAtomicIAdd %uint %19 %uint_1 %uint_0 %uint_1 +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck(text, text, true, true); +} + // TODO(greg-lunarg): Add tests to verify handling of these cases: // // Check that logical addressing required