Handle overflowing id in merge return (#4606)

If the ids overflow when creating an integer constant in the ir_builder, there will be a nullptr dereference.  This is happening from inside merge return.

We need to propagate the error up, and make sure it is handled appropriately.
This commit is contained in:
Steven Perron 2021-11-01 08:45:32 -04:00 committed by GitHub
parent 97d4495600
commit 1082de6bb3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 65 additions and 15 deletions

View File

@ -359,8 +359,9 @@ class InstructionBuilder {
return AddInstruction(std::move(select));
}
// Adds a signed int32 constant to the binary.
// The |value| parameter is the constant value to be added.
// Returns a pointer to the definition of a signed 32-bit integer constant
// with the given value. Returns |nullptr| if the constant does not exist and
// cannot be created.
Instruction* GetSintConstant(int32_t value) {
return GetIntConstant<int32_t>(value, true);
}
@ -381,21 +382,24 @@ class InstructionBuilder {
GetContext()->TakeNextId(), ops));
return AddInstruction(std::move(construct));
}
// Adds an unsigned int32 constant to the binary.
// The |value| parameter is the constant value to be added.
// Returns a pointer to the definition of an unsigned 32-bit integer constant
// with the given value. Returns |nullptr| if the constant does not exist and
// cannot be created.
Instruction* GetUintConstant(uint32_t value) {
return GetIntConstant<uint32_t>(value, false);
}
uint32_t GetUintConstantId(uint32_t value) {
Instruction* uint_inst = GetUintConstant(value);
return uint_inst->result_id();
return (uint_inst != nullptr ? uint_inst->result_id() : 0);
}
// Adds either a signed or unsigned 32 bit integer constant to the binary
// depedning on the |sign|. If |sign| is true then the value is added as a
// depending on the |sign|. If |sign| is true then the value is added as a
// signed constant otherwise as an unsigned constant. If |sign| is false the
// value must not be a negative number.
// value must not be a negative number. Returns false if the constant does
// not exists and could be be created.
template <typename T>
Instruction* GetIntConstant(T value, bool sign) {
// Assert that we are not trying to store a negative number in an unsigned
@ -411,6 +415,10 @@ class InstructionBuilder {
uint32_t type_id =
GetContext()->get_type_mgr()->GetTypeInstruction(&int_type);
if (type_id == 0) {
return nullptr;
}
// Get the memory managed type so that it is safe to be stored by
// GetConstant.
analysis::Type* rebuilt_type =

View File

@ -111,7 +111,9 @@ bool MergeReturnPass::ProcessStructured(
}
RecordImmediateDominators(function);
AddSingleCaseSwitchAroundFunction();
if (!AddSingleCaseSwitchAroundFunction()) {
return false;
}
std::list<BasicBlock*> order;
cfg()->ComputeStructuredOrder(function, &*function->begin(), &order);
@ -770,7 +772,7 @@ void MergeReturnPass::InsertAfterElement(BasicBlock* element,
list->insert(pos, new_element);
}
void MergeReturnPass::AddSingleCaseSwitchAroundFunction() {
bool MergeReturnPass::AddSingleCaseSwitchAroundFunction() {
CreateReturnBlock();
CreateReturn(final_return_block_);
@ -778,7 +780,10 @@ void MergeReturnPass::AddSingleCaseSwitchAroundFunction() {
cfg()->RegisterBlock(final_return_block_);
}
CreateSingleCaseSwitch(final_return_block_);
if (!CreateSingleCaseSwitch(final_return_block_)) {
return false;
}
return true;
}
BasicBlock* MergeReturnPass::CreateContinueTarget(uint32_t header_label_id) {
@ -813,7 +818,7 @@ BasicBlock* MergeReturnPass::CreateContinueTarget(uint32_t header_label_id) {
return new_block;
}
void MergeReturnPass::CreateSingleCaseSwitch(BasicBlock* merge_target) {
bool MergeReturnPass::CreateSingleCaseSwitch(BasicBlock* merge_target) {
// Insert the switch before any code is run. We have to split the entry
// block to make sure the OpVariable instructions remain in the entry block.
BasicBlock* start_block = &*function_->begin();
@ -830,13 +835,17 @@ void MergeReturnPass::CreateSingleCaseSwitch(BasicBlock* merge_target) {
context(), start_block,
IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
builder.AddSwitch(builder.GetUintConstantId(0u), old_block->id(), {},
merge_target->id());
uint32_t const_zero_id = builder.GetUintConstantId(0u);
if (const_zero_id == 0) {
return false;
}
builder.AddSwitch(const_zero_id, old_block->id(), {}, merge_target->id());
if (context()->AreAnalysesValid(IRContext::kAnalysisCFG)) {
cfg()->RegisterBlock(old_block);
cfg()->AddEdges(start_block);
}
return true;
}
bool MergeReturnPass::HasNontrivialUnreachableBlocks(Function* function) {

View File

@ -277,7 +277,7 @@ class MergeReturnPass : public MemPass {
// current function where the switch and case value are both zero and the
// default is the merge block. Returns after the switch is executed. Sets
// |final_return_block_|.
void AddSingleCaseSwitchAroundFunction();
bool AddSingleCaseSwitchAroundFunction();
// Creates a new basic block that branches to |header_label_id|. Returns the
// new basic block. The block will be the second last basic block in the
@ -286,7 +286,7 @@ class MergeReturnPass : public MemPass {
// Creates a one case switch around the executable code of the function with
// |merge_target| as the merge node.
void CreateSingleCaseSwitch(BasicBlock* merge_target);
bool CreateSingleCaseSwitch(BasicBlock* merge_target);
// Returns true if |function| has an unreachable block that is not a continue
// target that simply branches back to the header, or a merge block containing

View File

@ -2567,6 +2567,39 @@ TEST_F(MergeReturnPassTest, ChainedPointerUsedAfterLoop) {
SinglePassRunAndMatch<MergeReturnPass>(before, true);
}
TEST_F(MergeReturnPassTest, OverflowTest1) {
const std::string text =
R"(
; CHECK: OpReturn
; CHECK-NOT: OpReturn
; CHECK: OpFunctionEnd
OpCapability ClipDistance
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %2 "main"
OpExecutionMode %2 OriginUpperLeft
%void = OpTypeVoid
%6 = OpTypeFunction %void
%2 = OpFunction %void None %6
%4194303 = OpLabel
OpBranch %18
%18 = OpLabel
OpLoopMerge %19 %20 None
OpBranch %21
%21 = OpLabel
OpReturn
%20 = OpLabel
OpBranch %18
%19 = OpLabel
OpUnreachable
OpFunctionEnd
)";
SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
auto result =
SinglePassRunToBinary<MergeReturnPass>(text, /* skip_nop = */ true);
EXPECT_EQ(Pass::Status::Failure, std::get<1>(result));
}
} // namespace
} // namespace opt
} // namespace spvtools