Add strength reduction; for now replace multiply by power of 2

Create a new optimization pass, strength reduction, which will replace
integer multiplication by a constant power of 2 with an equivalent bit
shift.  More changes could be added later.

- Does not duplicate constants

- Adds vector |Concat| utility function to a common test header.
This commit is contained in:
Steven Perron 2017-09-08 12:08:03 -04:00 committed by David Neto
parent 7be791aaaa
commit e4c7d8e748
12 changed files with 761 additions and 31 deletions

View File

@ -185,6 +185,12 @@ Optimizer::PassToken CreateUnifyConstantPass();
// OpSpecConstantOp. // OpSpecConstantOp.
Optimizer::PassToken CreateEliminateDeadConstantPass(); Optimizer::PassToken CreateEliminateDeadConstantPass();
// Creates a strength-reduction pass.
// A strength-reduction pass will look for opportunities to replace an
// instruction with an equivalent and less expensive one. For example,
// multiplying by a power of 2 can be replaced by a bit shift.
Optimizer::PassToken CreateStrengthReductionPass();
// Creates a block merge pass. // Creates a block merge pass.
// This pass searches for blocks with a single Branch to a block with no // This pass searches for blocks with a single Branch to a block with no
// other predecessors and merges the blocks into a single block. Continue // other predecessors and merges the blocks into a single block. Continue

View File

@ -45,6 +45,7 @@ add_library(SPIRV-Tools-opt
passes.h passes.h
pass_manager.h pass_manager.h
set_spec_constant_default_value_pass.h set_spec_constant_default_value_pass.h
strength_reduction_pass.h
strip_debug_info_pass.h strip_debug_info_pass.h
types.h types.h
type_manager.h type_manager.h
@ -79,6 +80,7 @@ add_library(SPIRV-Tools-opt
mem_pass.cpp mem_pass.cpp
pass.cpp pass.cpp
pass_manager.cpp pass_manager.cpp
strength_reduction_pass.cpp
strip_debug_info_pass.cpp strip_debug_info_pass.cpp
types.cpp types.cpp
type_manager.cpp type_manager.cpp

View File

@ -132,6 +132,11 @@ Optimizer::PassToken CreateEliminateDeadConstantPass() {
MakeUnique<opt::EliminateDeadConstantPass>()); MakeUnique<opt::EliminateDeadConstantPass>());
} }
Optimizer::PassToken CreateStrengthReductionPass() {
return MakeUnique<Optimizer::PassToken::Impl>(
MakeUnique<opt::StrengthReductionPass>());
}
Optimizer::PassToken CreateBlockMergePass() { Optimizer::PassToken CreateBlockMergePass() {
return MakeUnique<Optimizer::PassToken::Impl>( return MakeUnique<Optimizer::PassToken::Impl>(
MakeUnique<opt::BlockMergePass>()); MakeUnique<opt::BlockMergePass>());

View File

@ -35,6 +35,7 @@
#include "aggressive_dead_code_elim_pass.h" #include "aggressive_dead_code_elim_pass.h"
#include "null_pass.h" #include "null_pass.h"
#include "set_spec_constant_default_value_pass.h" #include "set_spec_constant_default_value_pass.h"
#include "strength_reduction_pass.h"
#include "strip_debug_info_pass.h" #include "strip_debug_info_pass.h"
#include "unify_const_pass.h" #include "unify_const_pass.h"

View File

@ -0,0 +1,210 @@
// Copyright (c) 2017 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "strength_reduction_pass.h"
#include <algorithm>
#include <cstdio>
#include <cstring>
#include <unordered_map>
#include <unordered_set>
#include "def_use_manager.h"
#include "log.h"
#include "reflect.h"
namespace {
// Count the number of trailing zeros in the binary representation of
// |constVal|.
uint32_t CountTrailingZeros(uint32_t constVal) {
// Faster if we use the hardware count trailing zeros instruction.
// If not available, we could create a table.
uint32_t shiftAmount = 0;
while ((constVal & 1) == 0) {
++shiftAmount;
constVal = (constVal >> 1);
}
return shiftAmount;
}
// Return true if |val| is a power of 2.
bool IsPowerOf2(uint32_t val) {
// The idea is that the & will clear out the least
// significant 1 bit. If it is a power of 2, then
// there is exactly 1 bit set, and the value becomes 0.
if (val == 0) return false;
return ((val - 1) & val) == 0;
}
} // namespace
namespace spvtools {
namespace opt {
Pass::Status StrengthReductionPass::Process(ir::Module* module) {
// Initialize the member variables on a per module basis.
bool modified = false;
def_use_mgr_.reset(new analysis::DefUseManager(consumer(), module));
int32_type_id_ = 0;
uint32_type_id_ = 0;
std::memset(constant_ids_, 0, sizeof(constant_ids_));
next_id_ = module->IdBound();
module_ = module;
FindIntTypesAndConstants();
modified = ScanFunctions();
// Have to reset the id bound.
module->SetIdBound(next_id_);
return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
}
bool StrengthReductionPass::ReplaceMultiplyByPowerOf2(
ir::BasicBlock::iterator* instPtr) {
ir::BasicBlock::iterator& inst = *instPtr;
assert(inst->opcode() == SpvOp::SpvOpIMul &&
"Only works for multiplication of integers.");
bool modified = false;
// Currently only works on 32-bit integers.
if (inst->type_id() != int32_type_id_ && inst->type_id() != uint32_type_id_) {
return modified;
}
// Check the operands for a constant that is a power of 2.
for (int i = 0; i < 2; i++) {
uint32_t opId = inst->GetSingleWordInOperand(i);
ir::Instruction* opInst = def_use_mgr_->GetDef(opId);
if (opInst->opcode() == SpvOp::SpvOpConstant) {
// We found a constant operand.
uint32_t constVal = opInst->GetSingleWordOperand(2);
if (IsPowerOf2(constVal)) {
modified = true;
uint32_t shiftAmount = CountTrailingZeros(constVal);
uint32_t shiftConstResultId = GetConstantId(shiftAmount);
// Create the new instruction.
uint32_t newResultId = next_id_++;
std::vector<ir::Operand> newOperands;
newOperands.push_back(inst->GetInOperand(1 - i));
ir::Operand shiftOperand(spv_operand_type_t::SPV_OPERAND_TYPE_ID,
{shiftConstResultId});
newOperands.push_back(shiftOperand);
std::unique_ptr<ir::Instruction> newInstruction(
new ir::Instruction(SpvOp::SpvOpShiftLeftLogical, inst->type_id(),
newResultId, newOperands));
// Insert the new instruction and update the data structures.
def_use_mgr_->AnalyzeInstDefUse(&*newInstruction);
inst = inst.InsertBefore(std::move(newInstruction));
++inst;
def_use_mgr_->ReplaceAllUsesWith(inst->result_id(), newResultId);
// Remove the old instruction.
def_use_mgr_->KillInst(&*inst);
// We do not want to replace the instruction twice if both operands
// are constants that are a power of 2. So we break here.
break;
}
}
}
return modified;
}
void StrengthReductionPass::FindIntTypesAndConstants() {
for (auto iter = module_->types_values_begin();
iter != module_->types_values_end(); ++iter) {
switch (iter->opcode()) {
case SpvOp::SpvOpTypeInt:
if (iter->GetSingleWordOperand(1) == 32) {
if (iter->GetSingleWordOperand(2) == 1) {
int32_type_id_ = iter->result_id();
} else {
uint32_type_id_ = iter->result_id();
}
}
break;
case SpvOp::SpvOpConstant:
if (iter->type_id() == uint32_type_id_) {
uint32_t value = iter->GetSingleWordOperand(2);
if (value <= 32) constant_ids_[value] = iter->result_id();
}
break;
default:
break;
}
}
}
uint32_t StrengthReductionPass::GetConstantId(uint32_t val) {
assert(val <= 32 &&
"This function does not handle constants larger than 32.");
if (constant_ids_[val] == 0) {
if (uint32_type_id_ == 0) {
uint32_type_id_ = CreateUint32Type();
}
// Construct the constant.
uint32_t resultId = next_id_++;
ir::Operand constant(spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER,
{val});
std::unique_ptr<ir::Instruction> newConstant(new ir::Instruction(
SpvOp::SpvOpConstant, uint32_type_id_, resultId, {constant}));
module_->AddGlobalValue(std::move(newConstant));
// Store the result id for next time.
constant_ids_[val] = resultId;
}
return constant_ids_[val];
}
bool StrengthReductionPass::ScanFunctions() {
// I did not use |ForEachInst| in the module because the function that acts on
// the instruction gets a pointer to the instruction. We cannot use that to
// insert a new instruction. I want an iterator.
bool modified = false;
for (auto& func : *module_) {
for (auto& bb : func) {
for (auto inst = bb.begin(); inst != bb.end(); ++inst) {
switch (inst->opcode()) {
case SpvOp::SpvOpIMul:
if (ReplaceMultiplyByPowerOf2(&inst)) modified = true;
break;
default:
break;
}
}
}
}
return modified;
}
uint32_t StrengthReductionPass::CreateUint32Type() {
uint32_t type_id = next_id_++;
ir::Operand widthOperand(spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER,
{32});
ir::Operand signOperand(spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER,
{0});
std::unique_ptr<ir::Instruction> newType(new ir::Instruction(
SpvOp::SpvOpTypeInt, type_id, 0, {widthOperand, signOperand}));
module_->AddType(std::move(newType));
return type_id;
}
} // namespace opt
} // namespace spvtools

View File

@ -0,0 +1,75 @@
// Copyright (c) 2017 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef LIBSPIRV_OPT_STRENGTH_REDUCTION_PASS_H_
#define LIBSPIRV_OPT_STRENGTH_REDUCTION_PASS_H_
#include "def_use_manager.h"
#include "module.h"
#include "pass.h"
namespace spvtools {
namespace opt {
// See optimizer.hpp for documentation.
class StrengthReductionPass : public Pass {
public:
const char* name() const override { return "strength-reduction"; }
Status Process(ir::Module*) override;
private:
// Replaces multiple by power of 2 with an equivalent bit shift.
// Returns true if something changed.
bool ReplaceMultiplyByPowerOf2(ir::BasicBlock::iterator*);
// Scan the types and constants in the module looking for the the integer types that we are
// interested in. The shift operation needs a small unsigned integer. We need to find
// them or create them. We do not want duplicates.
void FindIntTypesAndConstants();
// Get the id for the given constant. If it does not exist, it will be
// created. The parameter must be between 0 and 32 inclusive.
uint32_t GetConstantId(uint32_t);
// Replaces certain instructions in function bodies with presumably cheaper
// ones. Returns true if something changed.
bool ScanFunctions();
// Will create the type for an unsigned 32-bit integer and return the id.
// This functions assumes one does not already exist.
uint32_t CreateUint32Type();
// Def-Uses for the module we are processing
std::unique_ptr<analysis::DefUseManager> def_use_mgr_;
// Type ids for the types of interest, or 0 if they do not exist.
uint32_t int32_type_id_;
uint32_t uint32_type_id_;
// constant_ids[i] is the id for unsigned integer constant i.
// We set the limit at 32 because a bit shift of a 32-bit integer does not
// need a value larger than 32.
uint32_t constant_ids_[33];
// Next unused ID
uint32_t next_id_;
// The module that the pass is being applied to.
ir::Module* module_;
};
} // namespace opt
} // namespace spvtools
#endif // LIBSPIRV_OPT_STRENGTH_REDUCTION_PASS_H_

View File

@ -168,3 +168,8 @@ add_spvtools_unittest(TARGET line_debug_info
SRCS line_debug_info_test.cpp pass_utils.cpp SRCS line_debug_info_test.cpp pass_utils.cpp
LIBS SPIRV-Tools-opt LIBS SPIRV-Tools-opt
) )
add_spvtools_unittest(TARGET pass_strength_reduction
SRCS strength_reduction_test.cpp pass_utils.cpp
LIBS SPIRV-Tools-opt
)

View File

@ -16,13 +16,6 @@
#include "pass_fixture.h" #include "pass_fixture.h"
#include "pass_utils.h" #include "pass_utils.h"
template <typename T> std::vector<T> concat(const std::vector<T> &a, const std::vector<T> &b) {
std::vector<T> ret = std::vector<T>();
std::copy(a.begin(), a.end(), back_inserter(ret));
std::copy(b.begin(), b.end(), back_inserter(ret));
return ret;
}
namespace { namespace {
using namespace spvtools; using namespace spvtools;
@ -134,8 +127,8 @@ TEST_F(InlineTest, Simple) {
// clang-format on // clang-format on
}; };
SinglePassRunAndCheck<opt::InlineExhaustivePass>( SinglePassRunAndCheck<opt::InlineExhaustivePass>(
JoinAllInsts(concat(concat(predefs, before), nonEntryFuncs)), JoinAllInsts(Concat(Concat(predefs, before), nonEntryFuncs)),
JoinAllInsts(concat(concat(predefs, after), nonEntryFuncs)), JoinAllInsts(Concat(Concat(predefs, after), nonEntryFuncs)),
/* skip_nop = */ false, /* do_validate = */ true); /* skip_nop = */ false, /* do_validate = */ true);
} }
@ -284,8 +277,8 @@ TEST_F(InlineTest, Nested) {
// clang-format on // clang-format on
}; };
SinglePassRunAndCheck<opt::InlineExhaustivePass>( SinglePassRunAndCheck<opt::InlineExhaustivePass>(
JoinAllInsts(concat(concat(predefs, before), nonEntryFuncs)), JoinAllInsts(Concat(Concat(predefs, before), nonEntryFuncs)),
JoinAllInsts(concat(concat(predefs, after), nonEntryFuncs)), JoinAllInsts(Concat(Concat(predefs, after), nonEntryFuncs)),
/* skip_nop = */ false, /* do_validate = */ true); /* skip_nop = */ false, /* do_validate = */ true);
} }
@ -413,8 +406,8 @@ TEST_F(InlineTest, InOutParameter) {
// clang-format on // clang-format on
}; };
SinglePassRunAndCheck<opt::InlineExhaustivePass>( SinglePassRunAndCheck<opt::InlineExhaustivePass>(
JoinAllInsts(concat(concat(predefs, before), nonEntryFuncs)), JoinAllInsts(Concat(Concat(predefs, before), nonEntryFuncs)),
JoinAllInsts(concat(concat(predefs, after), nonEntryFuncs)), JoinAllInsts(Concat(Concat(predefs, after), nonEntryFuncs)),
/* skip_nop = */ false, /* do_validate = */ true); /* skip_nop = */ false, /* do_validate = */ true);
} }
@ -549,8 +542,8 @@ TEST_F(InlineTest, BranchInCallee) {
// clang-format on // clang-format on
}; };
SinglePassRunAndCheck<opt::InlineExhaustivePass>( SinglePassRunAndCheck<opt::InlineExhaustivePass>(
JoinAllInsts(concat(concat(predefs, before), nonEntryFuncs)), JoinAllInsts(Concat(Concat(predefs, before), nonEntryFuncs)),
JoinAllInsts(concat(concat(predefs, after), nonEntryFuncs)), JoinAllInsts(Concat(Concat(predefs, after), nonEntryFuncs)),
/* skip_nop = */ false, /* do_validate = */ true); /* skip_nop = */ false, /* do_validate = */ true);
} }
@ -744,8 +737,8 @@ TEST_F(InlineTest, PhiAfterCall) {
// clang-format on // clang-format on
}; };
SinglePassRunAndCheck<opt::InlineExhaustivePass>( SinglePassRunAndCheck<opt::InlineExhaustivePass>(
JoinAllInsts(concat(concat(predefs, before), nonEntryFuncs)), JoinAllInsts(Concat(Concat(predefs, before), nonEntryFuncs)),
JoinAllInsts(concat(concat(predefs, after), nonEntryFuncs)), JoinAllInsts(Concat(Concat(predefs, after), nonEntryFuncs)),
/* skip_nop = */ false, /* do_validate = */ true); /* skip_nop = */ false, /* do_validate = */ true);
} }
@ -941,8 +934,8 @@ TEST_F(InlineTest, OpSampledImageOutOfBlock) {
// clang-format on // clang-format on
}; };
SinglePassRunAndCheck<opt::InlineExhaustivePass>( SinglePassRunAndCheck<opt::InlineExhaustivePass>(
JoinAllInsts(concat(concat(predefs, before), nonEntryFuncs)), JoinAllInsts(Concat(Concat(predefs, before), nonEntryFuncs)),
JoinAllInsts(concat(concat(predefs, after), nonEntryFuncs)), JoinAllInsts(Concat(Concat(predefs, after), nonEntryFuncs)),
/* skip_nop = */ false, /* do_validate = */ true); /* skip_nop = */ false, /* do_validate = */ true);
} }
@ -1147,8 +1140,8 @@ TEST_F(InlineTest, OpImageOutOfBlock) {
// clang-format on // clang-format on
}; };
SinglePassRunAndCheck<opt::InlineExhaustivePass>( SinglePassRunAndCheck<opt::InlineExhaustivePass>(
JoinAllInsts(concat(concat(predefs, before), nonEntryFuncs)), JoinAllInsts(Concat(Concat(predefs, before), nonEntryFuncs)),
JoinAllInsts(concat(concat(predefs, after), nonEntryFuncs)), JoinAllInsts(Concat(Concat(predefs, after), nonEntryFuncs)),
/* skip_nop = */ false, /* do_validate = */ true); /* skip_nop = */ false, /* do_validate = */ true);
} }
@ -1353,8 +1346,8 @@ TEST_F(InlineTest, OpImageAndOpSampledImageOutOfBlock) {
// clang-format on // clang-format on
}; };
SinglePassRunAndCheck<opt::InlineExhaustivePass>( SinglePassRunAndCheck<opt::InlineExhaustivePass>(
JoinAllInsts(concat(concat(predefs, before), nonEntryFuncs)), JoinAllInsts(Concat(Concat(predefs, before), nonEntryFuncs)),
JoinAllInsts(concat(concat(predefs, after), nonEntryFuncs)), JoinAllInsts(Concat(Concat(predefs, after), nonEntryFuncs)),
/* skip_nop = */ false, /* do_validate = */ true); /* skip_nop = */ false, /* do_validate = */ true);
} }

View File

@ -16,14 +16,6 @@
#include "pass_fixture.h" #include "pass_fixture.h"
#include "pass_utils.h" #include "pass_utils.h"
template <typename T>
std::vector<T> concat(const std::vector<T>& a, const std::vector<T>& b) {
std::vector<T> ret;
std::copy(a.begin(), a.end(), back_inserter(ret));
std::copy(b.begin(), b.end(), back_inserter(ret));
return ret;
}
namespace { namespace {
using namespace spvtools; using namespace spvtools;

View File

@ -49,6 +49,16 @@ std::string JoinAllInsts(const std::vector<const char*>& insts);
// will be ignored. // will be ignored.
std::string JoinNonDebugInsts(const std::vector<const char*>& insts); std::string JoinNonDebugInsts(const std::vector<const char*>& insts);
// Returns a vector that contains the contents of |a| followed by the contents
// of |b|.
template <typename T>
std::vector<T> Concat(const std::vector<T>& a, const std::vector<T>& b) {
std::vector<T> ret;
std::copy(a.begin(), a.end(), back_inserter(ret));
std::copy(b.begin(), b.end(), back_inserter(ret));
return ret;
}
} // namespace spvtools } // namespace spvtools
#endif // LIBSPIRV_TEST_OPT_PASS_UTILS_H_ #endif // LIBSPIRV_TEST_OPT_PASS_UTILS_H_

View File

@ -0,0 +1,427 @@
// Copyright (c) 2017 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "assembly_builder.h"
#include "gmock/gmock.h"
#include "pass_fixture.h"
#include "pass_utils.h"
#include <algorithm>
#include <cstdarg>
#include <iostream>
#include <sstream>
#include <unordered_set>
namespace {
using namespace spvtools;
using ::testing::HasSubstr;
using ::testing::MatchesRegex;
using StrengthReductionBasicTest = PassTest<::testing::Test>;
// Test to make sure we replace 5*8.
TEST_F(StrengthReductionBasicTest, BasicReplaceMulBy8) {
const std::vector<const char*> text = {
// clang-format off
"OpCapability Shader",
"%1 = OpExtInstImport \"GLSL.std.450\"",
"OpMemoryModel Logical GLSL450",
"OpEntryPoint Vertex %main \"main\"",
"OpName %main \"main\"",
"%void = OpTypeVoid",
"%4 = OpTypeFunction %void",
"%uint = OpTypeInt 32 0",
"%uint_5 = OpConstant %uint 5",
"%uint_8 = OpConstant %uint 8",
"%main = OpFunction %void None %4",
"%8 = OpLabel",
"%9 = OpIMul %uint %uint_5 %uint_8",
"OpReturn",
"OpFunctionEnd"
// clang-format on
};
auto result = SinglePassRunAndDisassemble<opt::StrengthReductionPass>(
JoinAllInsts(text), /* skip_nop = */ true);
EXPECT_EQ(opt::Pass::Status::SuccessWithChange, std::get<1>(result));
const std::string& output = std::get<0>(result);
EXPECT_THAT(output, Not(HasSubstr("OpIMul")));
EXPECT_THAT(output, HasSubstr("OpShiftLeftLogical %uint %uint_5 %uint_3"));
}
// Test to make sure we replace 16*5.
TEST_F(StrengthReductionBasicTest, BasicReplaceMulBy16) {
const std::vector<const char*> text = {
// clang-format off
"OpCapability Shader",
"%1 = OpExtInstImport \"GLSL.std.450\"",
"OpMemoryModel Logical GLSL450",
"OpEntryPoint Vertex %main \"main\"",
"OpName %main \"main\"",
"%void = OpTypeVoid",
"%4 = OpTypeFunction %void",
"%int = OpTypeInt 32 1",
"%int_16 = OpConstant %int 16",
"%int_5 = OpConstant %int 5",
"%main = OpFunction %void None %4",
"%8 = OpLabel",
"%9 = OpIMul %int %int_16 %int_5",
"OpReturn",
"OpFunctionEnd"
// clang-format on
};
auto result = SinglePassRunAndDisassemble<opt::StrengthReductionPass>(
JoinAllInsts(text), /* skip_nop = */ true);
EXPECT_EQ(opt::Pass::Status::SuccessWithChange, std::get<1>(result));
const std::string& output = std::get<0>(result);
EXPECT_THAT(output, Not(HasSubstr("OpIMul")));
EXPECT_THAT(output, HasSubstr("OpShiftLeftLogical %int %int_5 %uint_4"));
}
// Test to make sure we replace a multiple of 32 and 4.
TEST_F(StrengthReductionBasicTest, BasicTwoPowersOf2) {
// In this case, we have two powers of 2. Need to make sure we replace only
// one of them for the bit shift.
// clang-format off
const std::string text = R"(
OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint Vertex %main "main"
OpName %main "main"
%void = OpTypeVoid
%4 = OpTypeFunction %void
%int = OpTypeInt 32 1
%int_32 = OpConstant %int 32
%int_4 = OpConstant %int 4
%main = OpFunction %void None %4
%8 = OpLabel
%9 = OpIMul %int %int_32 %int_4
OpReturn
OpFunctionEnd
)";
// clang-format on
auto result = SinglePassRunAndDisassemble<opt::StrengthReductionPass>(
text, /* skip_nop = */ true);
EXPECT_EQ(opt::Pass::Status::SuccessWithChange, std::get<1>(result));
const std::string& output = std::get<0>(result);
EXPECT_THAT(output, Not(HasSubstr("OpIMul")));
EXPECT_THAT(output, HasSubstr("OpShiftLeftLogical %int %int_4 %uint_5"));
}
// Test to make sure we don't replace 0*5.
TEST_F(StrengthReductionBasicTest, BasicDontReplace0) {
const std::vector<const char*> text = {
// clang-format off
"OpCapability Shader",
"%1 = OpExtInstImport \"GLSL.std.450\"",
"OpMemoryModel Logical GLSL450",
"OpEntryPoint Vertex %main \"main\"",
"OpName %main \"main\"",
"%void = OpTypeVoid",
"%4 = OpTypeFunction %void",
"%int = OpTypeInt 32 1",
"%int_0 = OpConstant %int 0",
"%int_5 = OpConstant %int 5",
"%main = OpFunction %void None %4",
"%8 = OpLabel",
"%9 = OpIMul %int %int_0 %int_5",
"OpReturn",
"OpFunctionEnd"
// clang-format on
};
auto result = SinglePassRunAndDisassemble<opt::StrengthReductionPass>(
JoinAllInsts(text), /* skip_nop = */ true);
EXPECT_EQ(opt::Pass::Status::SuccessWithoutChange, std::get<1>(result));
}
// Test to make sure we do not replace a multiple of 5 and 7.
TEST_F(StrengthReductionBasicTest, BasicNoChange) {
const std::vector<const char*> text = {
// clang-format off
"OpCapability Shader",
"%1 = OpExtInstImport \"GLSL.std.450\"",
"OpMemoryModel Logical GLSL450",
"OpEntryPoint Vertex %2 \"main\"",
"OpName %2 \"main\"",
"%3 = OpTypeVoid",
"%4 = OpTypeFunction %3",
"%5 = OpTypeInt 32 1",
"%6 = OpTypeInt 32 0",
"%7 = OpConstant %5 5",
"%8 = OpConstant %5 7",
"%2 = OpFunction %3 None %4",
"%9 = OpLabel",
"%10 = OpIMul %5 %7 %8",
"OpReturn",
"OpFunctionEnd",
// clang-format on
};
auto result = SinglePassRunAndDisassemble<opt::StrengthReductionPass>(
JoinAllInsts(text), /* skip_nop = */ true);
EXPECT_EQ(opt::Pass::Status::SuccessWithoutChange, std::get<1>(result));
}
// Test to make sure constants and types are reused and not duplicated.
TEST_F(StrengthReductionBasicTest, NoDuplicateConstantsAndTypes) {
const std::vector<const char*> text = {
// clang-format off
"OpCapability Shader",
"%1 = OpExtInstImport \"GLSL.std.450\"",
"OpMemoryModel Logical GLSL450",
"OpEntryPoint Vertex %main \"main\"",
"OpName %main \"main\"",
"%void = OpTypeVoid",
"%4 = OpTypeFunction %void",
"%uint = OpTypeInt 32 0",
"%uint_8 = OpConstant %uint 8",
"%uint_3 = OpConstant %uint 3",
"%main = OpFunction %void None %4",
"%8 = OpLabel",
"%9 = OpIMul %uint %uint_8 %uint_3",
"OpReturn",
"OpFunctionEnd",
// clang-format on
};
auto result = SinglePassRunAndDisassemble<opt::StrengthReductionPass>(
JoinAllInsts(text), /* skip_nop = */ true);
EXPECT_EQ(opt::Pass::Status::SuccessWithChange, std::get<1>(result));
const std::string& output = std::get<0>(result);
EXPECT_THAT(output,
Not(MatchesRegex(".*OpConstant %uint 3.*OpConstant %uint 3.*")));
EXPECT_THAT(output, Not(MatchesRegex(".*OpTypeInt 32 0.*OpTypeInt 32 0.*")));
}
// Test to make sure we generate the constants only once
TEST_F(StrengthReductionBasicTest, BasicCreateOneConst) {
const std::vector<const char*> text = {
// clang-format off
"OpCapability Shader",
"%1 = OpExtInstImport \"GLSL.std.450\"",
"OpMemoryModel Logical GLSL450",
"OpEntryPoint Vertex %main \"main\"",
"OpName %main \"main\"",
"%void = OpTypeVoid",
"%4 = OpTypeFunction %void",
"%uint = OpTypeInt 32 0",
"%uint_5 = OpConstant %uint 5",
"%uint_9 = OpConstant %uint 9",
"%uint_128 = OpConstant %uint 128",
"%main = OpFunction %void None %4",
"%8 = OpLabel",
"%9 = OpIMul %uint %uint_5 %uint_128",
"%10 = OpIMul %uint %uint_9 %uint_128",
"OpReturn",
"OpFunctionEnd"
// clang-format on
};
auto result = SinglePassRunAndDisassemble<opt::StrengthReductionPass>(
JoinAllInsts(text), /* skip_nop = */ true);
EXPECT_EQ(opt::Pass::Status::SuccessWithChange, std::get<1>(result));
const std::string& output = std::get<0>(result);
EXPECT_THAT(output, Not(HasSubstr("OpIMul")));
EXPECT_THAT(output, HasSubstr("OpShiftLeftLogical %uint %uint_5 %uint_7"));
EXPECT_THAT(output, HasSubstr("OpShiftLeftLogical %uint %uint_9 %uint_7"));
}
// Test to make sure we generate the instructions in the correct position and
// that the uses get replaced as well. Here we check that the use in the return
// is replaced, we also check that we can replace two OpIMuls when one feeds the
// other.
TEST_F(StrengthReductionBasicTest, BasicCheckPositionAndReplacement) {
// This is just the preamble to set up the test.
const std::vector<const char*> common_text = {
// clang-format off
"OpCapability Shader",
"%1 = OpExtInstImport \"GLSL.std.450\"",
"OpMemoryModel Logical GLSL450",
"OpEntryPoint Fragment %main \"main\" %gl_FragColor",
"OpExecutionMode %main OriginUpperLeft",
"OpName %main \"main\"",
"OpName %foo_i1_ \"foo(i1;\"",
"OpName %n \"n\"",
"OpName %gl_FragColor \"gl_FragColor\"",
"OpName %param \"param\"",
"OpDecorate %gl_FragColor Location 0",
"%void = OpTypeVoid",
"%3 = OpTypeFunction %void",
"%int = OpTypeInt 32 1",
"%_ptr_Function_int = OpTypePointer Function %int",
"%8 = OpTypeFunction %int %_ptr_Function_int",
"%int_256 = OpConstant %int 256",
"%int_2 = OpConstant %int 2",
"%float = OpTypeFloat 32",
"%v4float = OpTypeVector %float 4",
"%_ptr_Output_v4float = OpTypePointer Output %v4float",
"%gl_FragColor = OpVariable %_ptr_Output_v4float Output",
"%float_1 = OpConstant %float 1",
"%int_10 = OpConstant %int 10",
"%float_0_4 = OpConstant %float 0.4",
"%float_0_8 = OpConstant %float 0.8",
"%uint = OpTypeInt 32 0",
"%uint_8 = OpConstant %uint 8",
"%uint_1 = OpConstant %uint 1",
"%main = OpFunction %void None %3",
"%5 = OpLabel",
"%param = OpVariable %_ptr_Function_int Function",
"OpStore %param %int_10",
"%26 = OpFunctionCall %int %foo_i1_ %param",
"%27 = OpConvertSToF %float %26",
"%28 = OpFDiv %float %float_1 %27",
"%31 = OpCompositeConstruct %v4float %28 %float_0_4 %float_0_8 %float_1",
"OpStore %gl_FragColor %31",
"OpReturn",
"OpFunctionEnd"
// clang-format on
};
// This is the real test. The two OpIMul should be replaced. The expected
// output is in |foo_after|.
const std::vector<const char*> foo_before = {
// clang-format off
"%foo_i1_ = OpFunction %int None %8",
"%n = OpFunctionParameter %_ptr_Function_int",
"%11 = OpLabel",
"%12 = OpLoad %int %n",
"%14 = OpIMul %int %12 %int_256",
"%16 = OpIMul %int %14 %int_2",
"OpReturnValue %16",
"OpFunctionEnd",
// clang-format on
};
const std::vector<const char*> foo_after = {
// clang-format off
"%foo_i1_ = OpFunction %int None %8",
"%n = OpFunctionParameter %_ptr_Function_int",
"%11 = OpLabel",
"%12 = OpLoad %int %n",
"%33 = OpShiftLeftLogical %int %12 %uint_8",
"%34 = OpShiftLeftLogical %int %33 %uint_1",
"OpReturnValue %34",
"OpFunctionEnd",
// clang-format on
};
SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
SinglePassRunAndCheck<opt::StrengthReductionPass>(
JoinAllInsts(Concat(common_text, foo_before)),
JoinAllInsts(Concat(common_text, foo_after)),
/* skip_nop = */ true, /* do_validate = */ true);
}
// Test that, when the result of an OpIMul instruction has more than 1 use, and
// the instruction is replaced, all of the uses of the results are replace with
// the new result.
TEST_F(StrengthReductionBasicTest, BasicTestMultipleReplacements) {
// This is just the preamble to set up the test.
const std::vector<const char*> common_text = {
// clang-format off
"OpCapability Shader",
"%1 = OpExtInstImport \"GLSL.std.450\"",
"OpMemoryModel Logical GLSL450",
"OpEntryPoint Fragment %main \"main\" %gl_FragColor",
"OpExecutionMode %main OriginUpperLeft",
"OpName %main \"main\"",
"OpName %foo_i1_ \"foo(i1;\"",
"OpName %n \"n\"",
"OpName %gl_FragColor \"gl_FragColor\"",
"OpName %param \"param\"",
"OpDecorate %gl_FragColor Location 0",
"%void = OpTypeVoid",
"%3 = OpTypeFunction %void",
"%int = OpTypeInt 32 1",
"%_ptr_Function_int = OpTypePointer Function %int",
"%8 = OpTypeFunction %int %_ptr_Function_int",
"%int_256 = OpConstant %int 256",
"%int_2 = OpConstant %int 2",
"%float = OpTypeFloat 32",
"%v4float = OpTypeVector %float 4",
"%_ptr_Output_v4float = OpTypePointer Output %v4float",
"%gl_FragColor = OpVariable %_ptr_Output_v4float Output",
"%float_1 = OpConstant %float 1",
"%int_10 = OpConstant %int 10",
"%float_0_4 = OpConstant %float 0.4",
"%float_0_8 = OpConstant %float 0.8",
"%uint = OpTypeInt 32 0",
"%uint_8 = OpConstant %uint 8",
"%uint_1 = OpConstant %uint 1",
"%main = OpFunction %void None %3",
"%5 = OpLabel",
"%param = OpVariable %_ptr_Function_int Function",
"OpStore %param %int_10",
"%26 = OpFunctionCall %int %foo_i1_ %param",
"%27 = OpConvertSToF %float %26",
"%28 = OpFDiv %float %float_1 %27",
"%31 = OpCompositeConstruct %v4float %28 %float_0_4 %float_0_8 %float_1",
"OpStore %gl_FragColor %31",
"OpReturn",
"OpFunctionEnd"
// clang-format on
};
// This is the real test. The two OpIMul instructions should be replaced. In
// particular, we want to be sure that both uses of %16 are changed to use the
// new result.
const std::vector<const char*> foo_before = {
// clang-format off
"%foo_i1_ = OpFunction %int None %8",
"%n = OpFunctionParameter %_ptr_Function_int",
"%11 = OpLabel",
"%12 = OpLoad %int %n",
"%14 = OpIMul %int %12 %int_256",
"%16 = OpIMul %int %14 %int_2",
"%17 = OpIAdd %int %14 %16",
"OpReturnValue %17",
"OpFunctionEnd",
// clang-format on
};
const std::vector<const char*> foo_after = {
// clang-format off
"%foo_i1_ = OpFunction %int None %8",
"%n = OpFunctionParameter %_ptr_Function_int",
"%11 = OpLabel",
"%12 = OpLoad %int %n",
"%34 = OpShiftLeftLogical %int %12 %uint_8",
"%35 = OpShiftLeftLogical %int %34 %uint_1",
"%17 = OpIAdd %int %34 %35",
"OpReturnValue %17",
"OpFunctionEnd",
// clang-format on
};
SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
SinglePassRunAndCheck<opt::StrengthReductionPass>(
JoinAllInsts(Concat(common_text, foo_before)),
JoinAllInsts(Concat(common_text, foo_after)),
/* skip_nop = */ true, /* do_validate = */ true);
}
} // anonymous namespace

View File

@ -112,6 +112,8 @@ Options:
Join two blocks into a single block if the second has the Join two blocks into a single block if the second has the
first as its only predecessor. Performed only on entry point first as its only predecessor. Performed only on entry point
call tree functions. call tree functions.
--strength-reduction
Replaces instructions with equivalent and less expensive ones.
-h, --help -h, --help
Print this help. Print this help.
--version --version
@ -200,6 +202,8 @@ int main(int argc, char** argv) {
optimizer.RegisterPass(CreateEliminateDeadConstantPass()); optimizer.RegisterPass(CreateEliminateDeadConstantPass());
} else if (0 == strcmp(cur_arg, "--fold-spec-const-op-composite")) { } else if (0 == strcmp(cur_arg, "--fold-spec-const-op-composite")) {
optimizer.RegisterPass(CreateFoldSpecConstantOpAndCompositePass()); optimizer.RegisterPass(CreateFoldSpecConstantOpAndCompositePass());
} else if (0 == strcmp(cur_arg, "--strength-reduction")) {
optimizer.RegisterPass(CreateStrengthReductionPass());
} else if (0 == strcmp(cur_arg, "--unify-const")) { } else if (0 == strcmp(cur_arg, "--unify-const")) {
optimizer.RegisterPass(CreateUnifyConstantPass()); optimizer.RegisterPass(CreateUnifyConstantPass());
} else if (0 == strcmp(cur_arg, "--flatten-decorations")) { } else if (0 == strcmp(cur_arg, "--flatten-decorations")) {