CompilerMSL support conversion of non-square row-major matrices.
This commit is contained in:
parent
1845f31397
commit
8890578d2a
@ -2237,7 +2237,7 @@ string CompilerGLSL::to_expression(uint32_t id)
|
||||
if (e.base_expression)
|
||||
return to_enclosed_expression(e.base_expression) + e.expression;
|
||||
else if (e.need_transpose)
|
||||
return convert_row_major_matrix(e.expression);
|
||||
return convert_row_major_matrix(e.expression, get<SPIRType>(e.expression_type));
|
||||
else
|
||||
return e.expression;
|
||||
}
|
||||
@ -4350,7 +4350,7 @@ string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indice
|
||||
{
|
||||
if (row_major_matrix_needs_conversion)
|
||||
{
|
||||
expr = convert_row_major_matrix(expr);
|
||||
expr = convert_row_major_matrix(expr, *type);
|
||||
row_major_matrix_needs_conversion = false;
|
||||
}
|
||||
|
||||
@ -4529,7 +4529,7 @@ std::string CompilerGLSL::flattened_access_chain_struct(uint32_t base, const uin
|
||||
|
||||
// Cannot forward transpositions, so resolve them here.
|
||||
if (need_transpose)
|
||||
expr += convert_row_major_matrix(tmp);
|
||||
expr += convert_row_major_matrix(tmp, member_type);
|
||||
else
|
||||
expr += tmp;
|
||||
}
|
||||
@ -5654,9 +5654,10 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction)
|
||||
case OpVectorTimesMatrix:
|
||||
case OpMatrixTimesVector:
|
||||
{
|
||||
// If the matrix needs transpose, just flip the multiply order.
|
||||
// If the matrix needs transpose and it is square, just flip the multiply order.
|
||||
SPIRType *t;
|
||||
auto *e = maybe_get<SPIRExpression>(ops[opcode == OpMatrixTimesVector ? 2 : 3]);
|
||||
if (e && e->need_transpose)
|
||||
if (e && e->need_transpose && (t = &get<SPIRType>(e->expression_type)) && t->columns == t->vecsize)
|
||||
{
|
||||
e->need_transpose = false;
|
||||
emit_binary_op(ops[0], ops[1], ops[3], ops[2], "*");
|
||||
@ -6840,7 +6841,7 @@ bool CompilerGLSL::member_is_packed_type(const SPIRType &type, uint32_t index) c
|
||||
// row_major matrix result of the expression to a column_major matrix.
|
||||
// Base implementation uses the standard library transpose() function.
|
||||
// Subclasses may override to use a different function.
|
||||
string CompilerGLSL::convert_row_major_matrix(string exp_str)
|
||||
string CompilerGLSL::convert_row_major_matrix(string exp_str, const SPIRType & /*exp_type*/)
|
||||
{
|
||||
strip_enclosed_expression(exp_str);
|
||||
return join("transpose(", exp_str, ")");
|
||||
|
@ -297,10 +297,10 @@ protected:
|
||||
void add_resource_name(uint32_t id);
|
||||
void add_member_name(SPIRType &type, uint32_t name);
|
||||
|
||||
bool is_non_native_row_major_matrix(uint32_t id);
|
||||
bool member_is_non_native_row_major_matrix(const SPIRType &type, uint32_t index);
|
||||
virtual bool is_non_native_row_major_matrix(uint32_t id);
|
||||
virtual bool member_is_non_native_row_major_matrix(const SPIRType &type, uint32_t index);
|
||||
bool member_is_packed_type(const SPIRType &type, uint32_t index) const;
|
||||
virtual std::string convert_row_major_matrix(std::string exp_str);
|
||||
virtual std::string convert_row_major_matrix(std::string exp_str, const SPIRType &exp_type);
|
||||
|
||||
std::unordered_set<std::string> local_variable_names;
|
||||
std::unordered_set<std::string> resource_names;
|
||||
|
134
spirv_msl.cpp
134
spirv_msl.cpp
@ -1048,6 +1048,64 @@ void CompilerMSL::emit_custom_functions()
|
||||
statement("");
|
||||
break;
|
||||
|
||||
case SPVFuncImplRowMajor2x3:
|
||||
statement("// Implementation of a conversion of matrix content from RowMajor to ColumnMajor organization.");
|
||||
statement("float2x3 spvConvertFromRowMajor2x3(float2x3 m)");
|
||||
begin_scope();
|
||||
statement("return float2x3(float3(m[0][0], m[0][2], m[1][1]), float3(m[0][1], m[1][0], m[1][2]));");
|
||||
end_scope();
|
||||
statement("");
|
||||
break;
|
||||
|
||||
case SPVFuncImplRowMajor2x4:
|
||||
statement("// Implementation of a conversion of matrix content from RowMajor to ColumnMajor organization.");
|
||||
statement("float2x4 spvConvertFromRowMajor2x4(float2x4 m)");
|
||||
begin_scope();
|
||||
statement("return float2x4(float4(m[0][0], m[0][2], m[1][0], m[1][2]), float4(m[0][1], m[0][3], m[1][1], "
|
||||
"m[1][3]));");
|
||||
end_scope();
|
||||
statement("");
|
||||
break;
|
||||
|
||||
case SPVFuncImplRowMajor3x2:
|
||||
statement("// Implementation of a conversion of matrix content from RowMajor to ColumnMajor organization.");
|
||||
statement("float3x2 spvConvertFromRowMajor3x2(float3x2 m)");
|
||||
begin_scope();
|
||||
statement("return float3x2(float2(m[0][0], m[1][1]), float2(m[0][1], m[2][0]), float2(m[1][0], m[2][1]));");
|
||||
end_scope();
|
||||
statement("");
|
||||
break;
|
||||
|
||||
case SPVFuncImplRowMajor3x4:
|
||||
statement("// Implementation of a conversion of matrix content from RowMajor to ColumnMajor organization.");
|
||||
statement("float3x4 spvConvertFromRowMajor3x4(float3x4 m)");
|
||||
begin_scope();
|
||||
statement("return float3x4(float4(m[0][0], m[0][3], m[1][2], m[2][1]), float4(m[0][1], m[1][0], m[1][3], "
|
||||
"m[2][2]), float4(m[0][2], m[1][1], m[2][0], m[2][3]));");
|
||||
end_scope();
|
||||
statement("");
|
||||
break;
|
||||
|
||||
case SPVFuncImplRowMajor4x2:
|
||||
statement("// Implementation of a conversion of matrix content from RowMajor to ColumnMajor organization.");
|
||||
statement("float4x2 spvConvertFromRowMajor4x2(float4x2 m)");
|
||||
begin_scope();
|
||||
statement("return float4x2(float2(m[0][0], m[2][0]), float2(m[0][1], m[2][1]), float2(m[1][0], m[3][0]), "
|
||||
"float2(m[1][1], m[3][1]));");
|
||||
end_scope();
|
||||
statement("");
|
||||
break;
|
||||
|
||||
case SPVFuncImplRowMajor4x3:
|
||||
statement("// Implementation of a conversion of matrix content from RowMajor to ColumnMajor organization.");
|
||||
statement("float4x3 spvConvertFromRowMajor4x3(float4x3 m)");
|
||||
begin_scope();
|
||||
statement("return float4x3(float3(m[0][0], m[1][1], m[2][2]), float3(m[0][1], m[1][2], m[3][0]), "
|
||||
"float3(m[0][2], m[2][0], m[3][1]), float3(m[1][0], m[2][1], m[3][2]));");
|
||||
end_scope();
|
||||
statement("");
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
@ -2220,6 +2278,82 @@ string CompilerMSL::to_sampler_expression(uint32_t id)
|
||||
return samp_id ? to_expression(samp_id) : to_expression(id) + sampler_name_suffix;
|
||||
}
|
||||
|
||||
// Checks whether the ID is a row_major matrix that requires conversion before use
|
||||
bool CompilerMSL::is_non_native_row_major_matrix(uint32_t id)
|
||||
{
|
||||
// Natively supported row-major matrices do not need to be converted.
|
||||
// Legacy targets do not support row major.
|
||||
if (backend.native_row_major_matrix && !is_legacy())
|
||||
return false;
|
||||
|
||||
// Non-matrix or column-major matrix types do not need to be converted.
|
||||
if (!(meta[id].decoration.decoration_flags & (1ull << DecorationRowMajor)))
|
||||
return false;
|
||||
|
||||
// Generate a function that will swap matrix elements from row-major to column-major.
|
||||
const auto type = expression_type(id);
|
||||
add_convert_row_major_matrix_function(type.columns, type.vecsize);
|
||||
return true;
|
||||
}
|
||||
|
||||
// Checks whether the member is a row_major matrix that requires conversion before use
|
||||
bool CompilerMSL::member_is_non_native_row_major_matrix(const SPIRType &type, uint32_t index)
|
||||
{
|
||||
// Natively supported row-major matrices do not need to be converted.
|
||||
if (backend.native_row_major_matrix && !is_legacy())
|
||||
return false;
|
||||
|
||||
// Non-matrix or column-major matrix types do not need to be converted.
|
||||
if (!(combined_decoration_for_member(type, index) & (1ull << DecorationRowMajor)))
|
||||
return false;
|
||||
|
||||
// Generate a function that will swap matrix elements from row-major to column-major.
|
||||
const auto mbr_type = get<SPIRType>(type.member_types[index]);
|
||||
add_convert_row_major_matrix_function(mbr_type.columns, mbr_type.vecsize);
|
||||
return true;
|
||||
}
|
||||
|
||||
// Adds a function suitable for converting a non-square row-major matrix to a column-major matrix.
|
||||
void CompilerMSL::add_convert_row_major_matrix_function(uint32_t cols, uint32_t rows)
|
||||
{
|
||||
SPVFuncImpl spv_func;
|
||||
if (cols == rows) // Square matrix...just use transpose() function
|
||||
return;
|
||||
else if (cols == 2 && rows == 3)
|
||||
spv_func = SPVFuncImplRowMajor2x3;
|
||||
else if (cols == 2 && rows == 4)
|
||||
spv_func = SPVFuncImplRowMajor2x4;
|
||||
else if (cols == 3 && rows == 2)
|
||||
spv_func = SPVFuncImplRowMajor3x2;
|
||||
else if (cols == 3 && rows == 4)
|
||||
spv_func = SPVFuncImplRowMajor3x4;
|
||||
else if (cols == 4 && rows == 2)
|
||||
spv_func = SPVFuncImplRowMajor4x2;
|
||||
else if (cols == 4 && rows == 3)
|
||||
spv_func = SPVFuncImplRowMajor4x3;
|
||||
else
|
||||
SPIRV_CROSS_THROW("Could not convert row-major matrix.");
|
||||
|
||||
auto rslt = spv_function_implementations.insert(spv_func);
|
||||
if (rslt.second)
|
||||
force_recompile = true;
|
||||
}
|
||||
|
||||
// Wraps the expression string in a function call that converts the
|
||||
// row_major matrix result of the expression to a column_major matrix.
|
||||
string CompilerMSL::convert_row_major_matrix(string exp_str, const SPIRType &exp_type)
|
||||
{
|
||||
strip_enclosed_expression(exp_str);
|
||||
|
||||
string func_name;
|
||||
if (exp_type.columns == exp_type.vecsize)
|
||||
func_name = "transpose";
|
||||
else
|
||||
func_name = string("spvConvertFromRowMajor") + to_string(exp_type.columns) + "x" + to_string(exp_type.vecsize);
|
||||
|
||||
return join(func_name, "(", exp_str, ")");
|
||||
}
|
||||
|
||||
// Called automatically at the end of the entry point function
|
||||
void CompilerMSL::emit_fixup()
|
||||
{
|
||||
|
@ -137,6 +137,12 @@ public:
|
||||
SPVFuncImplInverse2x2,
|
||||
SPVFuncImplInverse3x3,
|
||||
SPVFuncImplInverse4x4,
|
||||
SPVFuncImplRowMajor2x3,
|
||||
SPVFuncImplRowMajor2x4,
|
||||
SPVFuncImplRowMajor3x2,
|
||||
SPVFuncImplRowMajor3x4,
|
||||
SPVFuncImplRowMajor4x2,
|
||||
SPVFuncImplRowMajor4x3,
|
||||
};
|
||||
|
||||
// Constructs an instance to compile the SPIR-V code into Metal Shading Language,
|
||||
@ -201,6 +207,9 @@ protected:
|
||||
std::string to_qualifiers_glsl(uint32_t id) override;
|
||||
void replace_illegal_names() override;
|
||||
void declare_undefined_values() override;
|
||||
bool is_non_native_row_major_matrix(uint32_t id) override;
|
||||
bool member_is_non_native_row_major_matrix(const SPIRType &type, uint32_t index) override;
|
||||
std::string convert_row_major_matrix(std::string exp_str, const SPIRType &exp_type) override;
|
||||
|
||||
void preprocess_op_codes();
|
||||
void localize_global_variables();
|
||||
@ -222,6 +231,7 @@ protected:
|
||||
void emit_interface_block(uint32_t ib_var_id);
|
||||
bool maybe_emit_input_struct_assignment(uint32_t id_lhs, uint32_t id_rhs);
|
||||
bool maybe_emit_array_assignment(uint32_t id_lhs, uint32_t id_rhs);
|
||||
void add_convert_row_major_matrix_function(uint32_t cols, uint32_t rows);
|
||||
|
||||
std::string func_type_decl(SPIRType &type);
|
||||
std::string entry_point_args(bool append_comma);
|
||||
|
Loading…
Reference in New Issue
Block a user