From b8fce5f9e6dabf59f8bf495229a98d53bdbf89b0 Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Fri, 27 Aug 2021 14:43:23 -0400 Subject: [PATCH] spirv-lint: Add lint based on divergence analysis (#4488) This PR is a rebased version of #4479 by James Dong. --- The primary purpose of this PR is to add the code from my prototype as a PR, for licensing reasons. The commit history is messy, and the code is not especially clean. Fixes https://github.com/KhronosGroup/SPIRV-Tools/issues/3196. --- include/spirv-tools/linter.hpp | 2 +- source/lint/CMakeLists.txt | 2 + source/lint/lint_divergent_derivatives.cpp | 169 +++++++++++++++++++++ source/lint/linter.cpp | 21 ++- source/lint/lints.h | 34 +++++ 5 files changed, 221 insertions(+), 7 deletions(-) create mode 100644 source/lint/lint_divergent_derivatives.cpp create mode 100644 source/lint/lints.h diff --git a/include/spirv-tools/linter.hpp b/include/spirv-tools/linter.hpp index 57d1b4e98..52ed5a467 100644 --- a/include/spirv-tools/linter.hpp +++ b/include/spirv-tools/linter.hpp @@ -35,7 +35,7 @@ class Linter { void SetMessageConsumer(MessageConsumer consumer); // Returns a reference to the registered message consumer. - const MessageConsumer& consumer() const; + const MessageConsumer& Consumer() const; bool Run(const uint32_t* binary, size_t binary_size); diff --git a/source/lint/CMakeLists.txt b/source/lint/CMakeLists.txt index f9cae28a7..1feae3f94 100644 --- a/source/lint/CMakeLists.txt +++ b/source/lint/CMakeLists.txt @@ -13,9 +13,11 @@ # limitations under the License. set(SPIRV_TOOLS_LINT_SOURCES divergence_analysis.h + lints.h linter.cpp divergence_analysis.cpp + lint_divergent_derivatives.cpp ) if(MSVC AND (NOT ("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang"))) diff --git a/source/lint/lint_divergent_derivatives.cpp b/source/lint/lint_divergent_derivatives.cpp new file mode 100644 index 000000000..512847b0c --- /dev/null +++ b/source/lint/lint_divergent_derivatives.cpp @@ -0,0 +1,169 @@ +// Copyright (c) 2021 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "source/diagnostic.h" +#include "source/lint/divergence_analysis.h" +#include "source/lint/lints.h" +#include "source/opt/basic_block.h" +#include "source/opt/cfg.h" +#include "source/opt/control_dependence.h" +#include "source/opt/def_use_manager.h" +#include "source/opt/dominator_analysis.h" +#include "source/opt/instruction.h" +#include "source/opt/ir_context.h" +#include "spirv-tools/libspirv.h" +#include "spirv/unified1/spirv.h" + +namespace spvtools { +namespace lint { +namespace lints { +namespace { +// Returns the %name[id], where `name` is the first name associated with the +// given id, or just %id if one is not found. +std::string GetFriendlyName(opt::IRContext* context, uint32_t id) { + auto names = context->GetNames(id); + std::stringstream ss; + ss << "%"; + if (names.empty()) { + ss << id; + } else { + opt::Instruction* inst_name = names.begin()->second; + if (inst_name->opcode() == SpvOpName) { + ss << names.begin()->second->GetInOperand(0).AsString(); + ss << "[" << id << "]"; + } else { + ss << id; + } + } + return ss.str(); +} + +bool InstructionHasDerivative(const opt::Instruction& inst) { + static const SpvOp derivative_opcodes[] = { + // Implicit derivatives. + SpvOpImageSampleImplicitLod, + SpvOpImageSampleDrefImplicitLod, + SpvOpImageSampleProjImplicitLod, + SpvOpImageSampleProjDrefImplicitLod, + SpvOpImageSparseSampleImplicitLod, + SpvOpImageSparseSampleDrefImplicitLod, + SpvOpImageSparseSampleProjImplicitLod, + SpvOpImageSparseSampleProjDrefImplicitLod, + // Explicit derivatives. + SpvOpDPdx, + SpvOpDPdy, + SpvOpFwidth, + SpvOpDPdxFine, + SpvOpDPdyFine, + SpvOpFwidthFine, + SpvOpDPdxCoarse, + SpvOpDPdyCoarse, + SpvOpFwidthCoarse, + }; + return std::find(std::begin(derivative_opcodes), std::end(derivative_opcodes), + inst.opcode()) != std::end(derivative_opcodes); +} + +spvtools::DiagnosticStream Warn(opt::IRContext* context, + opt::Instruction* inst) { + if (inst == nullptr) { + return DiagnosticStream({0, 0, 0}, context->consumer(), "", SPV_WARNING); + } else { + // TODO(kuhar): Use line numbers based on debug info. + return DiagnosticStream( + {0, 0, 0}, context->consumer(), + inst->PrettyPrint(SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES), + SPV_WARNING); + } +} + +void PrintDivergenceFlow(opt::IRContext* context, DivergenceAnalysis div, + uint32_t id) { + opt::analysis::DefUseManager* def_use = context->get_def_use_mgr(); + opt::CFG* cfg = context->cfg(); + while (id != 0) { + bool is_block = def_use->GetDef(id)->opcode() == SpvOpLabel; + if (is_block) { + Warn(context, nullptr) + << "block " << GetFriendlyName(context, id) << " is divergent"; + uint32_t source = div.GetDivergenceSource(id); + // Skip intermediate blocks. + while (source != 0 && def_use->GetDef(source)->opcode() == SpvOpLabel) { + id = source; + source = div.GetDivergenceSource(id); + } + if (source == 0) break; + spvtools::opt::Instruction* branch = + cfg->block(div.GetDivergenceDependenceSource(id))->terminator(); + Warn(context, branch) + << "because it depends on a conditional branch on divergent value " + << GetFriendlyName(context, source) << ""; + id = source; + } else { + Warn(context, nullptr) + << "value " << GetFriendlyName(context, id) << " is divergent"; + uint32_t source = div.GetDivergenceSource(id); + opt::Instruction* def = def_use->GetDef(id); + opt::Instruction* source_def = + source == 0 ? nullptr : def_use->GetDef(source); + // First print data -> data dependencies. + while (source != 0 && source_def->opcode() != SpvOpLabel) { + Warn(context, def_use->GetDef(id)) + << "because " << GetFriendlyName(context, id) << " uses value " + << GetFriendlyName(context, source) + << "in its definition, which is divergent"; + id = source; + def = source_def; + source = div.GetDivergenceSource(id); + source_def = def_use->GetDef(source); + } + if (source == 0) { + Warn(context, def) << "because it has a divergent definition"; + break; + } + Warn(context, def) << "because it is conditionally set in block " + << GetFriendlyName(context, source); + id = source; + } + } +} +} // namespace + +bool CheckDivergentDerivatives(opt::IRContext* context) { + DivergenceAnalysis div(*context); + for (opt::Function& func : *context->module()) { + div.Run(&func); + for (const opt::BasicBlock& bb : func) { + for (const opt::Instruction& inst : bb) { + if (InstructionHasDerivative(inst) && + div.GetDivergenceLevel(bb.id()) > + DivergenceAnalysis::DivergenceLevel::kPartiallyUniform) { + Warn(context, nullptr) + << "derivative with divergent control flow" + << " located in block " << GetFriendlyName(context, bb.id()); + PrintDivergenceFlow(context, div, bb.id()); + } + } + } + } + return true; +} + +} // namespace lints +} // namespace lint +} // namespace spvtools diff --git a/source/lint/linter.cpp b/source/lint/linter.cpp index 0f8479537..e4ed04ea4 100644 --- a/source/lint/linter.cpp +++ b/source/lint/linter.cpp @@ -14,6 +14,13 @@ #include "spirv-tools/linter.hpp" +#include "source/lint/lints.h" +#include "source/opt/build_module.h" +#include "source/opt/ir_context.h" +#include "spirv-tools/libspirv.h" +#include "spirv-tools/libspirv.hpp" +#include "spirv/unified1/spirv.h" + namespace spvtools { struct Linter::Impl { @@ -32,20 +39,22 @@ Linter::Linter(spv_target_env env) : impl_(new Impl(env)) {} Linter::~Linter() {} void Linter::SetMessageConsumer(MessageConsumer consumer) { - impl_->message_consumer = consumer; + impl_->message_consumer = std::move(consumer); } -const MessageConsumer& Linter::consumer() const { +const MessageConsumer& Linter::Consumer() const { return impl_->message_consumer; } bool Linter::Run(const uint32_t* binary, size_t binary_size) { - (void)binary; - (void)binary_size; + std::unique_ptr context = + BuildModule(SPV_ENV_VULKAN_1_2, Consumer(), binary, binary_size); + if (context == nullptr) return false; - consumer()(SPV_MSG_INFO, "", {0, 0, 0}, "Hello, world!"); + bool result = true; + result &= lint::lints::CheckDivergentDerivatives(context.get()); - return true; + return result; } } // namespace spvtools diff --git a/source/lint/lints.h b/source/lint/lints.h new file mode 100644 index 000000000..a1995d2fb --- /dev/null +++ b/source/lint/lints.h @@ -0,0 +1,34 @@ +// Copyright (c) 2021 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_LINT_LINTS_H_ +#define SOURCE_LINT_LINTS_H_ + +#include "source/opt/ir_context.h" + +namespace spvtools { +namespace lint { + +// All of the functions in this namespace output to the error consumer in the +// |context| argument and return |true| if no errors are found. They do not +// modify the IR. +namespace lints { + +bool CheckDivergentDerivatives(opt::IRContext* context); + +} // namespace lints +} // namespace lint +} // namespace spvtools + +#endif // SOURCE_LINT_LINTS_H_