HLSL: Add native support for 16-bit types.

Adds support for templated load/store in SM 6.2 to deal with small
types.
This commit is contained in:
Hans-Kristian Arntzen 2020-06-04 11:35:21 +02:00
parent d385bf096f
commit 2d5200650a
9 changed files with 316 additions and 35 deletions

View File

@ -323,7 +323,7 @@ if (SPIRV_CROSS_STATIC)
endif() endif()
set(spirv-cross-abi-major 0) set(spirv-cross-abi-major 0)
set(spirv-cross-abi-minor 33) set(spirv-cross-abi-minor 34)
set(spirv-cross-abi-patch 0) set(spirv-cross-abi-patch 0)
if (SPIRV_CROSS_SHARED) if (SPIRV_CROSS_SHARED)

View File

@ -596,6 +596,7 @@ struct CLIArguments
bool hlsl_support_nonzero_base = false; bool hlsl_support_nonzero_base = false;
bool hlsl_force_storage_buffer_as_uav = false; bool hlsl_force_storage_buffer_as_uav = false;
bool hlsl_nonwritable_uav_texture_as_srv = false; bool hlsl_nonwritable_uav_texture_as_srv = false;
bool hlsl_enable_16bit_types = false;
HLSLBindingFlags hlsl_binding_flags = 0; HLSLBindingFlags hlsl_binding_flags = 0;
bool vulkan_semantics = false; bool vulkan_semantics = false;
bool flatten_multidimensional_arrays = false; bool flatten_multidimensional_arrays = false;
@ -687,6 +688,7 @@ static void print_help_hlsl()
"\t\tShader must ensure that read/write state is consistent at all call sites.\n" "\t\tShader must ensure that read/write state is consistent at all call sites.\n"
"\t[--set-hlsl-vertex-input-semantic <location> <semantic>]:\n\t\tEmits a specific vertex input semantic for a given location.\n" "\t[--set-hlsl-vertex-input-semantic <location> <semantic>]:\n\t\tEmits a specific vertex input semantic for a given location.\n"
"\t\tOtherwise, TEXCOORD# is used as semantics, where # is location.\n" "\t\tOtherwise, TEXCOORD# is used as semantics, where # is location.\n"
"\t[--hlsl-enable-16bit-types]:\n\t\tEnables native use of half/int16_t/uint16_t and ByteAddressBuffer interaction with these types. Requires SM 6.2.\n"
); );
} }
@ -1135,6 +1137,7 @@ static string compile_iteration(const CLIArguments &args, std::vector<uint32_t>
hlsl_opts.support_nonzero_base_vertex_base_instance = args.hlsl_support_nonzero_base; hlsl_opts.support_nonzero_base_vertex_base_instance = args.hlsl_support_nonzero_base;
hlsl_opts.force_storage_buffer_as_uav = args.hlsl_force_storage_buffer_as_uav; hlsl_opts.force_storage_buffer_as_uav = args.hlsl_force_storage_buffer_as_uav;
hlsl_opts.nonwritable_uav_texture_as_srv = args.hlsl_nonwritable_uav_texture_as_srv; hlsl_opts.nonwritable_uav_texture_as_srv = args.hlsl_nonwritable_uav_texture_as_srv;
hlsl_opts.enable_16bit_types = args.hlsl_enable_16bit_types;
hlsl->set_hlsl_options(hlsl_opts); hlsl->set_hlsl_options(hlsl_opts);
hlsl->set_resource_binding_flags(args.hlsl_binding_flags); hlsl->set_resource_binding_flags(args.hlsl_binding_flags);
} }
@ -1305,6 +1308,7 @@ static int main_inner(int argc, char *argv[])
[&args](CLIParser &) { args.hlsl_force_storage_buffer_as_uav = true; }); [&args](CLIParser &) { args.hlsl_force_storage_buffer_as_uav = true; });
cbs.add("--hlsl-nonwritable-uav-texture-as-srv", cbs.add("--hlsl-nonwritable-uav-texture-as-srv",
[&args](CLIParser &) { args.hlsl_nonwritable_uav_texture_as_srv = true; }); [&args](CLIParser &) { args.hlsl_nonwritable_uav_texture_as_srv = true; });
cbs.add("--hlsl-enable-16bit-types", [&args](CLIParser &) { args.hlsl_enable_16bit_types = true; });
cbs.add("--vulkan-semantics", [&args](CLIParser &) { args.vulkan_semantics = true; }); cbs.add("--vulkan-semantics", [&args](CLIParser &) { args.vulkan_semantics = true; });
cbs.add("-V", [&args](CLIParser &) { args.vulkan_semantics = true; }); cbs.add("-V", [&args](CLIParser &) { args.vulkan_semantics = true; });
cbs.add("--flatten-multidimensional-arrays", [&args](CLIParser &) { args.flatten_multidimensional_arrays = true; }); cbs.add("--flatten-multidimensional-arrays", [&args](CLIParser &) { args.flatten_multidimensional_arrays = true; });

View File

@ -0,0 +1,78 @@
RWByteAddressBuffer _62 : register(u0, space0);
static float4 gl_FragCoord;
static half4 Output;
static half4 Input;
static int16_t4 OutputI;
static int16_t4 InputI;
static uint16_t4 OutputU;
static uint16_t4 InputU;
struct SPIRV_Cross_Input
{
half4 Input : TEXCOORD0;
nointerpolation int16_t4 InputI : TEXCOORD1;
nointerpolation uint16_t4 InputU : TEXCOORD2;
float4 gl_FragCoord : SV_Position;
};
struct SPIRV_Cross_Output
{
half4 Output : SV_Target0;
int16_t4 OutputI : SV_Target1;
uint16_t4 OutputU : SV_Target2;
};
void frag_main()
{
int index = int(gl_FragCoord.x);
Output = Input + half(20.0).xxxx;
OutputI = InputI + int16_t4(int16_t(-40), int16_t(-40), int16_t(-40), int16_t(-40));
OutputU = InputU + uint16_t4(20u, 20u, 20u, 20u);
Output += _62.Load<half>(index * 2 + 0).xxxx;
OutputI += _62.Load<int16_t>(index * 2 + 8).xxxx;
OutputU += _62.Load<uint16_t>(index * 2 + 16).xxxx;
Output += _62.Load<half4>(index * 8 + 24);
OutputI += _62.Load<int16_t4>(index * 8 + 56);
OutputU += _62.Load<uint16_t4>(index * 8 + 88);
Output += _62.Load<half3>(index * 16 + 128).xyzz;
Output += half3(_62.Load<half>(index * 12 + 186), _62.Load<half>(index * 12 + 190), _62.Load<half>(index * 12 + 194)).xyzz;
half2x3 _128 = half2x3(_62.Load<half3>(index * 16 + 120), _62.Load<half3>(index * 16 + 128));
half2x3 m0 = _128;
half2x3 _132 = half2x3(_62.Load<half>(index * 12 + 184), _62.Load<half>(index * 12 + 188), _62.Load<half>(index * 12 + 192), _62.Load<half>(index * 12 + 186), _62.Load<half>(index * 12 + 190), _62.Load<half>(index * 12 + 194));
half2x3 m1 = _132;
_62.Store<half>(index * 2 + 0, Output.x);
_62.Store<int16_t>(index * 2 + 8, OutputI.y);
_62.Store<uint16_t>(index * 2 + 16, OutputU.z);
_62.Store<half4>(index * 8 + 24, Output);
_62.Store<int16_t4>(index * 8 + 56, OutputI);
_62.Store<uint16_t4>(index * 8 + 88, OutputU);
_62.Store<half3>(index * 16 + 128, Output.xyz);
_62.Store<half>(index * 12 + 186, Output.x);
_62.Store<half>(index * 12 + 190, Output.xyz.y);
_62.Store<half>(index * 12 + 194, Output.xyz.z);
half2x3 _182 = half2x3(half3(Output.xyz), half3(Output.wzy));
_62.Store<half3>(index * 16 + 120, _182[0]);
_62.Store<half3>(index * 16 + 128, _182[1]);
half2x3 _197 = half2x3(half3(Output.xyz), half3(Output.wzy));
_62.Store<half>(index * 12 + 184, _197[0].x);
_62.Store<half>(index * 12 + 186, _197[1].x);
_62.Store<half>(index * 12 + 188, _197[0].y);
_62.Store<half>(index * 12 + 190, _197[1].y);
_62.Store<half>(index * 12 + 192, _197[0].z);
_62.Store<half>(index * 12 + 194, _197[1].z);
}
SPIRV_Cross_Output main(SPIRV_Cross_Input stage_input)
{
gl_FragCoord = stage_input.gl_FragCoord;
Input = stage_input.Input;
InputI = stage_input.InputI;
InputU = stage_input.InputU;
frag_main();
SPIRV_Cross_Output stage_output;
stage_output.Output = Output;
stage_output.OutputI = OutputI;
stage_output.OutputU = OutputU;
return stage_output;
}

View File

@ -0,0 +1,72 @@
#version 450
#extension GL_EXT_shader_explicit_arithmetic_types : require
layout(location = 0) out f16vec4 Output;
layout(location = 0) in f16vec4 Input;
layout(location = 1) out i16vec4 OutputI;
layout(location = 1) flat in i16vec4 InputI;
layout(location = 2) out u16vec4 OutputU;
layout(location = 2) flat in u16vec4 InputU;
layout(set = 0, binding = 0) buffer Buf
{
float16_t foo0[4];
int16_t foo1[4];
uint16_t foo2[4];
f16vec4 foo3[4];
i16vec4 foo4[4];
u16vec4 foo5[4];
f16mat2x3 foo6[4];
layout(row_major) f16mat2x3 foo7[4];
};
void main()
{
int index = int(gl_FragCoord.x);
Output = Input + float16_t(20.0);
OutputI = InputI + int16_t(-40);
OutputU = InputU + uint16_t(20);
// Load 16-bit scalar.
Output += foo0[index];
OutputI += foo1[index];
OutputU += foo2[index];
// Load 16-bit vector.
Output += foo3[index];
OutputI += foo4[index];
OutputU += foo5[index];
// Load 16-bit vector from ColMajor matrix.
Output += foo6[index][1].xyzz;
// Load 16-bit vector from RowMajor matrix.
Output += foo7[index][1].xyzz;
// Load 16-bit matrix from ColMajor.
f16mat2x3 m0 = foo6[index];
// Load 16-bit matrix from RowMajor.
f16mat2x3 m1 = foo7[index];
// Store 16-bit scalar
foo0[index] = Output.x;
foo1[index] = OutputI.y;
foo2[index] = OutputU.z;
// Store 16-bit vector
foo3[index] = Output;
foo4[index] = OutputI;
foo5[index] = OutputU;
// Store 16-bit vector to ColMajor matrix.
foo6[index][1] = Output.xyz;
// Store 16-bit vector to RowMajor matrix.
foo7[index][1] = Output.xyz;
// Store 16-bit matrix to ColMajor.
foo6[index] = f16mat2x3(Output.xyz, Output.wzy);
// Store 16-bit matrix to RowMajor.
foo7[index] = f16mat2x3(Output.xyz, Output.wzy);
}

View File

@ -485,6 +485,10 @@ spvc_result spvc_compiler_options_set_uint(spvc_compiler_options options, spvc_c
case SPVC_COMPILER_OPTION_HLSL_NONWRITABLE_UAV_TEXTURE_AS_SRV: case SPVC_COMPILER_OPTION_HLSL_NONWRITABLE_UAV_TEXTURE_AS_SRV:
options->hlsl.nonwritable_uav_texture_as_srv = value != 0; options->hlsl.nonwritable_uav_texture_as_srv = value != 0;
break; break;
case SPVC_COMPILER_OPTION_HLSL_ENABLE_16BIT_TYPES:
options->hlsl.enable_16bit_types = value != 0;
break;
#endif #endif
#if SPIRV_CROSS_C_API_MSL #if SPIRV_CROSS_C_API_MSL

View File

@ -33,7 +33,7 @@ extern "C" {
/* Bumped if ABI or API breaks backwards compatibility. */ /* Bumped if ABI or API breaks backwards compatibility. */
#define SPVC_C_API_VERSION_MAJOR 0 #define SPVC_C_API_VERSION_MAJOR 0
/* Bumped if APIs or enumerations are added in a backwards compatible way. */ /* Bumped if APIs or enumerations are added in a backwards compatible way. */
#define SPVC_C_API_VERSION_MINOR 33 #define SPVC_C_API_VERSION_MINOR 34
/* Bumped if internal implementation details change. */ /* Bumped if internal implementation details change. */
#define SPVC_C_API_VERSION_PATCH 0 #define SPVC_C_API_VERSION_PATCH 0
@ -588,6 +588,8 @@ typedef enum spvc_compiler_option
SPVC_COMPILER_OPTION_MSL_ENABLE_FRAG_STENCIL_REF_BUILTIN = 58 | SPVC_COMPILER_OPTION_MSL_BIT, SPVC_COMPILER_OPTION_MSL_ENABLE_FRAG_STENCIL_REF_BUILTIN = 58 | SPVC_COMPILER_OPTION_MSL_BIT,
SPVC_COMPILER_OPTION_MSL_ENABLE_CLIP_DISTANCE_USER_VARYING = 59 | SPVC_COMPILER_OPTION_MSL_BIT, SPVC_COMPILER_OPTION_MSL_ENABLE_CLIP_DISTANCE_USER_VARYING = 59 | SPVC_COMPILER_OPTION_MSL_BIT,
SPVC_COMPILER_OPTION_HLSL_ENABLE_16BIT_TYPES = 60 | SPVC_COMPILER_OPTION_HLSL_BIT,
SPVC_COMPILER_OPTION_INT_MAX = 0x7fffffff SPVC_COMPILER_OPTION_INT_MAX = 0x7fffffff
} spvc_compiler_option; } spvc_compiler_option;

View File

@ -430,7 +430,20 @@ string CompilerHLSL::type_to_glsl(const SPIRType &type, uint32_t id)
case SPIRType::AtomicCounter: case SPIRType::AtomicCounter:
return "atomic_uint"; return "atomic_uint";
case SPIRType::Half: case SPIRType::Half:
return "min16float"; if (hlsl_options.enable_16bit_types)
return "half";
else
return "min16float";
case SPIRType::Short:
if (hlsl_options.enable_16bit_types)
return "int16_t";
else
return "min16int";
case SPIRType::UShort:
if (hlsl_options.enable_16bit_types)
return "uint16_t";
else
return "min16uint";
case SPIRType::Float: case SPIRType::Float:
return "float"; return "float";
case SPIRType::Double: case SPIRType::Double:
@ -458,7 +471,11 @@ string CompilerHLSL::type_to_glsl(const SPIRType &type, uint32_t id)
case SPIRType::UInt: case SPIRType::UInt:
return join("uint", type.vecsize); return join("uint", type.vecsize);
case SPIRType::Half: case SPIRType::Half:
return join("min16float", type.vecsize); return join(hlsl_options.enable_16bit_types ? "half" : "min16float", type.vecsize);
case SPIRType::Short:
return join(hlsl_options.enable_16bit_types ? "int16_t" : "min16int", type.vecsize);
case SPIRType::UShort:
return join(hlsl_options.enable_16bit_types ? "uint16_t" : "min16uint", type.vecsize);
case SPIRType::Float: case SPIRType::Float:
return join("float", type.vecsize); return join("float", type.vecsize);
case SPIRType::Double: case SPIRType::Double:
@ -482,7 +499,11 @@ string CompilerHLSL::type_to_glsl(const SPIRType &type, uint32_t id)
case SPIRType::UInt: case SPIRType::UInt:
return join("uint", type.columns, "x", type.vecsize); return join("uint", type.columns, "x", type.vecsize);
case SPIRType::Half: case SPIRType::Half:
return join("min16float", type.columns, "x", type.vecsize); return join(hlsl_options.enable_16bit_types ? "half" : "min16float", type.columns, "x", type.vecsize);
case SPIRType::Short:
return join(hlsl_options.enable_16bit_types ? "int16_t" : "min16int", type.columns, "x", type.vecsize);
case SPIRType::UShort:
return join(hlsl_options.enable_16bit_types ? "uint16_t" : "min16uint", type.columns, "x", type.vecsize);
case SPIRType::Float: case SPIRType::Float:
return join("float", type.columns, "x", type.vecsize); return join("float", type.columns, "x", type.vecsize);
case SPIRType::Double: case SPIRType::Double:
@ -3647,11 +3668,16 @@ void CompilerHLSL::read_access_chain(string *expr, const string &lhs, const SPIR
read_access_chain_struct(lhs, chain); read_access_chain_struct(lhs, chain);
return; return;
} }
else if (type.width != 32) else if (type.width != 32 && !hlsl_options.enable_16bit_types)
SPIRV_CROSS_THROW("Reading types other than 32-bit from ByteAddressBuffer not yet supported."); SPIRV_CROSS_THROW("Reading types other than 32-bit from ByteAddressBuffer not yet supported, unless SM 6.2 and native 16-bit types are enabled.");
bool templated_load = hlsl_options.shader_model >= 62;
string load_expr; string load_expr;
string template_expr;
if (templated_load)
template_expr = join("<", type_to_glsl(type), ">");
// Load a vector or scalar. // Load a vector or scalar.
if (type.columns == 1 && !chain.row_major_matrix) if (type.columns == 1 && !chain.row_major_matrix)
{ {
@ -3674,12 +3700,24 @@ void CompilerHLSL::read_access_chain(string *expr, const string &lhs, const SPIR
SPIRV_CROSS_THROW("Unknown vector size."); SPIRV_CROSS_THROW("Unknown vector size.");
} }
load_expr = join(chain.base, ".", load_op, "(", chain.dynamic_index, chain.static_index, ")"); if (templated_load)
load_op = "Load";
load_expr = join(chain.base, ".", load_op, template_expr, "(", chain.dynamic_index, chain.static_index, ")");
} }
else if (type.columns == 1) else if (type.columns == 1)
{ {
// Strided load since we are loading a column from a row-major matrix. // Strided load since we are loading a column from a row-major matrix.
if (type.vecsize > 1) if (templated_load)
{
auto scalar_type = type;
scalar_type.vecsize = 1;
scalar_type.columns = 1;
template_expr = join("<", type_to_glsl(scalar_type), ">");
if (type.vecsize > 1)
load_expr += type_to_glsl(type) + "(";
}
else if (type.vecsize > 1)
{ {
load_expr = type_to_glsl(target_type); load_expr = type_to_glsl(target_type);
load_expr += "("; load_expr += "(";
@ -3688,7 +3726,7 @@ void CompilerHLSL::read_access_chain(string *expr, const string &lhs, const SPIR
for (uint32_t r = 0; r < type.vecsize; r++) for (uint32_t r = 0; r < type.vecsize; r++)
{ {
load_expr += load_expr +=
join(chain.base, ".Load(", chain.dynamic_index, chain.static_index + r * chain.matrix_stride, ")"); join(chain.base, ".Load", template_expr, "(", chain.dynamic_index, chain.static_index + r * chain.matrix_stride, ")");
if (r + 1 < type.vecsize) if (r + 1 < type.vecsize)
load_expr += ", "; load_expr += ", ";
} }
@ -3718,13 +3756,25 @@ void CompilerHLSL::read_access_chain(string *expr, const string &lhs, const SPIR
SPIRV_CROSS_THROW("Unknown vector size."); SPIRV_CROSS_THROW("Unknown vector size.");
} }
// Note, this loading style in HLSL is *actually* row-major, but we always treat matrices as transposed in this backend, if (templated_load)
// so row-major is technically column-major ... {
load_expr = type_to_glsl(target_type); auto vector_type = type;
vector_type.columns = 1;
template_expr = join("<", type_to_glsl(vector_type), ">");
load_expr = type_to_glsl(type);
load_op = "Load";
}
else
{
// Note, this loading style in HLSL is *actually* row-major, but we always treat matrices as transposed in this backend,
// so row-major is technically column-major ...
load_expr = type_to_glsl(target_type);
}
load_expr += "("; load_expr += "(";
for (uint32_t c = 0; c < type.columns; c++) for (uint32_t c = 0; c < type.columns; c++)
{ {
load_expr += join(chain.base, ".", load_op, "(", chain.dynamic_index, load_expr += join(chain.base, ".", load_op, template_expr, "(", chain.dynamic_index,
chain.static_index + c * chain.matrix_stride, ")"); chain.static_index + c * chain.matrix_stride, ")");
if (c + 1 < type.columns) if (c + 1 < type.columns)
load_expr += ", "; load_expr += ", ";
@ -3736,13 +3786,24 @@ void CompilerHLSL::read_access_chain(string *expr, const string &lhs, const SPIR
// Pick out elements one by one ... Hopefully compilers are smart enough to recognize this pattern // Pick out elements one by one ... Hopefully compilers are smart enough to recognize this pattern
// considering HLSL is "row-major decl", but "column-major" memory layout (basically implicit transpose model, ugh) ... // considering HLSL is "row-major decl", but "column-major" memory layout (basically implicit transpose model, ugh) ...
load_expr = type_to_glsl(target_type); if (templated_load)
{
load_expr = type_to_glsl(type);
auto scalar_type = type;
scalar_type.vecsize = 1;
scalar_type.columns = 1;
template_expr = join("<", type_to_glsl(scalar_type), ">");
}
else
load_expr = type_to_glsl(target_type);
load_expr += "("; load_expr += "(";
for (uint32_t c = 0; c < type.columns; c++) for (uint32_t c = 0; c < type.columns; c++)
{ {
for (uint32_t r = 0; r < type.vecsize; r++) for (uint32_t r = 0; r < type.vecsize; r++)
{ {
load_expr += join(chain.base, ".Load(", chain.dynamic_index, load_expr += join(chain.base, ".Load", template_expr, "(", chain.dynamic_index,
chain.static_index + c * (type.width / 8) + r * chain.matrix_stride, ")"); chain.static_index + c * (type.width / 8) + r * chain.matrix_stride, ")");
if ((r + 1 < type.vecsize) || (c + 1 < type.columns)) if ((r + 1 < type.vecsize) || (c + 1 < type.columns))
@ -3752,9 +3813,12 @@ void CompilerHLSL::read_access_chain(string *expr, const string &lhs, const SPIR
load_expr += ")"; load_expr += ")";
} }
auto bitcast_op = bitcast_glsl_op(type, target_type); if (!templated_load)
if (!bitcast_op.empty()) {
load_expr = join(bitcast_op, "(", load_expr, ")"); auto bitcast_op = bitcast_glsl_op(type, target_type);
if (!bitcast_op.empty())
load_expr = join(bitcast_op, "(", load_expr, ")");
}
if (lhs.empty()) if (lhs.empty())
{ {
@ -3937,8 +4001,14 @@ void CompilerHLSL::write_access_chain(const SPIRAccessChain &chain, uint32_t val
register_write(chain.self); register_write(chain.self);
return; return;
} }
else if (type.width != 32) else if (type.width != 32 && !hlsl_options.enable_16bit_types)
SPIRV_CROSS_THROW("Writing types other than 32-bit to RWByteAddressBuffer not yet supported."); SPIRV_CROSS_THROW("Writing types other than 32-bit to RWByteAddressBuffer not yet supported, unless SM 6.2 and native 16-bit types are enabled.");
bool templated_store = hlsl_options.shader_model >= 62;
string template_expr;
if (templated_store)
template_expr = join("<", type_to_glsl(type), ">");
if (type.columns == 1 && !chain.row_major_matrix) if (type.columns == 1 && !chain.row_major_matrix)
{ {
@ -3962,13 +4032,27 @@ void CompilerHLSL::write_access_chain(const SPIRAccessChain &chain, uint32_t val
} }
auto store_expr = write_access_chain_value(value, composite_chain, false); auto store_expr = write_access_chain_value(value, composite_chain, false);
auto bitcast_op = bitcast_glsl_op(target_type, type);
if (!bitcast_op.empty()) if (!templated_store)
store_expr = join(bitcast_op, "(", store_expr, ")"); {
statement(chain.base, ".", store_op, "(", chain.dynamic_index, chain.static_index, ", ", store_expr, ");"); auto bitcast_op = bitcast_glsl_op(target_type, type);
if (!bitcast_op.empty())
store_expr = join(bitcast_op, "(", store_expr, ")");
}
else
store_op = "Store";
statement(chain.base, ".", store_op, template_expr, "(", chain.dynamic_index, chain.static_index, ", ", store_expr, ");");
} }
else if (type.columns == 1) else if (type.columns == 1)
{ {
if (templated_store)
{
auto scalar_type = type;
scalar_type.vecsize = 1;
scalar_type.columns = 1;
template_expr = join("<", type_to_glsl(scalar_type), ">");
}
// Strided store. // Strided store.
for (uint32_t r = 0; r < type.vecsize; r++) for (uint32_t r = 0; r < type.vecsize; r++)
{ {
@ -3980,10 +4064,14 @@ void CompilerHLSL::write_access_chain(const SPIRAccessChain &chain, uint32_t val
} }
remove_duplicate_swizzle(store_expr); remove_duplicate_swizzle(store_expr);
auto bitcast_op = bitcast_glsl_op(target_type, type); if (!templated_store)
if (!bitcast_op.empty()) {
store_expr = join(bitcast_op, "(", store_expr, ")"); auto bitcast_op = bitcast_glsl_op(target_type, type);
statement(chain.base, ".Store(", chain.dynamic_index, chain.static_index + chain.matrix_stride * r, ", ", if (!bitcast_op.empty())
store_expr = join(bitcast_op, "(", store_expr, ")");
}
statement(chain.base, ".Store", template_expr, "(", chain.dynamic_index, chain.static_index + chain.matrix_stride * r, ", ",
store_expr, ");"); store_expr, ");");
} }
} }
@ -4008,18 +4096,39 @@ void CompilerHLSL::write_access_chain(const SPIRAccessChain &chain, uint32_t val
SPIRV_CROSS_THROW("Unknown vector size."); SPIRV_CROSS_THROW("Unknown vector size.");
} }
if (templated_store)
{
store_op = "Store";
auto vector_type = type;
vector_type.columns = 1;
template_expr = join("<", type_to_glsl(vector_type), ">");
}
for (uint32_t c = 0; c < type.columns; c++) for (uint32_t c = 0; c < type.columns; c++)
{ {
auto store_expr = join(write_access_chain_value(value, composite_chain, true), "[", c, "]"); auto store_expr = join(write_access_chain_value(value, composite_chain, true), "[", c, "]");
auto bitcast_op = bitcast_glsl_op(target_type, type);
if (!bitcast_op.empty()) if (!templated_store)
store_expr = join(bitcast_op, "(", store_expr, ")"); {
statement(chain.base, ".", store_op, "(", chain.dynamic_index, chain.static_index + c * chain.matrix_stride, auto bitcast_op = bitcast_glsl_op(target_type, type);
if (!bitcast_op.empty())
store_expr = join(bitcast_op, "(", store_expr, ")");
}
statement(chain.base, ".", store_op, template_expr, "(", chain.dynamic_index, chain.static_index + c * chain.matrix_stride,
", ", store_expr, ");"); ", ", store_expr, ");");
} }
} }
else else
{ {
if (templated_store)
{
auto scalar_type = type;
scalar_type.vecsize = 1;
scalar_type.columns = 1;
template_expr = join("<", type_to_glsl(scalar_type), ">");
}
for (uint32_t r = 0; r < type.vecsize; r++) for (uint32_t r = 0; r < type.vecsize; r++)
{ {
for (uint32_t c = 0; c < type.columns; c++) for (uint32_t c = 0; c < type.columns; c++)
@ -4030,7 +4139,7 @@ void CompilerHLSL::write_access_chain(const SPIRAccessChain &chain, uint32_t val
auto bitcast_op = bitcast_glsl_op(target_type, type); auto bitcast_op = bitcast_glsl_op(target_type, type);
if (!bitcast_op.empty()) if (!bitcast_op.empty())
store_expr = join(bitcast_op, "(", store_expr, ")"); store_expr = join(bitcast_op, "(", store_expr, ")");
statement(chain.base, ".Store(", chain.dynamic_index, statement(chain.base, ".Store", template_expr, "(", chain.dynamic_index,
chain.static_index + c * (type.width / 8) + r * chain.matrix_stride, ", ", store_expr, ");"); chain.static_index + c * (type.width / 8) + r * chain.matrix_stride, ", ", store_expr, ");");
} }
} }
@ -5423,6 +5532,9 @@ void CompilerHLSL::validate_shader_model()
if (ir.addressing_model != AddressingModelLogical) if (ir.addressing_model != AddressingModelLogical)
SPIRV_CROSS_THROW("Only Logical addressing model can be used with HLSL."); SPIRV_CROSS_THROW("Only Logical addressing model can be used with HLSL.");
if (hlsl_options.enable_16bit_types && hlsl_options.shader_model < 62)
SPIRV_CROSS_THROW("Need at least shader model 6.2 when enabling native 16-bit type support.");
} }
string CompilerHLSL::compile() string CompilerHLSL::compile()

View File

@ -119,6 +119,11 @@ public:
// For this to work with function call parameters, NonWritable must be considered to be part of the type system // For this to work with function call parameters, NonWritable must be considered to be part of the type system
// so that NonWritable image arguments are also translated to Texture rather than RWTexture. // so that NonWritable image arguments are also translated to Texture rather than RWTexture.
bool nonwritable_uav_texture_as_srv = false; bool nonwritable_uav_texture_as_srv = false;
// Enables native 16-bit types. Needs SM 6.2.
// Uses half/int16_t/uint16_t instead of min16* types.
// Also adds support for 16-bit load-store from (RW)ByteAddressBuffer.
bool enable_16bit_types = false;
}; };
explicit CompilerHLSL(std::vector<uint32_t> spirv_) explicit CompilerHLSL(std::vector<uint32_t> spirv_)

View File

@ -336,7 +336,9 @@ def validate_shader_hlsl(shader, force_no_external_validation, paths):
raise RuntimeError('Failed compiling HLSL shader') raise RuntimeError('Failed compiling HLSL shader')
def shader_to_sm(shader): def shader_to_sm(shader):
if '.sm60.' in shader: if '.sm62.' in shader:
return '62'
elif '.sm60.' in shader:
return '60' return '60'
elif '.sm51.' in shader: elif '.sm51.' in shader:
return '51' return '51'
@ -374,6 +376,8 @@ def cross_compile_hlsl(shader, spirv, opt, force_no_external_validation, iterati
hlsl_args.append('--force-zero-initialized-variables') hlsl_args.append('--force-zero-initialized-variables')
if '.nonwritable-uav-texture.' in shader: if '.nonwritable-uav-texture.' in shader:
hlsl_args.append('--hlsl-nonwritable-uav-texture-as-srv') hlsl_args.append('--hlsl-nonwritable-uav-texture-as-srv')
if '.native-16bit.' in shader:
hlsl_args.append('--hlsl-enable-16bit-types')
subprocess.check_call(hlsl_args) subprocess.check_call(hlsl_args)