Merge pull request #1051 from KhronosGroup/fix-1049

MSL/HLSL: Support OpOuterProduct.
This commit is contained in:
Hans-Kristian Arntzen 2019-07-01 13:40:57 +02:00 committed by GitHub
commit 41399fc899
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 404 additions and 1 deletions

View File

@ -0,0 +1,48 @@
RWByteAddressBuffer _21 : register(u0);
ByteAddressBuffer _26 : register(t1);
void comp_main()
{
float2x2 _32 = float2x2(asfloat(_26.Load2(0)) * asfloat(_26.Load2(0)).x, asfloat(_26.Load2(0)) * asfloat(_26.Load2(0)).y);
_21.Store2(0, asuint(_32[0]));
_21.Store2(8, asuint(_32[1]));
float2x3 _41 = float2x3(asfloat(_26.Load3(16)) * asfloat(_26.Load2(0)).x, asfloat(_26.Load3(16)) * asfloat(_26.Load2(0)).y);
_21.Store3(16, asuint(_41[0]));
_21.Store3(32, asuint(_41[1]));
float2x4 _50 = float2x4(asfloat(_26.Load4(32)) * asfloat(_26.Load2(0)).x, asfloat(_26.Load4(32)) * asfloat(_26.Load2(0)).y);
_21.Store4(48, asuint(_50[0]));
_21.Store4(64, asuint(_50[1]));
float3x2 _58 = float3x2(asfloat(_26.Load2(0)) * asfloat(_26.Load3(16)).x, asfloat(_26.Load2(0)) * asfloat(_26.Load3(16)).y, asfloat(_26.Load2(0)) * asfloat(_26.Load3(16)).z);
_21.Store2(80, asuint(_58[0]));
_21.Store2(88, asuint(_58[1]));
_21.Store2(96, asuint(_58[2]));
float3x3 _66 = float3x3(asfloat(_26.Load3(16)) * asfloat(_26.Load3(16)).x, asfloat(_26.Load3(16)) * asfloat(_26.Load3(16)).y, asfloat(_26.Load3(16)) * asfloat(_26.Load3(16)).z);
_21.Store3(112, asuint(_66[0]));
_21.Store3(128, asuint(_66[1]));
_21.Store3(144, asuint(_66[2]));
float3x4 _74 = float3x4(asfloat(_26.Load4(32)) * asfloat(_26.Load3(16)).x, asfloat(_26.Load4(32)) * asfloat(_26.Load3(16)).y, asfloat(_26.Load4(32)) * asfloat(_26.Load3(16)).z);
_21.Store4(160, asuint(_74[0]));
_21.Store4(176, asuint(_74[1]));
_21.Store4(192, asuint(_74[2]));
float4x2 _82 = float4x2(asfloat(_26.Load2(0)) * asfloat(_26.Load4(32)).x, asfloat(_26.Load2(0)) * asfloat(_26.Load4(32)).y, asfloat(_26.Load2(0)) * asfloat(_26.Load4(32)).z, asfloat(_26.Load2(0)) * asfloat(_26.Load4(32)).w);
_21.Store2(208, asuint(_82[0]));
_21.Store2(216, asuint(_82[1]));
_21.Store2(224, asuint(_82[2]));
_21.Store2(232, asuint(_82[3]));
float4x3 _90 = float4x3(asfloat(_26.Load3(16)) * asfloat(_26.Load4(32)).x, asfloat(_26.Load3(16)) * asfloat(_26.Load4(32)).y, asfloat(_26.Load3(16)) * asfloat(_26.Load4(32)).z, asfloat(_26.Load3(16)) * asfloat(_26.Load4(32)).w);
_21.Store3(240, asuint(_90[0]));
_21.Store3(256, asuint(_90[1]));
_21.Store3(272, asuint(_90[2]));
_21.Store3(288, asuint(_90[3]));
float4x4 _98 = float4x4(asfloat(_26.Load4(32)) * asfloat(_26.Load4(32)).x, asfloat(_26.Load4(32)) * asfloat(_26.Load4(32)).y, asfloat(_26.Load4(32)) * asfloat(_26.Load4(32)).z, asfloat(_26.Load4(32)) * asfloat(_26.Load4(32)).w);
_21.Store4(304, asuint(_98[0]));
_21.Store4(320, asuint(_98[1]));
_21.Store4(336, asuint(_98[2]));
_21.Store4(352, asuint(_98[3]));
}
[numthreads(1, 1, 1)]
void main()
{
comp_main();
}

View File

@ -0,0 +1,38 @@
#include <metal_stdlib>
#include <simd/simd.h>
using namespace metal;
struct SSBO
{
float2x2 m22;
float2x3 m23;
float2x4 m24;
float3x2 m32;
float3x3 m33;
float3x4 m34;
float4x2 m42;
float4x3 m43;
float4x4 m44;
};
struct ReadSSBO
{
float2 v2;
float3 v3;
float4 v4;
};
kernel void main0(device SSBO& _21 [[buffer(0)]], const device ReadSSBO& _26 [[buffer(1)]])
{
_21.m22 = float2x2(_26.v2 * _26.v2.x, _26.v2 * _26.v2.y);
_21.m23 = float2x3(_26.v3 * _26.v2.x, _26.v3 * _26.v2.y);
_21.m24 = float2x4(_26.v4 * _26.v2.x, _26.v4 * _26.v2.y);
_21.m32 = float3x2(_26.v2 * _26.v3.x, _26.v2 * _26.v3.y, _26.v2 * _26.v3.z);
_21.m33 = float3x3(_26.v3 * _26.v3.x, _26.v3 * _26.v3.y, _26.v3 * _26.v3.z);
_21.m34 = float3x4(_26.v4 * _26.v3.x, _26.v4 * _26.v3.y, _26.v4 * _26.v3.z);
_21.m42 = float4x2(_26.v2 * _26.v4.x, _26.v2 * _26.v4.y, _26.v2 * _26.v4.z, _26.v2 * _26.v4.w);
_21.m43 = float4x3(_26.v3 * _26.v4.x, _26.v3 * _26.v4.y, _26.v3 * _26.v4.z, _26.v3 * _26.v4.w);
_21.m44 = float4x4(_26.v4 * _26.v4.x, _26.v4 * _26.v4.y, _26.v4 * _26.v4.z, _26.v4 * _26.v4.w);
}

View File

@ -0,0 +1,36 @@
#version 450
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
layout(binding = 0, std430) writeonly buffer SSBO
{
mat2 m22;
mat2x3 m23;
mat2x4 m24;
mat3x2 m32;
mat3 m33;
mat3x4 m34;
mat4x2 m42;
mat4x3 m43;
mat4 m44;
} _21;
layout(binding = 1, std430) readonly buffer ReadSSBO
{
vec2 v2;
vec3 v3;
vec4 v4;
} _26;
void main()
{
_21.m22 = outerProduct(_26.v2, _26.v2);
_21.m23 = outerProduct(_26.v3, _26.v2);
_21.m24 = outerProduct(_26.v4, _26.v2);
_21.m32 = outerProduct(_26.v2, _26.v3);
_21.m33 = outerProduct(_26.v3, _26.v3);
_21.m34 = outerProduct(_26.v4, _26.v3);
_21.m42 = outerProduct(_26.v2, _26.v4);
_21.m43 = outerProduct(_26.v3, _26.v4);
_21.m44 = outerProduct(_26.v4, _26.v4);
}

View File

@ -0,0 +1,48 @@
RWByteAddressBuffer _21 : register(u0);
ByteAddressBuffer _26 : register(t1);
void comp_main()
{
float2x2 _32 = float2x2(asfloat(_26.Load2(0)) * asfloat(_26.Load2(0)).x, asfloat(_26.Load2(0)) * asfloat(_26.Load2(0)).y);
_21.Store2(0, asuint(_32[0]));
_21.Store2(8, asuint(_32[1]));
float2x3 _41 = float2x3(asfloat(_26.Load3(16)) * asfloat(_26.Load2(0)).x, asfloat(_26.Load3(16)) * asfloat(_26.Load2(0)).y);
_21.Store3(16, asuint(_41[0]));
_21.Store3(32, asuint(_41[1]));
float2x4 _50 = float2x4(asfloat(_26.Load4(32)) * asfloat(_26.Load2(0)).x, asfloat(_26.Load4(32)) * asfloat(_26.Load2(0)).y);
_21.Store4(48, asuint(_50[0]));
_21.Store4(64, asuint(_50[1]));
float3x2 _58 = float3x2(asfloat(_26.Load2(0)) * asfloat(_26.Load3(16)).x, asfloat(_26.Load2(0)) * asfloat(_26.Load3(16)).y, asfloat(_26.Load2(0)) * asfloat(_26.Load3(16)).z);
_21.Store2(80, asuint(_58[0]));
_21.Store2(88, asuint(_58[1]));
_21.Store2(96, asuint(_58[2]));
float3x3 _66 = float3x3(asfloat(_26.Load3(16)) * asfloat(_26.Load3(16)).x, asfloat(_26.Load3(16)) * asfloat(_26.Load3(16)).y, asfloat(_26.Load3(16)) * asfloat(_26.Load3(16)).z);
_21.Store3(112, asuint(_66[0]));
_21.Store3(128, asuint(_66[1]));
_21.Store3(144, asuint(_66[2]));
float3x4 _74 = float3x4(asfloat(_26.Load4(32)) * asfloat(_26.Load3(16)).x, asfloat(_26.Load4(32)) * asfloat(_26.Load3(16)).y, asfloat(_26.Load4(32)) * asfloat(_26.Load3(16)).z);
_21.Store4(160, asuint(_74[0]));
_21.Store4(176, asuint(_74[1]));
_21.Store4(192, asuint(_74[2]));
float4x2 _82 = float4x2(asfloat(_26.Load2(0)) * asfloat(_26.Load4(32)).x, asfloat(_26.Load2(0)) * asfloat(_26.Load4(32)).y, asfloat(_26.Load2(0)) * asfloat(_26.Load4(32)).z, asfloat(_26.Load2(0)) * asfloat(_26.Load4(32)).w);
_21.Store2(208, asuint(_82[0]));
_21.Store2(216, asuint(_82[1]));
_21.Store2(224, asuint(_82[2]));
_21.Store2(232, asuint(_82[3]));
float4x3 _90 = float4x3(asfloat(_26.Load3(16)) * asfloat(_26.Load4(32)).x, asfloat(_26.Load3(16)) * asfloat(_26.Load4(32)).y, asfloat(_26.Load3(16)) * asfloat(_26.Load4(32)).z, asfloat(_26.Load3(16)) * asfloat(_26.Load4(32)).w);
_21.Store3(240, asuint(_90[0]));
_21.Store3(256, asuint(_90[1]));
_21.Store3(272, asuint(_90[2]));
_21.Store3(288, asuint(_90[3]));
float4x4 _98 = float4x4(asfloat(_26.Load4(32)) * asfloat(_26.Load4(32)).x, asfloat(_26.Load4(32)) * asfloat(_26.Load4(32)).y, asfloat(_26.Load4(32)) * asfloat(_26.Load4(32)).z, asfloat(_26.Load4(32)) * asfloat(_26.Load4(32)).w);
_21.Store4(304, asuint(_98[0]));
_21.Store4(320, asuint(_98[1]));
_21.Store4(336, asuint(_98[2]));
_21.Store4(352, asuint(_98[3]));
}
[numthreads(1, 1, 1)]
void main()
{
comp_main();
}

View File

@ -0,0 +1,38 @@
#include <metal_stdlib>
#include <simd/simd.h>
using namespace metal;
struct SSBO
{
float2x2 m22;
float2x3 m23;
float2x4 m24;
float3x2 m32;
float3x3 m33;
float3x4 m34;
float4x2 m42;
float4x3 m43;
float4x4 m44;
};
struct ReadSSBO
{
float2 v2;
float3 v3;
float4 v4;
};
kernel void main0(device SSBO& _21 [[buffer(0)]], const device ReadSSBO& _26 [[buffer(1)]])
{
_21.m22 = float2x2(_26.v2 * _26.v2.x, _26.v2 * _26.v2.y);
_21.m23 = float2x3(_26.v3 * _26.v2.x, _26.v3 * _26.v2.y);
_21.m24 = float2x4(_26.v4 * _26.v2.x, _26.v4 * _26.v2.y);
_21.m32 = float3x2(_26.v2 * _26.v3.x, _26.v2 * _26.v3.y, _26.v2 * _26.v3.z);
_21.m33 = float3x3(_26.v3 * _26.v3.x, _26.v3 * _26.v3.y, _26.v3 * _26.v3.z);
_21.m34 = float3x4(_26.v4 * _26.v3.x, _26.v4 * _26.v3.y, _26.v4 * _26.v3.z);
_21.m42 = float4x2(_26.v2 * _26.v4.x, _26.v2 * _26.v4.y, _26.v2 * _26.v4.z, _26.v2 * _26.v4.w);
_21.m43 = float4x3(_26.v3 * _26.v4.x, _26.v3 * _26.v4.y, _26.v3 * _26.v4.z, _26.v3 * _26.v4.w);
_21.m44 = float4x4(_26.v4 * _26.v4.x, _26.v4 * _26.v4.y, _26.v4 * _26.v4.z, _26.v4 * _26.v4.w);
}

View File

@ -0,0 +1,36 @@
#version 450
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
layout(binding = 0, std430) writeonly buffer SSBO
{
mat2 m22;
mat2x3 m23;
mat2x4 m24;
mat3x2 m32;
mat3 m33;
mat3x4 m34;
mat4x2 m42;
mat4x3 m43;
mat4 m44;
} _21;
layout(binding = 1, std430) readonly buffer ReadSSBO
{
vec2 v2;
vec3 v3;
vec4 v4;
} _26;
void main()
{
_21.m22 = outerProduct(_26.v2, _26.v2);
_21.m23 = outerProduct(_26.v3, _26.v2);
_21.m24 = outerProduct(_26.v4, _26.v2);
_21.m32 = outerProduct(_26.v2, _26.v3);
_21.m33 = outerProduct(_26.v3, _26.v3);
_21.m34 = outerProduct(_26.v4, _26.v3);
_21.m42 = outerProduct(_26.v2, _26.v4);
_21.m43 = outerProduct(_26.v3, _26.v4);
_21.m44 = outerProduct(_26.v4, _26.v4);
}

View File

@ -0,0 +1,37 @@
#version 450
layout(local_size_x = 1) in;
layout(set = 0, binding = 0, std430) writeonly buffer SSBO
{
mat2 m22;
mat2x3 m23;
mat2x4 m24;
mat3x2 m32;
mat3 m33;
mat3x4 m34;
mat4x2 m42;
mat4x3 m43;
mat4 m44;
};
layout(set = 0, binding = 1, std430) readonly buffer ReadSSBO
{
vec2 v2;
vec3 v3;
vec4 v4;
};
void main()
{
m22 = outerProduct(v2, v2);
m23 = outerProduct(v3, v2);
m24 = outerProduct(v4, v2);
m32 = outerProduct(v2, v3);
m33 = outerProduct(v3, v3);
m34 = outerProduct(v4, v3);
m42 = outerProduct(v2, v4);
m43 = outerProduct(v3, v4);
m44 = outerProduct(v4, v4);
}

View File

@ -0,0 +1,37 @@
#version 450
layout(local_size_x = 1) in;
layout(set = 0, binding = 0, std430) writeonly buffer SSBO
{
mat2 m22;
mat2x3 m23;
mat2x4 m24;
mat3x2 m32;
mat3 m33;
mat3x4 m34;
mat4x2 m42;
mat4x3 m43;
mat4 m44;
};
layout(set = 0, binding = 1, std430) readonly buffer ReadSSBO
{
vec2 v2;
vec3 v3;
vec4 v4;
};
void main()
{
m22 = outerProduct(v2, v2);
m23 = outerProduct(v3, v2);
m24 = outerProduct(v4, v2);
m32 = outerProduct(v2, v3);
m33 = outerProduct(v3, v3);
m34 = outerProduct(v4, v3);
m42 = outerProduct(v2, v4);
m43 = outerProduct(v3, v4);
m44 = outerProduct(v4, v4);
}

View File

@ -0,0 +1,37 @@
#version 450
layout(local_size_x = 1) in;
layout(set = 0, binding = 0, std430) writeonly buffer SSBO
{
mat2 m22;
mat2x3 m23;
mat2x4 m24;
mat3x2 m32;
mat3 m33;
mat3x4 m34;
mat4x2 m42;
mat4x3 m43;
mat4 m44;
};
layout(set = 0, binding = 1, std430) readonly buffer ReadSSBO
{
vec2 v2;
vec3 v3;
vec4 v4;
};
void main()
{
m22 = outerProduct(v2, v2);
m23 = outerProduct(v3, v2);
m24 = outerProduct(v4, v2);
m32 = outerProduct(v2, v3);
m33 = outerProduct(v3, v3);
m34 = outerProduct(v4, v3);
m42 = outerProduct(v2, v4);
m43 = outerProduct(v3, v4);
m44 = outerProduct(v4, v4);
}

View File

@ -3965,6 +3965,31 @@ void CompilerHLSL::emit_instruction(const Instruction &instruction)
break; break;
} }
case OpOuterProduct:
{
uint32_t result_type = ops[0];
uint32_t id = ops[1];
uint32_t a = ops[2];
uint32_t b = ops[3];
auto &type = get<SPIRType>(result_type);
string expr = type_to_glsl_constructor(type);
expr += "(";
for (uint32_t col = 0; col < type.columns; col++)
{
expr += to_enclosed_expression(a);
expr += " * ";
expr += to_extract_component_expression(b, col);
if (col + 1 < type.columns)
expr += ", ";
}
expr += ")";
emit_op(result_type, id, expr, should_forward(a) && should_forward(b));
inherit_expression_dependencies(id, a);
inherit_expression_dependencies(id, b);
break;
}
case OpFMod: case OpFMod:
{ {
if (!requires_op_fmod) if (!requires_op_fmod)

View File

@ -4163,7 +4163,30 @@ void CompilerMSL::emit_instruction(const Instruction &instruction)
break; break;
} }
// OpOuterProduct case OpOuterProduct:
{
uint32_t result_type = ops[0];
uint32_t id = ops[1];
uint32_t a = ops[2];
uint32_t b = ops[3];
auto &type = get<SPIRType>(result_type);
string expr = type_to_glsl_constructor(type);
expr += "(";
for (uint32_t col = 0; col < type.columns; col++)
{
expr += to_enclosed_expression(a);
expr += " * ";
expr += to_extract_component_expression(b, col);
if (col + 1 < type.columns)
expr += ", ";
}
expr += ")";
emit_op(result_type, id, expr, should_forward(a) && should_forward(b));
inherit_expression_dependencies(id, a);
inherit_expression_dependencies(id, b);
break;
}
case OpIAddCarry: case OpIAddCarry:
case OpISubBorrow: case OpISubBorrow: