Fix extract with out-of-bounds index (#4529)

* Fix extract with out-of-bounds index

When folding a OpCompositeExtract that is fed by an
OpCompositeConstruct, we handle and out of bounds
index, but only in the case where the result of the
OpCompostiteConstruct is a struct.  This change
refactors that folding rule and then improves it to
handle an out-of-bounds access when the result of the
OpCompositeConstruct is a vector.
This commit is contained in:
Steven Perron 2021-09-20 13:02:47 -04:00 committed by GitHub
parent 1454c95d1b
commit 59f51bb4f8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 120 additions and 76 deletions

View File

@ -14,6 +14,7 @@
#include "source/opt/folding_rules.h"
#include <climits>
#include <limits>
#include <memory>
#include <utility>
@ -1463,90 +1464,121 @@ FoldingRule IntMultipleBy1() {
};
}
FoldingRule CompositeConstructFeedingExtract() {
return [](IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>&) {
// If the input to an OpCompositeExtract is an OpCompositeConstruct,
// then we can simply use the appropriate element in the construction.
assert(inst->opcode() == SpvOpCompositeExtract &&
"Wrong opcode. Should be OpCompositeExtract.");
analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
analysis::TypeManager* type_mgr = context->get_type_mgr();
// Returns the number of elements that the |index|th in operand in |inst|
// contributes to the result of |inst|. |inst| must be an
// OpCompositeConstructInstruction.
uint32_t GetNumOfElementsContributedByOperand(IRContext* context,
const Instruction* inst,
uint32_t index) {
assert(inst->opcode() == SpvOpCompositeConstruct);
analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
analysis::TypeManager* type_mgr = context->get_type_mgr();
// If there are no index operands, then this rule cannot do anything.
if (inst->NumInOperands() <= 1) {
return false;
}
analysis::Vector* result_type =
type_mgr->GetType(inst->type_id())->AsVector();
if (result_type == nullptr) {
// If the result of the OpCompositeConstruct is not a vector then every
// operands corresponds to a single element in the result.
return 1;
}
uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
Instruction* cinst = def_use_mgr->GetDef(cid);
// If the result type is a vector then the operands are either scalars or
// vectors. If it is a scalar, then it corresponds to a single element. If it
// is a vector, then each element in the vector will be an element in the
// result.
uint32_t id = inst->GetSingleWordInOperand(index);
Instruction* def = def_use_mgr->GetDef(id);
analysis::Vector* type = type_mgr->GetType(def->type_id())->AsVector();
if (type == nullptr) {
return 1;
}
return type->element_count();
}
if (cinst->opcode() != SpvOpCompositeConstruct) {
return false;
}
// Returns the in-operands for an OpCompositeExtract instruction that are needed
// to extract the |result_index|th element in the result of |inst| without using
// the result of |inst|. Returns the empty vector if |result_index| is
// out-of-bounds. |inst| must be an |OpCompositeConstruct| instruction.
std::vector<Operand> GetExtractOperandsForElementOfCompositeConstruct(
IRContext* context, const Instruction* inst, uint32_t result_index) {
assert(inst->opcode() == SpvOpCompositeConstruct);
analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
analysis::TypeManager* type_mgr = context->get_type_mgr();
std::vector<Operand> operands;
analysis::Type* composite_type = type_mgr->GetType(cinst->type_id());
if (composite_type->AsVector() == nullptr) {
// Get the element being extracted from the OpCompositeConstruct
// Since it is not a vector, it is simple to extract the single element.
uint32_t element_index = inst->GetSingleWordInOperand(1);
uint32_t element_id = cinst->GetSingleWordInOperand(element_index);
operands.push_back({SPV_OPERAND_TYPE_ID, {element_id}});
analysis::Type* result_type = type_mgr->GetType(inst->type_id());
if (result_type->AsVector() == nullptr) {
uint32_t id = inst->GetSingleWordInOperand(result_index);
return {Operand(SPV_OPERAND_TYPE_ID, {id})};
}
// Add the remaining indices for extraction.
for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
{inst->GetSingleWordInOperand(i)}});
}
} else {
// With vectors we have to handle the case where it is concatenating
// vectors.
assert(inst->NumInOperands() == 2 &&
"Expecting a vector of scalar values.");
uint32_t element_index = inst->GetSingleWordInOperand(1);
for (uint32_t construct_index = 0;
construct_index < cinst->NumInOperands(); ++construct_index) {
uint32_t element_id = cinst->GetSingleWordInOperand(construct_index);
Instruction* element_def = def_use_mgr->GetDef(element_id);
analysis::Vector* element_type =
type_mgr->GetType(element_def->type_id())->AsVector();
if (element_type) {
uint32_t vector_size = element_type->element_count();
if (vector_size <= element_index) {
// The element we want comes after this vector.
element_index -= vector_size;
} else {
// We want an element of this vector.
operands.push_back({SPV_OPERAND_TYPE_ID, {element_id}});
operands.push_back(
{SPV_OPERAND_TYPE_LITERAL_INTEGER, {element_index}});
break;
}
} else {
if (element_index == 0) {
// This is a scalar, and we this is the element we are extracting.
operands.push_back({SPV_OPERAND_TYPE_ID, {element_id}});
break;
} else {
// Skip over this scalar value.
--element_index;
}
}
// If the result type is a vector, then vector operands are concatenated.
uint32_t total_element_count = 0;
for (uint32_t idx = 0; idx < inst->NumInOperands(); ++idx) {
uint32_t element_count =
GetNumOfElementsContributedByOperand(context, inst, idx);
total_element_count += element_count;
if (result_index < total_element_count) {
std::vector<Operand> operands;
uint32_t id = inst->GetSingleWordInOperand(idx);
Instruction* operand_def = def_use_mgr->GetDef(id);
analysis::Type* operand_type = type_mgr->GetType(operand_def->type_id());
operands.push_back({SPV_OPERAND_TYPE_ID, {id}});
if (operand_type->AsVector()) {
uint32_t start_index_of_id = total_element_count - element_count;
uint32_t index_into_id = result_index - start_index_of_id;
operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {index_into_id}});
}
return operands;
}
}
return {};
}
bool CompositeConstructFeedingExtract(
IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>&) {
// If the input to an OpCompositeExtract is an OpCompositeConstruct,
// then we can simply use the appropriate element in the construction.
assert(inst->opcode() == SpvOpCompositeExtract &&
"Wrong opcode. Should be OpCompositeExtract.");
analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
// If there are no index operands, then this rule cannot do anything.
if (inst->NumInOperands() <= 1) {
return false;
}
uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
Instruction* cinst = def_use_mgr->GetDef(cid);
if (cinst->opcode() != SpvOpCompositeConstruct) {
return false;
}
uint32_t index_into_result = inst->GetSingleWordInOperand(1);
std::vector<Operand> operands =
GetExtractOperandsForElementOfCompositeConstruct(context, cinst,
index_into_result);
if (operands.empty()) {
return false;
}
// Add the remaining indices for extraction.
for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
operands.push_back(
{SPV_OPERAND_TYPE_LITERAL_INTEGER, {inst->GetSingleWordInOperand(i)}});
}
if (operands.size() == 1) {
// If there were no extra indices, then we have the final object. No need
// to extract even more.
if (operands.size() == 1) {
inst->SetOpcode(SpvOpCopyObject);
}
// to extract any more.
inst->SetOpcode(SpvOpCopyObject);
}
inst->SetInOperands(std::move(operands));
return true;
};
inst->SetInOperands(std::move(operands));
return true;
}
// If the OpCompositeConstruct is simply putting back together elements that
@ -2505,7 +2537,7 @@ void FoldingRules::AddFoldingRules() {
rules_[SpvOpCompositeConstruct].push_back(CompositeExtractFeedingConstruct);
rules_[SpvOpCompositeExtract].push_back(InsertFeedingExtract());
rules_[SpvOpCompositeExtract].push_back(CompositeConstructFeedingExtract());
rules_[SpvOpCompositeExtract].push_back(CompositeConstructFeedingExtract);
rules_[SpvOpCompositeExtract].push_back(VectorShuffleFeedingExtract());
rules_[SpvOpCompositeExtract].push_back(FMixFeedingExtract());

View File

@ -3608,7 +3608,19 @@ INSTANTIATE_TEST_SUITE_P(CompositeExtractFoldingTest, GeneralInstructionFoldingT
"%4 = OpCompositeExtract %int %3 2\n" +
"OpReturn\n" +
"OpFunctionEnd",
4, INT_0_ID)
4, INT_0_ID),
// Test case 15:
// Don't fold extract fed by construct with vector result if the index is
// past the last element.
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpCompositeConstruct %v2int %int_0 %int_0\n" +
"%3 = OpCompositeConstruct %v4int %2 %100 %int_0\n" +
"%4 = OpCompositeExtract %int %3 4\n" +
"OpReturn\n" +
"OpFunctionEnd",
4, 0)
));
INSTANTIATE_TEST_SUITE_P(CompositeConstructFoldingTest, GeneralInstructionFoldingTest,