GLSL: Implement task shaders.

Due to bugged glslang / spirv-tools w.r.t. terminator instructions,
add a hack to ignore invalid SPIR-V for the time being.
This commit is contained in:
Hans-Kristian Arntzen 2022-09-05 12:31:22 +02:00
parent 5762617729
commit 4c345166dc
12 changed files with 368 additions and 2 deletions

View File

@ -1082,6 +1082,10 @@ static ExecutionModel stage_to_execution_model(const std::string &stage)
return ExecutionModelMissKHR;
else if (stage == "rcall")
return ExecutionModelCallableKHR;
else if (stage == "mesh")
return spv::ExecutionModelMeshEXT;
else if (stage == "task")
return spv::ExecutionModelTaskEXT;
else
SPIRV_CROSS_THROW("Invalid stage.");
}

View File

@ -0,0 +1,42 @@
#version 450
#extension GL_EXT_mesh_shader : require
layout(local_size_x = 4, local_size_y = 3, local_size_z = 2) in;
struct Payload
{
float v[3];
};
shared float vs[24];
taskPayloadSharedEXT Payload p;
void main()
{
vs[gl_LocalInvocationIndex] = 10.0;
barrier();
if (gl_LocalInvocationIndex < 12u)
{
vs[gl_LocalInvocationIndex] += vs[gl_LocalInvocationIndex + 12u];
}
barrier();
if (gl_LocalInvocationIndex < 6u)
{
vs[gl_LocalInvocationIndex] += vs[gl_LocalInvocationIndex + 6u];
}
barrier();
if (gl_LocalInvocationIndex < 3u)
{
vs[gl_LocalInvocationIndex] += vs[gl_LocalInvocationIndex + 3u];
}
barrier();
p.v[gl_LocalInvocationIndex] = vs[gl_LocalInvocationIndex];
if (vs[5] > 20.0)
{
EmitMeshTasksEXT(uint(int(vs[4])), uint(int(vs[6])), uint(int(vs[8])));
}
else
{
EmitMeshTasksEXT(uint(int(vs[6])), 10u, 50u);
}
}

View File

@ -0,0 +1,35 @@
#version 450
#extension GL_EXT_mesh_shader : require
layout(local_size_x = 4, local_size_y = 3, local_size_z = 2) in;
struct Payload
{
float v[3];
};
shared float vs[24];
taskPayloadSharedEXT Payload p;
void main()
{
vs[gl_LocalInvocationIndex] = 10.0;
barrier();
if (gl_LocalInvocationIndex < 12u)
{
vs[gl_LocalInvocationIndex] += vs[gl_LocalInvocationIndex + 12u];
}
barrier();
if (gl_LocalInvocationIndex < 6u)
{
vs[gl_LocalInvocationIndex] += vs[gl_LocalInvocationIndex + 6u];
}
barrier();
if (gl_LocalInvocationIndex < 3u)
{
vs[gl_LocalInvocationIndex] += vs[gl_LocalInvocationIndex + 3u];
}
barrier();
p.v[gl_LocalInvocationIndex] = vs[gl_LocalInvocationIndex];
EmitMeshTasksEXT(uint(int(vs[4])), uint(int(vs[6])), uint(int(vs[8])));
}

View File

@ -0,0 +1,35 @@
#version 450
#extension GL_EXT_mesh_shader : require
layout(local_size_x = 4, local_size_y = 3, local_size_z = 2) in;
struct Payload
{
float v[3];
};
shared float vs[24];
taskPayloadSharedEXT Payload p;
void main()
{
vs[gl_LocalInvocationIndex] = 10.0;
barrier();
if (gl_LocalInvocationIndex < 12u)
{
vs[gl_LocalInvocationIndex] += vs[gl_LocalInvocationIndex + 12u];
}
barrier();
if (gl_LocalInvocationIndex < 6u)
{
vs[gl_LocalInvocationIndex] += vs[gl_LocalInvocationIndex + 6u];
}
barrier();
if (gl_LocalInvocationIndex < 3u)
{
vs[gl_LocalInvocationIndex] += vs[gl_LocalInvocationIndex + 3u];
}
barrier();
p.v[gl_LocalInvocationIndex] = vs[gl_LocalInvocationIndex];
EmitMeshTasksEXT(uint(int(vs[4])), uint(int(vs[6])), uint(int(vs[8])));
}

View File

@ -0,0 +1,42 @@
#version 450
#extension GL_EXT_mesh_shader : require
layout(local_size_x = 4, local_size_y = 3, local_size_z = 2) in;
struct Payload
{
float v[3];
};
shared float vs[24];
taskPayloadSharedEXT Payload p;
void main()
{
vs[gl_LocalInvocationIndex] = 10.0;
barrier();
if (gl_LocalInvocationIndex < 12u)
{
vs[gl_LocalInvocationIndex] += vs[gl_LocalInvocationIndex + 12u];
}
barrier();
if (gl_LocalInvocationIndex < 6u)
{
vs[gl_LocalInvocationIndex] += vs[gl_LocalInvocationIndex + 6u];
}
barrier();
if (gl_LocalInvocationIndex < 3u)
{
vs[gl_LocalInvocationIndex] += vs[gl_LocalInvocationIndex + 3u];
}
barrier();
p.v[gl_LocalInvocationIndex] = vs[gl_LocalInvocationIndex];
if (vs[5] > 20.0)
{
EmitMeshTasksEXT(uint(int(vs[4])), uint(int(vs[6])), uint(int(vs[8])));
}
else
{
EmitMeshTasksEXT(uint(int(vs[6])), 10u, 50u);
}
}

View File

@ -0,0 +1,35 @@
#version 450
#extension GL_EXT_mesh_shader : require
layout(local_size_x = 4, local_size_y = 3, local_size_z = 2) in;
struct Payload
{
float v[3];
};
shared float vs[24];
taskPayloadSharedEXT Payload p;
void main()
{
vs[gl_LocalInvocationIndex] = 10.0;
barrier();
if (gl_LocalInvocationIndex < 12u)
{
vs[gl_LocalInvocationIndex] += vs[gl_LocalInvocationIndex + 12u];
}
barrier();
if (gl_LocalInvocationIndex < 6u)
{
vs[gl_LocalInvocationIndex] += vs[gl_LocalInvocationIndex + 6u];
}
barrier();
if (gl_LocalInvocationIndex < 3u)
{
vs[gl_LocalInvocationIndex] += vs[gl_LocalInvocationIndex + 3u];
}
barrier();
p.v[gl_LocalInvocationIndex] = vs[gl_LocalInvocationIndex];
EmitMeshTasksEXT(uint(int(vs[4])), uint(int(vs[6])), uint(int(vs[8])));
}

View File

@ -0,0 +1,132 @@
; SPIR-V
; Version: 1.4
; Generator: Khronos Glslang Reference Front End; 10
; Bound: 93
; Schema: 0
OpCapability MeshShadingEXT
OpExtension "SPV_EXT_mesh_shader"
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint TaskEXT %main "main" %vs %gl_LocalInvocationIndex %p
OpExecutionMode %main LocalSize 4 3 2
OpSource GLSL 450
OpSourceExtension "GL_EXT_mesh_shader"
OpName %main "main"
OpName %vs "vs"
OpName %gl_LocalInvocationIndex "gl_LocalInvocationIndex"
OpName %Payload "Payload"
OpMemberName %Payload 0 "v"
OpName %p "p"
OpDecorate %gl_LocalInvocationIndex BuiltIn LocalInvocationIndex
OpDecorate %gl_WorkGroupSize BuiltIn WorkgroupSize
%void = OpTypeVoid
%3 = OpTypeFunction %void
%float = OpTypeFloat 32
%uint = OpTypeInt 32 0
%uint_24 = OpConstant %uint 24
%_arr_float_uint_24 = OpTypeArray %float %uint_24
%_ptr_Workgroup__arr_float_uint_24 = OpTypePointer Workgroup %_arr_float_uint_24
%vs = OpVariable %_ptr_Workgroup__arr_float_uint_24 Workgroup
%_ptr_Input_uint = OpTypePointer Input %uint
%gl_LocalInvocationIndex = OpVariable %_ptr_Input_uint Input
%float_10 = OpConstant %float 10
%_ptr_Workgroup_float = OpTypePointer Workgroup %float
%uint_2 = OpConstant %uint 2
%uint_264 = OpConstant %uint 264
%uint_12 = OpConstant %uint 12
%bool = OpTypeBool
%uint_6 = OpConstant %uint 6
%uint_3 = OpConstant %uint 3
%_arr_float_uint_3 = OpTypeArray %float %uint_3
%Payload = OpTypeStruct %_arr_float_uint_3
%_ptr_TaskPayloadWorkgroupEXT_Payload = OpTypePointer TaskPayloadWorkgroupEXT %Payload
%p = OpVariable %_ptr_TaskPayloadWorkgroupEXT_Payload TaskPayloadWorkgroupEXT
%int = OpTypeInt 32 1
%int_0 = OpConstant %int 0
%_ptr_TaskPayloadWorkgroupEXT_float = OpTypePointer TaskPayloadWorkgroupEXT %float
%int_4 = OpConstant %int 4
%int_6 = OpConstant %int 6
%int_8 = OpConstant %int 8
%v3uint = OpTypeVector %uint 3
%uint_4 = OpConstant %uint 4
%gl_WorkGroupSize = OpConstantComposite %v3uint %uint_4 %uint_3 %uint_2
%main = OpFunction %void None %3
%5 = OpLabel
%14 = OpLoad %uint %gl_LocalInvocationIndex
%17 = OpAccessChain %_ptr_Workgroup_float %vs %14
OpStore %17 %float_10
OpControlBarrier %uint_2 %uint_2 %uint_264
%20 = OpLoad %uint %gl_LocalInvocationIndex
%23 = OpULessThan %bool %20 %uint_12
OpSelectionMerge %25 None
OpBranchConditional %23 %24 %25
%24 = OpLabel
%26 = OpLoad %uint %gl_LocalInvocationIndex
%27 = OpLoad %uint %gl_LocalInvocationIndex
%28 = OpIAdd %uint %27 %uint_12
%29 = OpAccessChain %_ptr_Workgroup_float %vs %28
%30 = OpLoad %float %29
%31 = OpAccessChain %_ptr_Workgroup_float %vs %26
%32 = OpLoad %float %31
%33 = OpFAdd %float %32 %30
%34 = OpAccessChain %_ptr_Workgroup_float %vs %26
OpStore %34 %33
OpBranch %25
%25 = OpLabel
OpControlBarrier %uint_2 %uint_2 %uint_264
%35 = OpLoad %uint %gl_LocalInvocationIndex
%37 = OpULessThan %bool %35 %uint_6
OpSelectionMerge %39 None
OpBranchConditional %37 %38 %39
%38 = OpLabel
%40 = OpLoad %uint %gl_LocalInvocationIndex
%41 = OpLoad %uint %gl_LocalInvocationIndex
%42 = OpIAdd %uint %41 %uint_6
%43 = OpAccessChain %_ptr_Workgroup_float %vs %42
%44 = OpLoad %float %43
%45 = OpAccessChain %_ptr_Workgroup_float %vs %40
%46 = OpLoad %float %45
%47 = OpFAdd %float %46 %44
%48 = OpAccessChain %_ptr_Workgroup_float %vs %40
OpStore %48 %47
OpBranch %39
%39 = OpLabel
OpControlBarrier %uint_2 %uint_2 %uint_264
%49 = OpLoad %uint %gl_LocalInvocationIndex
%51 = OpULessThan %bool %49 %uint_3
OpSelectionMerge %53 None
OpBranchConditional %51 %52 %53
%52 = OpLabel
%54 = OpLoad %uint %gl_LocalInvocationIndex
%55 = OpLoad %uint %gl_LocalInvocationIndex
%56 = OpIAdd %uint %55 %uint_3
%57 = OpAccessChain %_ptr_Workgroup_float %vs %56
%58 = OpLoad %float %57
%59 = OpAccessChain %_ptr_Workgroup_float %vs %54
%60 = OpLoad %float %59
%61 = OpFAdd %float %60 %58
%62 = OpAccessChain %_ptr_Workgroup_float %vs %54
OpStore %62 %61
OpBranch %53
%53 = OpLabel
OpControlBarrier %uint_2 %uint_2 %uint_264
%69 = OpLoad %uint %gl_LocalInvocationIndex
%70 = OpLoad %uint %gl_LocalInvocationIndex
%71 = OpAccessChain %_ptr_Workgroup_float %vs %70
%72 = OpLoad %float %71
%74 = OpAccessChain %_ptr_TaskPayloadWorkgroupEXT_float %p %int_0 %69
OpStore %74 %72
%76 = OpAccessChain %_ptr_Workgroup_float %vs %int_4
%77 = OpLoad %float %76
%78 = OpConvertFToS %int %77
%79 = OpBitcast %uint %78
%81 = OpAccessChain %_ptr_Workgroup_float %vs %int_6
%82 = OpLoad %float %81
%83 = OpConvertFToS %int %82
%84 = OpBitcast %uint %83
%86 = OpAccessChain %_ptr_Workgroup_float %vs %int_8
%87 = OpLoad %float %86
%88 = OpConvertFToS %int %87
%89 = OpBitcast %uint %88
OpEmitMeshTasksEXT %79 %84 %89 %p
OpFunctionEnd

View File

@ -777,7 +777,8 @@ struct SPIRBlock : IVariant
Unreachable, // Noop
Kill, // Discard
IgnoreIntersection, // Ray Tracing
TerminateRay // Ray Tracing
TerminateRay, // Ray Tracing
EmitMeshTasks // Mesh shaders
};
enum Merge
@ -839,6 +840,13 @@ struct SPIRBlock : IVariant
BlockID false_block = 0;
BlockID default_block = 0;
// If terminator is EmitMeshTasksEXT.
struct
{
ID groups[3];
ID payload;
} mesh = {};
SmallVector<Instruction> ops;
struct Phi

View File

@ -98,7 +98,8 @@ bool Compiler::block_is_pure(const SPIRBlock &block)
// This is a global side effect of the function.
if (block.terminator == SPIRBlock::Kill ||
block.terminator == SPIRBlock::TerminateRay ||
block.terminator == SPIRBlock::IgnoreIntersection)
block.terminator == SPIRBlock::IgnoreIntersection ||
block.terminator == SPIRBlock::EmitMeshTasks)
return false;
for (auto &i : block.ops)
@ -155,6 +156,7 @@ bool Compiler::block_is_pure(const SPIRBlock &block)
return false;
// Mesh shader functions modify global state.
// (EmitMeshTasks is a terminator).
case OpSetMeshOutputsEXT:
return false;

View File

@ -498,6 +498,7 @@ void CompilerGLSL::find_static_extensions()
break;
case ExecutionModelMeshEXT:
case ExecutionModelTaskEXT:
if (options.es || options.version < 450)
SPIRV_CROSS_THROW("Mesh shaders require GLSL 450 or above.");
if (!options.vulkan_semantics)
@ -16105,6 +16106,13 @@ void CompilerGLSL::emit_block_chain(SPIRBlock &block)
statement("terminateRayEXT;");
break;
case SPIRBlock::EmitMeshTasks:
statement("EmitMeshTasksEXT(",
to_unpacked_expression(block.mesh.groups[0]), ", ",
to_unpacked_expression(block.mesh.groups[1]), ", ",
to_unpacked_expression(block.mesh.groups[2]), ");");
break;
default:
SPIRV_CROSS_THROW("Unimplemented block terminator.");
}

View File

@ -183,6 +183,15 @@ void Parser::parse(const Instruction &instruction)
auto op = static_cast<Op>(instruction.op);
uint32_t length = instruction.length;
// HACK for glslang that might emit OpEmitMeshTasksEXT followed by return / branch.
// Instead of failing hard, just ignore it.
if (ignore_trailing_block_opcodes)
{
ignore_trailing_block_opcodes = false;
if (op == OpReturn || op == OpBranch || op == OpUnreachable)
return;
}
switch (op)
{
case OpSourceContinued:
@ -1107,6 +1116,18 @@ void Parser::parse(const Instruction &instruction)
current_block = nullptr;
break;
case OpEmitMeshTasksEXT:
if (!current_block)
SPIRV_CROSS_THROW("Trying to end a non-existing block.");
current_block->terminator = SPIRBlock::EmitMeshTasks;
for (uint32_t i = 0; i < 3; i++)
current_block->mesh.groups[i] = ops[i];
current_block->mesh.payload = length >= 4 ? ops[3] : 0;
current_block = nullptr;
// Currently glslang is bugged and does not treat EmitMeshTasksEXT as a terminator.
ignore_trailing_block_opcodes = true;
break;
case OpReturn:
{
if (!current_block)

View File

@ -46,6 +46,8 @@ private:
ParsedIR ir;
SPIRFunction *current_function = nullptr;
SPIRBlock *current_block = nullptr;
// For workarounds.
bool ignore_trailing_block_opcodes = false;
void parse(const Instruction &instr);
const uint32_t *stream(const Instruction &instr) const;