MSL: Handle missing FP16 trancendental overloads.
This commit is contained in:
parent
71fe131ed0
commit
cbaa86982a
64
reference/opt/shaders-msl/frag/fp16-trancendentals.frag
Normal file
64
reference/opt/shaders-msl/frag/fp16-trancendentals.frag
Normal file
@ -0,0 +1,64 @@
|
||||
#include <metal_stdlib>
|
||||
#include <simd/simd.h>
|
||||
|
||||
using namespace metal;
|
||||
|
||||
struct main0_out
|
||||
{
|
||||
half C [[color(0)]];
|
||||
half D [[color(1)]];
|
||||
};
|
||||
|
||||
struct main0_in
|
||||
{
|
||||
half A [[user(locn0)]];
|
||||
half B [[user(locn1)]];
|
||||
};
|
||||
|
||||
fragment main0_out main0(main0_in in [[stage_in]])
|
||||
{
|
||||
main0_out out = {};
|
||||
out.D = half(0.0);
|
||||
out.C = clamp(sin(in.A), in.A, in.B);
|
||||
out.D += out.C;
|
||||
out.C = clamp(cos(in.A), in.A, in.B);
|
||||
out.D += out.C;
|
||||
out.C = clamp(tan(in.A), in.A, in.B);
|
||||
out.D += out.C;
|
||||
out.C = clamp(asin(in.A), in.A, in.B);
|
||||
out.D += out.C;
|
||||
out.C = clamp(acos(in.A), in.A, in.B);
|
||||
out.D += out.C;
|
||||
out.C = clamp(atan(in.A), in.A, in.B);
|
||||
out.D += out.C;
|
||||
out.C = clamp(half(fast::sinh(in.A)), in.A, in.B);
|
||||
out.D += out.C;
|
||||
out.C = clamp(half(fast::cosh(in.A)), in.A, in.B);
|
||||
out.D += out.C;
|
||||
out.C = clamp(half(fast::tanh(in.A)), in.A, in.B);
|
||||
out.D += out.C;
|
||||
out.C = clamp(asinh(in.A), in.A, in.B);
|
||||
out.D += out.C;
|
||||
out.C = clamp(acosh(in.A), in.A, in.B);
|
||||
out.D += out.C;
|
||||
out.C = clamp(atanh(in.A), in.A, in.B);
|
||||
out.D += out.C;
|
||||
out.C = clamp(half(fast::atan2(in.A, in.B)), in.A, in.B);
|
||||
out.D += out.C;
|
||||
out.C = clamp(powr(in.A, in.B), in.A, in.B);
|
||||
out.D += out.C;
|
||||
out.C = clamp(exp(in.A), in.A, in.B);
|
||||
out.D += out.C;
|
||||
out.C = clamp(exp2(in.A), in.A, in.B);
|
||||
out.D += out.C;
|
||||
out.C = clamp(log(in.A), in.A, in.B);
|
||||
out.D += out.C;
|
||||
out.C = clamp(log2(in.A), in.A, in.B);
|
||||
out.D += out.C;
|
||||
out.C = clamp(sqrt(in.A), in.A, in.B);
|
||||
out.D += out.C;
|
||||
out.C = clamp(rsqrt(in.A), in.A, in.B);
|
||||
out.D += out.C;
|
||||
return out;
|
||||
}
|
||||
|
@ -94,11 +94,11 @@ void test_builtins(thread half4& v4, thread half3& v3, thread half& v1)
|
||||
res = cos(v4);
|
||||
res = tan(v4);
|
||||
res = asin(v4);
|
||||
res = precise::atan2(v4, v3.xyzz);
|
||||
res = half(fast::atan2(v4, v3.xyzz));
|
||||
res = atan(v4);
|
||||
res = fast::sinh(v4);
|
||||
res = fast::cosh(v4);
|
||||
res = precise::tanh(v4);
|
||||
res = half(fast::sinh(v4));
|
||||
res = half(fast::cosh(v4));
|
||||
res = half(fast::tanh(v4));
|
||||
res = asinh(v4);
|
||||
res = acosh(v4);
|
||||
res = atanh(v4);
|
||||
|
64
reference/shaders-msl/frag/fp16-trancendentals.frag
Normal file
64
reference/shaders-msl/frag/fp16-trancendentals.frag
Normal file
@ -0,0 +1,64 @@
|
||||
#include <metal_stdlib>
|
||||
#include <simd/simd.h>
|
||||
|
||||
using namespace metal;
|
||||
|
||||
struct main0_out
|
||||
{
|
||||
half C [[color(0)]];
|
||||
half D [[color(1)]];
|
||||
};
|
||||
|
||||
struct main0_in
|
||||
{
|
||||
half A [[user(locn0)]];
|
||||
half B [[user(locn1)]];
|
||||
};
|
||||
|
||||
fragment main0_out main0(main0_in in [[stage_in]])
|
||||
{
|
||||
main0_out out = {};
|
||||
out.D = half(0.0);
|
||||
out.C = clamp(sin(in.A), in.A, in.B);
|
||||
out.D += out.C;
|
||||
out.C = clamp(cos(in.A), in.A, in.B);
|
||||
out.D += out.C;
|
||||
out.C = clamp(tan(in.A), in.A, in.B);
|
||||
out.D += out.C;
|
||||
out.C = clamp(asin(in.A), in.A, in.B);
|
||||
out.D += out.C;
|
||||
out.C = clamp(acos(in.A), in.A, in.B);
|
||||
out.D += out.C;
|
||||
out.C = clamp(atan(in.A), in.A, in.B);
|
||||
out.D += out.C;
|
||||
out.C = clamp(half(fast::sinh(in.A)), in.A, in.B);
|
||||
out.D += out.C;
|
||||
out.C = clamp(half(fast::cosh(in.A)), in.A, in.B);
|
||||
out.D += out.C;
|
||||
out.C = clamp(half(fast::tanh(in.A)), in.A, in.B);
|
||||
out.D += out.C;
|
||||
out.C = clamp(asinh(in.A), in.A, in.B);
|
||||
out.D += out.C;
|
||||
out.C = clamp(acosh(in.A), in.A, in.B);
|
||||
out.D += out.C;
|
||||
out.C = clamp(atanh(in.A), in.A, in.B);
|
||||
out.D += out.C;
|
||||
out.C = clamp(half(fast::atan2(in.A, in.B)), in.A, in.B);
|
||||
out.D += out.C;
|
||||
out.C = clamp(powr(in.A, in.B), in.A, in.B);
|
||||
out.D += out.C;
|
||||
out.C = clamp(exp(in.A), in.A, in.B);
|
||||
out.D += out.C;
|
||||
out.C = clamp(exp2(in.A), in.A, in.B);
|
||||
out.D += out.C;
|
||||
out.C = clamp(log(in.A), in.A, in.B);
|
||||
out.D += out.C;
|
||||
out.C = clamp(log2(in.A), in.A, in.B);
|
||||
out.D += out.C;
|
||||
out.C = clamp(sqrt(in.A), in.A, in.B);
|
||||
out.D += out.C;
|
||||
out.C = clamp(rsqrt(in.A), in.A, in.B);
|
||||
out.D += out.C;
|
||||
return out;
|
||||
}
|
||||
|
33
shaders-msl/frag/fp16-trancendentals.frag
Normal file
33
shaders-msl/frag/fp16-trancendentals.frag
Normal file
@ -0,0 +1,33 @@
|
||||
#version 450
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||
|
||||
layout(location = 0) in float16_t A;
|
||||
layout(location = 1) in float16_t B;
|
||||
layout(location = 0) out float16_t C;
|
||||
layout(location = 1) out float16_t D;
|
||||
|
||||
void main()
|
||||
{
|
||||
D = 0.0hf;
|
||||
C = clamp(sin(A), A, B); D += C;
|
||||
C = clamp(cos(A), A, B); D += C;
|
||||
C = clamp(tan(A), A, B); D += C;
|
||||
C = clamp(asin(A), A, B); D += C;
|
||||
C = clamp(acos(A), A, B); D += C;
|
||||
C = clamp(atan(A), A, B); D += C;
|
||||
C = clamp(sinh(A), A, B); D += C;
|
||||
C = clamp(cosh(A), A, B); D += C;
|
||||
C = clamp(tanh(A), A, B); D += C;
|
||||
C = clamp(asinh(A), A, B); D += C;
|
||||
C = clamp(acosh(A), A, B); D += C;
|
||||
C = clamp(atanh(A), A, B); D += C;
|
||||
C = clamp(atan(A, B), A, B); D += C;
|
||||
C = clamp(pow(A, B), A, B); D += C;
|
||||
C = clamp(exp(A), A, B); D += C;
|
||||
C = clamp(exp2(A), A, B); D += C;
|
||||
C = clamp(log(A), A, B); D += C;
|
||||
C = clamp(log2(A), A, B); D += C;
|
||||
C = clamp(sqrt(A), A, B); D += C;
|
||||
C = clamp(inversesqrt(A), A, B); D += C;
|
||||
}
|
||||
|
@ -10405,19 +10405,54 @@ void CompilerMSL::emit_glsl_op(uint32_t result_type, uint32_t id, uint32_t eop,
|
||||
|
||||
op = get_remapped_glsl_op(op);
|
||||
|
||||
auto &restype = get<SPIRType>(result_type);
|
||||
|
||||
switch (op)
|
||||
{
|
||||
case GLSLstd450Sinh:
|
||||
emit_unary_func_op(result_type, id, args[0], "fast::sinh");
|
||||
if (restype.basetype == SPIRType::Half)
|
||||
{
|
||||
// MSL does not have overload for half. Force-cast back to half.
|
||||
auto expr = join("half(fast::sinh(", to_unpacked_expression(args[0]), "))");
|
||||
emit_op(result_type, id, expr, should_forward(args[0]));
|
||||
inherit_expression_dependencies(id, args[0]);
|
||||
}
|
||||
else
|
||||
emit_unary_func_op(result_type, id, args[0], "fast::sinh");
|
||||
break;
|
||||
case GLSLstd450Cosh:
|
||||
emit_unary_func_op(result_type, id, args[0], "fast::cosh");
|
||||
if (restype.basetype == SPIRType::Half)
|
||||
{
|
||||
// MSL does not have overload for half. Force-cast back to half.
|
||||
auto expr = join("half(fast::cosh(", to_unpacked_expression(args[0]), "))");
|
||||
emit_op(result_type, id, expr, should_forward(args[0]));
|
||||
inherit_expression_dependencies(id, args[0]);
|
||||
}
|
||||
else
|
||||
emit_unary_func_op(result_type, id, args[0], "fast::cosh");
|
||||
break;
|
||||
case GLSLstd450Tanh:
|
||||
emit_unary_func_op(result_type, id, args[0], "precise::tanh");
|
||||
if (restype.basetype == SPIRType::Half)
|
||||
{
|
||||
// MSL does not have overload for half. Force-cast back to half.
|
||||
auto expr = join("half(fast::tanh(", to_unpacked_expression(args[0]), "))");
|
||||
emit_op(result_type, id, expr, should_forward(args[0]));
|
||||
inherit_expression_dependencies(id, args[0]);
|
||||
}
|
||||
else
|
||||
emit_unary_func_op(result_type, id, args[0], "precise::tanh");
|
||||
break;
|
||||
case GLSLstd450Atan2:
|
||||
emit_binary_func_op(result_type, id, args[0], args[1], "precise::atan2");
|
||||
if (restype.basetype == SPIRType::Half)
|
||||
{
|
||||
// MSL does not have overload for half. Force-cast back to half.
|
||||
auto expr = join("half(fast::atan2(", to_unpacked_expression(args[0]), ", ", to_unpacked_expression(args[1]), "))");
|
||||
emit_op(result_type, id, expr, should_forward(args[0]) && should_forward(args[1]));
|
||||
inherit_expression_dependencies(id, args[0]);
|
||||
inherit_expression_dependencies(id, args[1]);
|
||||
}
|
||||
else
|
||||
emit_binary_func_op(result_type, id, args[0], args[1], "precise::atan2");
|
||||
break;
|
||||
case GLSLstd450InverseSqrt:
|
||||
emit_unary_func_op(result_type, id, args[0], "rsqrt");
|
||||
|
Loading…
Reference in New Issue
Block a user