From cbaa86982a03719a80db58955a952f7c4d18420c Mon Sep 17 00:00:00 2001 From: Hans-Kristian Arntzen Date: Mon, 29 Apr 2024 11:48:46 +0200 Subject: [PATCH] MSL: Handle missing FP16 trancendental overloads. --- .../shaders-msl/frag/fp16-trancendentals.frag | 64 +++++++++++++++++++ .../frag/fp16.desktop.invalid.frag | 8 +-- .../shaders-msl/frag/fp16-trancendentals.frag | 64 +++++++++++++++++++ shaders-msl/frag/fp16-trancendentals.frag | 33 ++++++++++ spirv_msl.cpp | 43 +++++++++++-- 5 files changed, 204 insertions(+), 8 deletions(-) create mode 100644 reference/opt/shaders-msl/frag/fp16-trancendentals.frag create mode 100644 reference/shaders-msl/frag/fp16-trancendentals.frag create mode 100644 shaders-msl/frag/fp16-trancendentals.frag diff --git a/reference/opt/shaders-msl/frag/fp16-trancendentals.frag b/reference/opt/shaders-msl/frag/fp16-trancendentals.frag new file mode 100644 index 00000000..a661433e --- /dev/null +++ b/reference/opt/shaders-msl/frag/fp16-trancendentals.frag @@ -0,0 +1,64 @@ +#include +#include + +using namespace metal; + +struct main0_out +{ + half C [[color(0)]]; + half D [[color(1)]]; +}; + +struct main0_in +{ + half A [[user(locn0)]]; + half B [[user(locn1)]]; +}; + +fragment main0_out main0(main0_in in [[stage_in]]) +{ + main0_out out = {}; + out.D = half(0.0); + out.C = clamp(sin(in.A), in.A, in.B); + out.D += out.C; + out.C = clamp(cos(in.A), in.A, in.B); + out.D += out.C; + out.C = clamp(tan(in.A), in.A, in.B); + out.D += out.C; + out.C = clamp(asin(in.A), in.A, in.B); + out.D += out.C; + out.C = clamp(acos(in.A), in.A, in.B); + out.D += out.C; + out.C = clamp(atan(in.A), in.A, in.B); + out.D += out.C; + out.C = clamp(half(fast::sinh(in.A)), in.A, in.B); + out.D += out.C; + out.C = clamp(half(fast::cosh(in.A)), in.A, in.B); + out.D += out.C; + out.C = clamp(half(fast::tanh(in.A)), in.A, in.B); + out.D += out.C; + out.C = clamp(asinh(in.A), in.A, in.B); + out.D += out.C; + out.C = clamp(acosh(in.A), in.A, in.B); + out.D += out.C; + out.C = clamp(atanh(in.A), in.A, in.B); + out.D += out.C; + out.C = clamp(half(fast::atan2(in.A, in.B)), in.A, in.B); + out.D += out.C; + out.C = clamp(powr(in.A, in.B), in.A, in.B); + out.D += out.C; + out.C = clamp(exp(in.A), in.A, in.B); + out.D += out.C; + out.C = clamp(exp2(in.A), in.A, in.B); + out.D += out.C; + out.C = clamp(log(in.A), in.A, in.B); + out.D += out.C; + out.C = clamp(log2(in.A), in.A, in.B); + out.D += out.C; + out.C = clamp(sqrt(in.A), in.A, in.B); + out.D += out.C; + out.C = clamp(rsqrt(in.A), in.A, in.B); + out.D += out.C; + return out; +} + diff --git a/reference/shaders-msl-no-opt/frag/fp16.desktop.invalid.frag b/reference/shaders-msl-no-opt/frag/fp16.desktop.invalid.frag index b7eae127..fe6b4eb8 100644 --- a/reference/shaders-msl-no-opt/frag/fp16.desktop.invalid.frag +++ b/reference/shaders-msl-no-opt/frag/fp16.desktop.invalid.frag @@ -94,11 +94,11 @@ void test_builtins(thread half4& v4, thread half3& v3, thread half& v1) res = cos(v4); res = tan(v4); res = asin(v4); - res = precise::atan2(v4, v3.xyzz); + res = half(fast::atan2(v4, v3.xyzz)); res = atan(v4); - res = fast::sinh(v4); - res = fast::cosh(v4); - res = precise::tanh(v4); + res = half(fast::sinh(v4)); + res = half(fast::cosh(v4)); + res = half(fast::tanh(v4)); res = asinh(v4); res = acosh(v4); res = atanh(v4); diff --git a/reference/shaders-msl/frag/fp16-trancendentals.frag b/reference/shaders-msl/frag/fp16-trancendentals.frag new file mode 100644 index 00000000..a661433e --- /dev/null +++ b/reference/shaders-msl/frag/fp16-trancendentals.frag @@ -0,0 +1,64 @@ +#include +#include + +using namespace metal; + +struct main0_out +{ + half C [[color(0)]]; + half D [[color(1)]]; +}; + +struct main0_in +{ + half A [[user(locn0)]]; + half B [[user(locn1)]]; +}; + +fragment main0_out main0(main0_in in [[stage_in]]) +{ + main0_out out = {}; + out.D = half(0.0); + out.C = clamp(sin(in.A), in.A, in.B); + out.D += out.C; + out.C = clamp(cos(in.A), in.A, in.B); + out.D += out.C; + out.C = clamp(tan(in.A), in.A, in.B); + out.D += out.C; + out.C = clamp(asin(in.A), in.A, in.B); + out.D += out.C; + out.C = clamp(acos(in.A), in.A, in.B); + out.D += out.C; + out.C = clamp(atan(in.A), in.A, in.B); + out.D += out.C; + out.C = clamp(half(fast::sinh(in.A)), in.A, in.B); + out.D += out.C; + out.C = clamp(half(fast::cosh(in.A)), in.A, in.B); + out.D += out.C; + out.C = clamp(half(fast::tanh(in.A)), in.A, in.B); + out.D += out.C; + out.C = clamp(asinh(in.A), in.A, in.B); + out.D += out.C; + out.C = clamp(acosh(in.A), in.A, in.B); + out.D += out.C; + out.C = clamp(atanh(in.A), in.A, in.B); + out.D += out.C; + out.C = clamp(half(fast::atan2(in.A, in.B)), in.A, in.B); + out.D += out.C; + out.C = clamp(powr(in.A, in.B), in.A, in.B); + out.D += out.C; + out.C = clamp(exp(in.A), in.A, in.B); + out.D += out.C; + out.C = clamp(exp2(in.A), in.A, in.B); + out.D += out.C; + out.C = clamp(log(in.A), in.A, in.B); + out.D += out.C; + out.C = clamp(log2(in.A), in.A, in.B); + out.D += out.C; + out.C = clamp(sqrt(in.A), in.A, in.B); + out.D += out.C; + out.C = clamp(rsqrt(in.A), in.A, in.B); + out.D += out.C; + return out; +} + diff --git a/shaders-msl/frag/fp16-trancendentals.frag b/shaders-msl/frag/fp16-trancendentals.frag new file mode 100644 index 00000000..ea4065d2 --- /dev/null +++ b/shaders-msl/frag/fp16-trancendentals.frag @@ -0,0 +1,33 @@ +#version 450 +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require + +layout(location = 0) in float16_t A; +layout(location = 1) in float16_t B; +layout(location = 0) out float16_t C; +layout(location = 1) out float16_t D; + +void main() +{ + D = 0.0hf; + C = clamp(sin(A), A, B); D += C; + C = clamp(cos(A), A, B); D += C; + C = clamp(tan(A), A, B); D += C; + C = clamp(asin(A), A, B); D += C; + C = clamp(acos(A), A, B); D += C; + C = clamp(atan(A), A, B); D += C; + C = clamp(sinh(A), A, B); D += C; + C = clamp(cosh(A), A, B); D += C; + C = clamp(tanh(A), A, B); D += C; + C = clamp(asinh(A), A, B); D += C; + C = clamp(acosh(A), A, B); D += C; + C = clamp(atanh(A), A, B); D += C; + C = clamp(atan(A, B), A, B); D += C; + C = clamp(pow(A, B), A, B); D += C; + C = clamp(exp(A), A, B); D += C; + C = clamp(exp2(A), A, B); D += C; + C = clamp(log(A), A, B); D += C; + C = clamp(log2(A), A, B); D += C; + C = clamp(sqrt(A), A, B); D += C; + C = clamp(inversesqrt(A), A, B); D += C; +} + diff --git a/spirv_msl.cpp b/spirv_msl.cpp index 258f38fb..cd2bb501 100644 --- a/spirv_msl.cpp +++ b/spirv_msl.cpp @@ -10405,19 +10405,54 @@ void CompilerMSL::emit_glsl_op(uint32_t result_type, uint32_t id, uint32_t eop, op = get_remapped_glsl_op(op); + auto &restype = get(result_type); + switch (op) { case GLSLstd450Sinh: - emit_unary_func_op(result_type, id, args[0], "fast::sinh"); + if (restype.basetype == SPIRType::Half) + { + // MSL does not have overload for half. Force-cast back to half. + auto expr = join("half(fast::sinh(", to_unpacked_expression(args[0]), "))"); + emit_op(result_type, id, expr, should_forward(args[0])); + inherit_expression_dependencies(id, args[0]); + } + else + emit_unary_func_op(result_type, id, args[0], "fast::sinh"); break; case GLSLstd450Cosh: - emit_unary_func_op(result_type, id, args[0], "fast::cosh"); + if (restype.basetype == SPIRType::Half) + { + // MSL does not have overload for half. Force-cast back to half. + auto expr = join("half(fast::cosh(", to_unpacked_expression(args[0]), "))"); + emit_op(result_type, id, expr, should_forward(args[0])); + inherit_expression_dependencies(id, args[0]); + } + else + emit_unary_func_op(result_type, id, args[0], "fast::cosh"); break; case GLSLstd450Tanh: - emit_unary_func_op(result_type, id, args[0], "precise::tanh"); + if (restype.basetype == SPIRType::Half) + { + // MSL does not have overload for half. Force-cast back to half. + auto expr = join("half(fast::tanh(", to_unpacked_expression(args[0]), "))"); + emit_op(result_type, id, expr, should_forward(args[0])); + inherit_expression_dependencies(id, args[0]); + } + else + emit_unary_func_op(result_type, id, args[0], "precise::tanh"); break; case GLSLstd450Atan2: - emit_binary_func_op(result_type, id, args[0], args[1], "precise::atan2"); + if (restype.basetype == SPIRType::Half) + { + // MSL does not have overload for half. Force-cast back to half. + auto expr = join("half(fast::atan2(", to_unpacked_expression(args[0]), ", ", to_unpacked_expression(args[1]), "))"); + emit_op(result_type, id, expr, should_forward(args[0]) && should_forward(args[1])); + inherit_expression_dependencies(id, args[0]); + inherit_expression_dependencies(id, args[1]); + } + else + emit_binary_func_op(result_type, id, args[0], args[1], "precise::atan2"); break; case GLSLstd450InverseSqrt: emit_unary_func_op(result_type, id, args[0], "rsqrt");