From 09c206b6fb79ae1c2c542e30225739b938c880a5 Mon Sep 17 00:00:00 2001 From: Alan Baker Date: Tue, 17 Apr 2018 10:18:59 -0400 Subject: [PATCH] Fixes #1480. Validate group non-uniform scopes. * Adds new pass for validating non-uniform group instructions * Currently on checks execution scope for Vulkan 1.1 and SPIR-V 1.3 * Added test framework --- Android.mk | 1 + source/CMakeLists.txt | 1 + source/opcode.cpp | 42 +++++ source/opcode.h | 3 + source/validate.cpp | 1 + source/validate.h | 4 + source/validate_non_uniform.cpp | 84 ++++++++++ test/val/CMakeLists.txt | 6 + test/val/val_non_uniform_test.cpp | 247 ++++++++++++++++++++++++++++++ 9 files changed, 389 insertions(+) create mode 100644 source/validate_non_uniform.cpp create mode 100644 test/val/val_non_uniform_test.cpp diff --git a/Android.mk b/Android.mk index 30be9b724..f6ac8298b 100644 --- a/Android.mk +++ b/Android.mk @@ -54,6 +54,7 @@ SPVTOOLS_SRC_FILES := \ source/validate_layout.cpp \ source/validate_literals.cpp \ source/validate_logicals.cpp \ + source/validate_non_uniform.cpp \ source/validate_primitives.cpp \ source/validate_type_unique.cpp diff --git a/source/CMakeLists.txt b/source/CMakeLists.txt index d9f595587..df376ca44 100644 --- a/source/CMakeLists.txt +++ b/source/CMakeLists.txt @@ -301,6 +301,7 @@ set(SPIRV_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/validate_layout.cpp ${CMAKE_CURRENT_SOURCE_DIR}/validate_literals.cpp ${CMAKE_CURRENT_SOURCE_DIR}/validate_logicals.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/validate_non_uniform.cpp ${CMAKE_CURRENT_SOURCE_DIR}/validate_primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/validate_type_unique.cpp ${CMAKE_CURRENT_SOURCE_DIR}/val/decoration.h diff --git a/source/opcode.cpp b/source/opcode.cpp index c73f14d3a..98c4bb9c8 100644 --- a/source/opcode.cpp +++ b/source/opcode.cpp @@ -454,3 +454,45 @@ bool spvOpcodeIsBaseOpaqueType(SpvOp opcode) { return false; } } + +bool spvOpcodeIsNonUniformGroupOperation(SpvOp opcode) { + switch (opcode) { + case SpvOpGroupNonUniformElect: + case SpvOpGroupNonUniformAll: + case SpvOpGroupNonUniformAny: + case SpvOpGroupNonUniformAllEqual: + case SpvOpGroupNonUniformBroadcast: + case SpvOpGroupNonUniformBroadcastFirst: + case SpvOpGroupNonUniformBallot: + case SpvOpGroupNonUniformInverseBallot: + case SpvOpGroupNonUniformBallotBitExtract: + case SpvOpGroupNonUniformBallotBitCount: + case SpvOpGroupNonUniformBallotFindLSB: + case SpvOpGroupNonUniformBallotFindMSB: + case SpvOpGroupNonUniformShuffle: + case SpvOpGroupNonUniformShuffleXor: + case SpvOpGroupNonUniformShuffleUp: + case SpvOpGroupNonUniformShuffleDown: + case SpvOpGroupNonUniformIAdd: + case SpvOpGroupNonUniformFAdd: + case SpvOpGroupNonUniformIMul: + case SpvOpGroupNonUniformFMul: + case SpvOpGroupNonUniformSMin: + case SpvOpGroupNonUniformUMin: + case SpvOpGroupNonUniformFMin: + case SpvOpGroupNonUniformSMax: + case SpvOpGroupNonUniformUMax: + case SpvOpGroupNonUniformFMax: + case SpvOpGroupNonUniformBitwiseAnd: + case SpvOpGroupNonUniformBitwiseOr: + case SpvOpGroupNonUniformBitwiseXor: + case SpvOpGroupNonUniformLogicalAnd: + case SpvOpGroupNonUniformLogicalOr: + case SpvOpGroupNonUniformLogicalXor: + case SpvOpGroupNonUniformQuadBroadcast: + case SpvOpGroupNonUniformQuadSwap: + return true; + default: + return false; + } +} diff --git a/source/opcode.h b/source/opcode.h index 9b585137e..7aadf30ce 100644 --- a/source/opcode.h +++ b/source/opcode.h @@ -118,4 +118,7 @@ bool spvOpcodeIsBlockTerminator(SpvOp opcode); // Returns true if the given opcode always defines an opaque type. bool spvOpcodeIsBaseOpaqueType(SpvOp opcode); + +// Returns true if the given opcode is a non-uniform group operation. +bool spvOpcodeIsNonUniformGroupOperation(SpvOp opcode); #endif // LIBSPIRV_OPCODE_H_ diff --git a/source/validate.cpp b/source/validate.cpp index ea7300487..953aad1ed 100644 --- a/source/validate.cpp +++ b/source/validate.cpp @@ -189,6 +189,7 @@ spv_result_t ProcessInstruction(void* user_data, if (auto error = BarriersPass(_, inst)) return error; if (auto error = PrimitivesPass(_, inst)) return error; if (auto error = LiteralsPass(_, inst)) return error; + if (auto error = NonUniformPass(_, inst)) return error; return SPV_SUCCESS; } diff --git a/source/validate.h b/source/validate.h index a4f6dde28..983b30da3 100644 --- a/source/validate.h +++ b/source/validate.h @@ -170,6 +170,10 @@ spv_result_t LiteralsPass(ValidationState_t& _, spv_result_t ExtInstPass(ValidationState_t& _, const spv_parsed_instruction_t* inst); +/// Validates correctness of non-uniform group instructions. +spv_result_t NonUniformPass(ValidationState_t& _, + const spv_parsed_instruction_t* inst); + // Validates that capability declarations use operands allowed in the current // context. spv_result_t CapabilityPass(ValidationState_t& _, diff --git a/source/validate_non_uniform.cpp b/source/validate_non_uniform.cpp new file mode 100644 index 000000000..66c2b4286 --- /dev/null +++ b/source/validate_non_uniform.cpp @@ -0,0 +1,84 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Validates correctness of barrier SPIR-V instructions. + +#include "validate.h" + +#include "diagnostic.h" +#include "opcode.h" +#include "spirv_constant.h" +#include "spirv_target_env.h" +#include "util/bitutils.h" +#include "val/instruction.h" +#include "val/validation_state.h" + +namespace libspirv { + +namespace { + +spv_result_t ValidateExecutionScope(ValidationState_t& _, + const spv_parsed_instruction_t* inst, + uint32_t scope) { + SpvOp opcode = static_cast(inst->opcode); + bool is_int32 = false, is_const_int32 = false; + uint32_t value = 0; + std::tie(is_int32, is_const_int32, value) = _.EvalInt32IfConst(scope); + + if (!is_int32) { + return _.diag(SPV_ERROR_INVALID_DATA) + << spvOpcodeString(opcode) + << ": expected Execution Scope to be a 32-bit int"; + } + + if (!is_const_int32) { + return SPV_SUCCESS; + } + + if (spvIsVulkanEnv(_.context()->target_env) && + _.context()->target_env != SPV_ENV_VULKAN_1_0 && + value != SpvScopeSubgroup) { + return _.diag(SPV_ERROR_INVALID_DATA) + << spvOpcodeString(opcode) + << ": in Vulkan environment Execution scope is limited to " + "Subgroup"; + } + + if (value != SpvScopeSubgroup && value != SpvScopeWorkgroup) { + return _.diag(SPV_ERROR_INVALID_DATA) << spvOpcodeString(opcode) + << ": Execution scope is limited to " + "Subgroup or Workgroup"; + } + + return SPV_SUCCESS; +} + +} // namespace + +// Validates correctness of non-uniform group instructions. +spv_result_t NonUniformPass(ValidationState_t& _, + const spv_parsed_instruction_t* inst) { + const SpvOp opcode = static_cast(inst->opcode); + + if (spvOpcodeIsNonUniformGroupOperation(opcode)) { + const uint32_t execution_scope = inst->words[3]; + if (auto error = ValidateExecutionScope(_, inst, execution_scope)) { + return error; + } + } + + return SPV_SUCCESS; +} + +} // namespace libspirv diff --git a/test/val/CMakeLists.txt b/test/val/CMakeLists.txt index 093a04a4d..86d547051 100644 --- a/test/val/CMakeLists.txt +++ b/test/val/CMakeLists.txt @@ -186,3 +186,9 @@ add_spvtools_unittest(TARGET val_version ${VAL_TEST_COMMON_SRCS} LIBS ${SPIRV_TOOLS} ) + +add_spvtools_unittest(TARGET val_non_uniform + SRCS val_non_uniform_test.cpp + ${VAL_TEST_COMMON_SRCS} + LIBS ${SPIRV_TOOLS} +) diff --git a/test/val/val_non_uniform_test.cpp b/test/val/val_non_uniform_test.cpp new file mode 100644 index 000000000..1548f8f38 --- /dev/null +++ b/test/val/val_non_uniform_test.cpp @@ -0,0 +1,247 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "gmock/gmock.h" +#include "unit_spirv.h" +#include "val_fixtures.h" + +namespace { + +using ::testing::Combine; +using ::testing::HasSubstr; +using ::testing::Values; +using ::testing::ValuesIn; + +std::string GenerateShaderCode( + const std::string& body, + const std::string& capabilities_and_extensions = "", + const std::string& execution_model = "GLCompute") { + std::ostringstream ss; + ss << R"( +OpCapability Shader +OpCapability GroupNonUniform +OpCapability GroupNonUniformVote +OpCapability GroupNonUniformBallot +OpCapability GroupNonUniformShuffle +OpCapability GroupNonUniformShuffleRelative +OpCapability GroupNonUniformArithmetic +OpCapability GroupNonUniformClustered +OpCapability GroupNonUniformQuad +)"; + + ss << capabilities_and_extensions; + ss << "OpMemoryModel Logical GLSL450\n"; + ss << "OpEntryPoint " << execution_model << " %main \"main\"\n"; + + ss << R"( +%void = OpTypeVoid +%func = OpTypeFunction %void +%bool = OpTypeBool +%u32 = OpTypeInt 32 0 +%float = OpTypeFloat 32 +%u32vec4 = OpTypeVector %u32 4 + +%true = OpConstantTrue %bool +%false = OpConstantFalse %bool + +%u32_0 = OpConstant %u32 0 + +%float_0 = OpConstant %float 0 + +%u32vec4_null = OpConstantComposite %u32vec4 %u32_0 %u32_0 %u32_0 %u32_0 + +%cross_device = OpConstant %u32 0 +%device = OpConstant %u32 1 +%workgroup = OpConstant %u32 2 +%subgroup = OpConstant %u32 3 +%invocation = OpConstant %u32 4 + +%reduce = OpConstant %u32 0 +%inclusive_scan = OpConstant %u32 1 +%exclusive_scan = OpConstant %u32 2 +%clustered_reduce = OpConstant %u32 3 + +%main = OpFunction %void None %func +%main_entry = OpLabel +)"; + + ss << body; + + ss << R"( +OpReturn +OpFunctionEnd)"; + + return ss.str(); +} + +SpvScope scopes[] = {SpvScopeCrossDevice, SpvScopeDevice, SpvScopeWorkgroup, + SpvScopeSubgroup, SpvScopeInvocation}; + +using GroupNonUniformScope = spvtest::ValidateBase< + std::tuple>; + +std::string ConvertScope(SpvScope scope) { + switch (scope) { + case SpvScopeCrossDevice: + return "%cross_device"; + case SpvScopeDevice: + return "%device"; + case SpvScopeWorkgroup: + return "%workgroup"; + case SpvScopeSubgroup: + return "%subgroup"; + case SpvScopeInvocation: + return "%invocation"; + default: + return ""; + } +} + +TEST_P(GroupNonUniformScope, Vulkan1p1) { + std::string opcode = std::get<0>(GetParam()); + std::string type = std::get<1>(GetParam()); + SpvScope execution_scope = std::get<2>(GetParam()); + std::string args = std::get<3>(GetParam()); + + std::ostringstream sstr; + sstr << "%result = " << opcode << " "; + sstr << type << " "; + sstr << ConvertScope(execution_scope) << " "; + sstr << args << "\n"; + + CompileSuccessfully(GenerateShaderCode(sstr.str()), SPV_ENV_VULKAN_1_1); + spv_result_t result = ValidateInstructions(SPV_ENV_VULKAN_1_1); + if (execution_scope == SpvScopeSubgroup) { + EXPECT_EQ(SPV_SUCCESS, result); + } else { + EXPECT_EQ(SPV_ERROR_INVALID_DATA, result); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "in Vulkan environment Execution scope is limited to Subgroup")); + } +} + +TEST_P(GroupNonUniformScope, Spirv1p3) { + std::string opcode = std::get<0>(GetParam()); + std::string type = std::get<1>(GetParam()); + SpvScope execution_scope = std::get<2>(GetParam()); + std::string args = std::get<3>(GetParam()); + + std::ostringstream sstr; + sstr << "%result = " << opcode << " "; + sstr << type << " "; + sstr << ConvertScope(execution_scope) << " "; + sstr << args << "\n"; + + CompileSuccessfully(GenerateShaderCode(sstr.str()), SPV_ENV_UNIVERSAL_1_3); + spv_result_t result = ValidateInstructions(SPV_ENV_UNIVERSAL_1_3); + if (execution_scope == SpvScopeSubgroup || + execution_scope == SpvScopeWorkgroup) { + EXPECT_EQ(SPV_SUCCESS, result); + } else { + EXPECT_EQ(SPV_ERROR_INVALID_DATA, result); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Execution scope is limited to Subgroup or Workgroup")); + } +} + +INSTANTIATE_TEST_CASE_P(GroupNonUniformElect, GroupNonUniformScope, + Combine(Values("OpGroupNonUniformElect"), + Values("%bool"), ValuesIn(scopes), Values(""))); + +INSTANTIATE_TEST_CASE_P(GroupNonUniformVote, GroupNonUniformScope, + Combine(Values("OpGroupNonUniformAll", + "OpGroupNonUniformAny", + "OpGroupNonUniformAllEqual"), + Values("%bool"), ValuesIn(scopes), + Values("%true"))); + +INSTANTIATE_TEST_CASE_P(GroupNonUniformBroadcast, GroupNonUniformScope, + Combine(Values("OpGroupNonUniformBroadcast"), + Values("%bool"), ValuesIn(scopes), + Values("%true %u32_0"))); + +INSTANTIATE_TEST_CASE_P(GroupNonUniformBroadcastFirst, GroupNonUniformScope, + Combine(Values("OpGroupNonUniformBroadcastFirst"), + Values("%bool"), ValuesIn(scopes), + Values("%true"))); + +INSTANTIATE_TEST_CASE_P(GroupNonUniformBallot, GroupNonUniformScope, + Combine(Values("OpGroupNonUniformBallot"), + Values("%u32vec4"), ValuesIn(scopes), + Values("%true"))); + +INSTANTIATE_TEST_CASE_P(GroupNonUniformInverseBallot, GroupNonUniformScope, + Combine(Values("OpGroupNonUniformInverseBallot"), + Values("%bool"), ValuesIn(scopes), + Values("%u32vec4_null"))); + +INSTANTIATE_TEST_CASE_P(GroupNonUniformBallotBitExtract, GroupNonUniformScope, + Combine(Values("OpGroupNonUniformBallotBitExtract"), + Values("%bool"), ValuesIn(scopes), + Values("%u32vec4_null %u32_0"))); + +INSTANTIATE_TEST_CASE_P(GroupNonUniformBallotBitCount, GroupNonUniformScope, + Combine(Values("OpGroupNonUniformBallotBitCount"), + Values("%u32"), ValuesIn(scopes), + Values("Reduce %u32vec4_null"))); + +INSTANTIATE_TEST_CASE_P(GroupNonUniformBallotFind, GroupNonUniformScope, + Combine(Values("OpGroupNonUniformBallotFindLSB", + "OpGroupNonUniformBallotFindMSB"), + Values("%u32"), ValuesIn(scopes), + Values("%u32vec4_null"))); + +INSTANTIATE_TEST_CASE_P(GroupNonUniformShuffle, GroupNonUniformScope, + Combine(Values("OpGroupNonUniformShuffle", + "OpGroupNonUniformShuffleXor", + "OpGroupNonUniformShuffleUp", + "OpGroupNonUniformShuffleDown"), + Values("%u32"), ValuesIn(scopes), + Values("%u32_0 %u32_0"))); + +INSTANTIATE_TEST_CASE_P( + GroupNonUniformIntegerArithmetic, GroupNonUniformScope, + Combine(Values("OpGroupNonUniformIAdd", "OpGroupNonUniformIMul", + "OpGroupNonUniformSMin", "OpGroupNonUniformUMin", + "OpGroupNonUniformSMax", "OpGroupNonUniformUMax", + "OpGroupNonUniformBitwiseAnd", "OpGroupNonUniformBitwiseOr", + "OpGroupNonUniformBitwiseXor"), + Values("%u32"), ValuesIn(scopes), Values("Reduce %u32_0"))); + +INSTANTIATE_TEST_CASE_P( + GroupNonUniformFloatArithmetic, GroupNonUniformScope, + Combine(Values("OpGroupNonUniformFAdd", "OpGroupNonUniformFMul", + "OpGroupNonUniformFMin", "OpGroupNonUniformFMax"), + Values("%float"), ValuesIn(scopes), Values("Reduce %float_0"))); + +INSTANTIATE_TEST_CASE_P(GroupNonUniformLogicalArithmetic, GroupNonUniformScope, + Combine(Values("OpGroupNonUniformLogicalAnd", + "OpGroupNonUniformLogicalOr", + "OpGroupNonUniformLogicalXor"), + Values("%bool"), ValuesIn(scopes), + Values("Reduce %true"))); + +INSTANTIATE_TEST_CASE_P(GroupNonUniformQuad, GroupNonUniformScope, + Combine(Values("OpGroupNonUniformQuadBroadcast", + "OpGroupNonUniformQuadSwap"), + Values("%u32"), ValuesIn(scopes), + Values("%u32_0 %u32_0"))); + +} // anonymous namespace