MSL/HLSL: Support scalar reflect and refract.

This commit is contained in:
Hans-Kristian Arntzen 2019-07-03 12:24:58 +02:00
parent 9a6e2534e9
commit 041f103d44
13 changed files with 368 additions and 1 deletions

View File

@ -0,0 +1,47 @@
static float FragColor;
static float3 vRefract;
struct SPIRV_Cross_Input
{
float3 vRefract : TEXCOORD0;
};
struct SPIRV_Cross_Output
{
float FragColor : SV_Target0;
};
float SPIRV_Cross_Reflect(float i, float n)
{
return i - 2.0 * dot(n, i) * n;
}
float SPIRV_Cross_Refract(float i, float n, float eta)
{
float k = 1.0 - eta * eta * (1.0 - dot(n, i) * dot(n, i));
if (k < 0.0)
{
return 0.0;
}
else
{
return eta * i - (eta * dot(n, i) + sqrt(k)) * n;
}
}
void frag_main()
{
FragColor = SPIRV_Cross_Refract(vRefract.x, vRefract.y, vRefract.z);
FragColor += SPIRV_Cross_Reflect(vRefract.x, vRefract.y);
FragColor += refract(vRefract.xy, vRefract.yz, vRefract.z).y;
FragColor += reflect(vRefract.xy, vRefract.zy).y;
}
SPIRV_Cross_Output main(SPIRV_Cross_Input stage_input)
{
vRefract = stage_input.vRefract;
frag_main();
SPIRV_Cross_Output stage_output;
stage_output.FragColor = FragColor;
return stage_output;
}

View File

@ -0,0 +1,47 @@
#pragma clang diagnostic ignored "-Wmissing-prototypes"
#include <metal_stdlib>
#include <simd/simd.h>
using namespace metal;
struct main0_out
{
float FragColor [[color(0)]];
};
struct main0_in
{
float3 vRefract [[user(locn0)]];
};
template<typename T>
inline T spvReflect(T i, T n)
{
return i - T(2) * dot(n, i) * n;
}
template<typename T>
inline T spvRefract(T i, T n, T eta)
{
T k = T(1) - eta * eta * (T(1) - dot(n, i) * dot(n, i));
if (k < T(0))
{
return T(0);
}
else
{
return eta * i - (eta * dot(n, i) + sqrt(k)) * n;
}
}
fragment main0_out main0(main0_in in [[stage_in]])
{
main0_out out = {};
out.FragColor = spvRefract(in.vRefract.x, in.vRefract.y, in.vRefract.z);
out.FragColor += spvReflect(in.vRefract.x, in.vRefract.y);
out.FragColor += refract(in.vRefract.xy, in.vRefract.yz, in.vRefract.z).y;
out.FragColor += reflect(in.vRefract.xy, in.vRefract.zy).y;
return out;
}

View File

@ -0,0 +1,13 @@
#version 450
layout(location = 0) out float FragColor;
layout(location = 0) in vec3 vRefract;
void main()
{
FragColor = refract(vRefract.x, vRefract.y, vRefract.z);
FragColor += reflect(vRefract.x, vRefract.y);
FragColor += refract(vRefract.xy, vRefract.yz, vRefract.z).y;
FragColor += reflect(vRefract.xy, vRefract.zy).y;
}

View File

@ -0,0 +1,47 @@
static float FragColor;
static float3 vRefract;
struct SPIRV_Cross_Input
{
float3 vRefract : TEXCOORD0;
};
struct SPIRV_Cross_Output
{
float FragColor : SV_Target0;
};
float SPIRV_Cross_Reflect(float i, float n)
{
return i - 2.0 * dot(n, i) * n;
}
float SPIRV_Cross_Refract(float i, float n, float eta)
{
float k = 1.0 - eta * eta * (1.0 - dot(n, i) * dot(n, i));
if (k < 0.0)
{
return 0.0;
}
else
{
return eta * i - (eta * dot(n, i) + sqrt(k)) * n;
}
}
void frag_main()
{
FragColor = SPIRV_Cross_Refract(vRefract.x, vRefract.y, vRefract.z);
FragColor += SPIRV_Cross_Reflect(vRefract.x, vRefract.y);
FragColor += refract(vRefract.xy, vRefract.yz, vRefract.z).y;
FragColor += reflect(vRefract.xy, vRefract.zy).y;
}
SPIRV_Cross_Output main(SPIRV_Cross_Input stage_input)
{
vRefract = stage_input.vRefract;
frag_main();
SPIRV_Cross_Output stage_output;
stage_output.FragColor = FragColor;
return stage_output;
}

View File

@ -0,0 +1,47 @@
#pragma clang diagnostic ignored "-Wmissing-prototypes"
#include <metal_stdlib>
#include <simd/simd.h>
using namespace metal;
struct main0_out
{
float FragColor [[color(0)]];
};
struct main0_in
{
float3 vRefract [[user(locn0)]];
};
template<typename T>
inline T spvReflect(T i, T n)
{
return i - T(2) * dot(n, i) * n;
}
template<typename T>
inline T spvRefract(T i, T n, T eta)
{
T k = T(1) - eta * eta * (T(1) - dot(n, i) * dot(n, i));
if (k < T(0))
{
return T(0);
}
else
{
return eta * i - (eta * dot(n, i) + sqrt(k)) * n;
}
}
fragment main0_out main0(main0_in in [[stage_in]])
{
main0_out out = {};
out.FragColor = spvRefract(in.vRefract.x, in.vRefract.y, in.vRefract.z);
out.FragColor += spvReflect(in.vRefract.x, in.vRefract.y);
out.FragColor += refract(in.vRefract.xy, in.vRefract.yz, in.vRefract.z).y;
out.FragColor += reflect(in.vRefract.xy, in.vRefract.zy).y;
return out;
}

View File

@ -0,0 +1,13 @@
#version 450
layout(location = 0) out float FragColor;
layout(location = 0) in vec3 vRefract;
void main()
{
FragColor = refract(vRefract.x, vRefract.y, vRefract.z);
FragColor += reflect(vRefract.x, vRefract.y);
FragColor += refract(vRefract.xy, vRefract.yz, vRefract.z).y;
FragColor += reflect(vRefract.xy, vRefract.zy).y;
}

View File

@ -0,0 +1,11 @@
#version 450
layout(location = 0) out float FragColor;
layout(location = 0) in vec3 vRefract;
void main()
{
FragColor = refract(vRefract.x, vRefract.y, vRefract.z);
FragColor += reflect(vRefract.x, vRefract.y);
FragColor += refract(vRefract.xy, vRefract.yz, vRefract.z).y;
FragColor += reflect(vRefract.xy, vRefract.zy).y;
}

View File

@ -0,0 +1,11 @@
#version 450
layout(location = 0) out float FragColor;
layout(location = 0) in vec3 vRefract;
void main()
{
FragColor = refract(vRefract.x, vRefract.y, vRefract.z);
FragColor += reflect(vRefract.x, vRefract.y);
FragColor += refract(vRefract.xy, vRefract.yz, vRefract.z).y;
FragColor += reflect(vRefract.xy, vRefract.zy).y;
}

View File

@ -0,0 +1,11 @@
#version 450
layout(location = 0) out float FragColor;
layout(location = 0) in vec3 vRefract;
void main()
{
FragColor = refract(vRefract.x, vRefract.y, vRefract.z);
FragColor += reflect(vRefract.x, vRefract.y);
FragColor += refract(vRefract.xy, vRefract.yz, vRefract.z).y;
FragColor += reflect(vRefract.xy, vRefract.zy).y;
}

View File

@ -1742,6 +1742,34 @@ void CompilerHLSL::emit_resources()
end_scope();
statement("");
}
if (requires_scalar_reflect)
{
// FP16/FP64? No templates in HLSL.
statement("float SPIRV_Cross_Reflect(float i, float n)");
begin_scope();
statement("return i - 2.0 * dot(n, i) * n;");
end_scope();
statement("");
}
if (requires_scalar_refract)
{
// FP16/FP64? No templates in HLSL.
statement("float SPIRV_Cross_Refract(float i, float n, float eta)");
begin_scope();
statement("float k = 1.0 - eta * eta * (1.0 - dot(n, i) * dot(n, i));");
statement("if (k < 0.0)");
begin_scope();
statement("return 0.0;");
end_scope();
statement("else");
begin_scope();
statement("return eta * i - (eta * dot(n, i) + sqrt(k)) * n;");
end_scope();
end_scope();
statement("");
}
}
string CompilerHLSL::layout_for_member(const SPIRType &type, uint32_t index)
@ -3256,6 +3284,34 @@ void CompilerHLSL::emit_glsl_op(uint32_t result_type, uint32_t id, uint32_t eop,
CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
break;
case GLSLstd450Reflect:
if (get<SPIRType>(result_type).vecsize == 1)
{
if (!requires_scalar_reflect)
{
requires_scalar_reflect = true;
force_recompile();
}
emit_binary_func_op(result_type, id, args[0], args[1], "SPIRV_Cross_Reflect");
}
else
CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
break;
case GLSLstd450Refract:
if (get<SPIRType>(result_type).vecsize == 1)
{
if (!requires_scalar_refract)
{
requires_scalar_refract = true;
force_recompile();
}
emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "SPIRV_Cross_Refract");
}
else
CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
break;
default:
CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
break;

View File

@ -167,6 +167,8 @@ private:
void replace_illegal_names() override;
Options hlsl_options;
// TODO: Refactor this to be more similar to MSL, maybe have some common system in place?
bool requires_op_fmod = false;
bool requires_fp16_packing = false;
bool requires_explicit_fp16_packing = false;
@ -179,6 +181,8 @@ private:
bool requires_inverse_2x2 = false;
bool requires_inverse_3x3 = false;
bool requires_inverse_4x4 = false;
bool requires_scalar_reflect = false;
bool requires_scalar_refract = false;
uint64_t required_textureSizeVariants = 0;
void require_texture_query_variant(const SPIRType &type);

View File

@ -3189,6 +3189,34 @@ void CompilerMSL::emit_custom_functions()
statement("");
break;
case SPVFuncImplReflectScalar:
// Metal does not support scalar versions of these functions.
statement("template<typename T>");
statement("inline T spvReflect(T i, T n)");
begin_scope();
statement("return i - T(2) * dot(n, i) * n;");
end_scope();
statement("");
break;
case SPVFuncImplRefractScalar:
// Metal does not support scalar versions of these functions.
statement("template<typename T>");
statement("inline T spvRefract(T i, T n, T eta)");
begin_scope();
statement("T k = T(1) - eta * eta * (T(1) - dot(n, i) * dot(n, i));");
statement("if (k < T(0))");
begin_scope();
statement("return T(0);");
end_scope();
statement("else");
begin_scope();
statement("return eta * i - (eta * dot(n, i) + sqrt(k)) * n;");
end_scope();
end_scope();
statement("");
break;
default:
break;
}
@ -4709,6 +4737,20 @@ void CompilerMSL::emit_glsl_op(uint32_t result_type, uint32_t id, uint32_t eop,
CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
break;
case GLSLstd450Reflect:
if (get<SPIRType>(result_type).vecsize == 1)
emit_binary_func_op(result_type, id, args[0], args[1], "spvReflect");
else
CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
break;
case GLSLstd450Refract:
if (get<SPIRType>(result_type).vecsize == 1)
emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "spvRefract");
else
CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
break;
default:
CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
break;
@ -8660,7 +8702,7 @@ CompilerMSL::SPVFuncImpl CompilerMSL::OpCodePreprocessor::get_spv_func_impl(Op o
uint32_t extension_set = args[2];
if (compiler.get<SPIRExtension>(extension_set).ext == SPIRExtension::GLSL)
{
GLSLstd450 op_450 = static_cast<GLSLstd450>(args[3]);
auto op_450 = static_cast<GLSLstd450>(args[3]);
switch (op_450)
{
case GLSLstd450Radians:
@ -8675,6 +8717,22 @@ CompilerMSL::SPVFuncImpl CompilerMSL::OpCodePreprocessor::get_spv_func_impl(Op o
return SPVFuncImplFindUMsb;
case GLSLstd450SSign:
return SPVFuncImplSSign;
case GLSLstd450Reflect:
{
auto &type = compiler.get<SPIRType>(args[0]);
if (type.vecsize == 1)
return SPVFuncImplReflectScalar;
else
return SPVFuncImplNone;
}
case GLSLstd450Refract:
{
auto &type = compiler.get<SPIRType>(args[0]);
if (type.vecsize == 1)
return SPVFuncImplRefractScalar;
else
return SPVFuncImplNone;
}
case GLSLstd450MatrixInverse:
{
auto &mat_type = compiler.get<SPIRType>(args[0]);

View File

@ -401,6 +401,8 @@ protected:
SPVFuncImplSubgroupBallotFindMSB,
SPVFuncImplSubgroupBallotBitCount,
SPVFuncImplSubgroupAllEqual,
SPVFuncImplReflectScalar,
SPVFuncImplRefractScalar,
SPVFuncImplArrayCopyMultidimMax = 6
};