From 3726b500b1dec7c08cd8195d0e999db28192266a Mon Sep 17 00:00:00 2001 From: alan-baker Date: Wed, 31 Jul 2019 15:39:33 -0400 Subject: [PATCH] Treat access chain indexes as signed in SROA (#2776) Fixes #2768 * In scalar replacement, interpret access chain indexes as signed counts * Use Constant::GetSignExtendedValue and Constant::GetZeroExtendedValue where appropriate * new tests --- source/opt/constants.cpp | 2 +- source/opt/constants.h | 2 +- source/opt/scalar_replacement_pass.cpp | 58 +++++++----------------- source/opt/scalar_replacement_pass.h | 10 +---- test/opt/scalar_replacement_test.cpp | 61 ++++++++++++++++++++++++++ 5 files changed, 80 insertions(+), 53 deletions(-) diff --git a/source/opt/constants.cpp b/source/opt/constants.cpp index b5875c9ee..5c1468be5 100644 --- a/source/opt/constants.cpp +++ b/source/opt/constants.cpp @@ -291,7 +291,7 @@ std::unique_ptr ConstantManager::CreateConstant( } } -const Constant* ConstantManager::GetConstantFromInst(Instruction* inst) { +const Constant* ConstantManager::GetConstantFromInst(const Instruction* inst) { std::vector literal_words_or_ids; // Collect the constant defining literals or component ids. diff --git a/source/opt/constants.h b/source/opt/constants.h index 7b9f24864..93d0847da 100644 --- a/source/opt/constants.h +++ b/source/opt/constants.h @@ -522,7 +522,7 @@ class ConstantManager { // Gets or creates a Constant instance to hold the constant value of the given // instruction. It returns a pointer to a Constant instance or nullptr if it // could not create the constant. - const Constant* GetConstantFromInst(Instruction* inst); + const Constant* GetConstantFromInst(const Instruction* inst); // Gets or creates a constant defining instruction for the given Constant |c|. // If |c| had already been defined, it returns a pointer to the existing diff --git a/source/opt/scalar_replacement_pass.cpp b/source/opt/scalar_replacement_pass.cpp index 9ae1ae8e8..7f352df2d 100644 --- a/source/opt/scalar_replacement_pass.cpp +++ b/source/opt/scalar_replacement_pass.cpp @@ -232,8 +232,12 @@ bool ScalarReplacementPass::ReplaceAccessChain( // indexes) or a direct use of the replacement variable. uint32_t indexId = chain->GetSingleWordInOperand(1u); const Instruction* index = get_def_use_mgr()->GetDef(indexId); - uint64_t indexValue = GetConstantInteger(index); - if (indexValue >= replacements.size()) { + int64_t indexValue = context() + ->get_constant_mgr() + ->GetConstantFromInst(index) + ->GetSignExtendedValue(); + if (indexValue < 0 || + indexValue >= static_cast(replacements.size())) { // Out of bounds access, this is illegal IR. Notice that OpAccessChain // indexing is 0-based, so we should also reject index == size-of-array. return false; @@ -269,7 +273,7 @@ bool ScalarReplacementPass::CreateReplacementVariables( Instruction* inst, std::vector* replacements) { Instruction* type = GetStorageType(inst); - std::unique_ptr> components_used = + std::unique_ptr> components_used = GetUsedComponents(inst); uint32_t elem = 0; @@ -467,35 +471,15 @@ void ScalarReplacementPass::GetOrCreateInitialValue(Instruction* source, } } -uint64_t ScalarReplacementPass::GetIntegerLiteral(const Operand& op) const { - assert(op.words.size() <= 2); - uint64_t len = 0; - for (uint32_t i = 0; i != op.words.size(); ++i) { - len |= (op.words[i] << (32 * i)); - } - return len; -} - -uint64_t ScalarReplacementPass::GetConstantInteger( - const Instruction* constant) const { - assert(get_def_use_mgr()->GetDef(constant->type_id())->opcode() == - SpvOpTypeInt); - assert(constant->opcode() == SpvOpConstant || - constant->opcode() == SpvOpConstantNull); - if (constant->opcode() == SpvOpConstantNull) { - return 0; - } - - const Operand& op = constant->GetInOperand(0u); - return GetIntegerLiteral(op); -} - uint64_t ScalarReplacementPass::GetArrayLength( const Instruction* arrayType) const { assert(arrayType->opcode() == SpvOpTypeArray); const Instruction* length = get_def_use_mgr()->GetDef(arrayType->GetSingleWordInOperand(1u)); - return GetConstantInteger(length); + return context() + ->get_constant_mgr() + ->GetConstantFromInst(length) + ->GetZeroExtendedValue(); } uint64_t ScalarReplacementPass::GetNumElements(const Instruction* type) const { @@ -734,10 +718,10 @@ bool ScalarReplacementPass::IsLargerThanSizeLimit(uint64_t length) const { return length > max_num_elements_; } -std::unique_ptr> +std::unique_ptr> ScalarReplacementPass::GetUsedComponents(Instruction* inst) { - std::unique_ptr> result( - new std::unordered_set()); + std::unique_ptr> result( + new std::unordered_set()); analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); @@ -775,18 +759,8 @@ ScalarReplacementPass::GetUsedComponents(Instruction* inst) { const analysis::Constant* index_const = const_mgr->FindDeclaredConstant(index_id); if (index_const) { - const analysis::Integer* index_type = - index_const->type()->AsInteger(); - assert(index_type); - if (index_type->width() == 32) { - result->insert(index_const->GetU32()); - return true; - } else if (index_type->width() == 64) { - result->insert(index_const->GetU64()); - return true; - } - result.reset(nullptr); - return false; + result->insert(index_const->GetSignExtendedValue()); + return true; } else { // Could be any element. Assuming all are used. result.reset(nullptr); diff --git a/source/opt/scalar_replacement_pass.h b/source/opt/scalar_replacement_pass.h index 3a17045b4..5b5198155 100644 --- a/source/opt/scalar_replacement_pass.h +++ b/source/opt/scalar_replacement_pass.h @@ -158,14 +158,6 @@ class ScalarReplacementPass : public Pass { bool CreateReplacementVariables(Instruction* inst, std::vector* replacements); - // Returns the value of an OpConstant of integer type. - // - // |constant| must use two or fewer words to generate the value. - uint64_t GetConstantInteger(const Instruction* constant) const; - - // Returns the integer literal for |op|. - uint64_t GetIntegerLiteral(const Operand& op) const; - // Returns the array length for |arrayInst|. uint64_t GetArrayLength(const Instruction* arrayInst) const; @@ -216,7 +208,7 @@ class ScalarReplacementPass : public Pass { // Returns a set containing the which components of the result of |inst| are // potentially used. If the return value is |nullptr|, then every components // is possibly used. - std::unique_ptr> GetUsedComponents( + std::unique_ptr> GetUsedComponents( Instruction* inst); // Returns an instruction defining a null constant with type |type_id|. If diff --git a/test/opt/scalar_replacement_test.cpp b/test/opt/scalar_replacement_test.cpp index 2ed7b5a13..04721c24f 100644 --- a/test/opt/scalar_replacement_test.cpp +++ b/test/opt/scalar_replacement_test.cpp @@ -1701,6 +1701,67 @@ TEST_F(ScalarReplacementTest, OutOfBoundOpAccessChain) { EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); } +TEST_F(ScalarReplacementTest, CharIndex) { + const std::string text = R"( +; CHECK: [[int:%\w+]] = OpTypeInt 32 0 +; CHECK: [[ptr:%\w+]] = OpTypePointer Function [[int]] +; CHECK: OpVariable [[ptr]] Function +OpCapability Shader +OpCapability Int8 +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %main "main" +OpExecutionMode %main LocalSize 1 1 1 +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%int_1024 = OpConstant %int 1024 +%char = OpTypeInt 8 0 +%char_1 = OpConstant %char 1 +%array = OpTypeArray %int %int_1024 +%ptr_func_array = OpTypePointer Function %array +%ptr_func_int = OpTypePointer Function %int +%void_fn = OpTypeFunction %void +%main = OpFunction %void None %void_fn +%entry = OpLabel +%var = OpVariable %ptr_func_array Function +%gep = OpAccessChain %ptr_func_int %var %char_1 +OpStore %gep %int_1024 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true, 0); +} + +TEST_F(ScalarReplacementTest, OutOfBoundsOpAccessChainNegative) { + const std::string text = R"( +OpCapability Shader +OpCapability Int8 +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %main "main" +OpExecutionMode %main LocalSize 1 1 1 +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%int_1024 = OpConstant %int 1024 +%char = OpTypeInt 8 1 +%char_n1 = OpConstant %char -1 +%array = OpTypeArray %int %int_1024 +%ptr_func_array = OpTypePointer Function %array +%ptr_func_int = OpTypePointer Function %int +%void_fn = OpTypeFunction %void +%main = OpFunction %void None %void_fn +%entry = OpLabel +%var = OpVariable %ptr_func_array Function +%gep = OpAccessChain %ptr_func_int %var %char_n1 +OpStore %gep %int_1024 +OpReturn +OpFunctionEnd +)"; + + auto result = + SinglePassRunAndDisassemble(text, true, true, 0); + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); +} + } // namespace } // namespace opt } // namespace spvtools