Treat access chain indexes as signed in SROA (#2776)

Fixes #2768

* In scalar replacement, interpret access chain indexes as signed counts
* Use Constant::GetSignExtendedValue and Constant::GetZeroExtendedValue
where appropriate
* new tests
This commit is contained in:
alan-baker 2019-07-31 15:39:33 -04:00 committed by GitHub
parent 31590104ec
commit 3726b500b1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 80 additions and 53 deletions

View File

@ -291,7 +291,7 @@ std::unique_ptr<Constant> ConstantManager::CreateConstant(
} }
} }
const Constant* ConstantManager::GetConstantFromInst(Instruction* inst) { const Constant* ConstantManager::GetConstantFromInst(const Instruction* inst) {
std::vector<uint32_t> literal_words_or_ids; std::vector<uint32_t> literal_words_or_ids;
// Collect the constant defining literals or component ids. // Collect the constant defining literals or component ids.

View File

@ -522,7 +522,7 @@ class ConstantManager {
// Gets or creates a Constant instance to hold the constant value of the given // Gets or creates a Constant instance to hold the constant value of the given
// instruction. It returns a pointer to a Constant instance or nullptr if it // instruction. It returns a pointer to a Constant instance or nullptr if it
// could not create the constant. // could not create the constant.
const Constant* GetConstantFromInst(Instruction* inst); const Constant* GetConstantFromInst(const Instruction* inst);
// Gets or creates a constant defining instruction for the given Constant |c|. // Gets or creates a constant defining instruction for the given Constant |c|.
// If |c| had already been defined, it returns a pointer to the existing // If |c| had already been defined, it returns a pointer to the existing

View File

@ -232,8 +232,12 @@ bool ScalarReplacementPass::ReplaceAccessChain(
// indexes) or a direct use of the replacement variable. // indexes) or a direct use of the replacement variable.
uint32_t indexId = chain->GetSingleWordInOperand(1u); uint32_t indexId = chain->GetSingleWordInOperand(1u);
const Instruction* index = get_def_use_mgr()->GetDef(indexId); const Instruction* index = get_def_use_mgr()->GetDef(indexId);
uint64_t indexValue = GetConstantInteger(index); int64_t indexValue = context()
if (indexValue >= replacements.size()) { ->get_constant_mgr()
->GetConstantFromInst(index)
->GetSignExtendedValue();
if (indexValue < 0 ||
indexValue >= static_cast<int64_t>(replacements.size())) {
// Out of bounds access, this is illegal IR. Notice that OpAccessChain // Out of bounds access, this is illegal IR. Notice that OpAccessChain
// indexing is 0-based, so we should also reject index == size-of-array. // indexing is 0-based, so we should also reject index == size-of-array.
return false; return false;
@ -269,7 +273,7 @@ bool ScalarReplacementPass::CreateReplacementVariables(
Instruction* inst, std::vector<Instruction*>* replacements) { Instruction* inst, std::vector<Instruction*>* replacements) {
Instruction* type = GetStorageType(inst); Instruction* type = GetStorageType(inst);
std::unique_ptr<std::unordered_set<uint64_t>> components_used = std::unique_ptr<std::unordered_set<int64_t>> components_used =
GetUsedComponents(inst); GetUsedComponents(inst);
uint32_t elem = 0; uint32_t elem = 0;
@ -467,35 +471,15 @@ void ScalarReplacementPass::GetOrCreateInitialValue(Instruction* source,
} }
} }
uint64_t ScalarReplacementPass::GetIntegerLiteral(const Operand& op) const {
assert(op.words.size() <= 2);
uint64_t len = 0;
for (uint32_t i = 0; i != op.words.size(); ++i) {
len |= (op.words[i] << (32 * i));
}
return len;
}
uint64_t ScalarReplacementPass::GetConstantInteger(
const Instruction* constant) const {
assert(get_def_use_mgr()->GetDef(constant->type_id())->opcode() ==
SpvOpTypeInt);
assert(constant->opcode() == SpvOpConstant ||
constant->opcode() == SpvOpConstantNull);
if (constant->opcode() == SpvOpConstantNull) {
return 0;
}
const Operand& op = constant->GetInOperand(0u);
return GetIntegerLiteral(op);
}
uint64_t ScalarReplacementPass::GetArrayLength( uint64_t ScalarReplacementPass::GetArrayLength(
const Instruction* arrayType) const { const Instruction* arrayType) const {
assert(arrayType->opcode() == SpvOpTypeArray); assert(arrayType->opcode() == SpvOpTypeArray);
const Instruction* length = const Instruction* length =
get_def_use_mgr()->GetDef(arrayType->GetSingleWordInOperand(1u)); get_def_use_mgr()->GetDef(arrayType->GetSingleWordInOperand(1u));
return GetConstantInteger(length); return context()
->get_constant_mgr()
->GetConstantFromInst(length)
->GetZeroExtendedValue();
} }
uint64_t ScalarReplacementPass::GetNumElements(const Instruction* type) const { uint64_t ScalarReplacementPass::GetNumElements(const Instruction* type) const {
@ -734,10 +718,10 @@ bool ScalarReplacementPass::IsLargerThanSizeLimit(uint64_t length) const {
return length > max_num_elements_; return length > max_num_elements_;
} }
std::unique_ptr<std::unordered_set<uint64_t>> std::unique_ptr<std::unordered_set<int64_t>>
ScalarReplacementPass::GetUsedComponents(Instruction* inst) { ScalarReplacementPass::GetUsedComponents(Instruction* inst) {
std::unique_ptr<std::unordered_set<uint64_t>> result( std::unique_ptr<std::unordered_set<int64_t>> result(
new std::unordered_set<uint64_t>()); new std::unordered_set<int64_t>());
analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
@ -775,18 +759,8 @@ ScalarReplacementPass::GetUsedComponents(Instruction* inst) {
const analysis::Constant* index_const = const analysis::Constant* index_const =
const_mgr->FindDeclaredConstant(index_id); const_mgr->FindDeclaredConstant(index_id);
if (index_const) { if (index_const) {
const analysis::Integer* index_type = result->insert(index_const->GetSignExtendedValue());
index_const->type()->AsInteger(); return true;
assert(index_type);
if (index_type->width() == 32) {
result->insert(index_const->GetU32());
return true;
} else if (index_type->width() == 64) {
result->insert(index_const->GetU64());
return true;
}
result.reset(nullptr);
return false;
} else { } else {
// Could be any element. Assuming all are used. // Could be any element. Assuming all are used.
result.reset(nullptr); result.reset(nullptr);

View File

@ -158,14 +158,6 @@ class ScalarReplacementPass : public Pass {
bool CreateReplacementVariables(Instruction* inst, bool CreateReplacementVariables(Instruction* inst,
std::vector<Instruction*>* replacements); std::vector<Instruction*>* replacements);
// Returns the value of an OpConstant of integer type.
//
// |constant| must use two or fewer words to generate the value.
uint64_t GetConstantInteger(const Instruction* constant) const;
// Returns the integer literal for |op|.
uint64_t GetIntegerLiteral(const Operand& op) const;
// Returns the array length for |arrayInst|. // Returns the array length for |arrayInst|.
uint64_t GetArrayLength(const Instruction* arrayInst) const; uint64_t GetArrayLength(const Instruction* arrayInst) const;
@ -216,7 +208,7 @@ class ScalarReplacementPass : public Pass {
// Returns a set containing the which components of the result of |inst| are // Returns a set containing the which components of the result of |inst| are
// potentially used. If the return value is |nullptr|, then every components // potentially used. If the return value is |nullptr|, then every components
// is possibly used. // is possibly used.
std::unique_ptr<std::unordered_set<uint64_t>> GetUsedComponents( std::unique_ptr<std::unordered_set<int64_t>> GetUsedComponents(
Instruction* inst); Instruction* inst);
// Returns an instruction defining a null constant with type |type_id|. If // Returns an instruction defining a null constant with type |type_id|. If

View File

@ -1701,6 +1701,67 @@ TEST_F(ScalarReplacementTest, OutOfBoundOpAccessChain) {
EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result));
} }
TEST_F(ScalarReplacementTest, CharIndex) {
const std::string text = R"(
; CHECK: [[int:%\w+]] = OpTypeInt 32 0
; CHECK: [[ptr:%\w+]] = OpTypePointer Function [[int]]
; CHECK: OpVariable [[ptr]] Function
OpCapability Shader
OpCapability Int8
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main"
OpExecutionMode %main LocalSize 1 1 1
%void = OpTypeVoid
%int = OpTypeInt 32 0
%int_1024 = OpConstant %int 1024
%char = OpTypeInt 8 0
%char_1 = OpConstant %char 1
%array = OpTypeArray %int %int_1024
%ptr_func_array = OpTypePointer Function %array
%ptr_func_int = OpTypePointer Function %int
%void_fn = OpTypeFunction %void
%main = OpFunction %void None %void_fn
%entry = OpLabel
%var = OpVariable %ptr_func_array Function
%gep = OpAccessChain %ptr_func_int %var %char_1
OpStore %gep %int_1024
OpReturn
OpFunctionEnd
)";
SinglePassRunAndMatch<ScalarReplacementPass>(text, true, 0);
}
TEST_F(ScalarReplacementTest, OutOfBoundsOpAccessChainNegative) {
const std::string text = R"(
OpCapability Shader
OpCapability Int8
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main"
OpExecutionMode %main LocalSize 1 1 1
%void = OpTypeVoid
%int = OpTypeInt 32 0
%int_1024 = OpConstant %int 1024
%char = OpTypeInt 8 1
%char_n1 = OpConstant %char -1
%array = OpTypeArray %int %int_1024
%ptr_func_array = OpTypePointer Function %array
%ptr_func_int = OpTypePointer Function %int
%void_fn = OpTypeFunction %void
%main = OpFunction %void None %void_fn
%entry = OpLabel
%var = OpVariable %ptr_func_array Function
%gep = OpAccessChain %ptr_func_int %var %char_n1
OpStore %gep %int_1024
OpReturn
OpFunctionEnd
)";
auto result =
SinglePassRunAndDisassemble<ScalarReplacementPass>(text, true, true, 0);
EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result));
}
} // namespace } // namespace
} // namespace opt } // namespace opt
} // namespace spvtools } // namespace spvtools