diff --git a/source/opt/scalar_replacement_pass.cpp b/source/opt/scalar_replacement_pass.cpp index 6997bf6b2..9ae1ae8e8 100644 --- a/source/opt/scalar_replacement_pass.cpp +++ b/source/opt/scalar_replacement_pass.cpp @@ -60,25 +60,25 @@ Pass::Status ScalarReplacementPass::ProcessFunction(Function* function) { Instruction* varInst = worklist.front(); worklist.pop(); - if (!ReplaceVariable(varInst, &worklist)) - return Status::Failure; - else - status = Status::SuccessWithChange; + Status var_status = ReplaceVariable(varInst, &worklist); + if (var_status == Status::Failure) + return var_status; + else if (var_status == Status::SuccessWithChange) + status = var_status; } return status; } -bool ScalarReplacementPass::ReplaceVariable( +Pass::Status ScalarReplacementPass::ReplaceVariable( Instruction* inst, std::queue* worklist) { std::vector replacements; if (!CreateReplacementVariables(inst, &replacements)) { - return false; + return Status::Failure; } std::vector dead; - dead.push_back(inst); - if (!get_def_use_mgr()->WhileEachUser( + if (get_def_use_mgr()->WhileEachUser( inst, [this, &replacements, &dead](Instruction* user) { if (!IsAnnotationInst(user->opcode())) { switch (user->opcode()) { @@ -92,8 +92,10 @@ bool ScalarReplacementPass::ReplaceVariable( break; case SpvOpAccessChain: case SpvOpInBoundsAccessChain: - if (!ReplaceAccessChain(user, replacements)) return false; - dead.push_back(user); + if (ReplaceAccessChain(user, replacements)) + dead.push_back(user); + else + return false; break; case SpvOpName: case SpvOpMemberName: @@ -105,7 +107,10 @@ bool ScalarReplacementPass::ReplaceVariable( } return true; })) - return false; + dead.push_back(inst); + + // If there are no dead instructions to clean up, return with no changes. + if (dead.empty()) return Status::SuccessWithoutChange; // Clean up some dead code. while (!dead.empty()) { @@ -125,7 +130,7 @@ bool ScalarReplacementPass::ReplaceVariable( } } - return true; + return Status::SuccessWithChange; } void ScalarReplacementPass::ReplaceWholeLoad( @@ -228,8 +233,9 @@ bool ScalarReplacementPass::ReplaceAccessChain( uint32_t indexId = chain->GetSingleWordInOperand(1u); const Instruction* index = get_def_use_mgr()->GetDef(indexId); uint64_t indexValue = GetConstantInteger(index); - if (indexValue > replacements.size()) { - // Out of bounds access, this is illegal IR. + if (indexValue >= replacements.size()) { + // Out of bounds access, this is illegal IR. Notice that OpAccessChain + // indexing is 0-based, so we should also reject index == size-of-array. return false; } else { const Instruction* var = replacements[static_cast(indexValue)]; diff --git a/source/opt/scalar_replacement_pass.h b/source/opt/scalar_replacement_pass.h index ca8c0b453..3a17045b4 100644 --- a/source/opt/scalar_replacement_pass.h +++ b/source/opt/scalar_replacement_pass.h @@ -117,9 +117,12 @@ class ScalarReplacementPass : public Pass { // for element of the composite type. Uses of |inst| are updated as // appropriate. If the replacement variables are themselves scalarizable, they // get added to |worklist| for further processing. If any replacement - // variable ends up with no uses it is erased. Returns false if the variable - // could not be replaced. - bool ReplaceVariable(Instruction* inst, std::queue* worklist); + // variable ends up with no uses it is erased. Returns + // - Status::SuccessWithoutChange if the variable could not be replaced. + // - Status::SuccessWithChange if it made replacements. + // - Status::Failure if it couldn't create replacement variables. + Pass::Status ReplaceVariable(Instruction* inst, + std::queue* worklist); // Returns the underlying storage type for |inst|. // diff --git a/test/opt/scalar_replacement_test.cpp b/test/opt/scalar_replacement_test.cpp index 6c182d5e6..2ed7b5a13 100644 --- a/test/opt/scalar_replacement_test.cpp +++ b/test/opt/scalar_replacement_test.cpp @@ -1657,6 +1657,50 @@ OpFunctionEnd EXPECT_EQ(Pass::Status::Failure, std::get<1>(result)); } +// Test that replacements for OpAccessChain do not go out of bounds. +// https://github.com/KhronosGroup/SPIRV-Tools/issues/2609. +TEST_F(ScalarReplacementTest, OutOfBoundOpAccessChain) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %_GLF_color + OpExecutionMode %main OriginUpperLeft + OpSource ESSL 310 + OpName %main "main" + OpName %a "a" + OpName %_GLF_color "_GLF_color" + OpDecorate %_GLF_color Location 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int + %int_1 = OpConstant %int 1 + %float = OpTypeFloat 32 + %uint = OpTypeInt 32 0 + %uint_1 = OpConstant %uint 1 +%_arr_float_uint_1 = OpTypeArray %float %uint_1 +%_ptr_Function__arr_float_uint_1 = OpTypePointer Function %_arr_float_uint_1 +%_ptr_Function_float = OpTypePointer Function %float +%_ptr_Output_float = OpTypePointer Output %float + %_GLF_color = OpVariable %_ptr_Output_float Output + %main = OpFunction %void None %3 + %5 = OpLabel + %a = OpVariable %_ptr_Function__arr_float_uint_1 Function + %21 = OpAccessChain %_ptr_Function_float %a %int_1 + %22 = OpLoad %float %21 + OpStore %_GLF_color %22 + OpReturn + OpFunctionEnd + )"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + + auto result = + SinglePassRunAndDisassemble(text, true, false); + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); +} + } // namespace } // namespace opt } // namespace spvtools