diff --git a/spirv_common.hpp b/spirv_common.hpp index f32ad322..ef5f4a11 100644 --- a/spirv_common.hpp +++ b/spirv_common.hpp @@ -187,6 +187,7 @@ enum Types TypeExpression, TypeConstantOp, TypeCombinedImageSampler, + TypeAccessChain, TypeUndef }; @@ -606,6 +607,38 @@ struct SPIRFunction : IVariant bool analyzed_variable_scope = false; }; +struct SPIRAccessChain : IVariant +{ + enum + { + type = TypeAccessChain + }; + + SPIRAccessChain(uint32_t basetype_, spv::StorageClass storage_, + std::string base_, std::string dynamic_index_, int32_t static_index_) + : basetype(basetype_), + storage(storage_), + base(base_), + dynamic_index(std::move(dynamic_index_)), + static_index(static_index_) + { + } + + // The access chain represents an offset into a buffer. + // Some backends need more complicated handling of access chains to be able to use buffers, like HLSL + // which has no usable buffer type ala GLSL SSBOs. + // StructuredBuffer is too limited, so our only option is to deal with ByteAddressBuffer which works with raw addresses. + + uint32_t basetype; + spv::StorageClass storage; + std::string base; + std::string dynamic_index; + int32_t static_index; + + uint32_t loaded_from = 0; + bool need_transpose = false; +}; + struct SPIRVariable : IVariant { enum diff --git a/spirv_cross.cpp b/spirv_cross.cpp index a019d558..54b044ac 100644 --- a/spirv_cross.cpp +++ b/spirv_cross.cpp @@ -359,6 +359,9 @@ uint32_t Compiler::expression_type_id(uint32_t id) const case TypeCombinedImageSampler: return get(id).combined_type; + case TypeAccessChain: + return get(get(id).basetype); + default: SPIRV_CROSS_THROW("Cannot resolve expression type."); } diff --git a/spirv_glsl.cpp b/spirv_glsl.cpp index 508b4337..005440e9 100644 --- a/spirv_glsl.cpp +++ b/spirv_glsl.cpp @@ -3833,7 +3833,7 @@ string CompilerGLSL::access_chain(uint32_t base, const uint32_t *indices, uint32 { uint32_t matrix_stride; bool need_transpose; - flattened_access_chain_offset(base, indices, count, 0, &need_transpose, &matrix_stride); + flattened_access_chain_offset(expression_type(base), indices, count, 0, 16, &need_transpose, &matrix_stride); if (out_need_transpose) *out_need_transpose = target_type.columns > 1 && need_transpose; @@ -3985,7 +3985,7 @@ std::string CompilerGLSL::flattened_access_chain_vector(uint32_t base, const uin const SPIRType &target_type, uint32_t offset, uint32_t matrix_stride, bool need_transpose) { - auto result = flattened_access_chain_offset(base, indices, count, offset); + auto result = flattened_access_chain_offset(expression_type(base), indices, count, offset, 16); auto buffer_name = to_name(expression_type(base).self); @@ -4044,12 +4044,13 @@ std::string CompilerGLSL::flattened_access_chain_vector(uint32_t base, const uin } } -std::pair CompilerGLSL::flattened_access_chain_offset(uint32_t base, const uint32_t *indices, +std::pair CompilerGLSL::flattened_access_chain_offset(const SPIRType &basetype, const uint32_t *indices, uint32_t count, uint32_t offset, + uint32_t word_stride, bool *need_transpose, uint32_t *out_matrix_stride) { - const auto *type = &expression_type(base); + const auto *type = &basetype; // Start traversing type hierarchy at the proper non-pointer types. while (type->pointer) @@ -4092,8 +4093,6 @@ std::pair CompilerGLSL::flattened_access_chain_offset(uin else { // Dynamic array access. - // FIXME: This will need to change if we support other flattening types than 32-bit. - const uint32_t word_stride = 16; if (array_stride % word_stride) { SPIRV_CROSS_THROW( @@ -4102,7 +4101,7 @@ std::pair CompilerGLSL::flattened_access_chain_offset(uin "This cannot be flattened. Try using std140 layout instead."); } - expr += to_expression(index); + expr += to_enclosed_expression(index); expr += " * "; expr += convert_to_string(array_stride / word_stride); expr += " + "; diff --git a/spirv_glsl.hpp b/spirv_glsl.hpp index 517aafe3..466ef416 100644 --- a/spirv_glsl.hpp +++ b/spirv_glsl.hpp @@ -359,8 +359,9 @@ protected: std::string flattened_access_chain_vector(uint32_t base, const uint32_t *indices, uint32_t count, const SPIRType &target_type, uint32_t offset, uint32_t matrix_stride, bool need_transpose); - std::pair flattened_access_chain_offset(uint32_t base, const uint32_t *indices, + std::pair flattened_access_chain_offset(const SPIRType &basetype, const uint32_t *indices, uint32_t count, uint32_t offset, + uint32_t word_stride, bool *need_transpose = nullptr, uint32_t *matrix_stride = nullptr); diff --git a/spirv_hlsl.cpp b/spirv_hlsl.cpp index 602d947c..be1bc34d 100644 --- a/spirv_hlsl.cpp +++ b/spirv_hlsl.cpp @@ -910,9 +910,9 @@ void CompilerHLSL::emit_buffer_block(const SPIRVariable &var) { auto &type = get(var.basetype); - bool is_uav = has_decoration(type.self, DecorationBufferBlock); - if (is_uav) - SPIRV_CROSS_THROW("Buffer is SSBO (UAV). This is currently unsupported."); + //bool is_uav = has_decoration(type.self, DecorationBufferBlock); + //if (is_uav) + // SPIRV_CROSS_THROW("Buffer is SSBO (UAV). This is currently unsupported."); add_resource_name(type.self); @@ -1802,6 +1802,7 @@ void CompilerHLSL::emit_instruction(const Instruction &instruction) { auto ops = stream(instruction); auto opcode = static_cast(instruction.op); + uint32_t length = instruction.length; #define BOP(op) emit_binary_op(ops[0], ops[1], ops[2], ops[3], #op) #define BOP_CAST(op, type) \ @@ -1817,6 +1818,66 @@ void CompilerHLSL::emit_instruction(const Instruction &instruction) switch (opcode) { + case OpAccessChain: + case OpInBoundsAccessChain: + { + bool need_byte_access_chain = false; + auto &type = expression_type(ops[2]); + const SPIRAccessChain *chain = nullptr; + if (has_decoration(type.self, DecorationBufferBlock)) + { + // If we are starting to poke into an SSBO, we are dealing with ByteAddressBuffers, and we need + // to emit SPIRAccessChain rather than a plain SPIRExpression. + uint32_t chain_arguments = length - 3; + if (chain_arguments > type.array.size()) + need_byte_access_chain = true; + } + else + { + // Keep tacking on an existing access chain. + chain = maybe_get(ops[2]); + if (chain) + need_byte_access_chain = true; + } + + if (need_byte_access_chain) + { + uint32_t to_plain_buffer_length = type.array.size(); + + string base; + if (to_plain_buffer_length != 0) + { + bool need_transpose; + base = access_chain(ops[2], &ops[3], to_plain_buffer_length, get(ops[0]), &need_transpose); + } + else + base = to_expression(ops[2]); + + auto *basetype = &type; + for (uint32_t i = 0; i < to_plain_buffer_length; i++) + basetype = &get(type.parent_type); + + uint32_t matrix_stride = 0; + bool need_transpose = false; + auto offsets = flattened_access_chain_offset(*basetype, + &ops[3 + to_plain_buffer_length], length - 3 - to_plain_buffer_length, + 0, 1, &need_transpose, &matrix_stride); + + + auto &e = set(ops[1], ops[0], type.storage, base, offsets.first, offsets.second); + if (chain) + { + e.dynamic_index += chain->dynamic_index; + e.static_index += chain->static_index; + } + } + else + { + CompilerGLSL::emit_instruction(instruction); + } + break; + } + case OpMatrixTimesVector: { emit_binary_func_op(ops[0], ops[1], ops[3], ops[2], "mul");