MSL: Support extended arithmetic opcodes.

This commit is contained in:
Chip Davis 2018-11-13 14:11:50 -06:00
parent 1adaaba74e
commit cf2a890e4f
4 changed files with 461 additions and 0 deletions

View File

@ -0,0 +1,177 @@
#include <metal_stdlib>
#include <simd/simd.h>
using namespace metal;
struct SSBOUint
{
uint a;
uint b;
uint c;
uint d;
uint2 a2;
uint2 b2;
uint2 c2;
uint2 d2;
uint3 a3;
uint3 b3;
uint3 c3;
uint3 d3;
uint4 a4;
uint4 b4;
uint4 c4;
uint4 d4;
};
struct ResType
{
uint _m0;
uint _m1;
};
struct ResType_1
{
uint2 _m0;
uint2 _m1;
};
struct ResType_2
{
uint3 _m0;
uint3 _m1;
};
struct ResType_3
{
uint4 _m0;
uint4 _m1;
};
struct SSBOInt
{
int a;
int b;
int c;
int d;
int2 a2;
int2 b2;
int2 c2;
int2 d2;
int3 a3;
int3 b3;
int3 c3;
int3 d3;
int4 a4;
int4 b4;
int4 c4;
int4 d4;
};
struct ResType_4
{
int _m0;
int _m1;
};
struct ResType_5
{
int2 _m0;
int2 _m1;
};
struct ResType_6
{
int3 _m0;
int3 _m1;
};
struct ResType_7
{
int4 _m0;
int4 _m1;
};
kernel void main0(device SSBOUint& u [[buffer(0)]], device SSBOInt& i [[buffer(1)]])
{
ResType _25;
_25._m0 = u.a + u.b;
_25._m1 = select(uint(1), uint(0), addsat(u.a, u.b) == UINT_MAX);
u.d = _25._m1;
u.c = _25._m0;
ResType_1 _40;
_40._m0 = u.a2 + u.b2;
_40._m1 = select(uint2(1), uint2(0), addsat(u.a2, u.b2) == UINT_MAX);
u.d2 = _40._m1;
u.c2 = _40._m0;
ResType_2 _55;
_55._m0 = u.a3 + u.b3;
_55._m1 = select(uint3(1), uint3(0), addsat(u.a3, u.b3) == UINT_MAX);
u.d3 = _55._m1;
u.c3 = _55._m0;
ResType_3 _70;
_70._m0 = u.a4 + u.b4;
_70._m1 = select(uint4(1), uint4(0), addsat(u.a4, u.b4) == UINT_MAX);
u.d4 = _70._m1;
u.c4 = _70._m0;
ResType _79;
_79._m0 = u.a - u.b;
_79._m1 = select(uint(1), uint(0), subsat(u.a, u.b) == 0);
u.d = _79._m1;
u.c = _79._m0;
ResType_1 _88;
_88._m0 = u.a2 - u.b2;
_88._m1 = select(uint2(1), uint2(0), subsat(u.a2, u.b2) == 0);
u.d2 = _88._m1;
u.c2 = _88._m0;
ResType_2 _97;
_97._m0 = u.a3 - u.b3;
_97._m1 = select(uint3(1), uint3(0), subsat(u.a3, u.b3) == 0);
u.d3 = _97._m1;
u.c3 = _97._m0;
ResType_3 _106;
_106._m0 = u.a4 - u.b4;
_106._m1 = select(uint4(1), uint4(0), subsat(u.a4, u.b4) == 0);
u.d4 = _106._m1;
u.c4 = _106._m0;
ResType _116;
_116._m0 = u.a * u.b;
_116._m1 = mulhi(u.a, u.b);
u.d = _116._m0;
u.c = _116._m1;
ResType_1 _125;
_125._m0 = u.a2 * u.b2;
_125._m1 = mulhi(u.a2, u.b2);
u.d2 = _125._m0;
u.c2 = _125._m1;
ResType_2 _134;
_134._m0 = u.a3 * u.b3;
_134._m1 = mulhi(u.a3, u.b3);
u.d3 = _134._m0;
u.c3 = _134._m1;
ResType_3 _143;
_143._m0 = u.a4 * u.b4;
_143._m1 = mulhi(u.a4, u.b4);
u.d4 = _143._m0;
u.c4 = _143._m1;
ResType_4 _160;
_160._m0 = i.a * i.b;
_160._m1 = mulhi(i.a, i.b);
i.d = _160._m0;
i.c = _160._m1;
ResType_5 _171;
_171._m0 = i.a2 * i.b2;
_171._m1 = mulhi(i.a2, i.b2);
i.d2 = _171._m0;
i.c2 = _171._m1;
ResType_6 _182;
_182._m0 = i.a3 * i.b3;
_182._m1 = mulhi(i.a3, i.b3);
i.d3 = _182._m0;
i.c3 = _182._m1;
ResType_7 _193;
_193._m0 = i.a4 * i.b4;
_193._m1 = mulhi(i.a4, i.b4);
i.d4 = _193._m0;
i.c4 = _193._m1;
}

View File

@ -0,0 +1,177 @@
#include <metal_stdlib>
#include <simd/simd.h>
using namespace metal;
struct SSBOUint
{
uint a;
uint b;
uint c;
uint d;
uint2 a2;
uint2 b2;
uint2 c2;
uint2 d2;
uint3 a3;
uint3 b3;
uint3 c3;
uint3 d3;
uint4 a4;
uint4 b4;
uint4 c4;
uint4 d4;
};
struct ResType
{
uint _m0;
uint _m1;
};
struct ResType_1
{
uint2 _m0;
uint2 _m1;
};
struct ResType_2
{
uint3 _m0;
uint3 _m1;
};
struct ResType_3
{
uint4 _m0;
uint4 _m1;
};
struct SSBOInt
{
int a;
int b;
int c;
int d;
int2 a2;
int2 b2;
int2 c2;
int2 d2;
int3 a3;
int3 b3;
int3 c3;
int3 d3;
int4 a4;
int4 b4;
int4 c4;
int4 d4;
};
struct ResType_4
{
int _m0;
int _m1;
};
struct ResType_5
{
int2 _m0;
int2 _m1;
};
struct ResType_6
{
int3 _m0;
int3 _m1;
};
struct ResType_7
{
int4 _m0;
int4 _m1;
};
kernel void main0(device SSBOUint& u [[buffer(0)]], device SSBOInt& i [[buffer(1)]])
{
ResType _25;
_25._m0 = u.a + u.b;
_25._m1 = select(uint(1), uint(0), addsat(u.a, u.b) == UINT_MAX);
u.d = _25._m1;
u.c = _25._m0;
ResType_1 _40;
_40._m0 = u.a2 + u.b2;
_40._m1 = select(uint2(1), uint2(0), addsat(u.a2, u.b2) == UINT_MAX);
u.d2 = _40._m1;
u.c2 = _40._m0;
ResType_2 _55;
_55._m0 = u.a3 + u.b3;
_55._m1 = select(uint3(1), uint3(0), addsat(u.a3, u.b3) == UINT_MAX);
u.d3 = _55._m1;
u.c3 = _55._m0;
ResType_3 _70;
_70._m0 = u.a4 + u.b4;
_70._m1 = select(uint4(1), uint4(0), addsat(u.a4, u.b4) == UINT_MAX);
u.d4 = _70._m1;
u.c4 = _70._m0;
ResType _79;
_79._m0 = u.a - u.b;
_79._m1 = select(uint(1), uint(0), subsat(u.a, u.b) == 0);
u.d = _79._m1;
u.c = _79._m0;
ResType_1 _88;
_88._m0 = u.a2 - u.b2;
_88._m1 = select(uint2(1), uint2(0), subsat(u.a2, u.b2) == 0);
u.d2 = _88._m1;
u.c2 = _88._m0;
ResType_2 _97;
_97._m0 = u.a3 - u.b3;
_97._m1 = select(uint3(1), uint3(0), subsat(u.a3, u.b3) == 0);
u.d3 = _97._m1;
u.c3 = _97._m0;
ResType_3 _106;
_106._m0 = u.a4 - u.b4;
_106._m1 = select(uint4(1), uint4(0), subsat(u.a4, u.b4) == 0);
u.d4 = _106._m1;
u.c4 = _106._m0;
ResType _116;
_116._m0 = u.a * u.b;
_116._m1 = mulhi(u.a, u.b);
u.d = _116._m0;
u.c = _116._m1;
ResType_1 _125;
_125._m0 = u.a2 * u.b2;
_125._m1 = mulhi(u.a2, u.b2);
u.d2 = _125._m0;
u.c2 = _125._m1;
ResType_2 _134;
_134._m0 = u.a3 * u.b3;
_134._m1 = mulhi(u.a3, u.b3);
u.d3 = _134._m0;
u.c3 = _134._m1;
ResType_3 _143;
_143._m0 = u.a4 * u.b4;
_143._m1 = mulhi(u.a4, u.b4);
u.d4 = _143._m0;
u.c4 = _143._m1;
ResType_4 _160;
_160._m0 = i.a * i.b;
_160._m1 = mulhi(i.a, i.b);
i.d = _160._m0;
i.c = _160._m1;
ResType_5 _171;
_171._m0 = i.a2 * i.b2;
_171._m1 = mulhi(i.a2, i.b2);
i.d2 = _171._m0;
i.c2 = _171._m1;
ResType_6 _182;
_182._m0 = i.a3 * i.b3;
_182._m1 = mulhi(i.a3, i.b3);
i.d3 = _182._m0;
i.c3 = _182._m1;
ResType_7 _193;
_193._m0 = i.a4 * i.b4;
_193._m1 = mulhi(i.a4, i.b4);
i.d4 = _193._m0;
i.c4 = _193._m1;
}

View File

@ -0,0 +1,41 @@
#version 450
layout(local_size_x = 1) in;
layout(binding = 0, std430) buffer SSBOUint
{
uint a, b, c, d;
uvec2 a2, b2, c2, d2;
uvec3 a3, b3, c3, d3;
uvec4 a4, b4, c4, d4;
} u;
layout(binding = 1, std430) buffer SSBOInt
{
int a, b, c, d;
ivec2 a2, b2, c2, d2;
ivec3 a3, b3, c3, d3;
ivec4 a4, b4, c4, d4;
} i;
void main()
{
u.c = uaddCarry(u.a, u.b, u.d);
u.c2 = uaddCarry(u.a2, u.b2, u.d2);
u.c3 = uaddCarry(u.a3, u.b3, u.d3);
u.c4 = uaddCarry(u.a4, u.b4, u.d4);
u.c = usubBorrow(u.a, u.b, u.d);
u.c2 = usubBorrow(u.a2, u.b2, u.d2);
u.c3 = usubBorrow(u.a3, u.b3, u.d3);
u.c4 = usubBorrow(u.a4, u.b4, u.d4);
umulExtended(u.a, u.b, u.c, u.d);
umulExtended(u.a2, u.b2, u.c2, u.d2);
umulExtended(u.a3, u.b3, u.c3, u.d3);
umulExtended(u.a4, u.b4, u.c4, u.d4);
imulExtended(i.a, i.b, i.c, i.d);
imulExtended(i.a2, i.b2, i.c2, i.d2);
imulExtended(i.a3, i.b3, i.c3, i.d3);
imulExtended(i.a4, i.b4, i.c4, i.d4);
}

View File

@ -2005,6 +2005,21 @@ void CompilerMSL::emit_binary_unord_op(uint32_t result_type, uint32_t result_id,
inherit_expression_dependencies(result_id, op1);
}
static std::string type_to_max_limit(const SPIRType &type)
{
switch (type.basetype)
{
case SPIRType::UByte:
return "UCHAR_MAX";
case SPIRType::UShort:
return "USHRT_MAX";
case SPIRType::UInt:
return "UINT_MAX";
default:
return "";
}
}
// Override for MSL-specific syntax instructions
void CompilerMSL::emit_instruction(const Instruction &instruction)
{
@ -2484,6 +2499,57 @@ void CompilerMSL::emit_instruction(const Instruction &instruction)
// OpOuterProduct
case OpIAddCarry:
case OpISubBorrow:
{
uint32_t result_type = ops[0];
uint32_t result_id = ops[1];
uint32_t op0 = ops[2];
uint32_t op1 = ops[3];
forced_temporaries.insert(result_id);
auto &type = get<SPIRType>(result_type);
statement(variable_decl(type, to_name(result_id)), ";");
set<SPIRExpression>(result_id, to_name(result_id), result_type, true);
auto &res_type = get<SPIRType>(type.member_types[1]);
if (opcode == OpIAddCarry)
{
statement(to_expression(result_id), ".", to_member_name(type, 0), " = ", to_enclosed_expression(op0), " + ",
to_enclosed_expression(op1), ";");
statement(to_expression(result_id), ".", to_member_name(type, 1), " = select(", type_to_glsl(res_type),
"(1), ", type_to_glsl(res_type), "(0), addsat(", to_expression(op0), ", ", to_expression(op1),
") == ", type_to_max_limit(res_type), ");");
}
else
{
statement(to_expression(result_id), ".", to_member_name(type, 0), " = ", to_enclosed_expression(op0), " - ",
to_enclosed_expression(op1), ";");
statement(to_expression(result_id), ".", to_member_name(type, 1), " = select(", type_to_glsl(res_type),
"(1), ", type_to_glsl(res_type), "(0), subsat(", to_expression(op0), ", ", to_expression(op1),
") == 0);");
}
break;
}
case OpUMulExtended:
case OpSMulExtended:
{
uint32_t result_type = ops[0];
uint32_t result_id = ops[1];
uint32_t op0 = ops[2];
uint32_t op1 = ops[3];
forced_temporaries.insert(result_id);
auto &type = get<SPIRType>(result_type);
statement(variable_decl(type, to_name(result_id)), ";");
set<SPIRExpression>(result_id, to_name(result_id), result_type, true);
statement(to_expression(result_id), ".", to_member_name(type, 0), " = ", to_enclosed_expression(op0), " * ",
to_enclosed_expression(op1), ";");
statement(to_expression(result_id), ".", to_member_name(type, 1), " = mulhi(", to_expression(op0), ", ",
to_expression(op1), ");");
break;
}
default:
CompilerGLSL::emit_instruction(instruction);
break;