diff --git a/main.cpp b/main.cpp index 6fa1fc17..c1b9f632 100644 --- a/main.cpp +++ b/main.cpp @@ -347,6 +347,16 @@ static void print_push_constant_resources(const Compiler &compiler, const vector } } +static void print_spec_constants(const Compiler &compiler) +{ + auto spec_constants = compiler.get_specialization_constants(); + fprintf(stderr, "Specialization constants\n"); + fprintf(stderr, "==================\n\n"); + for (auto &c : spec_constants) + fprintf(stderr, "ID: %u, Spec ID: %u\n", c.id, c.constant_id); + fprintf(stderr, "==================\n\n"); +} + struct PLSArg { PlsFormat format; @@ -650,6 +660,7 @@ int main(int argc, char *argv[]) { print_resources(*compiler, res); print_push_constant_resources(*compiler, res.push_constant_buffers); + print_spec_constants(*compiler); } if (combined_image_samplers) diff --git a/reference/shaders/vulkan/frag/spec-constant.vk.frag b/reference/shaders/vulkan/frag/spec-constant.vk.frag new file mode 100644 index 00000000..0785cfd5 --- /dev/null +++ b/reference/shaders/vulkan/frag/spec-constant.vk.frag @@ -0,0 +1,59 @@ +#version 310 es +precision mediump float; +precision highp int; + +struct Foo +{ + float elems[(4 + 2)]; +}; + +layout(location = 0) out vec4 FragColor; + +void main() +{ + float t0 = 1.0; + float t1 = 2.0; + mediump uint c0 = (uint(3) + 0u); + mediump int c1 = (-3); + mediump int c2 = (~3); + mediump int c3 = (3 + 4); + mediump int c4 = (3 - 4); + mediump int c5 = (3 * 4); + mediump int c6 = (3 / 4); + mediump uint c7 = (5u / 6u); + mediump int c8 = (3 % 4); + mediump uint c9 = (5u % 6u); + mediump int c10 = (3 >> 4); + mediump uint c11 = (5u >> 6u); + mediump int c12 = (3 << 4); + mediump int c13 = (3 | 4); + mediump int c14 = (3 ^ 4); + mediump int c15 = (3 & 4); + bool c16 = (false || true); + bool c17 = (false && true); + bool c18 = (!false); + bool c19 = (false == true); + bool c20 = (false != true); + bool c21 = (3 == 4); + bool c22 = (3 != 4); + bool c23 = (3 < 4); + bool c24 = (5u < 6u); + bool c25 = (3 > 4); + bool c26 = (5u > 6u); + bool c27 = (3 <= 4); + bool c28 = (5u <= 6u); + bool c29 = (3 >= 4); + bool c30 = (5u >= 6u); + mediump int c31 = (c8 + c3); + mediump int c32 = int(5u + 0u); + bool c33 = (3 != int(0u)); + bool c34 = (5u != 0u); + mediump int c35 = int(false); + mediump uint c36 = uint(false); + float c37 = float(false); + float vec0[4][(3 + 3)]; + float vec1[(3 + 2)][(4 + 5)]; + Foo foo; + FragColor = (((vec4((t0 + t1)) + vec4(vec0[0][0])) + vec4(vec1[0][0])) + vec4(foo.elems[3])); +} + diff --git a/reference/shaders/vulkan/frag/spec-constant.vk.frag.vk b/reference/shaders/vulkan/frag/spec-constant.vk.frag.vk new file mode 100644 index 00000000..f8e241be --- /dev/null +++ b/reference/shaders/vulkan/frag/spec-constant.vk.frag.vk @@ -0,0 +1,68 @@ +#version 310 es +precision mediump float; +precision highp int; + +layout(constant_id = 1) const float _9 = 1.0; +layout(constant_id = 2) const float _11 = 2.0; +layout(constant_id = 3) const int _16 = 3; +layout(constant_id = 4) const int _25 = 4; +layout(constant_id = 5) const uint _34 = 5u; +layout(constant_id = 6) const uint _35 = 6u; +layout(constant_id = 7) const bool _56 = false; +layout(constant_id = 8) const bool _57 = true; + +struct Foo +{ + float elems[(_25 + 2)]; +}; + +layout(location = 0) out vec4 FragColor; + +void main() +{ + float t0 = _9; + float t1 = _11; + mediump uint c0 = (uint(_16) + 0u); + mediump int c1 = (-_16); + mediump int c2 = (~_16); + mediump int c3 = (_16 + _25); + mediump int c4 = (_16 - _25); + mediump int c5 = (_16 * _25); + mediump int c6 = (_16 / _25); + mediump uint c7 = (_34 / _35); + mediump int c8 = (_16 % _25); + mediump uint c9 = (_34 % _35); + mediump int c10 = (_16 >> _25); + mediump uint c11 = (_34 >> _35); + mediump int c12 = (_16 << _25); + mediump int c13 = (_16 | _25); + mediump int c14 = (_16 ^ _25); + mediump int c15 = (_16 & _25); + bool c16 = (_56 || _57); + bool c17 = (_56 && _57); + bool c18 = (!_56); + bool c19 = (_56 == _57); + bool c20 = (_56 != _57); + bool c21 = (_16 == _25); + bool c22 = (_16 != _25); + bool c23 = (_16 < _25); + bool c24 = (_34 < _35); + bool c25 = (_16 > _25); + bool c26 = (_34 > _35); + bool c27 = (_16 <= _25); + bool c28 = (_34 <= _35); + bool c29 = (_16 >= _25); + bool c30 = (_34 >= _35); + mediump int c31 = (c8 + c3); + mediump int c32 = int(_34 + 0u); + bool c33 = (_16 != int(0u)); + bool c34 = (_34 != 0u); + mediump int c35 = int(_56); + mediump uint c36 = uint(_56); + float c37 = float(_56); + float vec0[_25][(_16 + 3)]; + float vec1[(_16 + 2)][(_25 + 5)]; + Foo foo; + FragColor = (((vec4((t0 + t1)) + vec4(vec0[0][0])) + vec4(vec1[0][0])) + vec4(foo.elems[_16])); +} + diff --git a/shaders/vulkan/frag/spec-constant.vk.frag b/shaders/vulkan/frag/spec-constant.vk.frag new file mode 100644 index 00000000..03b625bf --- /dev/null +++ b/shaders/vulkan/frag/spec-constant.vk.frag @@ -0,0 +1,77 @@ +#version 310 es +precision mediump float; + +layout(location = 0) out vec4 FragColor; +layout(constant_id = 1) const float a = 1.0; +layout(constant_id = 2) const float b = 2.0; +layout(constant_id = 3) const int c = 3; +layout(constant_id = 4) const int d = 4; +layout(constant_id = 5) const uint e = 5u; +layout(constant_id = 6) const uint f = 6u; +layout(constant_id = 7) const bool g = false; +layout(constant_id = 8) const bool h = true; +// glslang doesn't seem to support partial spec constants or composites yet, so only test the basics. + +struct Foo +{ + float elems[d + 2]; +}; + +void main() +{ + float t0 = a; + float t1 = b; + + uint c0 = uint(c); // OpIAdd with different types. + // FConvert, float-to-double. + int c1 = -c; // SNegate + int c2 = ~c; // OpNot + int c3 = c + d; // OpIAdd + int c4 = c - d; // OpISub + int c5 = c * d; // OpIMul + int c6 = c / d; // OpSDiv + uint c7 = e / f; // OpUDiv + int c8 = c % d; // OpSMod + uint c9 = e % f; // OpUMod + // TODO: OpSRem, any way to access this in GLSL? + int c10 = c >> d; // OpShiftRightArithmetic + uint c11 = e >> f; // OpShiftRightLogical + int c12 = c << d; // OpShiftLeftLogical + int c13 = c | d; // OpBitwiseOr + int c14 = c ^ d; // OpBitwiseXor + int c15 = c & d; // OpBitwiseAnd + // VectorShuffle, CompositeExtract, CompositeInsert, not testable atm. + bool c16 = g || h; // OpLogicalOr + bool c17 = g && h; // OpLogicalAnd + bool c18 = !g; // OpLogicalNot + bool c19 = g == h; // OpLogicalEqual + bool c20 = g != h; // OpLogicalNotEqual + // OpSelect not testable atm. + bool c21 = c == d; // OpIEqual + bool c22 = c != d; // OpINotEqual + bool c23 = c < d; // OpSLessThan + bool c24 = e < f; // OpULessThan + bool c25 = c > d; // OpSGreaterThan + bool c26 = e > f; // OpUGreaterThan + bool c27 = c <= d; // OpSLessThanEqual + bool c28 = e <= f; // OpULessThanEqual + bool c29 = c >= d; // OpSGreaterThanEqual + bool c30 = e >= f; // OpUGreaterThanEqual + // OpQuantizeToF16 not testable atm. + + int c31 = c8 + c3; + + int c32 = int(e); // OpIAdd with different types. + bool c33 = bool(c); // int -> bool + bool c34 = bool(e); // uint -> bool + int c35 = int(g); // bool -> int + uint c36 = uint(g); // bool -> uint + float c37 = float(g); // bool -> float + + // Flexible sized arrays with spec constants and spec constant ops. + float vec0[d][c + 3]; + float vec1[c + 2][d + 5]; + + Foo foo; + FragColor = vec4(t0 + t1) + vec0[0][0] + vec1[0][0] + foo.elems[c]; +} diff --git a/spirv_common.hpp b/spirv_common.hpp index 714b6ec6..0b06c69a 100644 --- a/spirv_common.hpp +++ b/spirv_common.hpp @@ -134,6 +134,7 @@ enum Types TypeBlock, TypeExtension, TypeExpression, + TypeConstantOp, TypeUndef }; @@ -150,6 +151,25 @@ struct SPIRUndef : IVariant uint32_t basetype; }; +struct SPIRConstantOp : IVariant +{ + enum + { + type = TypeConstantOp + }; + + SPIRConstantOp(uint32_t result_type, spv::Op op, const uint32_t *args, uint32_t length) + : opcode(op) + , arguments(args, args + length) + , basetype(result_type) + { + } + + spv::Op opcode; + std::vector arguments; + uint32_t basetype; +}; + struct SPIRType : IVariant { enum @@ -182,9 +202,16 @@ struct SPIRType : IVariant uint32_t vecsize = 1; uint32_t columns = 1; - // Arrays, suport array of arrays by having a vector of array sizes. + // Arrays, support array of arrays by having a vector of array sizes. std::vector array; + // Array elements can be either specialization constants or specialization ops. + // This array determines how to interpret the array size. + // If an element is true, the element is a literal, + // otherwise, it's an expression, which must be resolved on demand. + // The actual size is not really known until runtime. + std::vector array_size_literal; + // Pointers bool pointer = false; spv::StorageClass storage = spv::StorageClassGeneric; @@ -826,6 +853,7 @@ struct Meta uint32_t offset = 0; uint32_t array_stride = 0; uint32_t input_attachment = 0; + uint32_t spec_id = 0; bool builtin = false; bool per_instance = false; }; diff --git a/spirv_cpp.cpp b/spirv_cpp.cpp index cc4c05c0..ae1737f6 100644 --- a/spirv_cpp.cpp +++ b/spirv_cpp.cpp @@ -410,8 +410,8 @@ string CompilerCPP::argument_decl(const SPIRFunction::Parameter &arg) string variable_name = to_name(var.self); remap_variable_type_name(type, variable_name, base); - for (auto &array : type.array) - base = join("std::array<", base, ", ", array, ">"); + for (uint32_t i = 0; i < type.array.size(); i++) + base = join("std::array<", base, ", ", to_array_size(type, i), ">"); return join(constref ? "const " : "", base, " &", variable_name); } @@ -421,16 +421,18 @@ string CompilerCPP::variable_decl(const SPIRType &type, const string &name) string base = type_to_glsl(type); remap_variable_type_name(type, name, base); bool runtime = false; - for (auto &array : type.array) + + for (uint32_t i = 0; i < type.array.size(); i++) { - if (array) - base = join("std::array<", base, ", ", array, ">"); - else + auto &array = type.array[i]; + if (!array && type.array_size_literal[i]) { // Avoid using runtime arrays with std::array since this is undefined. // Runtime arrays cannot be passed around as values, so this is fine. runtime = true; } + else + base = join("std::array<", base, ", ", to_array_size(type, i), ">"); } base += ' '; return base + name + (runtime ? "[1]" : ""); diff --git a/spirv_cross.cpp b/spirv_cross.cpp index b8121ca2..69320a18 100644 --- a/spirv_cross.cpp +++ b/spirv_cross.cpp @@ -319,6 +319,9 @@ const SPIRType &Compiler::expression_type(uint32_t id) const case TypeConstant: return get(get(id).constant_type); + case TypeConstantOp: + return get(get(id).basetype); + case TypeUndef: return get(get(id).basetype); @@ -354,7 +357,8 @@ bool Compiler::is_immutable(uint32_t id) const } else if (ids[id].get_type() == TypeExpression) return get(id).immutable; - else if (ids[id].get_type() == TypeConstant || ids[id].get_type() == TypeUndef) + else if (ids[id].get_type() == TypeConstant || ids[id].get_type() == TypeConstantOp || + ids[id].get_type() == TypeUndef) return true; else return false; @@ -814,6 +818,10 @@ void Compiler::set_member_decoration(uint32_t id, uint32_t index, Decoration dec dec.offset = argument; break; + case DecorationSpecId: + dec.spec_id = argument; + break; + default: break; } @@ -855,6 +863,8 @@ uint32_t Compiler::get_member_decoration(uint32_t id, uint32_t index, Decoration return dec.location; case DecorationOffset: return dec.offset; + case DecorationSpecId: + return dec.spec_id; default: return 0; } @@ -892,6 +902,10 @@ void Compiler::unset_member_decoration(uint32_t id, uint32_t index, Decoration d dec.offset = 0; break; + case DecorationSpecId: + dec.spec_id = 0; + break; + default: break; } @@ -933,6 +947,10 @@ void Compiler::set_decoration(uint32_t id, Decoration decoration, uint32_t argum dec.input_attachment = argument; break; + case DecorationSpecId: + dec.spec_id = argument; + break; + default: break; } @@ -974,6 +992,8 @@ uint32_t Compiler::get_decoration(uint32_t id, Decoration decoration) const return dec.set; case DecorationInputAttachmentIndex: return dec.input_attachment; + case DecorationSpecId: + return dec.spec_id; default: return 0; } @@ -1005,6 +1025,14 @@ void Compiler::unset_decoration(uint32_t id, Decoration decoration) dec.set = 0; break; + case DecorationInputAttachmentIndex: + dec.input_attachment = 0; + break; + + case DecorationSpecId: + dec.spec_id = 0; + break; + default: break; } @@ -1237,7 +1265,12 @@ void Compiler::parse(const Instruction &instruction) auto &arraybase = set(id); arraybase = base; - arraybase.array.push_back(get(ops[2]).scalar()); + + auto *c = maybe_get(ops[2]); + bool literal = c && !c->specialization; + + arraybase.array_size_literal.push_back(literal); + arraybase.array.push_back(literal ? c->scalar() : ops[2]); // Do NOT set arraybase.self! break; } @@ -1251,6 +1284,7 @@ void Compiler::parse(const Instruction &instruction) arraybase = base; arraybase.array.push_back(0); + arraybase.array_size_literal.push_back(true); // Do NOT set arraybase.self! break; } @@ -1729,6 +1763,19 @@ void Compiler::parse(const Instruction &instruction) break; } + case OpSpecConstantOp: + { + if (length < 3) + throw CompilerError("OpSpecConstantOp not enough arguments."); + + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + auto spec_op = static_cast(ops[2]); + + set(id, result_type, spec_op, ops + 3, length - 3); + break; + } + // Actual opcodes. default: { @@ -2616,3 +2663,30 @@ void Compiler::build_combined_image_samplers() CombinedImageSamplerHandler handler(*this); traverse_all_reachable_opcodes(get(entry_point), handler); } + +vector Compiler::get_specialization_constants() const +{ + vector spec_consts; + for (auto &id : ids) + { + if (id.get_type() == TypeConstant) + { + auto &c = id.get(); + if (c.specialization) + { + spec_consts.push_back({ c.self, get_decoration(c.self, DecorationSpecId) }); + } + } + } + return spec_consts; +} + +SPIRConstant &Compiler::get_constant(uint32_t id) +{ + return get(id); +} + +const SPIRConstant &Compiler::get_constant(uint32_t id) const +{ + return get(id); +} diff --git a/spirv_cross.hpp b/spirv_cross.hpp index 160e5340..baec7f8c 100644 --- a/spirv_cross.hpp +++ b/spirv_cross.hpp @@ -91,6 +91,14 @@ struct CombinedImageSampler uint32_t sampler_id; }; +struct SpecializationConstant +{ + // The ID of the specialization constant. + uint32_t id; + // The constant ID of the constant, used in Vulkan during pipeline creation. + uint32_t constant_id; +}; + struct BufferRange { unsigned index; @@ -288,6 +296,16 @@ public: variable_remap_callback = std::move(cb); } + // API for querying which specialization constants exist. + // To modify a specialization constant before compile(), use get_constant(constant.id), + // then update constants directly in the SPIRConstant data structure. + // For composite types, the subconstants can be iterated over and modified. + // constant_type is the SPIRType for the specialization constant, + // which can be queried to determine which fields in the unions should be poked at. + std::vector get_specialization_constants() const; + SPIRConstant &get_constant(uint32_t id); + const SPIRConstant &get_constant(uint32_t id) const; + protected: const uint32_t *stream(const Instruction &instr) const { diff --git a/spirv_glsl.cpp b/spirv_glsl.cpp index 614707e6..c39d4eff 100644 --- a/spirv_glsl.cpp +++ b/spirv_glsl.cpp @@ -23,6 +23,27 @@ using namespace spv; using namespace spirv_cross; using namespace std; +// Returns true if an arithmetic operation does not change behavior depending on signedness. +static bool opcode_is_sign_invariant(Op opcode) +{ + switch (opcode) + { + case OpIEqual: + case OpINotEqual: + case OpISub: + case OpIAdd: + case OpIMul: + case OpShiftLeftLogical: + case OpBitwiseOr: + case OpBitwiseXor: + case OpBitwiseAnd: + return true; + + default: + return false; + } +} + static const char *to_pls_layout(PlsFormat format) { switch (format) @@ -708,6 +729,7 @@ uint32_t CompilerGLSL::type_to_std430_array_stride(const SPIRType &type, uint64_ // Array stride is equal to aligned size of the underlying type. SPIRType tmp = type; tmp.array.pop_back(); + tmp.array_size_literal.pop_back(); uint32_t size = type_to_std430_size(tmp, flags); uint32_t alignment = type_to_std430_alignment(tmp, flags); return (size + alignment - 1) & ~(alignment - 1); @@ -716,7 +738,7 @@ uint32_t CompilerGLSL::type_to_std430_array_stride(const SPIRType &type, uint64_ uint32_t CompilerGLSL::type_to_std430_size(const SPIRType &type, uint64_t flags) { if (!type.array.empty()) - return type.array.back() * type_to_std430_array_stride(type, flags); + return to_array_size_literal(type, type.array.size() - 1) * type_to_std430_array_stride(type, flags); const uint32_t base_alignment = type_to_std430_base_size(type); uint32_t size = 0; @@ -1046,6 +1068,15 @@ void CompilerGLSL::emit_uniform(const SPIRVariable &var) statement(layout_for_variable(var), "uniform ", variable_decl(var), ";"); } +void CompilerGLSL::emit_specialization_constant(const SPIRConstant &constant) +{ + auto &type = get(constant.constant_type); + auto name = to_name(constant.self); + + statement("layout(constant_id = ", get_decoration(constant.self, DecorationSpecId), ") const ", + variable_decl(type, name), " = ", constant_expression(constant), ";"); +} + void CompilerGLSL::replace_illegal_names() { for (auto &id : ids) @@ -1185,6 +1216,34 @@ void CompilerGLSL::emit_resources() if (!pls_inputs.empty() || !pls_outputs.empty()) emit_pls(); + bool emitted = false; + + // If emitted Vulkan GLSL, + // emit specialization constants as actual floats, + // spec op expressions will redirect to the constant name. + // + // TODO: If we have the fringe case that we create a spec constant which depends on a struct type, + // we'll have to deal with that, but there's currently no known way to express that. + if (options.vulkan_semantics) + { + for (auto &id : ids) + { + if (id.get_type() == TypeConstant) + { + auto &c = id.get(); + if (!c.specialization) + continue; + + emit_specialization_constant(c); + emitted = true; + } + } + } + + if (emitted) + statement(""); + emitted = false; + // Output all basic struct types which are not Block or BufferBlock as these are declared inplace // when such variables are instantiated. for (auto &id : ids) @@ -1233,8 +1292,6 @@ void CompilerGLSL::emit_resources() } } - bool emitted = false; - bool skip_separate_image_sampler = !combined_image_samplers.empty() || !options.vulkan_semantics; // Output Uniform Constants (values, samplers, images, etc). @@ -1366,7 +1423,16 @@ string CompilerGLSL::to_expression(uint32_t id) } case TypeConstant: - return constant_expression(get(id)); + { + auto &c = get(id); + if (c.specialization && options.vulkan_semantics) + return to_name(id); + else + return constant_expression(c); + } + + case TypeConstantOp: + return constant_op_expression(get(id)); case TypeVariable: { @@ -1393,6 +1459,145 @@ string CompilerGLSL::to_expression(uint32_t id) } } +string CompilerGLSL::constant_op_expression(const SPIRConstantOp &cop) +{ + auto &type = get(cop.basetype); + bool binary = false; + bool unary = false; + string op; + + // TODO: Find a clean way to reuse emit_instruction. + switch (cop.opcode) + { + case OpSConvert: + case OpUConvert: + case OpFConvert: + op = type_to_glsl_constructor(type); + break; + +#define BOP(opname, x) \ + case Op##opname: \ + binary = true; \ + op = x; \ + break + +#define UOP(opname, x) \ + case Op##opname: \ + unary = true; \ + op = x; \ + break + + UOP(SNegate, "-"); + UOP(Not, "~"); + BOP(IAdd, "+"); + BOP(ISub, "-"); + BOP(IMul, "*"); + BOP(SDiv, "/"); + BOP(UDiv, "/"); + BOP(UMod, "%"); + BOP(SMod, "%"); + BOP(ShiftRightLogical, ">>"); + BOP(ShiftRightArithmetic, ">>"); + BOP(ShiftLeftLogical, "<<"); + BOP(BitwiseOr, "|"); + BOP(BitwiseXor, "^"); + BOP(BitwiseAnd, "&"); + BOP(LogicalOr, "||"); + BOP(LogicalAnd, "&&"); + UOP(LogicalNot, "!"); + BOP(LogicalEqual, "=="); + BOP(LogicalNotEqual, "!="); + BOP(IEqual, "=="); + BOP(INotEqual, "!="); + BOP(ULessThan, "<"); + BOP(SLessThan, "<"); + BOP(ULessThanEqual, "<="); + BOP(SLessThanEqual, "<="); + BOP(UGreaterThan, ">"); + BOP(SGreaterThan, ">"); + BOP(UGreaterThanEqual, ">="); + BOP(SGreaterThanEqual, ">="); + + case OpSelect: + { + if (cop.arguments.size() < 3) + throw CompilerError("Not enough arguments to OpSpecConstantOp."); + + // This one is pretty annoying. It's triggered from + // uint(bool), int(bool) from spec constants. + // In order to preserve its compile-time constness in Vulkan GLSL, + // we need to reduce the OpSelect expression back to this simplified model. + // If we cannot, fail. + if (!to_trivial_mix_op(type, op, cop.arguments[2], cop.arguments[1], cop.arguments[0])) + { + throw CompilerError( + "Cannot implement specialization constant op OpSelect. " + "Need trivial select implementation which can be resolved to a simple cast from boolean."); + } + break; + } + + default: + // Some opcodes are unimplemented here, these are currently not possible to test from glslang. + throw CompilerError("Unimplemented spec constant op."); + } + + SPIRType::BaseType input_type; + bool skip_cast_if_equal_type = opcode_is_sign_invariant(cop.opcode); + + switch (cop.opcode) + { + case OpIEqual: + case OpINotEqual: + input_type = SPIRType::Int; + break; + + default: + input_type = type.basetype; + break; + } + +#undef BOP +#undef UOP + if (binary) + { + if (cop.arguments.size() < 2) + throw CompilerError("Not enough arguments to OpSpecConstantOp."); + + string cast_op0; + string cast_op1; + auto expected_type = binary_op_bitcast_helper(cast_op0, cast_op1, input_type, cop.arguments[0], + cop.arguments[1], skip_cast_if_equal_type); + + if (type.basetype != input_type && type.basetype != SPIRType::Boolean) + { + expected_type.basetype = input_type; + auto expr = bitcast_glsl_op(type, expected_type); + expr += '('; + expr += join(cast_op0, " ", op, " ", cast_op1); + expr += ')'; + return expr; + } + else + return join("(", cast_op0, " ", op, " ", cast_op1, ")"); + } + else if (unary) + { + if (cop.arguments.size() < 1) + throw CompilerError("Not enough arguments to OpSpecConstantOp."); + + // Auto-bitcast to result type as needed. + // Works around various casting scenarios in glslang as there is no OpBitcast for specialization constants. + return join("(", op, bitcast_glsl(type, cop.arguments[0]), ")"); + } + else + { + if (cop.arguments.size() < 1) + throw CompilerError("Not enough arguments to OpSpecConstantOp."); + return join(op, "(", to_expression(cop.arguments[0]), ")"); + } +} + string CompilerGLSL::constant_expression(const SPIRConstant &c) { if (!c.subconstants.empty()) @@ -1406,7 +1611,12 @@ string CompilerGLSL::constant_expression(const SPIRConstant &c) for (auto &elem : c.subconstants) { - res += constant_expression(get(elem)); + auto &subc = get(elem); + if (subc.specialization && options.vulkan_semantics) + res += to_name(elem); + else + res += constant_expression(get(elem)); + if (&elem != &c.subconstants.back()) res += ", "; } @@ -1742,9 +1952,7 @@ void CompilerGLSL::emit_binary_op_cast(uint32_t result_type, uint32_t result_id, extra_parens = false; } else - { expr += join(cast_op0, " ", op, " ", cast_op1); - } emit_op(result_type, result_id, expr, should_forward(op0) && should_forward(op1), extra_parens); } @@ -1870,17 +2078,77 @@ string CompilerGLSL::legacy_tex_op(const std::string &op, const SPIRType &imgtyp throw CompilerError(join("Unsupported legacy texture op: ", op)); } +bool CompilerGLSL::to_trivial_mix_op(const SPIRType &type, string &op, uint32_t left, uint32_t right, uint32_t lerp) +{ + auto *cleft = maybe_get(left); + auto *cright = maybe_get(right); + auto &lerptype = expression_type(lerp); + + // If our targets aren't constants, we cannot use construction. + if (!cleft || !cright) + return false; + + // If our targets are spec constants, we cannot use construction. + if (cleft->specialization || cright->specialization) + return false; + + // We can only use trivial construction if we have a scalar + // (should be possible to do it for vectors as well, but that is overkill for now). + if (lerptype.basetype != SPIRType::Boolean || lerptype.vecsize > 1) + return false; + + // If our bool selects between 0 and 1, we can cast from bool instead, making our trivial constructor. + bool ret = false; + switch (type.basetype) + { + case SPIRType::Int: + case SPIRType::UInt: + ret = cleft->scalar() == 0 && cright->scalar() == 1; + break; + + case SPIRType::Float: + ret = cleft->scalar_f32() == 0.0f && cright->scalar_f32() == 1.0f; + break; + + case SPIRType::Double: + ret = cleft->scalar_f64() == 0.0 && cright->scalar_f64() == 1.0; + break; + + case SPIRType::Int64: + case SPIRType::UInt64: + ret = cleft->scalar_u64() == 0 && cright->scalar_u64() == 1; + break; + + default: + break; + } + + if (ret) + op = type_to_glsl_constructor(type); + return ret; +} + void CompilerGLSL::emit_mix_op(uint32_t result_type, uint32_t id, uint32_t left, uint32_t right, uint32_t lerp) { auto &lerptype = expression_type(lerp); auto &restype = get(result_type); + string mix_op; bool has_boolean_mix = (options.es && options.version >= 310) || (!options.es && options.version >= 450); + bool trivial_mix = to_trivial_mix_op(restype, mix_op, left, right, lerp); - // Boolean mix not supported on desktop without extension. - // Was added in OpenGL 4.5 with ES 3.1 compat. - if (!has_boolean_mix && lerptype.basetype == SPIRType::Boolean) + // If we can reduce the mix to a simple cast, do so. + // This helps for cases like int(bool), uint(bool) which is implemented with + // OpSelect bool 1 0. + if (trivial_mix) { + emit_unary_func_op(result_type, id, lerp, mix_op.c_str()); + } + else if (!has_boolean_mix && lerptype.basetype == SPIRType::Boolean) + { + // Boolean mix not supported on desktop without extension. + // Was added in OpenGL 4.5 with ES 3.1 compat. + // // Could use GL_EXT_shader_integer_mix on desktop at least, // but Apple doesn't support it. :( // Just implement it as ternary expressions. @@ -3013,12 +3281,14 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction) uint32_t length = instruction.length; #define BOP(op) emit_binary_op(ops[0], ops[1], ops[2], ops[3], #op) -#define BOP_CAST(op, type, skip_cast) emit_binary_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, skip_cast) +#define BOP_CAST(op, type) \ + emit_binary_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, opcode_is_sign_invariant(opcode)) #define UOP(op) emit_unary_op(ops[0], ops[1], ops[2], #op) #define QFOP(op) emit_quaternary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], ops[5], #op) #define TFOP(op) emit_trinary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], #op) #define BFOP(op) emit_binary_func_op(ops[0], ops[1], ops[2], ops[3], #op) -#define BFOP_CAST(op, type, skip_cast) emit_binary_func_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, skip_cast) +#define BFOP_CAST(op, type) \ + emit_binary_func_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, opcode_is_sign_invariant(opcode)) #define BFOP(op) emit_binary_func_op(ops[0], ops[1], ops[2], ops[3], #op) #define UFOP(op) emit_unary_func_op(ops[0], ops[1], ops[2], #op) @@ -3433,7 +3703,7 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction) { // For simple arith ops, prefer the output type if there's a mismatch to avoid extra bitcasts. auto type = get(ops[0]).basetype; - BOP_CAST(+, type, true); + BOP_CAST(+, type); break; } @@ -3444,7 +3714,7 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction) case OpISub: { auto type = get(ops[0]).basetype; - BOP_CAST(-, type, true); + BOP_CAST(-, type); break; } @@ -3455,7 +3725,7 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction) case OpIMul: { auto type = get(ops[0]).basetype; - BOP_CAST(*, type, true); + BOP_CAST(*, type); break; } @@ -3481,11 +3751,11 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction) break; case OpSDiv: - BOP_CAST(/, SPIRType::Int, false); + BOP_CAST(/, SPIRType::Int); break; case OpUDiv: - BOP_CAST(/, SPIRType::UInt, false); + BOP_CAST(/, SPIRType::UInt); break; case OpFDiv: @@ -3493,38 +3763,38 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction) break; case OpShiftRightLogical: - BOP_CAST(>>, SPIRType::UInt, false); + BOP_CAST(>>, SPIRType::UInt); break; case OpShiftRightArithmetic: - BOP_CAST(>>, SPIRType::Int, false); + BOP_CAST(>>, SPIRType::Int); break; case OpShiftLeftLogical: { auto type = get(ops[0]).basetype; - BOP_CAST(<<, type, true); + BOP_CAST(<<, type); break; } case OpBitwiseOr: { auto type = get(ops[0]).basetype; - BOP_CAST(|, type, true); + BOP_CAST(|, type); break; } case OpBitwiseXor: { auto type = get(ops[0]).basetype; - BOP_CAST (^, type, true); + BOP_CAST (^, type); break; } case OpBitwiseAnd: { auto type = get(ops[0]).basetype; - BOP_CAST(&, type, true); + BOP_CAST(&, type); break; } @@ -3533,11 +3803,11 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction) break; case OpUMod: - BOP_CAST(%, SPIRType::UInt, false); + BOP_CAST(%, SPIRType::UInt); break; case OpSMod: - BOP_CAST(%, SPIRType::Int, false); + BOP_CAST(%, SPIRType::Int); break; case OpFMod: @@ -3572,9 +3842,9 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction) case OpIEqual: { if (expression_type(ops[2]).vecsize > 1) - BFOP_CAST(equal, SPIRType::Int, true); + BFOP_CAST(equal, SPIRType::Int); else - BOP_CAST(==, SPIRType::Int, true); + BOP_CAST(==, SPIRType::Int); break; } @@ -3591,9 +3861,9 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction) case OpINotEqual: { if (expression_type(ops[2]).vecsize > 1) - BFOP_CAST(notEqual, SPIRType::Int, true); + BFOP_CAST(notEqual, SPIRType::Int); else - BOP_CAST(!=, SPIRType::Int, true); + BOP_CAST(!=, SPIRType::Int); break; } @@ -3612,9 +3882,9 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction) { auto type = opcode == OpUGreaterThan ? SPIRType::UInt : SPIRType::Int; if (expression_type(ops[2]).vecsize > 1) - BFOP_CAST(greaterThan, type, false); + BFOP_CAST(greaterThan, type); else - BOP_CAST(>, type, false); + BOP_CAST(>, type); break; } @@ -3632,9 +3902,9 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction) { auto type = opcode == OpUGreaterThanEqual ? SPIRType::UInt : SPIRType::Int; if (expression_type(ops[2]).vecsize > 1) - BFOP_CAST(greaterThanEqual, type, false); + BFOP_CAST(greaterThanEqual, type); else - BOP_CAST(>=, type, false); + BOP_CAST(>=, type); break; } @@ -3652,9 +3922,9 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction) { auto type = opcode == OpULessThan ? SPIRType::UInt : SPIRType::Int; if (expression_type(ops[2]).vecsize > 1) - BFOP_CAST(lessThan, type, false); + BFOP_CAST(lessThan, type); else - BOP_CAST(<, type, false); + BOP_CAST(<, type); break; } @@ -3672,9 +3942,9 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction) { auto type = opcode == OpULessThanEqual ? SPIRType::UInt : SPIRType::Int; if (expression_type(ops[2]).vecsize > 1) - BFOP_CAST(lessThanEqual, type, false); + BFOP_CAST(lessThanEqual, type); else - BOP_CAST(<=, type, false); + BOP_CAST(<=, type); break; } @@ -4435,6 +4705,39 @@ string CompilerGLSL::pls_decl(const PlsRemap &var) to_name(variable.self)); } +uint32_t CompilerGLSL::to_array_size_literal(const SPIRType &type, uint32_t index) const +{ + assert(type.array.size() == type.array_size_literal.size()); + + if (!type.array_size_literal[index]) + throw CompilerError("The array size is not a literal, but a specialization constant or spec constant op."); + + return type.array[index]; +} + +string CompilerGLSL::to_array_size(const SPIRType &type, uint32_t index) +{ + assert(type.array.size() == type.array_size_literal.size()); + + auto &size = type.array[index]; + if (!type.array_size_literal[index]) + return to_expression(size); + else if (size) + return convert_to_string(size); + else if (!backend.flexible_member_array_supported) + { + // For runtime-sized arrays, we can work around + // lack of standard support for this by simply having + // a single element array. + // + // Runtime length arrays must always be the last element + // in an interface block. + return "1"; + } + else + return ""; +} + string CompilerGLSL::type_to_array_glsl(const SPIRType &type) { if (type.array.empty()) @@ -4443,23 +4746,8 @@ string CompilerGLSL::type_to_array_glsl(const SPIRType &type) string res; for (size_t i = type.array.size(); i; i--) { - auto &size = type.array[i - 1]; - res += "["; - if (size) - { - res += convert_to_string(size); - } - else if (!backend.flexible_member_array_supported) - { - // For runtime-sized arrays, we can work around - // lack of standard support for this by simply having - // a single element array. - // - // Runtime length arrays must always be the last element - // in an interface block. - res += '1'; - } + res += to_array_size(type, i - 1); res += "]"; } return res; diff --git a/spirv_glsl.hpp b/spirv_glsl.hpp index 35ec98cc..fa948560 100644 --- a/spirv_glsl.hpp +++ b/spirv_glsl.hpp @@ -144,6 +144,7 @@ protected: virtual std::string member_decl(const SPIRType &type, const SPIRType &member_type, uint32_t member); virtual std::string image_type_glsl(const SPIRType &type); virtual std::string constant_expression(const SPIRConstant &c); + std::string constant_op_expression(const SPIRConstantOp &cop); virtual std::string constant_expression_vector(const SPIRConstant &c, uint32_t vector); virtual void emit_fixup(); virtual std::string variable_decl(const SPIRType &type, const std::string &name); @@ -203,6 +204,8 @@ protected: Options options; std::string type_to_array_glsl(const SPIRType &type); + std::string to_array_size(const SPIRType &type, uint32_t index); + uint32_t to_array_size_literal(const SPIRType &type, uint32_t index) const; std::string variable_decl(const SPIRVariable &variable); void add_local_variable_name(uint32_t id); @@ -240,6 +243,7 @@ protected: void emit_push_constant_block_glsl(const SPIRVariable &var); void emit_interface_block(const SPIRVariable &type); void emit_block_chain(SPIRBlock &block); + void emit_specialization_constant(const SPIRConstant &constant); std::string emit_continue_block(uint32_t continue_block); bool attempt_emit_loop_header(SPIRBlock &block, SPIRBlock::Method method); void emit_uniform(const SPIRVariable &var); @@ -254,6 +258,7 @@ protected: bool should_forward(uint32_t id); void emit_mix_op(uint32_t result_type, uint32_t id, uint32_t left, uint32_t right, uint32_t lerp); + bool to_trivial_mix_op(const SPIRType &type, std::string &op, uint32_t left, uint32_t right, uint32_t lerp); void emit_glsl_op(uint32_t result_type, uint32_t result_id, uint32_t op, const uint32_t *args, uint32_t count); void emit_quaternary_func_op(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1, uint32_t op2, uint32_t op3, const char *op); diff --git a/spirv_msl.cpp b/spirv_msl.cpp index a66b8a00..b5d5cd26 100644 --- a/spirv_msl.cpp +++ b/spirv_msl.cpp @@ -1260,6 +1260,7 @@ SPIRType &CompilerMSL::get_pad_type(uint32_t pad_len) ib_type.basetype = SPIRType::Char; ib_type.width = 8; ib_type.array.push_back(pad_len); + ib_type.array_size_literal.push_back(true); set_decoration(ib_type.self, DecorationArrayStride, pad_len); pad_type_ids_by_pad_len[pad_len] = pad_type_id; @@ -1616,7 +1617,7 @@ size_t CompilerMSL::get_declared_type_size(const SPIRType &type, uint64_t dec_ma // ArrayStride is part of the array type not OpMemberDecorate. auto &dec = meta[type.self].decoration; if (dec.decoration_flags & (1ull << DecorationArrayStride)) - return dec.array_stride * type.array.back(); + return dec.array_stride * to_array_size_literal(type, type.array.size() - 1); else throw CompilerError("Type does not have ArrayStride set."); }