Use to_unpacked_row_major_expression to unify row-major in MSL/GLSL.
This commit is contained in:
parent
47a18b9f1b
commit
7277c7ac46
@ -0,0 +1,21 @@
|
||||
#include <metal_stdlib>
|
||||
#include <simd/simd.h>
|
||||
|
||||
using namespace metal;
|
||||
|
||||
struct SSBO
|
||||
{
|
||||
float2x4 m0;
|
||||
float2x4 m1;
|
||||
float2 v0;
|
||||
float2 v1;
|
||||
};
|
||||
|
||||
kernel void main0(device SSBO& _11 [[buffer(0)]])
|
||||
{
|
||||
_11.v0 = _11.v1 * (float2x2(_11.m1[0].xy, _11.m1[1].xy) * float2x2(_11.m0[0].xy, _11.m0[1].xy));
|
||||
_11.v0 = (_11.v1 * float2x2(_11.m1[0].xy, _11.m1[1].xy)) * float2x2(_11.m0[0].xy, _11.m0[1].xy);
|
||||
_11.v0 = float2x2(_11.m1[0].xy, _11.m1[1].xy) * (float2x2(_11.m0[0].xy, _11.m0[1].xy) * _11.v1);
|
||||
_11.v0 = (float2x2(_11.m1[0].xy, _11.m1[1].xy) * float2x2(_11.m0[0].xy, _11.m0[1].xy)) * _11.v1;
|
||||
}
|
||||
|
@ -23,7 +23,7 @@ struct main0_in
|
||||
vertex main0_out main0(main0_in in [[stage_in]], constant UBO& _18 [[buffer(0)]])
|
||||
{
|
||||
main0_out out = {};
|
||||
float2 v = float2x4(_18.uMVP[0], _18.uMVP[1]) * in.aVertex;
|
||||
float2 v = float4x2(_18.uMVP[0].xy, _18.uMVP[1].xy, _18.uMVP[2].xy, _18.uMVP[3].xy) * in.aVertex;
|
||||
out.gl_Position = (_18.uMVPR * in.aVertex) + (in.aVertex * _18.uMVPC);
|
||||
return out;
|
||||
}
|
||||
|
@ -0,0 +1,18 @@
|
||||
#version 450
|
||||
layout(local_size_x = 1) in;
|
||||
|
||||
layout(std140, row_major, set = 0, binding = 0) buffer SSBO
|
||||
{
|
||||
mat2 m0;
|
||||
mat2 m1;
|
||||
vec2 v0;
|
||||
vec2 v1;
|
||||
};
|
||||
|
||||
void main()
|
||||
{
|
||||
v0 = (m0 * m1) * v1;
|
||||
v0 = m0 * (m1 * v1);
|
||||
v0 = (v1 * m0) * m1;
|
||||
v0 = v1 * (m0 * m1);
|
||||
}
|
@ -2838,6 +2838,15 @@ string CompilerGLSL::to_enclosed_expression(uint32_t id, bool register_expressio
|
||||
return enclose_expression(to_expression(id, register_expression_read));
|
||||
}
|
||||
|
||||
// Used explicitly when we want to read a row-major expression, but without any transpose shenanigans.
|
||||
// need_transpose must be forced to false.
|
||||
string CompilerGLSL::to_unpacked_row_major_matrix_expression(uint32_t id)
|
||||
{
|
||||
return unpack_expression_type(to_expression(id), expression_type(id),
|
||||
get_extended_decoration(id, SPIRVCrossDecorationPhysicalTypeID),
|
||||
has_extended_decoration(id, SPIRVCrossDecorationPhysicalTypePacked), true);
|
||||
}
|
||||
|
||||
string CompilerGLSL::to_unpacked_expression(uint32_t id, bool register_expression_read)
|
||||
{
|
||||
// If we need to transpose, it will also take care of unpacking rules.
|
||||
@ -2845,6 +2854,7 @@ string CompilerGLSL::to_unpacked_expression(uint32_t id, bool register_expressio
|
||||
bool need_transpose = e && e->need_transpose;
|
||||
bool is_remapped = has_extended_decoration(id, SPIRVCrossDecorationPhysicalTypeID);
|
||||
bool is_packed = has_extended_decoration(id, SPIRVCrossDecorationPhysicalTypePacked);
|
||||
|
||||
if (!need_transpose && (is_remapped || is_packed))
|
||||
{
|
||||
return unpack_expression_type(to_expression(id, register_expression_read), expression_type(id),
|
||||
@ -8311,8 +8321,18 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction)
|
||||
if (e && e->need_transpose)
|
||||
{
|
||||
e->need_transpose = false;
|
||||
emit_binary_op(ops[0], ops[1], ops[3], ops[2], "*");
|
||||
string expr;
|
||||
|
||||
if (opcode == OpMatrixTimesVector)
|
||||
expr = join(to_enclosed_expression(ops[3]), " * ", enclose_expression(to_unpacked_row_major_matrix_expression(ops[2])));
|
||||
else
|
||||
expr = join(enclose_expression(to_unpacked_row_major_matrix_expression(ops[3])), " * ", to_enclosed_expression(ops[2]));
|
||||
|
||||
bool forward = should_forward(ops[2]) && should_forward(ops[3]);
|
||||
emit_op(ops[0], ops[1], expr, forward);
|
||||
e->need_transpose = true;
|
||||
inherit_expression_dependencies(ops[1], ops[2]);
|
||||
inherit_expression_dependencies(ops[1], ops[3]);
|
||||
}
|
||||
else
|
||||
GLSL_BOP(*);
|
||||
@ -8330,10 +8350,15 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction)
|
||||
{
|
||||
a->need_transpose = false;
|
||||
b->need_transpose = false;
|
||||
emit_binary_op(ops[0], ops[1], ops[3], ops[2], "*");
|
||||
get<SPIRExpression>(ops[1]).need_transpose = true;
|
||||
auto expr = join(enclose_expression(to_unpacked_row_major_matrix_expression(ops[3])), " * ",
|
||||
enclose_expression(to_unpacked_row_major_matrix_expression(ops[2])));
|
||||
bool forward = should_forward(ops[2]) && should_forward(ops[3]);
|
||||
auto &e = emit_op(ops[0], ops[1], expr, forward);
|
||||
e.need_transpose = true;
|
||||
a->need_transpose = true;
|
||||
b->need_transpose = true;
|
||||
inherit_expression_dependencies(ops[1], ops[2]);
|
||||
inherit_expression_dependencies(ops[1], ops[3]);
|
||||
}
|
||||
else
|
||||
GLSL_BOP(*);
|
||||
|
@ -511,6 +511,7 @@ protected:
|
||||
std::string to_rerolled_array_expression(const std::string &expr, const SPIRType &type);
|
||||
std::string to_enclosed_expression(uint32_t id, bool register_expression_read = true);
|
||||
std::string to_unpacked_expression(uint32_t id, bool register_expression_read = true);
|
||||
std::string to_unpacked_row_major_matrix_expression(uint32_t id);
|
||||
std::string to_enclosed_unpacked_expression(uint32_t id, bool register_expression_read = true);
|
||||
std::string to_dereferenced_expression(uint32_t id, bool register_expression_read = true);
|
||||
std::string to_pointer_expression(uint32_t id, bool register_expression_read = true);
|
||||
|
@ -2718,7 +2718,7 @@ void CompilerMSL::emit_store_statement(uint32_t lhs_expression, uint32_t rhs_exp
|
||||
// Direct copy, but might need to unpack RHS.
|
||||
// Skip the transpose, as we will transpose when writing to LHS and transpose(transpose(T)) == T.
|
||||
rhs_e->need_transpose = false;
|
||||
statement(to_expression(lhs_expression), " = ", unpack_expression_explicit(rhs_expression, true), ";");
|
||||
statement(to_expression(lhs_expression), " = ", to_unpacked_row_major_matrix_expression(rhs_expression), ";");
|
||||
rhs_e->need_transpose = true;
|
||||
}
|
||||
else
|
||||
@ -2802,7 +2802,7 @@ void CompilerMSL::emit_store_statement(uint32_t lhs_expression, uint32_t rhs_exp
|
||||
for (uint32_t i = 0; i < type.vecsize; i++)
|
||||
{
|
||||
statement(to_enclosed_expression(lhs_expression),
|
||||
"[", i, "]", store_swiz, " = ", unpack_expression_explicit(rhs_expression, true), "[", i, "];");
|
||||
"[", i, "]", store_swiz, " = ", to_unpacked_row_major_matrix_expression(rhs_expression), "[", i, "];");
|
||||
}
|
||||
}
|
||||
else
|
||||
@ -2851,7 +2851,7 @@ void CompilerMSL::emit_store_statement(uint32_t lhs_expression, uint32_t rhs_exp
|
||||
for (uint32_t j = 0; j < vector_type.vecsize; j++)
|
||||
{
|
||||
// Need to explicitly unpack expression since we've mucked with transpose state.
|
||||
auto unpacked_expr = unpack_expression_explicit(rhs_expression, true);
|
||||
auto unpacked_expr = to_unpacked_row_major_matrix_expression(rhs_expression);
|
||||
rhs_row += join(unpacked_expr, "[", j, "][", i, "]");
|
||||
if (j + 1 < vector_type.vecsize)
|
||||
rhs_row += ", ";
|
||||
@ -2926,13 +2926,6 @@ void CompilerMSL::emit_store_statement(uint32_t lhs_expression, uint32_t rhs_exp
|
||||
}
|
||||
}
|
||||
|
||||
string CompilerMSL::unpack_expression_explicit(uint32_t id, bool row_major)
|
||||
{
|
||||
return unpack_expression_type(to_expression(id), expression_type(id),
|
||||
get_extended_decoration(id, SPIRVCrossDecorationPhysicalTypeID),
|
||||
has_extended_decoration(id, SPIRVCrossDecorationPhysicalTypePacked), row_major);
|
||||
}
|
||||
|
||||
// Converts the format of the current expression from packed to unpacked,
|
||||
// by wrapping the expression in a constructor of the appropriate type.
|
||||
// Also, handle special physical ID remapping scenarios, similar to emit_store_statement().
|
||||
@ -4557,23 +4550,6 @@ void CompilerMSL::emit_instruction(const Instruction &instruction)
|
||||
emit_barrier(ops[0], ops[1], ops[2]);
|
||||
break;
|
||||
|
||||
case OpVectorTimesMatrix:
|
||||
case OpMatrixTimesVector:
|
||||
{
|
||||
// If the matrix needs transpose and it is square or packed, just flip the multiply order.
|
||||
uint32_t mtx_id = ops[opcode == OpMatrixTimesVector ? 2 : 3];
|
||||
auto *e = maybe_get<SPIRExpression>(mtx_id);
|
||||
if (e && e->need_transpose)
|
||||
{
|
||||
e->need_transpose = false;
|
||||
emit_binary_op(ops[0], ops[1], ops[3], ops[2], "*");
|
||||
e->need_transpose = true;
|
||||
}
|
||||
else
|
||||
MSL_BOP(*);
|
||||
break;
|
||||
}
|
||||
|
||||
case OpOuterProduct:
|
||||
{
|
||||
uint32_t result_type = ops[0];
|
||||
|
@ -432,9 +432,6 @@ protected:
|
||||
std::string to_initializer_expression(const SPIRVariable &var) override;
|
||||
std::string unpack_expression_type(std::string expr_str, const SPIRType &type, uint32_t physical_type_id, bool is_packed, bool row_major) override;
|
||||
|
||||
// Special purpose unpack which overrides transpose state, used internally only for matrix packing.
|
||||
std::string unpack_expression_explicit(uint32_t id, bool row_major);
|
||||
|
||||
std::string bitcast_glsl_op(const SPIRType &result_type, const SPIRType &argument_type) override;
|
||||
bool skip_argument(uint32_t id) const override;
|
||||
std::string to_member_reference(uint32_t base, const SPIRType &type, uint32_t index, bool ptr_chain) override;
|
||||
|
Loading…
Reference in New Issue
Block a user