SPIRV-Cross/reference/opt/shaders-hlsl/comp/inverse.comp
Hans-Kristian Arntzen b380a2113a Implement MatrixInverse on HLSL.
Copy-paste implementation from MSL. I assume it's correct.
2018-02-23 16:42:40 +01:00

123 lines
6.1 KiB
Plaintext

RWByteAddressBuffer _15 : register(u0);
ByteAddressBuffer _20 : register(t1);
// Returns the inverse of a matrix, by using the algorithm of calculating the classical
// adjoint and dividing by the determinant. The contents of the matrix are changed.
float2x2 SPIRV_Cross_Inverse(float2x2 m)
{
float2x2 adj; // The adjoint matrix (inverse after dividing by determinant)
// Create the transpose of the cofactors, as the classical adjoint of the matrix.
adj[0][0] = m[1][1];
adj[0][1] = -m[0][1];
adj[1][0] = -m[1][0];
adj[1][1] = m[0][0];
// Calculate the determinant as a combination of the cofactors of the first row.
float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]);
// Divide the classical adjoint matrix by the determinant.
// If determinant is zero, matrix is not invertable, so leave it unchanged.
return (det != 0.0f) ? (adj * (1.0f / det)) : m;
}
// Returns the determinant of a 2x2 matrix.
float SPIRV_Cross_Det2x2(float a1, float a2, float b1, float b2)
{
return a1 * b2 - b1 * a2;
}
// Returns the inverse of a matrix, by using the algorithm of calculating the classical
// adjoint and dividing by the determinant. The contents of the matrix are changed.
float3x3 SPIRV_Cross_Inverse(float3x3 m)
{
float3x3 adj; // The adjoint matrix (inverse after dividing by determinant)
// Create the transpose of the cofactors, as the classical adjoint of the matrix.
adj[0][0] = SPIRV_Cross_Det2x2(m[1][1], m[1][2], m[2][1], m[2][2]);
adj[0][1] = -SPIRV_Cross_Det2x2(m[0][1], m[0][2], m[2][1], m[2][2]);
adj[0][2] = SPIRV_Cross_Det2x2(m[0][1], m[0][2], m[1][1], m[1][2]);
adj[1][0] = -SPIRV_Cross_Det2x2(m[1][0], m[1][2], m[2][0], m[2][2]);
adj[1][1] = SPIRV_Cross_Det2x2(m[0][0], m[0][2], m[2][0], m[2][2]);
adj[1][2] = -SPIRV_Cross_Det2x2(m[0][0], m[0][2], m[1][0], m[1][2]);
adj[2][0] = SPIRV_Cross_Det2x2(m[1][0], m[1][1], m[2][0], m[2][1]);
adj[2][1] = -SPIRV_Cross_Det2x2(m[0][0], m[0][1], m[2][0], m[2][1]);
adj[2][2] = SPIRV_Cross_Det2x2(m[0][0], m[0][1], m[1][0], m[1][1]);
// Calculate the determinant as a combination of the cofactors of the first row.
float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]) + (adj[0][2] * m[2][0]);
// Divide the classical adjoint matrix by the determinant.
// If determinant is zero, matrix is not invertable, so leave it unchanged.
return (det != 0.0f) ? (adj * (1.0f / det)) : m;
}
// Returns the determinant of a 3x3 matrix.
float SPIRV_Cross_Det3x3(float a1, float a2, float a3, float b1, float b2, float b3, float c1, float c2, float c3)
{
return a1 * SPIRV_Cross_Det2x2(b2, b3, c2, c3) - b1 * SPIRV_Cross_Det2x2(a2, a3, c2, c3) + c1 * SPIRV_Cross_Det2x2(a2, a3, b2, b3);
}
// Returns the inverse of a matrix, by using the algorithm of calculating the classical
// adjoint and dividing by the determinant. The contents of the matrix are changed.
float4x4 SPIRV_Cross_Inverse(float4x4 m)
{
float4x4 adj; // The adjoint matrix (inverse after dividing by determinant)
// Create the transpose of the cofactors, as the classical adjoint of the matrix.
adj[0][0] = SPIRV_Cross_Det3x3(m[1][1], m[1][2], m[1][3], m[2][1], m[2][2], m[2][3], m[3][1], m[3][2], m[3][3]);
adj[0][1] = -SPIRV_Cross_Det3x3(m[0][1], m[0][2], m[0][3], m[2][1], m[2][2], m[2][3], m[3][1], m[3][2], m[3][3]);
adj[0][2] = SPIRV_Cross_Det3x3(m[0][1], m[0][2], m[0][3], m[1][1], m[1][2], m[1][3], m[3][1], m[3][2], m[3][3]);
adj[0][3] = -SPIRV_Cross_Det3x3(m[0][1], m[0][2], m[0][3], m[1][1], m[1][2], m[1][3], m[2][1], m[2][2], m[2][3]);
adj[1][0] = -SPIRV_Cross_Det3x3(m[1][0], m[1][2], m[1][3], m[2][0], m[2][2], m[2][3], m[3][0], m[3][2], m[3][3]);
adj[1][1] = SPIRV_Cross_Det3x3(m[0][0], m[0][2], m[0][3], m[2][0], m[2][2], m[2][3], m[3][0], m[3][2], m[3][3]);
adj[1][2] = -SPIRV_Cross_Det3x3(m[0][0], m[0][2], m[0][3], m[1][0], m[1][2], m[1][3], m[3][0], m[3][2], m[3][3]);
adj[1][3] = SPIRV_Cross_Det3x3(m[0][0], m[0][2], m[0][3], m[1][0], m[1][2], m[1][3], m[2][0], m[2][2], m[2][3]);
adj[2][0] = SPIRV_Cross_Det3x3(m[1][0], m[1][1], m[1][3], m[2][0], m[2][1], m[2][3], m[3][0], m[3][1], m[3][3]);
adj[2][1] = -SPIRV_Cross_Det3x3(m[0][0], m[0][1], m[0][3], m[2][0], m[2][1], m[2][3], m[3][0], m[3][1], m[3][3]);
adj[2][2] = SPIRV_Cross_Det3x3(m[0][0], m[0][1], m[0][3], m[1][0], m[1][1], m[1][3], m[3][0], m[3][1], m[3][3]);
adj[2][3] = -SPIRV_Cross_Det3x3(m[0][0], m[0][1], m[0][3], m[1][0], m[1][1], m[1][3], m[2][0], m[2][1], m[2][3]);
adj[3][0] = -SPIRV_Cross_Det3x3(m[1][0], m[1][1], m[1][2], m[2][0], m[2][1], m[2][2], m[3][0], m[3][1], m[3][2]);
adj[3][1] = SPIRV_Cross_Det3x3(m[0][0], m[0][1], m[0][2], m[2][0], m[2][1], m[2][2], m[3][0], m[3][1], m[3][2]);
adj[3][2] = -SPIRV_Cross_Det3x3(m[0][0], m[0][1], m[0][2], m[1][0], m[1][1], m[1][2], m[3][0], m[3][1], m[3][2]);
adj[3][3] = SPIRV_Cross_Det3x3(m[0][0], m[0][1], m[0][2], m[1][0], m[1][1], m[1][2], m[2][0], m[2][1], m[2][2]);
// Calculate the determinant as a combination of the cofactors of the first row.
float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]) + (adj[0][2] * m[2][0]) + (adj[0][3] * m[3][0]);
// Divide the classical adjoint matrix by the determinant.
// If determinant is zero, matrix is not invertable, so leave it unchanged.
return (det != 0.0f) ? (adj * (1.0f / det)) : m;
}
void comp_main()
{
float2x2 _23 = asfloat(uint2x2(_20.Load2(0), _20.Load2(8)));
float2x2 _24 = SPIRV_Cross_Inverse(_23);
_15.Store2(0, asuint(_24[0]));
_15.Store2(8, asuint(_24[1]));
float3x3 _29 = asfloat(uint3x3(_20.Load3(16), _20.Load3(32), _20.Load3(48)));
float3x3 _30 = SPIRV_Cross_Inverse(_29);
_15.Store3(16, asuint(_30[0]));
_15.Store3(32, asuint(_30[1]));
_15.Store3(48, asuint(_30[2]));
float4x4 _35 = asfloat(uint4x4(_20.Load4(64), _20.Load4(80), _20.Load4(96), _20.Load4(112)));
float4x4 _36 = SPIRV_Cross_Inverse(_35);
_15.Store4(64, asuint(_36[0]));
_15.Store4(80, asuint(_36[1]));
_15.Store4(96, asuint(_36[2]));
_15.Store4(112, asuint(_36[3]));
}
[numthreads(1, 1, 1)]
void main()
{
comp_main();
}