Have scalar replacement use undef instead of null (#4691)

Scalar replacement generates a null when there value for a member will
not be used.  The null is used to make sure things are
deterministic in case there is an error.

However, some type cannot be null, so we will change that to use undef.
To keep the code simpler we will always use the undef.

Fixes #3996
This commit is contained in:
Steven Perron 2022-02-03 15:51:15 +00:00 committed by GitHub
parent 7fa9e746ef
commit 5b371918b9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 60 additions and 38 deletions

View File

@ -24,6 +24,7 @@
#include "source/opt/reflect.h"
#include "source/opt/types.h"
#include "source/util/make_unique.h"
#include "types.h"
static const uint32_t kDebugValueOperandValueIndex = 5;
static const uint32_t kDebugValueOperandExpressionIndex = 6;
@ -395,7 +396,7 @@ bool ScalarReplacementPass::CreateReplacementVariables(
if (!components_used || components_used->count(elem)) {
CreateVariable(*id, inst, elem, replacements);
} else {
replacements->push_back(CreateNullConstant(*id));
replacements->push_back(GetUndef(*id));
}
elem++;
});
@ -406,8 +407,8 @@ bool ScalarReplacementPass::CreateReplacementVariables(
CreateVariable(type->GetSingleWordInOperand(0u), inst, i,
replacements);
} else {
replacements->push_back(
CreateNullConstant(type->GetSingleWordInOperand(0u)));
uint32_t element_type_id = type->GetSingleWordInOperand(0);
replacements->push_back(GetUndef(element_type_id));
}
}
break;
@ -429,6 +430,10 @@ bool ScalarReplacementPass::CreateReplacementVariables(
replacements->end();
}
Instruction* ScalarReplacementPass::GetUndef(uint32_t type_id) {
return get_def_use_mgr()->GetDef(Type2Undef(type_id));
}
void ScalarReplacementPass::TransferAnnotations(
const Instruction* source, std::vector<Instruction*>* replacements) {
// Only transfer invariant and restrict decorations on the variable. There are
@ -981,20 +986,6 @@ ScalarReplacementPass::GetUsedComponents(Instruction* inst) {
return result;
}
Instruction* ScalarReplacementPass::CreateNullConstant(uint32_t type_id) {
analysis::TypeManager* type_mgr = context()->get_type_mgr();
analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
const analysis::Type* type = type_mgr->GetType(type_id);
const analysis::Constant* null_const = const_mgr->GetConstant(type, {});
Instruction* null_inst =
const_mgr->GetDefiningInstruction(null_const, type_id);
if (null_inst != nullptr) {
context()->UpdateDefUse(null_inst);
}
return null_inst;
}
uint64_t ScalarReplacementPass::GetMaxLegalIndex(
const Instruction* var_inst) const {
assert(var_inst->opcode() == SpvOpVariable &&

View File

@ -23,14 +23,14 @@
#include <vector>
#include "source/opt/function.h"
#include "source/opt/pass.h"
#include "source/opt/mem_pass.h"
#include "source/opt/type_manager.h"
namespace spvtools {
namespace opt {
// Documented in optimizer.hpp
class ScalarReplacementPass : public Pass {
class ScalarReplacementPass : public MemPass {
private:
static const uint32_t kDefaultLimit = 100;
@ -234,10 +234,8 @@ class ScalarReplacementPass : public Pass {
std::unique_ptr<std::unordered_set<int64_t>> GetUsedComponents(
Instruction* inst);
// Returns an instruction defining a null constant with type |type_id|. If
// one already exists, it is returned. Otherwise a new one is created.
// Returns |nullptr| if the new constant could not be created.
Instruction* CreateNullConstant(uint32_t type_id);
// Returns an instruction defining an undefined value type |type_id|.
Instruction* GetUndef(uint32_t type_id);
// Maps storage type to a pointer type enclosing that type.
std::unordered_map<uint32_t, uint32_t> pointee_to_pointer_;

View File

@ -470,9 +470,9 @@ TEST_F(ScalarReplacementTest, NonUniformCompositeInitialization) {
; CHECK: [[const_array:%\w+]] = OpConstantComposite [[array]]
; CHECK: [[const_matrix:%\w+]] = OpConstantNull [[matrix]]
; CHECK: [[const_struct1:%\w+]] = OpConstantComposite [[struct1]]
; CHECK: OpConstantNull [[uint]]
; CHECK: OpConstantNull [[vector]]
; CHECK: OpConstantNull [[long]]
; CHECK: OpUndef [[uint]]
; CHECK: OpUndef [[vector]]
; CHECK: OpUndef [[long]]
; CHECK: OpFunction
; CHECK-NOT: OpVariable [[struct2_ptr]] Function
; CHECK: OpVariable [[uint_ptr]] Function
@ -654,11 +654,10 @@ TEST_F(ScalarReplacementTest, ReplaceWholeLoadCopyMemoryAccess) {
; CHECK: [[uint:%\w+]] = OpTypeInt 32 0
; CHECK: [[struct1:%\w+]] = OpTypeStruct [[uint]] [[uint]]
; CHECK: [[uint_ptr:%\w+]] = OpTypePointer Function [[uint]]
; CHECK: [[const:%\w+]] = OpConstant [[uint]] 0
; CHECK: [[null:%\w+]] = OpConstantNull [[uint]]
; CHECK: [[undef:%\w+]] = OpUndef [[uint]]
; CHECK: [[var0:%\w+]] = OpVariable [[uint_ptr]] Function
; CHECK: [[l0:%\w+]] = OpLoad [[uint]] [[var0]] Nontemporal
; CHECK: OpCompositeConstruct [[struct1]] [[l0]] [[null]]
; CHECK: OpCompositeConstruct [[struct1]] [[l0]] [[undef]]
;
OpCapability Shader
OpCapability Linkage
@ -1267,16 +1266,16 @@ TEST_F(ScalarReplacementTest, ReplaceWholeLoadAndStore) {
; CHECK: [[struct1:%\w+]] = OpTypeStruct [[uint]] [[uint]]
; CHECK: [[uint_ptr:%\w+]] = OpTypePointer Function [[uint]]
; CHECK: [[const:%\w+]] = OpConstant [[uint]] 0
; CHECK: [[null:%\w+]] = OpConstantNull [[uint]]
; CHECK: [[undef:%\w+]] = OpUndef [[uint]]
; CHECK: [[var0:%\w+]] = OpVariable [[uint_ptr]] Function
; CHECK: [[var1:%\w+]] = OpVariable [[uint_ptr]] Function
; CHECK-NOT: OpVariable
; CHECK: [[l0:%\w+]] = OpLoad [[uint]] [[var0]]
; CHECK: [[c0:%\w+]] = OpCompositeConstruct [[struct1]] [[l0]] [[null]]
; CHECK: [[c0:%\w+]] = OpCompositeConstruct [[struct1]] [[l0]] [[undef]]
; CHECK: [[e0:%\w+]] = OpCompositeExtract [[uint]] [[c0]] 0
; CHECK: OpStore [[var1]] [[e0]]
; CHECK: [[l1:%\w+]] = OpLoad [[uint]] [[var1]]
; CHECK: [[c1:%\w+]] = OpCompositeConstruct [[struct1]] [[l1]] [[null]]
; CHECK: [[c1:%\w+]] = OpCompositeConstruct [[struct1]] [[l1]] [[undef]]
; CHECK: [[e1:%\w+]] = OpCompositeExtract [[uint]] [[c1]] 0
;
OpCapability Shader
@ -1314,7 +1313,7 @@ TEST_F(ScalarReplacementTest, ReplaceWholeLoadAndStore2) {
; CHECK: [[struct1:%\w+]] = OpTypeStruct [[uint]] [[uint]]
; CHECK: [[uint_ptr:%\w+]] = OpTypePointer Function [[uint]]
; CHECK: [[const:%\w+]] = OpConstant [[uint]] 0
; CHECK: [[null:%\w+]] = OpConstantNull [[uint]]
; CHECK: [[undef:%\w+]] = OpUndef [[uint]]
; CHECK: [[var1:%\w+]] = OpVariable [[uint_ptr]] Function
; CHECK: [[var0a:%\w+]] = OpVariable [[uint_ptr]] Function
; CHECK: [[var0b:%\w+]] = OpVariable [[uint_ptr]] Function
@ -1325,7 +1324,7 @@ TEST_F(ScalarReplacementTest, ReplaceWholeLoadAndStore2) {
; CHECK: [[e0:%\w+]] = OpCompositeExtract [[uint]] [[c0]] 0
; CHECK: OpStore [[var1]] [[e0]]
; CHECK: [[l1:%\w+]] = OpLoad [[uint]] [[var1]]
; CHECK: [[c1:%\w+]] = OpCompositeConstruct [[struct1]] [[l1]] [[null]]
; CHECK: [[c1:%\w+]] = OpCompositeConstruct [[struct1]] [[l1]] [[undef]]
; CHECK: [[e1:%\w+]] = OpCompositeExtract [[uint]] [[c1]] 0
;
OpCapability Shader
@ -1362,14 +1361,14 @@ TEST_F(ScalarReplacementTest, CreateAmbiguousNullConstant1) {
; CHECK: [[struct1:%\w+]] = OpTypeStruct [[uint]] [[struct_member:%\w+]]
; CHECK: [[uint_ptr:%\w+]] = OpTypePointer Function [[uint]]
; CHECK: [[const:%\w+]] = OpConstant [[uint]] 0
; CHECK: [[null:%\w+]] = OpConstantNull [[struct_member]]
; CHECK: [[undef:%\w+]] = OpUndef [[struct_member]]
; CHECK: [[var0a:%\w+]] = OpVariable [[uint_ptr]] Function
; CHECK: [[var1:%\w+]] = OpVariable [[uint_ptr]] Function
; CHECK: [[var0b:%\w+]] = OpVariable [[uint_ptr]] Function
; CHECK-NOT: OpVariable
; CHECK: OpStore [[var1]]
; CHECK: [[l1:%\w+]] = OpLoad [[uint]] [[var1]]
; CHECK: [[c1:%\w+]] = OpCompositeConstruct [[struct1]] [[l1]] [[null]]
; CHECK: [[c1:%\w+]] = OpCompositeConstruct [[struct1]] [[l1]] [[undef]]
; CHECK: [[e1:%\w+]] = OpCompositeExtract [[uint]] [[c1]] 0
;
OpCapability Shader
@ -1444,13 +1443,13 @@ TEST_F(ScalarReplacementTest, CreateAmbiguousNullConstant2) {
; CHECK: [[struct1:%\w+]] = OpTypeStruct [[uint]] [[struct_member:%\w+]]
; CHECK: [[uint_ptr:%\w+]] = OpTypePointer Function [[uint]]
; CHECK: [[const:%\w+]] = OpConstant [[uint]] 0
; CHECK: [[null:%\w+]] = OpConstantNull [[struct_member]]
; CHECK: [[undef:%\w+]] = OpUndef [[struct_member]]
; CHECK: [[var0a:%\w+]] = OpVariable [[uint_ptr]] Function
; CHECK: [[var1:%\w+]] = OpVariable [[uint_ptr]] Function
; CHECK: [[var0b:%\w+]] = OpVariable [[uint_ptr]] Function
; CHECK: OpStore [[var1]]
; CHECK: [[l1:%\w+]] = OpLoad [[uint]] [[var1]]
; CHECK: [[c1:%\w+]] = OpCompositeConstruct [[struct1]] [[l1]] [[null]]
; CHECK: [[c1:%\w+]] = OpCompositeConstruct [[struct1]] [[l1]] [[undef]]
; CHECK: [[e1:%\w+]] = OpCompositeExtract [[uint]] [[c1]] 0
;
OpCapability Shader
@ -2263,6 +2262,40 @@ OpFunctionEnd
SinglePassRunAndCheck<ScalarReplacementPass>(text, text, false);
}
TEST_F(ScalarReplacementTest, UndefImageMember) {
// Test that scalar replacement creates an undef for a type that cannot have
// and OpConstantNull.
const std::string text = R"(
; CHECK: [[image_type:%\w+]] = OpTypeSampledImage {{%\w+}}
; CHECK: [[struct_type:%\w+]] = OpTypeStruct [[image_type]]
; CHECK: [[undef:%\w+]] = OpUndef [[image_type]]
; CHECK: {{%\w+}} = OpCompositeConstruct [[struct_type]] [[undef]]
OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %2 "main"
OpExecutionMode %2 OriginUpperLeft
%void = OpTypeVoid
%4 = OpTypeFunction %void
%float = OpTypeFloat 32
%6 = OpTypeImage %float 2D 0 0 0 1 Unknown
%7 = OpTypeSampledImage %6
%_struct_8 = OpTypeStruct %7
%9 = OpTypeFunction %_struct_8
%10 = OpUndef %_struct_8
%_ptr_Function__struct_8 = OpTypePointer Function %_struct_8
%2 = OpFunction %void None %4
%11 = OpLabel
%16 = OpVariable %_ptr_Function__struct_8 Function
OpStore %16 %10
%12 = OpLoad %_struct_8 %16
OpReturn
OpFunctionEnd
)";
SinglePassRunAndMatch<ScalarReplacementPass>(text, true);
}
} // namespace
} // namespace opt
} // namespace spvtools