From aab3107a3f3f3695cc6366162a2bd00b7d70b6cf Mon Sep 17 00:00:00 2001 From: Hans-Kristian Arntzen Date: Fri, 29 Sep 2017 12:16:53 +0200 Subject: [PATCH] Add WorkGroupID/NumWorkGroups to MSL. Fix block name alias. --- reference/shaders-hlsl/comp/builtins.comp | 4 ++++ reference/shaders-msl/comp/builtins.comp | 17 +++++++++++++++++ shaders-hlsl/comp/builtins.comp | 1 + shaders-msl/comp/builtins.comp | 12 ++++++++++++ spirv_cross.cpp | 23 ++++++++++++++++++++--- spirv_cross.hpp | 9 +++++---- spirv_glsl.cpp | 4 ++-- spirv_msl.cpp | 21 ++++++++++++++++++--- 8 files changed, 79 insertions(+), 12 deletions(-) create mode 100644 reference/shaders-msl/comp/builtins.comp create mode 100644 shaders-msl/comp/builtins.comp diff --git a/reference/shaders-hlsl/comp/builtins.comp b/reference/shaders-hlsl/comp/builtins.comp index 189aa959..45b6c030 100644 --- a/reference/shaders-hlsl/comp/builtins.comp +++ b/reference/shaders-hlsl/comp/builtins.comp @@ -1,10 +1,12 @@ const uint3 gl_WorkGroupSize = uint3(8u, 4u, 2u); +static uint3 gl_WorkGroupID; static uint3 gl_LocalInvocationID; static uint3 gl_GlobalInvocationID; static uint gl_LocalInvocationIndex; struct SPIRV_Cross_Input { + uint3 gl_WorkGroupID : SV_GroupID; uint3 gl_LocalInvocationID : SV_GroupThreadID; uint3 gl_GlobalInvocationID : SV_DispatchThreadID; uint gl_LocalInvocationIndex : SV_GroupIndex; @@ -16,11 +18,13 @@ void comp_main() uint3 global_id = gl_GlobalInvocationID; uint local_index = gl_LocalInvocationIndex; uint3 work_group_size = gl_WorkGroupSize; + uint3 work_group_id = gl_WorkGroupID; } [numthreads(8, 4, 2)] void main(SPIRV_Cross_Input stage_input) { + gl_WorkGroupID = stage_input.gl_WorkGroupID; gl_LocalInvocationID = stage_input.gl_LocalInvocationID; gl_GlobalInvocationID = stage_input.gl_GlobalInvocationID; gl_LocalInvocationIndex = stage_input.gl_LocalInvocationIndex; diff --git a/reference/shaders-msl/comp/builtins.comp b/reference/shaders-msl/comp/builtins.comp new file mode 100644 index 00000000..4330d578 --- /dev/null +++ b/reference/shaders-msl/comp/builtins.comp @@ -0,0 +1,17 @@ +#include +#include + +using namespace metal; + +constant uint3 gl_WorkGroupSize = uint3(8u, 4u, 2u); + +kernel void main0(uint3 gl_LocalInvocationID [[thread_position_in_threadgroup]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], uint gl_LocalInvocationIndex [[thread_index_in_threadgroup]], uint3 gl_NumWorkGroups [[threadgroups_per_grid]], uint3 gl_WorkGroupID [[threadgroup_position_in_grid]]) +{ + uint3 local_id = gl_LocalInvocationID; + uint3 global_id = gl_GlobalInvocationID; + uint local_index = gl_LocalInvocationIndex; + uint3 work_group_size = gl_WorkGroupSize; + uint3 num_work_groups = gl_NumWorkGroups; + uint3 work_group_id = gl_WorkGroupID; +} + diff --git a/shaders-hlsl/comp/builtins.comp b/shaders-hlsl/comp/builtins.comp index ca0be279..b41cb539 100644 --- a/shaders-hlsl/comp/builtins.comp +++ b/shaders-hlsl/comp/builtins.comp @@ -7,4 +7,5 @@ void main() uvec3 global_id = gl_GlobalInvocationID; uint local_index = gl_LocalInvocationIndex; uvec3 work_group_size = gl_WorkGroupSize; + uvec3 work_group_id = gl_WorkGroupID; } diff --git a/shaders-msl/comp/builtins.comp b/shaders-msl/comp/builtins.comp new file mode 100644 index 00000000..88bb5951 --- /dev/null +++ b/shaders-msl/comp/builtins.comp @@ -0,0 +1,12 @@ +#version 310 es +layout(local_size_x = 8, local_size_y = 4, local_size_z = 2) in; + +void main() +{ + uvec3 local_id = gl_LocalInvocationID; + uvec3 global_id = gl_GlobalInvocationID; + uint local_index = gl_LocalInvocationIndex; + uvec3 work_group_size = gl_WorkGroupSize; + uvec3 num_work_groups = gl_NumWorkGroups; + uvec3 work_group_id = gl_WorkGroupID; +} diff --git a/spirv_cross.cpp b/spirv_cross.cpp index c9f0cf7d..c04b6be0 100644 --- a/spirv_cross.cpp +++ b/spirv_cross.cpp @@ -655,18 +655,24 @@ ShaderResources Compiler::get_shader_resources(const unordered_set *ac else if (type.storage == StorageClassUniform && (meta[type.self].decoration.decoration_flags & (1ull << DecorationBlock))) { - res.uniform_buffers.push_back({ var.self, var.basetype, type.self, meta[type.self].decoration.alias }); + auto &block_name = meta[type.self].decoration.alias; + res.uniform_buffers.push_back({ var.self, var.basetype, type.self, + block_name.empty() ? get_block_fallback_name(var.self) : block_name }); } // Old way to declare SSBOs. else if (type.storage == StorageClassUniform && (meta[type.self].decoration.decoration_flags & (1ull << DecorationBufferBlock))) { - res.storage_buffers.push_back({ var.self, var.basetype, type.self, meta[type.self].decoration.alias }); + auto &block_name = meta[type.self].decoration.alias; + res.storage_buffers.push_back({ var.self, var.basetype, type.self, + block_name.empty() ? get_block_fallback_name(var.self) : block_name }); } // Modern way to declare SSBOs. else if (type.storage == StorageClassStorageBuffer) { - res.storage_buffers.push_back({ var.self, var.basetype, type.self, meta[type.self].decoration.alias }); + auto &block_name = meta[type.self].decoration.alias; + res.storage_buffers.push_back({ var.self, var.basetype, type.self, + block_name.empty() ? get_block_fallback_name(var.self) : block_name }); } // Push constant blocks else if (type.storage == StorageClassPushConstant) @@ -1121,6 +1127,17 @@ const std::string &Compiler::get_name(uint32_t id) const return meta.at(id).decoration.alias; } +const std::string Compiler::get_fallback_name(uint32_t id) const +{ + return join("_", id); +} + +const std::string Compiler::get_block_fallback_name(uint32_t id) const +{ + auto &var = get(id); + return join("_", get(var.basetype).self, "_", id); +} + uint64_t Compiler::get_decoration_mask(uint32_t id) const { auto &dec = meta.at(id).decoration; diff --git a/spirv_cross.hpp b/spirv_cross.hpp index b85dc040..375689a7 100644 --- a/spirv_cross.hpp +++ b/spirv_cross.hpp @@ -155,10 +155,11 @@ public: // If get_name() is an empty string, get the fallback name which will be used // instead in the disassembled source. - virtual const std::string get_fallback_name(uint32_t id) const - { - return join("_", id); - } + virtual const std::string get_fallback_name(uint32_t id) const; + + // If get_name() of a Block struct is an empty string, get the fallback name. + // This needs to be per-variable as multiple variables can use the same block type. + virtual const std::string get_block_fallback_name(uint32_t id) const; // Given an OpTypeStruct in ID, obtain the identifier for member number "index". // This may be an empty string. diff --git a/spirv_glsl.cpp b/spirv_glsl.cpp index a65f900d..cdde579b 100644 --- a/spirv_glsl.cpp +++ b/spirv_glsl.cpp @@ -1171,8 +1171,8 @@ void CompilerGLSL::emit_buffer_block_native(const SPIRVariable &var) // Shaders never use the block by interface name, so we don't // have to track this other than updating name caches. - if (resource_names.find(buffer_name) != end(resource_names)) - buffer_name = get_fallback_name(type.self); + if (meta[type.self].decoration.alias.empty() || resource_names.find(buffer_name) != end(resource_names)) + buffer_name = get_block_fallback_name(var.self); else resource_names.insert(buffer_name); diff --git a/spirv_msl.cpp b/spirv_msl.cpp index 08ee0cff..33203947 100644 --- a/spirv_msl.cpp +++ b/spirv_msl.cpp @@ -996,6 +996,7 @@ void CompilerMSL::emit_resources() } // Output Uniform buffers and constants + unordered_set declared_interface_structs; for (auto &id : ids) { if (id.get_type() == TypeVariable) @@ -1009,8 +1010,13 @@ void CompilerMSL::emit_resources() (has_decoration(type.self, DecorationBlock) || has_decoration(type.self, DecorationBufferBlock)) && !is_hidden_variable(var)) { - align_struct(type); - emit_struct(type); + // Avoid declaring the same struct multiple times. + if (declared_interface_structs.count(type.self) == 0) + { + align_struct(type); + emit_struct(type); + declared_interface_structs.insert(type.self); + } } } } @@ -2177,6 +2183,8 @@ string CompilerMSL::member_attribute_qualifier(const SPIRType &type, uint32_t in switch (builtin) { case BuiltInGlobalInvocationId: + case BuiltInWorkgroupId: + case BuiltInNumWorkgroups: case BuiltInLocalInvocationId: case BuiltInLocalInvocationIndex: return string(" [[") + builtin_qualifier(builtin) + "]]"; @@ -2800,6 +2808,12 @@ string CompilerMSL::builtin_qualifier(BuiltIn builtin) case BuiltInGlobalInvocationId: return "thread_position_in_grid"; + case BuiltInWorkgroupId: + return "threadgroup_position_in_grid"; + + case BuiltInNumWorkgroups: + return "threadgroups_per_grid"; + case BuiltInLocalInvocationId: return "thread_position_in_threadgroup"; @@ -2848,8 +2862,9 @@ string CompilerMSL::builtin_type_decl(BuiltIn builtin) // Compute function in case BuiltInGlobalInvocationId: - return "uint3"; case BuiltInLocalInvocationId: + case BuiltInNumWorkgroups: + case BuiltInWorkgroupId: return "uint3"; case BuiltInLocalInvocationIndex: return "uint";