diff --git a/reference/shaders-hlsl/comp/rwbuffer-matrix.comp b/reference/shaders-hlsl/comp/rwbuffer-matrix.comp new file mode 100644 index 00000000..4ae22364 --- /dev/null +++ b/reference/shaders-hlsl/comp/rwbuffer-matrix.comp @@ -0,0 +1,136 @@ +RWByteAddressBuffer _28 : register(u0); +cbuffer _68 : register(b1) +{ + int UBO_index0 : packoffset(c0); + int UBO_index1 : packoffset(c0.y); +}; + +void row_to_col() +{ + float4x4 _55 = asfloat(uint4x4(_28.Load(64), _28.Load(80), _28.Load(96), _28.Load(112), _28.Load(68), _28.Load(84), _28.Load(100), _28.Load(116), _28.Load(72), _28.Load(88), _28.Load(104), _28.Load(120), _28.Load(76), _28.Load(92), _28.Load(108), _28.Load(124))); + _28.Store4(0, asuint(_55[0])); + _28.Store4(16, asuint(_55[1])); + _28.Store4(32, asuint(_55[2])); + _28.Store4(48, asuint(_55[3])); + float2x2 _58 = asfloat(uint2x2(_28.Load(144), _28.Load(152), _28.Load(148), _28.Load(156))); + _28.Store2(128, asuint(_58[0])); + _28.Store2(136, asuint(_58[1])); + float2x3 _61 = asfloat(uint2x3(_28.Load(192), _28.Load(200), _28.Load(208), _28.Load(196), _28.Load(204), _28.Load(212))); + _28.Store3(160, asuint(_61[0])); + _28.Store3(176, asuint(_61[1])); + float3x2 _64 = asfloat(uint3x2(_28.Load(240), _28.Load(256), _28.Load(244), _28.Load(260), _28.Load(248), _28.Load(264))); + _28.Store2(216, asuint(_64[0])); + _28.Store2(224, asuint(_64[1])); + _28.Store2(232, asuint(_64[2])); +} + +void col_to_row() +{ + float4x4 _34 = asfloat(uint4x4(_28.Load4(0), _28.Load4(16), _28.Load4(32), _28.Load4(48))); + _28.Store(64, asuint(_34[0].x)); + _28.Store(68, asuint(_34[1].x)); + _28.Store(72, asuint(_34[2].x)); + _28.Store(76, asuint(_34[3].x)); + _28.Store(80, asuint(_34[0].y)); + _28.Store(84, asuint(_34[1].y)); + _28.Store(88, asuint(_34[2].y)); + _28.Store(92, asuint(_34[3].y)); + _28.Store(96, asuint(_34[0].z)); + _28.Store(100, asuint(_34[1].z)); + _28.Store(104, asuint(_34[2].z)); + _28.Store(108, asuint(_34[3].z)); + _28.Store(112, asuint(_34[0].w)); + _28.Store(116, asuint(_34[1].w)); + _28.Store(120, asuint(_34[2].w)); + _28.Store(124, asuint(_34[3].w)); + float2x2 _40 = asfloat(uint2x2(_28.Load2(128), _28.Load2(136))); + _28.Store(144, asuint(_40[0].x)); + _28.Store(148, asuint(_40[1].x)); + _28.Store(152, asuint(_40[0].y)); + _28.Store(156, asuint(_40[1].y)); + float2x3 _46 = asfloat(uint2x3(_28.Load3(160), _28.Load3(176))); + _28.Store(192, asuint(_46[0].x)); + _28.Store(196, asuint(_46[1].x)); + _28.Store(200, asuint(_46[0].y)); + _28.Store(204, asuint(_46[1].y)); + _28.Store(208, asuint(_46[0].z)); + _28.Store(212, asuint(_46[1].z)); + float3x2 _52 = asfloat(uint3x2(_28.Load2(216), _28.Load2(224), _28.Load2(232))); + _28.Store(240, asuint(_52[0].x)); + _28.Store(244, asuint(_52[1].x)); + _28.Store(248, asuint(_52[2].x)); + _28.Store(256, asuint(_52[0].y)); + _28.Store(260, asuint(_52[1].y)); + _28.Store(264, asuint(_52[2].y)); +} + +void write_dynamic_index_row() +{ + _28.Store(UBO_index0 * 4 + UBO_index1 * 16 + 64, asuint(1.0f)); + _28.Store(UBO_index0 * 4 + UBO_index1 * 8 + 144, asuint(2.0f)); + _28.Store(UBO_index0 * 4 + UBO_index1 * 8 + 192, asuint(3.0f)); + _28.Store(UBO_index0 * 4 + UBO_index1 * 16 + 240, asuint(4.0f)); + _28.Store(UBO_index0 * 4 + 64, asuint(float4(1.0f, 1.0f, 1.0f, 1.0f).x)); + _28.Store(UBO_index0 * 4 + 80, asuint(float4(1.0f, 1.0f, 1.0f, 1.0f).y)); + _28.Store(UBO_index0 * 4 + 96, asuint(float4(1.0f, 1.0f, 1.0f, 1.0f).z)); + _28.Store(UBO_index0 * 4 + 112, asuint(float4(1.0f, 1.0f, 1.0f, 1.0f).w)); + _28.Store(UBO_index0 * 4 + 144, asuint(float2(2.0f, 2.0f).x)); + _28.Store(UBO_index0 * 4 + 152, asuint(float2(2.0f, 2.0f).y)); + _28.Store(UBO_index0 * 4 + 192, asuint(float3(3.0f, 3.0f, 3.0f).x)); + _28.Store(UBO_index0 * 4 + 200, asuint(float3(3.0f, 3.0f, 3.0f).y)); + _28.Store(UBO_index0 * 4 + 208, asuint(float3(3.0f, 3.0f, 3.0f).z)); + _28.Store(UBO_index0 * 4 + 240, asuint(float2(4.0f, 4.0f).x)); + _28.Store(UBO_index0 * 4 + 256, asuint(float2(4.0f, 4.0f).y)); +} + +void write_dynamic_index_col() +{ + _28.Store(UBO_index0 * 16 + UBO_index1 * 4 + 0, asuint(1.0f)); + _28.Store(UBO_index0 * 8 + UBO_index1 * 4 + 128, asuint(2.0f)); + _28.Store(UBO_index0 * 16 + UBO_index1 * 4 + 160, asuint(3.0f)); + _28.Store(UBO_index0 * 8 + UBO_index1 * 4 + 216, asuint(4.0f)); + _28.Store4(UBO_index0 * 16 + 0, asuint(float4(1.0f, 1.0f, 1.0f, 1.0f))); + _28.Store2(UBO_index0 * 8 + 128, asuint(float2(2.0f, 2.0f))); + _28.Store3(UBO_index0 * 16 + 160, asuint(float3(3.0f, 3.0f, 3.0f))); + _28.Store2(UBO_index0 * 8 + 216, asuint(float2(4.0f, 4.0f))); +} + +void read_dynamic_index_row() +{ + float a0 = asfloat(_28.Load(UBO_index0 * 4 + UBO_index1 * 16 + 64)); + float a1 = asfloat(_28.Load(UBO_index0 * 4 + UBO_index1 * 8 + 144)); + float a2 = asfloat(_28.Load(UBO_index0 * 4 + UBO_index1 * 8 + 192)); + float a3 = asfloat(_28.Load(UBO_index0 * 4 + UBO_index1 * 16 + 240)); + float4 v0 = asfloat(uint4(_28.Load(UBO_index0 * 4 + 64), _28.Load(UBO_index0 * 4 + 80), _28.Load(UBO_index0 * 4 + 96), _28.Load(UBO_index0 * 4 + 112))); + float2 v1 = asfloat(uint2(_28.Load(UBO_index0 * 4 + 144), _28.Load(UBO_index0 * 4 + 152))); + float3 v2 = asfloat(uint3(_28.Load(UBO_index0 * 4 + 192), _28.Load(UBO_index0 * 4 + 200), _28.Load(UBO_index0 * 4 + 208))); + float2 v3 = asfloat(uint2(_28.Load(UBO_index0 * 4 + 240), _28.Load(UBO_index0 * 4 + 256))); +} + +void read_dynamic_index_col() +{ + float a0 = asfloat(_28.Load(UBO_index0 * 16 + UBO_index1 * 4 + 0)); + float a1 = asfloat(_28.Load(UBO_index0 * 8 + UBO_index1 * 4 + 128)); + float a2 = asfloat(_28.Load(UBO_index0 * 16 + UBO_index1 * 4 + 160)); + float a3 = asfloat(_28.Load(UBO_index0 * 8 + UBO_index1 * 4 + 216)); + float4 v0 = asfloat(_28.Load4(UBO_index0 * 16 + 0)); + float2 v1 = asfloat(_28.Load2(UBO_index0 * 8 + 128)); + float3 v2 = asfloat(_28.Load3(UBO_index0 * 16 + 160)); + float2 v3 = asfloat(_28.Load2(UBO_index0 * 8 + 216)); +} + +void comp_main() +{ + row_to_col(); + col_to_row(); + write_dynamic_index_row(); + write_dynamic_index_col(); + read_dynamic_index_row(); + read_dynamic_index_col(); +} + +[numthreads(1, 1, 1)] +void main() +{ + comp_main(); +} diff --git a/shaders-hlsl/comp/rwbuffer-matrix.comp b/shaders-hlsl/comp/rwbuffer-matrix.comp new file mode 100644 index 00000000..0e722e0a --- /dev/null +++ b/shaders-hlsl/comp/rwbuffer-matrix.comp @@ -0,0 +1,104 @@ +#version 310 es +layout(local_size_x = 1) in; + +layout(std140, binding = 1) uniform UBO +{ + int index0; + int index1; +}; + +layout(binding = 0, std430) buffer SSBO +{ + layout(column_major) mat4 mcol; + layout(row_major) mat4 mrow; + + layout(column_major) mat2 mcol2x2; + layout(row_major) mat2 mrow2x2; + + layout(column_major) mat2x3 mcol2x3; + layout(row_major) mat2x3 mrow2x3; + + layout(column_major) mat3x2 mcol3x2; + layout(row_major) mat3x2 mrow3x2; +}; + +void col_to_row() +{ + // Load column-major, store row-major. + mrow = mcol; + mrow2x2 = mcol2x2; + mrow2x3 = mcol2x3; + mrow3x2 = mcol3x2; +} + +void row_to_col() +{ + // Load row-major, store column-major. + mcol = mrow; + mcol2x2 = mrow2x2; + mcol2x3 = mrow2x3; + mcol3x2 = mrow3x2; +} + +void write_dynamic_index_row() +{ + mrow[index0][index1] = 1.0; + mrow2x2[index0][index1] = 2.0; + mrow2x3[index0][index1] = 3.0; + mrow3x2[index0][index1] = 4.0; + + mrow[index0] = vec4(1.0); + mrow2x2[index0] = vec2(2.0); + mrow2x3[index0] = vec3(3.0); + mrow3x2[index0] = vec2(4.0); +} + +void write_dynamic_index_col() +{ + mcol[index0][index1] = 1.0; + mcol2x2[index0][index1] = 2.0; + mcol2x3[index0][index1] = 3.0; + mcol3x2[index0][index1] = 4.0; + + mcol[index0] = vec4(1.0); + mcol2x2[index0] = vec2(2.0); + mcol2x3[index0] = vec3(3.0); + mcol3x2[index0] = vec2(4.0); +} + +void read_dynamic_index_row() +{ + float a0 = mrow[index0][index1]; + float a1 = mrow2x2[index0][index1]; + float a2 = mrow2x3[index0][index1]; + float a3 = mrow3x2[index0][index1]; + + vec4 v0 = mrow[index0]; + vec2 v1 = mrow2x2[index0]; + vec3 v2 = mrow2x3[index0]; + vec2 v3 = mrow3x2[index0]; +} + +void read_dynamic_index_col() +{ + float a0 = mcol[index0][index1]; + float a1 = mcol2x2[index0][index1]; + float a2 = mcol2x3[index0][index1]; + float a3 = mcol3x2[index0][index1]; + + vec4 v0 = mcol[index0]; + vec2 v1 = mcol2x2[index0]; + vec3 v2 = mcol2x3[index0]; + vec2 v3 = mcol3x2[index0]; +} + +void main() +{ + row_to_col(); + col_to_row(); + write_dynamic_index_row(); + write_dynamic_index_col(); + read_dynamic_index_row(); + read_dynamic_index_col(); +} + diff --git a/spirv_hlsl.cpp b/spirv_hlsl.cpp index aa3f265f..ab77f6c1 100644 --- a/spirv_hlsl.cpp +++ b/spirv_hlsl.cpp @@ -2230,15 +2230,21 @@ string CompilerHLSL::read_access_chain(const SPIRAccessChain &chain) else if (type.columns == 1) { // Strided load since we are loading a column from a row-major matrix. - load_expr = type_to_glsl(target_type); - load_expr += "("; + if (type.vecsize > 1) + { + load_expr = type_to_glsl(target_type); + load_expr += "("; + } + for (uint32_t r = 0; r < type.vecsize; r++) { load_expr += join(chain.base, ".Load(", chain.dynamic_index, chain.static_index + r * chain.matrix_stride, ")"); if (r + 1 < type.vecsize) load_expr += ", "; } - load_expr += ")"; + + if (type.vecsize > 1) + load_expr += ")"; } else if (!chain.row_major_matrix) { @@ -2316,6 +2322,12 @@ void CompilerHLSL::emit_load(const Instruction &instruction) auto load_expr = read_access_chain(*chain); bool forward = should_forward(ptr) && forced_temporaries.find(id) == end(forced_temporaries); + + // Do not forward complex load sequences like matrices, structs and arrays. + auto &type = get(result_type); + if (type.columns > 1 || !type.array.empty() || type.basetype == SPIRType::Struct) + forward = false; + auto &e = emit_op(result_type, id, load_expr, forward, true); e.need_transpose = false; register_read(id, ptr, forward); @@ -2372,7 +2384,13 @@ void CompilerHLSL::write_access_chain(const SPIRAccessChain &chain, uint32_t val // Strided store. for (uint32_t r = 0; r < type.vecsize; r++) { - auto store_expr = join(to_enclosed_expression(value), ".", index_to_swizzle(r)); + auto store_expr = to_enclosed_expression(value); + if (type.vecsize > 1) + { + store_expr += "."; + store_expr += index_to_swizzle(r); + } + auto bitcast_op = bitcast_glsl_op(target_type, type); if (!bitcast_op.empty()) store_expr = join(bitcast_op, "(", store_expr, ")");