MSL: Image gather ConstOffsets supports multiple address spaces.

Required when using descriptor set argument buffers.

- Output overloaded functions for each address space.
- Update test shaders.
This commit is contained in:
Bill Hollings 2024-06-15 07:21:46 -04:00
parent 2d990d355a
commit b5ccb0cf2c
5 changed files with 213 additions and 49 deletions

View File

@ -56,7 +56,31 @@ template<typename T> inline constexpr thread T&& spvForward(thread typename spvR
return static_cast<thread T&&>(x);
}
// Wrapper function that processes a texture gather with a constant offset array.
// Wrapper function that processes a device texture gather with a constant offset array.
template<typename T, template<typename, access = access::sample, typename = void> class Tex, typename Toff, typename... Tp>
inline vec<T, 4> spvGatherCompareConstOffsets(const device Tex<T>& t, sampler s, Toff coffsets, Tp... params)
{
vec<T, 4> rslts[4];
for (uint i = 0; i < 4; i++)
{
rslts[i] = t.gather_compare(s, spvForward<Tp>(params)..., coffsets[i]);
}
return vec<T, 4>(rslts[0].w, rslts[1].w, rslts[2].w, rslts[3].w);
}
// Wrapper function that processes a constant texture gather with a constant offset array.
template<typename T, template<typename, access = access::sample, typename = void> class Tex, typename Toff, typename... Tp>
inline vec<T, 4> spvGatherCompareConstOffsets(const constant Tex<T>& t, sampler s, Toff coffsets, Tp... params)
{
vec<T, 4> rslts[4];
for (uint i = 0; i < 4; i++)
{
rslts[i] = t.gather_compare(s, spvForward<Tp>(params)..., coffsets[i]);
}
return vec<T, 4>(rslts[0].w, rslts[1].w, rslts[2].w, rslts[3].w);
}
// Wrapper function that processes a thread texture gather with a constant offset array.
template<typename T, template<typename, access = access::sample, typename = void> class Tex, typename Toff, typename... Tp>
inline vec<T, 4> spvGatherCompareConstOffsets(const thread Tex<T>& t, sampler s, Toff coffsets, Tp... params)
{

View File

@ -56,7 +56,59 @@ template<typename T> inline constexpr thread T&& spvForward(thread typename spvR
return static_cast<thread T&&>(x);
}
// Wrapper function that processes a texture gather with a constant offset array.
// Wrapper function that processes a device texture gather with a constant offset array.
template<typename T, template<typename, access = access::sample, typename = void> class Tex, typename Toff, typename... Tp>
inline vec<T, 4> spvGatherConstOffsets(const device Tex<T>& t, sampler s, Toff coffsets, component c, Tp... params) METAL_CONST_ARG(c)
{
vec<T, 4> rslts[4];
for (uint i = 0; i < 4; i++)
{
switch (c)
{
case component::x:
rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::x);
break;
case component::y:
rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::y);
break;
case component::z:
rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::z);
break;
case component::w:
rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::w);
break;
}
}
return vec<T, 4>(rslts[0].w, rslts[1].w, rslts[2].w, rslts[3].w);
}
// Wrapper function that processes a constant texture gather with a constant offset array.
template<typename T, template<typename, access = access::sample, typename = void> class Tex, typename Toff, typename... Tp>
inline vec<T, 4> spvGatherConstOffsets(const constant Tex<T>& t, sampler s, Toff coffsets, component c, Tp... params) METAL_CONST_ARG(c)
{
vec<T, 4> rslts[4];
for (uint i = 0; i < 4; i++)
{
switch (c)
{
case component::x:
rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::x);
break;
case component::y:
rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::y);
break;
case component::z:
rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::z);
break;
case component::w:
rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::w);
break;
}
}
return vec<T, 4>(rslts[0].w, rslts[1].w, rslts[2].w, rslts[3].w);
}
// Wrapper function that processes a thread texture gather with a constant offset array.
template<typename T, template<typename, access = access::sample, typename = void> class Tex, typename Toff, typename... Tp>
inline vec<T, 4> spvGatherConstOffsets(const thread Tex<T>& t, sampler s, Toff coffsets, component c, Tp... params) METAL_CONST_ARG(c)
{

View File

@ -56,7 +56,31 @@ template<typename T> inline constexpr thread T&& spvForward(thread typename spvR
return static_cast<thread T&&>(x);
}
// Wrapper function that processes a texture gather with a constant offset array.
// Wrapper function that processes a device texture gather with a constant offset array.
template<typename T, template<typename, access = access::sample, typename = void> class Tex, typename Toff, typename... Tp>
inline vec<T, 4> spvGatherCompareConstOffsets(const device Tex<T>& t, sampler s, Toff coffsets, Tp... params)
{
vec<T, 4> rslts[4];
for (uint i = 0; i < 4; i++)
{
rslts[i] = t.gather_compare(s, spvForward<Tp>(params)..., coffsets[i]);
}
return vec<T, 4>(rslts[0].w, rslts[1].w, rslts[2].w, rslts[3].w);
}
// Wrapper function that processes a constant texture gather with a constant offset array.
template<typename T, template<typename, access = access::sample, typename = void> class Tex, typename Toff, typename... Tp>
inline vec<T, 4> spvGatherCompareConstOffsets(const constant Tex<T>& t, sampler s, Toff coffsets, Tp... params)
{
vec<T, 4> rslts[4];
for (uint i = 0; i < 4; i++)
{
rslts[i] = t.gather_compare(s, spvForward<Tp>(params)..., coffsets[i]);
}
return vec<T, 4>(rslts[0].w, rslts[1].w, rslts[2].w, rslts[3].w);
}
// Wrapper function that processes a thread texture gather with a constant offset array.
template<typename T, template<typename, access = access::sample, typename = void> class Tex, typename Toff, typename... Tp>
inline vec<T, 4> spvGatherCompareConstOffsets(const thread Tex<T>& t, sampler s, Toff coffsets, Tp... params)
{

View File

@ -56,7 +56,59 @@ template<typename T> inline constexpr thread T&& spvForward(thread typename spvR
return static_cast<thread T&&>(x);
}
// Wrapper function that processes a texture gather with a constant offset array.
// Wrapper function that processes a device texture gather with a constant offset array.
template<typename T, template<typename, access = access::sample, typename = void> class Tex, typename Toff, typename... Tp>
inline vec<T, 4> spvGatherConstOffsets(const device Tex<T>& t, sampler s, Toff coffsets, component c, Tp... params) METAL_CONST_ARG(c)
{
vec<T, 4> rslts[4];
for (uint i = 0; i < 4; i++)
{
switch (c)
{
case component::x:
rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::x);
break;
case component::y:
rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::y);
break;
case component::z:
rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::z);
break;
case component::w:
rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::w);
break;
}
}
return vec<T, 4>(rslts[0].w, rslts[1].w, rslts[2].w, rslts[3].w);
}
// Wrapper function that processes a constant texture gather with a constant offset array.
template<typename T, template<typename, access = access::sample, typename = void> class Tex, typename Toff, typename... Tp>
inline vec<T, 4> spvGatherConstOffsets(const constant Tex<T>& t, sampler s, Toff coffsets, component c, Tp... params) METAL_CONST_ARG(c)
{
vec<T, 4> rslts[4];
for (uint i = 0; i < 4; i++)
{
switch (c)
{
case component::x:
rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::x);
break;
case component::y:
rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::y);
break;
case component::z:
rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::z);
break;
case component::w:
rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::w);
break;
}
}
return vec<T, 4>(rslts[0].w, rslts[1].w, rslts[2].w, rslts[3].w);
}
// Wrapper function that processes a thread texture gather with a constant offset array.
template<typename T, template<typename, access = access::sample, typename = void> class Tex, typename Toff, typename... Tp>
inline vec<T, 4> spvGatherConstOffsets(const thread Tex<T>& t, sampler s, Toff coffsets, component c, Tp... params) METAL_CONST_ARG(c)
{

View File

@ -5615,6 +5615,10 @@ void CompilerMSL::emit_custom_templates()
// otherwise they will cause problems when linked together in a single Metallib.
void CompilerMSL::emit_custom_functions()
{
// Use when outputting overloaded functions to cover different address spaces.
static const char *texture_addr_spaces[] = { "device", "constant", "thread" };
static uint32_t texture_addr_space_count = sizeof(texture_addr_spaces) / sizeof(char*);
if (spv_function_implementations.count(SPVFuncImplArrayCopyMultidim))
spv_function_implementations.insert(SPVFuncImplArrayCopy);
@ -6264,54 +6268,62 @@ void CompilerMSL::emit_custom_functions()
break;
case SPVFuncImplGatherConstOffsets:
statement("// Wrapper function that processes a texture gather with a constant offset array.");
statement("template<typename T, template<typename, access = access::sample, typename = void> class Tex, "
"typename Toff, typename... Tp>");
statement("inline vec<T, 4> spvGatherConstOffsets(const thread Tex<T>& t, sampler s, "
"Toff coffsets, component c, Tp... params) METAL_CONST_ARG(c)");
begin_scope();
statement("vec<T, 4> rslts[4];");
statement("for (uint i = 0; i < 4; i++)");
begin_scope();
statement("switch (c)");
begin_scope();
// Work around texture::gather() requiring its component parameter to be a constant expression
statement("case component::x:");
statement(" rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::x);");
statement(" break;");
statement("case component::y:");
statement(" rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::y);");
statement(" break;");
statement("case component::z:");
statement(" rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::z);");
statement(" break;");
statement("case component::w:");
statement(" rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::w);");
statement(" break;");
end_scope();
end_scope();
// Pull all values from the i0j0 component of each gather footprint
statement("return vec<T, 4>(rslts[0].w, rslts[1].w, rslts[2].w, rslts[3].w);");
end_scope();
statement("");
// Because we are passing a texture reference, we have to output an overloaded version of this function for each address space.
for (uint32_t i = 0; i < texture_addr_space_count; i++)
{
statement("// Wrapper function that processes a ", texture_addr_spaces[i], " texture gather with a constant offset array.");
statement("template<typename T, template<typename, access = access::sample, typename = void> class Tex, "
"typename Toff, typename... Tp>");
statement("inline vec<T, 4> spvGatherConstOffsets(const ", texture_addr_spaces[i], " Tex<T>& t, sampler s, "
"Toff coffsets, component c, Tp... params) METAL_CONST_ARG(c)");
begin_scope();
statement("vec<T, 4> rslts[4];");
statement("for (uint i = 0; i < 4; i++)");
begin_scope();
statement("switch (c)");
begin_scope();
// Work around texture::gather() requiring its component parameter to be a constant expression
statement("case component::x:");
statement(" rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::x);");
statement(" break;");
statement("case component::y:");
statement(" rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::y);");
statement(" break;");
statement("case component::z:");
statement(" rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::z);");
statement(" break;");
statement("case component::w:");
statement(" rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::w);");
statement(" break;");
end_scope();
end_scope();
// Pull all values from the i0j0 component of each gather footprint
statement("return vec<T, 4>(rslts[0].w, rslts[1].w, rslts[2].w, rslts[3].w);");
end_scope();
statement("");
}
break;
case SPVFuncImplGatherCompareConstOffsets:
statement("// Wrapper function that processes a texture gather with a constant offset array.");
statement("template<typename T, template<typename, access = access::sample, typename = void> class Tex, "
"typename Toff, typename... Tp>");
statement("inline vec<T, 4> spvGatherCompareConstOffsets(const thread Tex<T>& t, sampler s, "
"Toff coffsets, Tp... params)");
begin_scope();
statement("vec<T, 4> rslts[4];");
statement("for (uint i = 0; i < 4; i++)");
begin_scope();
statement(" rslts[i] = t.gather_compare(s, spvForward<Tp>(params)..., coffsets[i]);");
end_scope();
// Pull all values from the i0j0 component of each gather footprint
statement("return vec<T, 4>(rslts[0].w, rslts[1].w, rslts[2].w, rslts[3].w);");
end_scope();
statement("");
// Because we are passing a texture reference, we have to output an overloaded version of this function for each address space.
for (uint32_t i = 0; i < texture_addr_space_count; i++)
{
statement("// Wrapper function that processes a ", texture_addr_spaces[i], " texture gather with a constant offset array.");
statement("template<typename T, template<typename, access = access::sample, typename = void> class Tex, "
"typename Toff, typename... Tp>");
statement("inline vec<T, 4> spvGatherCompareConstOffsets(const ", texture_addr_spaces[i], " Tex<T>& t, sampler s, "
"Toff coffsets, Tp... params)");
begin_scope();
statement("vec<T, 4> rslts[4];");
statement("for (uint i = 0; i < 4; i++)");
begin_scope();
statement(" rslts[i] = t.gather_compare(s, spvForward<Tp>(params)..., coffsets[i]);");
end_scope();
// Pull all values from the i0j0 component of each gather footprint
statement("return vec<T, 4>(rslts[0].w, rslts[1].w, rslts[2].w, rslts[3].w);");
end_scope();
statement("");
}
break;
case SPVFuncImplSubgroupBroadcast: