MSL: Fix restrict vs __restrict incompatibility.

restrict was supported, but it broke in MSL 3.0. __restrict works on all
versions, so opt for that instead.

Also check for RestrictPointer decoration and refactor to_restrict() to
not take optional parameter to make it more obvious when implied space
character is added.
This commit is contained in:
Hans-Kristian Arntzen 2022-10-26 12:00:34 +02:00
parent 3ea057a303
commit 2a49f7e82d
14 changed files with 67 additions and 25 deletions

View File

@ -15,7 +15,7 @@ struct _4
int4 _m1;
};
kernel void main0(device _3& restrict _5 [[buffer(0)]], device _4& restrict _6 [[buffer(1)]])
kernel void main0(device _3& __restrict _5 [[buffer(0)]], device _4& __restrict _6 [[buffer(1)]])
{
_6._m0 = _5._m1 + uint4(_5._m0);
_6._m0 = uint4(_5._m0) + _5._m1;

View File

@ -15,7 +15,7 @@ struct _4
int4 _m1;
};
kernel void main0(device _3& restrict _5 [[buffer(0)]], device _4& restrict _6 [[buffer(1)]])
kernel void main0(device _3& __restrict _5 [[buffer(0)]], device _4& __restrict _6 [[buffer(1)]])
{
_6._m0 = uint4(int4(_5._m1) < _5._m0);
_6._m0 = uint4(int4(_5._m1) <= _5._m0);

View File

@ -15,7 +15,7 @@ struct _7
int4 _m1;
};
kernel void main0(device _6& restrict _8 [[buffer(0)]], device _7& restrict _9 [[buffer(1)]])
kernel void main0(device _6& __restrict _8 [[buffer(0)]], device _7& __restrict _9 [[buffer(1)]])
{
_9._m0 = _8._m1 + uint4(_8._m0);
_9._m0 = uint4(_8._m0) + _8._m1;

View File

@ -32,7 +32,7 @@ kernel void main0(constant Registers& registers [[buffer(0)]], uint3 gl_GlobalIn
uint _30 = ((_19 * 8u) * gl_NumWorkGroups.x) + _29;
uint local_index = _30;
uint slice = gl_WorkGroupID.z;
device Position* positions = registers.references->buffers[gl_WorkGroupID.z];
device Position* __restrict positions = registers.references->buffers[gl_WorkGroupID.z];
float _66 = float(gl_WorkGroupID.z);
float _70 = fract(fma(_66, 0.100000001490116119384765625, registers.fract_time));
float _71 = 6.283125400543212890625 * _70;

View File

@ -32,7 +32,7 @@ vertex main0_out main0(constant Registers& registers [[buffer(0)]], uint gl_Inst
{
main0_out out = {};
int slice = int(gl_InstanceIndex);
const device Position* positions = registers.references->buffers[int(gl_InstanceIndex)];
const device Position* __restrict positions = registers.references->buffers[int(gl_InstanceIndex)];
float2 _45 = registers.references->buffers[int(gl_InstanceIndex)]->positions[int(gl_VertexIndex)] * 2.5;
float2 pos = _45;
float2 _60 = _45 + ((float2(float(int(gl_InstanceIndex) % 8), float(int(gl_InstanceIndex) / 8)) - float2(3.5)) * 3.0);

View File

@ -0,0 +1,23 @@
#include <metal_stdlib>
#include <simd/simd.h>
using namespace metal;
struct Ref;
struct Ref
{
float4 v;
};
struct Registers
{
device Ref* foo;
};
kernel void main0(constant Registers& _14 [[buffer(0)]])
{
device Ref* __restrict ref = _14.foo;
ref->v = float4(1.0);
}

View File

@ -15,7 +15,7 @@ struct _4
int4 _m1;
};
kernel void main0(device _3& restrict _5 [[buffer(0)]], device _4& restrict _6 [[buffer(1)]])
kernel void main0(device _3& __restrict _5 [[buffer(0)]], device _4& __restrict _6 [[buffer(1)]])
{
_6._m0 = _5._m1 + uint4(_5._m0);
_6._m0 = uint4(_5._m0) + _5._m1;

View File

@ -15,7 +15,7 @@ struct _4
int4 _m1;
};
kernel void main0(device _3& restrict _5 [[buffer(0)]], device _4& restrict _6 [[buffer(1)]])
kernel void main0(device _3& __restrict _5 [[buffer(0)]], device _4& __restrict _6 [[buffer(1)]])
{
_6._m0 = uint4(int4(_5._m1) < _5._m0);
_6._m0 = uint4(int4(_5._m1) <= _5._m0);

View File

@ -15,7 +15,7 @@ struct _7
int4 _m1;
};
kernel void main0(device _6& restrict _8 [[buffer(0)]], device _7& restrict _9 [[buffer(1)]])
kernel void main0(device _6& __restrict _8 [[buffer(0)]], device _7& __restrict _9 [[buffer(1)]])
{
_9._m0 = _8._m1 + uint4(_8._m0);
_9._m0 = uint4(_8._m0) + _8._m1;

View File

@ -29,7 +29,7 @@ kernel void main0(constant Registers& registers [[buffer(0)]], uint3 gl_GlobalIn
uint2 local_offset = gl_GlobalInvocationID.xy;
uint local_index = ((local_offset.y * 8u) * gl_NumWorkGroups.x) + local_offset.x;
uint slice = gl_WorkGroupID.z;
device Position* positions = registers.references->buffers[slice];
device Position* __restrict positions = registers.references->buffers[slice];
float offset = 6.283125400543212890625 * fract(registers.fract_time + (float(slice) * 0.100000001490116119384765625));
float2 pos = float2(local_offset);
pos.x += (0.20000000298023223876953125 * sin((2.2000000476837158203125 * pos.x) + offset));

View File

@ -32,7 +32,7 @@ vertex main0_out main0(constant Registers& registers [[buffer(0)]], uint gl_Inst
{
main0_out out = {};
int slice = int(gl_InstanceIndex);
const device Position* positions = registers.references->buffers[slice];
const device Position* __restrict positions = registers.references->buffers[slice];
float2 pos = positions->positions[int(gl_VertexIndex)] * 2.5;
pos += ((float2(float(slice % 8), float(slice / 8)) - float2(3.5)) * 3.0);
out.gl_Position = registers.view_projection * float4(pos, 0.0, 1.0);

View File

@ -0,0 +1,18 @@
#version 450
#extension GL_EXT_buffer_reference : require
layout(buffer_reference) buffer Ref
{
vec4 v;
};
layout(push_constant) uniform Registers
{
Ref foo;
};
void main()
{
restrict Ref ref = foo;
ref.v = vec4(1.0);
}

View File

@ -1275,7 +1275,7 @@ void CompilerMSL::emit_entry_point_declarations()
else
{
is_using_builtin_array = true;
statement(get_argument_address_space(var), " ", type_to_glsl(type), "* ", to_restrict(var_id), name,
statement(get_argument_address_space(var), " ", type_to_glsl(type), "* ", to_restrict(var_id, true), name,
type_to_array_glsl(type), " =");
uint32_t dim = uint32_t(type.array.size());
@ -1310,7 +1310,7 @@ void CompilerMSL::emit_entry_point_declarations()
}
else
{
statement(get_argument_address_space(var), " auto& ", to_restrict(var_id), name, " = *(",
statement(get_argument_address_space(var), " auto& ", to_restrict(var_id, true), name, " = *(",
get_argument_address_space(var), " ", type_to_glsl(type), "* ", to_restrict(var_id, false), ")((",
get_argument_address_space(var), " char* ", to_restrict(var_id, false), ")", to_name(arg_id), ".",
ensure_valid_name(name, "m"), " + ", to_name(dynamic_offsets_buffer_id), "[", base_index, "]);");
@ -1324,7 +1324,7 @@ void CompilerMSL::emit_entry_point_declarations()
const auto &type = get_variable_data_type(var);
const auto &buffer_type = get_variable_element_type(var);
string name = to_name(array_id);
statement(get_argument_address_space(var), " ", type_to_glsl(buffer_type), "* ", to_restrict(array_id), name,
statement(get_argument_address_space(var), " ", type_to_glsl(buffer_type), "* ", to_restrict(array_id, true), name,
"[] =");
begin_scope();
for (uint32_t i = 0; i < to_array_size_literal(type); ++i)
@ -1347,7 +1347,7 @@ void CompilerMSL::emit_entry_point_declarations()
uint32_t desc_binding = get_decoration(var_id, DecorationBinding);
auto alias_name = join("spvBufferAliasSet", desc_set, "Binding", desc_binding);
statement(addr_space, " auto& ", to_restrict(var_id),
statement(addr_space, " auto& ", to_restrict(var_id, true),
name,
" = *(", addr_space, " ", type_to_glsl(type), "*)", alias_name, ";");
}
@ -1365,7 +1365,7 @@ void CompilerMSL::emit_entry_point_declarations()
if (type.array.empty())
{
statement(addr_space, " auto& ", to_restrict(var_id), to_name(var_id), " = (", addr_space, " ",
statement(addr_space, " auto& ", to_restrict(var_id, true), to_name(var_id), " = (", addr_space, " ",
type_to_glsl(type), "&)", ir.meta[alias_id].decoration.qualified_alias, ";");
}
else
@ -1377,7 +1377,7 @@ void CompilerMSL::emit_entry_point_declarations()
// address space of the argument buffer itself, which is usually constant, but can be const device for
// large argument buffers.
is_using_builtin_array = true;
statement(desc_addr_space, " auto& ", to_restrict(var_id), to_name(var_id), " = (", addr_space, " ",
statement(desc_addr_space, " auto& ", to_restrict(var_id, true), to_name(var_id), " = (", addr_space, " ",
type_to_glsl(type), "* ", desc_addr_space, " (&)",
type_to_array_glsl(type), ")", ir.meta[alias_id].decoration.qualified_alias, ";");
is_using_builtin_array = false;
@ -12120,7 +12120,8 @@ const char *CompilerMSL::to_restrict(uint32_t id, bool space)
else
flags = get_decoration_bitset(id);
return flags.get(DecorationRestrict) ? (space ? "restrict " : "restrict") : "";
return flags.get(DecorationRestrict) || flags.get(DecorationRestrictPointerEXT) ?
(space ? "__restrict " : "__restrict") : "";
}
string CompilerMSL::entry_point_arg_stage_in()
@ -12525,7 +12526,7 @@ string CompilerMSL::entry_point_args_argument_buffer(bool append_comma)
claimed_bindings.set(buffer_binding);
ep_args += get_argument_address_space(var) + " " + type_to_glsl(type) + "& " + to_restrict(id) + to_name(id);
ep_args += get_argument_address_space(var) + " " + type_to_glsl(type) + "& " + to_restrict(id, true) + to_name(id);
ep_args += " [[buffer(" + convert_to_string(buffer_binding) + ")]]";
next_metal_resource_index_buffer = max(next_metal_resource_index_buffer, buffer_binding + 1);
@ -12737,7 +12738,7 @@ void CompilerMSL::entry_point_args_discrete_descriptors(string &ep_args)
{
if (!ep_args.empty())
ep_args += ", ";
ep_args += get_argument_address_space(var) + " " + type_to_glsl(type) + "* " + to_restrict(var_id) +
ep_args += get_argument_address_space(var) + " " + type_to_glsl(type) + "* " + to_restrict(var_id, true) +
r.name + "_" + convert_to_string(i);
ep_args += " [[buffer(" + convert_to_string(r.index + i) + ")";
if (interlocked_resources.count(var_id))
@ -12751,7 +12752,7 @@ void CompilerMSL::entry_point_args_discrete_descriptors(string &ep_args)
if (!ep_args.empty())
ep_args += ", ";
ep_args +=
get_argument_address_space(var) + " " + type_to_glsl(type) + "& " + to_restrict(var_id) + r.name;
get_argument_address_space(var) + " " + type_to_glsl(type) + "& " + to_restrict(var_id, true) + r.name;
ep_args += " [[buffer(" + convert_to_string(r.index) + ")";
if (interlocked_resources.count(var_id))
ep_args += ", raster_order_group(0)";
@ -13687,7 +13688,7 @@ string CompilerMSL::argument_decl(const SPIRFunction::Parameter &arg)
// non-constant arrays, but we can create thread const from constant.
decl = string("thread const ") + decl;
decl += " (&";
const char *restrict_kw = to_restrict(name_id);
const char *restrict_kw = to_restrict(name_id, true);
if (*restrict_kw)
{
decl += " ";
@ -13744,7 +13745,7 @@ string CompilerMSL::argument_decl(const SPIRFunction::Parameter &arg)
else
decl += " (&";
const char *restrict_kw = to_restrict(name_id);
const char *restrict_kw = to_restrict(name_id, true);
if (*restrict_kw)
{
decl += " ";
@ -13776,7 +13777,7 @@ string CompilerMSL::argument_decl(const SPIRFunction::Parameter &arg)
}
decl += "&";
decl += " ";
decl += to_restrict(name_id);
decl += to_restrict(name_id, true);
decl += to_expression(name_id);
}
else if (type_is_image)
@ -14276,7 +14277,7 @@ string CompilerMSL::type_to_glsl(const SPIRType &type, uint32_t id, bool member)
default:
// Anything else can be a raw pointer.
type_name += "*";
restrict_kw = to_restrict(id);
restrict_kw = to_restrict(id, false);
if (*restrict_kw)
{
type_name += " ";

View File

@ -980,7 +980,7 @@ protected:
bool validate_member_packing_rules_msl(const SPIRType &type, uint32_t index) const;
std::string get_argument_address_space(const SPIRVariable &argument);
std::string get_type_address_space(const SPIRType &type, uint32_t id, bool argument = false);
const char *to_restrict(uint32_t id, bool space = true);
const char *to_restrict(uint32_t id, bool space);
SPIRType &get_stage_in_struct_type();
SPIRType &get_stage_out_struct_type();
SPIRType &get_patch_stage_in_struct_type();