diff --git a/spirv_glsl.cpp b/spirv_glsl.cpp index ad8b4703..289b81e8 100644 --- a/spirv_glsl.cpp +++ b/spirv_glsl.cpp @@ -2237,7 +2237,7 @@ string CompilerGLSL::to_expression(uint32_t id) if (e.base_expression) return to_enclosed_expression(e.base_expression) + e.expression; else if (e.need_transpose) - return convert_row_major_matrix(e.expression); + return convert_row_major_matrix(e.expression, get(e.expression_type)); else return e.expression; } @@ -4350,7 +4350,7 @@ string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indice { if (row_major_matrix_needs_conversion) { - expr = convert_row_major_matrix(expr); + expr = convert_row_major_matrix(expr, *type); row_major_matrix_needs_conversion = false; } @@ -4529,7 +4529,7 @@ std::string CompilerGLSL::flattened_access_chain_struct(uint32_t base, const uin // Cannot forward transpositions, so resolve them here. if (need_transpose) - expr += convert_row_major_matrix(tmp); + expr += convert_row_major_matrix(tmp, member_type); else expr += tmp; } @@ -5654,9 +5654,10 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction) case OpVectorTimesMatrix: case OpMatrixTimesVector: { - // If the matrix needs transpose, just flip the multiply order. + // If the matrix needs transpose and it is square, just flip the multiply order. + SPIRType *t; auto *e = maybe_get(ops[opcode == OpMatrixTimesVector ? 2 : 3]); - if (e && e->need_transpose) + if (e && e->need_transpose && (t = &get(e->expression_type)) && t->columns == t->vecsize) { e->need_transpose = false; emit_binary_op(ops[0], ops[1], ops[3], ops[2], "*"); @@ -6840,7 +6841,7 @@ bool CompilerGLSL::member_is_packed_type(const SPIRType &type, uint32_t index) c // row_major matrix result of the expression to a column_major matrix. // Base implementation uses the standard library transpose() function. // Subclasses may override to use a different function. -string CompilerGLSL::convert_row_major_matrix(string exp_str) +string CompilerGLSL::convert_row_major_matrix(string exp_str, const SPIRType & /*exp_type*/) { strip_enclosed_expression(exp_str); return join("transpose(", exp_str, ")"); diff --git a/spirv_glsl.hpp b/spirv_glsl.hpp index 7ff1111b..25a95f49 100644 --- a/spirv_glsl.hpp +++ b/spirv_glsl.hpp @@ -297,10 +297,10 @@ protected: void add_resource_name(uint32_t id); void add_member_name(SPIRType &type, uint32_t name); - bool is_non_native_row_major_matrix(uint32_t id); - bool member_is_non_native_row_major_matrix(const SPIRType &type, uint32_t index); + virtual bool is_non_native_row_major_matrix(uint32_t id); + virtual bool member_is_non_native_row_major_matrix(const SPIRType &type, uint32_t index); bool member_is_packed_type(const SPIRType &type, uint32_t index) const; - virtual std::string convert_row_major_matrix(std::string exp_str); + virtual std::string convert_row_major_matrix(std::string exp_str, const SPIRType &exp_type); std::unordered_set local_variable_names; std::unordered_set resource_names; diff --git a/spirv_msl.cpp b/spirv_msl.cpp index 28ec36c0..b7b085dd 100644 --- a/spirv_msl.cpp +++ b/spirv_msl.cpp @@ -1048,6 +1048,64 @@ void CompilerMSL::emit_custom_functions() statement(""); break; + case SPVFuncImplRowMajor2x3: + statement("// Implementation of a conversion of matrix content from RowMajor to ColumnMajor organization."); + statement("float2x3 spvConvertFromRowMajor2x3(float2x3 m)"); + begin_scope(); + statement("return float2x3(float3(m[0][0], m[0][2], m[1][1]), float3(m[0][1], m[1][0], m[1][2]));"); + end_scope(); + statement(""); + break; + + case SPVFuncImplRowMajor2x4: + statement("// Implementation of a conversion of matrix content from RowMajor to ColumnMajor organization."); + statement("float2x4 spvConvertFromRowMajor2x4(float2x4 m)"); + begin_scope(); + statement("return float2x4(float4(m[0][0], m[0][2], m[1][0], m[1][2]), float4(m[0][1], m[0][3], m[1][1], " + "m[1][3]));"); + end_scope(); + statement(""); + break; + + case SPVFuncImplRowMajor3x2: + statement("// Implementation of a conversion of matrix content from RowMajor to ColumnMajor organization."); + statement("float3x2 spvConvertFromRowMajor3x2(float3x2 m)"); + begin_scope(); + statement("return float3x2(float2(m[0][0], m[1][1]), float2(m[0][1], m[2][0]), float2(m[1][0], m[2][1]));"); + end_scope(); + statement(""); + break; + + case SPVFuncImplRowMajor3x4: + statement("// Implementation of a conversion of matrix content from RowMajor to ColumnMajor organization."); + statement("float3x4 spvConvertFromRowMajor3x4(float3x4 m)"); + begin_scope(); + statement("return float3x4(float4(m[0][0], m[0][3], m[1][2], m[2][1]), float4(m[0][1], m[1][0], m[1][3], " + "m[2][2]), float4(m[0][2], m[1][1], m[2][0], m[2][3]));"); + end_scope(); + statement(""); + break; + + case SPVFuncImplRowMajor4x2: + statement("// Implementation of a conversion of matrix content from RowMajor to ColumnMajor organization."); + statement("float4x2 spvConvertFromRowMajor4x2(float4x2 m)"); + begin_scope(); + statement("return float4x2(float2(m[0][0], m[2][0]), float2(m[0][1], m[2][1]), float2(m[1][0], m[3][0]), " + "float2(m[1][1], m[3][1]));"); + end_scope(); + statement(""); + break; + + case SPVFuncImplRowMajor4x3: + statement("// Implementation of a conversion of matrix content from RowMajor to ColumnMajor organization."); + statement("float4x3 spvConvertFromRowMajor4x3(float4x3 m)"); + begin_scope(); + statement("return float4x3(float3(m[0][0], m[1][1], m[2][2]), float3(m[0][1], m[1][2], m[3][0]), " + "float3(m[0][2], m[2][0], m[3][1]), float3(m[1][0], m[2][1], m[3][2]));"); + end_scope(); + statement(""); + break; + default: break; } @@ -2220,6 +2278,82 @@ string CompilerMSL::to_sampler_expression(uint32_t id) return samp_id ? to_expression(samp_id) : to_expression(id) + sampler_name_suffix; } +// Checks whether the ID is a row_major matrix that requires conversion before use +bool CompilerMSL::is_non_native_row_major_matrix(uint32_t id) +{ + // Natively supported row-major matrices do not need to be converted. + // Legacy targets do not support row major. + if (backend.native_row_major_matrix && !is_legacy()) + return false; + + // Non-matrix or column-major matrix types do not need to be converted. + if (!(meta[id].decoration.decoration_flags & (1ull << DecorationRowMajor))) + return false; + + // Generate a function that will swap matrix elements from row-major to column-major. + const auto type = expression_type(id); + add_convert_row_major_matrix_function(type.columns, type.vecsize); + return true; +} + +// Checks whether the member is a row_major matrix that requires conversion before use +bool CompilerMSL::member_is_non_native_row_major_matrix(const SPIRType &type, uint32_t index) +{ + // Natively supported row-major matrices do not need to be converted. + if (backend.native_row_major_matrix && !is_legacy()) + return false; + + // Non-matrix or column-major matrix types do not need to be converted. + if (!(combined_decoration_for_member(type, index) & (1ull << DecorationRowMajor))) + return false; + + // Generate a function that will swap matrix elements from row-major to column-major. + const auto mbr_type = get(type.member_types[index]); + add_convert_row_major_matrix_function(mbr_type.columns, mbr_type.vecsize); + return true; +} + +// Adds a function suitable for converting a non-square row-major matrix to a column-major matrix. +void CompilerMSL::add_convert_row_major_matrix_function(uint32_t cols, uint32_t rows) +{ + SPVFuncImpl spv_func; + if (cols == rows) // Square matrix...just use transpose() function + return; + else if (cols == 2 && rows == 3) + spv_func = SPVFuncImplRowMajor2x3; + else if (cols == 2 && rows == 4) + spv_func = SPVFuncImplRowMajor2x4; + else if (cols == 3 && rows == 2) + spv_func = SPVFuncImplRowMajor3x2; + else if (cols == 3 && rows == 4) + spv_func = SPVFuncImplRowMajor3x4; + else if (cols == 4 && rows == 2) + spv_func = SPVFuncImplRowMajor4x2; + else if (cols == 4 && rows == 3) + spv_func = SPVFuncImplRowMajor4x3; + else + SPIRV_CROSS_THROW("Could not convert row-major matrix."); + + auto rslt = spv_function_implementations.insert(spv_func); + if (rslt.second) + force_recompile = true; +} + +// Wraps the expression string in a function call that converts the +// row_major matrix result of the expression to a column_major matrix. +string CompilerMSL::convert_row_major_matrix(string exp_str, const SPIRType &exp_type) +{ + strip_enclosed_expression(exp_str); + + string func_name; + if (exp_type.columns == exp_type.vecsize) + func_name = "transpose"; + else + func_name = string("spvConvertFromRowMajor") + to_string(exp_type.columns) + "x" + to_string(exp_type.vecsize); + + return join(func_name, "(", exp_str, ")"); +} + // Called automatically at the end of the entry point function void CompilerMSL::emit_fixup() { diff --git a/spirv_msl.hpp b/spirv_msl.hpp index 3ec98397..8994561a 100644 --- a/spirv_msl.hpp +++ b/spirv_msl.hpp @@ -137,6 +137,12 @@ public: SPVFuncImplInverse2x2, SPVFuncImplInverse3x3, SPVFuncImplInverse4x4, + SPVFuncImplRowMajor2x3, + SPVFuncImplRowMajor2x4, + SPVFuncImplRowMajor3x2, + SPVFuncImplRowMajor3x4, + SPVFuncImplRowMajor4x2, + SPVFuncImplRowMajor4x3, }; // Constructs an instance to compile the SPIR-V code into Metal Shading Language, @@ -201,6 +207,9 @@ protected: std::string to_qualifiers_glsl(uint32_t id) override; void replace_illegal_names() override; void declare_undefined_values() override; + bool is_non_native_row_major_matrix(uint32_t id) override; + bool member_is_non_native_row_major_matrix(const SPIRType &type, uint32_t index) override; + std::string convert_row_major_matrix(std::string exp_str, const SPIRType &exp_type) override; void preprocess_op_codes(); void localize_global_variables(); @@ -222,6 +231,7 @@ protected: void emit_interface_block(uint32_t ib_var_id); bool maybe_emit_input_struct_assignment(uint32_t id_lhs, uint32_t id_rhs); bool maybe_emit_array_assignment(uint32_t id_lhs, uint32_t id_rhs); + void add_convert_row_major_matrix_function(uint32_t cols, uint32_t rows); std::string func_type_decl(SPIRType &type); std::string entry_point_args(bool append_comma);