CompilerMSL support conversion of non-square row-major matrices.

This commit is contained in:
Bill Hollings 2018-01-04 16:33:45 -05:00
parent 1845f31397
commit 8890578d2a
4 changed files with 154 additions and 9 deletions

View File

@ -2237,7 +2237,7 @@ string CompilerGLSL::to_expression(uint32_t id)
if (e.base_expression)
return to_enclosed_expression(e.base_expression) + e.expression;
else if (e.need_transpose)
return convert_row_major_matrix(e.expression);
return convert_row_major_matrix(e.expression, get<SPIRType>(e.expression_type));
else
return e.expression;
}
@ -4350,7 +4350,7 @@ string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indice
{
if (row_major_matrix_needs_conversion)
{
expr = convert_row_major_matrix(expr);
expr = convert_row_major_matrix(expr, *type);
row_major_matrix_needs_conversion = false;
}
@ -4529,7 +4529,7 @@ std::string CompilerGLSL::flattened_access_chain_struct(uint32_t base, const uin
// Cannot forward transpositions, so resolve them here.
if (need_transpose)
expr += convert_row_major_matrix(tmp);
expr += convert_row_major_matrix(tmp, member_type);
else
expr += tmp;
}
@ -5654,9 +5654,10 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction)
case OpVectorTimesMatrix:
case OpMatrixTimesVector:
{
// If the matrix needs transpose, just flip the multiply order.
// If the matrix needs transpose and it is square, just flip the multiply order.
SPIRType *t;
auto *e = maybe_get<SPIRExpression>(ops[opcode == OpMatrixTimesVector ? 2 : 3]);
if (e && e->need_transpose)
if (e && e->need_transpose && (t = &get<SPIRType>(e->expression_type)) && t->columns == t->vecsize)
{
e->need_transpose = false;
emit_binary_op(ops[0], ops[1], ops[3], ops[2], "*");
@ -6840,7 +6841,7 @@ bool CompilerGLSL::member_is_packed_type(const SPIRType &type, uint32_t index) c
// row_major matrix result of the expression to a column_major matrix.
// Base implementation uses the standard library transpose() function.
// Subclasses may override to use a different function.
string CompilerGLSL::convert_row_major_matrix(string exp_str)
string CompilerGLSL::convert_row_major_matrix(string exp_str, const SPIRType & /*exp_type*/)
{
strip_enclosed_expression(exp_str);
return join("transpose(", exp_str, ")");

View File

@ -297,10 +297,10 @@ protected:
void add_resource_name(uint32_t id);
void add_member_name(SPIRType &type, uint32_t name);
bool is_non_native_row_major_matrix(uint32_t id);
bool member_is_non_native_row_major_matrix(const SPIRType &type, uint32_t index);
virtual bool is_non_native_row_major_matrix(uint32_t id);
virtual bool member_is_non_native_row_major_matrix(const SPIRType &type, uint32_t index);
bool member_is_packed_type(const SPIRType &type, uint32_t index) const;
virtual std::string convert_row_major_matrix(std::string exp_str);
virtual std::string convert_row_major_matrix(std::string exp_str, const SPIRType &exp_type);
std::unordered_set<std::string> local_variable_names;
std::unordered_set<std::string> resource_names;

View File

@ -1048,6 +1048,64 @@ void CompilerMSL::emit_custom_functions()
statement("");
break;
case SPVFuncImplRowMajor2x3:
statement("// Implementation of a conversion of matrix content from RowMajor to ColumnMajor organization.");
statement("float2x3 spvConvertFromRowMajor2x3(float2x3 m)");
begin_scope();
statement("return float2x3(float3(m[0][0], m[0][2], m[1][1]), float3(m[0][1], m[1][0], m[1][2]));");
end_scope();
statement("");
break;
case SPVFuncImplRowMajor2x4:
statement("// Implementation of a conversion of matrix content from RowMajor to ColumnMajor organization.");
statement("float2x4 spvConvertFromRowMajor2x4(float2x4 m)");
begin_scope();
statement("return float2x4(float4(m[0][0], m[0][2], m[1][0], m[1][2]), float4(m[0][1], m[0][3], m[1][1], "
"m[1][3]));");
end_scope();
statement("");
break;
case SPVFuncImplRowMajor3x2:
statement("// Implementation of a conversion of matrix content from RowMajor to ColumnMajor organization.");
statement("float3x2 spvConvertFromRowMajor3x2(float3x2 m)");
begin_scope();
statement("return float3x2(float2(m[0][0], m[1][1]), float2(m[0][1], m[2][0]), float2(m[1][0], m[2][1]));");
end_scope();
statement("");
break;
case SPVFuncImplRowMajor3x4:
statement("// Implementation of a conversion of matrix content from RowMajor to ColumnMajor organization.");
statement("float3x4 spvConvertFromRowMajor3x4(float3x4 m)");
begin_scope();
statement("return float3x4(float4(m[0][0], m[0][3], m[1][2], m[2][1]), float4(m[0][1], m[1][0], m[1][3], "
"m[2][2]), float4(m[0][2], m[1][1], m[2][0], m[2][3]));");
end_scope();
statement("");
break;
case SPVFuncImplRowMajor4x2:
statement("// Implementation of a conversion of matrix content from RowMajor to ColumnMajor organization.");
statement("float4x2 spvConvertFromRowMajor4x2(float4x2 m)");
begin_scope();
statement("return float4x2(float2(m[0][0], m[2][0]), float2(m[0][1], m[2][1]), float2(m[1][0], m[3][0]), "
"float2(m[1][1], m[3][1]));");
end_scope();
statement("");
break;
case SPVFuncImplRowMajor4x3:
statement("// Implementation of a conversion of matrix content from RowMajor to ColumnMajor organization.");
statement("float4x3 spvConvertFromRowMajor4x3(float4x3 m)");
begin_scope();
statement("return float4x3(float3(m[0][0], m[1][1], m[2][2]), float3(m[0][1], m[1][2], m[3][0]), "
"float3(m[0][2], m[2][0], m[3][1]), float3(m[1][0], m[2][1], m[3][2]));");
end_scope();
statement("");
break;
default:
break;
}
@ -2220,6 +2278,82 @@ string CompilerMSL::to_sampler_expression(uint32_t id)
return samp_id ? to_expression(samp_id) : to_expression(id) + sampler_name_suffix;
}
// Checks whether the ID is a row_major matrix that requires conversion before use
bool CompilerMSL::is_non_native_row_major_matrix(uint32_t id)
{
// Natively supported row-major matrices do not need to be converted.
// Legacy targets do not support row major.
if (backend.native_row_major_matrix && !is_legacy())
return false;
// Non-matrix or column-major matrix types do not need to be converted.
if (!(meta[id].decoration.decoration_flags & (1ull << DecorationRowMajor)))
return false;
// Generate a function that will swap matrix elements from row-major to column-major.
const auto type = expression_type(id);
add_convert_row_major_matrix_function(type.columns, type.vecsize);
return true;
}
// Checks whether the member is a row_major matrix that requires conversion before use
bool CompilerMSL::member_is_non_native_row_major_matrix(const SPIRType &type, uint32_t index)
{
// Natively supported row-major matrices do not need to be converted.
if (backend.native_row_major_matrix && !is_legacy())
return false;
// Non-matrix or column-major matrix types do not need to be converted.
if (!(combined_decoration_for_member(type, index) & (1ull << DecorationRowMajor)))
return false;
// Generate a function that will swap matrix elements from row-major to column-major.
const auto mbr_type = get<SPIRType>(type.member_types[index]);
add_convert_row_major_matrix_function(mbr_type.columns, mbr_type.vecsize);
return true;
}
// Adds a function suitable for converting a non-square row-major matrix to a column-major matrix.
void CompilerMSL::add_convert_row_major_matrix_function(uint32_t cols, uint32_t rows)
{
SPVFuncImpl spv_func;
if (cols == rows) // Square matrix...just use transpose() function
return;
else if (cols == 2 && rows == 3)
spv_func = SPVFuncImplRowMajor2x3;
else if (cols == 2 && rows == 4)
spv_func = SPVFuncImplRowMajor2x4;
else if (cols == 3 && rows == 2)
spv_func = SPVFuncImplRowMajor3x2;
else if (cols == 3 && rows == 4)
spv_func = SPVFuncImplRowMajor3x4;
else if (cols == 4 && rows == 2)
spv_func = SPVFuncImplRowMajor4x2;
else if (cols == 4 && rows == 3)
spv_func = SPVFuncImplRowMajor4x3;
else
SPIRV_CROSS_THROW("Could not convert row-major matrix.");
auto rslt = spv_function_implementations.insert(spv_func);
if (rslt.second)
force_recompile = true;
}
// Wraps the expression string in a function call that converts the
// row_major matrix result of the expression to a column_major matrix.
string CompilerMSL::convert_row_major_matrix(string exp_str, const SPIRType &exp_type)
{
strip_enclosed_expression(exp_str);
string func_name;
if (exp_type.columns == exp_type.vecsize)
func_name = "transpose";
else
func_name = string("spvConvertFromRowMajor") + to_string(exp_type.columns) + "x" + to_string(exp_type.vecsize);
return join(func_name, "(", exp_str, ")");
}
// Called automatically at the end of the entry point function
void CompilerMSL::emit_fixup()
{

View File

@ -137,6 +137,12 @@ public:
SPVFuncImplInverse2x2,
SPVFuncImplInverse3x3,
SPVFuncImplInverse4x4,
SPVFuncImplRowMajor2x3,
SPVFuncImplRowMajor2x4,
SPVFuncImplRowMajor3x2,
SPVFuncImplRowMajor3x4,
SPVFuncImplRowMajor4x2,
SPVFuncImplRowMajor4x3,
};
// Constructs an instance to compile the SPIR-V code into Metal Shading Language,
@ -201,6 +207,9 @@ protected:
std::string to_qualifiers_glsl(uint32_t id) override;
void replace_illegal_names() override;
void declare_undefined_values() override;
bool is_non_native_row_major_matrix(uint32_t id) override;
bool member_is_non_native_row_major_matrix(const SPIRType &type, uint32_t index) override;
std::string convert_row_major_matrix(std::string exp_str, const SPIRType &exp_type) override;
void preprocess_op_codes();
void localize_global_variables();
@ -222,6 +231,7 @@ protected:
void emit_interface_block(uint32_t ib_var_id);
bool maybe_emit_input_struct_assignment(uint32_t id_lhs, uint32_t id_rhs);
bool maybe_emit_array_assignment(uint32_t id_lhs, uint32_t id_rhs);
void add_convert_row_major_matrix_function(uint32_t cols, uint32_t rows);
std::string func_type_decl(SPIRType &type);
std::string entry_point_args(bool append_comma);