diff --git a/source/opt/local_access_chain_convert_pass.cpp b/source/opt/local_access_chain_convert_pass.cpp index 9b8c112e1..9ff4ec6f0 100644 --- a/source/opt/local_access_chain_convert_pass.cpp +++ b/source/opt/local_access_chain_convert_pass.cpp @@ -77,6 +77,15 @@ void LocalAccessChainConvertPass::AppendConstantOperands( bool LocalAccessChainConvertPass::ReplaceAccessChainLoad( const Instruction* address_inst, Instruction* original_load) { // Build and append load of variable in ptrInst + if (address_inst->NumInOperands() == 1) { + // An access chain with no indices is essentially a copy. All that is + // needed is to propagate the address. + context()->ReplaceAllUsesWith( + address_inst->result_id(), + address_inst->GetSingleWordInOperand(kAccessChainPtrIdInIdx)); + return true; + } + std::vector> new_inst; uint32_t varId; uint32_t varPteTypeId; @@ -109,6 +118,18 @@ bool LocalAccessChainConvertPass::ReplaceAccessChainLoad( bool LocalAccessChainConvertPass::GenAccessChainStoreReplacement( const Instruction* ptrInst, uint32_t valId, std::vector>* newInsts) { + if (ptrInst->NumInOperands() == 1) { + // An access chain with no indices is essentially a copy. However, we still + // have to create a new store because the old ones will be deleted. + BuildAndAppendInst( + SpvOpStore, 0, 0, + {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, + {ptrInst->GetSingleWordInOperand(kAccessChainPtrIdInIdx)}}, + {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {valId}}}, + newInsts); + return true; + } + // Build and append load of variable in ptrInst uint32_t varId; uint32_t varPteTypeId; @@ -246,11 +267,13 @@ Pass::Status LocalAccessChainConvertPass::ConvertLocalAccessChains( if (!GenAccessChainStoreReplacement(ptrInst, valId, &newInsts)) { return Status::Failure; } + size_t num_of_instructions_to_skip = newInsts.size() - 1; dead_instructions.push_back(&*ii); ++ii; ii = ii.InsertBefore(std::move(newInsts)); - ++ii; - ++ii; + for (size_t i = 0; i < num_of_instructions_to_skip; ++i) { + ++ii; + } modified = true; } break; default: diff --git a/test/opt/local_access_chain_convert_test.cpp b/test/opt/local_access_chain_convert_test.cpp index 39899e3ee..3161d903f 100644 --- a/test/opt/local_access_chain_convert_test.cpp +++ b/test/opt/local_access_chain_convert_test.cpp @@ -927,6 +927,37 @@ TEST_F(LocalAccessChainConvertTest, IdOverflowReplacingStore2) { EXPECT_EQ(Pass::Status::Failure, std::get<1>(result)); } +TEST_F(LocalAccessChainConvertTest, AccessChainWithNoIndex) { + const std::string before = + R"( +; CHECK: OpFunction +; CHECK: [[var:%\w+]] = OpVariable +; CHECK: OpStore [[var]] %true +; CHECK: OpLoad %bool [[var]] + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource ESSL 310 + %void = OpTypeVoid + %4 = OpTypeFunction %void + %bool = OpTypeBool + %true = OpConstantTrue %bool +%_ptr_Function_bool = OpTypePointer Function %bool + %2 = OpFunction %void None %4 + %8 = OpLabel + %9 = OpVariable %_ptr_Function_bool Function + %10 = OpAccessChain %_ptr_Function_bool %9 + OpStore %10 %true + %11 = OpLoad %bool %10 + OpReturn + OpFunctionEnd +)"; + + SinglePassRunAndMatch(before, true); +} + // TODO(greg-lunarg): Add tests to verify handling of these cases: // // Assorted vector and matrix types