diff --git a/spirv_common.hpp b/spirv_common.hpp index c645aacf..ba420e1d 100644 --- a/spirv_common.hpp +++ b/spirv_common.hpp @@ -842,6 +842,13 @@ struct SPIRBlock : IVariant BlockID false_block = 0; BlockID default_block = 0; + // If terminator is EmitMeshTasksEXT. + struct + { + ID groups[3]; + ID payload; + } mesh = {}; + SmallVector ops; struct Phi diff --git a/spirv_glsl.cpp b/spirv_glsl.cpp index a2a193cf..902299a8 100644 --- a/spirv_glsl.cpp +++ b/spirv_glsl.cpp @@ -6280,6 +6280,14 @@ void CompilerGLSL::emit_unary_op_cast(uint32_t result_type, uint32_t result_id, inherit_expression_dependencies(result_id, op0); } +void CompilerGLSL::emit_mesh_tasks(SPIRBlock &block) +{ + statement("EmitMeshTasksEXT(", + to_unpacked_expression(block.mesh.groups[0]), ", ", + to_unpacked_expression(block.mesh.groups[1]), ", ", + to_unpacked_expression(block.mesh.groups[2]), ");"); +} + void CompilerGLSL::emit_binary_op(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1, const char *op) { // Various FP arithmetic opcodes such as add, sub, mul will hit this. @@ -16877,6 +16885,7 @@ void CompilerGLSL::emit_block_chain(SPIRBlock &block) break; case SPIRBlock::EmitMeshTasks: + emit_mesh_tasks(block); break; default: diff --git a/spirv_glsl.hpp b/spirv_glsl.hpp index d48ca8fd..ac6a4924 100644 --- a/spirv_glsl.hpp +++ b/spirv_glsl.hpp @@ -708,6 +708,7 @@ protected: void emit_unary_op(uint32_t result_type, uint32_t result_id, uint32_t op0, const char *op); void emit_unary_op_cast(uint32_t result_type, uint32_t result_id, uint32_t op0, const char *op); + virtual void emit_mesh_tasks(SPIRBlock &block); bool expression_is_forwarded(uint32_t id) const; bool expression_suppresses_usage_tracking(uint32_t id) const; bool expression_read_implies_multiple_reads(uint32_t id) const; diff --git a/spirv_hlsl.cpp b/spirv_hlsl.cpp index 7f332ee8..391d629f 100644 --- a/spirv_hlsl.cpp +++ b/spirv_hlsl.cpp @@ -2576,6 +2576,19 @@ void CompilerHLSL::emit_rayquery_function(const char *commited, const char *cand emit_op(ops[0], ops[1], join(to_expression(ops[2]), is_commited ? commited : candidate), false); } +void CompilerHLSL::emit_mesh_tasks(SPIRBlock &block) +{ + if (block.mesh.payload != 0) + { + statement("DispatchMesh(", to_unpacked_expression(block.mesh.groups[0]), ", ", to_unpacked_expression(block.mesh.groups[1]), ", ", + to_unpacked_expression(block.mesh.groups[2]), ", ", to_unpacked_expression(block.mesh.payload), ");"); + } + else + { + SPIRV_CROSS_THROW("Amplification shader in HLSL must have payload"); + } +} + void CompilerHLSL::emit_buffer_block(const SPIRVariable &var) { auto &type = get(var.basetype); @@ -2936,7 +2949,7 @@ void CompilerHLSL::emit_hlsl_entry_point() switch (execution.model) { - case spv::ExecutionModelTaskEXT: + case ExecutionModelTaskEXT: case ExecutionModelMeshEXT: case ExecutionModelGLCompute: { @@ -6361,20 +6374,6 @@ void CompilerHLSL::emit_instruction(const Instruction &instruction) statement("SetMeshOutputCounts(", to_unpacked_expression(ops[0]), ", ", to_unpacked_expression(ops[1]), ");"); break; } - case OpEmitMeshTasksEXT: - { - if (instruction.length == 4) - { - statement("DispatchMesh(", to_unpacked_expression(ops[0]), ", ", to_unpacked_expression(ops[1]), ", ", - to_unpacked_expression(ops[2]), ", ", to_unpacked_expression(ops[3]), ");"); - } - else - { - SPIRV_CROSS_THROW("Amplification shader in HLSL must have payload"); - } - break; - } - default: CompilerGLSL::emit_instruction(instruction); break; diff --git a/spirv_hlsl.hpp b/spirv_hlsl.hpp index 2e16ebd3..51af5bf0 100644 --- a/spirv_hlsl.hpp +++ b/spirv_hlsl.hpp @@ -280,6 +280,7 @@ private: void emit_struct_member(const SPIRType &type, uint32_t member_type_id, uint32_t index, const std::string &qualifier, uint32_t base_offset = 0) override; void emit_rayquery_function(const char *commited, const char *candidate, const uint32_t *ops); + void emit_mesh_tasks(SPIRBlock &block) override; const char *to_storage_qualifiers_glsl(const SPIRVariable &var) override; void replace_illegal_names() override; diff --git a/spirv_parser.cpp b/spirv_parser.cpp index c8cb752a..01c2e381 100644 --- a/spirv_parser.cpp +++ b/spirv_parser.cpp @@ -1123,21 +1123,16 @@ void Parser::parse(const Instruction &instruction) break; case OpEmitMeshTasksEXT: - { if (!current_block) SPIRV_CROSS_THROW("Trying to end a non-existing block."); - - const auto *type = maybe_get(ops[0]); - if (type) - ir.load_type_width.insert({ ops[1], type->width }); - current_block->ops.push_back(instruction); - current_block->terminator = SPIRBlock::EmitMeshTasks; + for (uint32_t i = 0; i < 3; i++) + current_block->mesh.groups[i] = ops[i]; + current_block->mesh.payload = length >= 4 ? ops[3] : 0; current_block = nullptr; // Currently glslang is bugged and does not treat EmitMeshTasksEXT as a terminator. ignore_trailing_block_opcodes = true; break; - } case OpReturn: {