MSL: Refactor member reference in terms of one boolean.

ptr_chain was really just masking the proper i == 0 check.
Be more explicit about what the check is actually doing and comment
this.
This commit is contained in:
Hans-Kristian Arntzen 2022-11-21 13:40:27 +01:00
parent e75c496ec6
commit df76a14056
6 changed files with 64 additions and 9 deletions

View File

@ -0,0 +1,28 @@
#include <metal_stdlib>
#include <simd/simd.h>
using namespace metal;
struct SSBO;
struct S
{
float3 v;
};
struct SSBO
{
S s[1];
};
struct PC
{
uint2 ptr;
};
kernel void main0(constant PC& pc [[buffer(0)]])
{
device SSBO* ssbo = reinterpret_cast<device SSBO*>(as_type<ulong>(pc.ptr));
ssbo->s[0].v = float3(1.0);
}

View File

@ -0,0 +1,21 @@
#version 460
#extension GL_EXT_buffer_reference: enable
#extension GL_EXT_buffer_reference_uvec2: enable
struct S {
vec3 v;
};
layout(buffer_reference) buffer SSBO{
S s[];
};
layout(push_constant) uniform PC {
uvec2 ptr;
} pc;
void main(){
SSBO ssbo = SSBO(pc.ptr);
ssbo.s[0].v = vec3(1.0);
}

View File

@ -9421,7 +9421,13 @@ string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indice
else if (flatten_member_reference)
expr += join("_", to_member_name(*type, index));
else
expr += to_member_reference(base, *type, index, i, ptr_chain);
{
// Any pointer de-refences for values are handled in the first access chain.
// For pointer chains, the pointer-ness is resolved through an array access.
// The only time this is not true is when accessing array of SSBO/UBO.
// This case is explicitly handled.
expr += to_member_reference(base, *type, index, ptr_chain || i != 0);
}
}
if (has_member_decoration(type->self, index, DecorationInvariant))
@ -13843,9 +13849,9 @@ string CompilerGLSL::to_member_name(const SPIRType &type, uint32_t index)
return join("_m", index);
}
string CompilerGLSL::to_member_reference(uint32_t, const SPIRType &type, uint32_t member_index, uint32_t chain_index, bool)
string CompilerGLSL::to_member_reference(uint32_t, const SPIRType &type, uint32_t index, bool)
{
return join(".", to_member_name(type, member_index));
return join(".", to_member_name(type, index));
}
string CompilerGLSL::to_multi_member_reference(const SPIRType &type, const SmallVector<uint32_t> &indices)

View File

@ -770,7 +770,7 @@ protected:
std::string address_of_expression(const std::string &expr);
void strip_enclosed_expression(std::string &expr);
std::string to_member_name(const SPIRType &type, uint32_t index);
virtual std::string to_member_reference(uint32_t base, const SPIRType &type, uint32_t member_index, uint32_t chain_index, bool ptr_chain);
virtual std::string to_member_reference(uint32_t base, const SPIRType &type, uint32_t index, bool ptr_chain_is_resolved);
std::string to_multi_member_reference(const SPIRType &type, const SmallVector<uint32_t> &indices);
std::string type_to_glsl_constructor(const SPIRType &type);
std::string argument_decl(const SPIRFunction::Parameter &arg);

View File

@ -14432,7 +14432,7 @@ void CompilerMSL::sync_entry_point_aliases_and_names()
entry.second.name = ir.meta[entry.first].decoration.alias;
}
string CompilerMSL::to_member_reference(uint32_t base, const SPIRType &type, uint32_t member_index, uint32_t chain_index, bool ptr_chain)
string CompilerMSL::to_member_reference(uint32_t base, const SPIRType &type, uint32_t index, bool ptr_chain_is_resolved)
{
auto *var = maybe_get_backing_variable(base);
// If this is a buffer array, we have to dereference the buffer pointers.
@ -14451,10 +14451,10 @@ string CompilerMSL::to_member_reference(uint32_t base, const SPIRType &type, uin
declared_as_pointer = is_buffer_variable && is_array(get<SPIRType>(var->basetype));
}
if (declared_as_pointer || (!ptr_chain && should_dereference(base) && chain_index == 0))
return join("->", to_member_name(type, member_index));
if (declared_as_pointer || (!ptr_chain_is_resolved && should_dereference(base)))
return join("->", to_member_name(type, index));
else
return join(".", to_member_name(type, member_index));
return join(".", to_member_name(type, index));
}
string CompilerMSL::to_qualifiers_glsl(uint32_t id)

View File

@ -831,7 +831,7 @@ protected:
std::string bitcast_glsl_op(const SPIRType &result_type, const SPIRType &argument_type) override;
bool emit_complex_bitcast(uint32_t result_id, uint32_t id, uint32_t op0) override;
bool skip_argument(uint32_t id) const override;
std::string to_member_reference(uint32_t base, const SPIRType &type, uint32_t member_index, uint32_t chain_index, bool ptr_chain) override;
std::string to_member_reference(uint32_t base, const SPIRType &type, uint32_t index, bool ptr_chain_is_resolved) override;
std::string to_qualifiers_glsl(uint32_t id) override;
void replace_illegal_names() override;
void declare_undefined_values() override;