diff --git a/spirv_msl.cpp b/spirv_msl.cpp index 1f107158..0cab94fd 100644 --- a/spirv_msl.cpp +++ b/spirv_msl.cpp @@ -67,6 +67,32 @@ void CompilerMSL::add_msl_resource_binding(const MSLResourceBinding &binding) { StageSetBinding tuple = { binding.stage, binding.desc_set, binding.binding }; resource_bindings[tuple] = { binding, false }; + + // If we might need to pad argument buffer members to positionally align + // arg buffer indexes, also maintain a lookup by argument buffer index. + if (msl_options.pad_argument_buffer_resources) + { + uint32_t arg_buff_idx = k_unknown_component; + switch (binding.base_type) + { + case SPIRType::Void: + arg_buff_idx = binding.msl_buffer; + break; + case SPIRType::Image: + case SPIRType::SampledImage: + arg_buff_idx = binding.msl_texture; + break; + case SPIRType::Sampler: + arg_buff_idx = binding.msl_sampler; + break; + default: + SPIRV_CROSS_THROW("Unexpected argument buffer resource base type. When padding argument buffer elements, " + "all descriptor set resources must be supplied with a base type by the app."); + break; + } + StageSetBinding arg_idx_tuple = { binding.stage, binding.desc_set, arg_buff_idx }; + resource_arg_buff_idx_to_binding_number[arg_idx_tuple] = binding.binding; + } } void CompilerMSL::add_dynamic_buffer(uint32_t desc_set, uint32_t binding, uint32_t index) @@ -15106,26 +15132,26 @@ void CompilerMSL::analyze_argument_buffers() { while (resource.index > next_arg_buff_index) { - auto& rez_bind = get_argument_buffer_resource(desc_set, next_arg_buff_index); + auto &rez_bind = get_argument_buffer_resource(desc_set, next_arg_buff_index); switch (rez_bind.base_type) { - case SPIRType::Void: - add_argument_buffer_padding_buffer_type(buffer_type, member_index, rez_bind.count); - break; - case SPIRType::Image: - add_argument_buffer_padding_image_type(buffer_type, member_index, rez_bind.count); - break; - case SPIRType::Sampler: - add_argument_buffer_padding_sampler_type(buffer_type, member_index, rez_bind.count); - break; - case SPIRType::SampledImage: - add_argument_buffer_padding_image_type(buffer_type, member_index, rez_bind.count); - add_argument_buffer_padding_sampler_type(buffer_type, member_index, rez_bind.count); - break; - default: - break; + case SPIRType::Void: + next_arg_buff_index += add_argument_buffer_padding_buffer_type(buffer_type, member_index, rez_bind.count); + break; + case SPIRType::Image: + next_arg_buff_index += add_argument_buffer_padding_image_type(buffer_type, member_index, rez_bind.count); + break; + case SPIRType::Sampler: + next_arg_buff_index += add_argument_buffer_padding_sampler_type(buffer_type, member_index, rez_bind.count); + break; + case SPIRType::SampledImage: + next_arg_buff_index += add_argument_buffer_padding_image_type(buffer_type, member_index, rez_bind.count); + next_arg_buff_index += add_argument_buffer_padding_sampler_type(buffer_type, member_index, rez_bind.count); + break; + default: + next_arg_buff_index += rez_bind.count; + break; } - next_arg_buff_index += rez_bind.count; } // Adjust the number of slots consumed by current member itself. @@ -15134,6 +15160,10 @@ void CompilerMSL::analyze_argument_buffers() if (elem_cnt == 0) elem_cnt = get_resource_array_size(var.self); + // And if the member is a combined image sampler, it takes double the slots + if (type.basetype == SPIRType::SampledImage) + elem_cnt *= 2; + next_arg_buff_index += elem_cnt; } @@ -15237,40 +15267,30 @@ void CompilerMSL::analyze_argument_buffers() // Return the resource type of the app-provided resources for the descriptor set, // that matches the resource index of the argument buffer index. -MSLResourceBinding& CompilerMSL::get_argument_buffer_resource(uint32_t desc_set, uint32_t arg_idx) +// This is a two-step lookup, first lookup the resource binding number from the argument buffer index, +// then lookup the resource binding using the binding number. +MSLResourceBinding &CompilerMSL::get_argument_buffer_resource(uint32_t desc_set, uint32_t arg_idx) { - for (auto itr = resource_bindings.begin(); itr != resource_bindings.end(); itr++) + auto stage = get_entry_point().model; + StageSetBinding arg_idx_tuple = { stage, desc_set, arg_idx }; + auto arg_itr = resource_arg_buff_idx_to_binding_number.find(arg_idx_tuple); + if (arg_itr != end(resource_arg_buff_idx_to_binding_number)) { - auto& rez_bind = itr->second.first; - uint32_t rez_idx = ~0; - switch (rez_bind.base_type) - { - case SPIRType::Void: - rez_idx = rez_bind.msl_buffer; - break; - case SPIRType::Image: - case SPIRType::SampledImage: - rez_idx = rez_bind.msl_texture; - break; - case SPIRType::Sampler: - rez_idx = rez_bind.msl_sampler; - break; - default: - SPIRV_CROSS_THROW("Unexpected argument buffer resource base type. When padding argument buffer elements, all descriptor set resources must be supplied with a base type by the app."); - break; - } - - if (rez_bind.desc_set == desc_set && rez_idx == arg_idx) - return rez_bind; + StageSetBinding bind_tuple = { stage, desc_set, arg_itr->second }; + auto bind_itr = resource_bindings.find(bind_tuple); + if (bind_itr != end(resource_bindings)) + return bind_itr->second.first; } - SPIRV_CROSS_THROW("Argument buffer resource base type could not be determined. When padding argument buffer elements, all descriptor set resources must be supplied with a base type by the app."); + SPIRV_CROSS_THROW("Argument buffer resource base type could not be determined. When padding argument buffer " + "elements, all descriptor set resources must be supplied with a base type by the app."); static MSLResourceBinding unkwn_rez; return unkwn_rez; } // Adds an argument buffer padding argument buffer type as one or more members of the struct type at the member index. // Metal does not support arrays of buffers, so these are emitted as multiple struct members. -void CompilerMSL::add_argument_buffer_padding_buffer_type(SPIRType& struct_type, uint32_t& mbr_idx, uint32_t count) +// Returns the number of argument buffer slots consumed. +uint32_t CompilerMSL::add_argument_buffer_padding_buffer_type(SPIRType &struct_type, uint32_t &mbr_idx, uint32_t count) { if (!argument_buffer_padding_buffer_type_id) { @@ -15288,13 +15308,16 @@ void CompilerMSL::add_argument_buffer_padding_buffer_type(SPIRType& struct_type, argument_buffer_padding_buffer_type_id = ptr_type_id; } - for (uint32_t rez_idx = 0; rez_idx < count; rez_idx++) { + + for (uint32_t rez_idx = 0; rez_idx < count; rez_idx++) add_argument_buffer_padding_type(argument_buffer_padding_buffer_type_id, struct_type, mbr_idx, 1); - } + + return count; } // Adds an argument buffer padding argument image type as a member of the struct type at the member index. -void CompilerMSL::add_argument_buffer_padding_image_type(SPIRType& struct_type, uint32_t& mbr_idx, uint32_t count) +// Returns the number of argument buffer slots consumed. +uint32_t CompilerMSL::add_argument_buffer_padding_image_type(SPIRType &struct_type, uint32_t &mbr_idx, uint32_t count) { if (!argument_buffer_padding_image_type_id) { @@ -15321,11 +15344,13 @@ void CompilerMSL::add_argument_buffer_padding_image_type(SPIRType& struct_type, } add_argument_buffer_padding_type(argument_buffer_padding_image_type_id, struct_type, mbr_idx, count); + + return count; } // Adds an argument buffer padding argument sampler type as a member of the struct type at the member index. // Returns the number of argument buffer slots consumed. -void CompilerMSL::add_argument_buffer_padding_sampler_type(SPIRType& struct_type, uint32_t& mbr_idx, uint32_t count) +uint32_t CompilerMSL::add_argument_buffer_padding_sampler_type(SPIRType &struct_type, uint32_t &mbr_idx, uint32_t count) { if (!argument_buffer_padding_sampler_type_id) { @@ -15338,11 +15363,13 @@ void CompilerMSL::add_argument_buffer_padding_sampler_type(SPIRType& struct_type } add_argument_buffer_padding_type(argument_buffer_padding_sampler_type_id, struct_type, mbr_idx, count); + + return count; } // Adds the argument buffer padding argument type as a member of the struct type at the member index. -// Returns the number of argument buffer slots consumed. -void CompilerMSL::add_argument_buffer_padding_type(uint32_t mbr_type_id, SPIRType& struct_type, uint32_t& mbr_idx, uint32_t count) +void CompilerMSL::add_argument_buffer_padding_type(uint32_t mbr_type_id, SPIRType &struct_type, uint32_t &mbr_idx, + uint32_t count) { uint32_t type_id = mbr_type_id; if (count > 1) diff --git a/spirv_msl.hpp b/spirv_msl.hpp index 7dda3fa4..ff7c0f4c 100644 --- a/spirv_msl.hpp +++ b/spirv_msl.hpp @@ -972,6 +972,7 @@ protected: SmallVector vars_needing_early_declaration; std::unordered_map, InternalHasher> resource_bindings; + std::unordered_map resource_arg_buff_idx_to_binding_number; uint32_t type_to_location_count(const SPIRType &type) const; uint32_t next_metal_resource_index_buffer = 0; @@ -1051,11 +1052,12 @@ protected: void analyze_argument_buffers(); bool descriptor_set_is_argument_buffer(uint32_t desc_set) const; - MSLResourceBinding& get_argument_buffer_resource(uint32_t desc_set, uint32_t arg_idx); - void add_argument_buffer_padding_buffer_type(SPIRType& struct_type, uint32_t& mbr_idx, uint32_t count); - void add_argument_buffer_padding_image_type(SPIRType& struct_type, uint32_t& mbr_idx, uint32_t count); - void add_argument_buffer_padding_sampler_type(SPIRType& struct_type, uint32_t& mbr_idx, uint32_t count); - void add_argument_buffer_padding_type(uint32_t mbr_type_id, SPIRType& struct_type, uint32_t& mbr_idx, uint32_t count); + MSLResourceBinding &get_argument_buffer_resource(uint32_t desc_set, uint32_t arg_idx); + uint32_t add_argument_buffer_padding_buffer_type(SPIRType &struct_type, uint32_t &mbr_idx, uint32_t count); + uint32_t add_argument_buffer_padding_image_type(SPIRType &struct_type, uint32_t &mbr_idx, uint32_t count); + uint32_t add_argument_buffer_padding_sampler_type(SPIRType &struct_type, uint32_t &mbr_idx, uint32_t count); + void add_argument_buffer_padding_type(uint32_t mbr_type_id, SPIRType &struct_type, uint32_t &mbr_idx, + uint32_t count); uint32_t get_target_components_for_fragment_location(uint32_t location) const; uint32_t build_extended_vector_type(uint32_t type_id, uint32_t components,