MSL: Add support for SPV_EXT_integer_dot_product
This commit is contained in:
parent
0a5e7b0f6a
commit
6c24be197f
44
reference/shaders-msl/comp/integer-dot-product.comp
Normal file
44
reference/shaders-msl/comp/integer-dot-product.comp
Normal file
@ -0,0 +1,44 @@
|
||||
#pragma clang diagnostic ignored "-Wmissing-prototypes"
|
||||
|
||||
#include <metal_stdlib>
|
||||
#include <simd/simd.h>
|
||||
|
||||
using namespace metal;
|
||||
|
||||
template <typename T>
|
||||
T reduce_add(vec<T, 2> v) { return v.x + v.y; }
|
||||
template <typename T>
|
||||
T reduce_add(vec<T, 3> v) { return v.x + v.y + v.z; }
|
||||
template <typename T>
|
||||
T reduce_add(vec<T, 4> v) { return v.x + v.y + v.z + v.w; }
|
||||
|
||||
struct InOut
|
||||
{
|
||||
uint4 x;
|
||||
uint4 y;
|
||||
int result;
|
||||
};
|
||||
|
||||
struct InOut2
|
||||
{
|
||||
uint x;
|
||||
uint y;
|
||||
uint result;
|
||||
};
|
||||
|
||||
struct InOut3
|
||||
{
|
||||
ushort4 x;
|
||||
ushort4 y;
|
||||
int acc;
|
||||
int result;
|
||||
};
|
||||
|
||||
kernel void main0(device InOut& comp [[buffer(0)]], device void* spvBufferAliasSet0Binding1 [[buffer(1)]])
|
||||
{
|
||||
device auto& comp2 = *(device InOut2*)spvBufferAliasSet0Binding1;
|
||||
device auto& comp3 = *(device InOut3*)spvBufferAliasSet0Binding1;
|
||||
comp.result = reduce_add(int4(comp.x) * int4(comp.y));
|
||||
comp2.result = uint(reduce_add(uchar4(as_type<uchar4>(comp2.x)) * uchar4(as_type<uchar4>(comp2.y))));
|
||||
comp3.result = addsat(reduce_add(int4(comp3.x) * int4(comp3.y)), comp3.acc);
|
||||
}
|
43
shaders-msl/comp/integer-dot-product.comp
Normal file
43
shaders-msl/comp/integer-dot-product.comp
Normal file
@ -0,0 +1,43 @@
|
||||
#version 450
|
||||
#extension GL_EXT_shader_8bit_storage : require
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types : require
|
||||
#extension GL_EXT_spirv_intrinsics : require
|
||||
|
||||
layout(local_size_x = 1) in;
|
||||
|
||||
layout(std430, binding = 0) buffer InOut {
|
||||
uvec4 x;
|
||||
uvec4 y;
|
||||
int result;
|
||||
} comp;
|
||||
|
||||
layout(std430, binding = 1) buffer InOut2 {
|
||||
uint x;
|
||||
uint y;
|
||||
uint result;
|
||||
} comp2;
|
||||
|
||||
layout(std430, binding = 1) buffer InOut3 {
|
||||
u16vec4 x;
|
||||
u16vec4 y;
|
||||
int acc;
|
||||
int result;
|
||||
} comp3;
|
||||
|
||||
// Signed integer dot with unsigned integer
|
||||
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4450)
|
||||
int sdot(uvec4 x, uvec4 y);
|
||||
|
||||
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4451)
|
||||
uint8_t updot(uint x, uint y, spirv_literal uint packedFormat);
|
||||
|
||||
// SDotAccSat with unsigned input and result type
|
||||
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4453)
|
||||
int sdotaddsat(u16vec4 x, u16vec4 y, int acc);
|
||||
|
||||
void main() {
|
||||
comp.result = sdot(comp.x, comp.y);
|
||||
comp2.result = uint(updot(comp2.x, comp2.y, 0x0)); // PackedVectorFormat4x8Bit
|
||||
comp3.result = sdotaddsat(comp3.x, comp3.y, comp3.acc);
|
||||
}
|
127
spirv_msl.cpp
127
spirv_msl.cpp
@ -7458,6 +7458,22 @@ void CompilerMSL::emit_custom_functions()
|
||||
statement("");
|
||||
break;
|
||||
|
||||
case SPVFuncImplReduceAdd:
|
||||
// Metal doesn't support __builtin_reduce_add or simd_reduce_add, so we need this.
|
||||
// Metal also doesn't support the other vector builtins, which would have been useful to make this a single template.
|
||||
|
||||
statement("template <typename T>");
|
||||
statement("T reduce_add(vec<T, 2> v) { return v.x + v.y; }");
|
||||
|
||||
statement("template <typename T>");
|
||||
statement("T reduce_add(vec<T, 3> v) { return v.x + v.y + v.z; }");
|
||||
|
||||
statement("template <typename T>");
|
||||
statement("T reduce_add(vec<T, 4> v) { return v.x + v.y + v.z + v.w; }");
|
||||
|
||||
statement("");
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
@ -9641,6 +9657,109 @@ void CompilerMSL::emit_instruction(const Instruction &instruction)
|
||||
break;
|
||||
}
|
||||
|
||||
case OpSDot:
|
||||
case OpUDot:
|
||||
case OpSUDot:
|
||||
{
|
||||
uint32_t result_type = ops[0];
|
||||
uint32_t id = ops[1];
|
||||
uint32_t vec1 = ops[2];
|
||||
uint32_t vec2 = ops[3];
|
||||
|
||||
auto& input_type1 = expression_type(vec1);
|
||||
|
||||
string vec1input, vec2input;
|
||||
auto input_size = input_type1.vecsize;
|
||||
if (instruction.length == 5)
|
||||
{
|
||||
if (ops[4] == PackedVectorFormatPackedVectorFormat4x8Bit)
|
||||
{
|
||||
string type = opcode == OpSDot || opcode == OpSUDot ? "char4" : "uchar4";
|
||||
vec1input = join("as_type<", type, ">(", to_expression(vec1), ")");
|
||||
type = opcode == OpSDot ? "char4" : "uchar4";
|
||||
vec2input = join("as_type<", type, ">(", to_expression(vec2), ")");
|
||||
input_size = 4;
|
||||
}
|
||||
else
|
||||
SPIRV_CROSS_THROW("Packed vector formats other than 4x8Bit for integer dot product is not supported.");
|
||||
}
|
||||
else
|
||||
{
|
||||
vec1input = to_expression(vec1);
|
||||
vec2input = to_expression(vec2);
|
||||
}
|
||||
|
||||
auto& type = get<SPIRType>(result_type);
|
||||
auto result_type_cast = join(type_to_glsl(type), input_size);
|
||||
|
||||
// When the opcode specifies signed integers, we always cast to the signed integer type, regardless of the output type.
|
||||
string_view type_cast1(result_type_cast);
|
||||
if (type_cast1[0] == 'u' && (opcode == OpSDot || opcode == OpSUDot))
|
||||
type_cast1 = type_cast1.substr(1);
|
||||
|
||||
string_view type_cast2(result_type_cast);
|
||||
if (type_cast2[0] == 'u' && opcode == OpSDot)
|
||||
type_cast2 = type_cast2.substr(1);
|
||||
|
||||
string exp = join("reduce_add(", std::string(type_cast1), "(", vec1input, ") * ", std::string(type_cast2), "(", vec2input, "))");
|
||||
emit_op(result_type, id, exp, should_forward(vec1) && should_forward(vec2));
|
||||
inherit_expression_dependencies(id, vec1);
|
||||
inherit_expression_dependencies(id, vec2);
|
||||
break;
|
||||
}
|
||||
|
||||
case OpSDotAccSat:
|
||||
case OpUDotAccSat:
|
||||
case OpSUDotAccSat:
|
||||
{
|
||||
uint32_t result_type = ops[0];
|
||||
uint32_t id = ops[1];
|
||||
uint32_t vec1 = ops[2];
|
||||
uint32_t vec2 = ops[3];
|
||||
uint32_t acc = ops[4];
|
||||
|
||||
auto& input_type1 = expression_type(vec1);
|
||||
auto& input_type2 = expression_type(vec2);
|
||||
|
||||
string vec1input, vec2input;
|
||||
auto input_size = input_type1.vecsize;
|
||||
if (instruction.length == 6)
|
||||
{
|
||||
if (ops[5] == PackedVectorFormatPackedVectorFormat4x8Bit)
|
||||
{
|
||||
string type = opcode == OpSDotAccSat || opcode == OpSUDotAccSat ? "char4" : "uchar4";
|
||||
vec1input = join("as_type<", type, ">(", to_expression(vec1), ")");
|
||||
type = opcode == OpSDotAccSat ? "char4" : "uchar4";
|
||||
vec2input = join("as_type<", type, ">(", to_expression(vec2), ")");
|
||||
input_size = 4;
|
||||
}
|
||||
else
|
||||
SPIRV_CROSS_THROW("Packed vector formats other than 4x8Bit for integer dot product is not supported.");
|
||||
}
|
||||
else
|
||||
{
|
||||
vec1input = to_expression(vec1);
|
||||
vec2input = to_expression(vec2);
|
||||
}
|
||||
|
||||
auto& type = get<SPIRType>(result_type);
|
||||
auto result_type_cast = join(type_to_glsl(type), input_size);
|
||||
|
||||
string_view type_cast1(result_type_cast);
|
||||
if (type_cast1[0] == 'u' && (opcode == OpSDotAccSat || opcode == OpSUDotAccSat))
|
||||
type_cast1 = type_cast1.substr(1);
|
||||
|
||||
string_view type_cast2(result_type_cast);
|
||||
if (type_cast2[0] == 'u' && opcode == OpSDotAccSat)
|
||||
type_cast2 = type_cast2.substr(1);
|
||||
|
||||
string exp = join("addsat(reduce_add(", std::string(type_cast1), "(", vec1input, ") * ", std::string(type_cast2), "(", vec2input, ")), ", to_expression(acc), ")");
|
||||
emit_op(result_type, id, exp, should_forward(vec1) && should_forward(vec2));
|
||||
inherit_expression_dependencies(id, vec1);
|
||||
inherit_expression_dependencies(id, vec2);
|
||||
break;
|
||||
}
|
||||
|
||||
default:
|
||||
CompilerGLSL::emit_instruction(instruction);
|
||||
break;
|
||||
@ -17266,6 +17385,14 @@ CompilerMSL::SPVFuncImpl CompilerMSL::OpCodePreprocessor::get_spv_func_impl(Op o
|
||||
case OpGroupNonUniformQuadSwap:
|
||||
return SPVFuncImplQuadSwap;
|
||||
|
||||
case OpSDot:
|
||||
case OpUDot:
|
||||
case OpSUDot:
|
||||
case OpSDotAccSat:
|
||||
case OpUDotAccSat:
|
||||
case OpSUDotAccSat:
|
||||
return SPVFuncImplReduceAdd;
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
@ -823,7 +823,8 @@ protected:
|
||||
SPVFuncImplVariableDescriptor,
|
||||
SPVFuncImplVariableSizedDescriptor,
|
||||
SPVFuncImplVariableDescriptorArray,
|
||||
SPVFuncImplPaddedStd140
|
||||
SPVFuncImplPaddedStd140,
|
||||
SPVFuncImplReduceAdd
|
||||
};
|
||||
|
||||
// If the underlying resource has been used for comparison then duplicate loads of that resource must be too
|
||||
|
Loading…
Reference in New Issue
Block a user