diff --git a/SPIRV/SpvBuilder.cpp b/SPIRV/SpvBuilder.cpp index 0262a12cc..8d1bfb9b0 100644 --- a/SPIRV/SpvBuilder.cpp +++ b/SPIRV/SpvBuilder.cpp @@ -622,7 +622,7 @@ Id Builder::getContainedTypeId(Id typeId) const // See if a scalar constant of this type has already been created, so it // can be reused rather than duplicated. (Required by the specification). -Id Builder::findScalarConstant(Op typeClass, Op opcode, Id typeId, unsigned value) const +Id Builder::findScalarConstant(Op typeClass, Op opcode, Id typeId, unsigned value) { Instruction* constant; for (int i = 0; i < (int)groupedConstants[typeClass].size(); ++i) { @@ -637,7 +637,7 @@ Id Builder::findScalarConstant(Op typeClass, Op opcode, Id typeId, unsigned valu } // Version of findScalarConstant (see above) for scalars that take two operands (e.g. a 'double' or 'int64'). -Id Builder::findScalarConstant(Op typeClass, Op opcode, Id typeId, unsigned v1, unsigned v2) const +Id Builder::findScalarConstant(Op typeClass, Op opcode, Id typeId, unsigned v1, unsigned v2) { Instruction* constant; for (int i = 0; i < (int)groupedConstants[typeClass].size(); ++i) { @@ -849,7 +849,7 @@ Id Builder::makeFloat16Constant(float f16, bool specConstant) } #endif -Id Builder::findCompositeConstant(Op typeClass, const std::vector& comps) const +Id Builder::findCompositeConstant(Op typeClass, const std::vector& comps) { Instruction* constant = 0; bool found = false; @@ -877,6 +877,30 @@ Id Builder::findCompositeConstant(Op typeClass, const std::vector& comps) co return found ? constant->getResultId() : NoResult; } +Id Builder::findStructConstant(Id typeId, const std::vector& comps) +{ + Instruction* constant = 0; + bool found = false; + for (int i = 0; i < (int)groupedStructConstants[typeId].size(); ++i) { + constant = groupedStructConstants[typeId][i]; + + // same contents? + bool mismatch = false; + for (int op = 0; op < constant->getNumOperands(); ++op) { + if (constant->getIdOperand(op) != comps[op]) { + mismatch = true; + break; + } + } + if (! mismatch) { + found = true; + break; + } + } + + return found ? constant->getResultId() : NoResult; +} + // Comments in header Id Builder::makeCompositeConstant(Id typeId, const std::vector& members, bool specConstant) { @@ -887,25 +911,33 @@ Id Builder::makeCompositeConstant(Id typeId, const std::vector& members, boo switch (typeClass) { case OpTypeVector: case OpTypeArray: - case OpTypeStruct: case OpTypeMatrix: + if (! specConstant) { + Id existing = findCompositeConstant(typeClass, members); + if (existing) + return existing; + } + break; + case OpTypeStruct: + if (! specConstant) { + Id existing = findStructConstant(typeId, members); + if (existing) + return existing; + } break; default: assert(0); return makeFloatConstant(0.0); } - if (! specConstant) { - Id existing = findCompositeConstant(typeClass, members); - if (existing) - return existing; - } - Instruction* c = new Instruction(getUniqueId(), typeId, opcode); for (int op = 0; op < (int)members.size(); ++op) c->addIdOperand(members[op]); constantsTypesGlobals.push_back(std::unique_ptr(c)); - groupedConstants[typeClass].push_back(c); + if (typeClass == OpTypeStruct) + groupedStructConstants[typeId].push_back(c); + else + groupedConstants[typeClass].push_back(c); module.mapInstruction(c); return c->getResultId(); diff --git a/SPIRV/SpvBuilder.h b/SPIRV/SpvBuilder.h index fcf351b46..154687d7f 100755 --- a/SPIRV/SpvBuilder.h +++ b/SPIRV/SpvBuilder.h @@ -55,6 +55,7 @@ #include #include #include +#include namespace spv { @@ -149,7 +150,7 @@ public: bool isAggregate(Id resultId) const { return isAggregateType(getTypeId(resultId)); } bool isSampledImage(Id resultId) const { return isSampledImageType(getTypeId(resultId)); } - bool isBoolType(Id typeId) const { return groupedTypes[OpTypeBool].size() > 0 && typeId == groupedTypes[OpTypeBool].back()->getResultId(); } + bool isBoolType(Id typeId) { return groupedTypes[OpTypeBool].size() > 0 && typeId == groupedTypes[OpTypeBool].back()->getResultId(); } bool isIntType(Id typeId) const { return getTypeClass(typeId) == OpTypeInt && module.getInstruction(typeId)->getImmediateOperand(1) != 0; } bool isUintType(Id typeId) const { return getTypeClass(typeId) == OpTypeInt && module.getInstruction(typeId)->getImmediateOperand(1) == 0; } bool isFloatType(Id typeId) const { return getTypeClass(typeId) == OpTypeFloat; } @@ -576,9 +577,10 @@ public: protected: Id makeIntConstant(Id typeId, unsigned value, bool specConstant); Id makeInt64Constant(Id typeId, unsigned long long value, bool specConstant); - Id findScalarConstant(Op typeClass, Op opcode, Id typeId, unsigned value) const; - Id findScalarConstant(Op typeClass, Op opcode, Id typeId, unsigned v1, unsigned v2) const; - Id findCompositeConstant(Op typeClass, const std::vector& comps) const; + Id findScalarConstant(Op typeClass, Op opcode, Id typeId, unsigned value); + Id findScalarConstant(Op typeClass, Op opcode, Id typeId, unsigned v1, unsigned v2); + Id findCompositeConstant(Op typeClass, const std::vector& comps); + Id findStructConstant(Id typeId, const std::vector& comps); Id collapseAccessChain(); void remapDynamicSwizzle(); void transferAccessChainSwizzle(bool dynamic); @@ -622,8 +624,9 @@ public: std::vector > functions; // not output, internally used for quick & dirty canonical (unique) creation - std::vector groupedConstants[OpConstant]; // all types appear before OpConstant - std::vector groupedTypes[OpConstant]; + std::unordered_map> groupedConstants; // map type opcodes to constant inst. + std::unordered_map> groupedStructConstants; // map struct-id to constant instructions + std::unordered_map> groupedTypes; // map type opcodes to type instructions // stack of switches std::stack switchMerges; diff --git a/Test/baseResults/spv.constStruct.vert.out b/Test/baseResults/spv.constStruct.vert.out new file mode 100755 index 000000000..1a5130265 --- /dev/null +++ b/Test/baseResults/spv.constStruct.vert.out @@ -0,0 +1,45 @@ +spv.constStruct.vert +// Module Version 10000 +// Generated by (magic number): 80004 +// Id's are bound by 23 + + Capability Shader + 1: ExtInstImport "GLSL.std.450" + MemoryModel Logical GLSL450 + EntryPoint Vertex 4 "main" + Source GLSL 450 + Name 4 "main" + Name 9 "T" + MemberName 9(T) 0 "m" + Name 10 "U" + MemberName 10(U) 0 "m" + Name 11 "S" + MemberName 11(S) 0 "t" + MemberName 11(S) 1 "u" + Name 13 "s1" + Name 22 "s2" + 2: TypeVoid + 3: TypeFunction 2 + 6: TypeFloat 32 + 7: TypeVector 6(float) 2 + 8: TypeMatrix 7(fvec2) 2 + 9(T): TypeStruct 8 + 10(U): TypeStruct 8 + 11(S): TypeStruct 9(T) 10(U) + 12: TypePointer Function 11(S) + 14: 6(float) Constant 1065353216 + 15: 6(float) Constant 0 + 16: 7(fvec2) ConstantComposite 14 15 + 17: 7(fvec2) ConstantComposite 15 14 + 18: 8 ConstantComposite 16 17 + 19: 9(T) ConstantComposite 18 + 20: 10(U) ConstantComposite 18 + 21: 11(S) ConstantComposite 19 20 + 4(main): 2 Function None 3 + 5: Label + 13(s1): 12(ptr) Variable Function + 22(s2): 12(ptr) Variable Function + Store 13(s1) 21 + Store 22(s2) 21 + Return + FunctionEnd diff --git a/Test/spv.constStruct.vert b/Test/spv.constStruct.vert new file mode 100644 index 000000000..d5dd8da97 --- /dev/null +++ b/Test/spv.constStruct.vert @@ -0,0 +1,22 @@ +#version 450 + +precision highp float; + +struct U { + mat2 m; +}; + +struct T { + mat2 m; +}; + +struct S { + T t; + U u; +}; + +void main() +{ + S s1 = S(T(mat2(1.0)), U(mat2(1.0))); + S s2 = S(T(mat2(1.0)), U(mat2(1.0))); +} diff --git a/gtests/Spv.FromFile.cpp b/gtests/Spv.FromFile.cpp index 8cf1da83c..51d9deebd 100644 --- a/gtests/Spv.FromFile.cpp +++ b/gtests/Spv.FromFile.cpp @@ -235,6 +235,7 @@ INSTANTIATE_TEST_CASE_P( "spv.branch-return.vert", "spv.builtInXFB.vert", "spv.conditionalDiscard.frag", + "spv.constStruct.vert", "spv.controlFlowAttributes.frag", "spv.conversion.frag", "spv.dataOut.frag",