From df18d98beada9c8e333c9a1bffc59b09c1a894ba Mon Sep 17 00:00:00 2001 From: Chip Davis Date: Fri, 26 Jul 2019 01:06:35 -0500 Subject: [PATCH] MSL: Unify the get_*_address_space() methods. These methods have largely the same logic, with minor differences. That I felt compelled to duplicate the logic into another method was one of the things that bothered me about the variable pointers change. This cleans that part of the code up; now we don't have two places to change. --- spirv_msl.cpp | 96 ++++++++++++--------------------------------------- spirv_msl.hpp | 2 +- 2 files changed, 23 insertions(+), 75 deletions(-) diff --git a/spirv_msl.cpp b/spirv_msl.cpp index 1805e12a..9693fd2b 100644 --- a/spirv_msl.cpp +++ b/spirv_msl.cpp @@ -6576,80 +6576,15 @@ string CompilerMSL::func_type_decl(SPIRType &type) string CompilerMSL::get_argument_address_space(const SPIRVariable &argument) { const auto &type = get(argument.basetype); - Bitset flags; - if (type.basetype == SPIRType::Struct && - (has_decoration(type.self, DecorationBlock) || has_decoration(type.self, DecorationBufferBlock))) - flags = ir.get_buffer_block_flags(argument); - else - flags = get_decoration_bitset(argument.self); - const char *addr_space = nullptr; - - switch (type.storage) - { - case StorageClassWorkgroup: - addr_space = "threadgroup"; - break; - - case StorageClassStorageBuffer: - { - // For arguments from variable pointers, we use the write count deduction, so - // we should not assume any constness here. Only for global SSBOs. - bool readonly = false; - if (has_decoration(type.self, DecorationBlock)) - readonly = flags.get(DecorationNonWritable); - - addr_space = readonly ? "const device" : "device"; - break; - } - - case StorageClassUniform: - case StorageClassUniformConstant: - case StorageClassPushConstant: - if (type.basetype == SPIRType::Struct) - { - bool ssbo = has_decoration(type.self, DecorationBufferBlock); - if (ssbo) - { - bool readonly = flags.get(DecorationNonWritable); - addr_space = readonly ? "const device" : "device"; - } - else - addr_space = "constant"; - break; - } - break; - - case StorageClassFunction: - case StorageClassGeneric: - // No address space for plain values. - addr_space = type.pointer ? "thread" : ""; - break; - - case StorageClassInput: - if (get_execution_model() == ExecutionModelTessellationControl && argument.basevariable == stage_in_ptr_var_id) - addr_space = "threadgroup"; - break; - - case StorageClassOutput: - if (capture_output_to_buffer) - addr_space = "device"; - break; - - default: - break; - } - - if (!addr_space) - addr_space = "thread"; - - return join(flags.get(DecorationVolatile) || flags.get(DecorationCoherent) ? "volatile " : "", addr_space); + return get_type_address_space(type, argument.self, true); } -string CompilerMSL::get_type_address_space(const SPIRType &type, uint32_t id) +string CompilerMSL::get_type_address_space(const SPIRType &type, uint32_t id, bool argument) { // This can be called for variable pointer contexts as well, so be very careful about which method we choose. Bitset flags; - if (ir.ids[id].get_type() == TypeVariable && type.basetype == SPIRType::Struct && + auto *var = maybe_get(id); + if (var && type.basetype == SPIRType::Struct && (has_decoration(type.self, DecorationBlock) || has_decoration(type.self, DecorationBufferBlock))) flags = get_buffer_block_flags(id); else @@ -6663,8 +6598,16 @@ string CompilerMSL::get_type_address_space(const SPIRType &type, uint32_t id) break; case StorageClassStorageBuffer: - addr_space = flags.get(DecorationNonWritable) ? "const device" : "device"; + { + // For arguments from variable pointers, we use the write count deduction, so + // we should not assume any constness here. Only for global SSBOs. + bool readonly = false; + if (!var || has_decoration(type.self, DecorationBlock)) + readonly = flags.get(DecorationNonWritable); + + addr_space = readonly ? "const device" : "device"; break; + } case StorageClassUniform: case StorageClassUniformConstant: @@ -6677,14 +6620,18 @@ string CompilerMSL::get_type_address_space(const SPIRType &type, uint32_t id) else addr_space = "constant"; } - else + else if (!argument) addr_space = "constant"; break; case StorageClassFunction: case StorageClassGeneric: - // No address space for plain values. - addr_space = type.pointer ? "thread" : ""; + break; + + case StorageClassInput: + if (get_execution_model() == ExecutionModelTessellationControl && var && + var->basevariable == stage_in_ptr_var_id) + addr_space = "threadgroup"; break; case StorageClassOutput: @@ -6697,7 +6644,8 @@ string CompilerMSL::get_type_address_space(const SPIRType &type, uint32_t id) } if (!addr_space) - addr_space = "thread"; + // No address space for plain values. + addr_space = type.pointer || (argument && type.basetype == SPIRType::ControlPointArray) ? "thread" : ""; return join(flags.get(DecorationVolatile) || flags.get(DecorationCoherent) ? "volatile " : "", addr_space); } diff --git a/spirv_msl.hpp b/spirv_msl.hpp index 21473e2d..f0858c9d 100644 --- a/spirv_msl.hpp +++ b/spirv_msl.hpp @@ -541,7 +541,7 @@ protected: void ensure_member_packing_rules_msl(SPIRType &ib_type, uint32_t index); bool validate_member_packing_rules_msl(const SPIRType &type, uint32_t index) const; std::string get_argument_address_space(const SPIRVariable &argument); - std::string get_type_address_space(const SPIRType &type, uint32_t id); + std::string get_type_address_space(const SPIRType &type, uint32_t id, bool argument = false); const char *to_restrict(uint32_t id, bool space = true); SPIRType &get_stage_in_struct_type(); SPIRType &get_stage_out_struct_type();