Fix #2609 - Handle out-of-bounds scalar replacements. (#2767)

* Fix #2609 - Handle out-of-bounds scalar replacements.

When SROA tries to do a replacement for an OpAccessChain that is exactly
one element out of bounds, the code was trying to access its internal
array of replacements and segfaulting.

This protects the code from doing this, and it additionally fixes the
way SROA works by not returning failure when it refuses to do a
replacement.  Instead of failing the optimization pass, SROA will now
simply refuse to do the replacement and keep going.

Additionally, this patch fixes the SROA logic to now return a proper status so we can
correctly state that the pass made no changes to the IR if it only found
invalid references.
This commit is contained in:
Diego Novillo 2019-07-26 12:33:40 -04:00 committed by GitHub
parent f54b8653dd
commit 9559cdbdf0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 70 additions and 17 deletions

View File

@ -60,25 +60,25 @@ Pass::Status ScalarReplacementPass::ProcessFunction(Function* function) {
Instruction* varInst = worklist.front();
worklist.pop();
if (!ReplaceVariable(varInst, &worklist))
return Status::Failure;
else
status = Status::SuccessWithChange;
Status var_status = ReplaceVariable(varInst, &worklist);
if (var_status == Status::Failure)
return var_status;
else if (var_status == Status::SuccessWithChange)
status = var_status;
}
return status;
}
bool ScalarReplacementPass::ReplaceVariable(
Pass::Status ScalarReplacementPass::ReplaceVariable(
Instruction* inst, std::queue<Instruction*>* worklist) {
std::vector<Instruction*> replacements;
if (!CreateReplacementVariables(inst, &replacements)) {
return false;
return Status::Failure;
}
std::vector<Instruction*> dead;
dead.push_back(inst);
if (!get_def_use_mgr()->WhileEachUser(
if (get_def_use_mgr()->WhileEachUser(
inst, [this, &replacements, &dead](Instruction* user) {
if (!IsAnnotationInst(user->opcode())) {
switch (user->opcode()) {
@ -92,8 +92,10 @@ bool ScalarReplacementPass::ReplaceVariable(
break;
case SpvOpAccessChain:
case SpvOpInBoundsAccessChain:
if (!ReplaceAccessChain(user, replacements)) return false;
if (ReplaceAccessChain(user, replacements))
dead.push_back(user);
else
return false;
break;
case SpvOpName:
case SpvOpMemberName:
@ -105,7 +107,10 @@ bool ScalarReplacementPass::ReplaceVariable(
}
return true;
}))
return false;
dead.push_back(inst);
// If there are no dead instructions to clean up, return with no changes.
if (dead.empty()) return Status::SuccessWithoutChange;
// Clean up some dead code.
while (!dead.empty()) {
@ -125,7 +130,7 @@ bool ScalarReplacementPass::ReplaceVariable(
}
}
return true;
return Status::SuccessWithChange;
}
void ScalarReplacementPass::ReplaceWholeLoad(
@ -228,8 +233,9 @@ bool ScalarReplacementPass::ReplaceAccessChain(
uint32_t indexId = chain->GetSingleWordInOperand(1u);
const Instruction* index = get_def_use_mgr()->GetDef(indexId);
uint64_t indexValue = GetConstantInteger(index);
if (indexValue > replacements.size()) {
// Out of bounds access, this is illegal IR.
if (indexValue >= replacements.size()) {
// Out of bounds access, this is illegal IR. Notice that OpAccessChain
// indexing is 0-based, so we should also reject index == size-of-array.
return false;
} else {
const Instruction* var = replacements[static_cast<size_t>(indexValue)];

View File

@ -117,9 +117,12 @@ class ScalarReplacementPass : public Pass {
// for element of the composite type. Uses of |inst| are updated as
// appropriate. If the replacement variables are themselves scalarizable, they
// get added to |worklist| for further processing. If any replacement
// variable ends up with no uses it is erased. Returns false if the variable
// could not be replaced.
bool ReplaceVariable(Instruction* inst, std::queue<Instruction*>* worklist);
// variable ends up with no uses it is erased. Returns
// - Status::SuccessWithoutChange if the variable could not be replaced.
// - Status::SuccessWithChange if it made replacements.
// - Status::Failure if it couldn't create replacement variables.
Pass::Status ReplaceVariable(Instruction* inst,
std::queue<Instruction*>* worklist);
// Returns the underlying storage type for |inst|.
//

View File

@ -1657,6 +1657,50 @@ OpFunctionEnd
EXPECT_EQ(Pass::Status::Failure, std::get<1>(result));
}
// Test that replacements for OpAccessChain do not go out of bounds.
// https://github.com/KhronosGroup/SPIRV-Tools/issues/2609.
TEST_F(ScalarReplacementTest, OutOfBoundOpAccessChain) {
const std::string text = R"(
OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %main "main" %_GLF_color
OpExecutionMode %main OriginUpperLeft
OpSource ESSL 310
OpName %main "main"
OpName %a "a"
OpName %_GLF_color "_GLF_color"
OpDecorate %_GLF_color Location 0
%void = OpTypeVoid
%3 = OpTypeFunction %void
%int = OpTypeInt 32 1
%_ptr_Function_int = OpTypePointer Function %int
%int_1 = OpConstant %int 1
%float = OpTypeFloat 32
%uint = OpTypeInt 32 0
%uint_1 = OpConstant %uint 1
%_arr_float_uint_1 = OpTypeArray %float %uint_1
%_ptr_Function__arr_float_uint_1 = OpTypePointer Function %_arr_float_uint_1
%_ptr_Function_float = OpTypePointer Function %float
%_ptr_Output_float = OpTypePointer Output %float
%_GLF_color = OpVariable %_ptr_Output_float Output
%main = OpFunction %void None %3
%5 = OpLabel
%a = OpVariable %_ptr_Function__arr_float_uint_1 Function
%21 = OpAccessChain %_ptr_Function_float %a %int_1
%22 = OpLoad %float %21
OpStore %_GLF_color %22
OpReturn
OpFunctionEnd
)";
SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
auto result =
SinglePassRunAndDisassemble<ScalarReplacementPass>(text, true, false);
EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result));
}
} // namespace
} // namespace opt
} // namespace spvtools