mirror of
https://github.com/KhronosGroup/SPIRV-Tools
synced 2024-11-22 19:50:05 +00:00
Fix extract with out-of-bounds index (#4529)
* Fix extract with out-of-bounds index When folding a OpCompositeExtract that is fed by an OpCompositeConstruct, we handle and out of bounds index, but only in the case where the result of the OpCompostiteConstruct is a struct. This change refactors that folding rule and then improves it to handle an out-of-bounds access when the result of the OpCompositeConstruct is a vector.
This commit is contained in:
parent
1454c95d1b
commit
59f51bb4f8
@ -14,6 +14,7 @@
|
||||
|
||||
#include "source/opt/folding_rules.h"
|
||||
|
||||
#include <climits>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
@ -1463,90 +1464,121 @@ FoldingRule IntMultipleBy1() {
|
||||
};
|
||||
}
|
||||
|
||||
FoldingRule CompositeConstructFeedingExtract() {
|
||||
return [](IRContext* context, Instruction* inst,
|
||||
const std::vector<const analysis::Constant*>&) {
|
||||
// If the input to an OpCompositeExtract is an OpCompositeConstruct,
|
||||
// then we can simply use the appropriate element in the construction.
|
||||
assert(inst->opcode() == SpvOpCompositeExtract &&
|
||||
"Wrong opcode. Should be OpCompositeExtract.");
|
||||
analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
|
||||
analysis::TypeManager* type_mgr = context->get_type_mgr();
|
||||
// Returns the number of elements that the |index|th in operand in |inst|
|
||||
// contributes to the result of |inst|. |inst| must be an
|
||||
// OpCompositeConstructInstruction.
|
||||
uint32_t GetNumOfElementsContributedByOperand(IRContext* context,
|
||||
const Instruction* inst,
|
||||
uint32_t index) {
|
||||
assert(inst->opcode() == SpvOpCompositeConstruct);
|
||||
analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
|
||||
analysis::TypeManager* type_mgr = context->get_type_mgr();
|
||||
|
||||
// If there are no index operands, then this rule cannot do anything.
|
||||
if (inst->NumInOperands() <= 1) {
|
||||
return false;
|
||||
}
|
||||
analysis::Vector* result_type =
|
||||
type_mgr->GetType(inst->type_id())->AsVector();
|
||||
if (result_type == nullptr) {
|
||||
// If the result of the OpCompositeConstruct is not a vector then every
|
||||
// operands corresponds to a single element in the result.
|
||||
return 1;
|
||||
}
|
||||
|
||||
uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
|
||||
Instruction* cinst = def_use_mgr->GetDef(cid);
|
||||
// If the result type is a vector then the operands are either scalars or
|
||||
// vectors. If it is a scalar, then it corresponds to a single element. If it
|
||||
// is a vector, then each element in the vector will be an element in the
|
||||
// result.
|
||||
uint32_t id = inst->GetSingleWordInOperand(index);
|
||||
Instruction* def = def_use_mgr->GetDef(id);
|
||||
analysis::Vector* type = type_mgr->GetType(def->type_id())->AsVector();
|
||||
if (type == nullptr) {
|
||||
return 1;
|
||||
}
|
||||
return type->element_count();
|
||||
}
|
||||
|
||||
if (cinst->opcode() != SpvOpCompositeConstruct) {
|
||||
return false;
|
||||
}
|
||||
// Returns the in-operands for an OpCompositeExtract instruction that are needed
|
||||
// to extract the |result_index|th element in the result of |inst| without using
|
||||
// the result of |inst|. Returns the empty vector if |result_index| is
|
||||
// out-of-bounds. |inst| must be an |OpCompositeConstruct| instruction.
|
||||
std::vector<Operand> GetExtractOperandsForElementOfCompositeConstruct(
|
||||
IRContext* context, const Instruction* inst, uint32_t result_index) {
|
||||
assert(inst->opcode() == SpvOpCompositeConstruct);
|
||||
analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
|
||||
analysis::TypeManager* type_mgr = context->get_type_mgr();
|
||||
|
||||
std::vector<Operand> operands;
|
||||
analysis::Type* composite_type = type_mgr->GetType(cinst->type_id());
|
||||
if (composite_type->AsVector() == nullptr) {
|
||||
// Get the element being extracted from the OpCompositeConstruct
|
||||
// Since it is not a vector, it is simple to extract the single element.
|
||||
uint32_t element_index = inst->GetSingleWordInOperand(1);
|
||||
uint32_t element_id = cinst->GetSingleWordInOperand(element_index);
|
||||
operands.push_back({SPV_OPERAND_TYPE_ID, {element_id}});
|
||||
analysis::Type* result_type = type_mgr->GetType(inst->type_id());
|
||||
if (result_type->AsVector() == nullptr) {
|
||||
uint32_t id = inst->GetSingleWordInOperand(result_index);
|
||||
return {Operand(SPV_OPERAND_TYPE_ID, {id})};
|
||||
}
|
||||
|
||||
// Add the remaining indices for extraction.
|
||||
for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
|
||||
operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
|
||||
{inst->GetSingleWordInOperand(i)}});
|
||||
}
|
||||
|
||||
} else {
|
||||
// With vectors we have to handle the case where it is concatenating
|
||||
// vectors.
|
||||
assert(inst->NumInOperands() == 2 &&
|
||||
"Expecting a vector of scalar values.");
|
||||
|
||||
uint32_t element_index = inst->GetSingleWordInOperand(1);
|
||||
for (uint32_t construct_index = 0;
|
||||
construct_index < cinst->NumInOperands(); ++construct_index) {
|
||||
uint32_t element_id = cinst->GetSingleWordInOperand(construct_index);
|
||||
Instruction* element_def = def_use_mgr->GetDef(element_id);
|
||||
analysis::Vector* element_type =
|
||||
type_mgr->GetType(element_def->type_id())->AsVector();
|
||||
if (element_type) {
|
||||
uint32_t vector_size = element_type->element_count();
|
||||
if (vector_size <= element_index) {
|
||||
// The element we want comes after this vector.
|
||||
element_index -= vector_size;
|
||||
} else {
|
||||
// We want an element of this vector.
|
||||
operands.push_back({SPV_OPERAND_TYPE_ID, {element_id}});
|
||||
operands.push_back(
|
||||
{SPV_OPERAND_TYPE_LITERAL_INTEGER, {element_index}});
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
if (element_index == 0) {
|
||||
// This is a scalar, and we this is the element we are extracting.
|
||||
operands.push_back({SPV_OPERAND_TYPE_ID, {element_id}});
|
||||
break;
|
||||
} else {
|
||||
// Skip over this scalar value.
|
||||
--element_index;
|
||||
}
|
||||
}
|
||||
// If the result type is a vector, then vector operands are concatenated.
|
||||
uint32_t total_element_count = 0;
|
||||
for (uint32_t idx = 0; idx < inst->NumInOperands(); ++idx) {
|
||||
uint32_t element_count =
|
||||
GetNumOfElementsContributedByOperand(context, inst, idx);
|
||||
total_element_count += element_count;
|
||||
if (result_index < total_element_count) {
|
||||
std::vector<Operand> operands;
|
||||
uint32_t id = inst->GetSingleWordInOperand(idx);
|
||||
Instruction* operand_def = def_use_mgr->GetDef(id);
|
||||
analysis::Type* operand_type = type_mgr->GetType(operand_def->type_id());
|
||||
|
||||
operands.push_back({SPV_OPERAND_TYPE_ID, {id}});
|
||||
if (operand_type->AsVector()) {
|
||||
uint32_t start_index_of_id = total_element_count - element_count;
|
||||
uint32_t index_into_id = result_index - start_index_of_id;
|
||||
operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {index_into_id}});
|
||||
}
|
||||
return operands;
|
||||
}
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
bool CompositeConstructFeedingExtract(
|
||||
IRContext* context, Instruction* inst,
|
||||
const std::vector<const analysis::Constant*>&) {
|
||||
// If the input to an OpCompositeExtract is an OpCompositeConstruct,
|
||||
// then we can simply use the appropriate element in the construction.
|
||||
assert(inst->opcode() == SpvOpCompositeExtract &&
|
||||
"Wrong opcode. Should be OpCompositeExtract.");
|
||||
analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
|
||||
|
||||
// If there are no index operands, then this rule cannot do anything.
|
||||
if (inst->NumInOperands() <= 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
|
||||
Instruction* cinst = def_use_mgr->GetDef(cid);
|
||||
|
||||
if (cinst->opcode() != SpvOpCompositeConstruct) {
|
||||
return false;
|
||||
}
|
||||
|
||||
uint32_t index_into_result = inst->GetSingleWordInOperand(1);
|
||||
std::vector<Operand> operands =
|
||||
GetExtractOperandsForElementOfCompositeConstruct(context, cinst,
|
||||
index_into_result);
|
||||
|
||||
if (operands.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Add the remaining indices for extraction.
|
||||
for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
|
||||
operands.push_back(
|
||||
{SPV_OPERAND_TYPE_LITERAL_INTEGER, {inst->GetSingleWordInOperand(i)}});
|
||||
}
|
||||
|
||||
if (operands.size() == 1) {
|
||||
// If there were no extra indices, then we have the final object. No need
|
||||
// to extract even more.
|
||||
if (operands.size() == 1) {
|
||||
inst->SetOpcode(SpvOpCopyObject);
|
||||
}
|
||||
// to extract any more.
|
||||
inst->SetOpcode(SpvOpCopyObject);
|
||||
}
|
||||
|
||||
inst->SetInOperands(std::move(operands));
|
||||
return true;
|
||||
};
|
||||
inst->SetInOperands(std::move(operands));
|
||||
return true;
|
||||
}
|
||||
|
||||
// If the OpCompositeConstruct is simply putting back together elements that
|
||||
@ -2505,7 +2537,7 @@ void FoldingRules::AddFoldingRules() {
|
||||
rules_[SpvOpCompositeConstruct].push_back(CompositeExtractFeedingConstruct);
|
||||
|
||||
rules_[SpvOpCompositeExtract].push_back(InsertFeedingExtract());
|
||||
rules_[SpvOpCompositeExtract].push_back(CompositeConstructFeedingExtract());
|
||||
rules_[SpvOpCompositeExtract].push_back(CompositeConstructFeedingExtract);
|
||||
rules_[SpvOpCompositeExtract].push_back(VectorShuffleFeedingExtract());
|
||||
rules_[SpvOpCompositeExtract].push_back(FMixFeedingExtract());
|
||||
|
||||
|
@ -3608,7 +3608,19 @@ INSTANTIATE_TEST_SUITE_P(CompositeExtractFoldingTest, GeneralInstructionFoldingT
|
||||
"%4 = OpCompositeExtract %int %3 2\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
4, INT_0_ID)
|
||||
4, INT_0_ID),
|
||||
// Test case 15:
|
||||
// Don't fold extract fed by construct with vector result if the index is
|
||||
// past the last element.
|
||||
InstructionFoldingCase<uint32_t>(
|
||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%2 = OpCompositeConstruct %v2int %int_0 %int_0\n" +
|
||||
"%3 = OpCompositeConstruct %v4int %2 %100 %int_0\n" +
|
||||
"%4 = OpCompositeExtract %int %3 4\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
4, 0)
|
||||
));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(CompositeConstructFoldingTest, GeneralInstructionFoldingTest,
|
||||
|
Loading…
Reference in New Issue
Block a user