MSL: Handle implicit integer promotion rules.

MSL inherits the behavior of C where arithmetic on small types are
implicitly converted to int. SPIR-V does not have this behavior, so make
sure that arithmetic results are handled correctly.
This commit is contained in:
Hans-Kristian Arntzen 2022-10-31 13:05:56 +01:00
parent c813d8d67b
commit 4de9d6c2b6
8 changed files with 259 additions and 13 deletions

View File

@ -0,0 +1,93 @@
#pragma clang diagnostic ignored "-Wmissing-prototypes"
#include <metal_stdlib>
#include <simd/simd.h>
using namespace metal;
struct BUF0
{
half2 f16s;
ushort2 u16;
short2 i16;
ushort4 u16s;
short4 i16s;
half f16;
};
static inline __attribute__((always_inline))
void test_u16(device BUF0& v_24)
{
v_24.f16 += as_type<half>(ushort(((device ushort*)&v_24.u16)[0u] + ((device ushort*)&v_24.u16)[1u]));
v_24.f16 += as_type<half>(ushort(((device ushort*)&v_24.u16)[0u] - ((device ushort*)&v_24.u16)[1u]));
v_24.f16 += as_type<half>(ushort(((device ushort*)&v_24.u16)[0u] * ((device ushort*)&v_24.u16)[1u]));
v_24.f16 += as_type<half>(ushort(((device ushort*)&v_24.u16)[0u] / ((device ushort*)&v_24.u16)[1u]));
v_24.f16 += as_type<half>(ushort(((device ushort*)&v_24.u16)[0u] % ((device ushort*)&v_24.u16)[1u]));
v_24.f16 += as_type<half>(ushort(((device ushort*)&v_24.u16)[0u] << ((device ushort*)&v_24.u16)[1u]));
v_24.f16 += as_type<half>(ushort(((device ushort*)&v_24.u16)[0u] >> ((device ushort*)&v_24.u16)[1u]));
v_24.f16 += as_type<half>(ushort(~((device ushort*)&v_24.u16)[0u]));
v_24.f16 += as_type<half>(ushort(-((device ushort*)&v_24.u16)[0u]));
v_24.f16 += as_type<half>(ushort(((device ushort*)&v_24.u16)[0u] ^ ((device ushort*)&v_24.u16)[1u]));
v_24.f16 += as_type<half>(ushort(((device ushort*)&v_24.u16)[0u] & ((device ushort*)&v_24.u16)[1u]));
v_24.f16 += as_type<half>(ushort(((device ushort*)&v_24.u16)[0u] | ((device ushort*)&v_24.u16)[1u]));
}
static inline __attribute__((always_inline))
void test_i16(device BUF0& v_24)
{
v_24.f16 += as_type<half>(short(((device short*)&v_24.i16)[0u] + ((device short*)&v_24.i16)[1u]));
v_24.f16 += as_type<half>(short(((device short*)&v_24.i16)[0u] - ((device short*)&v_24.i16)[1u]));
v_24.f16 += as_type<half>(short(((device short*)&v_24.i16)[0u] * ((device short*)&v_24.i16)[1u]));
v_24.f16 += as_type<half>(short(((device short*)&v_24.i16)[0u] / ((device short*)&v_24.i16)[1u]));
v_24.f16 += as_type<half>(short(((device short*)&v_24.i16)[0u] % ((device short*)&v_24.i16)[1u]));
v_24.f16 += as_type<half>(short(((device short*)&v_24.i16)[0u] << ((device short*)&v_24.i16)[1u]));
v_24.f16 += as_type<half>(short(((device short*)&v_24.i16)[0u] >> ((device short*)&v_24.i16)[1u]));
v_24.f16 += as_type<half>(short(~((device short*)&v_24.i16)[0u]));
v_24.f16 += as_type<half>(short(-((device short*)&v_24.i16)[0u]));
v_24.f16 += as_type<half>(short(((device short*)&v_24.i16)[0u] ^ ((device short*)&v_24.i16)[1u]));
v_24.f16 += as_type<half>(short(((device short*)&v_24.i16)[0u] & ((device short*)&v_24.i16)[1u]));
v_24.f16 += as_type<half>(short(((device short*)&v_24.i16)[0u] | ((device short*)&v_24.i16)[1u]));
}
static inline __attribute__((always_inline))
void test_u16s(device BUF0& v_24)
{
v_24.f16s += as_type<half2>(v_24.u16s.xy + v_24.u16s.zw);
v_24.f16s += as_type<half2>(v_24.u16s.xy - v_24.u16s.zw);
v_24.f16s += as_type<half2>(v_24.u16s.xy * v_24.u16s.zw);
v_24.f16s += as_type<half2>(v_24.u16s.xy / v_24.u16s.zw);
v_24.f16s += as_type<half2>(v_24.u16s.xy % v_24.u16s.zw);
v_24.f16s += as_type<half2>(v_24.u16s.xy << v_24.u16s.zw);
v_24.f16s += as_type<half2>(v_24.u16s.xy >> v_24.u16s.zw);
v_24.f16s += as_type<half2>(~v_24.u16s.xy);
v_24.f16s += as_type<half2>(-v_24.u16s.xy);
v_24.f16s += as_type<half2>(v_24.u16s.xy ^ v_24.u16s.zw);
v_24.f16s += as_type<half2>(v_24.u16s.xy & v_24.u16s.zw);
v_24.f16s += as_type<half2>(v_24.u16s.xy | v_24.u16s.zw);
}
static inline __attribute__((always_inline))
void test_i16s(device BUF0& v_24)
{
v_24.f16s += as_type<half2>(v_24.i16s.xy + v_24.i16s.zw);
v_24.f16s += as_type<half2>(v_24.i16s.xy - v_24.i16s.zw);
v_24.f16s += as_type<half2>(v_24.i16s.xy * v_24.i16s.zw);
v_24.f16s += as_type<half2>(v_24.i16s.xy / v_24.i16s.zw);
v_24.f16s += as_type<half2>(v_24.i16s.xy % v_24.i16s.zw);
v_24.f16s += as_type<half2>(v_24.i16s.xy << v_24.i16s.zw);
v_24.f16s += as_type<half2>(v_24.i16s.xy >> v_24.i16s.zw);
v_24.f16s += as_type<half2>(~v_24.i16s.xy);
v_24.f16s += as_type<half2>(-v_24.i16s.xy);
v_24.f16s += as_type<half2>(v_24.i16s.xy ^ v_24.i16s.zw);
v_24.f16s += as_type<half2>(v_24.i16s.xy & v_24.i16s.zw);
v_24.f16s += as_type<half2>(v_24.i16s.xy | v_24.i16s.zw);
}
kernel void main0(device BUF0& v_24 [[buffer(0)]])
{
test_u16(v_24);
test_i16(v_24);
test_u16s(v_24);
test_i16s(v_24);
}

View File

@ -18,7 +18,7 @@ constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(1u);
kernel void main0(constant UBO& _12 [[buffer(0)]], device SSBO& _24 [[buffer(1)]])
{
short v = as_type<short>(_12.b);
v ^= short(-32768);
v = short(v ^ short(-32768));
_24.a = as_type<half>(v);
}

View File

@ -0,0 +1,85 @@
#version 450
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
layout(set = 0, binding = 0) buffer BUF0
{
f16vec2 f16s;
u16vec2 u16;
i16vec2 i16;
u16vec4 u16s;
i16vec4 i16s;
float16_t f16;
};
void test_i16()
{
f16 += int16BitsToFloat16(i16.x + i16.y);
f16 += int16BitsToFloat16(i16.x - i16.y);
f16 += int16BitsToFloat16(i16.x * i16.y);
f16 += int16BitsToFloat16(i16.x / i16.y);
f16 += int16BitsToFloat16(i16.x % i16.y);
f16 += int16BitsToFloat16(i16.x << i16.y);
f16 += int16BitsToFloat16(i16.x >> i16.y);
f16 += int16BitsToFloat16(~i16.x);
f16 += int16BitsToFloat16(-i16.x);
f16 += int16BitsToFloat16(i16.x ^ i16.y);
f16 += int16BitsToFloat16(i16.x & i16.y);
f16 += int16BitsToFloat16(i16.x | i16.y);
}
void test_u16()
{
f16 += uint16BitsToFloat16(u16.x + u16.y);
f16 += uint16BitsToFloat16(u16.x - u16.y);
f16 += uint16BitsToFloat16(u16.x * u16.y);
f16 += uint16BitsToFloat16(u16.x / u16.y);
f16 += uint16BitsToFloat16(u16.x % u16.y);
f16 += uint16BitsToFloat16(u16.x << u16.y);
f16 += uint16BitsToFloat16(u16.x >> u16.y);
f16 += uint16BitsToFloat16(~u16.x);
f16 += uint16BitsToFloat16(-u16.x);
f16 += uint16BitsToFloat16(u16.x ^ u16.y);
f16 += uint16BitsToFloat16(u16.x & u16.y);
f16 += uint16BitsToFloat16(u16.x | u16.y);
}
void test_u16s()
{
f16s += uint16BitsToFloat16(u16s.xy + u16s.zw);
f16s += uint16BitsToFloat16(u16s.xy - u16s.zw);
f16s += uint16BitsToFloat16(u16s.xy * u16s.zw);
f16s += uint16BitsToFloat16(u16s.xy / u16s.zw);
f16s += uint16BitsToFloat16(u16s.xy % u16s.zw);
f16s += uint16BitsToFloat16(u16s.xy << u16s.zw);
f16s += uint16BitsToFloat16(u16s.xy >> u16s.zw);
f16s += uint16BitsToFloat16(~u16s.xy);
f16s += uint16BitsToFloat16(-u16s.xy);
f16s += uint16BitsToFloat16(u16s.xy ^ u16s.zw);
f16s += uint16BitsToFloat16(u16s.xy & u16s.zw);
f16s += uint16BitsToFloat16(u16s.xy | u16s.zw);
}
void test_i16s()
{
f16s += int16BitsToFloat16(i16s.xy + i16s.zw);
f16s += int16BitsToFloat16(i16s.xy - i16s.zw);
f16s += int16BitsToFloat16(i16s.xy * i16s.zw);
f16s += int16BitsToFloat16(i16s.xy / i16s.zw);
f16s += int16BitsToFloat16(i16s.xy % i16s.zw);
f16s += int16BitsToFloat16(i16s.xy << i16s.zw);
f16s += int16BitsToFloat16(i16s.xy >> i16s.zw);
f16s += int16BitsToFloat16(~i16s.xy);
f16s += int16BitsToFloat16(-i16s.xy);
f16s += int16BitsToFloat16(i16s.xy ^ i16s.zw);
f16s += int16BitsToFloat16(i16s.xy & i16s.zw);
f16s += int16BitsToFloat16(i16s.xy | i16s.zw);
}
void main()
{
test_u16();
test_i16();
test_u16s();
test_i16s();
}

View File

@ -1796,6 +1796,33 @@ static inline bool opcode_is_sign_invariant(spv::Op opcode)
}
}
static inline bool opcode_can_promote_integer_implicitly(spv::Op opcode)
{
switch (opcode)
{
case spv::OpSNegate:
case spv::OpNot:
case spv::OpBitwiseAnd:
case spv::OpBitwiseOr:
case spv::OpBitwiseXor:
case spv::OpShiftLeftLogical:
case spv::OpShiftRightLogical:
case spv::OpShiftRightArithmetic:
case spv::OpIAdd:
case spv::OpISub:
case spv::OpIMul:
case spv::OpSDiv:
case spv::OpUDiv:
case spv::OpSRem:
case spv::OpUMod:
case spv::OpSMod:
return true;
default:
return false;
}
}
struct SetBindingPair
{
uint32_t desc_set;

View File

@ -5984,6 +5984,14 @@ void CompilerGLSL::emit_unary_op(uint32_t result_type, uint32_t result_id, uint3
inherit_expression_dependencies(result_id, op0);
}
void CompilerGLSL::emit_unary_op_cast(uint32_t result_type, uint32_t result_id, uint32_t op0, const char *op)
{
auto &type = get<SPIRType>(result_type);
bool forward = should_forward(op0);
emit_op(result_type, result_id, join(type_to_glsl(type), "(", op, to_enclosed_unpacked_expression(op0), ")"), forward);
inherit_expression_dependencies(result_id, op0);
}
void CompilerGLSL::emit_binary_op(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1, const char *op)
{
// Various FP arithmetic opcodes such as add, sub, mul will hit this.
@ -6127,7 +6135,9 @@ bool CompilerGLSL::emit_complex_bitcast(uint32_t result_type, uint32_t id, uint3
}
void CompilerGLSL::emit_binary_op_cast(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1,
const char *op, SPIRType::BaseType input_type, bool skip_cast_if_equal_type)
const char *op, SPIRType::BaseType input_type,
bool skip_cast_if_equal_type,
bool implicit_integer_promotion)
{
string cast_op0, cast_op1;
auto expected_type = binary_op_bitcast_helper(cast_op0, cast_op1, input_type, op0, op1, skip_cast_if_equal_type);
@ -6136,17 +6146,23 @@ void CompilerGLSL::emit_binary_op_cast(uint32_t result_type, uint32_t result_id,
// We might have casted away from the result type, so bitcast again.
// For example, arithmetic right shift with uint inputs.
// Special case boolean outputs since relational opcodes output booleans instead of int/uint.
auto bitop = join(cast_op0, " ", op, " ", cast_op1);
string expr;
if (out_type.basetype != input_type && out_type.basetype != SPIRType::Boolean)
if (implicit_integer_promotion)
{
// Simple value cast.
expr = join(type_to_glsl(out_type), '(', bitop, ')');
}
else if (out_type.basetype != input_type && out_type.basetype != SPIRType::Boolean)
{
expected_type.basetype = input_type;
expr = bitcast_glsl_op(out_type, expected_type);
expr += '(';
expr += join(cast_op0, " ", op, " ", cast_op1);
expr += ')';
expr = join(bitcast_glsl_op(out_type, expected_type), '(', bitop, ')');
}
else
expr += join(cast_op0, " ", op, " ", cast_op1);
{
expr = std::move(bitop);
}
emit_op(result_type, result_id, expr, should_forward(op0) && should_forward(op1));
inherit_expression_dependencies(result_id, op0);
@ -10751,8 +10767,10 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction)
#define GLSL_BOP(op) emit_binary_op(ops[0], ops[1], ops[2], ops[3], #op)
#define GLSL_BOP_CAST(op, type) \
emit_binary_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, opcode_is_sign_invariant(opcode))
emit_binary_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, \
opcode_is_sign_invariant(opcode), implicit_integer_promotion)
#define GLSL_UOP(op) emit_unary_op(ops[0], ops[1], ops[2], #op)
#define GLSL_UOP_CAST(op) emit_unary_op_cast(ops[0], ops[1], ops[2], #op)
#define GLSL_QFOP(op) emit_quaternary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], ops[5], #op)
#define GLSL_TFOP(op) emit_trinary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], #op)
#define GLSL_BFOP(op) emit_binary_func_op(ops[0], ops[1], ops[2], ops[3], #op)
@ -10766,6 +10784,13 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction)
auto int_type = to_signed_basetype(integer_width);
auto uint_type = to_unsigned_basetype(integer_width);
// Handle C implicit integer promotion rules.
// If we get implicit promotion to int, need to make sure we cast by value to intended return type,
// otherwise, future sign-dependent operations and bitcasts will break.
bool implicit_integer_promotion = integer_width < 32 && backend.implicit_c_integer_promotion_rules &&
opcode_can_promote_integer_implicitly(opcode) &&
get<SPIRType>(ops[0]).vecsize == 1;
opcode = get_remapped_spirv_op(opcode);
switch (opcode)
@ -11600,6 +11625,12 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction)
break;
case OpSNegate:
if (implicit_integer_promotion || expression_type_id(ops[2]) != ops[0])
GLSL_UOP_CAST(-);
else
GLSL_UOP(-);
break;
case OpFNegate:
GLSL_UOP(-);
break;
@ -11744,6 +11775,9 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction)
auto expr = join(to_enclosed_expression(op0), " - ", to_enclosed_expression(op1), " * ", "(",
to_enclosed_expression(op0), " / ", to_enclosed_expression(op1), ")");
if (implicit_integer_promotion)
expr = join(type_to_glsl(get<SPIRType>(result_type)), '(', expr, ')');
emit_op(result_type, result_id, expr, forward);
inherit_expression_dependencies(result_id, op0);
inherit_expression_dependencies(result_id, op1);
@ -11841,7 +11875,10 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction)
}
case OpNot:
GLSL_UOP(~);
if (implicit_integer_promotion || expression_type_id(ops[2]) != ops[0])
GLSL_UOP_CAST(~);
else
GLSL_UOP(~);
break;
case OpUMod:

View File

@ -619,6 +619,7 @@ protected:
bool support_64bit_switch = false;
bool workgroup_size_is_hidden = false;
bool requires_relaxed_precision_analysis = false;
bool implicit_c_integer_promotion_rules = false;
} backend;
void emit_struct(SPIRType &type);
@ -691,7 +692,7 @@ protected:
void emit_unrolled_binary_op(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1, const char *op,
bool negate, SPIRType::BaseType expected_type);
void emit_binary_op_cast(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1, const char *op,
SPIRType::BaseType input_type, bool skip_cast_if_equal_type);
SPIRType::BaseType input_type, bool skip_cast_if_equal_type, bool implicit_integer_promotion);
SPIRType binary_op_bitcast_helper(std::string &cast_op0, std::string &cast_op1, SPIRType::BaseType &input_type,
uint32_t op0, uint32_t op1, bool skip_cast_if_equal_type);
@ -702,6 +703,7 @@ protected:
uint32_t false_value);
void emit_unary_op(uint32_t result_type, uint32_t result_id, uint32_t op0, const char *op);
void emit_unary_op_cast(uint32_t result_type, uint32_t result_id, uint32_t op0, const char *op);
bool expression_is_forwarded(uint32_t id) const;
bool expression_suppresses_usage_tracking(uint32_t id) const;
bool expression_read_implies_multiple_reads(uint32_t id) const;

View File

@ -4965,7 +4965,7 @@ void CompilerHLSL::emit_instruction(const Instruction &instruction)
#define HLSL_BOP(op) emit_binary_op(ops[0], ops[1], ops[2], ops[3], #op)
#define HLSL_BOP_CAST(op, type) \
emit_binary_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, opcode_is_sign_invariant(opcode))
emit_binary_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, opcode_is_sign_invariant(opcode), false)
#define HLSL_UOP(op) emit_unary_op(ops[0], ops[1], ops[2], #op)
#define HLSL_QFOP(op) emit_quaternary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], ops[5], #op)
#define HLSL_TFOP(op) emit_trinary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], #op)

View File

@ -1439,6 +1439,7 @@ string CompilerMSL::compile()
// Arrays which are part of buffer objects are never considered to be value types (just plain C-style).
backend.array_is_value_type_in_buffer_blocks = false;
backend.support_pointer_to_pointer = true;
backend.implicit_c_integer_promotion_rules = true;
capture_output_to_buffer = msl_options.capture_output_to_buffer;
is_rasterization_disabled = msl_options.disable_rasterization || capture_output_to_buffer;
@ -8167,8 +8168,9 @@ void CompilerMSL::emit_instruction(const Instruction &instruction)
{
#define MSL_BOP(op) emit_binary_op(ops[0], ops[1], ops[2], ops[3], #op)
#define MSL_PTR_BOP(op) emit_binary_ptr_op(ops[0], ops[1], ops[2], ops[3], #op)
// MSL does care about implicit integer promotion, but those cases are all handled in common code.
#define MSL_BOP_CAST(op, type) \
emit_binary_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, opcode_is_sign_invariant(opcode))
emit_binary_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, opcode_is_sign_invariant(opcode), false)
#define MSL_UOP(op) emit_unary_op(ops[0], ops[1], ops[2], #op)
#define MSL_QFOP(op) emit_quaternary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], ops[5], #op)
#define MSL_TFOP(op) emit_trinary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], #op)