Implement relaxed rule for opaque struct members

This commit is contained in:
Samuel Bourasseau 2023-11-29 01:19:02 +01:00 committed by GitHub
parent a3069e1df4
commit c59b876ca0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 2787 additions and 2187 deletions

File diff suppressed because it is too large Load Diff

View File

@ -7,10 +7,17 @@ uniform vec4 a;
uniform vec2 b = vec2(0, 0); // initializer will be ignored
layout(location = 0) uniform vec2 c; // location qualifier will be ignored
uniform vec4 d[10];
struct SamplerArray{
sampler2D tn[4];
};
uniform struct e {
vec2 x;
float y;
uint z;
sampler2D t0;
SamplerArray samplers;
} structUniform;
// opaque types will not be grouped into uniform block
@ -64,11 +71,15 @@ uint bar() {
vec4 foo() {
float f = j + bufferInstance.j + structUniform.y + structUniform.z;
vec2 v2 = b + c + structUniform.x;
vec4 v4 = a + d[0] + d[1] + d[2] + k + bufferInstance.k + texture(t1, vec2(0, 0));
vec4 v4 = a + d[0] + d[1] + d[2] + k + bufferInstance.k + texture(t1, vec2(0, 0)) + texture(structUniform.t0, vec2(0, 0));
return vec4(f) * vec4(v2, 1, 1) * v4;
}
vec4 baz(SamplerArray samplers) {
return texture(samplers.tn[0], vec2(0, 0)) + texture(samplers.tn[1], vec2(0, 0)) + texture(samplers.tn[2], vec2(0, 0)) + texture(samplers.tn[3], vec2(0, 0));
}
void main() {
float j = float(bar());
o = j * foo();
}
o = j * foo() + baz(structUniform.samplers);
}

View File

@ -2317,6 +2317,40 @@ TIntermAggregate* TIntermediate::growAggregate(TIntermNode* left, TIntermNode* r
return aggNode;
}
TIntermAggregate* TIntermediate::mergeAggregate(TIntermNode* left, TIntermNode* right)
{
if (left == nullptr && right == nullptr)
return nullptr;
TIntermAggregate* aggNode = nullptr;
if (left != nullptr)
aggNode = left->getAsAggregate();
if (aggNode == nullptr || aggNode->getOp() != EOpNull) {
aggNode = new TIntermAggregate;
if (left != nullptr)
aggNode->getSequence().push_back(left);
}
TIntermAggregate* rhsagg = right->getAsAggregate();
if (rhsagg == nullptr || rhsagg->getOp() != EOpNull)
aggNode->getSequence().push_back(right);
else
aggNode->getSequence().insert(aggNode->getSequence().end(),
rhsagg->getSequence().begin(),
rhsagg->getSequence().end());
return aggNode;
}
TIntermAggregate* TIntermediate::mergeAggregate(TIntermNode* left, TIntermNode* right, const TSourceLoc& loc)
{
TIntermAggregate* aggNode = mergeAggregate(left, right);
if (aggNode)
aggNode->setLoc(loc);
return aggNode;
}
//
// Turn an existing node into an aggregate.
//

View File

@ -722,6 +722,24 @@ void TParseContextBase::finish()
if (parsingBuiltins)
return;
for (const TString& relaxedSymbol : relaxedSymbols)
{
TSymbol* symbol = symbolTable.find(relaxedSymbol);
TType& type = symbol->getWritableType();
for (const TTypeLoc& typeLoc : *type.getStruct())
{
if (typeLoc.type->isOpaque())
{
typeLoc.type->getSampler() = TSampler{};
typeLoc.type->setBasicType(EbtInt);
TString fieldName("/*");
fieldName.append(typeLoc.type->getFieldName());
fieldName.append("*/");
typeLoc.type->setFieldName(fieldName);
}
}
}
// Transfer the linkage symbols to AST nodes, preserving order.
TIntermAggregate* linkage = new TIntermAggregate;
for (auto i = linkageSymbols.begin(); i != linkageSymbols.end(); ++i)

View File

@ -993,17 +993,25 @@ TIntermTyped* TParseContext::handleDotDereference(const TSourceLoc& loc, TInterm
break;
}
}
if (fieldFound) {
if (base->getType().getQualifier().isFrontEndConstant())
result = intermediate.foldDereference(base, member, loc);
else {
blockMemberExtensionCheck(loc, base, member, field);
TIntermTyped* index = intermediate.addConstantUnion(member, loc);
result = intermediate.addIndex(EOpIndexDirectStruct, base, index, loc);
result->setType(*(*fields)[member].type);
if ((*fields)[member].type->getQualifier().isIo())
intermediate.addIoAccessed(field);
if (spvVersion.vulkan != 0 && spvVersion.vulkanRelaxed)
result = vkRelaxedRemapDotDereference(loc, *base, *(*fields)[member].type, field);
if (result == base)
{
if (base->getType().getQualifier().isFrontEndConstant())
result = intermediate.foldDereference(base, member, loc);
else {
blockMemberExtensionCheck(loc, base, member, field);
TIntermTyped* index = intermediate.addConstantUnion(member, loc);
result = intermediate.addIndex(EOpIndexDirectStruct, base, index, loc);
result->setType(*(*fields)[member].type);
if ((*fields)[member].type->getQualifier().isIo())
intermediate.addIoAccessed(field);
}
}
inheritMemoryQualifiers(base->getQualifier(), result->getWritableType().getQualifier());
} else {
auto baseSymbol = base;
@ -7357,12 +7365,14 @@ void TParseContext::coopMatTypeParametersCheck(const TSourceLoc& loc, const TPub
}
}
bool TParseContext::vkRelaxedRemapUniformVariable(const TSourceLoc& loc, TString& identifier, const TPublicType&,
bool TParseContext::vkRelaxedRemapUniformVariable(const TSourceLoc& loc, TString& identifier, const TPublicType& publicType,
TArraySizes*, TIntermTyped* initializer, TType& type)
{
vkRelaxedRemapUniformMembers(loc, publicType, type, identifier);
if (parsingBuiltins || symbolTable.atBuiltInLevel() || !symbolTable.atGlobalLevel() ||
type.getQualifier().storage != EvqUniform ||
!(type.containsNonOpaque()|| type.getBasicType() == EbtAtomicUint)) {
!(type.containsNonOpaque() || type.getBasicType() == EbtAtomicUint || (type.containsSampler() && type.isStruct()))) {
return false;
}
@ -7436,6 +7446,251 @@ bool TParseContext::vkRelaxedRemapUniformVariable(const TSourceLoc& loc, TString
return true;
}
template <typename Function>
static void ForEachOpaque(const TType& type, const TString& path, Function callback)
{
auto recursion = [&callback](const TType& type, const TString& path, bool skipArray, auto& recursion) -> void {
if (!skipArray && type.isArray())
{
std::vector<int> indices(type.getArraySizes()->getNumDims());
for (int flatIndex = 0;
flatIndex < type.getArraySizes()->getCumulativeSize();
++flatIndex)
{
TString subscriptPath = path;
for (int dimIndex = 0; dimIndex < indices.size(); ++dimIndex)
{
int index = indices[dimIndex];
subscriptPath.append("[");
subscriptPath.append(String(index));
subscriptPath.append("]");
}
recursion(type, subscriptPath, true, recursion);
for (int dimIndex = 0; dimIndex < indices.size(); ++dimIndex)
{
++indices[dimIndex];
if (indices[dimIndex] < type.getArraySizes()->getDimSize(dimIndex))
break;
else
indices[dimIndex] = 0;
}
}
}
else if (type.isStruct() && type.containsOpaque())
{
const TTypeList& types = *type.getStruct();
for (const TTypeLoc& typeLoc : types)
{
TString nextPath = path;
nextPath.append(".");
nextPath.append(typeLoc.type->getFieldName());
recursion(*(typeLoc.type), nextPath, false, recursion);
}
}
else if (type.isOpaque())
{
callback(type, path);
}
};
recursion(type, path, false, recursion);
}
void TParseContext::vkRelaxedRemapUniformMembers(const TSourceLoc& loc, const TPublicType& publicType, const TType& type,
const TString& identifier)
{
if (!type.isStruct() || !type.containsOpaque())
return;
ForEachOpaque(type, identifier,
[&publicType, &loc, this](const TType& type, const TString& path) {
TArraySizes arraySizes = {};
if (type.getArraySizes()) arraySizes = *type.getArraySizes();
TTypeParameters typeParameters = {};
if (type.getTypeParameters()) typeParameters = *type.getTypeParameters();
TPublicType memberType{};
memberType.basicType = type.getBasicType();
memberType.sampler = type.getSampler();
memberType.qualifier = type.getQualifier();
memberType.vectorSize = type.getVectorSize();
memberType.matrixCols = type.getMatrixCols();
memberType.matrixRows = type.getMatrixRows();
memberType.coopmatNV = type.isCoopMatNV();
memberType.coopmatKHR = type.isCoopMatKHR();
memberType.arraySizes = nullptr;
memberType.userDef = nullptr;
memberType.loc = loc;
memberType.typeParameters = (type.getTypeParameters() ? &typeParameters : nullptr);
memberType.spirvType = nullptr;
memberType.qualifier.storage = publicType.qualifier.storage;
memberType.shaderQualifiers = publicType.shaderQualifiers;
TString& structMemberName = *NewPoolTString(path.c_str()); // A copy is required due to declareVariable() signature.
declareVariable(loc, structMemberName, memberType, nullptr, nullptr);
});
}
void TParseContext::vkRelaxedRemapFunctionParameter(const TSourceLoc& loc, TFunction* function, TParameter& param, std::vector<int>* newParams)
{
function->addParameter(param);
if (!param.type->isStruct() || !param.type->containsOpaque())
return;
ForEachOpaque(*param.type, (param.name ? *param.name : param.type->getFieldName()),
[function, param, newParams](const TType& type, const TString& path) {
TString* memberName = NewPoolTString(path.c_str());
TType* memberType = new TType();
memberType->shallowCopy(type);
memberType->getQualifier().storage = param.type->getQualifier().storage;
memberType->clearArraySizes();
TParameter memberParam = {};
memberParam.name = memberName;
memberParam.type = memberType;
memberParam.defaultValue = nullptr;
function->addParameter(memberParam);
if (newParams)
newParams->push_back(function->getParamCount()-1);
});
}
//
// Generates a valid GLSL dereferencing string for the input TIntermNode
//
struct AccessChainTraverser : public TIntermTraverser {
AccessChainTraverser() : TIntermTraverser(false, false, true)
{}
TString path = "";
bool visitBinary(TVisit, TIntermBinary* binary) override {
if (binary->getOp() == EOpIndexDirectStruct)
{
const TTypeList& members = *binary->getLeft()->getType().getStruct();
const TTypeLoc& member =
members[binary->getRight()->getAsConstantUnion()->getConstArray()[0].getIConst()];
TString memberName = member.type->getFieldName();
if (path != "")
path.append(".");
path.append(memberName);
}
if (binary->getOp() == EOpIndexDirect)
{
const TConstUnionArray& indices = binary->getRight()->getAsConstantUnion()->getConstArray();
for (int index = 0; index < indices.size(); ++index)
{
path.append("[");
path.append(String(indices[index].getIConst()));
path.append("]");
}
}
return true;
}
void visitSymbol(TIntermSymbol* symbol) override {
if (!IsAnonymous(symbol->getName()))
path.append(symbol->getName());
}
};
TIntermNode* TParseContext::vkRelaxedRemapFunctionArgument(const TSourceLoc& loc, TFunction* function, TIntermTyped* intermTyped)
{
AccessChainTraverser accessChainTraverser{};
intermTyped->traverse(&accessChainTraverser);
TParameter param = { NewPoolTString(accessChainTraverser.path.c_str()), new TType };
param.type->shallowCopy(intermTyped->getType());
std::vector<int> newParams = {};
vkRelaxedRemapFunctionParameter(loc, function, param, &newParams);
if (intermTyped->getType().isOpaque())
{
TIntermNode* remappedArgument = intermTyped;
{
TIntermSymbol* intermSymbol = nullptr;
TSymbol* symbol = symbolTable.find(*param.name);
if (symbol && symbol->getAsVariable())
intermSymbol = intermediate.addSymbol(*symbol->getAsVariable(), loc);
else
{
TVariable* variable = new TVariable(param.name, *param.type);
intermSymbol = intermediate.addSymbol(*variable, loc);
}
remappedArgument = intermSymbol;
}
return remappedArgument;
}
else if (!(intermTyped->isStruct() && intermTyped->getType().containsOpaque()))
return intermTyped;
else
{
TIntermNode* remappedArgument = intermTyped;
{
TSymbol* symbol = symbolTable.find(*param.name);
if (symbol && symbol->getAsVariable())
remappedArgument = intermediate.addSymbol(*symbol->getAsVariable(), loc);
}
if (!newParams.empty())
remappedArgument = intermediate.makeAggregate(remappedArgument, loc);
for (int paramIndex : newParams)
{
TParameter& newParam = function->operator[](paramIndex);
TIntermSymbol* intermSymbol = nullptr;
TSymbol* symbol = symbolTable.find(*newParam.name);
if (symbol && symbol->getAsVariable())
intermSymbol = intermediate.addSymbol(*symbol->getAsVariable(), loc);
else
{
TVariable* variable = new TVariable(newParam.name, *newParam.type);
intermSymbol = intermediate.addSymbol(*variable, loc);
}
remappedArgument = intermediate.growAggregate(remappedArgument, intermSymbol);
}
return remappedArgument;
}
}
TIntermTyped* TParseContext::vkRelaxedRemapDotDereference(const TSourceLoc&, TIntermTyped& base, const TType& member,
const TString& identifier)
{
if (!member.isOpaque())
return &base;
AccessChainTraverser traverser{};
base.traverse(&traverser);
if (!traverser.path.empty())
traverser.path.append(".");
traverser.path.append(identifier);
const TSymbol* symbol = symbolTable.find(traverser.path);
if (!symbol)
return &base;
TIntermTyped* result = intermediate.addSymbol(*symbol->getAsVariable());
result->setType(symbol->getType());
return result;
}
//
// Do everything necessary to handle a variable (non-block) declaration.
// Either redeclaring a variable, or making a new one, updating the symbol

View File

@ -180,6 +180,7 @@ public:
// Basic parsing state, easily accessible to the grammar
TSymbolTable& symbolTable; // symbol table that goes with the current language, version, and profile
TVector<TString> relaxedSymbols;
int statementNestingLevel; // 0 if outside all flow control or compound statements
int loopNestingLevel; // 0 if outside all loops
int structNestingLevel; // 0 if outside structures
@ -367,6 +368,10 @@ public:
TIntermTyped* vkRelaxedRemapFunctionCall(const TSourceLoc&, TFunction*, TIntermNode*);
// returns true if the variable was remapped to something else
bool vkRelaxedRemapUniformVariable(const TSourceLoc&, TString&, const TPublicType&, TArraySizes*, TIntermTyped*, TType&);
void vkRelaxedRemapUniformMembers(const TSourceLoc&, const TPublicType&, const TType&, const TString&);
void vkRelaxedRemapFunctionParameter(const TSourceLoc&, TFunction*, TParameter&, std::vector<int>* newParams = nullptr);
TIntermNode* vkRelaxedRemapFunctionArgument(const TSourceLoc&, TFunction*, TIntermTyped*);
TIntermTyped* vkRelaxedRemapDotDereference(const TSourceLoc&, TIntermTyped&, const TType&, const TString&);
void assignError(const TSourceLoc&, const char* op, TString left, TString right);
void unaryOpError(const TSourceLoc&, const char* op, TString operand);

View File

@ -492,18 +492,41 @@ function_call_header_no_parameters
function_call_header_with_parameters
: function_call_header assignment_expression {
TParameter param = { 0, new TType };
param.type->shallowCopy($2->getType());
$1.function->addParameter(param);
$$.function = $1.function;
$$.intermNode = $2;
if (parseContext.spvVersion.vulkan > 0
&& parseContext.spvVersion.vulkanRelaxed
&& $2->getType().containsOpaque())
{
$$.intermNode = parseContext.vkRelaxedRemapFunctionArgument($$.loc, $1.function, $2);
$$.function = $1.function;
}
else
{
TParameter param = { 0, new TType };
param.type->shallowCopy($2->getType());
$1.function->addParameter(param);
$$.function = $1.function;
$$.intermNode = $2;
}
}
| function_call_header_with_parameters COMMA assignment_expression {
TParameter param = { 0, new TType };
param.type->shallowCopy($3->getType());
$1.function->addParameter(param);
$$.function = $1.function;
$$.intermNode = parseContext.intermediate.growAggregate($1.intermNode, $3, $2.loc);
if (parseContext.spvVersion.vulkan > 0
&& parseContext.spvVersion.vulkanRelaxed
&& $3->getType().containsOpaque())
{
TIntermNode* remappedNode = parseContext.vkRelaxedRemapFunctionArgument($2.loc, $1.function, $3);
$$.intermNode = parseContext.intermediate.mergeAggregate($1.intermNode, remappedNode, $2.loc);
$$.function = $1.function;
}
else
{
TParameter param = { 0, new TType };
param.type->shallowCopy($3->getType());
$1.function->addParameter(param);
$$.function = $1.function;
$$.intermNode = parseContext.intermediate.growAggregate($1.intermNode, $3, $2.loc);
}
}
;
@ -980,7 +1003,12 @@ function_header_with_parameters
// Add the parameter
$$ = $1;
if ($2.param.type->getBasicType() != EbtVoid)
$1->addParameter($2.param);
{
if (!(parseContext.spvVersion.vulkan > 0 && parseContext.spvVersion.vulkanRelaxed))
$1->addParameter($2.param);
else
parseContext.vkRelaxedRemapFunctionParameter($2.loc, $1, $2.param);
}
else
delete $2.param.type;
}
@ -998,7 +1026,10 @@ function_header_with_parameters
} else {
// Add the parameter
$$ = $1;
$1->addParameter($3.param);
if (!(parseContext.spvVersion.vulkan > 0 && parseContext.spvVersion.vulkanRelaxed))
$1->addParameter($3.param);
else
parseContext.vkRelaxedRemapFunctionParameter($3.loc, $1, $3.param);
}
}
;
@ -3549,11 +3580,17 @@ precision_qualifier
struct_specifier
: STRUCT IDENTIFIER LEFT_BRACE { parseContext.nestedStructCheck($1.loc); } struct_declaration_list RIGHT_BRACE {
TType* structure = new TType($5, *$2.string);
parseContext.structArrayCheck($2.loc, *structure);
TVariable* userTypeDef = new TVariable($2.string, *structure, true);
if (! parseContext.symbolTable.insert(*userTypeDef))
parseContext.error($2.loc, "redefinition", $2.string->c_str(), "struct");
else if (parseContext.spvVersion.vulkanRelaxed
&& structure->containsOpaque())
parseContext.relaxedSymbols.push_back(structure->getTypeName());
$$.init($1.loc);
$$.basicType = EbtStruct;
$$.userDef = structure;

File diff suppressed because it is too large Load Diff

View File

@ -528,6 +528,8 @@ public:
TOperator mapTypeToConstructorOp(const TType&) const;
TIntermAggregate* growAggregate(TIntermNode* left, TIntermNode* right);
TIntermAggregate* growAggregate(TIntermNode* left, TIntermNode* right, const TSourceLoc&);
TIntermAggregate* mergeAggregate(TIntermNode* left, TIntermNode* right);
TIntermAggregate* mergeAggregate(TIntermNode* left, TIntermNode* right, const TSourceLoc&);
TIntermAggregate* makeAggregate(TIntermNode* node);
TIntermAggregate* makeAggregate(TIntermNode* node, const TSourceLoc&);
TIntermAggregate* makeAggregate(const TSourceLoc&);