MSL: Add support for SPV_EXT_integer_dot_product

This commit is contained in:
sean 2024-01-06 16:59:09 +01:00 committed by Hans-Kristian Arntzen
parent 0a5e7b0f6a
commit 6c24be197f
4 changed files with 216 additions and 1 deletions

View File

@ -0,0 +1,44 @@
#pragma clang diagnostic ignored "-Wmissing-prototypes"
#include <metal_stdlib>
#include <simd/simd.h>
using namespace metal;
template <typename T>
T reduce_add(vec<T, 2> v) { return v.x + v.y; }
template <typename T>
T reduce_add(vec<T, 3> v) { return v.x + v.y + v.z; }
template <typename T>
T reduce_add(vec<T, 4> v) { return v.x + v.y + v.z + v.w; }
struct InOut
{
uint4 x;
uint4 y;
int result;
};
struct InOut2
{
uint x;
uint y;
uint result;
};
struct InOut3
{
ushort4 x;
ushort4 y;
int acc;
int result;
};
kernel void main0(device InOut& comp [[buffer(0)]], device void* spvBufferAliasSet0Binding1 [[buffer(1)]])
{
device auto& comp2 = *(device InOut2*)spvBufferAliasSet0Binding1;
device auto& comp3 = *(device InOut3*)spvBufferAliasSet0Binding1;
comp.result = reduce_add(int4(comp.x) * int4(comp.y));
comp2.result = uint(reduce_add(uchar4(as_type<uchar4>(comp2.x)) * uchar4(as_type<uchar4>(comp2.y))));
comp3.result = addsat(reduce_add(int4(comp3.x) * int4(comp3.y)), comp3.acc);
}

View File

@ -0,0 +1,43 @@
#version 450
#extension GL_EXT_shader_8bit_storage : require
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_shader_explicit_arithmetic_types : require
#extension GL_EXT_spirv_intrinsics : require
layout(local_size_x = 1) in;
layout(std430, binding = 0) buffer InOut {
uvec4 x;
uvec4 y;
int result;
} comp;
layout(std430, binding = 1) buffer InOut2 {
uint x;
uint y;
uint result;
} comp2;
layout(std430, binding = 1) buffer InOut3 {
u16vec4 x;
u16vec4 y;
int acc;
int result;
} comp3;
// Signed integer dot with unsigned integer
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4450)
int sdot(uvec4 x, uvec4 y);
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4451)
uint8_t updot(uint x, uint y, spirv_literal uint packedFormat);
// SDotAccSat with unsigned input and result type
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4453)
int sdotaddsat(u16vec4 x, u16vec4 y, int acc);
void main() {
comp.result = sdot(comp.x, comp.y);
comp2.result = uint(updot(comp2.x, comp2.y, 0x0)); // PackedVectorFormat4x8Bit
comp3.result = sdotaddsat(comp3.x, comp3.y, comp3.acc);
}

View File

@ -7458,6 +7458,22 @@ void CompilerMSL::emit_custom_functions()
statement("");
break;
case SPVFuncImplReduceAdd:
// Metal doesn't support __builtin_reduce_add or simd_reduce_add, so we need this.
// Metal also doesn't support the other vector builtins, which would have been useful to make this a single template.
statement("template <typename T>");
statement("T reduce_add(vec<T, 2> v) { return v.x + v.y; }");
statement("template <typename T>");
statement("T reduce_add(vec<T, 3> v) { return v.x + v.y + v.z; }");
statement("template <typename T>");
statement("T reduce_add(vec<T, 4> v) { return v.x + v.y + v.z + v.w; }");
statement("");
break;
default:
break;
}
@ -9641,6 +9657,109 @@ void CompilerMSL::emit_instruction(const Instruction &instruction)
break;
}
case OpSDot:
case OpUDot:
case OpSUDot:
{
uint32_t result_type = ops[0];
uint32_t id = ops[1];
uint32_t vec1 = ops[2];
uint32_t vec2 = ops[3];
auto& input_type1 = expression_type(vec1);
string vec1input, vec2input;
auto input_size = input_type1.vecsize;
if (instruction.length == 5)
{
if (ops[4] == PackedVectorFormatPackedVectorFormat4x8Bit)
{
string type = opcode == OpSDot || opcode == OpSUDot ? "char4" : "uchar4";
vec1input = join("as_type<", type, ">(", to_expression(vec1), ")");
type = opcode == OpSDot ? "char4" : "uchar4";
vec2input = join("as_type<", type, ">(", to_expression(vec2), ")");
input_size = 4;
}
else
SPIRV_CROSS_THROW("Packed vector formats other than 4x8Bit for integer dot product is not supported.");
}
else
{
vec1input = to_expression(vec1);
vec2input = to_expression(vec2);
}
auto& type = get<SPIRType>(result_type);
auto result_type_cast = join(type_to_glsl(type), input_size);
// When the opcode specifies signed integers, we always cast to the signed integer type, regardless of the output type.
string_view type_cast1(result_type_cast);
if (type_cast1[0] == 'u' && (opcode == OpSDot || opcode == OpSUDot))
type_cast1 = type_cast1.substr(1);
string_view type_cast2(result_type_cast);
if (type_cast2[0] == 'u' && opcode == OpSDot)
type_cast2 = type_cast2.substr(1);
string exp = join("reduce_add(", std::string(type_cast1), "(", vec1input, ") * ", std::string(type_cast2), "(", vec2input, "))");
emit_op(result_type, id, exp, should_forward(vec1) && should_forward(vec2));
inherit_expression_dependencies(id, vec1);
inherit_expression_dependencies(id, vec2);
break;
}
case OpSDotAccSat:
case OpUDotAccSat:
case OpSUDotAccSat:
{
uint32_t result_type = ops[0];
uint32_t id = ops[1];
uint32_t vec1 = ops[2];
uint32_t vec2 = ops[3];
uint32_t acc = ops[4];
auto& input_type1 = expression_type(vec1);
auto& input_type2 = expression_type(vec2);
string vec1input, vec2input;
auto input_size = input_type1.vecsize;
if (instruction.length == 6)
{
if (ops[5] == PackedVectorFormatPackedVectorFormat4x8Bit)
{
string type = opcode == OpSDotAccSat || opcode == OpSUDotAccSat ? "char4" : "uchar4";
vec1input = join("as_type<", type, ">(", to_expression(vec1), ")");
type = opcode == OpSDotAccSat ? "char4" : "uchar4";
vec2input = join("as_type<", type, ">(", to_expression(vec2), ")");
input_size = 4;
}
else
SPIRV_CROSS_THROW("Packed vector formats other than 4x8Bit for integer dot product is not supported.");
}
else
{
vec1input = to_expression(vec1);
vec2input = to_expression(vec2);
}
auto& type = get<SPIRType>(result_type);
auto result_type_cast = join(type_to_glsl(type), input_size);
string_view type_cast1(result_type_cast);
if (type_cast1[0] == 'u' && (opcode == OpSDotAccSat || opcode == OpSUDotAccSat))
type_cast1 = type_cast1.substr(1);
string_view type_cast2(result_type_cast);
if (type_cast2[0] == 'u' && opcode == OpSDotAccSat)
type_cast2 = type_cast2.substr(1);
string exp = join("addsat(reduce_add(", std::string(type_cast1), "(", vec1input, ") * ", std::string(type_cast2), "(", vec2input, ")), ", to_expression(acc), ")");
emit_op(result_type, id, exp, should_forward(vec1) && should_forward(vec2));
inherit_expression_dependencies(id, vec1);
inherit_expression_dependencies(id, vec2);
break;
}
default:
CompilerGLSL::emit_instruction(instruction);
break;
@ -17266,6 +17385,14 @@ CompilerMSL::SPVFuncImpl CompilerMSL::OpCodePreprocessor::get_spv_func_impl(Op o
case OpGroupNonUniformQuadSwap:
return SPVFuncImplQuadSwap;
case OpSDot:
case OpUDot:
case OpSUDot:
case OpSDotAccSat:
case OpUDotAccSat:
case OpSUDotAccSat:
return SPVFuncImplReduceAdd;
default:
break;
}

View File

@ -823,7 +823,8 @@ protected:
SPVFuncImplVariableDescriptor,
SPVFuncImplVariableSizedDescriptor,
SPVFuncImplVariableDescriptorArray,
SPVFuncImplPaddedStd140
SPVFuncImplPaddedStd140,
SPVFuncImplReduceAdd
};
// If the underlying resource has been used for comparison then duplicate loads of that resource must be too