Merge pull request #2266 from KhronosGroup/pr-2257

Land MSL integer dot products
This commit is contained in:
Hans-Kristian Arntzen 2024-01-16 16:24:33 +01:00 committed by GitHub
commit 64f64c837a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 331 additions and 1 deletions

View File

@ -0,0 +1,65 @@
#pragma clang diagnostic ignored "-Wmissing-prototypes"
#include <metal_stdlib>
#include <simd/simd.h>
using namespace metal;
template <typename T>
T reduce_add(vec<T, 2> v) { return v.x + v.y; }
template <typename T>
T reduce_add(vec<T, 3> v) { return v.x + v.y + v.z; }
template <typename T>
T reduce_add(vec<T, 4> v) { return v.x + v.y + v.z + v.w; }
struct InOut3
{
ushort4 x;
ushort4 y;
int acc;
int result;
};
struct InOut2
{
uint x;
uint y;
uint result;
};
constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(1u);
struct InOut
{
uint4 x;
uint4 y;
int result;
};
kernel void main0(device void* spvBufferAliasSet0Binding1 [[buffer(0)]])
{
device auto& comp3 = *(device InOut3*)spvBufferAliasSet0Binding1;
device auto& comp2 = *(device InOut2*)spvBufferAliasSet0Binding1;
int sdot_int = reduce_add(int4(short4(comp3.x)) * int4(short4(comp3.y)));
uint sdot_uint = reduce_add(uint4(short4(comp3.x)) * uint4(short4(comp3.y)));
uint udot_uint = reduce_add(uint4(comp3.x) * uint4(comp3.y));
int sudot_int = reduce_add(int4(short4(comp3.x)) * int4(comp3.y));
uint sudot_uint = reduce_add(uint4(short4(comp3.x)) * uint4(comp3.y));
uchar spdot8 = reduce_add(uchar4(as_type<char4>(comp2.x)) * uchar4(as_type<char4>(comp2.y)));
ushort spdot16 = reduce_add(ushort4(as_type<char4>(comp2.x)) * ushort4(as_type<char4>(comp2.y)));
uint spdot32 = reduce_add(uint4(as_type<char4>(comp2.x)) * uint4(as_type<char4>(comp2.y)));
int spdoti32 = reduce_add(int4(as_type<char4>(comp2.x)) * int4(as_type<char4>(comp2.y)));
uchar updot8 = reduce_add(uchar4(as_type<uchar4>(comp2.x)) * uchar4(as_type<uchar4>(comp2.y)));
ushort updot16 = reduce_add(ushort4(as_type<uchar4>(comp2.x)) * ushort4(as_type<uchar4>(comp2.y)));
uint updot32 = reduce_add(uint4(as_type<uchar4>(comp2.x)) * uint4(as_type<uchar4>(comp2.y)));
uchar supdot8 = reduce_add(uchar4(as_type<char4>(comp2.x)) * uchar4(as_type<uchar4>(comp2.y)));
ushort supdot16 = reduce_add(ushort4(as_type<char4>(comp2.x)) * ushort4(as_type<uchar4>(comp2.y)));
uint supdot32 = reduce_add(uint4(as_type<char4>(comp2.x)) * uint4(as_type<uchar4>(comp2.y)));
int supdoti32 = reduce_add(int4(as_type<char4>(comp2.x)) * int4(as_type<uchar4>(comp2.y)));
int sdotaddsat_int = int(addsat(reduce_add(int4(short4(comp3.x)) * int4(short4(comp3.y))), comp3.acc));
uint sdotaddsat_uint = uint(addsat(reduce_add(int4(short4(comp3.x)) * int4(short4(comp3.y))), comp3.acc));
uint udotaddsat_uint = uint(addsat(reduce_add(uint4(comp3.x) * uint4(comp3.y)), uint(comp3.acc)));
int sudotaddsat_int = int(addsat(reduce_add(int4(short4(comp3.x)) * int4(comp3.y)), comp3.acc));
uint sudotaddsat_uint = uint(addsat(reduce_add(int4(short4(comp3.x)) * int4(comp3.y)), comp3.acc));
}

View File

@ -0,0 +1,114 @@
#version 450
#extension GL_EXT_shader_8bit_storage : require
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_shader_explicit_arithmetic_types : require
#extension GL_EXT_spirv_intrinsics : require
layout(local_size_x = 1) in;
layout(std430, binding = 0) buffer InOut {
uvec4 x;
uvec4 y;
int result;
} comp;
layout(std430, binding = 1) buffer InOut2 {
uint x;
uint y;
uint result;
} comp2;
layout(std430, binding = 1) buffer InOut3 {
u16vec4 x;
u16vec4 y;
int acc;
int result;
} comp3;
// Signed integer dot with unsigned integer
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4450)
int sdot_int_result(u16vec4 x, u16vec4 y);
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4450)
uint sdot_uint_result(u16vec4 x, u16vec4 y);
// Unsigned integer dot with signed integer. Only unsigned result is allowed in SPIR-V.
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4451)
uint udot_uint_result(u16vec4 x, u16vec4 y);
// Mixed integer dot with unsigned integer
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4452)
int sudot_int_result(u16vec4 x, u16vec4 y);
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4452)
uint sudot_uint_result(u16vec4 x, u16vec4 y);
// Signed packed dot product with different output widths.
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4450)
uint8_t spdot_to_8(uint x, uint y, spirv_literal uint packedFormat);
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4450)
uint16_t spdot_to_16(uint x, uint y, spirv_literal uint packedFormat);
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4450)
uint spdot_to_32(uint x, uint y, spirv_literal uint packedFormat);
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4450)
int spdot_to_i32(uint x, uint y, spirv_literal uint packedFormat);
// Unsigned packed dot product with different output widths.
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4451)
uint8_t updot_to_8(uint x, uint y, spirv_literal uint packedFormat);
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4451)
uint16_t updot_to_16(uint x, uint y, spirv_literal uint packedFormat);
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4451)
uint updot_to_32(uint x, uint y, spirv_literal uint packedFormat);
// Mixed packed dot product with different output widths.
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4452)
uint8_t supdot_to_8(uint x, uint y, spirv_literal uint packedFormat);
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4452)
uint16_t supdot_to_16(uint x, uint y, spirv_literal uint packedFormat);
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4452)
uint supdot_to_32(uint x, uint y, spirv_literal uint packedFormat);
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4452)
int supdot_to_i32(uint x, uint y, spirv_literal uint packedFormat);
// SDotAccSat with unsigned input and result type
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4453)
int sdotaddsat_int_result(u16vec4 x, u16vec4 y, int acc);
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4453)
uint sdotaddsat_uint_result(u16vec4 x, u16vec4 y, int acc);
// UDotAccSat. Result type must be unsigned in SPIR-V.
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4454)
uint udotaddsat(u16vec4 x, u16vec4 y, int acc);
// SUDotAccSat
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4455)
int sudotaddsat_int_result(u16vec4 x, u16vec4 y, int acc);
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4455)
uint sudotaddsat_uint_result(u16vec4 x, u16vec4 y, int acc);
void main() {
int sdot_int = sdot_int_result(comp3.x, comp3.y);
uint sdot_uint = sdot_uint_result(comp3.x, comp3.y);
uint udot_uint = udot_uint_result(comp3.x, comp3.y);
int sudot_int = sudot_int_result(comp3.x, comp3.y);
uint sudot_uint = sudot_uint_result(comp3.x, comp3.y);
uint8_t spdot8 = spdot_to_8(comp2.x, comp2.y, 0x0); // PackedVectorFormat4x8Bit
uint16_t spdot16 = spdot_to_16(comp2.x, comp2.y, 0x0); // PackedVectorFormat4x8Bit
uint spdot32 = spdot_to_32(comp2.x, comp2.y, 0x0); // PackedVectorFormat4x8Bit
int spdoti32 = spdot_to_i32(comp2.x, comp2.y, 0x0); // PackedVectorFormat4x8Bit
uint8_t updot8 = updot_to_8(comp2.x, comp2.y, 0x0); // PackedVectorFormat4x8Bit
uint16_t updot16 = updot_to_16(comp2.x, comp2.y, 0x0); // PackedVectorFormat4x8Bit
uint updot32 = updot_to_32(comp2.x, comp2.y, 0x0); // PackedVectorFormat4x8Bit
uint8_t supdot8 = supdot_to_8(comp2.x, comp2.y, 0x0); // PackedVectorFormat4x8Bit
uint16_t supdot16 = supdot_to_16(comp2.x, comp2.y, 0x0); // PackedVectorFormat4x8Bit
uint supdot32 = supdot_to_32(comp2.x, comp2.y, 0x0); // PackedVectorFormat4x8Bit
int supdoti32 = supdot_to_i32(comp2.x, comp2.y, 0x0); // PackedVectorFormat4x8Bit
int sdotaddsat_int = sdotaddsat_int_result(comp3.x, comp3.y, comp3.acc);
uint sdotaddsat_uint = sdotaddsat_uint_result(comp3.x, comp3.y, comp3.acc);
uint udotaddsat_uint = udotaddsat(comp3.x, comp3.y, comp3.acc);
int sudotaddsat_int = sudotaddsat_int_result(comp3.x, comp3.y, comp3.acc);
uint sudotaddsat_uint = sudotaddsat_uint_result(comp3.x, comp3.y, comp3.acc);
}

View File

@ -7458,6 +7458,22 @@ void CompilerMSL::emit_custom_functions()
statement("");
break;
case SPVFuncImplReduceAdd:
// Metal doesn't support __builtin_reduce_add or simd_reduce_add, so we need this.
// Metal also doesn't support the other vector builtins, which would have been useful to make this a single template.
statement("template <typename T>");
statement("T reduce_add(vec<T, 2> v) { return v.x + v.y; }");
statement("template <typename T>");
statement("T reduce_add(vec<T, 3> v) { return v.x + v.y + v.z; }");
statement("template <typename T>");
statement("T reduce_add(vec<T, 4> v) { return v.x + v.y + v.z + v.w; }");
statement("");
break;
default:
break;
}
@ -9641,6 +9657,132 @@ void CompilerMSL::emit_instruction(const Instruction &instruction)
break;
}
case OpSDot:
case OpUDot:
case OpSUDot:
{
uint32_t result_type = ops[0];
uint32_t id = ops[1];
uint32_t vec1 = ops[2];
uint32_t vec2 = ops[3];
auto &input_type1 = expression_type(vec1);
auto &input_type2 = expression_type(vec2);
string vec1input, vec2input;
auto input_size = input_type1.vecsize;
if (instruction.length == 5)
{
if (ops[4] == PackedVectorFormatPackedVectorFormat4x8Bit)
{
string type = opcode == OpSDot || opcode == OpSUDot ? "char4" : "uchar4";
vec1input = join("as_type<", type, ">(", to_expression(vec1), ")");
type = opcode == OpSDot ? "char4" : "uchar4";
vec2input = join("as_type<", type, ">(", to_expression(vec2), ")");
input_size = 4;
}
else
SPIRV_CROSS_THROW("Packed vector formats other than 4x8Bit for integer dot product is not supported.");
}
else
{
// Inputs are sign or zero-extended to their target width.
SPIRType::BaseType vec1_expected_type =
opcode != OpUDot ?
to_signed_basetype(input_type1.width) :
to_unsigned_basetype(input_type1.width);
SPIRType::BaseType vec2_expected_type =
opcode != OpSDot ?
to_unsigned_basetype(input_type2.width) :
to_signed_basetype(input_type2.width);
vec1input = bitcast_expression(vec1_expected_type, vec1);
vec2input = bitcast_expression(vec2_expected_type, vec2);
}
auto &type = get<SPIRType>(result_type);
// We'll get the appropriate sign-extend or zero-extend, no matter which type we cast to here.
// The addition in reduce_add is sign-invariant.
auto result_type_cast = join(type_to_glsl(type), input_size);
string exp = join("reduce_add(",
result_type_cast, "(", vec1input, ") * ",
result_type_cast, "(", vec2input, "))");
emit_op(result_type, id, exp, should_forward(vec1) && should_forward(vec2));
inherit_expression_dependencies(id, vec1);
inherit_expression_dependencies(id, vec2);
break;
}
case OpSDotAccSat:
case OpUDotAccSat:
case OpSUDotAccSat:
{
uint32_t result_type = ops[0];
uint32_t id = ops[1];
uint32_t vec1 = ops[2];
uint32_t vec2 = ops[3];
uint32_t acc = ops[4];
auto input_type1 = expression_type(vec1);
auto input_type2 = expression_type(vec2);
string vec1input, vec2input;
if (instruction.length == 6)
{
if (ops[5] == PackedVectorFormatPackedVectorFormat4x8Bit)
{
string type = opcode == OpSDotAccSat || opcode == OpSUDotAccSat ? "char4" : "uchar4";
vec1input = join("as_type<", type, ">(", to_expression(vec1), ")");
type = opcode == OpSDotAccSat ? "char4" : "uchar4";
vec2input = join("as_type<", type, ">(", to_expression(vec2), ")");
input_type1.vecsize = 4;
input_type2.vecsize = 4;
}
else
SPIRV_CROSS_THROW("Packed vector formats other than 4x8Bit for integer dot product is not supported.");
}
else
{
// Inputs are sign or zero-extended to their target width.
SPIRType::BaseType vec1_expected_type =
opcode != OpUDotAccSat ?
to_signed_basetype(input_type1.width) :
to_unsigned_basetype(input_type1.width);
SPIRType::BaseType vec2_expected_type =
opcode != OpSDotAccSat ?
to_unsigned_basetype(input_type2.width) :
to_signed_basetype(input_type2.width);
vec1input = bitcast_expression(vec1_expected_type, vec1);
vec2input = bitcast_expression(vec2_expected_type, vec2);
}
auto &type = get<SPIRType>(result_type);
SPIRType::BaseType pre_saturate_type =
opcode != OpUDotAccSat ?
to_signed_basetype(type.width) :
to_unsigned_basetype(type.width);
input_type1.basetype = pre_saturate_type;
input_type2.basetype = pre_saturate_type;
string exp = join(type_to_glsl(type), "(addsat(reduce_add(",
type_to_glsl(input_type1), "(", vec1input, ") * ",
type_to_glsl(input_type2), "(", vec2input, ")), ",
bitcast_expression(pre_saturate_type, acc), "))");
emit_op(result_type, id, exp, should_forward(vec1) && should_forward(vec2));
inherit_expression_dependencies(id, vec1);
inherit_expression_dependencies(id, vec2);
break;
}
default:
CompilerGLSL::emit_instruction(instruction);
break;
@ -17266,6 +17408,14 @@ CompilerMSL::SPVFuncImpl CompilerMSL::OpCodePreprocessor::get_spv_func_impl(Op o
case OpGroupNonUniformQuadSwap:
return SPVFuncImplQuadSwap;
case OpSDot:
case OpUDot:
case OpSUDot:
case OpSDotAccSat:
case OpUDotAccSat:
case OpSUDotAccSat:
return SPVFuncImplReduceAdd;
default:
break;
}

View File

@ -823,7 +823,8 @@ protected:
SPVFuncImplVariableDescriptor,
SPVFuncImplVariableSizedDescriptor,
SPVFuncImplVariableDescriptorArray,
SPVFuncImplPaddedStd140
SPVFuncImplPaddedStd140,
SPVFuncImplReduceAdd
};
// If the underlying resource has been used for comparison then duplicate loads of that resource must be too