mirror of
https://github.com/KhronosGroup/SPIRV-Tools
synced 2025-01-12 17:30:15 +00:00
Treat access chain indexes as signed in SROA (#2776)
Fixes #2768 * In scalar replacement, interpret access chain indexes as signed counts * Use Constant::GetSignExtendedValue and Constant::GetZeroExtendedValue where appropriate * new tests
This commit is contained in:
parent
31590104ec
commit
3726b500b1
@ -291,7 +291,7 @@ std::unique_ptr<Constant> ConstantManager::CreateConstant(
|
||||
}
|
||||
}
|
||||
|
||||
const Constant* ConstantManager::GetConstantFromInst(Instruction* inst) {
|
||||
const Constant* ConstantManager::GetConstantFromInst(const Instruction* inst) {
|
||||
std::vector<uint32_t> literal_words_or_ids;
|
||||
|
||||
// Collect the constant defining literals or component ids.
|
||||
|
@ -522,7 +522,7 @@ class ConstantManager {
|
||||
// Gets or creates a Constant instance to hold the constant value of the given
|
||||
// instruction. It returns a pointer to a Constant instance or nullptr if it
|
||||
// could not create the constant.
|
||||
const Constant* GetConstantFromInst(Instruction* inst);
|
||||
const Constant* GetConstantFromInst(const Instruction* inst);
|
||||
|
||||
// Gets or creates a constant defining instruction for the given Constant |c|.
|
||||
// If |c| had already been defined, it returns a pointer to the existing
|
||||
|
@ -232,8 +232,12 @@ bool ScalarReplacementPass::ReplaceAccessChain(
|
||||
// indexes) or a direct use of the replacement variable.
|
||||
uint32_t indexId = chain->GetSingleWordInOperand(1u);
|
||||
const Instruction* index = get_def_use_mgr()->GetDef(indexId);
|
||||
uint64_t indexValue = GetConstantInteger(index);
|
||||
if (indexValue >= replacements.size()) {
|
||||
int64_t indexValue = context()
|
||||
->get_constant_mgr()
|
||||
->GetConstantFromInst(index)
|
||||
->GetSignExtendedValue();
|
||||
if (indexValue < 0 ||
|
||||
indexValue >= static_cast<int64_t>(replacements.size())) {
|
||||
// Out of bounds access, this is illegal IR. Notice that OpAccessChain
|
||||
// indexing is 0-based, so we should also reject index == size-of-array.
|
||||
return false;
|
||||
@ -269,7 +273,7 @@ bool ScalarReplacementPass::CreateReplacementVariables(
|
||||
Instruction* inst, std::vector<Instruction*>* replacements) {
|
||||
Instruction* type = GetStorageType(inst);
|
||||
|
||||
std::unique_ptr<std::unordered_set<uint64_t>> components_used =
|
||||
std::unique_ptr<std::unordered_set<int64_t>> components_used =
|
||||
GetUsedComponents(inst);
|
||||
|
||||
uint32_t elem = 0;
|
||||
@ -467,35 +471,15 @@ void ScalarReplacementPass::GetOrCreateInitialValue(Instruction* source,
|
||||
}
|
||||
}
|
||||
|
||||
uint64_t ScalarReplacementPass::GetIntegerLiteral(const Operand& op) const {
|
||||
assert(op.words.size() <= 2);
|
||||
uint64_t len = 0;
|
||||
for (uint32_t i = 0; i != op.words.size(); ++i) {
|
||||
len |= (op.words[i] << (32 * i));
|
||||
}
|
||||
return len;
|
||||
}
|
||||
|
||||
uint64_t ScalarReplacementPass::GetConstantInteger(
|
||||
const Instruction* constant) const {
|
||||
assert(get_def_use_mgr()->GetDef(constant->type_id())->opcode() ==
|
||||
SpvOpTypeInt);
|
||||
assert(constant->opcode() == SpvOpConstant ||
|
||||
constant->opcode() == SpvOpConstantNull);
|
||||
if (constant->opcode() == SpvOpConstantNull) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const Operand& op = constant->GetInOperand(0u);
|
||||
return GetIntegerLiteral(op);
|
||||
}
|
||||
|
||||
uint64_t ScalarReplacementPass::GetArrayLength(
|
||||
const Instruction* arrayType) const {
|
||||
assert(arrayType->opcode() == SpvOpTypeArray);
|
||||
const Instruction* length =
|
||||
get_def_use_mgr()->GetDef(arrayType->GetSingleWordInOperand(1u));
|
||||
return GetConstantInteger(length);
|
||||
return context()
|
||||
->get_constant_mgr()
|
||||
->GetConstantFromInst(length)
|
||||
->GetZeroExtendedValue();
|
||||
}
|
||||
|
||||
uint64_t ScalarReplacementPass::GetNumElements(const Instruction* type) const {
|
||||
@ -734,10 +718,10 @@ bool ScalarReplacementPass::IsLargerThanSizeLimit(uint64_t length) const {
|
||||
return length > max_num_elements_;
|
||||
}
|
||||
|
||||
std::unique_ptr<std::unordered_set<uint64_t>>
|
||||
std::unique_ptr<std::unordered_set<int64_t>>
|
||||
ScalarReplacementPass::GetUsedComponents(Instruction* inst) {
|
||||
std::unique_ptr<std::unordered_set<uint64_t>> result(
|
||||
new std::unordered_set<uint64_t>());
|
||||
std::unique_ptr<std::unordered_set<int64_t>> result(
|
||||
new std::unordered_set<int64_t>());
|
||||
|
||||
analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
|
||||
|
||||
@ -775,18 +759,8 @@ ScalarReplacementPass::GetUsedComponents(Instruction* inst) {
|
||||
const analysis::Constant* index_const =
|
||||
const_mgr->FindDeclaredConstant(index_id);
|
||||
if (index_const) {
|
||||
const analysis::Integer* index_type =
|
||||
index_const->type()->AsInteger();
|
||||
assert(index_type);
|
||||
if (index_type->width() == 32) {
|
||||
result->insert(index_const->GetU32());
|
||||
result->insert(index_const->GetSignExtendedValue());
|
||||
return true;
|
||||
} else if (index_type->width() == 64) {
|
||||
result->insert(index_const->GetU64());
|
||||
return true;
|
||||
}
|
||||
result.reset(nullptr);
|
||||
return false;
|
||||
} else {
|
||||
// Could be any element. Assuming all are used.
|
||||
result.reset(nullptr);
|
||||
|
@ -158,14 +158,6 @@ class ScalarReplacementPass : public Pass {
|
||||
bool CreateReplacementVariables(Instruction* inst,
|
||||
std::vector<Instruction*>* replacements);
|
||||
|
||||
// Returns the value of an OpConstant of integer type.
|
||||
//
|
||||
// |constant| must use two or fewer words to generate the value.
|
||||
uint64_t GetConstantInteger(const Instruction* constant) const;
|
||||
|
||||
// Returns the integer literal for |op|.
|
||||
uint64_t GetIntegerLiteral(const Operand& op) const;
|
||||
|
||||
// Returns the array length for |arrayInst|.
|
||||
uint64_t GetArrayLength(const Instruction* arrayInst) const;
|
||||
|
||||
@ -216,7 +208,7 @@ class ScalarReplacementPass : public Pass {
|
||||
// Returns a set containing the which components of the result of |inst| are
|
||||
// potentially used. If the return value is |nullptr|, then every components
|
||||
// is possibly used.
|
||||
std::unique_ptr<std::unordered_set<uint64_t>> GetUsedComponents(
|
||||
std::unique_ptr<std::unordered_set<int64_t>> GetUsedComponents(
|
||||
Instruction* inst);
|
||||
|
||||
// Returns an instruction defining a null constant with type |type_id|. If
|
||||
|
@ -1701,6 +1701,67 @@ TEST_F(ScalarReplacementTest, OutOfBoundOpAccessChain) {
|
||||
EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result));
|
||||
}
|
||||
|
||||
TEST_F(ScalarReplacementTest, CharIndex) {
|
||||
const std::string text = R"(
|
||||
; CHECK: [[int:%\w+]] = OpTypeInt 32 0
|
||||
; CHECK: [[ptr:%\w+]] = OpTypePointer Function [[int]]
|
||||
; CHECK: OpVariable [[ptr]] Function
|
||||
OpCapability Shader
|
||||
OpCapability Int8
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint GLCompute %main "main"
|
||||
OpExecutionMode %main LocalSize 1 1 1
|
||||
%void = OpTypeVoid
|
||||
%int = OpTypeInt 32 0
|
||||
%int_1024 = OpConstant %int 1024
|
||||
%char = OpTypeInt 8 0
|
||||
%char_1 = OpConstant %char 1
|
||||
%array = OpTypeArray %int %int_1024
|
||||
%ptr_func_array = OpTypePointer Function %array
|
||||
%ptr_func_int = OpTypePointer Function %int
|
||||
%void_fn = OpTypeFunction %void
|
||||
%main = OpFunction %void None %void_fn
|
||||
%entry = OpLabel
|
||||
%var = OpVariable %ptr_func_array Function
|
||||
%gep = OpAccessChain %ptr_func_int %var %char_1
|
||||
OpStore %gep %int_1024
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
)";
|
||||
|
||||
SinglePassRunAndMatch<ScalarReplacementPass>(text, true, 0);
|
||||
}
|
||||
|
||||
TEST_F(ScalarReplacementTest, OutOfBoundsOpAccessChainNegative) {
|
||||
const std::string text = R"(
|
||||
OpCapability Shader
|
||||
OpCapability Int8
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint GLCompute %main "main"
|
||||
OpExecutionMode %main LocalSize 1 1 1
|
||||
%void = OpTypeVoid
|
||||
%int = OpTypeInt 32 0
|
||||
%int_1024 = OpConstant %int 1024
|
||||
%char = OpTypeInt 8 1
|
||||
%char_n1 = OpConstant %char -1
|
||||
%array = OpTypeArray %int %int_1024
|
||||
%ptr_func_array = OpTypePointer Function %array
|
||||
%ptr_func_int = OpTypePointer Function %int
|
||||
%void_fn = OpTypeFunction %void
|
||||
%main = OpFunction %void None %void_fn
|
||||
%entry = OpLabel
|
||||
%var = OpVariable %ptr_func_array Function
|
||||
%gep = OpAccessChain %ptr_func_int %var %char_n1
|
||||
OpStore %gep %int_1024
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
)";
|
||||
|
||||
auto result =
|
||||
SinglePassRunAndDisassemble<ScalarReplacementPass>(text, true, true, 0);
|
||||
EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace opt
|
||||
} // namespace spvtools
|
||||
|
Loading…
Reference in New Issue
Block a user