Treat access chain indexes as signed in SROA (#2776)

Fixes #2768

* In scalar replacement, interpret access chain indexes as signed counts
* Use Constant::GetSignExtendedValue and Constant::GetZeroExtendedValue
where appropriate
* new tests
This commit is contained in:
alan-baker 2019-07-31 15:39:33 -04:00 committed by GitHub
parent 31590104ec
commit 3726b500b1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 80 additions and 53 deletions

View File

@ -291,7 +291,7 @@ std::unique_ptr<Constant> ConstantManager::CreateConstant(
}
}
const Constant* ConstantManager::GetConstantFromInst(Instruction* inst) {
const Constant* ConstantManager::GetConstantFromInst(const Instruction* inst) {
std::vector<uint32_t> literal_words_or_ids;
// Collect the constant defining literals or component ids.

View File

@ -522,7 +522,7 @@ class ConstantManager {
// Gets or creates a Constant instance to hold the constant value of the given
// instruction. It returns a pointer to a Constant instance or nullptr if it
// could not create the constant.
const Constant* GetConstantFromInst(Instruction* inst);
const Constant* GetConstantFromInst(const Instruction* inst);
// Gets or creates a constant defining instruction for the given Constant |c|.
// If |c| had already been defined, it returns a pointer to the existing

View File

@ -232,8 +232,12 @@ bool ScalarReplacementPass::ReplaceAccessChain(
// indexes) or a direct use of the replacement variable.
uint32_t indexId = chain->GetSingleWordInOperand(1u);
const Instruction* index = get_def_use_mgr()->GetDef(indexId);
uint64_t indexValue = GetConstantInteger(index);
if (indexValue >= replacements.size()) {
int64_t indexValue = context()
->get_constant_mgr()
->GetConstantFromInst(index)
->GetSignExtendedValue();
if (indexValue < 0 ||
indexValue >= static_cast<int64_t>(replacements.size())) {
// Out of bounds access, this is illegal IR. Notice that OpAccessChain
// indexing is 0-based, so we should also reject index == size-of-array.
return false;
@ -269,7 +273,7 @@ bool ScalarReplacementPass::CreateReplacementVariables(
Instruction* inst, std::vector<Instruction*>* replacements) {
Instruction* type = GetStorageType(inst);
std::unique_ptr<std::unordered_set<uint64_t>> components_used =
std::unique_ptr<std::unordered_set<int64_t>> components_used =
GetUsedComponents(inst);
uint32_t elem = 0;
@ -467,35 +471,15 @@ void ScalarReplacementPass::GetOrCreateInitialValue(Instruction* source,
}
}
uint64_t ScalarReplacementPass::GetIntegerLiteral(const Operand& op) const {
assert(op.words.size() <= 2);
uint64_t len = 0;
for (uint32_t i = 0; i != op.words.size(); ++i) {
len |= (op.words[i] << (32 * i));
}
return len;
}
uint64_t ScalarReplacementPass::GetConstantInteger(
const Instruction* constant) const {
assert(get_def_use_mgr()->GetDef(constant->type_id())->opcode() ==
SpvOpTypeInt);
assert(constant->opcode() == SpvOpConstant ||
constant->opcode() == SpvOpConstantNull);
if (constant->opcode() == SpvOpConstantNull) {
return 0;
}
const Operand& op = constant->GetInOperand(0u);
return GetIntegerLiteral(op);
}
uint64_t ScalarReplacementPass::GetArrayLength(
const Instruction* arrayType) const {
assert(arrayType->opcode() == SpvOpTypeArray);
const Instruction* length =
get_def_use_mgr()->GetDef(arrayType->GetSingleWordInOperand(1u));
return GetConstantInteger(length);
return context()
->get_constant_mgr()
->GetConstantFromInst(length)
->GetZeroExtendedValue();
}
uint64_t ScalarReplacementPass::GetNumElements(const Instruction* type) const {
@ -734,10 +718,10 @@ bool ScalarReplacementPass::IsLargerThanSizeLimit(uint64_t length) const {
return length > max_num_elements_;
}
std::unique_ptr<std::unordered_set<uint64_t>>
std::unique_ptr<std::unordered_set<int64_t>>
ScalarReplacementPass::GetUsedComponents(Instruction* inst) {
std::unique_ptr<std::unordered_set<uint64_t>> result(
new std::unordered_set<uint64_t>());
std::unique_ptr<std::unordered_set<int64_t>> result(
new std::unordered_set<int64_t>());
analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
@ -775,18 +759,8 @@ ScalarReplacementPass::GetUsedComponents(Instruction* inst) {
const analysis::Constant* index_const =
const_mgr->FindDeclaredConstant(index_id);
if (index_const) {
const analysis::Integer* index_type =
index_const->type()->AsInteger();
assert(index_type);
if (index_type->width() == 32) {
result->insert(index_const->GetU32());
result->insert(index_const->GetSignExtendedValue());
return true;
} else if (index_type->width() == 64) {
result->insert(index_const->GetU64());
return true;
}
result.reset(nullptr);
return false;
} else {
// Could be any element. Assuming all are used.
result.reset(nullptr);

View File

@ -158,14 +158,6 @@ class ScalarReplacementPass : public Pass {
bool CreateReplacementVariables(Instruction* inst,
std::vector<Instruction*>* replacements);
// Returns the value of an OpConstant of integer type.
//
// |constant| must use two or fewer words to generate the value.
uint64_t GetConstantInteger(const Instruction* constant) const;
// Returns the integer literal for |op|.
uint64_t GetIntegerLiteral(const Operand& op) const;
// Returns the array length for |arrayInst|.
uint64_t GetArrayLength(const Instruction* arrayInst) const;
@ -216,7 +208,7 @@ class ScalarReplacementPass : public Pass {
// Returns a set containing the which components of the result of |inst| are
// potentially used. If the return value is |nullptr|, then every components
// is possibly used.
std::unique_ptr<std::unordered_set<uint64_t>> GetUsedComponents(
std::unique_ptr<std::unordered_set<int64_t>> GetUsedComponents(
Instruction* inst);
// Returns an instruction defining a null constant with type |type_id|. If

View File

@ -1701,6 +1701,67 @@ TEST_F(ScalarReplacementTest, OutOfBoundOpAccessChain) {
EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result));
}
TEST_F(ScalarReplacementTest, CharIndex) {
const std::string text = R"(
; CHECK: [[int:%\w+]] = OpTypeInt 32 0
; CHECK: [[ptr:%\w+]] = OpTypePointer Function [[int]]
; CHECK: OpVariable [[ptr]] Function
OpCapability Shader
OpCapability Int8
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main"
OpExecutionMode %main LocalSize 1 1 1
%void = OpTypeVoid
%int = OpTypeInt 32 0
%int_1024 = OpConstant %int 1024
%char = OpTypeInt 8 0
%char_1 = OpConstant %char 1
%array = OpTypeArray %int %int_1024
%ptr_func_array = OpTypePointer Function %array
%ptr_func_int = OpTypePointer Function %int
%void_fn = OpTypeFunction %void
%main = OpFunction %void None %void_fn
%entry = OpLabel
%var = OpVariable %ptr_func_array Function
%gep = OpAccessChain %ptr_func_int %var %char_1
OpStore %gep %int_1024
OpReturn
OpFunctionEnd
)";
SinglePassRunAndMatch<ScalarReplacementPass>(text, true, 0);
}
TEST_F(ScalarReplacementTest, OutOfBoundsOpAccessChainNegative) {
const std::string text = R"(
OpCapability Shader
OpCapability Int8
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main"
OpExecutionMode %main LocalSize 1 1 1
%void = OpTypeVoid
%int = OpTypeInt 32 0
%int_1024 = OpConstant %int 1024
%char = OpTypeInt 8 1
%char_n1 = OpConstant %char -1
%array = OpTypeArray %int %int_1024
%ptr_func_array = OpTypePointer Function %array
%ptr_func_int = OpTypePointer Function %int
%void_fn = OpTypeFunction %void
%main = OpFunction %void None %void_fn
%entry = OpLabel
%var = OpVariable %ptr_func_array Function
%gep = OpAccessChain %ptr_func_int %var %char_n1
OpStore %gep %int_1024
OpReturn
OpFunctionEnd
)";
auto result =
SinglePassRunAndDisassemble<ScalarReplacementPass>(text, true, true, 0);
EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result));
}
} // namespace
} // namespace opt
} // namespace spvtools