Use to_unpacked_row_major_expression to unify row-major in MSL/GLSL.

This commit is contained in:
Hans-Kristian Arntzen 2019-07-23 11:36:54 +02:00
parent 47a18b9f1b
commit 7277c7ac46
7 changed files with 72 additions and 34 deletions

View File

@ -0,0 +1,21 @@
#include <metal_stdlib>
#include <simd/simd.h>
using namespace metal;
struct SSBO
{
float2x4 m0;
float2x4 m1;
float2 v0;
float2 v1;
};
kernel void main0(device SSBO& _11 [[buffer(0)]])
{
_11.v0 = _11.v1 * (float2x2(_11.m1[0].xy, _11.m1[1].xy) * float2x2(_11.m0[0].xy, _11.m0[1].xy));
_11.v0 = (_11.v1 * float2x2(_11.m1[0].xy, _11.m1[1].xy)) * float2x2(_11.m0[0].xy, _11.m0[1].xy);
_11.v0 = float2x2(_11.m1[0].xy, _11.m1[1].xy) * (float2x2(_11.m0[0].xy, _11.m0[1].xy) * _11.v1);
_11.v0 = (float2x2(_11.m1[0].xy, _11.m1[1].xy) * float2x2(_11.m0[0].xy, _11.m0[1].xy)) * _11.v1;
}

View File

@ -23,7 +23,7 @@ struct main0_in
vertex main0_out main0(main0_in in [[stage_in]], constant UBO& _18 [[buffer(0)]]) vertex main0_out main0(main0_in in [[stage_in]], constant UBO& _18 [[buffer(0)]])
{ {
main0_out out = {}; main0_out out = {};
float2 v = float2x4(_18.uMVP[0], _18.uMVP[1]) * in.aVertex; float2 v = float4x2(_18.uMVP[0].xy, _18.uMVP[1].xy, _18.uMVP[2].xy, _18.uMVP[3].xy) * in.aVertex;
out.gl_Position = (_18.uMVPR * in.aVertex) + (in.aVertex * _18.uMVPC); out.gl_Position = (_18.uMVPR * in.aVertex) + (in.aVertex * _18.uMVPC);
return out; return out;
} }

View File

@ -0,0 +1,18 @@
#version 450
layout(local_size_x = 1) in;
layout(std140, row_major, set = 0, binding = 0) buffer SSBO
{
mat2 m0;
mat2 m1;
vec2 v0;
vec2 v1;
};
void main()
{
v0 = (m0 * m1) * v1;
v0 = m0 * (m1 * v1);
v0 = (v1 * m0) * m1;
v0 = v1 * (m0 * m1);
}

View File

@ -2838,6 +2838,15 @@ string CompilerGLSL::to_enclosed_expression(uint32_t id, bool register_expressio
return enclose_expression(to_expression(id, register_expression_read)); return enclose_expression(to_expression(id, register_expression_read));
} }
// Used explicitly when we want to read a row-major expression, but without any transpose shenanigans.
// need_transpose must be forced to false.
string CompilerGLSL::to_unpacked_row_major_matrix_expression(uint32_t id)
{
return unpack_expression_type(to_expression(id), expression_type(id),
get_extended_decoration(id, SPIRVCrossDecorationPhysicalTypeID),
has_extended_decoration(id, SPIRVCrossDecorationPhysicalTypePacked), true);
}
string CompilerGLSL::to_unpacked_expression(uint32_t id, bool register_expression_read) string CompilerGLSL::to_unpacked_expression(uint32_t id, bool register_expression_read)
{ {
// If we need to transpose, it will also take care of unpacking rules. // If we need to transpose, it will also take care of unpacking rules.
@ -2845,6 +2854,7 @@ string CompilerGLSL::to_unpacked_expression(uint32_t id, bool register_expressio
bool need_transpose = e && e->need_transpose; bool need_transpose = e && e->need_transpose;
bool is_remapped = has_extended_decoration(id, SPIRVCrossDecorationPhysicalTypeID); bool is_remapped = has_extended_decoration(id, SPIRVCrossDecorationPhysicalTypeID);
bool is_packed = has_extended_decoration(id, SPIRVCrossDecorationPhysicalTypePacked); bool is_packed = has_extended_decoration(id, SPIRVCrossDecorationPhysicalTypePacked);
if (!need_transpose && (is_remapped || is_packed)) if (!need_transpose && (is_remapped || is_packed))
{ {
return unpack_expression_type(to_expression(id, register_expression_read), expression_type(id), return unpack_expression_type(to_expression(id, register_expression_read), expression_type(id),
@ -8311,8 +8321,18 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction)
if (e && e->need_transpose) if (e && e->need_transpose)
{ {
e->need_transpose = false; e->need_transpose = false;
emit_binary_op(ops[0], ops[1], ops[3], ops[2], "*"); string expr;
if (opcode == OpMatrixTimesVector)
expr = join(to_enclosed_expression(ops[3]), " * ", enclose_expression(to_unpacked_row_major_matrix_expression(ops[2])));
else
expr = join(enclose_expression(to_unpacked_row_major_matrix_expression(ops[3])), " * ", to_enclosed_expression(ops[2]));
bool forward = should_forward(ops[2]) && should_forward(ops[3]);
emit_op(ops[0], ops[1], expr, forward);
e->need_transpose = true; e->need_transpose = true;
inherit_expression_dependencies(ops[1], ops[2]);
inherit_expression_dependencies(ops[1], ops[3]);
} }
else else
GLSL_BOP(*); GLSL_BOP(*);
@ -8330,10 +8350,15 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction)
{ {
a->need_transpose = false; a->need_transpose = false;
b->need_transpose = false; b->need_transpose = false;
emit_binary_op(ops[0], ops[1], ops[3], ops[2], "*"); auto expr = join(enclose_expression(to_unpacked_row_major_matrix_expression(ops[3])), " * ",
get<SPIRExpression>(ops[1]).need_transpose = true; enclose_expression(to_unpacked_row_major_matrix_expression(ops[2])));
bool forward = should_forward(ops[2]) && should_forward(ops[3]);
auto &e = emit_op(ops[0], ops[1], expr, forward);
e.need_transpose = true;
a->need_transpose = true; a->need_transpose = true;
b->need_transpose = true; b->need_transpose = true;
inherit_expression_dependencies(ops[1], ops[2]);
inherit_expression_dependencies(ops[1], ops[3]);
} }
else else
GLSL_BOP(*); GLSL_BOP(*);

View File

@ -511,6 +511,7 @@ protected:
std::string to_rerolled_array_expression(const std::string &expr, const SPIRType &type); std::string to_rerolled_array_expression(const std::string &expr, const SPIRType &type);
std::string to_enclosed_expression(uint32_t id, bool register_expression_read = true); std::string to_enclosed_expression(uint32_t id, bool register_expression_read = true);
std::string to_unpacked_expression(uint32_t id, bool register_expression_read = true); std::string to_unpacked_expression(uint32_t id, bool register_expression_read = true);
std::string to_unpacked_row_major_matrix_expression(uint32_t id);
std::string to_enclosed_unpacked_expression(uint32_t id, bool register_expression_read = true); std::string to_enclosed_unpacked_expression(uint32_t id, bool register_expression_read = true);
std::string to_dereferenced_expression(uint32_t id, bool register_expression_read = true); std::string to_dereferenced_expression(uint32_t id, bool register_expression_read = true);
std::string to_pointer_expression(uint32_t id, bool register_expression_read = true); std::string to_pointer_expression(uint32_t id, bool register_expression_read = true);

View File

@ -2718,7 +2718,7 @@ void CompilerMSL::emit_store_statement(uint32_t lhs_expression, uint32_t rhs_exp
// Direct copy, but might need to unpack RHS. // Direct copy, but might need to unpack RHS.
// Skip the transpose, as we will transpose when writing to LHS and transpose(transpose(T)) == T. // Skip the transpose, as we will transpose when writing to LHS and transpose(transpose(T)) == T.
rhs_e->need_transpose = false; rhs_e->need_transpose = false;
statement(to_expression(lhs_expression), " = ", unpack_expression_explicit(rhs_expression, true), ";"); statement(to_expression(lhs_expression), " = ", to_unpacked_row_major_matrix_expression(rhs_expression), ";");
rhs_e->need_transpose = true; rhs_e->need_transpose = true;
} }
else else
@ -2802,7 +2802,7 @@ void CompilerMSL::emit_store_statement(uint32_t lhs_expression, uint32_t rhs_exp
for (uint32_t i = 0; i < type.vecsize; i++) for (uint32_t i = 0; i < type.vecsize; i++)
{ {
statement(to_enclosed_expression(lhs_expression), statement(to_enclosed_expression(lhs_expression),
"[", i, "]", store_swiz, " = ", unpack_expression_explicit(rhs_expression, true), "[", i, "];"); "[", i, "]", store_swiz, " = ", to_unpacked_row_major_matrix_expression(rhs_expression), "[", i, "];");
} }
} }
else else
@ -2851,7 +2851,7 @@ void CompilerMSL::emit_store_statement(uint32_t lhs_expression, uint32_t rhs_exp
for (uint32_t j = 0; j < vector_type.vecsize; j++) for (uint32_t j = 0; j < vector_type.vecsize; j++)
{ {
// Need to explicitly unpack expression since we've mucked with transpose state. // Need to explicitly unpack expression since we've mucked with transpose state.
auto unpacked_expr = unpack_expression_explicit(rhs_expression, true); auto unpacked_expr = to_unpacked_row_major_matrix_expression(rhs_expression);
rhs_row += join(unpacked_expr, "[", j, "][", i, "]"); rhs_row += join(unpacked_expr, "[", j, "][", i, "]");
if (j + 1 < vector_type.vecsize) if (j + 1 < vector_type.vecsize)
rhs_row += ", "; rhs_row += ", ";
@ -2926,13 +2926,6 @@ void CompilerMSL::emit_store_statement(uint32_t lhs_expression, uint32_t rhs_exp
} }
} }
string CompilerMSL::unpack_expression_explicit(uint32_t id, bool row_major)
{
return unpack_expression_type(to_expression(id), expression_type(id),
get_extended_decoration(id, SPIRVCrossDecorationPhysicalTypeID),
has_extended_decoration(id, SPIRVCrossDecorationPhysicalTypePacked), row_major);
}
// Converts the format of the current expression from packed to unpacked, // Converts the format of the current expression from packed to unpacked,
// by wrapping the expression in a constructor of the appropriate type. // by wrapping the expression in a constructor of the appropriate type.
// Also, handle special physical ID remapping scenarios, similar to emit_store_statement(). // Also, handle special physical ID remapping scenarios, similar to emit_store_statement().
@ -4557,23 +4550,6 @@ void CompilerMSL::emit_instruction(const Instruction &instruction)
emit_barrier(ops[0], ops[1], ops[2]); emit_barrier(ops[0], ops[1], ops[2]);
break; break;
case OpVectorTimesMatrix:
case OpMatrixTimesVector:
{
// If the matrix needs transpose and it is square or packed, just flip the multiply order.
uint32_t mtx_id = ops[opcode == OpMatrixTimesVector ? 2 : 3];
auto *e = maybe_get<SPIRExpression>(mtx_id);
if (e && e->need_transpose)
{
e->need_transpose = false;
emit_binary_op(ops[0], ops[1], ops[3], ops[2], "*");
e->need_transpose = true;
}
else
MSL_BOP(*);
break;
}
case OpOuterProduct: case OpOuterProduct:
{ {
uint32_t result_type = ops[0]; uint32_t result_type = ops[0];

View File

@ -432,9 +432,6 @@ protected:
std::string to_initializer_expression(const SPIRVariable &var) override; std::string to_initializer_expression(const SPIRVariable &var) override;
std::string unpack_expression_type(std::string expr_str, const SPIRType &type, uint32_t physical_type_id, bool is_packed, bool row_major) override; std::string unpack_expression_type(std::string expr_str, const SPIRType &type, uint32_t physical_type_id, bool is_packed, bool row_major) override;
// Special purpose unpack which overrides transpose state, used internally only for matrix packing.
std::string unpack_expression_explicit(uint32_t id, bool row_major);
std::string bitcast_glsl_op(const SPIRType &result_type, const SPIRType &argument_type) override; std::string bitcast_glsl_op(const SPIRType &result_type, const SPIRType &argument_type) override;
bool skip_argument(uint32_t id) const override; bool skip_argument(uint32_t id) const override;
std::string to_member_reference(uint32_t base, const SPIRType &type, uint32_t index, bool ptr_chain) override; std::string to_member_reference(uint32_t base, const SPIRType &type, uint32_t index, bool ptr_chain) override;