diff --git a/spirv_hlsl.cpp b/spirv_hlsl.cpp index d7021978..3310e119 100644 --- a/spirv_hlsl.cpp +++ b/spirv_hlsl.cpp @@ -2254,6 +2254,20 @@ void CompilerHLSL::emit_resources() end_scope(); statement(""); } + + if (is_mesh_shader && options.vertex.flip_vert_y) + { + statement("float4 spvFlipVertY(float4 v)"); + begin_scope(); + statement("return float4(v.x, -v.y, v.z, v.w);"); + end_scope(); + statement(""); + statement("float spvFlipVertY(float v)"); + begin_scope(); + statement("return -v;"); + end_scope(); + statement(""); + } } void CompilerHLSL::emit_texture_size_variants(uint64_t variant_mask, const char *vecsize_qualifier, bool uav, @@ -2483,6 +2497,30 @@ void CompilerHLSL::analyze_meshlet_writes(uint32_t func_id, uint32_t id_per_vert bool already_declared = false; auto builtin_type = BuiltIn(get_decoration(var->self, DecorationBuiltIn)); + if (op == OpAccessChain && i.length > 4) + { + auto &type = get_variable_element_type(*var); + auto *c = maybe_get(ops[4]); + if (c != nullptr && c->scalar() == 0 && + has_member_decoration(type.self, 0, DecorationBuiltIn) && + get_member_decoration(type.self, 0, DecorationBuiltIn) == spv::BuiltInPosition) + { + const SPIRType &dst_type = get_pointee_type(ops[0]); + if (dst_type.vecsize == 4) + { + // full vec4 write + access_to_meshlet_position.insert(ops[1]); + } + else if (i.length > 5) + { + // gl_Position.y + auto *cY = maybe_get(ops[5]); + if (cY != nullptr && cY->scalar() == 1) + access_to_meshlet_position.insert(ops[1]); + } + } + } + uint32_t var_id = var->self; if (var->storage != StorageClassTaskPayloadWorkgroupEXT && builtin_type != BuiltInPrimitivePointIndicesEXT && @@ -4950,6 +4988,15 @@ void CompilerHLSL::write_access_chain(const SPIRAccessChain &chain, uint32_t val void CompilerHLSL::emit_store(const Instruction &instruction) { auto ops = stream(instruction); + if (options.vertex.flip_vert_y && access_to_meshlet_position.count(ops[0]) != 0) + { + auto lhs = to_dereferenced_expression(ops[0]); + auto rhs = to_pointer_expression(ops[1]); + statement(lhs, " = spvFlipVertY(", rhs, ");"); + register_write(ops[0]); + return; + } + auto *chain = maybe_get(ops[0]); if (chain) write_access_chain(*chain, ops[1], {}); diff --git a/spirv_hlsl.hpp b/spirv_hlsl.hpp index a5d30b1b..4322c493 100644 --- a/spirv_hlsl.hpp +++ b/spirv_hlsl.hpp @@ -363,6 +363,7 @@ private: void analyze_meshlet_writes(); void analyze_meshlet_writes(uint32_t func_id, uint32_t id_per_vertex, uint32_t id_per_primitive, std::unordered_set &processed_func_ids); + std::unordered_set access_to_meshlet_position; BitcastType get_bitcast_type(uint32_t result_type, uint32_t op0);