Merge pull request #2266 from KhronosGroup/pr-2257
Land MSL integer dot products
This commit is contained in:
commit
64f64c837a
65
reference/shaders-msl-no-opt/comp/integer-dot-product.comp
Normal file
65
reference/shaders-msl-no-opt/comp/integer-dot-product.comp
Normal file
@ -0,0 +1,65 @@
|
||||
#pragma clang diagnostic ignored "-Wmissing-prototypes"
|
||||
|
||||
#include <metal_stdlib>
|
||||
#include <simd/simd.h>
|
||||
|
||||
using namespace metal;
|
||||
|
||||
template <typename T>
|
||||
T reduce_add(vec<T, 2> v) { return v.x + v.y; }
|
||||
template <typename T>
|
||||
T reduce_add(vec<T, 3> v) { return v.x + v.y + v.z; }
|
||||
template <typename T>
|
||||
T reduce_add(vec<T, 4> v) { return v.x + v.y + v.z + v.w; }
|
||||
|
||||
struct InOut3
|
||||
{
|
||||
ushort4 x;
|
||||
ushort4 y;
|
||||
int acc;
|
||||
int result;
|
||||
};
|
||||
|
||||
struct InOut2
|
||||
{
|
||||
uint x;
|
||||
uint y;
|
||||
uint result;
|
||||
};
|
||||
|
||||
constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(1u);
|
||||
|
||||
struct InOut
|
||||
{
|
||||
uint4 x;
|
||||
uint4 y;
|
||||
int result;
|
||||
};
|
||||
|
||||
kernel void main0(device void* spvBufferAliasSet0Binding1 [[buffer(0)]])
|
||||
{
|
||||
device auto& comp3 = *(device InOut3*)spvBufferAliasSet0Binding1;
|
||||
device auto& comp2 = *(device InOut2*)spvBufferAliasSet0Binding1;
|
||||
int sdot_int = reduce_add(int4(short4(comp3.x)) * int4(short4(comp3.y)));
|
||||
uint sdot_uint = reduce_add(uint4(short4(comp3.x)) * uint4(short4(comp3.y)));
|
||||
uint udot_uint = reduce_add(uint4(comp3.x) * uint4(comp3.y));
|
||||
int sudot_int = reduce_add(int4(short4(comp3.x)) * int4(comp3.y));
|
||||
uint sudot_uint = reduce_add(uint4(short4(comp3.x)) * uint4(comp3.y));
|
||||
uchar spdot8 = reduce_add(uchar4(as_type<char4>(comp2.x)) * uchar4(as_type<char4>(comp2.y)));
|
||||
ushort spdot16 = reduce_add(ushort4(as_type<char4>(comp2.x)) * ushort4(as_type<char4>(comp2.y)));
|
||||
uint spdot32 = reduce_add(uint4(as_type<char4>(comp2.x)) * uint4(as_type<char4>(comp2.y)));
|
||||
int spdoti32 = reduce_add(int4(as_type<char4>(comp2.x)) * int4(as_type<char4>(comp2.y)));
|
||||
uchar updot8 = reduce_add(uchar4(as_type<uchar4>(comp2.x)) * uchar4(as_type<uchar4>(comp2.y)));
|
||||
ushort updot16 = reduce_add(ushort4(as_type<uchar4>(comp2.x)) * ushort4(as_type<uchar4>(comp2.y)));
|
||||
uint updot32 = reduce_add(uint4(as_type<uchar4>(comp2.x)) * uint4(as_type<uchar4>(comp2.y)));
|
||||
uchar supdot8 = reduce_add(uchar4(as_type<char4>(comp2.x)) * uchar4(as_type<uchar4>(comp2.y)));
|
||||
ushort supdot16 = reduce_add(ushort4(as_type<char4>(comp2.x)) * ushort4(as_type<uchar4>(comp2.y)));
|
||||
uint supdot32 = reduce_add(uint4(as_type<char4>(comp2.x)) * uint4(as_type<uchar4>(comp2.y)));
|
||||
int supdoti32 = reduce_add(int4(as_type<char4>(comp2.x)) * int4(as_type<uchar4>(comp2.y)));
|
||||
int sdotaddsat_int = int(addsat(reduce_add(int4(short4(comp3.x)) * int4(short4(comp3.y))), comp3.acc));
|
||||
uint sdotaddsat_uint = uint(addsat(reduce_add(int4(short4(comp3.x)) * int4(short4(comp3.y))), comp3.acc));
|
||||
uint udotaddsat_uint = uint(addsat(reduce_add(uint4(comp3.x) * uint4(comp3.y)), uint(comp3.acc)));
|
||||
int sudotaddsat_int = int(addsat(reduce_add(int4(short4(comp3.x)) * int4(comp3.y)), comp3.acc));
|
||||
uint sudotaddsat_uint = uint(addsat(reduce_add(int4(short4(comp3.x)) * int4(comp3.y)), comp3.acc));
|
||||
}
|
||||
|
114
shaders-msl-no-opt/comp/integer-dot-product.comp
Normal file
114
shaders-msl-no-opt/comp/integer-dot-product.comp
Normal file
@ -0,0 +1,114 @@
|
||||
#version 450
|
||||
#extension GL_EXT_shader_8bit_storage : require
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types : require
|
||||
#extension GL_EXT_spirv_intrinsics : require
|
||||
|
||||
layout(local_size_x = 1) in;
|
||||
|
||||
layout(std430, binding = 0) buffer InOut {
|
||||
uvec4 x;
|
||||
uvec4 y;
|
||||
int result;
|
||||
} comp;
|
||||
|
||||
layout(std430, binding = 1) buffer InOut2 {
|
||||
uint x;
|
||||
uint y;
|
||||
uint result;
|
||||
} comp2;
|
||||
|
||||
layout(std430, binding = 1) buffer InOut3 {
|
||||
u16vec4 x;
|
||||
u16vec4 y;
|
||||
int acc;
|
||||
int result;
|
||||
} comp3;
|
||||
|
||||
// Signed integer dot with unsigned integer
|
||||
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4450)
|
||||
int sdot_int_result(u16vec4 x, u16vec4 y);
|
||||
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4450)
|
||||
uint sdot_uint_result(u16vec4 x, u16vec4 y);
|
||||
|
||||
// Unsigned integer dot with signed integer. Only unsigned result is allowed in SPIR-V.
|
||||
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4451)
|
||||
uint udot_uint_result(u16vec4 x, u16vec4 y);
|
||||
|
||||
// Mixed integer dot with unsigned integer
|
||||
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4452)
|
||||
int sudot_int_result(u16vec4 x, u16vec4 y);
|
||||
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4452)
|
||||
uint sudot_uint_result(u16vec4 x, u16vec4 y);
|
||||
|
||||
// Signed packed dot product with different output widths.
|
||||
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4450)
|
||||
uint8_t spdot_to_8(uint x, uint y, spirv_literal uint packedFormat);
|
||||
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4450)
|
||||
uint16_t spdot_to_16(uint x, uint y, spirv_literal uint packedFormat);
|
||||
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4450)
|
||||
uint spdot_to_32(uint x, uint y, spirv_literal uint packedFormat);
|
||||
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4450)
|
||||
int spdot_to_i32(uint x, uint y, spirv_literal uint packedFormat);
|
||||
|
||||
// Unsigned packed dot product with different output widths.
|
||||
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4451)
|
||||
uint8_t updot_to_8(uint x, uint y, spirv_literal uint packedFormat);
|
||||
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4451)
|
||||
uint16_t updot_to_16(uint x, uint y, spirv_literal uint packedFormat);
|
||||
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4451)
|
||||
uint updot_to_32(uint x, uint y, spirv_literal uint packedFormat);
|
||||
|
||||
// Mixed packed dot product with different output widths.
|
||||
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4452)
|
||||
uint8_t supdot_to_8(uint x, uint y, spirv_literal uint packedFormat);
|
||||
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4452)
|
||||
uint16_t supdot_to_16(uint x, uint y, spirv_literal uint packedFormat);
|
||||
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4452)
|
||||
uint supdot_to_32(uint x, uint y, spirv_literal uint packedFormat);
|
||||
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4452)
|
||||
int supdot_to_i32(uint x, uint y, spirv_literal uint packedFormat);
|
||||
|
||||
// SDotAccSat with unsigned input and result type
|
||||
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4453)
|
||||
int sdotaddsat_int_result(u16vec4 x, u16vec4 y, int acc);
|
||||
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4453)
|
||||
uint sdotaddsat_uint_result(u16vec4 x, u16vec4 y, int acc);
|
||||
|
||||
// UDotAccSat. Result type must be unsigned in SPIR-V.
|
||||
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4454)
|
||||
uint udotaddsat(u16vec4 x, u16vec4 y, int acc);
|
||||
|
||||
// SUDotAccSat
|
||||
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4455)
|
||||
int sudotaddsat_int_result(u16vec4 x, u16vec4 y, int acc);
|
||||
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4455)
|
||||
uint sudotaddsat_uint_result(u16vec4 x, u16vec4 y, int acc);
|
||||
|
||||
void main() {
|
||||
int sdot_int = sdot_int_result(comp3.x, comp3.y);
|
||||
uint sdot_uint = sdot_uint_result(comp3.x, comp3.y);
|
||||
uint udot_uint = udot_uint_result(comp3.x, comp3.y);
|
||||
int sudot_int = sudot_int_result(comp3.x, comp3.y);
|
||||
uint sudot_uint = sudot_uint_result(comp3.x, comp3.y);
|
||||
|
||||
uint8_t spdot8 = spdot_to_8(comp2.x, comp2.y, 0x0); // PackedVectorFormat4x8Bit
|
||||
uint16_t spdot16 = spdot_to_16(comp2.x, comp2.y, 0x0); // PackedVectorFormat4x8Bit
|
||||
uint spdot32 = spdot_to_32(comp2.x, comp2.y, 0x0); // PackedVectorFormat4x8Bit
|
||||
int spdoti32 = spdot_to_i32(comp2.x, comp2.y, 0x0); // PackedVectorFormat4x8Bit
|
||||
|
||||
uint8_t updot8 = updot_to_8(comp2.x, comp2.y, 0x0); // PackedVectorFormat4x8Bit
|
||||
uint16_t updot16 = updot_to_16(comp2.x, comp2.y, 0x0); // PackedVectorFormat4x8Bit
|
||||
uint updot32 = updot_to_32(comp2.x, comp2.y, 0x0); // PackedVectorFormat4x8Bit
|
||||
|
||||
uint8_t supdot8 = supdot_to_8(comp2.x, comp2.y, 0x0); // PackedVectorFormat4x8Bit
|
||||
uint16_t supdot16 = supdot_to_16(comp2.x, comp2.y, 0x0); // PackedVectorFormat4x8Bit
|
||||
uint supdot32 = supdot_to_32(comp2.x, comp2.y, 0x0); // PackedVectorFormat4x8Bit
|
||||
int supdoti32 = supdot_to_i32(comp2.x, comp2.y, 0x0); // PackedVectorFormat4x8Bit
|
||||
|
||||
int sdotaddsat_int = sdotaddsat_int_result(comp3.x, comp3.y, comp3.acc);
|
||||
uint sdotaddsat_uint = sdotaddsat_uint_result(comp3.x, comp3.y, comp3.acc);
|
||||
uint udotaddsat_uint = udotaddsat(comp3.x, comp3.y, comp3.acc);
|
||||
int sudotaddsat_int = sudotaddsat_int_result(comp3.x, comp3.y, comp3.acc);
|
||||
uint sudotaddsat_uint = sudotaddsat_uint_result(comp3.x, comp3.y, comp3.acc);
|
||||
}
|
150
spirv_msl.cpp
150
spirv_msl.cpp
@ -7458,6 +7458,22 @@ void CompilerMSL::emit_custom_functions()
|
||||
statement("");
|
||||
break;
|
||||
|
||||
case SPVFuncImplReduceAdd:
|
||||
// Metal doesn't support __builtin_reduce_add or simd_reduce_add, so we need this.
|
||||
// Metal also doesn't support the other vector builtins, which would have been useful to make this a single template.
|
||||
|
||||
statement("template <typename T>");
|
||||
statement("T reduce_add(vec<T, 2> v) { return v.x + v.y; }");
|
||||
|
||||
statement("template <typename T>");
|
||||
statement("T reduce_add(vec<T, 3> v) { return v.x + v.y + v.z; }");
|
||||
|
||||
statement("template <typename T>");
|
||||
statement("T reduce_add(vec<T, 4> v) { return v.x + v.y + v.z + v.w; }");
|
||||
|
||||
statement("");
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
@ -9641,6 +9657,132 @@ void CompilerMSL::emit_instruction(const Instruction &instruction)
|
||||
break;
|
||||
}
|
||||
|
||||
case OpSDot:
|
||||
case OpUDot:
|
||||
case OpSUDot:
|
||||
{
|
||||
uint32_t result_type = ops[0];
|
||||
uint32_t id = ops[1];
|
||||
uint32_t vec1 = ops[2];
|
||||
uint32_t vec2 = ops[3];
|
||||
|
||||
auto &input_type1 = expression_type(vec1);
|
||||
auto &input_type2 = expression_type(vec2);
|
||||
|
||||
string vec1input, vec2input;
|
||||
auto input_size = input_type1.vecsize;
|
||||
if (instruction.length == 5)
|
||||
{
|
||||
if (ops[4] == PackedVectorFormatPackedVectorFormat4x8Bit)
|
||||
{
|
||||
string type = opcode == OpSDot || opcode == OpSUDot ? "char4" : "uchar4";
|
||||
vec1input = join("as_type<", type, ">(", to_expression(vec1), ")");
|
||||
type = opcode == OpSDot ? "char4" : "uchar4";
|
||||
vec2input = join("as_type<", type, ">(", to_expression(vec2), ")");
|
||||
input_size = 4;
|
||||
}
|
||||
else
|
||||
SPIRV_CROSS_THROW("Packed vector formats other than 4x8Bit for integer dot product is not supported.");
|
||||
}
|
||||
else
|
||||
{
|
||||
// Inputs are sign or zero-extended to their target width.
|
||||
SPIRType::BaseType vec1_expected_type =
|
||||
opcode != OpUDot ?
|
||||
to_signed_basetype(input_type1.width) :
|
||||
to_unsigned_basetype(input_type1.width);
|
||||
|
||||
SPIRType::BaseType vec2_expected_type =
|
||||
opcode != OpSDot ?
|
||||
to_unsigned_basetype(input_type2.width) :
|
||||
to_signed_basetype(input_type2.width);
|
||||
|
||||
vec1input = bitcast_expression(vec1_expected_type, vec1);
|
||||
vec2input = bitcast_expression(vec2_expected_type, vec2);
|
||||
}
|
||||
|
||||
auto &type = get<SPIRType>(result_type);
|
||||
|
||||
// We'll get the appropriate sign-extend or zero-extend, no matter which type we cast to here.
|
||||
// The addition in reduce_add is sign-invariant.
|
||||
auto result_type_cast = join(type_to_glsl(type), input_size);
|
||||
|
||||
string exp = join("reduce_add(",
|
||||
result_type_cast, "(", vec1input, ") * ",
|
||||
result_type_cast, "(", vec2input, "))");
|
||||
|
||||
emit_op(result_type, id, exp, should_forward(vec1) && should_forward(vec2));
|
||||
inherit_expression_dependencies(id, vec1);
|
||||
inherit_expression_dependencies(id, vec2);
|
||||
break;
|
||||
}
|
||||
|
||||
case OpSDotAccSat:
|
||||
case OpUDotAccSat:
|
||||
case OpSUDotAccSat:
|
||||
{
|
||||
uint32_t result_type = ops[0];
|
||||
uint32_t id = ops[1];
|
||||
uint32_t vec1 = ops[2];
|
||||
uint32_t vec2 = ops[3];
|
||||
uint32_t acc = ops[4];
|
||||
|
||||
auto input_type1 = expression_type(vec1);
|
||||
auto input_type2 = expression_type(vec2);
|
||||
|
||||
string vec1input, vec2input;
|
||||
if (instruction.length == 6)
|
||||
{
|
||||
if (ops[5] == PackedVectorFormatPackedVectorFormat4x8Bit)
|
||||
{
|
||||
string type = opcode == OpSDotAccSat || opcode == OpSUDotAccSat ? "char4" : "uchar4";
|
||||
vec1input = join("as_type<", type, ">(", to_expression(vec1), ")");
|
||||
type = opcode == OpSDotAccSat ? "char4" : "uchar4";
|
||||
vec2input = join("as_type<", type, ">(", to_expression(vec2), ")");
|
||||
input_type1.vecsize = 4;
|
||||
input_type2.vecsize = 4;
|
||||
}
|
||||
else
|
||||
SPIRV_CROSS_THROW("Packed vector formats other than 4x8Bit for integer dot product is not supported.");
|
||||
}
|
||||
else
|
||||
{
|
||||
// Inputs are sign or zero-extended to their target width.
|
||||
SPIRType::BaseType vec1_expected_type =
|
||||
opcode != OpUDotAccSat ?
|
||||
to_signed_basetype(input_type1.width) :
|
||||
to_unsigned_basetype(input_type1.width);
|
||||
|
||||
SPIRType::BaseType vec2_expected_type =
|
||||
opcode != OpSDotAccSat ?
|
||||
to_unsigned_basetype(input_type2.width) :
|
||||
to_signed_basetype(input_type2.width);
|
||||
|
||||
vec1input = bitcast_expression(vec1_expected_type, vec1);
|
||||
vec2input = bitcast_expression(vec2_expected_type, vec2);
|
||||
}
|
||||
|
||||
auto &type = get<SPIRType>(result_type);
|
||||
|
||||
SPIRType::BaseType pre_saturate_type =
|
||||
opcode != OpUDotAccSat ?
|
||||
to_signed_basetype(type.width) :
|
||||
to_unsigned_basetype(type.width);
|
||||
|
||||
input_type1.basetype = pre_saturate_type;
|
||||
input_type2.basetype = pre_saturate_type;
|
||||
|
||||
string exp = join(type_to_glsl(type), "(addsat(reduce_add(",
|
||||
type_to_glsl(input_type1), "(", vec1input, ") * ",
|
||||
type_to_glsl(input_type2), "(", vec2input, ")), ",
|
||||
bitcast_expression(pre_saturate_type, acc), "))");
|
||||
|
||||
emit_op(result_type, id, exp, should_forward(vec1) && should_forward(vec2));
|
||||
inherit_expression_dependencies(id, vec1);
|
||||
inherit_expression_dependencies(id, vec2);
|
||||
break;
|
||||
}
|
||||
|
||||
default:
|
||||
CompilerGLSL::emit_instruction(instruction);
|
||||
break;
|
||||
@ -17266,6 +17408,14 @@ CompilerMSL::SPVFuncImpl CompilerMSL::OpCodePreprocessor::get_spv_func_impl(Op o
|
||||
case OpGroupNonUniformQuadSwap:
|
||||
return SPVFuncImplQuadSwap;
|
||||
|
||||
case OpSDot:
|
||||
case OpUDot:
|
||||
case OpSUDot:
|
||||
case OpSDotAccSat:
|
||||
case OpUDotAccSat:
|
||||
case OpSUDotAccSat:
|
||||
return SPVFuncImplReduceAdd;
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
@ -823,7 +823,8 @@ protected:
|
||||
SPVFuncImplVariableDescriptor,
|
||||
SPVFuncImplVariableSizedDescriptor,
|
||||
SPVFuncImplVariableDescriptorArray,
|
||||
SPVFuncImplPaddedStd140
|
||||
SPVFuncImplPaddedStd140,
|
||||
SPVFuncImplReduceAdd
|
||||
};
|
||||
|
||||
// If the underlying resource has been used for comparison then duplicate loads of that resource must be too
|
||||
|
Loading…
Reference in New Issue
Block a user