From 46413d57802cc68351426d43ab72080b1f79dc95 Mon Sep 17 00:00:00 2001 From: John Kessenich Date: Mon, 26 Feb 2018 19:20:05 -0700 Subject: [PATCH] SPV: Fix #1258: cache constant structs by id, not opcode. Constants were generally cached by type opcode, but all structures share the same type opcode (OpTypeStruct), so they need to be cached by type id. --- SPIRV/SpvBuilder.cpp | 54 ++++++++++++++++++----- SPIRV/SpvBuilder.h | 15 ++++--- Test/baseResults/spv.constStruct.vert.out | 45 +++++++++++++++++++ Test/spv.constStruct.vert | 22 +++++++++ gtests/Spv.FromFile.cpp | 1 + 5 files changed, 120 insertions(+), 17 deletions(-) create mode 100755 Test/baseResults/spv.constStruct.vert.out create mode 100644 Test/spv.constStruct.vert diff --git a/SPIRV/SpvBuilder.cpp b/SPIRV/SpvBuilder.cpp index 0262a12cc..8d1bfb9b0 100644 --- a/SPIRV/SpvBuilder.cpp +++ b/SPIRV/SpvBuilder.cpp @@ -622,7 +622,7 @@ Id Builder::getContainedTypeId(Id typeId) const // See if a scalar constant of this type has already been created, so it // can be reused rather than duplicated. (Required by the specification). -Id Builder::findScalarConstant(Op typeClass, Op opcode, Id typeId, unsigned value) const +Id Builder::findScalarConstant(Op typeClass, Op opcode, Id typeId, unsigned value) { Instruction* constant; for (int i = 0; i < (int)groupedConstants[typeClass].size(); ++i) { @@ -637,7 +637,7 @@ Id Builder::findScalarConstant(Op typeClass, Op opcode, Id typeId, unsigned valu } // Version of findScalarConstant (see above) for scalars that take two operands (e.g. a 'double' or 'int64'). -Id Builder::findScalarConstant(Op typeClass, Op opcode, Id typeId, unsigned v1, unsigned v2) const +Id Builder::findScalarConstant(Op typeClass, Op opcode, Id typeId, unsigned v1, unsigned v2) { Instruction* constant; for (int i = 0; i < (int)groupedConstants[typeClass].size(); ++i) { @@ -849,7 +849,7 @@ Id Builder::makeFloat16Constant(float f16, bool specConstant) } #endif -Id Builder::findCompositeConstant(Op typeClass, const std::vector& comps) const +Id Builder::findCompositeConstant(Op typeClass, const std::vector& comps) { Instruction* constant = 0; bool found = false; @@ -877,6 +877,30 @@ Id Builder::findCompositeConstant(Op typeClass, const std::vector& comps) co return found ? constant->getResultId() : NoResult; } +Id Builder::findStructConstant(Id typeId, const std::vector& comps) +{ + Instruction* constant = 0; + bool found = false; + for (int i = 0; i < (int)groupedStructConstants[typeId].size(); ++i) { + constant = groupedStructConstants[typeId][i]; + + // same contents? + bool mismatch = false; + for (int op = 0; op < constant->getNumOperands(); ++op) { + if (constant->getIdOperand(op) != comps[op]) { + mismatch = true; + break; + } + } + if (! mismatch) { + found = true; + break; + } + } + + return found ? constant->getResultId() : NoResult; +} + // Comments in header Id Builder::makeCompositeConstant(Id typeId, const std::vector& members, bool specConstant) { @@ -887,25 +911,33 @@ Id Builder::makeCompositeConstant(Id typeId, const std::vector& members, boo switch (typeClass) { case OpTypeVector: case OpTypeArray: - case OpTypeStruct: case OpTypeMatrix: + if (! specConstant) { + Id existing = findCompositeConstant(typeClass, members); + if (existing) + return existing; + } + break; + case OpTypeStruct: + if (! specConstant) { + Id existing = findStructConstant(typeId, members); + if (existing) + return existing; + } break; default: assert(0); return makeFloatConstant(0.0); } - if (! specConstant) { - Id existing = findCompositeConstant(typeClass, members); - if (existing) - return existing; - } - Instruction* c = new Instruction(getUniqueId(), typeId, opcode); for (int op = 0; op < (int)members.size(); ++op) c->addIdOperand(members[op]); constantsTypesGlobals.push_back(std::unique_ptr(c)); - groupedConstants[typeClass].push_back(c); + if (typeClass == OpTypeStruct) + groupedStructConstants[typeId].push_back(c); + else + groupedConstants[typeClass].push_back(c); module.mapInstruction(c); return c->getResultId(); diff --git a/SPIRV/SpvBuilder.h b/SPIRV/SpvBuilder.h index fcf351b46..154687d7f 100755 --- a/SPIRV/SpvBuilder.h +++ b/SPIRV/SpvBuilder.h @@ -55,6 +55,7 @@ #include #include #include +#include namespace spv { @@ -149,7 +150,7 @@ public: bool isAggregate(Id resultId) const { return isAggregateType(getTypeId(resultId)); } bool isSampledImage(Id resultId) const { return isSampledImageType(getTypeId(resultId)); } - bool isBoolType(Id typeId) const { return groupedTypes[OpTypeBool].size() > 0 && typeId == groupedTypes[OpTypeBool].back()->getResultId(); } + bool isBoolType(Id typeId) { return groupedTypes[OpTypeBool].size() > 0 && typeId == groupedTypes[OpTypeBool].back()->getResultId(); } bool isIntType(Id typeId) const { return getTypeClass(typeId) == OpTypeInt && module.getInstruction(typeId)->getImmediateOperand(1) != 0; } bool isUintType(Id typeId) const { return getTypeClass(typeId) == OpTypeInt && module.getInstruction(typeId)->getImmediateOperand(1) == 0; } bool isFloatType(Id typeId) const { return getTypeClass(typeId) == OpTypeFloat; } @@ -576,9 +577,10 @@ public: protected: Id makeIntConstant(Id typeId, unsigned value, bool specConstant); Id makeInt64Constant(Id typeId, unsigned long long value, bool specConstant); - Id findScalarConstant(Op typeClass, Op opcode, Id typeId, unsigned value) const; - Id findScalarConstant(Op typeClass, Op opcode, Id typeId, unsigned v1, unsigned v2) const; - Id findCompositeConstant(Op typeClass, const std::vector& comps) const; + Id findScalarConstant(Op typeClass, Op opcode, Id typeId, unsigned value); + Id findScalarConstant(Op typeClass, Op opcode, Id typeId, unsigned v1, unsigned v2); + Id findCompositeConstant(Op typeClass, const std::vector& comps); + Id findStructConstant(Id typeId, const std::vector& comps); Id collapseAccessChain(); void remapDynamicSwizzle(); void transferAccessChainSwizzle(bool dynamic); @@ -622,8 +624,9 @@ public: std::vector > functions; // not output, internally used for quick & dirty canonical (unique) creation - std::vector groupedConstants[OpConstant]; // all types appear before OpConstant - std::vector groupedTypes[OpConstant]; + std::unordered_map> groupedConstants; // map type opcodes to constant inst. + std::unordered_map> groupedStructConstants; // map struct-id to constant instructions + std::unordered_map> groupedTypes; // map type opcodes to type instructions // stack of switches std::stack switchMerges; diff --git a/Test/baseResults/spv.constStruct.vert.out b/Test/baseResults/spv.constStruct.vert.out new file mode 100755 index 000000000..1a5130265 --- /dev/null +++ b/Test/baseResults/spv.constStruct.vert.out @@ -0,0 +1,45 @@ +spv.constStruct.vert +// Module Version 10000 +// Generated by (magic number): 80004 +// Id's are bound by 23 + + Capability Shader + 1: ExtInstImport "GLSL.std.450" + MemoryModel Logical GLSL450 + EntryPoint Vertex 4 "main" + Source GLSL 450 + Name 4 "main" + Name 9 "T" + MemberName 9(T) 0 "m" + Name 10 "U" + MemberName 10(U) 0 "m" + Name 11 "S" + MemberName 11(S) 0 "t" + MemberName 11(S) 1 "u" + Name 13 "s1" + Name 22 "s2" + 2: TypeVoid + 3: TypeFunction 2 + 6: TypeFloat 32 + 7: TypeVector 6(float) 2 + 8: TypeMatrix 7(fvec2) 2 + 9(T): TypeStruct 8 + 10(U): TypeStruct 8 + 11(S): TypeStruct 9(T) 10(U) + 12: TypePointer Function 11(S) + 14: 6(float) Constant 1065353216 + 15: 6(float) Constant 0 + 16: 7(fvec2) ConstantComposite 14 15 + 17: 7(fvec2) ConstantComposite 15 14 + 18: 8 ConstantComposite 16 17 + 19: 9(T) ConstantComposite 18 + 20: 10(U) ConstantComposite 18 + 21: 11(S) ConstantComposite 19 20 + 4(main): 2 Function None 3 + 5: Label + 13(s1): 12(ptr) Variable Function + 22(s2): 12(ptr) Variable Function + Store 13(s1) 21 + Store 22(s2) 21 + Return + FunctionEnd diff --git a/Test/spv.constStruct.vert b/Test/spv.constStruct.vert new file mode 100644 index 000000000..d5dd8da97 --- /dev/null +++ b/Test/spv.constStruct.vert @@ -0,0 +1,22 @@ +#version 450 + +precision highp float; + +struct U { + mat2 m; +}; + +struct T { + mat2 m; +}; + +struct S { + T t; + U u; +}; + +void main() +{ + S s1 = S(T(mat2(1.0)), U(mat2(1.0))); + S s2 = S(T(mat2(1.0)), U(mat2(1.0))); +} diff --git a/gtests/Spv.FromFile.cpp b/gtests/Spv.FromFile.cpp index 8cf1da83c..51d9deebd 100644 --- a/gtests/Spv.FromFile.cpp +++ b/gtests/Spv.FromFile.cpp @@ -235,6 +235,7 @@ INSTANTIATE_TEST_CASE_P( "spv.branch-return.vert", "spv.builtInXFB.vert", "spv.conditionalDiscard.frag", + "spv.constStruct.vert", "spv.controlFlowAttributes.frag", "spv.conversion.frag", "spv.dataOut.frag",