diff --git a/SPIRV/GlslangToSpv.cpp b/SPIRV/GlslangToSpv.cpp index 50806d091..138a2cd25 100755 --- a/SPIRV/GlslangToSpv.cpp +++ b/SPIRV/GlslangToSpv.cpp @@ -122,6 +122,7 @@ protected: spv::Decoration TranslateAuxiliaryStorageDecoration(const glslang::TQualifier& qualifier); spv::BuiltIn TranslateBuiltInDecoration(glslang::TBuiltInVariable, bool memberDeclaration); spv::ImageFormat TranslateImageFormat(const glslang::TType& type); + spv::SelectionControlMask TranslateSelectionControl(glslang::TSelectionControl) const; spv::LoopControlMask TranslateLoopControl(glslang::TLoopControl) const; spv::StorageClass TranslateStorageClass(const glslang::TType&); spv::Id createSpvVariable(const glslang::TIntermSymbol*); @@ -741,6 +742,16 @@ spv::ImageFormat TGlslangToSpvTraverser::TranslateImageFormat(const glslang::TTy } } +spv::SelectionControlMask TGlslangToSpvTraverser::TranslateSelectionControl(glslang::TSelectionControl selectionControl) const +{ + switch (selectionControl) { + case glslang::ESelectionControlNone: return spv::SelectionControlMaskNone; + case glslang::ESelectionControlFlatten: return spv::SelectionControlFlattenMask; + case glslang::ESelectionControlDontFlatten: return spv::SelectionControlDontFlattenMask; + default: return spv::SelectionControlMaskNone; + } +} + spv::LoopControlMask TGlslangToSpvTraverser::TranslateLoopControl(glslang::TLoopControl loopControl) const { switch (loopControl) { @@ -1941,8 +1952,7 @@ bool TGlslangToSpvTraverser::visitSelection(glslang::TVisit /* visit */, glslang return false; } - // Instead, emit control flow... - + // Instead, emit control flow... // Don't handle results as temporaries, because there will be two names // and better to leave SSA to later passes. spv::Id result = (node->getBasicType() == glslang::EbtVoid) @@ -1952,8 +1962,11 @@ bool TGlslangToSpvTraverser::visitSelection(glslang::TVisit /* visit */, glslang // emit the condition before doing anything with selection node->getCondition()->traverse(this); + // Selection control: + const spv::SelectionControlMask control = TranslateSelectionControl(node->getSelectionControl()); + // make an "if" based on the value created by the condition - spv::Builder::If ifBuilder(accessChainLoad(node->getCondition()->getType()), builder); + spv::Builder::If ifBuilder(accessChainLoad(node->getCondition()->getType()), control, builder); // emit the "then" statement if (node->getTrueBlock() != nullptr) { @@ -1991,6 +2004,9 @@ bool TGlslangToSpvTraverser::visitSwitch(glslang::TVisit /* visit */, glslang::T node->getCondition()->traverse(this); spv::Id selector = accessChainLoad(node->getCondition()->getAsTyped()->getType()); + // Selection control: + const spv::SelectionControlMask control = TranslateSelectionControl(node->getSelectionControl()); + // browse the children to sort out code segments int defaultSegment = -1; std::vector codeSegments; @@ -2016,7 +2032,7 @@ bool TGlslangToSpvTraverser::visitSwitch(glslang::TVisit /* visit */, glslang::T // make the switch statement std::vector segmentBlocks; // returned, as the blocks allocated in the call - builder.makeSwitch(selector, (int)codeSegments.size(), caseValues, valueIndexToSegment, defaultSegment, segmentBlocks); + builder.makeSwitch(selector, control, (int)codeSegments.size(), caseValues, valueIndexToSegment, defaultSegment, segmentBlocks); // emit all the code in the segments breakForLoop.push(false); @@ -5707,7 +5723,7 @@ spv::Id TGlslangToSpvTraverser::createShortCircuit(glslang::TOperator op, glslan leftId = builder.createUnaryOp(spv::OpLogicalNot, boolTypeId, leftId); // make an "if" based on the left value - spv::Builder::If ifBuilder(leftId, builder); + spv::Builder::If ifBuilder(leftId, spv::SelectionControlMaskNone, builder); // emit right operand as the "then" part of the "if" builder.clearAccessChain(); diff --git a/SPIRV/SpvBuilder.cpp b/SPIRV/SpvBuilder.cpp index ffd17af1b..11e9fe600 100644 --- a/SPIRV/SpvBuilder.cpp +++ b/SPIRV/SpvBuilder.cpp @@ -2009,9 +2009,10 @@ Id Builder::createMatrixConstructor(Decoration precision, const std::vector& } // Comments in header -Builder::If::If(Id cond, Builder& gb) : +Builder::If::If(Id cond, unsigned int ctrl, Builder& gb) : builder(gb), condition(cond), + control(ctrl), elseBlock(0) { function = &builder.getBuildPoint()->getParent(); @@ -2052,7 +2053,7 @@ void Builder::If::makeEndIf() // Go back to the headerBlock and make the flow control split builder.setBuildPoint(headerBlock); - builder.createSelectionMerge(mergeBlock, SelectionControlMaskNone); + builder.createSelectionMerge(mergeBlock, control); if (elseBlock) builder.createConditionalBranch(condition, thenBlock, elseBlock); else @@ -2064,7 +2065,7 @@ void Builder::If::makeEndIf() } // Comments in header -void Builder::makeSwitch(Id selector, int numSegments, const std::vector& caseValues, +void Builder::makeSwitch(Id selector, unsigned int control, int numSegments, const std::vector& caseValues, const std::vector& valueIndexToSegment, int defaultSegment, std::vector& segmentBlocks) { @@ -2077,7 +2078,7 @@ void Builder::makeSwitch(Id selector, int numSegments, const std::vector& c Block* mergeBlock = new Block(getUniqueId(), function); // make and insert the switch's selection-merge instruction - createSelectionMerge(mergeBlock, SelectionControlMaskNone); + createSelectionMerge(mergeBlock, control); // make the switch instruction Instruction* switchInst = new Instruction(NoResult, NoType, OpSwitch); diff --git a/SPIRV/SpvBuilder.h b/SPIRV/SpvBuilder.h index 92f5084c1..2b17ba608 100755 --- a/SPIRV/SpvBuilder.h +++ b/SPIRV/SpvBuilder.h @@ -385,7 +385,7 @@ public: // Helper to use for building nested control flow with if-then-else. class If { public: - If(Id condition, Builder& builder); + If(Id condition, unsigned int ctrl, Builder& builder); ~If() {} void makeBeginElse(); @@ -397,6 +397,7 @@ public: Builder& builder; Id condition; + unsigned int control; Function* function; Block* headerBlock; Block* thenBlock; @@ -416,7 +417,7 @@ public: // Returns the right set of basic blocks to start each code segment with, so that the caller's // recursion stack can hold the memory for it. // - void makeSwitch(Id condition, int numSegments, const std::vector& caseValues, + void makeSwitch(Id condition, unsigned int control, int numSegments, const std::vector& caseValues, const std::vector& valueToSegment, int defaultSegment, std::vector& segmentBB); // return argument // Add a branch to the innermost switch's merge block. diff --git a/SPIRV/doc.cpp b/SPIRV/doc.cpp index 03097d0cb..fb0cc36ea 100755 --- a/SPIRV/doc.cpp +++ b/SPIRV/doc.cpp @@ -637,13 +637,15 @@ const char* SelectControlString(int cont) } } -const int LoopControlCeiling = 2; +const int LoopControlCeiling = 4; const char* LoopControlString(int cont) { switch (cont) { case 0: return "Unroll"; case 1: return "DontUnroll"; + case 2: return "DependencyInfinite"; + case 3: return "DependencyLength"; case LoopControlCeiling: default: return "Bad"; diff --git a/Test/baseResults/hlsl.attribute.frag.out b/Test/baseResults/hlsl.attribute.frag.out index ccd7693f7..47fb67760 100755 --- a/Test/baseResults/hlsl.attribute.frag.out +++ b/Test/baseResults/hlsl.attribute.frag.out @@ -90,7 +90,7 @@ gl_FragCoord origin is upper left 11(@PixelShaderFunction(vf4;): 2 Function None 9 10(input): 8(ptr) FunctionParameter 12: Label - SelectionMerge 16 None + SelectionMerge 16 DontFlatten BranchConditional 14 15 16 15: Label Branch 16 diff --git a/Test/baseResults/hlsl.if.frag.out b/Test/baseResults/hlsl.if.frag.out index 89e0bb1b8..eade928aa 100755 --- a/Test/baseResults/hlsl.if.frag.out +++ b/Test/baseResults/hlsl.if.frag.out @@ -319,7 +319,7 @@ gl_FragCoord origin is upper left 48: 7(fvec4) Load 10(input) 49: 16(bvec4) FOrdEqual 47 48 50: 15(bool) All 49 - SelectionMerge 52 None + SelectionMerge 52 Flatten BranchConditional 50 51 52 51: Label 53: 7(fvec4) Load 10(input) diff --git a/Test/baseResults/hlsl.switch.frag.out b/Test/baseResults/hlsl.switch.frag.out index 192e1b4ff..9cb52f115 100755 --- a/Test/baseResults/hlsl.switch.frag.out +++ b/Test/baseResults/hlsl.switch.frag.out @@ -399,7 +399,7 @@ gl_FragCoord origin is upper left Branch 25 25: Label 36: 9(int) Load 13(c) - SelectionMerge 40 None + SelectionMerge 40 DontFlatten Switch 36 39 case 1: 37 case 2: 38 diff --git a/Test/hlsl.switch.frag b/Test/hlsl.switch.frag index 88239c2be..78ebfef34 100644 --- a/Test/hlsl.switch.frag +++ b/Test/hlsl.switch.frag @@ -18,7 +18,7 @@ float4 PixelShaderFunction(float4 input, int c, int d) : COLOR0 break; } - switch (c) { + [branch] switch (c) { case 1: ++input; break; diff --git a/glslang/Include/intermediate.h b/glslang/Include/intermediate.h index 0fcac0b54..0922d125a 100644 --- a/glslang/Include/intermediate.h +++ b/glslang/Include/intermediate.h @@ -861,6 +861,15 @@ protected: TType type; }; +// +// Selection control hints +// +enum TSelectionControl { + ESelectionControlNone, + ESelectionControlFlatten, + ESelectionControlDontFlatten, +}; + // // Loop control hints // @@ -1285,19 +1294,22 @@ protected: class TIntermSelection : public TIntermTyped { public: TIntermSelection(TIntermTyped* cond, TIntermNode* trueB, TIntermNode* falseB) : - TIntermTyped(EbtVoid), condition(cond), trueBlock(trueB), falseBlock(falseB) {} + TIntermTyped(EbtVoid), condition(cond), trueBlock(trueB), falseBlock(falseB), control(ESelectionControlNone) {} TIntermSelection(TIntermTyped* cond, TIntermNode* trueB, TIntermNode* falseB, const TType& type) : - TIntermTyped(type), condition(cond), trueBlock(trueB), falseBlock(falseB) {} + TIntermTyped(type), condition(cond), trueBlock(trueB), falseBlock(falseB), control(ESelectionControlNone) {} virtual void traverse(TIntermTraverser*); virtual TIntermTyped* getCondition() const { return condition; } virtual TIntermNode* getTrueBlock() const { return trueBlock; } virtual TIntermNode* getFalseBlock() const { return falseBlock; } virtual TIntermSelection* getAsSelectionNode() { return this; } virtual const TIntermSelection* getAsSelectionNode() const { return this; } + void setSelectionControl(TSelectionControl c) { control = c; } + TSelectionControl getSelectionControl() const { return control; } protected: TIntermTyped* condition; TIntermNode* trueBlock; TIntermNode* falseBlock; + TSelectionControl control; // selection control hint }; // @@ -1308,15 +1320,18 @@ protected: // class TIntermSwitch : public TIntermNode { public: - TIntermSwitch(TIntermTyped* cond, TIntermAggregate* b) : condition(cond), body(b) { } + TIntermSwitch(TIntermTyped* cond, TIntermAggregate* b) : condition(cond), body(b), control(ESelectionControlNone) { } virtual void traverse(TIntermTraverser*); virtual TIntermNode* getCondition() const { return condition; } virtual TIntermAggregate* getBody() const { return body; } virtual TIntermSwitch* getAsSwitchNode() { return this; } virtual const TIntermSwitch* getAsSwitchNode() const { return this; } + void setSelectionControl(TSelectionControl c) { control = c; } + TSelectionControl getSelectionControl() const { return control; } protected: TIntermTyped* condition; TIntermAggregate* body; + TSelectionControl control; // selection control hint }; enum TVisit diff --git a/glslang/MachineIndependent/Intermediate.cpp b/glslang/MachineIndependent/Intermediate.cpp index a08944a5e..04a45e039 100644 --- a/glslang/MachineIndependent/Intermediate.cpp +++ b/glslang/MachineIndependent/Intermediate.cpp @@ -1614,7 +1614,7 @@ TIntermAggregate* TIntermediate::makeAggregate(const TSourceLoc& loc) // // Returns the selection node created. // -TIntermTyped* TIntermediate::addSelection(TIntermTyped* cond, TIntermNodePair nodePair, const TSourceLoc& loc) +TIntermTyped* TIntermediate::addSelection(TIntermTyped* cond, TIntermNodePair nodePair, const TSourceLoc& loc, TSelectionControl control) { // // Don't prune the false path for compile-time constants; it's needed @@ -1623,6 +1623,7 @@ TIntermTyped* TIntermediate::addSelection(TIntermTyped* cond, TIntermNodePair no TIntermSelection* node = new TIntermSelection(cond, nodePair.node1, nodePair.node2); node->setLoc(loc); + node->setSelectionControl(control); return node; } @@ -1665,12 +1666,12 @@ TIntermTyped* TIntermediate::addMethod(TIntermTyped* object, const TType& type, // // Returns the selection node created, or nullptr if one could not be. // -TIntermTyped* TIntermediate::addSelection(TIntermTyped* cond, TIntermTyped* trueBlock, TIntermTyped* falseBlock, const TSourceLoc& loc) +TIntermTyped* TIntermediate::addSelection(TIntermTyped* cond, TIntermTyped* trueBlock, TIntermTyped* falseBlock, const TSourceLoc& loc, TSelectionControl control) { // If it's void, go to the if-then-else selection() if (trueBlock->getBasicType() == EbtVoid && falseBlock->getBasicType() == EbtVoid) { TIntermNodePair pair = { trueBlock, falseBlock }; - return addSelection(cond, pair, loc); + return addSelection(cond, pair, loc, control); } // diff --git a/glslang/MachineIndependent/localintermediate.h b/glslang/MachineIndependent/localintermediate.h index 4eacf58ca..c1daf1b50 100644 --- a/glslang/MachineIndependent/localintermediate.h +++ b/glslang/MachineIndependent/localintermediate.h @@ -276,8 +276,8 @@ public: TIntermAggregate* makeAggregate(const TSourceLoc&); TIntermTyped* setAggregateOperator(TIntermNode*, TOperator, const TType& type, TSourceLoc); bool areAllChildConst(TIntermAggregate* aggrNode); - TIntermTyped* addSelection(TIntermTyped* cond, TIntermNodePair code, const TSourceLoc&); - TIntermTyped* addSelection(TIntermTyped* cond, TIntermTyped* trueBlock, TIntermTyped* falseBlock, const TSourceLoc&); + TIntermTyped* addSelection(TIntermTyped* cond, TIntermNodePair code, const TSourceLoc&, TSelectionControl = ESelectionControlNone); + TIntermTyped* addSelection(TIntermTyped* cond, TIntermTyped* trueBlock, TIntermTyped* falseBlock, const TSourceLoc&, TSelectionControl = ESelectionControlNone); TIntermTyped* addComma(TIntermTyped* left, TIntermTyped* right, const TSourceLoc&); TIntermTyped* addMethod(TIntermTyped*, const TType&, const TString*, const TSourceLoc&); TIntermConstantUnion* addConstantUnion(const TConstUnionArray&, const TType&, const TSourceLoc&, bool literal = false) const; diff --git a/hlsl/hlslGrammar.cpp b/hlsl/hlslGrammar.cpp index a711fd338..95869e8d4 100755 --- a/hlsl/hlslGrammar.cpp +++ b/hlsl/hlslGrammar.cpp @@ -3201,10 +3201,10 @@ bool HlslGrammar::acceptStatement(TIntermNode*& statement) return acceptScopedCompoundStatement(statement); case EHTokIf: - return acceptSelectionStatement(statement); + return acceptSelectionStatement(statement, attributes); case EHTokSwitch: - return acceptSwitchStatement(statement); + return acceptSwitchStatement(statement, attributes); case EHTokFor: case EHTokDo: @@ -3317,10 +3317,12 @@ void HlslGrammar::acceptAttributes(TAttributeMap& attributes) // : IF LEFT_PAREN expression RIGHT_PAREN statement // : IF LEFT_PAREN expression RIGHT_PAREN statement ELSE statement // -bool HlslGrammar::acceptSelectionStatement(TIntermNode*& statement) +bool HlslGrammar::acceptSelectionStatement(TIntermNode*& statement, const TAttributeMap& attributes) { TSourceLoc loc = token.loc; + const TSelectionControl control = parseContext.handleSelectionControl(attributes); + // IF if (! acceptTokenClass(EHTokIf)) return false; @@ -3358,7 +3360,7 @@ bool HlslGrammar::acceptSelectionStatement(TIntermNode*& statement) } // Put the pieces together - statement = intermediate.addSelection(condition, thenElse, loc); + statement = intermediate.addSelection(condition, thenElse, loc, control); parseContext.popScope(); --parseContext.controlFlowNestingLevel; @@ -3368,10 +3370,13 @@ bool HlslGrammar::acceptSelectionStatement(TIntermNode*& statement) // switch_statement // : SWITCH LEFT_PAREN expression RIGHT_PAREN compound_statement // -bool HlslGrammar::acceptSwitchStatement(TIntermNode*& statement) +bool HlslGrammar::acceptSwitchStatement(TIntermNode*& statement, const TAttributeMap& attributes) { // SWITCH TSourceLoc loc = token.loc; + + const TSelectionControl control = parseContext.handleSelectionControl(attributes); + if (! acceptTokenClass(EHTokSwitch)) return false; @@ -3391,7 +3396,7 @@ bool HlslGrammar::acceptSwitchStatement(TIntermNode*& statement) --parseContext.controlFlowNestingLevel; if (statementOkay) - statement = parseContext.addSwitch(loc, switchExpression, statement ? statement->getAsAggregate() : nullptr); + statement = parseContext.addSwitch(loc, switchExpression, statement ? statement->getAsAggregate() : nullptr, control); parseContext.popSwitchSequence(); parseContext.popScope(); diff --git a/hlsl/hlslGrammar.h b/hlsl/hlslGrammar.h index 5e56eddcc..ded8e9669 100755 --- a/hlsl/hlslGrammar.h +++ b/hlsl/hlslGrammar.h @@ -116,8 +116,8 @@ namespace glslang { bool acceptStatement(TIntermNode*&); bool acceptNestedStatement(TIntermNode*&); void acceptAttributes(TAttributeMap&); - bool acceptSelectionStatement(TIntermNode*&); - bool acceptSwitchStatement(TIntermNode*&); + bool acceptSelectionStatement(TIntermNode*&, const TAttributeMap&); + bool acceptSwitchStatement(TIntermNode*&, const TAttributeMap&); bool acceptIterationStatement(TIntermNode*&, const TAttributeMap&); bool acceptJumpStatement(TIntermNode*&); bool acceptCaseLabel(TIntermNode*&); diff --git a/hlsl/hlslParseHelper.cpp b/hlsl/hlslParseHelper.cpp index 1df7ba3c3..27dac886a 100755 --- a/hlsl/hlslParseHelper.cpp +++ b/hlsl/hlslParseHelper.cpp @@ -8273,6 +8273,19 @@ bool HlslParseContext::handleOutputGeometry(const TSourceLoc& loc, const TLayout return true; } +// +// Selection hints +// +TSelectionControl HlslParseContext::handleSelectionControl(const TAttributeMap& attributes) const +{ + if (attributes.contains(EatFlatten)) + return ESelectionControlFlatten; + else if (attributes.contains(EatBranch)) + return ESelectionControlDontFlatten; + else + return ESelectionControlNone; +} + // // Loop hints // @@ -8286,7 +8299,6 @@ TLoopControl HlslParseContext::handleLoopControl(const TAttributeMap& attributes return ELoopControlNone; } - // // Updating default qualifier for the case of a declaration with just a qualifier, // no type, block, or identifier. @@ -8425,7 +8437,7 @@ void HlslParseContext::wrapupSwitchSubsequence(TIntermAggregate* statements, TIn // Turn the top-level node sequence built up of wrapupSwitchSubsequence // into a switch node. // -TIntermNode* HlslParseContext::addSwitch(const TSourceLoc& loc, TIntermTyped* expression, TIntermAggregate* lastStatements) +TIntermNode* HlslParseContext::addSwitch(const TSourceLoc& loc, TIntermTyped* expression, TIntermAggregate* lastStatements, TSelectionControl control) { wrapupSwitchSubsequence(lastStatements, nullptr); @@ -8452,6 +8464,7 @@ TIntermNode* HlslParseContext::addSwitch(const TSourceLoc& loc, TIntermTyped* ex TIntermSwitch* switchNode = new TIntermSwitch(expression, body); switchNode->setLoc(loc); + switchNode->setSelectionControl(control); return switchNode; } diff --git a/hlsl/hlslParseHelper.h b/hlsl/hlslParseHelper.h index 15c7b0705..713e8303e 100755 --- a/hlsl/hlslParseHelper.h +++ b/hlsl/hlslParseHelper.h @@ -159,7 +159,7 @@ public: void addQualifierToExisting(const TSourceLoc&, TQualifier, TIdentifierList&); void updateStandaloneQualifierDefaults(const TSourceLoc&, const TPublicType&); void wrapupSwitchSubsequence(TIntermAggregate* statements, TIntermNode* branchNode); - TIntermNode* addSwitch(const TSourceLoc&, TIntermTyped* expression, TIntermAggregate* body); + TIntermNode* addSwitch(const TSourceLoc&, TIntermTyped* expression, TIntermAggregate* body, TSelectionControl control); void updateImplicitArraySize(const TSourceLoc&, TIntermNode*, int index); @@ -198,6 +198,9 @@ public: bool handleOutputGeometry(const TSourceLoc&, const TLayoutGeometry& geometry); bool handleInputGeometry(const TSourceLoc&, const TLayoutGeometry& geometry); + // Determine selection control from attributes + TSelectionControl handleSelectionControl(const TAttributeMap& attributes) const; + // Determine loop control from attributes TLoopControl handleLoopControl(const TAttributeMap& attributes) const;