MSL: Handle missing FP16 trancendental overloads.

This commit is contained in:
Hans-Kristian Arntzen 2024-04-29 11:48:46 +02:00
parent 71fe131ed0
commit cbaa86982a
5 changed files with 204 additions and 8 deletions

View File

@ -0,0 +1,64 @@
#include <metal_stdlib>
#include <simd/simd.h>
using namespace metal;
struct main0_out
{
half C [[color(0)]];
half D [[color(1)]];
};
struct main0_in
{
half A [[user(locn0)]];
half B [[user(locn1)]];
};
fragment main0_out main0(main0_in in [[stage_in]])
{
main0_out out = {};
out.D = half(0.0);
out.C = clamp(sin(in.A), in.A, in.B);
out.D += out.C;
out.C = clamp(cos(in.A), in.A, in.B);
out.D += out.C;
out.C = clamp(tan(in.A), in.A, in.B);
out.D += out.C;
out.C = clamp(asin(in.A), in.A, in.B);
out.D += out.C;
out.C = clamp(acos(in.A), in.A, in.B);
out.D += out.C;
out.C = clamp(atan(in.A), in.A, in.B);
out.D += out.C;
out.C = clamp(half(fast::sinh(in.A)), in.A, in.B);
out.D += out.C;
out.C = clamp(half(fast::cosh(in.A)), in.A, in.B);
out.D += out.C;
out.C = clamp(half(fast::tanh(in.A)), in.A, in.B);
out.D += out.C;
out.C = clamp(asinh(in.A), in.A, in.B);
out.D += out.C;
out.C = clamp(acosh(in.A), in.A, in.B);
out.D += out.C;
out.C = clamp(atanh(in.A), in.A, in.B);
out.D += out.C;
out.C = clamp(half(fast::atan2(in.A, in.B)), in.A, in.B);
out.D += out.C;
out.C = clamp(powr(in.A, in.B), in.A, in.B);
out.D += out.C;
out.C = clamp(exp(in.A), in.A, in.B);
out.D += out.C;
out.C = clamp(exp2(in.A), in.A, in.B);
out.D += out.C;
out.C = clamp(log(in.A), in.A, in.B);
out.D += out.C;
out.C = clamp(log2(in.A), in.A, in.B);
out.D += out.C;
out.C = clamp(sqrt(in.A), in.A, in.B);
out.D += out.C;
out.C = clamp(rsqrt(in.A), in.A, in.B);
out.D += out.C;
return out;
}

View File

@ -94,11 +94,11 @@ void test_builtins(thread half4& v4, thread half3& v3, thread half& v1)
res = cos(v4); res = cos(v4);
res = tan(v4); res = tan(v4);
res = asin(v4); res = asin(v4);
res = precise::atan2(v4, v3.xyzz); res = half(fast::atan2(v4, v3.xyzz));
res = atan(v4); res = atan(v4);
res = fast::sinh(v4); res = half(fast::sinh(v4));
res = fast::cosh(v4); res = half(fast::cosh(v4));
res = precise::tanh(v4); res = half(fast::tanh(v4));
res = asinh(v4); res = asinh(v4);
res = acosh(v4); res = acosh(v4);
res = atanh(v4); res = atanh(v4);

View File

@ -0,0 +1,64 @@
#include <metal_stdlib>
#include <simd/simd.h>
using namespace metal;
struct main0_out
{
half C [[color(0)]];
half D [[color(1)]];
};
struct main0_in
{
half A [[user(locn0)]];
half B [[user(locn1)]];
};
fragment main0_out main0(main0_in in [[stage_in]])
{
main0_out out = {};
out.D = half(0.0);
out.C = clamp(sin(in.A), in.A, in.B);
out.D += out.C;
out.C = clamp(cos(in.A), in.A, in.B);
out.D += out.C;
out.C = clamp(tan(in.A), in.A, in.B);
out.D += out.C;
out.C = clamp(asin(in.A), in.A, in.B);
out.D += out.C;
out.C = clamp(acos(in.A), in.A, in.B);
out.D += out.C;
out.C = clamp(atan(in.A), in.A, in.B);
out.D += out.C;
out.C = clamp(half(fast::sinh(in.A)), in.A, in.B);
out.D += out.C;
out.C = clamp(half(fast::cosh(in.A)), in.A, in.B);
out.D += out.C;
out.C = clamp(half(fast::tanh(in.A)), in.A, in.B);
out.D += out.C;
out.C = clamp(asinh(in.A), in.A, in.B);
out.D += out.C;
out.C = clamp(acosh(in.A), in.A, in.B);
out.D += out.C;
out.C = clamp(atanh(in.A), in.A, in.B);
out.D += out.C;
out.C = clamp(half(fast::atan2(in.A, in.B)), in.A, in.B);
out.D += out.C;
out.C = clamp(powr(in.A, in.B), in.A, in.B);
out.D += out.C;
out.C = clamp(exp(in.A), in.A, in.B);
out.D += out.C;
out.C = clamp(exp2(in.A), in.A, in.B);
out.D += out.C;
out.C = clamp(log(in.A), in.A, in.B);
out.D += out.C;
out.C = clamp(log2(in.A), in.A, in.B);
out.D += out.C;
out.C = clamp(sqrt(in.A), in.A, in.B);
out.D += out.C;
out.C = clamp(rsqrt(in.A), in.A, in.B);
out.D += out.C;
return out;
}

View File

@ -0,0 +1,33 @@
#version 450
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
layout(location = 0) in float16_t A;
layout(location = 1) in float16_t B;
layout(location = 0) out float16_t C;
layout(location = 1) out float16_t D;
void main()
{
D = 0.0hf;
C = clamp(sin(A), A, B); D += C;
C = clamp(cos(A), A, B); D += C;
C = clamp(tan(A), A, B); D += C;
C = clamp(asin(A), A, B); D += C;
C = clamp(acos(A), A, B); D += C;
C = clamp(atan(A), A, B); D += C;
C = clamp(sinh(A), A, B); D += C;
C = clamp(cosh(A), A, B); D += C;
C = clamp(tanh(A), A, B); D += C;
C = clamp(asinh(A), A, B); D += C;
C = clamp(acosh(A), A, B); D += C;
C = clamp(atanh(A), A, B); D += C;
C = clamp(atan(A, B), A, B); D += C;
C = clamp(pow(A, B), A, B); D += C;
C = clamp(exp(A), A, B); D += C;
C = clamp(exp2(A), A, B); D += C;
C = clamp(log(A), A, B); D += C;
C = clamp(log2(A), A, B); D += C;
C = clamp(sqrt(A), A, B); D += C;
C = clamp(inversesqrt(A), A, B); D += C;
}

View File

@ -10405,19 +10405,54 @@ void CompilerMSL::emit_glsl_op(uint32_t result_type, uint32_t id, uint32_t eop,
op = get_remapped_glsl_op(op); op = get_remapped_glsl_op(op);
auto &restype = get<SPIRType>(result_type);
switch (op) switch (op)
{ {
case GLSLstd450Sinh: case GLSLstd450Sinh:
emit_unary_func_op(result_type, id, args[0], "fast::sinh"); if (restype.basetype == SPIRType::Half)
{
// MSL does not have overload for half. Force-cast back to half.
auto expr = join("half(fast::sinh(", to_unpacked_expression(args[0]), "))");
emit_op(result_type, id, expr, should_forward(args[0]));
inherit_expression_dependencies(id, args[0]);
}
else
emit_unary_func_op(result_type, id, args[0], "fast::sinh");
break; break;
case GLSLstd450Cosh: case GLSLstd450Cosh:
emit_unary_func_op(result_type, id, args[0], "fast::cosh"); if (restype.basetype == SPIRType::Half)
{
// MSL does not have overload for half. Force-cast back to half.
auto expr = join("half(fast::cosh(", to_unpacked_expression(args[0]), "))");
emit_op(result_type, id, expr, should_forward(args[0]));
inherit_expression_dependencies(id, args[0]);
}
else
emit_unary_func_op(result_type, id, args[0], "fast::cosh");
break; break;
case GLSLstd450Tanh: case GLSLstd450Tanh:
emit_unary_func_op(result_type, id, args[0], "precise::tanh"); if (restype.basetype == SPIRType::Half)
{
// MSL does not have overload for half. Force-cast back to half.
auto expr = join("half(fast::tanh(", to_unpacked_expression(args[0]), "))");
emit_op(result_type, id, expr, should_forward(args[0]));
inherit_expression_dependencies(id, args[0]);
}
else
emit_unary_func_op(result_type, id, args[0], "precise::tanh");
break; break;
case GLSLstd450Atan2: case GLSLstd450Atan2:
emit_binary_func_op(result_type, id, args[0], args[1], "precise::atan2"); if (restype.basetype == SPIRType::Half)
{
// MSL does not have overload for half. Force-cast back to half.
auto expr = join("half(fast::atan2(", to_unpacked_expression(args[0]), ", ", to_unpacked_expression(args[1]), "))");
emit_op(result_type, id, expr, should_forward(args[0]) && should_forward(args[1]));
inherit_expression_dependencies(id, args[0]);
inherit_expression_dependencies(id, args[1]);
}
else
emit_binary_func_op(result_type, id, args[0], args[1], "precise::atan2");
break; break;
case GLSLstd450InverseSqrt: case GLSLstd450InverseSqrt:
emit_unary_func_op(result_type, id, args[0], "rsqrt"); emit_unary_func_op(result_type, id, args[0], "rsqrt");