MSL/HLSL: Support scalar reflect and refract.
This commit is contained in:
parent
9a6e2534e9
commit
041f103d44
47
reference/opt/shaders-hlsl/frag/scalar-refract-reflect.frag
Normal file
47
reference/opt/shaders-hlsl/frag/scalar-refract-reflect.frag
Normal file
@ -0,0 +1,47 @@
|
||||
static float FragColor;
|
||||
static float3 vRefract;
|
||||
|
||||
struct SPIRV_Cross_Input
|
||||
{
|
||||
float3 vRefract : TEXCOORD0;
|
||||
};
|
||||
|
||||
struct SPIRV_Cross_Output
|
||||
{
|
||||
float FragColor : SV_Target0;
|
||||
};
|
||||
|
||||
float SPIRV_Cross_Reflect(float i, float n)
|
||||
{
|
||||
return i - 2.0 * dot(n, i) * n;
|
||||
}
|
||||
|
||||
float SPIRV_Cross_Refract(float i, float n, float eta)
|
||||
{
|
||||
float k = 1.0 - eta * eta * (1.0 - dot(n, i) * dot(n, i));
|
||||
if (k < 0.0)
|
||||
{
|
||||
return 0.0;
|
||||
}
|
||||
else
|
||||
{
|
||||
return eta * i - (eta * dot(n, i) + sqrt(k)) * n;
|
||||
}
|
||||
}
|
||||
|
||||
void frag_main()
|
||||
{
|
||||
FragColor = SPIRV_Cross_Refract(vRefract.x, vRefract.y, vRefract.z);
|
||||
FragColor += SPIRV_Cross_Reflect(vRefract.x, vRefract.y);
|
||||
FragColor += refract(vRefract.xy, vRefract.yz, vRefract.z).y;
|
||||
FragColor += reflect(vRefract.xy, vRefract.zy).y;
|
||||
}
|
||||
|
||||
SPIRV_Cross_Output main(SPIRV_Cross_Input stage_input)
|
||||
{
|
||||
vRefract = stage_input.vRefract;
|
||||
frag_main();
|
||||
SPIRV_Cross_Output stage_output;
|
||||
stage_output.FragColor = FragColor;
|
||||
return stage_output;
|
||||
}
|
47
reference/opt/shaders-msl/frag/scalar-refract-reflect.frag
Normal file
47
reference/opt/shaders-msl/frag/scalar-refract-reflect.frag
Normal file
@ -0,0 +1,47 @@
|
||||
#pragma clang diagnostic ignored "-Wmissing-prototypes"
|
||||
|
||||
#include <metal_stdlib>
|
||||
#include <simd/simd.h>
|
||||
|
||||
using namespace metal;
|
||||
|
||||
struct main0_out
|
||||
{
|
||||
float FragColor [[color(0)]];
|
||||
};
|
||||
|
||||
struct main0_in
|
||||
{
|
||||
float3 vRefract [[user(locn0)]];
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
inline T spvReflect(T i, T n)
|
||||
{
|
||||
return i - T(2) * dot(n, i) * n;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
inline T spvRefract(T i, T n, T eta)
|
||||
{
|
||||
T k = T(1) - eta * eta * (T(1) - dot(n, i) * dot(n, i));
|
||||
if (k < T(0))
|
||||
{
|
||||
return T(0);
|
||||
}
|
||||
else
|
||||
{
|
||||
return eta * i - (eta * dot(n, i) + sqrt(k)) * n;
|
||||
}
|
||||
}
|
||||
|
||||
fragment main0_out main0(main0_in in [[stage_in]])
|
||||
{
|
||||
main0_out out = {};
|
||||
out.FragColor = spvRefract(in.vRefract.x, in.vRefract.y, in.vRefract.z);
|
||||
out.FragColor += spvReflect(in.vRefract.x, in.vRefract.y);
|
||||
out.FragColor += refract(in.vRefract.xy, in.vRefract.yz, in.vRefract.z).y;
|
||||
out.FragColor += reflect(in.vRefract.xy, in.vRefract.zy).y;
|
||||
return out;
|
||||
}
|
||||
|
13
reference/opt/shaders/frag/scalar-refract-reflect.frag
Normal file
13
reference/opt/shaders/frag/scalar-refract-reflect.frag
Normal file
@ -0,0 +1,13 @@
|
||||
#version 450
|
||||
|
||||
layout(location = 0) out float FragColor;
|
||||
layout(location = 0) in vec3 vRefract;
|
||||
|
||||
void main()
|
||||
{
|
||||
FragColor = refract(vRefract.x, vRefract.y, vRefract.z);
|
||||
FragColor += reflect(vRefract.x, vRefract.y);
|
||||
FragColor += refract(vRefract.xy, vRefract.yz, vRefract.z).y;
|
||||
FragColor += reflect(vRefract.xy, vRefract.zy).y;
|
||||
}
|
||||
|
47
reference/shaders-hlsl/frag/scalar-refract-reflect.frag
Normal file
47
reference/shaders-hlsl/frag/scalar-refract-reflect.frag
Normal file
@ -0,0 +1,47 @@
|
||||
static float FragColor;
|
||||
static float3 vRefract;
|
||||
|
||||
struct SPIRV_Cross_Input
|
||||
{
|
||||
float3 vRefract : TEXCOORD0;
|
||||
};
|
||||
|
||||
struct SPIRV_Cross_Output
|
||||
{
|
||||
float FragColor : SV_Target0;
|
||||
};
|
||||
|
||||
float SPIRV_Cross_Reflect(float i, float n)
|
||||
{
|
||||
return i - 2.0 * dot(n, i) * n;
|
||||
}
|
||||
|
||||
float SPIRV_Cross_Refract(float i, float n, float eta)
|
||||
{
|
||||
float k = 1.0 - eta * eta * (1.0 - dot(n, i) * dot(n, i));
|
||||
if (k < 0.0)
|
||||
{
|
||||
return 0.0;
|
||||
}
|
||||
else
|
||||
{
|
||||
return eta * i - (eta * dot(n, i) + sqrt(k)) * n;
|
||||
}
|
||||
}
|
||||
|
||||
void frag_main()
|
||||
{
|
||||
FragColor = SPIRV_Cross_Refract(vRefract.x, vRefract.y, vRefract.z);
|
||||
FragColor += SPIRV_Cross_Reflect(vRefract.x, vRefract.y);
|
||||
FragColor += refract(vRefract.xy, vRefract.yz, vRefract.z).y;
|
||||
FragColor += reflect(vRefract.xy, vRefract.zy).y;
|
||||
}
|
||||
|
||||
SPIRV_Cross_Output main(SPIRV_Cross_Input stage_input)
|
||||
{
|
||||
vRefract = stage_input.vRefract;
|
||||
frag_main();
|
||||
SPIRV_Cross_Output stage_output;
|
||||
stage_output.FragColor = FragColor;
|
||||
return stage_output;
|
||||
}
|
47
reference/shaders-msl/frag/scalar-refract-reflect.frag
Normal file
47
reference/shaders-msl/frag/scalar-refract-reflect.frag
Normal file
@ -0,0 +1,47 @@
|
||||
#pragma clang diagnostic ignored "-Wmissing-prototypes"
|
||||
|
||||
#include <metal_stdlib>
|
||||
#include <simd/simd.h>
|
||||
|
||||
using namespace metal;
|
||||
|
||||
struct main0_out
|
||||
{
|
||||
float FragColor [[color(0)]];
|
||||
};
|
||||
|
||||
struct main0_in
|
||||
{
|
||||
float3 vRefract [[user(locn0)]];
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
inline T spvReflect(T i, T n)
|
||||
{
|
||||
return i - T(2) * dot(n, i) * n;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
inline T spvRefract(T i, T n, T eta)
|
||||
{
|
||||
T k = T(1) - eta * eta * (T(1) - dot(n, i) * dot(n, i));
|
||||
if (k < T(0))
|
||||
{
|
||||
return T(0);
|
||||
}
|
||||
else
|
||||
{
|
||||
return eta * i - (eta * dot(n, i) + sqrt(k)) * n;
|
||||
}
|
||||
}
|
||||
|
||||
fragment main0_out main0(main0_in in [[stage_in]])
|
||||
{
|
||||
main0_out out = {};
|
||||
out.FragColor = spvRefract(in.vRefract.x, in.vRefract.y, in.vRefract.z);
|
||||
out.FragColor += spvReflect(in.vRefract.x, in.vRefract.y);
|
||||
out.FragColor += refract(in.vRefract.xy, in.vRefract.yz, in.vRefract.z).y;
|
||||
out.FragColor += reflect(in.vRefract.xy, in.vRefract.zy).y;
|
||||
return out;
|
||||
}
|
||||
|
13
reference/shaders/frag/scalar-refract-reflect.frag
Normal file
13
reference/shaders/frag/scalar-refract-reflect.frag
Normal file
@ -0,0 +1,13 @@
|
||||
#version 450
|
||||
|
||||
layout(location = 0) out float FragColor;
|
||||
layout(location = 0) in vec3 vRefract;
|
||||
|
||||
void main()
|
||||
{
|
||||
FragColor = refract(vRefract.x, vRefract.y, vRefract.z);
|
||||
FragColor += reflect(vRefract.x, vRefract.y);
|
||||
FragColor += refract(vRefract.xy, vRefract.yz, vRefract.z).y;
|
||||
FragColor += reflect(vRefract.xy, vRefract.zy).y;
|
||||
}
|
||||
|
11
shaders-hlsl/frag/scalar-refract-reflect.frag
Normal file
11
shaders-hlsl/frag/scalar-refract-reflect.frag
Normal file
@ -0,0 +1,11 @@
|
||||
#version 450
|
||||
layout(location = 0) out float FragColor;
|
||||
layout(location = 0) in vec3 vRefract;
|
||||
|
||||
void main()
|
||||
{
|
||||
FragColor = refract(vRefract.x, vRefract.y, vRefract.z);
|
||||
FragColor += reflect(vRefract.x, vRefract.y);
|
||||
FragColor += refract(vRefract.xy, vRefract.yz, vRefract.z).y;
|
||||
FragColor += reflect(vRefract.xy, vRefract.zy).y;
|
||||
}
|
11
shaders-msl/frag/scalar-refract-reflect.frag
Normal file
11
shaders-msl/frag/scalar-refract-reflect.frag
Normal file
@ -0,0 +1,11 @@
|
||||
#version 450
|
||||
layout(location = 0) out float FragColor;
|
||||
layout(location = 0) in vec3 vRefract;
|
||||
|
||||
void main()
|
||||
{
|
||||
FragColor = refract(vRefract.x, vRefract.y, vRefract.z);
|
||||
FragColor += reflect(vRefract.x, vRefract.y);
|
||||
FragColor += refract(vRefract.xy, vRefract.yz, vRefract.z).y;
|
||||
FragColor += reflect(vRefract.xy, vRefract.zy).y;
|
||||
}
|
11
shaders/frag/scalar-refract-reflect.frag
Normal file
11
shaders/frag/scalar-refract-reflect.frag
Normal file
@ -0,0 +1,11 @@
|
||||
#version 450
|
||||
layout(location = 0) out float FragColor;
|
||||
layout(location = 0) in vec3 vRefract;
|
||||
|
||||
void main()
|
||||
{
|
||||
FragColor = refract(vRefract.x, vRefract.y, vRefract.z);
|
||||
FragColor += reflect(vRefract.x, vRefract.y);
|
||||
FragColor += refract(vRefract.xy, vRefract.yz, vRefract.z).y;
|
||||
FragColor += reflect(vRefract.xy, vRefract.zy).y;
|
||||
}
|
@ -1742,6 +1742,34 @@ void CompilerHLSL::emit_resources()
|
||||
end_scope();
|
||||
statement("");
|
||||
}
|
||||
|
||||
if (requires_scalar_reflect)
|
||||
{
|
||||
// FP16/FP64? No templates in HLSL.
|
||||
statement("float SPIRV_Cross_Reflect(float i, float n)");
|
||||
begin_scope();
|
||||
statement("return i - 2.0 * dot(n, i) * n;");
|
||||
end_scope();
|
||||
statement("");
|
||||
}
|
||||
|
||||
if (requires_scalar_refract)
|
||||
{
|
||||
// FP16/FP64? No templates in HLSL.
|
||||
statement("float SPIRV_Cross_Refract(float i, float n, float eta)");
|
||||
begin_scope();
|
||||
statement("float k = 1.0 - eta * eta * (1.0 - dot(n, i) * dot(n, i));");
|
||||
statement("if (k < 0.0)");
|
||||
begin_scope();
|
||||
statement("return 0.0;");
|
||||
end_scope();
|
||||
statement("else");
|
||||
begin_scope();
|
||||
statement("return eta * i - (eta * dot(n, i) + sqrt(k)) * n;");
|
||||
end_scope();
|
||||
end_scope();
|
||||
statement("");
|
||||
}
|
||||
}
|
||||
|
||||
string CompilerHLSL::layout_for_member(const SPIRType &type, uint32_t index)
|
||||
@ -3256,6 +3284,34 @@ void CompilerHLSL::emit_glsl_op(uint32_t result_type, uint32_t id, uint32_t eop,
|
||||
CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
|
||||
break;
|
||||
|
||||
case GLSLstd450Reflect:
|
||||
if (get<SPIRType>(result_type).vecsize == 1)
|
||||
{
|
||||
if (!requires_scalar_reflect)
|
||||
{
|
||||
requires_scalar_reflect = true;
|
||||
force_recompile();
|
||||
}
|
||||
emit_binary_func_op(result_type, id, args[0], args[1], "SPIRV_Cross_Reflect");
|
||||
}
|
||||
else
|
||||
CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
|
||||
break;
|
||||
|
||||
case GLSLstd450Refract:
|
||||
if (get<SPIRType>(result_type).vecsize == 1)
|
||||
{
|
||||
if (!requires_scalar_refract)
|
||||
{
|
||||
requires_scalar_refract = true;
|
||||
force_recompile();
|
||||
}
|
||||
emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "SPIRV_Cross_Refract");
|
||||
}
|
||||
else
|
||||
CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
|
||||
break;
|
||||
|
||||
default:
|
||||
CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
|
||||
break;
|
||||
|
@ -167,6 +167,8 @@ private:
|
||||
void replace_illegal_names() override;
|
||||
|
||||
Options hlsl_options;
|
||||
|
||||
// TODO: Refactor this to be more similar to MSL, maybe have some common system in place?
|
||||
bool requires_op_fmod = false;
|
||||
bool requires_fp16_packing = false;
|
||||
bool requires_explicit_fp16_packing = false;
|
||||
@ -179,6 +181,8 @@ private:
|
||||
bool requires_inverse_2x2 = false;
|
||||
bool requires_inverse_3x3 = false;
|
||||
bool requires_inverse_4x4 = false;
|
||||
bool requires_scalar_reflect = false;
|
||||
bool requires_scalar_refract = false;
|
||||
uint64_t required_textureSizeVariants = 0;
|
||||
void require_texture_query_variant(const SPIRType &type);
|
||||
|
||||
|
@ -3189,6 +3189,34 @@ void CompilerMSL::emit_custom_functions()
|
||||
statement("");
|
||||
break;
|
||||
|
||||
case SPVFuncImplReflectScalar:
|
||||
// Metal does not support scalar versions of these functions.
|
||||
statement("template<typename T>");
|
||||
statement("inline T spvReflect(T i, T n)");
|
||||
begin_scope();
|
||||
statement("return i - T(2) * dot(n, i) * n;");
|
||||
end_scope();
|
||||
statement("");
|
||||
break;
|
||||
|
||||
case SPVFuncImplRefractScalar:
|
||||
// Metal does not support scalar versions of these functions.
|
||||
statement("template<typename T>");
|
||||
statement("inline T spvRefract(T i, T n, T eta)");
|
||||
begin_scope();
|
||||
statement("T k = T(1) - eta * eta * (T(1) - dot(n, i) * dot(n, i));");
|
||||
statement("if (k < T(0))");
|
||||
begin_scope();
|
||||
statement("return T(0);");
|
||||
end_scope();
|
||||
statement("else");
|
||||
begin_scope();
|
||||
statement("return eta * i - (eta * dot(n, i) + sqrt(k)) * n;");
|
||||
end_scope();
|
||||
end_scope();
|
||||
statement("");
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
@ -4709,6 +4737,20 @@ void CompilerMSL::emit_glsl_op(uint32_t result_type, uint32_t id, uint32_t eop,
|
||||
CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
|
||||
break;
|
||||
|
||||
case GLSLstd450Reflect:
|
||||
if (get<SPIRType>(result_type).vecsize == 1)
|
||||
emit_binary_func_op(result_type, id, args[0], args[1], "spvReflect");
|
||||
else
|
||||
CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
|
||||
break;
|
||||
|
||||
case GLSLstd450Refract:
|
||||
if (get<SPIRType>(result_type).vecsize == 1)
|
||||
emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "spvRefract");
|
||||
else
|
||||
CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
|
||||
break;
|
||||
|
||||
default:
|
||||
CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
|
||||
break;
|
||||
@ -8660,7 +8702,7 @@ CompilerMSL::SPVFuncImpl CompilerMSL::OpCodePreprocessor::get_spv_func_impl(Op o
|
||||
uint32_t extension_set = args[2];
|
||||
if (compiler.get<SPIRExtension>(extension_set).ext == SPIRExtension::GLSL)
|
||||
{
|
||||
GLSLstd450 op_450 = static_cast<GLSLstd450>(args[3]);
|
||||
auto op_450 = static_cast<GLSLstd450>(args[3]);
|
||||
switch (op_450)
|
||||
{
|
||||
case GLSLstd450Radians:
|
||||
@ -8675,6 +8717,22 @@ CompilerMSL::SPVFuncImpl CompilerMSL::OpCodePreprocessor::get_spv_func_impl(Op o
|
||||
return SPVFuncImplFindUMsb;
|
||||
case GLSLstd450SSign:
|
||||
return SPVFuncImplSSign;
|
||||
case GLSLstd450Reflect:
|
||||
{
|
||||
auto &type = compiler.get<SPIRType>(args[0]);
|
||||
if (type.vecsize == 1)
|
||||
return SPVFuncImplReflectScalar;
|
||||
else
|
||||
return SPVFuncImplNone;
|
||||
}
|
||||
case GLSLstd450Refract:
|
||||
{
|
||||
auto &type = compiler.get<SPIRType>(args[0]);
|
||||
if (type.vecsize == 1)
|
||||
return SPVFuncImplRefractScalar;
|
||||
else
|
||||
return SPVFuncImplNone;
|
||||
}
|
||||
case GLSLstd450MatrixInverse:
|
||||
{
|
||||
auto &mat_type = compiler.get<SPIRType>(args[0]);
|
||||
|
@ -401,6 +401,8 @@ protected:
|
||||
SPVFuncImplSubgroupBallotFindMSB,
|
||||
SPVFuncImplSubgroupBallotBitCount,
|
||||
SPVFuncImplSubgroupAllEqual,
|
||||
SPVFuncImplReflectScalar,
|
||||
SPVFuncImplRefractScalar,
|
||||
SPVFuncImplArrayCopyMultidimMax = 6
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user