MSL: Workaround compiler crashes when using threadgroup bool.

Promote to short instead and do simple casts on load/store instead.

Not 100% complete fix since structs can contain booleans, but this is
getting into pretty ridiculously complicated territory.
This commit is contained in:
Hans-Kristian Arntzen 2021-10-25 10:55:11 +02:00
parent 43eecb2360
commit edf247fb1c
7 changed files with 107 additions and 11 deletions

View File

@ -0,0 +1,20 @@
#include <metal_stdlib>
#include <simd/simd.h>
using namespace metal;
struct SSBO
{
float4 values[1];
};
constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(4u, 1u, 1u);
kernel void main0(device SSBO& _23 [[buffer(0)]], uint gl_LocalInvocationIndex [[thread_index_in_threadgroup]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]])
{
threadgroup short4 foo[4];
foo[gl_LocalInvocationIndex] = short4((isunordered(_23.values[gl_GlobalInvocationID.x], float4(10.0)) || _23.values[gl_GlobalInvocationID.x] != float4(10.0)));
threadgroup_barrier(mem_flags::mem_threadgroup);
_23.values[gl_GlobalInvocationID.x] = select(float4(40.0), float4(30.0), bool4(foo[gl_LocalInvocationIndex ^ 3u]));
}

View File

@ -0,0 +1,28 @@
#pragma clang diagnostic ignored "-Wmissing-prototypes"
#include <metal_stdlib>
#include <simd/simd.h>
using namespace metal;
struct SSBO
{
float4 values[1];
};
constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(4u, 1u, 1u);
static inline __attribute__((always_inline))
void in_function(threadgroup short4 (&foo)[4], thread uint& gl_LocalInvocationIndex, device SSBO& v_23, thread uint3& gl_GlobalInvocationID)
{
foo[gl_LocalInvocationIndex] = short4((isunordered(v_23.values[gl_GlobalInvocationID.x], float4(10.0)) || v_23.values[gl_GlobalInvocationID.x] != float4(10.0)));
threadgroup_barrier(mem_flags::mem_threadgroup);
v_23.values[gl_GlobalInvocationID.x] = select(float4(40.0), float4(30.0), bool4(foo[gl_LocalInvocationIndex ^ 3u]));
}
kernel void main0(device SSBO& v_23 [[buffer(0)]], uint gl_LocalInvocationIndex [[thread_index_in_threadgroup]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]])
{
threadgroup short4 foo[4];
in_function(foo, gl_LocalInvocationIndex, v_23, gl_GlobalInvocationID);
}

View File

@ -0,0 +1,21 @@
#version 450
layout(local_size_x = 4) in;
shared bvec4 foo[4];
layout(binding = 0) buffer SSBO
{
vec4 values[];
};
void in_function()
{
foo[gl_LocalInvocationIndex] = notEqual(values[gl_GlobalInvocationID.x], vec4(10.0));
barrier();
values[gl_GlobalInvocationID.x] = mix(vec4(40.0), vec4(30.0), foo[gl_LocalInvocationIndex ^ 3]);
}
void main()
{
in_function();
}

View File

@ -9891,7 +9891,7 @@ void CompilerGLSL::emit_store_statement(uint32_t lhs_expression, uint32_t rhs_ex
convert_non_uniform_expression(lhs, lhs_expression);
// We might need to cast in order to store to a builtin.
cast_to_builtin_store(lhs_expression, rhs, expression_type(rhs_expression));
cast_to_variable_store(lhs_expression, rhs, expression_type(rhs_expression));
// Tries to optimize assignments like "<lhs> = <lhs> op expr".
// While this is purely cosmetic, this is important for legacy ESSL where loop
@ -10056,7 +10056,7 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction)
expr = enclose_expression(expr + vector_swizzle(type.vecsize, 0));
// We might need to cast in order to load from a builtin.
cast_from_builtin_load(ptr, expr, type);
cast_from_variable_load(ptr, expr, type);
// We might be trying to load a gl_Position[N], where we should be
// doing float4[](gl_in[i].gl_Position, ...) instead.
@ -15385,7 +15385,7 @@ void CompilerGLSL::unroll_array_from_complex_load(uint32_t target_id, uint32_t s
}
}
void CompilerGLSL::cast_from_builtin_load(uint32_t source_id, std::string &expr, const SPIRType &expr_type)
void CompilerGLSL::cast_from_variable_load(uint32_t source_id, std::string &expr, const SPIRType &expr_type)
{
// We will handle array cases elsewhere.
if (!expr_type.array.empty())
@ -15444,7 +15444,7 @@ void CompilerGLSL::cast_from_builtin_load(uint32_t source_id, std::string &expr,
expr = bitcast_expression(expr_type, expected_type, expr);
}
void CompilerGLSL::cast_to_builtin_store(uint32_t target_id, std::string &expr, const SPIRType &expr_type)
void CompilerGLSL::cast_to_variable_store(uint32_t target_id, std::string &expr, const SPIRType &expr_type)
{
auto *var = maybe_get_backing_variable(target_id);
if (var)

View File

@ -903,8 +903,8 @@ protected:
// Builtins in GLSL are always specific signedness, but the SPIR-V can declare them
// as either unsigned or signed.
// Sometimes we will need to automatically perform casts on load and store to make this work.
virtual void cast_to_builtin_store(uint32_t target_id, std::string &expr, const SPIRType &expr_type);
virtual void cast_from_builtin_load(uint32_t source_id, std::string &expr, const SPIRType &expr_type);
virtual void cast_to_variable_store(uint32_t target_id, std::string &expr, const SPIRType &expr_type);
virtual void cast_from_variable_load(uint32_t source_id, std::string &expr, const SPIRType &expr_type);
void unroll_array_from_complex_load(uint32_t target_id, uint32_t source_id, std::string &expr);
bool unroll_array_to_complex_store(uint32_t target_id, uint32_t source_id);
void convert_non_uniform_expression(std::string &expr, uint32_t ptr_id);

View File

@ -13351,8 +13351,23 @@ string CompilerMSL::type_to_glsl(const SPIRType &type, uint32_t id)
// Scalars
case SPIRType::Boolean:
type_name = "bool";
{
auto *var = maybe_get_backing_variable(id);
if (var && var->basevariable)
var = &get<SPIRVariable>(var->basevariable);
// Need to special-case threadgroup booleans. They are supposed to be logical
// storage, but MSL compilers will sometimes crash if you use threadgroup bool.
// Workaround this by using 16-bit types instead and fixup on load-store to this data.
// FIXME: We have no sane way of working around this problem if a struct member is boolean
// and that struct is used as a threadgroup variable, but ... sigh.
if ((var && var->storage == StorageClassWorkgroup) || type.storage == StorageClassWorkgroup)
type_name = "short";
else
type_name = "bool";
break;
}
case SPIRType::Char:
case SPIRType::SByte:
type_name = "char";
@ -15413,12 +15428,16 @@ void CompilerMSL::remap_constexpr_sampler_by_binding(uint32_t desc_set, uint32_t
constexpr_samplers_by_binding[{ desc_set, binding }] = sampler;
}
void CompilerMSL::cast_from_builtin_load(uint32_t source_id, std::string &expr, const SPIRType &expr_type)
void CompilerMSL::cast_from_variable_load(uint32_t source_id, std::string &expr, const SPIRType &expr_type)
{
auto *var = maybe_get_backing_variable(source_id);
if (var)
source_id = var->self;
// Type fixups for workgroup variables if they are booleans.
if (var && var->storage == StorageClassWorkgroup && expr_type.basetype == SPIRType::Boolean)
expr = join(type_to_glsl(expr_type), "(", expr, ")");
// Only interested in standalone builtin variables.
if (!has_decoration(source_id, DecorationBuiltIn))
return;
@ -15505,12 +15524,20 @@ void CompilerMSL::cast_from_builtin_load(uint32_t source_id, std::string &expr,
}
}
void CompilerMSL::cast_to_builtin_store(uint32_t target_id, std::string &expr, const SPIRType &expr_type)
void CompilerMSL::cast_to_variable_store(uint32_t target_id, std::string &expr, const SPIRType &expr_type)
{
auto *var = maybe_get_backing_variable(target_id);
if (var)
target_id = var->self;
// Type fixups for workgroup variables if they are booleans.
if (var && var->storage == StorageClassWorkgroup && expr_type.basetype == SPIRType::Boolean)
{
auto short_type = expr_type;
short_type.basetype = SPIRType::Short;
expr = join(type_to_glsl(short_type), "(", expr, ")");
}
// Only interested in standalone builtin variables.
if (!has_decoration(target_id, DecorationBuiltIn))
return;

View File

@ -960,8 +960,8 @@ protected:
bool does_shader_write_sample_mask = false;
void cast_to_builtin_store(uint32_t target_id, std::string &expr, const SPIRType &expr_type) override;
void cast_from_builtin_load(uint32_t source_id, std::string &expr, const SPIRType &expr_type) override;
void cast_to_variable_store(uint32_t target_id, std::string &expr, const SPIRType &expr_type) override;
void cast_from_variable_load(uint32_t source_id, std::string &expr, const SPIRType &expr_type) override;
void emit_store_statement(uint32_t lhs_expression, uint32_t rhs_expression) override;
void analyze_sampled_image_usage();