diff --git a/source/val/validation_state.cpp b/source/val/validation_state.cpp index 4e8db282c..b76620ea1 100644 --- a/source/val/validation_state.cpp +++ b/source/val/validation_state.cpp @@ -292,6 +292,9 @@ void ValidationState_t::RegisterCapability(SpvCapability cap) { } switch (cap) { + case SpvCapabilityKernel: + features_.group_ops_reduce_and_scans = true; + break; case SpvCapabilityInt16: features_.declare_int16_type = true; break; @@ -323,6 +326,18 @@ void ValidationState_t::RegisterExtension(Extension ext) { if (module_extensions_.Contains(ext)) return; module_extensions_.Add(ext); + + switch (ext) { + case kSPV_AMD_shader_ballot: + // The grammar doesn't encode the fact that SPV_AMD_shader_ballot + // enables the use of group operations Reduce, InclusiveScan, + // and ExclusiveScan. Enable it manually. + // https://github.com/KhronosGroup/SPIRV-Tools/issues/991 + features_.group_ops_reduce_and_scans = true; + break; + default: + break; + } } bool ValidationState_t::HasAnyOfCapabilities( diff --git a/source/val/validation_state.h b/source/val/validation_state.h index 48bc3be59..600560514 100644 --- a/source/val/validation_state.h +++ b/source/val/validation_state.h @@ -69,6 +69,9 @@ class ValidationState_t { // Allow functionalities enabled by VariablePointersStorageBuffer // capability. bool variable_pointers_storage_buffer = false; + + // Permit group oerations Reduce, InclusiveScan, ExclusiveScan + bool group_ops_reduce_and_scans = false; }; ValidationState_t(const spv_const_context context, diff --git a/source/validate_instruction.cpp b/source/validate_instruction.cpp index 018150789..dc013cbd7 100644 --- a/source/validate_instruction.cpp +++ b/source/validate_instruction.cpp @@ -67,6 +67,39 @@ spv_result_t CapabilityError(ValidationState_t& _, int which_operand, << " requires one of these capabilities: " << required_capabilities; } +// Returns capabilities that enable an opcode. An empty result is interpreted +// as no prohibition of use of the opcode. If the result is non-empty, then +// the opcode may only be used if at least one of the capabilities is specified +// by the module. +CapabilitySet EnablingCapabilitiesForOp(const ValidationState_t& state, + SpvOp opcode) { + // Exceptions for SPV_AMD_shader_ballot + switch (opcode) { + // Normally these would require Group capability + case SpvOpGroupIAddNonUniformAMD: + case SpvOpGroupFAddNonUniformAMD: + case SpvOpGroupFMinNonUniformAMD: + case SpvOpGroupUMinNonUniformAMD: + case SpvOpGroupSMinNonUniformAMD: + case SpvOpGroupFMaxNonUniformAMD: + case SpvOpGroupUMaxNonUniformAMD: + case SpvOpGroupSMaxNonUniformAMD: + if (state.HasExtension(libspirv::kSPV_AMD_shader_ballot)) + return CapabilitySet(); + break; + default: + break; + } + // Look it up in the grammar + spv_opcode_desc opcode_desc = {}; + if (SPV_SUCCESS == state.grammar().lookupOpcode(opcode, &opcode_desc)) { + CapabilitySet opcode_caps(opcode_desc->numCapabilities, + opcode_desc->capabilities); + return opcode_caps; + } + return CapabilitySet(); +} + // Returns an operand's required capabilities. CapabilitySet RequiredCapabilities(const ValidationState_t& state, spv_operand_type_t type, uint32_t operand) { @@ -97,12 +130,18 @@ CapabilitySet RequiredCapabilities(const ValidationState_t& state, CapabilitySet result(operand_desc->numCapabilities, operand_desc->capabilities); - // Allow FPRoundingMode decoration if requested + // Allow FPRoundingMode decoration if requested. if (state.features().free_fp_rounding_mode && type == SPV_OPERAND_TYPE_DECORATION && operand_desc->value == SpvDecorationFPRoundingMode) { return CapabilitySet(); } + // Allow certain group operations if requested. + if (state.features().group_ops_reduce_and_scans && + type == SPV_OPERAND_TYPE_GROUP_OPERATION && + (operand <= uint32_t(SpvGroupOperationExclusiveScan))) { + return CapabilitySet(); + } return result; } @@ -128,16 +167,13 @@ namespace libspirv { spv_result_t CapabilityCheck(ValidationState_t& _, const spv_parsed_instruction_t* inst) { - spv_opcode_desc opcode_desc = {}; const SpvOp opcode = static_cast(inst->opcode); - if (SPV_SUCCESS == _.grammar().lookupOpcode(opcode, &opcode_desc)) { - CapabilitySet opcode_caps(opcode_desc->numCapabilities, - opcode_desc->capabilities); - if (!_.HasAnyOfCapabilities(opcode_caps)) - return _.diag(SPV_ERROR_INVALID_CAPABILITY) - << "Opcode " << spvOpcodeString(opcode) - << " requires one of these capabilities: " - << ToString(opcode_caps, _.grammar()); + CapabilitySet opcode_caps = EnablingCapabilitiesForOp(_, opcode); + if (!_.HasAnyOfCapabilities(opcode_caps)) { + return _.diag(SPV_ERROR_INVALID_CAPABILITY) + << "Opcode " << spvOpcodeString(opcode) + << " requires one of these capabilities: " + << ToString(opcode_caps, _.grammar()); } for (int i = 0; i < inst->num_operands; ++i) { const auto& operand = inst->operands[i]; diff --git a/test/val/val_extensions_test.cpp b/test/val/val_extensions_test.cpp index e4993ca00..bd3998cae 100644 --- a/test/val/val_extensions_test.cpp +++ b/test/val/val_extensions_test.cpp @@ -30,6 +30,7 @@ using ::libspirv::Extension; using ::testing::HasSubstr; using ::testing::Not; using ::testing::Values; +using ::testing::ValuesIn; using std::string; @@ -106,4 +107,106 @@ TEST_F(ValidateExtensionCapabilities, DeclCapabilityFailure) { EXPECT_THAT(getDiagnosticString(), HasSubstr("SPV_KHR_device_group")); } + +using ValidateAMDShaderBallotCapabilities = spvtest::ValidateBase; + +// Returns a vector of strings for the prefix of a SPIR-V assembly shader +// that can use the group instructions introduced by SPV_AMD_shader_ballot. +std::vector ShaderPartsForAMDShaderBallot() { + return std::vector{R"( + OpCapability Shader + OpCapability Linkage + )", + R"( + OpMemoryModel Logical GLSL450 + %float = OpTypeFloat 32 + %uint = OpTypeInt 32 0 + %int = OpTypeInt 32 1 + %scope = OpConstant %uint 3 + %uint_const = OpConstant %uint 42 + %int_const = OpConstant %uint 45 + %float_const = OpConstant %float 3.5 + + %void = OpTypeVoid + %fn_ty = OpTypeFunction %void + %fn = OpFunction %void None %fn_ty + %entry = OpLabel + )"}; +} + +// Returns a list of SPIR-V assembly strings, where each uses only types +// and IDs that can fit with a shader made from parts from the result +// of ShaderPartsForAMDShaderBallot. +std::vector AMDShaderBallotGroupInstructions() { + return std::vector{ + "%iadd_reduce = OpGroupIAddNonUniformAMD %uint %scope Reduce %uint_const", + "%iadd_iscan = OpGroupIAddNonUniformAMD %uint %scope InclusiveScan %uint_const", + "%iadd_escan = OpGroupIAddNonUniformAMD %uint %scope ExclusiveScan %uint_const", + + "%fadd_reduce = OpGroupFAddNonUniformAMD %float %scope Reduce %float_const", + "%fadd_iscan = OpGroupFAddNonUniformAMD %float %scope InclusiveScan %float_const", + "%fadd_escan = OpGroupFAddNonUniformAMD %float %scope ExclusiveScan %float_const", + + "%fmin_reduce = OpGroupFMinNonUniformAMD %float %scope Reduce %float_const", + "%fmin_iscan = OpGroupFMinNonUniformAMD %float %scope InclusiveScan %float_const", + "%fmin_escan = OpGroupFMinNonUniformAMD %float %scope ExclusiveScan %float_const", + + "%umin_reduce = OpGroupUMinNonUniformAMD %uint %scope Reduce %uint_const", + "%umin_iscan = OpGroupUMinNonUniformAMD %uint %scope InclusiveScan %uint_const", + "%umin_escan = OpGroupUMinNonUniformAMD %uint %scope ExclusiveScan %uint_const", + + "%smin_reduce = OpGroupUMinNonUniformAMD %int %scope Reduce %int_const", + "%smin_iscan = OpGroupUMinNonUniformAMD %int %scope InclusiveScan %int_const", + "%smin_escan = OpGroupUMinNonUniformAMD %int %scope ExclusiveScan %int_const", + + "%fmax_reduce = OpGroupFMaxNonUniformAMD %float %scope Reduce %float_const", + "%fmax_iscan = OpGroupFMaxNonUniformAMD %float %scope InclusiveScan %float_const", + "%fmax_escan = OpGroupFMaxNonUniformAMD %float %scope ExclusiveScan %float_const", + + "%umax_reduce = OpGroupUMaxNonUniformAMD %uint %scope Reduce %uint_const", + "%umax_iscan = OpGroupUMaxNonUniformAMD %uint %scope InclusiveScan %uint_const", + "%umax_escan = OpGroupUMaxNonUniformAMD %uint %scope ExclusiveScan %uint_const", + + "%smax_reduce = OpGroupUMaxNonUniformAMD %int %scope Reduce %int_const", + "%smax_iscan = OpGroupUMaxNonUniformAMD %int %scope InclusiveScan %int_const", + "%smax_escan = OpGroupUMaxNonUniformAMD %int %scope ExclusiveScan %int_const" + }; +} + +TEST_P(ValidateAMDShaderBallotCapabilities, ExpectSuccess) { + // Succeed because the module specifies the SPV_AMD_shader_ballot extension. + auto parts = ShaderPartsForAMDShaderBallot(); + + const string assembly = parts[0] + "OpExtension \"SPV_AMD_shader_ballot\"\n" + + parts[1] + GetParam() + "\nOpReturn OpFunctionEnd"; + + CompileSuccessfully(assembly.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()) << getDiagnosticString(); +} + +INSTANTIATE_TEST_CASE_P(ExpectSuccess, ValidateAMDShaderBallotCapabilities, + ValuesIn(AMDShaderBallotGroupInstructions())); + +TEST_P(ValidateAMDShaderBallotCapabilities, ExpectFailure) { + // Fail because the module does not specify the SPV_AMD_shader_ballot extension. + auto parts = ShaderPartsForAMDShaderBallot(); + + const string assembly = + parts[0] + parts[1] + GetParam() + "\nOpReturn OpFunctionEnd"; + + CompileSuccessfully(assembly.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_CAPABILITY, ValidateInstructions()); + + // Make sure we get an appropriate error message. + // Find just the opcode name, skipping over the "Op" part. + auto prefix_with_opcode = GetParam().substr(GetParam().find("Group")); + auto opcode = prefix_with_opcode.substr(0, prefix_with_opcode.find(' ')); + EXPECT_THAT(getDiagnosticString(), + HasSubstr(string("Opcode " + opcode + + " requires one of these capabilities: Groups"))); +} + +INSTANTIATE_TEST_CASE_P(ExpectFailure, ValidateAMDShaderBallotCapabilities, + ValuesIn(AMDShaderBallotGroupInstructions())); + } // anonymous namespace