diff --git a/spirv_msl.cpp b/spirv_msl.cpp index 10c4c850..90db1373 100644 --- a/spirv_msl.cpp +++ b/spirv_msl.cpp @@ -2445,13 +2445,10 @@ void CompilerMSL::align_struct(SPIRType &ib_type, unordered_set &align // offsets, array strides and matrix strides. check_member_packing_rules_msl(ib_type, mbr_idx); - auto &mbr_type = get(ib_type.member_types[mbr_idx]); - bool is_packed = member_is_packed_physical_type(ib_type, mbr_idx); - // Align current offset to the current member's default alignment. If the member was packed, it will observe // the updated alignment here. - size_t msl_align_mask = get_declared_type_alignment_msl(mbr_type, is_packed) - 1; - auto aligned_msl_offset = uint32_t((msl_offset + msl_align_mask) & ~msl_align_mask); + size_t msl_align_mask = get_declared_struct_member_alignment_msl(ib_type, mbr_idx) - 1; + uint32_t aligned_msl_offset = (msl_offset + msl_align_mask) & ~msl_align_mask; // Fetch the member offset as declared in the SPIRV. uint32_t spirv_mbr_offset = get_member_decoration(ib_type_id, mbr_idx, DecorationOffset); @@ -2465,7 +2462,7 @@ void CompilerMSL::align_struct(SPIRType &ib_type, unordered_set &align // Re-align as a sanity check that aligning post-padding matches up. msl_offset += padding_bytes; - aligned_msl_offset = uint32_t((msl_offset + msl_align_mask) & ~msl_align_mask); + aligned_msl_offset = (msl_offset + msl_align_mask) & ~msl_align_mask; } else if (spirv_mbr_offset < aligned_msl_offset) { @@ -2478,7 +2475,7 @@ void CompilerMSL::align_struct(SPIRType &ib_type, unordered_set &align // Increment the current offset to be positioned immediately after the current member. // Don't do this for the last member since it can be unsized, and it is not relevant for padding purposes here. if (mbr_idx + 1 < mbr_cnt) - msl_offset = aligned_msl_offset + uint32_t(get_declared_type_size_msl(mbr_type, is_packed)); + msl_offset = aligned_msl_offset + get_declared_struct_member_size_msl(ib_type, mbr_idx); } } @@ -2488,7 +2485,6 @@ void CompilerMSL::align_struct(SPIRType &ib_type, unordered_set &align void CompilerMSL::check_member_packing_rules_msl(SPIRType &ib_type, uint32_t index) { auto &mbr_type = get(ib_type.member_types[index]); - bool is_packed = member_is_packed_physical_type(mbr_type, index); uint32_t spirv_offset = get_member_decoration(ib_type.self, index, DecorationOffset); bool conforms_to_packing_rules = true; @@ -2501,7 +2497,7 @@ void CompilerMSL::check_member_packing_rules_msl(SPIRType &ib_type, uint32_t ind uint32_t spirv_offset_next = get_member_decoration(ib_type.self, index + 1, DecorationOffset); assert(spirv_offset_next >= spirv_offset); uint32_t maximum_size = spirv_offset_next - spirv_offset; - uint32_t msl_mbr_size = get_declared_type_size_msl(mbr_type, is_packed); + uint32_t msl_mbr_size = get_declared_struct_member_size_msl(ib_type, index); if (msl_mbr_size > maximum_size) conforms_to_packing_rules = false; } @@ -2510,7 +2506,7 @@ void CompilerMSL::check_member_packing_rules_msl(SPIRType &ib_type, uint32_t ind { // If we have an array type, array stride must match exactly with SPIR-V. uint32_t spirv_array_stride = type_struct_member_array_stride(ib_type, index); - uint32_t msl_array_stride = get_declared_type_array_stride_msl(mbr_type, is_packed); + uint32_t msl_array_stride = get_declared_struct_member_array_stride_msl(ib_type, index); if (spirv_array_stride != msl_array_stride) conforms_to_packing_rules = false; } @@ -2519,13 +2515,13 @@ void CompilerMSL::check_member_packing_rules_msl(SPIRType &ib_type, uint32_t ind { // Need to check MatrixStride as well. uint32_t spirv_matrix_stride = type_struct_member_matrix_stride(ib_type, index); - uint32_t msl_matrix_stride = get_declared_type_matrix_stride_msl(mbr_type, is_packed); + uint32_t msl_matrix_stride = get_declared_struct_member_matrix_stride_msl(ib_type, index); if (spirv_matrix_stride != msl_matrix_stride) conforms_to_packing_rules = false; } // Now, we check alignment. - uint32_t msl_alignment = get_declared_type_alignment_msl(mbr_type, is_packed); + uint32_t msl_alignment = get_declared_struct_member_alignment_msl(ib_type, index); if ((spirv_offset % msl_alignment) != 0) conforms_to_packing_rules = false; @@ -2533,6 +2529,8 @@ void CompilerMSL::check_member_packing_rules_msl(SPIRType &ib_type, uint32_t ind if (conforms_to_packing_rules) return; + // Perform remapping here. + set_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypePacked); } uint32_t CompilerMSL::get_member_packed_type(SPIRType &type, uint32_t index) @@ -8650,6 +8648,14 @@ string CompilerMSL::built_in_func_arg(BuiltIn builtin, bool prefix_comma) return bi_arg; } +const SPIRType &CompilerMSL::get_physical_member_type(const SPIRType &type, uint32_t index) const +{ + if (member_is_remapped_physical_type(type, index)) + return get(get_extended_member_decoration(type.self, index, SPIRVCrossDecorationPhysicalTypeID)); + else + return get(type.member_types[index]); +} + uint32_t CompilerMSL::get_declared_type_array_stride_msl(const SPIRType &type, bool is_packed) const { // Array stride in MSL is always size * array_size. sizeof(float3) == 16, @@ -8664,6 +8670,12 @@ uint32_t CompilerMSL::get_declared_type_array_stride_msl(const SPIRType &type, b return get_declared_type_size_msl(parent_type, is_packed); } +uint32_t CompilerMSL::get_declared_struct_member_array_stride_msl(const SPIRType &type, uint32_t index) const +{ + return get_declared_type_array_stride_msl(get_physical_member_type(type, index), + member_is_packed_physical_type(type, index)); +} + uint32_t CompilerMSL::get_declared_type_matrix_stride_msl(const SPIRType &type, bool packed) const { // For packed matrices, we just use the size of the vector type. @@ -8674,6 +8686,12 @@ uint32_t CompilerMSL::get_declared_type_matrix_stride_msl(const SPIRType &type, return get_declared_type_alignment_msl(type, false); } +uint32_t CompilerMSL::get_declared_struct_member_matrix_stride_msl(const SPIRType &type, uint32_t index) const +{ + return get_declared_type_matrix_stride_msl(get_physical_member_type(type, index), + member_is_packed_physical_type(type, index)); +} + uint32_t CompilerMSL::get_declared_struct_size_msl(const SPIRType &struct_type) const { if (struct_type.member_types.empty()) @@ -8685,19 +8703,14 @@ uint32_t CompilerMSL::get_declared_struct_size_msl(const SPIRType &struct_type) uint32_t alignment = 1; for (uint32_t i = 0; i < mbr_cnt; i++) { - auto &mbr_type = get(struct_type.member_types[i]); - bool is_packed = has_extended_member_decoration(struct_type.self, i, SPIRVCrossDecorationPhysicalTypePacked); - uint32_t mbr_alignment = get_declared_type_alignment_msl(mbr_type, is_packed); + uint32_t mbr_alignment = get_declared_struct_member_alignment_msl(struct_type, i); alignment = max(alignment, mbr_alignment); } - auto &mbr_type = get(struct_type.member_types[mbr_cnt - 1]); - bool is_packed = member_is_packed_physical_type(struct_type, mbr_cnt - 1); - // Last member will always be matched to the final Offset decoration, but size of struct in MSL now depends // on physical size in MSL, and the size of the struct itself is then aligned to struct alignment. uint32_t spirv_offset = type_struct_member_offset(struct_type, mbr_cnt - 1); - uint32_t msl_size = spirv_offset + get_declared_type_size_msl(mbr_type, is_packed); + uint32_t msl_size = spirv_offset + get_declared_struct_member_size_msl(struct_type, mbr_cnt - 1); msl_size = (msl_size + alignment - 1) & ~(alignment - 1); return msl_size; } @@ -8743,6 +8756,12 @@ uint32_t CompilerMSL::get_declared_type_size_msl(const SPIRType &type, bool is_p } } +uint32_t CompilerMSL::get_declared_struct_member_size_msl(const SPIRType &type, uint32_t index) const +{ + return get_declared_struct_member_size_msl(get_physical_member_type(type, index), + member_is_packed_physical_type(type, index)); +} + // Returns the byte alignment of a type. uint32_t CompilerMSL::get_declared_type_alignment_msl(const SPIRType &type, bool is_packed) const { @@ -8768,7 +8787,7 @@ uint32_t CompilerMSL::get_declared_type_alignment_msl(const SPIRType &type, bool // In MSL, a struct's alignment is equal to the maximum alignment of any of its members. uint32_t alignment = 1; for (uint32_t i = 0; i < type.member_types.size(); i++) - alignment = max(alignment, uint32_t(get_declared_type_alignment_msl(type, i))); + alignment = max(alignment, uint32_t(get_declared_struct_member_alignment_msl(type, i))); return alignment; } @@ -8791,6 +8810,12 @@ uint32_t CompilerMSL::get_declared_type_alignment_msl(const SPIRType &type, bool } } +uint32_t CompilerMSL::get_declared_struct_member_alignment_msl(const SPIRType &type, uint32_t index) const +{ + return get_declared_struct_member_alignment_msl(get_physical_member_type(type, index), + member_is_packed_physical_type(type, index)); +} + bool CompilerMSL::skip_argument(uint32_t) const { return false; diff --git a/spirv_msl.hpp b/spirv_msl.hpp index 3146bbdf..0595862c 100644 --- a/spirv_msl.hpp +++ b/spirv_msl.hpp @@ -517,6 +517,14 @@ protected: uint32_t get_declared_type_array_stride_msl(const SPIRType &type, bool packed) const; uint32_t get_declared_type_matrix_stride_msl(const SPIRType &type, bool packed) const; uint32_t get_declared_type_alignment_msl(const SPIRType &type, bool packed) const; + + uint32_t get_declared_struct_member_size_msl(const SPIRType &struct_type, uint32_t index) const; + uint32_t get_declared_struct_member_array_stride_msl(const SPIRType &struct_type, uint32_t index) const; + uint32_t get_declared_struct_member_matrix_stride_msl(const SPIRType &struct_type, uint32_t index) const; + uint32_t get_declared_struct_member_alignment_msl(const SPIRType &struct_type, uint32_t index) const; + + const SPIRType &get_physical_member_type(const SPIRType &struct_type, uint32_t index) const; + uint32_t get_declared_struct_size_msl(const SPIRType &struct_type) const; std::string to_component_argument(uint32_t id);