mirror of
https://github.com/KhronosGroup/SPIRV-Tools
synced 2024-12-25 17:21:06 +00:00
Treat access chain indexes as signed in SROA (#2776)
Fixes #2768 * In scalar replacement, interpret access chain indexes as signed counts * Use Constant::GetSignExtendedValue and Constant::GetZeroExtendedValue where appropriate * new tests
This commit is contained in:
parent
31590104ec
commit
3726b500b1
@ -291,7 +291,7 @@ std::unique_ptr<Constant> ConstantManager::CreateConstant(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const Constant* ConstantManager::GetConstantFromInst(Instruction* inst) {
|
const Constant* ConstantManager::GetConstantFromInst(const Instruction* inst) {
|
||||||
std::vector<uint32_t> literal_words_or_ids;
|
std::vector<uint32_t> literal_words_or_ids;
|
||||||
|
|
||||||
// Collect the constant defining literals or component ids.
|
// Collect the constant defining literals or component ids.
|
||||||
|
@ -522,7 +522,7 @@ class ConstantManager {
|
|||||||
// Gets or creates a Constant instance to hold the constant value of the given
|
// Gets or creates a Constant instance to hold the constant value of the given
|
||||||
// instruction. It returns a pointer to a Constant instance or nullptr if it
|
// instruction. It returns a pointer to a Constant instance or nullptr if it
|
||||||
// could not create the constant.
|
// could not create the constant.
|
||||||
const Constant* GetConstantFromInst(Instruction* inst);
|
const Constant* GetConstantFromInst(const Instruction* inst);
|
||||||
|
|
||||||
// Gets or creates a constant defining instruction for the given Constant |c|.
|
// Gets or creates a constant defining instruction for the given Constant |c|.
|
||||||
// If |c| had already been defined, it returns a pointer to the existing
|
// If |c| had already been defined, it returns a pointer to the existing
|
||||||
|
@ -232,8 +232,12 @@ bool ScalarReplacementPass::ReplaceAccessChain(
|
|||||||
// indexes) or a direct use of the replacement variable.
|
// indexes) or a direct use of the replacement variable.
|
||||||
uint32_t indexId = chain->GetSingleWordInOperand(1u);
|
uint32_t indexId = chain->GetSingleWordInOperand(1u);
|
||||||
const Instruction* index = get_def_use_mgr()->GetDef(indexId);
|
const Instruction* index = get_def_use_mgr()->GetDef(indexId);
|
||||||
uint64_t indexValue = GetConstantInteger(index);
|
int64_t indexValue = context()
|
||||||
if (indexValue >= replacements.size()) {
|
->get_constant_mgr()
|
||||||
|
->GetConstantFromInst(index)
|
||||||
|
->GetSignExtendedValue();
|
||||||
|
if (indexValue < 0 ||
|
||||||
|
indexValue >= static_cast<int64_t>(replacements.size())) {
|
||||||
// Out of bounds access, this is illegal IR. Notice that OpAccessChain
|
// Out of bounds access, this is illegal IR. Notice that OpAccessChain
|
||||||
// indexing is 0-based, so we should also reject index == size-of-array.
|
// indexing is 0-based, so we should also reject index == size-of-array.
|
||||||
return false;
|
return false;
|
||||||
@ -269,7 +273,7 @@ bool ScalarReplacementPass::CreateReplacementVariables(
|
|||||||
Instruction* inst, std::vector<Instruction*>* replacements) {
|
Instruction* inst, std::vector<Instruction*>* replacements) {
|
||||||
Instruction* type = GetStorageType(inst);
|
Instruction* type = GetStorageType(inst);
|
||||||
|
|
||||||
std::unique_ptr<std::unordered_set<uint64_t>> components_used =
|
std::unique_ptr<std::unordered_set<int64_t>> components_used =
|
||||||
GetUsedComponents(inst);
|
GetUsedComponents(inst);
|
||||||
|
|
||||||
uint32_t elem = 0;
|
uint32_t elem = 0;
|
||||||
@ -467,35 +471,15 @@ void ScalarReplacementPass::GetOrCreateInitialValue(Instruction* source,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
uint64_t ScalarReplacementPass::GetIntegerLiteral(const Operand& op) const {
|
|
||||||
assert(op.words.size() <= 2);
|
|
||||||
uint64_t len = 0;
|
|
||||||
for (uint32_t i = 0; i != op.words.size(); ++i) {
|
|
||||||
len |= (op.words[i] << (32 * i));
|
|
||||||
}
|
|
||||||
return len;
|
|
||||||
}
|
|
||||||
|
|
||||||
uint64_t ScalarReplacementPass::GetConstantInteger(
|
|
||||||
const Instruction* constant) const {
|
|
||||||
assert(get_def_use_mgr()->GetDef(constant->type_id())->opcode() ==
|
|
||||||
SpvOpTypeInt);
|
|
||||||
assert(constant->opcode() == SpvOpConstant ||
|
|
||||||
constant->opcode() == SpvOpConstantNull);
|
|
||||||
if (constant->opcode() == SpvOpConstantNull) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
const Operand& op = constant->GetInOperand(0u);
|
|
||||||
return GetIntegerLiteral(op);
|
|
||||||
}
|
|
||||||
|
|
||||||
uint64_t ScalarReplacementPass::GetArrayLength(
|
uint64_t ScalarReplacementPass::GetArrayLength(
|
||||||
const Instruction* arrayType) const {
|
const Instruction* arrayType) const {
|
||||||
assert(arrayType->opcode() == SpvOpTypeArray);
|
assert(arrayType->opcode() == SpvOpTypeArray);
|
||||||
const Instruction* length =
|
const Instruction* length =
|
||||||
get_def_use_mgr()->GetDef(arrayType->GetSingleWordInOperand(1u));
|
get_def_use_mgr()->GetDef(arrayType->GetSingleWordInOperand(1u));
|
||||||
return GetConstantInteger(length);
|
return context()
|
||||||
|
->get_constant_mgr()
|
||||||
|
->GetConstantFromInst(length)
|
||||||
|
->GetZeroExtendedValue();
|
||||||
}
|
}
|
||||||
|
|
||||||
uint64_t ScalarReplacementPass::GetNumElements(const Instruction* type) const {
|
uint64_t ScalarReplacementPass::GetNumElements(const Instruction* type) const {
|
||||||
@ -734,10 +718,10 @@ bool ScalarReplacementPass::IsLargerThanSizeLimit(uint64_t length) const {
|
|||||||
return length > max_num_elements_;
|
return length > max_num_elements_;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<std::unordered_set<uint64_t>>
|
std::unique_ptr<std::unordered_set<int64_t>>
|
||||||
ScalarReplacementPass::GetUsedComponents(Instruction* inst) {
|
ScalarReplacementPass::GetUsedComponents(Instruction* inst) {
|
||||||
std::unique_ptr<std::unordered_set<uint64_t>> result(
|
std::unique_ptr<std::unordered_set<int64_t>> result(
|
||||||
new std::unordered_set<uint64_t>());
|
new std::unordered_set<int64_t>());
|
||||||
|
|
||||||
analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
|
analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
|
||||||
|
|
||||||
@ -775,18 +759,8 @@ ScalarReplacementPass::GetUsedComponents(Instruction* inst) {
|
|||||||
const analysis::Constant* index_const =
|
const analysis::Constant* index_const =
|
||||||
const_mgr->FindDeclaredConstant(index_id);
|
const_mgr->FindDeclaredConstant(index_id);
|
||||||
if (index_const) {
|
if (index_const) {
|
||||||
const analysis::Integer* index_type =
|
result->insert(index_const->GetSignExtendedValue());
|
||||||
index_const->type()->AsInteger();
|
return true;
|
||||||
assert(index_type);
|
|
||||||
if (index_type->width() == 32) {
|
|
||||||
result->insert(index_const->GetU32());
|
|
||||||
return true;
|
|
||||||
} else if (index_type->width() == 64) {
|
|
||||||
result->insert(index_const->GetU64());
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
result.reset(nullptr);
|
|
||||||
return false;
|
|
||||||
} else {
|
} else {
|
||||||
// Could be any element. Assuming all are used.
|
// Could be any element. Assuming all are used.
|
||||||
result.reset(nullptr);
|
result.reset(nullptr);
|
||||||
|
@ -158,14 +158,6 @@ class ScalarReplacementPass : public Pass {
|
|||||||
bool CreateReplacementVariables(Instruction* inst,
|
bool CreateReplacementVariables(Instruction* inst,
|
||||||
std::vector<Instruction*>* replacements);
|
std::vector<Instruction*>* replacements);
|
||||||
|
|
||||||
// Returns the value of an OpConstant of integer type.
|
|
||||||
//
|
|
||||||
// |constant| must use two or fewer words to generate the value.
|
|
||||||
uint64_t GetConstantInteger(const Instruction* constant) const;
|
|
||||||
|
|
||||||
// Returns the integer literal for |op|.
|
|
||||||
uint64_t GetIntegerLiteral(const Operand& op) const;
|
|
||||||
|
|
||||||
// Returns the array length for |arrayInst|.
|
// Returns the array length for |arrayInst|.
|
||||||
uint64_t GetArrayLength(const Instruction* arrayInst) const;
|
uint64_t GetArrayLength(const Instruction* arrayInst) const;
|
||||||
|
|
||||||
@ -216,7 +208,7 @@ class ScalarReplacementPass : public Pass {
|
|||||||
// Returns a set containing the which components of the result of |inst| are
|
// Returns a set containing the which components of the result of |inst| are
|
||||||
// potentially used. If the return value is |nullptr|, then every components
|
// potentially used. If the return value is |nullptr|, then every components
|
||||||
// is possibly used.
|
// is possibly used.
|
||||||
std::unique_ptr<std::unordered_set<uint64_t>> GetUsedComponents(
|
std::unique_ptr<std::unordered_set<int64_t>> GetUsedComponents(
|
||||||
Instruction* inst);
|
Instruction* inst);
|
||||||
|
|
||||||
// Returns an instruction defining a null constant with type |type_id|. If
|
// Returns an instruction defining a null constant with type |type_id|. If
|
||||||
|
@ -1701,6 +1701,67 @@ TEST_F(ScalarReplacementTest, OutOfBoundOpAccessChain) {
|
|||||||
EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result));
|
EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(ScalarReplacementTest, CharIndex) {
|
||||||
|
const std::string text = R"(
|
||||||
|
; CHECK: [[int:%\w+]] = OpTypeInt 32 0
|
||||||
|
; CHECK: [[ptr:%\w+]] = OpTypePointer Function [[int]]
|
||||||
|
; CHECK: OpVariable [[ptr]] Function
|
||||||
|
OpCapability Shader
|
||||||
|
OpCapability Int8
|
||||||
|
OpMemoryModel Logical GLSL450
|
||||||
|
OpEntryPoint GLCompute %main "main"
|
||||||
|
OpExecutionMode %main LocalSize 1 1 1
|
||||||
|
%void = OpTypeVoid
|
||||||
|
%int = OpTypeInt 32 0
|
||||||
|
%int_1024 = OpConstant %int 1024
|
||||||
|
%char = OpTypeInt 8 0
|
||||||
|
%char_1 = OpConstant %char 1
|
||||||
|
%array = OpTypeArray %int %int_1024
|
||||||
|
%ptr_func_array = OpTypePointer Function %array
|
||||||
|
%ptr_func_int = OpTypePointer Function %int
|
||||||
|
%void_fn = OpTypeFunction %void
|
||||||
|
%main = OpFunction %void None %void_fn
|
||||||
|
%entry = OpLabel
|
||||||
|
%var = OpVariable %ptr_func_array Function
|
||||||
|
%gep = OpAccessChain %ptr_func_int %var %char_1
|
||||||
|
OpStore %gep %int_1024
|
||||||
|
OpReturn
|
||||||
|
OpFunctionEnd
|
||||||
|
)";
|
||||||
|
|
||||||
|
SinglePassRunAndMatch<ScalarReplacementPass>(text, true, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ScalarReplacementTest, OutOfBoundsOpAccessChainNegative) {
|
||||||
|
const std::string text = R"(
|
||||||
|
OpCapability Shader
|
||||||
|
OpCapability Int8
|
||||||
|
OpMemoryModel Logical GLSL450
|
||||||
|
OpEntryPoint GLCompute %main "main"
|
||||||
|
OpExecutionMode %main LocalSize 1 1 1
|
||||||
|
%void = OpTypeVoid
|
||||||
|
%int = OpTypeInt 32 0
|
||||||
|
%int_1024 = OpConstant %int 1024
|
||||||
|
%char = OpTypeInt 8 1
|
||||||
|
%char_n1 = OpConstant %char -1
|
||||||
|
%array = OpTypeArray %int %int_1024
|
||||||
|
%ptr_func_array = OpTypePointer Function %array
|
||||||
|
%ptr_func_int = OpTypePointer Function %int
|
||||||
|
%void_fn = OpTypeFunction %void
|
||||||
|
%main = OpFunction %void None %void_fn
|
||||||
|
%entry = OpLabel
|
||||||
|
%var = OpVariable %ptr_func_array Function
|
||||||
|
%gep = OpAccessChain %ptr_func_int %var %char_n1
|
||||||
|
OpStore %gep %int_1024
|
||||||
|
OpReturn
|
||||||
|
OpFunctionEnd
|
||||||
|
)";
|
||||||
|
|
||||||
|
auto result =
|
||||||
|
SinglePassRunAndDisassemble<ScalarReplacementPass>(text, true, true, 0);
|
||||||
|
EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
} // namespace spvtools
|
} // namespace spvtools
|
||||||
|
Loading…
Reference in New Issue
Block a user