mirror of
https://github.com/KhronosGroup/SPIRV-Tools
synced 2024-11-22 19:50:05 +00:00
Add strength reduction; for now replace multiply by power of 2
Create a new optimization pass, strength reduction, which will replace integer multiplication by a constant power of 2 with an equivalent bit shift. More changes could be added later. - Does not duplicate constants - Adds vector |Concat| utility function to a common test header.
This commit is contained in:
parent
7be791aaaa
commit
e4c7d8e748
@ -185,6 +185,12 @@ Optimizer::PassToken CreateUnifyConstantPass();
|
||||
// OpSpecConstantOp.
|
||||
Optimizer::PassToken CreateEliminateDeadConstantPass();
|
||||
|
||||
// Creates a strength-reduction pass.
|
||||
// A strength-reduction pass will look for opportunities to replace an
|
||||
// instruction with an equivalent and less expensive one. For example,
|
||||
// multiplying by a power of 2 can be replaced by a bit shift.
|
||||
Optimizer::PassToken CreateStrengthReductionPass();
|
||||
|
||||
// Creates a block merge pass.
|
||||
// This pass searches for blocks with a single Branch to a block with no
|
||||
// other predecessors and merges the blocks into a single block. Continue
|
||||
|
@ -45,6 +45,7 @@ add_library(SPIRV-Tools-opt
|
||||
passes.h
|
||||
pass_manager.h
|
||||
set_spec_constant_default_value_pass.h
|
||||
strength_reduction_pass.h
|
||||
strip_debug_info_pass.h
|
||||
types.h
|
||||
type_manager.h
|
||||
@ -79,6 +80,7 @@ add_library(SPIRV-Tools-opt
|
||||
mem_pass.cpp
|
||||
pass.cpp
|
||||
pass_manager.cpp
|
||||
strength_reduction_pass.cpp
|
||||
strip_debug_info_pass.cpp
|
||||
types.cpp
|
||||
type_manager.cpp
|
||||
|
@ -132,6 +132,11 @@ Optimizer::PassToken CreateEliminateDeadConstantPass() {
|
||||
MakeUnique<opt::EliminateDeadConstantPass>());
|
||||
}
|
||||
|
||||
Optimizer::PassToken CreateStrengthReductionPass() {
|
||||
return MakeUnique<Optimizer::PassToken::Impl>(
|
||||
MakeUnique<opt::StrengthReductionPass>());
|
||||
}
|
||||
|
||||
Optimizer::PassToken CreateBlockMergePass() {
|
||||
return MakeUnique<Optimizer::PassToken::Impl>(
|
||||
MakeUnique<opt::BlockMergePass>());
|
||||
|
@ -35,6 +35,7 @@
|
||||
#include "aggressive_dead_code_elim_pass.h"
|
||||
#include "null_pass.h"
|
||||
#include "set_spec_constant_default_value_pass.h"
|
||||
#include "strength_reduction_pass.h"
|
||||
#include "strip_debug_info_pass.h"
|
||||
#include "unify_const_pass.h"
|
||||
|
||||
|
210
source/opt/strength_reduction_pass.cpp
Normal file
210
source/opt/strength_reduction_pass.cpp
Normal file
@ -0,0 +1,210 @@
|
||||
// Copyright (c) 2017 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "strength_reduction_pass.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "def_use_manager.h"
|
||||
#include "log.h"
|
||||
#include "reflect.h"
|
||||
|
||||
namespace {
|
||||
// Count the number of trailing zeros in the binary representation of
|
||||
// |constVal|.
|
||||
uint32_t CountTrailingZeros(uint32_t constVal) {
|
||||
// Faster if we use the hardware count trailing zeros instruction.
|
||||
// If not available, we could create a table.
|
||||
uint32_t shiftAmount = 0;
|
||||
while ((constVal & 1) == 0) {
|
||||
++shiftAmount;
|
||||
constVal = (constVal >> 1);
|
||||
}
|
||||
return shiftAmount;
|
||||
}
|
||||
|
||||
// Return true if |val| is a power of 2.
|
||||
bool IsPowerOf2(uint32_t val) {
|
||||
// The idea is that the & will clear out the least
|
||||
// significant 1 bit. If it is a power of 2, then
|
||||
// there is exactly 1 bit set, and the value becomes 0.
|
||||
if (val == 0) return false;
|
||||
return ((val - 1) & val) == 0;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace spvtools {
|
||||
namespace opt {
|
||||
|
||||
Pass::Status StrengthReductionPass::Process(ir::Module* module) {
|
||||
// Initialize the member variables on a per module basis.
|
||||
bool modified = false;
|
||||
def_use_mgr_.reset(new analysis::DefUseManager(consumer(), module));
|
||||
int32_type_id_ = 0;
|
||||
uint32_type_id_ = 0;
|
||||
std::memset(constant_ids_, 0, sizeof(constant_ids_));
|
||||
next_id_ = module->IdBound();
|
||||
module_ = module;
|
||||
|
||||
FindIntTypesAndConstants();
|
||||
modified = ScanFunctions();
|
||||
// Have to reset the id bound.
|
||||
module->SetIdBound(next_id_);
|
||||
return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
|
||||
}
|
||||
|
||||
bool StrengthReductionPass::ReplaceMultiplyByPowerOf2(
|
||||
ir::BasicBlock::iterator* instPtr) {
|
||||
ir::BasicBlock::iterator& inst = *instPtr;
|
||||
assert(inst->opcode() == SpvOp::SpvOpIMul &&
|
||||
"Only works for multiplication of integers.");
|
||||
bool modified = false;
|
||||
|
||||
// Currently only works on 32-bit integers.
|
||||
if (inst->type_id() != int32_type_id_ && inst->type_id() != uint32_type_id_) {
|
||||
return modified;
|
||||
}
|
||||
|
||||
// Check the operands for a constant that is a power of 2.
|
||||
for (int i = 0; i < 2; i++) {
|
||||
uint32_t opId = inst->GetSingleWordInOperand(i);
|
||||
ir::Instruction* opInst = def_use_mgr_->GetDef(opId);
|
||||
if (opInst->opcode() == SpvOp::SpvOpConstant) {
|
||||
// We found a constant operand.
|
||||
uint32_t constVal = opInst->GetSingleWordOperand(2);
|
||||
|
||||
if (IsPowerOf2(constVal)) {
|
||||
modified = true;
|
||||
uint32_t shiftAmount = CountTrailingZeros(constVal);
|
||||
uint32_t shiftConstResultId = GetConstantId(shiftAmount);
|
||||
|
||||
// Create the new instruction.
|
||||
uint32_t newResultId = next_id_++;
|
||||
std::vector<ir::Operand> newOperands;
|
||||
newOperands.push_back(inst->GetInOperand(1 - i));
|
||||
ir::Operand shiftOperand(spv_operand_type_t::SPV_OPERAND_TYPE_ID,
|
||||
{shiftConstResultId});
|
||||
newOperands.push_back(shiftOperand);
|
||||
std::unique_ptr<ir::Instruction> newInstruction(
|
||||
new ir::Instruction(SpvOp::SpvOpShiftLeftLogical, inst->type_id(),
|
||||
newResultId, newOperands));
|
||||
|
||||
// Insert the new instruction and update the data structures.
|
||||
def_use_mgr_->AnalyzeInstDefUse(&*newInstruction);
|
||||
inst = inst.InsertBefore(std::move(newInstruction));
|
||||
++inst;
|
||||
def_use_mgr_->ReplaceAllUsesWith(inst->result_id(), newResultId);
|
||||
|
||||
// Remove the old instruction.
|
||||
def_use_mgr_->KillInst(&*inst);
|
||||
|
||||
// We do not want to replace the instruction twice if both operands
|
||||
// are constants that are a power of 2. So we break here.
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return modified;
|
||||
}
|
||||
|
||||
void StrengthReductionPass::FindIntTypesAndConstants() {
|
||||
for (auto iter = module_->types_values_begin();
|
||||
iter != module_->types_values_end(); ++iter) {
|
||||
switch (iter->opcode()) {
|
||||
case SpvOp::SpvOpTypeInt:
|
||||
if (iter->GetSingleWordOperand(1) == 32) {
|
||||
if (iter->GetSingleWordOperand(2) == 1) {
|
||||
int32_type_id_ = iter->result_id();
|
||||
} else {
|
||||
uint32_type_id_ = iter->result_id();
|
||||
}
|
||||
}
|
||||
break;
|
||||
case SpvOp::SpvOpConstant:
|
||||
if (iter->type_id() == uint32_type_id_) {
|
||||
uint32_t value = iter->GetSingleWordOperand(2);
|
||||
if (value <= 32) constant_ids_[value] = iter->result_id();
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t StrengthReductionPass::GetConstantId(uint32_t val) {
|
||||
assert(val <= 32 &&
|
||||
"This function does not handle constants larger than 32.");
|
||||
|
||||
if (constant_ids_[val] == 0) {
|
||||
if (uint32_type_id_ == 0) {
|
||||
uint32_type_id_ = CreateUint32Type();
|
||||
}
|
||||
|
||||
// Construct the constant.
|
||||
uint32_t resultId = next_id_++;
|
||||
ir::Operand constant(spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER,
|
||||
{val});
|
||||
std::unique_ptr<ir::Instruction> newConstant(new ir::Instruction(
|
||||
SpvOp::SpvOpConstant, uint32_type_id_, resultId, {constant}));
|
||||
module_->AddGlobalValue(std::move(newConstant));
|
||||
|
||||
// Store the result id for next time.
|
||||
constant_ids_[val] = resultId;
|
||||
}
|
||||
|
||||
return constant_ids_[val];
|
||||
}
|
||||
|
||||
bool StrengthReductionPass::ScanFunctions() {
|
||||
// I did not use |ForEachInst| in the module because the function that acts on
|
||||
// the instruction gets a pointer to the instruction. We cannot use that to
|
||||
// insert a new instruction. I want an iterator.
|
||||
bool modified = false;
|
||||
for (auto& func : *module_) {
|
||||
for (auto& bb : func) {
|
||||
for (auto inst = bb.begin(); inst != bb.end(); ++inst) {
|
||||
switch (inst->opcode()) {
|
||||
case SpvOp::SpvOpIMul:
|
||||
if (ReplaceMultiplyByPowerOf2(&inst)) modified = true;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return modified;
|
||||
}
|
||||
|
||||
uint32_t StrengthReductionPass::CreateUint32Type() {
|
||||
uint32_t type_id = next_id_++;
|
||||
ir::Operand widthOperand(spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER,
|
||||
{32});
|
||||
ir::Operand signOperand(spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER,
|
||||
{0});
|
||||
std::unique_ptr<ir::Instruction> newType(new ir::Instruction(
|
||||
SpvOp::SpvOpTypeInt, type_id, 0, {widthOperand, signOperand}));
|
||||
module_->AddType(std::move(newType));
|
||||
return type_id;
|
||||
}
|
||||
|
||||
} // namespace opt
|
||||
} // namespace spvtools
|
75
source/opt/strength_reduction_pass.h
Normal file
75
source/opt/strength_reduction_pass.h
Normal file
@ -0,0 +1,75 @@
|
||||
// Copyright (c) 2017 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef LIBSPIRV_OPT_STRENGTH_REDUCTION_PASS_H_
|
||||
#define LIBSPIRV_OPT_STRENGTH_REDUCTION_PASS_H_
|
||||
|
||||
#include "def_use_manager.h"
|
||||
#include "module.h"
|
||||
#include "pass.h"
|
||||
|
||||
namespace spvtools {
|
||||
namespace opt {
|
||||
|
||||
// See optimizer.hpp for documentation.
|
||||
class StrengthReductionPass : public Pass {
|
||||
public:
|
||||
const char* name() const override { return "strength-reduction"; }
|
||||
Status Process(ir::Module*) override;
|
||||
|
||||
private:
|
||||
// Replaces multiple by power of 2 with an equivalent bit shift.
|
||||
// Returns true if something changed.
|
||||
bool ReplaceMultiplyByPowerOf2(ir::BasicBlock::iterator*);
|
||||
|
||||
// Scan the types and constants in the module looking for the the integer types that we are
|
||||
// interested in. The shift operation needs a small unsigned integer. We need to find
|
||||
// them or create them. We do not want duplicates.
|
||||
void FindIntTypesAndConstants();
|
||||
|
||||
// Get the id for the given constant. If it does not exist, it will be
|
||||
// created. The parameter must be between 0 and 32 inclusive.
|
||||
uint32_t GetConstantId(uint32_t);
|
||||
|
||||
// Replaces certain instructions in function bodies with presumably cheaper
|
||||
// ones. Returns true if something changed.
|
||||
bool ScanFunctions();
|
||||
|
||||
// Will create the type for an unsigned 32-bit integer and return the id.
|
||||
// This functions assumes one does not already exist.
|
||||
uint32_t CreateUint32Type();
|
||||
|
||||
// Def-Uses for the module we are processing
|
||||
std::unique_ptr<analysis::DefUseManager> def_use_mgr_;
|
||||
|
||||
// Type ids for the types of interest, or 0 if they do not exist.
|
||||
uint32_t int32_type_id_;
|
||||
uint32_t uint32_type_id_;
|
||||
|
||||
// constant_ids[i] is the id for unsigned integer constant i.
|
||||
// We set the limit at 32 because a bit shift of a 32-bit integer does not
|
||||
// need a value larger than 32.
|
||||
uint32_t constant_ids_[33];
|
||||
|
||||
// Next unused ID
|
||||
uint32_t next_id_;
|
||||
|
||||
// The module that the pass is being applied to.
|
||||
ir::Module* module_;
|
||||
};
|
||||
|
||||
} // namespace opt
|
||||
} // namespace spvtools
|
||||
|
||||
#endif // LIBSPIRV_OPT_STRENGTH_REDUCTION_PASS_H_
|
@ -168,3 +168,8 @@ add_spvtools_unittest(TARGET line_debug_info
|
||||
SRCS line_debug_info_test.cpp pass_utils.cpp
|
||||
LIBS SPIRV-Tools-opt
|
||||
)
|
||||
|
||||
add_spvtools_unittest(TARGET pass_strength_reduction
|
||||
SRCS strength_reduction_test.cpp pass_utils.cpp
|
||||
LIBS SPIRV-Tools-opt
|
||||
)
|
||||
|
@ -16,13 +16,6 @@
|
||||
#include "pass_fixture.h"
|
||||
#include "pass_utils.h"
|
||||
|
||||
template <typename T> std::vector<T> concat(const std::vector<T> &a, const std::vector<T> &b) {
|
||||
std::vector<T> ret = std::vector<T>();
|
||||
std::copy(a.begin(), a.end(), back_inserter(ret));
|
||||
std::copy(b.begin(), b.end(), back_inserter(ret));
|
||||
return ret;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace spvtools;
|
||||
@ -134,8 +127,8 @@ TEST_F(InlineTest, Simple) {
|
||||
// clang-format on
|
||||
};
|
||||
SinglePassRunAndCheck<opt::InlineExhaustivePass>(
|
||||
JoinAllInsts(concat(concat(predefs, before), nonEntryFuncs)),
|
||||
JoinAllInsts(concat(concat(predefs, after), nonEntryFuncs)),
|
||||
JoinAllInsts(Concat(Concat(predefs, before), nonEntryFuncs)),
|
||||
JoinAllInsts(Concat(Concat(predefs, after), nonEntryFuncs)),
|
||||
/* skip_nop = */ false, /* do_validate = */ true);
|
||||
}
|
||||
|
||||
@ -284,8 +277,8 @@ TEST_F(InlineTest, Nested) {
|
||||
// clang-format on
|
||||
};
|
||||
SinglePassRunAndCheck<opt::InlineExhaustivePass>(
|
||||
JoinAllInsts(concat(concat(predefs, before), nonEntryFuncs)),
|
||||
JoinAllInsts(concat(concat(predefs, after), nonEntryFuncs)),
|
||||
JoinAllInsts(Concat(Concat(predefs, before), nonEntryFuncs)),
|
||||
JoinAllInsts(Concat(Concat(predefs, after), nonEntryFuncs)),
|
||||
/* skip_nop = */ false, /* do_validate = */ true);
|
||||
}
|
||||
|
||||
@ -413,8 +406,8 @@ TEST_F(InlineTest, InOutParameter) {
|
||||
// clang-format on
|
||||
};
|
||||
SinglePassRunAndCheck<opt::InlineExhaustivePass>(
|
||||
JoinAllInsts(concat(concat(predefs, before), nonEntryFuncs)),
|
||||
JoinAllInsts(concat(concat(predefs, after), nonEntryFuncs)),
|
||||
JoinAllInsts(Concat(Concat(predefs, before), nonEntryFuncs)),
|
||||
JoinAllInsts(Concat(Concat(predefs, after), nonEntryFuncs)),
|
||||
/* skip_nop = */ false, /* do_validate = */ true);
|
||||
}
|
||||
|
||||
@ -549,8 +542,8 @@ TEST_F(InlineTest, BranchInCallee) {
|
||||
// clang-format on
|
||||
};
|
||||
SinglePassRunAndCheck<opt::InlineExhaustivePass>(
|
||||
JoinAllInsts(concat(concat(predefs, before), nonEntryFuncs)),
|
||||
JoinAllInsts(concat(concat(predefs, after), nonEntryFuncs)),
|
||||
JoinAllInsts(Concat(Concat(predefs, before), nonEntryFuncs)),
|
||||
JoinAllInsts(Concat(Concat(predefs, after), nonEntryFuncs)),
|
||||
/* skip_nop = */ false, /* do_validate = */ true);
|
||||
}
|
||||
|
||||
@ -744,8 +737,8 @@ TEST_F(InlineTest, PhiAfterCall) {
|
||||
// clang-format on
|
||||
};
|
||||
SinglePassRunAndCheck<opt::InlineExhaustivePass>(
|
||||
JoinAllInsts(concat(concat(predefs, before), nonEntryFuncs)),
|
||||
JoinAllInsts(concat(concat(predefs, after), nonEntryFuncs)),
|
||||
JoinAllInsts(Concat(Concat(predefs, before), nonEntryFuncs)),
|
||||
JoinAllInsts(Concat(Concat(predefs, after), nonEntryFuncs)),
|
||||
/* skip_nop = */ false, /* do_validate = */ true);
|
||||
}
|
||||
|
||||
@ -941,8 +934,8 @@ TEST_F(InlineTest, OpSampledImageOutOfBlock) {
|
||||
// clang-format on
|
||||
};
|
||||
SinglePassRunAndCheck<opt::InlineExhaustivePass>(
|
||||
JoinAllInsts(concat(concat(predefs, before), nonEntryFuncs)),
|
||||
JoinAllInsts(concat(concat(predefs, after), nonEntryFuncs)),
|
||||
JoinAllInsts(Concat(Concat(predefs, before), nonEntryFuncs)),
|
||||
JoinAllInsts(Concat(Concat(predefs, after), nonEntryFuncs)),
|
||||
/* skip_nop = */ false, /* do_validate = */ true);
|
||||
}
|
||||
|
||||
@ -1147,8 +1140,8 @@ TEST_F(InlineTest, OpImageOutOfBlock) {
|
||||
// clang-format on
|
||||
};
|
||||
SinglePassRunAndCheck<opt::InlineExhaustivePass>(
|
||||
JoinAllInsts(concat(concat(predefs, before), nonEntryFuncs)),
|
||||
JoinAllInsts(concat(concat(predefs, after), nonEntryFuncs)),
|
||||
JoinAllInsts(Concat(Concat(predefs, before), nonEntryFuncs)),
|
||||
JoinAllInsts(Concat(Concat(predefs, after), nonEntryFuncs)),
|
||||
/* skip_nop = */ false, /* do_validate = */ true);
|
||||
}
|
||||
|
||||
@ -1353,8 +1346,8 @@ TEST_F(InlineTest, OpImageAndOpSampledImageOutOfBlock) {
|
||||
// clang-format on
|
||||
};
|
||||
SinglePassRunAndCheck<opt::InlineExhaustivePass>(
|
||||
JoinAllInsts(concat(concat(predefs, before), nonEntryFuncs)),
|
||||
JoinAllInsts(concat(concat(predefs, after), nonEntryFuncs)),
|
||||
JoinAllInsts(Concat(Concat(predefs, before), nonEntryFuncs)),
|
||||
JoinAllInsts(Concat(Concat(predefs, after), nonEntryFuncs)),
|
||||
/* skip_nop = */ false, /* do_validate = */ true);
|
||||
}
|
||||
|
||||
|
@ -16,14 +16,6 @@
|
||||
#include "pass_fixture.h"
|
||||
#include "pass_utils.h"
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> concat(const std::vector<T>& a, const std::vector<T>& b) {
|
||||
std::vector<T> ret;
|
||||
std::copy(a.begin(), a.end(), back_inserter(ret));
|
||||
std::copy(b.begin(), b.end(), back_inserter(ret));
|
||||
return ret;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace spvtools;
|
||||
|
@ -49,6 +49,16 @@ std::string JoinAllInsts(const std::vector<const char*>& insts);
|
||||
// will be ignored.
|
||||
std::string JoinNonDebugInsts(const std::vector<const char*>& insts);
|
||||
|
||||
// Returns a vector that contains the contents of |a| followed by the contents
|
||||
// of |b|.
|
||||
template <typename T>
|
||||
std::vector<T> Concat(const std::vector<T>& a, const std::vector<T>& b) {
|
||||
std::vector<T> ret;
|
||||
std::copy(a.begin(), a.end(), back_inserter(ret));
|
||||
std::copy(b.begin(), b.end(), back_inserter(ret));
|
||||
return ret;
|
||||
}
|
||||
|
||||
} // namespace spvtools
|
||||
|
||||
#endif // LIBSPIRV_TEST_OPT_PASS_UTILS_H_
|
||||
|
427
test/opt/strength_reduction_test.cpp
Normal file
427
test/opt/strength_reduction_test.cpp
Normal file
@ -0,0 +1,427 @@
|
||||
// Copyright (c) 2017 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "assembly_builder.h"
|
||||
#include "gmock/gmock.h"
|
||||
#include "pass_fixture.h"
|
||||
#include "pass_utils.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdarg>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <unordered_set>
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace spvtools;
|
||||
|
||||
using ::testing::HasSubstr;
|
||||
using ::testing::MatchesRegex;
|
||||
|
||||
using StrengthReductionBasicTest = PassTest<::testing::Test>;
|
||||
|
||||
// Test to make sure we replace 5*8.
|
||||
TEST_F(StrengthReductionBasicTest, BasicReplaceMulBy8) {
|
||||
const std::vector<const char*> text = {
|
||||
// clang-format off
|
||||
"OpCapability Shader",
|
||||
"%1 = OpExtInstImport \"GLSL.std.450\"",
|
||||
"OpMemoryModel Logical GLSL450",
|
||||
"OpEntryPoint Vertex %main \"main\"",
|
||||
"OpName %main \"main\"",
|
||||
"%void = OpTypeVoid",
|
||||
"%4 = OpTypeFunction %void",
|
||||
"%uint = OpTypeInt 32 0",
|
||||
"%uint_5 = OpConstant %uint 5",
|
||||
"%uint_8 = OpConstant %uint 8",
|
||||
"%main = OpFunction %void None %4",
|
||||
"%8 = OpLabel",
|
||||
"%9 = OpIMul %uint %uint_5 %uint_8",
|
||||
"OpReturn",
|
||||
"OpFunctionEnd"
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
auto result = SinglePassRunAndDisassemble<opt::StrengthReductionPass>(
|
||||
JoinAllInsts(text), /* skip_nop = */ true);
|
||||
|
||||
EXPECT_EQ(opt::Pass::Status::SuccessWithChange, std::get<1>(result));
|
||||
const std::string& output = std::get<0>(result);
|
||||
EXPECT_THAT(output, Not(HasSubstr("OpIMul")));
|
||||
EXPECT_THAT(output, HasSubstr("OpShiftLeftLogical %uint %uint_5 %uint_3"));
|
||||
}
|
||||
|
||||
// Test to make sure we replace 16*5.
|
||||
TEST_F(StrengthReductionBasicTest, BasicReplaceMulBy16) {
|
||||
const std::vector<const char*> text = {
|
||||
// clang-format off
|
||||
"OpCapability Shader",
|
||||
"%1 = OpExtInstImport \"GLSL.std.450\"",
|
||||
"OpMemoryModel Logical GLSL450",
|
||||
"OpEntryPoint Vertex %main \"main\"",
|
||||
"OpName %main \"main\"",
|
||||
"%void = OpTypeVoid",
|
||||
"%4 = OpTypeFunction %void",
|
||||
"%int = OpTypeInt 32 1",
|
||||
"%int_16 = OpConstant %int 16",
|
||||
"%int_5 = OpConstant %int 5",
|
||||
"%main = OpFunction %void None %4",
|
||||
"%8 = OpLabel",
|
||||
"%9 = OpIMul %int %int_16 %int_5",
|
||||
"OpReturn",
|
||||
"OpFunctionEnd"
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
auto result = SinglePassRunAndDisassemble<opt::StrengthReductionPass>(
|
||||
JoinAllInsts(text), /* skip_nop = */ true);
|
||||
|
||||
EXPECT_EQ(opt::Pass::Status::SuccessWithChange, std::get<1>(result));
|
||||
const std::string& output = std::get<0>(result);
|
||||
EXPECT_THAT(output, Not(HasSubstr("OpIMul")));
|
||||
EXPECT_THAT(output, HasSubstr("OpShiftLeftLogical %int %int_5 %uint_4"));
|
||||
}
|
||||
|
||||
// Test to make sure we replace a multiple of 32 and 4.
|
||||
TEST_F(StrengthReductionBasicTest, BasicTwoPowersOf2) {
|
||||
// In this case, we have two powers of 2. Need to make sure we replace only
|
||||
// one of them for the bit shift.
|
||||
// clang-format off
|
||||
const std::string text = R"(
|
||||
OpCapability Shader
|
||||
%1 = OpExtInstImport "GLSL.std.450"
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint Vertex %main "main"
|
||||
OpName %main "main"
|
||||
%void = OpTypeVoid
|
||||
%4 = OpTypeFunction %void
|
||||
%int = OpTypeInt 32 1
|
||||
%int_32 = OpConstant %int 32
|
||||
%int_4 = OpConstant %int 4
|
||||
%main = OpFunction %void None %4
|
||||
%8 = OpLabel
|
||||
%9 = OpIMul %int %int_32 %int_4
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
)";
|
||||
// clang-format on
|
||||
auto result = SinglePassRunAndDisassemble<opt::StrengthReductionPass>(
|
||||
text, /* skip_nop = */ true);
|
||||
|
||||
EXPECT_EQ(opt::Pass::Status::SuccessWithChange, std::get<1>(result));
|
||||
const std::string& output = std::get<0>(result);
|
||||
EXPECT_THAT(output, Not(HasSubstr("OpIMul")));
|
||||
EXPECT_THAT(output, HasSubstr("OpShiftLeftLogical %int %int_4 %uint_5"));
|
||||
}
|
||||
|
||||
// Test to make sure we don't replace 0*5.
|
||||
TEST_F(StrengthReductionBasicTest, BasicDontReplace0) {
|
||||
const std::vector<const char*> text = {
|
||||
// clang-format off
|
||||
"OpCapability Shader",
|
||||
"%1 = OpExtInstImport \"GLSL.std.450\"",
|
||||
"OpMemoryModel Logical GLSL450",
|
||||
"OpEntryPoint Vertex %main \"main\"",
|
||||
"OpName %main \"main\"",
|
||||
"%void = OpTypeVoid",
|
||||
"%4 = OpTypeFunction %void",
|
||||
"%int = OpTypeInt 32 1",
|
||||
"%int_0 = OpConstant %int 0",
|
||||
"%int_5 = OpConstant %int 5",
|
||||
"%main = OpFunction %void None %4",
|
||||
"%8 = OpLabel",
|
||||
"%9 = OpIMul %int %int_0 %int_5",
|
||||
"OpReturn",
|
||||
"OpFunctionEnd"
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
auto result = SinglePassRunAndDisassemble<opt::StrengthReductionPass>(
|
||||
JoinAllInsts(text), /* skip_nop = */ true);
|
||||
|
||||
EXPECT_EQ(opt::Pass::Status::SuccessWithoutChange, std::get<1>(result));
|
||||
}
|
||||
|
||||
// Test to make sure we do not replace a multiple of 5 and 7.
|
||||
TEST_F(StrengthReductionBasicTest, BasicNoChange) {
|
||||
const std::vector<const char*> text = {
|
||||
// clang-format off
|
||||
"OpCapability Shader",
|
||||
"%1 = OpExtInstImport \"GLSL.std.450\"",
|
||||
"OpMemoryModel Logical GLSL450",
|
||||
"OpEntryPoint Vertex %2 \"main\"",
|
||||
"OpName %2 \"main\"",
|
||||
"%3 = OpTypeVoid",
|
||||
"%4 = OpTypeFunction %3",
|
||||
"%5 = OpTypeInt 32 1",
|
||||
"%6 = OpTypeInt 32 0",
|
||||
"%7 = OpConstant %5 5",
|
||||
"%8 = OpConstant %5 7",
|
||||
"%2 = OpFunction %3 None %4",
|
||||
"%9 = OpLabel",
|
||||
"%10 = OpIMul %5 %7 %8",
|
||||
"OpReturn",
|
||||
"OpFunctionEnd",
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
auto result = SinglePassRunAndDisassemble<opt::StrengthReductionPass>(
|
||||
JoinAllInsts(text), /* skip_nop = */ true);
|
||||
|
||||
EXPECT_EQ(opt::Pass::Status::SuccessWithoutChange, std::get<1>(result));
|
||||
}
|
||||
|
||||
// Test to make sure constants and types are reused and not duplicated.
|
||||
TEST_F(StrengthReductionBasicTest, NoDuplicateConstantsAndTypes) {
|
||||
const std::vector<const char*> text = {
|
||||
// clang-format off
|
||||
"OpCapability Shader",
|
||||
"%1 = OpExtInstImport \"GLSL.std.450\"",
|
||||
"OpMemoryModel Logical GLSL450",
|
||||
"OpEntryPoint Vertex %main \"main\"",
|
||||
"OpName %main \"main\"",
|
||||
"%void = OpTypeVoid",
|
||||
"%4 = OpTypeFunction %void",
|
||||
"%uint = OpTypeInt 32 0",
|
||||
"%uint_8 = OpConstant %uint 8",
|
||||
"%uint_3 = OpConstant %uint 3",
|
||||
"%main = OpFunction %void None %4",
|
||||
"%8 = OpLabel",
|
||||
"%9 = OpIMul %uint %uint_8 %uint_3",
|
||||
"OpReturn",
|
||||
"OpFunctionEnd",
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
auto result = SinglePassRunAndDisassemble<opt::StrengthReductionPass>(
|
||||
JoinAllInsts(text), /* skip_nop = */ true);
|
||||
|
||||
EXPECT_EQ(opt::Pass::Status::SuccessWithChange, std::get<1>(result));
|
||||
const std::string& output = std::get<0>(result);
|
||||
EXPECT_THAT(output,
|
||||
Not(MatchesRegex(".*OpConstant %uint 3.*OpConstant %uint 3.*")));
|
||||
EXPECT_THAT(output, Not(MatchesRegex(".*OpTypeInt 32 0.*OpTypeInt 32 0.*")));
|
||||
}
|
||||
|
||||
// Test to make sure we generate the constants only once
|
||||
TEST_F(StrengthReductionBasicTest, BasicCreateOneConst) {
|
||||
const std::vector<const char*> text = {
|
||||
// clang-format off
|
||||
"OpCapability Shader",
|
||||
"%1 = OpExtInstImport \"GLSL.std.450\"",
|
||||
"OpMemoryModel Logical GLSL450",
|
||||
"OpEntryPoint Vertex %main \"main\"",
|
||||
"OpName %main \"main\"",
|
||||
"%void = OpTypeVoid",
|
||||
"%4 = OpTypeFunction %void",
|
||||
"%uint = OpTypeInt 32 0",
|
||||
"%uint_5 = OpConstant %uint 5",
|
||||
"%uint_9 = OpConstant %uint 9",
|
||||
"%uint_128 = OpConstant %uint 128",
|
||||
"%main = OpFunction %void None %4",
|
||||
"%8 = OpLabel",
|
||||
"%9 = OpIMul %uint %uint_5 %uint_128",
|
||||
"%10 = OpIMul %uint %uint_9 %uint_128",
|
||||
"OpReturn",
|
||||
"OpFunctionEnd"
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
auto result = SinglePassRunAndDisassemble<opt::StrengthReductionPass>(
|
||||
JoinAllInsts(text), /* skip_nop = */ true);
|
||||
|
||||
EXPECT_EQ(opt::Pass::Status::SuccessWithChange, std::get<1>(result));
|
||||
const std::string& output = std::get<0>(result);
|
||||
EXPECT_THAT(output, Not(HasSubstr("OpIMul")));
|
||||
EXPECT_THAT(output, HasSubstr("OpShiftLeftLogical %uint %uint_5 %uint_7"));
|
||||
EXPECT_THAT(output, HasSubstr("OpShiftLeftLogical %uint %uint_9 %uint_7"));
|
||||
}
|
||||
|
||||
// Test to make sure we generate the instructions in the correct position and
|
||||
// that the uses get replaced as well. Here we check that the use in the return
|
||||
// is replaced, we also check that we can replace two OpIMuls when one feeds the
|
||||
// other.
|
||||
TEST_F(StrengthReductionBasicTest, BasicCheckPositionAndReplacement) {
|
||||
// This is just the preamble to set up the test.
|
||||
const std::vector<const char*> common_text = {
|
||||
// clang-format off
|
||||
"OpCapability Shader",
|
||||
"%1 = OpExtInstImport \"GLSL.std.450\"",
|
||||
"OpMemoryModel Logical GLSL450",
|
||||
"OpEntryPoint Fragment %main \"main\" %gl_FragColor",
|
||||
"OpExecutionMode %main OriginUpperLeft",
|
||||
"OpName %main \"main\"",
|
||||
"OpName %foo_i1_ \"foo(i1;\"",
|
||||
"OpName %n \"n\"",
|
||||
"OpName %gl_FragColor \"gl_FragColor\"",
|
||||
"OpName %param \"param\"",
|
||||
"OpDecorate %gl_FragColor Location 0",
|
||||
"%void = OpTypeVoid",
|
||||
"%3 = OpTypeFunction %void",
|
||||
"%int = OpTypeInt 32 1",
|
||||
"%_ptr_Function_int = OpTypePointer Function %int",
|
||||
"%8 = OpTypeFunction %int %_ptr_Function_int",
|
||||
"%int_256 = OpConstant %int 256",
|
||||
"%int_2 = OpConstant %int 2",
|
||||
"%float = OpTypeFloat 32",
|
||||
"%v4float = OpTypeVector %float 4",
|
||||
"%_ptr_Output_v4float = OpTypePointer Output %v4float",
|
||||
"%gl_FragColor = OpVariable %_ptr_Output_v4float Output",
|
||||
"%float_1 = OpConstant %float 1",
|
||||
"%int_10 = OpConstant %int 10",
|
||||
"%float_0_4 = OpConstant %float 0.4",
|
||||
"%float_0_8 = OpConstant %float 0.8",
|
||||
"%uint = OpTypeInt 32 0",
|
||||
"%uint_8 = OpConstant %uint 8",
|
||||
"%uint_1 = OpConstant %uint 1",
|
||||
"%main = OpFunction %void None %3",
|
||||
"%5 = OpLabel",
|
||||
"%param = OpVariable %_ptr_Function_int Function",
|
||||
"OpStore %param %int_10",
|
||||
"%26 = OpFunctionCall %int %foo_i1_ %param",
|
||||
"%27 = OpConvertSToF %float %26",
|
||||
"%28 = OpFDiv %float %float_1 %27",
|
||||
"%31 = OpCompositeConstruct %v4float %28 %float_0_4 %float_0_8 %float_1",
|
||||
"OpStore %gl_FragColor %31",
|
||||
"OpReturn",
|
||||
"OpFunctionEnd"
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
// This is the real test. The two OpIMul should be replaced. The expected
|
||||
// output is in |foo_after|.
|
||||
const std::vector<const char*> foo_before = {
|
||||
// clang-format off
|
||||
"%foo_i1_ = OpFunction %int None %8",
|
||||
"%n = OpFunctionParameter %_ptr_Function_int",
|
||||
"%11 = OpLabel",
|
||||
"%12 = OpLoad %int %n",
|
||||
"%14 = OpIMul %int %12 %int_256",
|
||||
"%16 = OpIMul %int %14 %int_2",
|
||||
"OpReturnValue %16",
|
||||
"OpFunctionEnd",
|
||||
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
const std::vector<const char*> foo_after = {
|
||||
// clang-format off
|
||||
"%foo_i1_ = OpFunction %int None %8",
|
||||
"%n = OpFunctionParameter %_ptr_Function_int",
|
||||
"%11 = OpLabel",
|
||||
"%12 = OpLoad %int %n",
|
||||
"%33 = OpShiftLeftLogical %int %12 %uint_8",
|
||||
"%34 = OpShiftLeftLogical %int %33 %uint_1",
|
||||
"OpReturnValue %34",
|
||||
"OpFunctionEnd",
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
|
||||
SinglePassRunAndCheck<opt::StrengthReductionPass>(
|
||||
JoinAllInsts(Concat(common_text, foo_before)),
|
||||
JoinAllInsts(Concat(common_text, foo_after)),
|
||||
/* skip_nop = */ true, /* do_validate = */ true);
|
||||
}
|
||||
|
||||
// Test that, when the result of an OpIMul instruction has more than 1 use, and
|
||||
// the instruction is replaced, all of the uses of the results are replace with
|
||||
// the new result.
|
||||
TEST_F(StrengthReductionBasicTest, BasicTestMultipleReplacements) {
|
||||
// This is just the preamble to set up the test.
|
||||
const std::vector<const char*> common_text = {
|
||||
// clang-format off
|
||||
"OpCapability Shader",
|
||||
"%1 = OpExtInstImport \"GLSL.std.450\"",
|
||||
"OpMemoryModel Logical GLSL450",
|
||||
"OpEntryPoint Fragment %main \"main\" %gl_FragColor",
|
||||
"OpExecutionMode %main OriginUpperLeft",
|
||||
"OpName %main \"main\"",
|
||||
"OpName %foo_i1_ \"foo(i1;\"",
|
||||
"OpName %n \"n\"",
|
||||
"OpName %gl_FragColor \"gl_FragColor\"",
|
||||
"OpName %param \"param\"",
|
||||
"OpDecorate %gl_FragColor Location 0",
|
||||
"%void = OpTypeVoid",
|
||||
"%3 = OpTypeFunction %void",
|
||||
"%int = OpTypeInt 32 1",
|
||||
"%_ptr_Function_int = OpTypePointer Function %int",
|
||||
"%8 = OpTypeFunction %int %_ptr_Function_int",
|
||||
"%int_256 = OpConstant %int 256",
|
||||
"%int_2 = OpConstant %int 2",
|
||||
"%float = OpTypeFloat 32",
|
||||
"%v4float = OpTypeVector %float 4",
|
||||
"%_ptr_Output_v4float = OpTypePointer Output %v4float",
|
||||
"%gl_FragColor = OpVariable %_ptr_Output_v4float Output",
|
||||
"%float_1 = OpConstant %float 1",
|
||||
"%int_10 = OpConstant %int 10",
|
||||
"%float_0_4 = OpConstant %float 0.4",
|
||||
"%float_0_8 = OpConstant %float 0.8",
|
||||
"%uint = OpTypeInt 32 0",
|
||||
"%uint_8 = OpConstant %uint 8",
|
||||
"%uint_1 = OpConstant %uint 1",
|
||||
"%main = OpFunction %void None %3",
|
||||
"%5 = OpLabel",
|
||||
"%param = OpVariable %_ptr_Function_int Function",
|
||||
"OpStore %param %int_10",
|
||||
"%26 = OpFunctionCall %int %foo_i1_ %param",
|
||||
"%27 = OpConvertSToF %float %26",
|
||||
"%28 = OpFDiv %float %float_1 %27",
|
||||
"%31 = OpCompositeConstruct %v4float %28 %float_0_4 %float_0_8 %float_1",
|
||||
"OpStore %gl_FragColor %31",
|
||||
"OpReturn",
|
||||
"OpFunctionEnd"
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
// This is the real test. The two OpIMul instructions should be replaced. In
|
||||
// particular, we want to be sure that both uses of %16 are changed to use the
|
||||
// new result.
|
||||
const std::vector<const char*> foo_before = {
|
||||
// clang-format off
|
||||
"%foo_i1_ = OpFunction %int None %8",
|
||||
"%n = OpFunctionParameter %_ptr_Function_int",
|
||||
"%11 = OpLabel",
|
||||
"%12 = OpLoad %int %n",
|
||||
"%14 = OpIMul %int %12 %int_256",
|
||||
"%16 = OpIMul %int %14 %int_2",
|
||||
"%17 = OpIAdd %int %14 %16",
|
||||
"OpReturnValue %17",
|
||||
"OpFunctionEnd",
|
||||
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
const std::vector<const char*> foo_after = {
|
||||
// clang-format off
|
||||
"%foo_i1_ = OpFunction %int None %8",
|
||||
"%n = OpFunctionParameter %_ptr_Function_int",
|
||||
"%11 = OpLabel",
|
||||
"%12 = OpLoad %int %n",
|
||||
"%34 = OpShiftLeftLogical %int %12 %uint_8",
|
||||
"%35 = OpShiftLeftLogical %int %34 %uint_1",
|
||||
"%17 = OpIAdd %int %34 %35",
|
||||
"OpReturnValue %17",
|
||||
"OpFunctionEnd",
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
|
||||
SinglePassRunAndCheck<opt::StrengthReductionPass>(
|
||||
JoinAllInsts(Concat(common_text, foo_before)),
|
||||
JoinAllInsts(Concat(common_text, foo_after)),
|
||||
/* skip_nop = */ true, /* do_validate = */ true);
|
||||
}
|
||||
} // anonymous namespace
|
@ -112,6 +112,8 @@ Options:
|
||||
Join two blocks into a single block if the second has the
|
||||
first as its only predecessor. Performed only on entry point
|
||||
call tree functions.
|
||||
--strength-reduction
|
||||
Replaces instructions with equivalent and less expensive ones.
|
||||
-h, --help
|
||||
Print this help.
|
||||
--version
|
||||
@ -200,6 +202,8 @@ int main(int argc, char** argv) {
|
||||
optimizer.RegisterPass(CreateEliminateDeadConstantPass());
|
||||
} else if (0 == strcmp(cur_arg, "--fold-spec-const-op-composite")) {
|
||||
optimizer.RegisterPass(CreateFoldSpecConstantOpAndCompositePass());
|
||||
} else if (0 == strcmp(cur_arg, "--strength-reduction")) {
|
||||
optimizer.RegisterPass(CreateStrengthReductionPass());
|
||||
} else if (0 == strcmp(cur_arg, "--unify-const")) {
|
||||
optimizer.RegisterPass(CreateUnifyConstantPass());
|
||||
} else if (0 == strcmp(cur_arg, "--flatten-decorations")) {
|
||||
|
Loading…
Reference in New Issue
Block a user