Mark function call results as control dependent as necessary.

Inner function calls can contain flow-control sensitive code.
In this case, the function call itself must inherit the
control-dependence.

Rarely happens in practice since optimized code with SSA tends to
inline.
This commit is contained in:
Hans-Kristian Arntzen 2024-05-10 14:52:24 +02:00
parent 476f384eb7
commit 04ddb9a809
3 changed files with 106 additions and 6 deletions

View File

@ -93,6 +93,97 @@ bool Compiler::variable_storage_is_aliased(const SPIRVariable &v)
return !is_restrict && (ssbo || image || counter || buffer_reference);
}
bool Compiler::block_is_control_dependent(const SPIRBlock &block)
{
for (auto &i : block.ops)
{
auto ops = stream(i);
auto op = static_cast<Op>(i.op);
switch (op)
{
case OpFunctionCall:
{
uint32_t func = ops[2];
if (function_is_control_dependent(get<SPIRFunction>(func)))
return true;
break;
}
// Derivatives
case OpDPdx:
case OpDPdxCoarse:
case OpDPdxFine:
case OpDPdy:
case OpDPdyCoarse:
case OpDPdyFine:
case OpFwidth:
case OpFwidthCoarse:
case OpFwidthFine:
// Anything implicit LOD
case OpImageSampleImplicitLod:
case OpImageSampleDrefImplicitLod:
case OpImageSampleProjImplicitLod:
case OpImageSampleProjDrefImplicitLod:
case OpImageSparseSampleImplicitLod:
case OpImageSparseSampleDrefImplicitLod:
case OpImageSparseSampleProjImplicitLod:
case OpImageSparseSampleProjDrefImplicitLod:
case OpImageQueryLod:
case OpImageDrefGather:
case OpImageGather:
case OpImageSparseDrefGather:
case OpImageSparseGather:
// Anything subgroups
case OpGroupNonUniformElect:
case OpGroupNonUniformAll:
case OpGroupNonUniformAny:
case OpGroupNonUniformAllEqual:
case OpGroupNonUniformBroadcast:
case OpGroupNonUniformBroadcastFirst:
case OpGroupNonUniformBallot:
case OpGroupNonUniformInverseBallot:
case OpGroupNonUniformBallotBitExtract:
case OpGroupNonUniformBallotBitCount:
case OpGroupNonUniformBallotFindLSB:
case OpGroupNonUniformBallotFindMSB:
case OpGroupNonUniformShuffle:
case OpGroupNonUniformShuffleXor:
case OpGroupNonUniformShuffleUp:
case OpGroupNonUniformShuffleDown:
case OpGroupNonUniformIAdd:
case OpGroupNonUniformFAdd:
case OpGroupNonUniformIMul:
case OpGroupNonUniformFMul:
case OpGroupNonUniformSMin:
case OpGroupNonUniformUMin:
case OpGroupNonUniformFMin:
case OpGroupNonUniformSMax:
case OpGroupNonUniformUMax:
case OpGroupNonUniformFMax:
case OpGroupNonUniformBitwiseAnd:
case OpGroupNonUniformBitwiseOr:
case OpGroupNonUniformBitwiseXor:
case OpGroupNonUniformLogicalAnd:
case OpGroupNonUniformLogicalOr:
case OpGroupNonUniformLogicalXor:
case OpGroupNonUniformQuadBroadcast:
case OpGroupNonUniformQuadSwap:
// Control barriers
case OpControlBarrier:
return true;
default:
break;
}
}
return false;
}
bool Compiler::block_is_pure(const SPIRBlock &block)
{
// This is a global side effect of the function.
@ -247,18 +338,21 @@ string Compiler::to_name(uint32_t id, bool allow_alias) const
bool Compiler::function_is_pure(const SPIRFunction &func)
{
for (auto block : func.blocks)
{
if (!block_is_pure(get<SPIRBlock>(block)))
{
//fprintf(stderr, "Function %s is impure!\n", to_name(func.self).c_str());
return false;
}
}
//fprintf(stderr, "Function %s is pure!\n", to_name(func.self).c_str());
return true;
}
bool Compiler::function_is_control_dependent(const SPIRFunction &func)
{
for (auto block : func.blocks)
if (block_is_control_dependent(get<SPIRBlock>(block)))
return true;
return false;
}
void Compiler::register_global_read_dependencies(const SPIRBlock &block, uint32_t id)
{
for (auto &i : block.ops)

View File

@ -744,6 +744,8 @@ protected:
bool function_is_pure(const SPIRFunction &func);
bool block_is_pure(const SPIRBlock &block);
bool function_is_control_dependent(const SPIRFunction &func);
bool block_is_control_dependent(const SPIRBlock &block);
bool execution_is_branchless(const SPIRBlock &from, const SPIRBlock &to) const;
bool execution_is_direct_branch(const SPIRBlock &from, const SPIRBlock &to) const;

View File

@ -12171,6 +12171,7 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction)
auto &callee = get<SPIRFunction>(func);
auto &return_type = get<SPIRType>(callee.return_type);
bool pure = function_is_pure(callee);
bool control_dependent = function_is_control_dependent(callee);
bool callee_has_out_variables = false;
bool emit_return_value_as_argument = false;
@ -12264,6 +12265,9 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction)
else
statement(funexpr, ";");
if (control_dependent)
register_control_dependent_expression(id);
break;
}