Add strength reduction; for now replace multiply by power of 2

Create a new optimization pass, strength reduction, which will replace
integer multiplication by a constant power of 2 with an equivalent bit
shift.  More changes could be added later.

- Does not duplicate constants

- Adds vector |Concat| utility function to a common test header.
This commit is contained in:
Steven Perron 2017-09-08 12:08:03 -04:00 committed by David Neto
parent 7be791aaaa
commit e4c7d8e748
12 changed files with 761 additions and 31 deletions

View File

@ -185,6 +185,12 @@ Optimizer::PassToken CreateUnifyConstantPass();
// OpSpecConstantOp.
Optimizer::PassToken CreateEliminateDeadConstantPass();
// Creates a strength-reduction pass.
// A strength-reduction pass will look for opportunities to replace an
// instruction with an equivalent and less expensive one. For example,
// multiplying by a power of 2 can be replaced by a bit shift.
Optimizer::PassToken CreateStrengthReductionPass();
// Creates a block merge pass.
// This pass searches for blocks with a single Branch to a block with no
// other predecessors and merges the blocks into a single block. Continue

View File

@ -45,6 +45,7 @@ add_library(SPIRV-Tools-opt
passes.h
pass_manager.h
set_spec_constant_default_value_pass.h
strength_reduction_pass.h
strip_debug_info_pass.h
types.h
type_manager.h
@ -79,6 +80,7 @@ add_library(SPIRV-Tools-opt
mem_pass.cpp
pass.cpp
pass_manager.cpp
strength_reduction_pass.cpp
strip_debug_info_pass.cpp
types.cpp
type_manager.cpp

View File

@ -132,6 +132,11 @@ Optimizer::PassToken CreateEliminateDeadConstantPass() {
MakeUnique<opt::EliminateDeadConstantPass>());
}
Optimizer::PassToken CreateStrengthReductionPass() {
return MakeUnique<Optimizer::PassToken::Impl>(
MakeUnique<opt::StrengthReductionPass>());
}
Optimizer::PassToken CreateBlockMergePass() {
return MakeUnique<Optimizer::PassToken::Impl>(
MakeUnique<opt::BlockMergePass>());

View File

@ -35,6 +35,7 @@
#include "aggressive_dead_code_elim_pass.h"
#include "null_pass.h"
#include "set_spec_constant_default_value_pass.h"
#include "strength_reduction_pass.h"
#include "strip_debug_info_pass.h"
#include "unify_const_pass.h"

View File

@ -0,0 +1,210 @@
// Copyright (c) 2017 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "strength_reduction_pass.h"
#include <algorithm>
#include <cstdio>
#include <cstring>
#include <unordered_map>
#include <unordered_set>
#include "def_use_manager.h"
#include "log.h"
#include "reflect.h"
namespace {
// Count the number of trailing zeros in the binary representation of
// |constVal|.
uint32_t CountTrailingZeros(uint32_t constVal) {
// Faster if we use the hardware count trailing zeros instruction.
// If not available, we could create a table.
uint32_t shiftAmount = 0;
while ((constVal & 1) == 0) {
++shiftAmount;
constVal = (constVal >> 1);
}
return shiftAmount;
}
// Return true if |val| is a power of 2.
bool IsPowerOf2(uint32_t val) {
// The idea is that the & will clear out the least
// significant 1 bit. If it is a power of 2, then
// there is exactly 1 bit set, and the value becomes 0.
if (val == 0) return false;
return ((val - 1) & val) == 0;
}
} // namespace
namespace spvtools {
namespace opt {
Pass::Status StrengthReductionPass::Process(ir::Module* module) {
// Initialize the member variables on a per module basis.
bool modified = false;
def_use_mgr_.reset(new analysis::DefUseManager(consumer(), module));
int32_type_id_ = 0;
uint32_type_id_ = 0;
std::memset(constant_ids_, 0, sizeof(constant_ids_));
next_id_ = module->IdBound();
module_ = module;
FindIntTypesAndConstants();
modified = ScanFunctions();
// Have to reset the id bound.
module->SetIdBound(next_id_);
return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
}
bool StrengthReductionPass::ReplaceMultiplyByPowerOf2(
ir::BasicBlock::iterator* instPtr) {
ir::BasicBlock::iterator& inst = *instPtr;
assert(inst->opcode() == SpvOp::SpvOpIMul &&
"Only works for multiplication of integers.");
bool modified = false;
// Currently only works on 32-bit integers.
if (inst->type_id() != int32_type_id_ && inst->type_id() != uint32_type_id_) {
return modified;
}
// Check the operands for a constant that is a power of 2.
for (int i = 0; i < 2; i++) {
uint32_t opId = inst->GetSingleWordInOperand(i);
ir::Instruction* opInst = def_use_mgr_->GetDef(opId);
if (opInst->opcode() == SpvOp::SpvOpConstant) {
// We found a constant operand.
uint32_t constVal = opInst->GetSingleWordOperand(2);
if (IsPowerOf2(constVal)) {
modified = true;
uint32_t shiftAmount = CountTrailingZeros(constVal);
uint32_t shiftConstResultId = GetConstantId(shiftAmount);
// Create the new instruction.
uint32_t newResultId = next_id_++;
std::vector<ir::Operand> newOperands;
newOperands.push_back(inst->GetInOperand(1 - i));
ir::Operand shiftOperand(spv_operand_type_t::SPV_OPERAND_TYPE_ID,
{shiftConstResultId});
newOperands.push_back(shiftOperand);
std::unique_ptr<ir::Instruction> newInstruction(
new ir::Instruction(SpvOp::SpvOpShiftLeftLogical, inst->type_id(),
newResultId, newOperands));
// Insert the new instruction and update the data structures.
def_use_mgr_->AnalyzeInstDefUse(&*newInstruction);
inst = inst.InsertBefore(std::move(newInstruction));
++inst;
def_use_mgr_->ReplaceAllUsesWith(inst->result_id(), newResultId);
// Remove the old instruction.
def_use_mgr_->KillInst(&*inst);
// We do not want to replace the instruction twice if both operands
// are constants that are a power of 2. So we break here.
break;
}
}
}
return modified;
}
void StrengthReductionPass::FindIntTypesAndConstants() {
for (auto iter = module_->types_values_begin();
iter != module_->types_values_end(); ++iter) {
switch (iter->opcode()) {
case SpvOp::SpvOpTypeInt:
if (iter->GetSingleWordOperand(1) == 32) {
if (iter->GetSingleWordOperand(2) == 1) {
int32_type_id_ = iter->result_id();
} else {
uint32_type_id_ = iter->result_id();
}
}
break;
case SpvOp::SpvOpConstant:
if (iter->type_id() == uint32_type_id_) {
uint32_t value = iter->GetSingleWordOperand(2);
if (value <= 32) constant_ids_[value] = iter->result_id();
}
break;
default:
break;
}
}
}
uint32_t StrengthReductionPass::GetConstantId(uint32_t val) {
assert(val <= 32 &&
"This function does not handle constants larger than 32.");
if (constant_ids_[val] == 0) {
if (uint32_type_id_ == 0) {
uint32_type_id_ = CreateUint32Type();
}
// Construct the constant.
uint32_t resultId = next_id_++;
ir::Operand constant(spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER,
{val});
std::unique_ptr<ir::Instruction> newConstant(new ir::Instruction(
SpvOp::SpvOpConstant, uint32_type_id_, resultId, {constant}));
module_->AddGlobalValue(std::move(newConstant));
// Store the result id for next time.
constant_ids_[val] = resultId;
}
return constant_ids_[val];
}
bool StrengthReductionPass::ScanFunctions() {
// I did not use |ForEachInst| in the module because the function that acts on
// the instruction gets a pointer to the instruction. We cannot use that to
// insert a new instruction. I want an iterator.
bool modified = false;
for (auto& func : *module_) {
for (auto& bb : func) {
for (auto inst = bb.begin(); inst != bb.end(); ++inst) {
switch (inst->opcode()) {
case SpvOp::SpvOpIMul:
if (ReplaceMultiplyByPowerOf2(&inst)) modified = true;
break;
default:
break;
}
}
}
}
return modified;
}
uint32_t StrengthReductionPass::CreateUint32Type() {
uint32_t type_id = next_id_++;
ir::Operand widthOperand(spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER,
{32});
ir::Operand signOperand(spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER,
{0});
std::unique_ptr<ir::Instruction> newType(new ir::Instruction(
SpvOp::SpvOpTypeInt, type_id, 0, {widthOperand, signOperand}));
module_->AddType(std::move(newType));
return type_id;
}
} // namespace opt
} // namespace spvtools

View File

@ -0,0 +1,75 @@
// Copyright (c) 2017 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef LIBSPIRV_OPT_STRENGTH_REDUCTION_PASS_H_
#define LIBSPIRV_OPT_STRENGTH_REDUCTION_PASS_H_
#include "def_use_manager.h"
#include "module.h"
#include "pass.h"
namespace spvtools {
namespace opt {
// See optimizer.hpp for documentation.
class StrengthReductionPass : public Pass {
public:
const char* name() const override { return "strength-reduction"; }
Status Process(ir::Module*) override;
private:
// Replaces multiple by power of 2 with an equivalent bit shift.
// Returns true if something changed.
bool ReplaceMultiplyByPowerOf2(ir::BasicBlock::iterator*);
// Scan the types and constants in the module looking for the the integer types that we are
// interested in. The shift operation needs a small unsigned integer. We need to find
// them or create them. We do not want duplicates.
void FindIntTypesAndConstants();
// Get the id for the given constant. If it does not exist, it will be
// created. The parameter must be between 0 and 32 inclusive.
uint32_t GetConstantId(uint32_t);
// Replaces certain instructions in function bodies with presumably cheaper
// ones. Returns true if something changed.
bool ScanFunctions();
// Will create the type for an unsigned 32-bit integer and return the id.
// This functions assumes one does not already exist.
uint32_t CreateUint32Type();
// Def-Uses for the module we are processing
std::unique_ptr<analysis::DefUseManager> def_use_mgr_;
// Type ids for the types of interest, or 0 if they do not exist.
uint32_t int32_type_id_;
uint32_t uint32_type_id_;
// constant_ids[i] is the id for unsigned integer constant i.
// We set the limit at 32 because a bit shift of a 32-bit integer does not
// need a value larger than 32.
uint32_t constant_ids_[33];
// Next unused ID
uint32_t next_id_;
// The module that the pass is being applied to.
ir::Module* module_;
};
} // namespace opt
} // namespace spvtools
#endif // LIBSPIRV_OPT_STRENGTH_REDUCTION_PASS_H_

View File

@ -168,3 +168,8 @@ add_spvtools_unittest(TARGET line_debug_info
SRCS line_debug_info_test.cpp pass_utils.cpp
LIBS SPIRV-Tools-opt
)
add_spvtools_unittest(TARGET pass_strength_reduction
SRCS strength_reduction_test.cpp pass_utils.cpp
LIBS SPIRV-Tools-opt
)

View File

@ -16,13 +16,6 @@
#include "pass_fixture.h"
#include "pass_utils.h"
template <typename T> std::vector<T> concat(const std::vector<T> &a, const std::vector<T> &b) {
std::vector<T> ret = std::vector<T>();
std::copy(a.begin(), a.end(), back_inserter(ret));
std::copy(b.begin(), b.end(), back_inserter(ret));
return ret;
}
namespace {
using namespace spvtools;
@ -134,8 +127,8 @@ TEST_F(InlineTest, Simple) {
// clang-format on
};
SinglePassRunAndCheck<opt::InlineExhaustivePass>(
JoinAllInsts(concat(concat(predefs, before), nonEntryFuncs)),
JoinAllInsts(concat(concat(predefs, after), nonEntryFuncs)),
JoinAllInsts(Concat(Concat(predefs, before), nonEntryFuncs)),
JoinAllInsts(Concat(Concat(predefs, after), nonEntryFuncs)),
/* skip_nop = */ false, /* do_validate = */ true);
}
@ -284,8 +277,8 @@ TEST_F(InlineTest, Nested) {
// clang-format on
};
SinglePassRunAndCheck<opt::InlineExhaustivePass>(
JoinAllInsts(concat(concat(predefs, before), nonEntryFuncs)),
JoinAllInsts(concat(concat(predefs, after), nonEntryFuncs)),
JoinAllInsts(Concat(Concat(predefs, before), nonEntryFuncs)),
JoinAllInsts(Concat(Concat(predefs, after), nonEntryFuncs)),
/* skip_nop = */ false, /* do_validate = */ true);
}
@ -413,8 +406,8 @@ TEST_F(InlineTest, InOutParameter) {
// clang-format on
};
SinglePassRunAndCheck<opt::InlineExhaustivePass>(
JoinAllInsts(concat(concat(predefs, before), nonEntryFuncs)),
JoinAllInsts(concat(concat(predefs, after), nonEntryFuncs)),
JoinAllInsts(Concat(Concat(predefs, before), nonEntryFuncs)),
JoinAllInsts(Concat(Concat(predefs, after), nonEntryFuncs)),
/* skip_nop = */ false, /* do_validate = */ true);
}
@ -549,8 +542,8 @@ TEST_F(InlineTest, BranchInCallee) {
// clang-format on
};
SinglePassRunAndCheck<opt::InlineExhaustivePass>(
JoinAllInsts(concat(concat(predefs, before), nonEntryFuncs)),
JoinAllInsts(concat(concat(predefs, after), nonEntryFuncs)),
JoinAllInsts(Concat(Concat(predefs, before), nonEntryFuncs)),
JoinAllInsts(Concat(Concat(predefs, after), nonEntryFuncs)),
/* skip_nop = */ false, /* do_validate = */ true);
}
@ -744,8 +737,8 @@ TEST_F(InlineTest, PhiAfterCall) {
// clang-format on
};
SinglePassRunAndCheck<opt::InlineExhaustivePass>(
JoinAllInsts(concat(concat(predefs, before), nonEntryFuncs)),
JoinAllInsts(concat(concat(predefs, after), nonEntryFuncs)),
JoinAllInsts(Concat(Concat(predefs, before), nonEntryFuncs)),
JoinAllInsts(Concat(Concat(predefs, after), nonEntryFuncs)),
/* skip_nop = */ false, /* do_validate = */ true);
}
@ -941,8 +934,8 @@ TEST_F(InlineTest, OpSampledImageOutOfBlock) {
// clang-format on
};
SinglePassRunAndCheck<opt::InlineExhaustivePass>(
JoinAllInsts(concat(concat(predefs, before), nonEntryFuncs)),
JoinAllInsts(concat(concat(predefs, after), nonEntryFuncs)),
JoinAllInsts(Concat(Concat(predefs, before), nonEntryFuncs)),
JoinAllInsts(Concat(Concat(predefs, after), nonEntryFuncs)),
/* skip_nop = */ false, /* do_validate = */ true);
}
@ -1147,8 +1140,8 @@ TEST_F(InlineTest, OpImageOutOfBlock) {
// clang-format on
};
SinglePassRunAndCheck<opt::InlineExhaustivePass>(
JoinAllInsts(concat(concat(predefs, before), nonEntryFuncs)),
JoinAllInsts(concat(concat(predefs, after), nonEntryFuncs)),
JoinAllInsts(Concat(Concat(predefs, before), nonEntryFuncs)),
JoinAllInsts(Concat(Concat(predefs, after), nonEntryFuncs)),
/* skip_nop = */ false, /* do_validate = */ true);
}
@ -1353,8 +1346,8 @@ TEST_F(InlineTest, OpImageAndOpSampledImageOutOfBlock) {
// clang-format on
};
SinglePassRunAndCheck<opt::InlineExhaustivePass>(
JoinAllInsts(concat(concat(predefs, before), nonEntryFuncs)),
JoinAllInsts(concat(concat(predefs, after), nonEntryFuncs)),
JoinAllInsts(Concat(Concat(predefs, before), nonEntryFuncs)),
JoinAllInsts(Concat(Concat(predefs, after), nonEntryFuncs)),
/* skip_nop = */ false, /* do_validate = */ true);
}

View File

@ -16,14 +16,6 @@
#include "pass_fixture.h"
#include "pass_utils.h"
template <typename T>
std::vector<T> concat(const std::vector<T>& a, const std::vector<T>& b) {
std::vector<T> ret;
std::copy(a.begin(), a.end(), back_inserter(ret));
std::copy(b.begin(), b.end(), back_inserter(ret));
return ret;
}
namespace {
using namespace spvtools;

View File

@ -49,6 +49,16 @@ std::string JoinAllInsts(const std::vector<const char*>& insts);
// will be ignored.
std::string JoinNonDebugInsts(const std::vector<const char*>& insts);
// Returns a vector that contains the contents of |a| followed by the contents
// of |b|.
template <typename T>
std::vector<T> Concat(const std::vector<T>& a, const std::vector<T>& b) {
std::vector<T> ret;
std::copy(a.begin(), a.end(), back_inserter(ret));
std::copy(b.begin(), b.end(), back_inserter(ret));
return ret;
}
} // namespace spvtools
#endif // LIBSPIRV_TEST_OPT_PASS_UTILS_H_

View File

@ -0,0 +1,427 @@
// Copyright (c) 2017 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "assembly_builder.h"
#include "gmock/gmock.h"
#include "pass_fixture.h"
#include "pass_utils.h"
#include <algorithm>
#include <cstdarg>
#include <iostream>
#include <sstream>
#include <unordered_set>
namespace {
using namespace spvtools;
using ::testing::HasSubstr;
using ::testing::MatchesRegex;
using StrengthReductionBasicTest = PassTest<::testing::Test>;
// Test to make sure we replace 5*8.
TEST_F(StrengthReductionBasicTest, BasicReplaceMulBy8) {
const std::vector<const char*> text = {
// clang-format off
"OpCapability Shader",
"%1 = OpExtInstImport \"GLSL.std.450\"",
"OpMemoryModel Logical GLSL450",
"OpEntryPoint Vertex %main \"main\"",
"OpName %main \"main\"",
"%void = OpTypeVoid",
"%4 = OpTypeFunction %void",
"%uint = OpTypeInt 32 0",
"%uint_5 = OpConstant %uint 5",
"%uint_8 = OpConstant %uint 8",
"%main = OpFunction %void None %4",
"%8 = OpLabel",
"%9 = OpIMul %uint %uint_5 %uint_8",
"OpReturn",
"OpFunctionEnd"
// clang-format on
};
auto result = SinglePassRunAndDisassemble<opt::StrengthReductionPass>(
JoinAllInsts(text), /* skip_nop = */ true);
EXPECT_EQ(opt::Pass::Status::SuccessWithChange, std::get<1>(result));
const std::string& output = std::get<0>(result);
EXPECT_THAT(output, Not(HasSubstr("OpIMul")));
EXPECT_THAT(output, HasSubstr("OpShiftLeftLogical %uint %uint_5 %uint_3"));
}
// Test to make sure we replace 16*5.
TEST_F(StrengthReductionBasicTest, BasicReplaceMulBy16) {
const std::vector<const char*> text = {
// clang-format off
"OpCapability Shader",
"%1 = OpExtInstImport \"GLSL.std.450\"",
"OpMemoryModel Logical GLSL450",
"OpEntryPoint Vertex %main \"main\"",
"OpName %main \"main\"",
"%void = OpTypeVoid",
"%4 = OpTypeFunction %void",
"%int = OpTypeInt 32 1",
"%int_16 = OpConstant %int 16",
"%int_5 = OpConstant %int 5",
"%main = OpFunction %void None %4",
"%8 = OpLabel",
"%9 = OpIMul %int %int_16 %int_5",
"OpReturn",
"OpFunctionEnd"
// clang-format on
};
auto result = SinglePassRunAndDisassemble<opt::StrengthReductionPass>(
JoinAllInsts(text), /* skip_nop = */ true);
EXPECT_EQ(opt::Pass::Status::SuccessWithChange, std::get<1>(result));
const std::string& output = std::get<0>(result);
EXPECT_THAT(output, Not(HasSubstr("OpIMul")));
EXPECT_THAT(output, HasSubstr("OpShiftLeftLogical %int %int_5 %uint_4"));
}
// Test to make sure we replace a multiple of 32 and 4.
TEST_F(StrengthReductionBasicTest, BasicTwoPowersOf2) {
// In this case, we have two powers of 2. Need to make sure we replace only
// one of them for the bit shift.
// clang-format off
const std::string text = R"(
OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint Vertex %main "main"
OpName %main "main"
%void = OpTypeVoid
%4 = OpTypeFunction %void
%int = OpTypeInt 32 1
%int_32 = OpConstant %int 32
%int_4 = OpConstant %int 4
%main = OpFunction %void None %4
%8 = OpLabel
%9 = OpIMul %int %int_32 %int_4
OpReturn
OpFunctionEnd
)";
// clang-format on
auto result = SinglePassRunAndDisassemble<opt::StrengthReductionPass>(
text, /* skip_nop = */ true);
EXPECT_EQ(opt::Pass::Status::SuccessWithChange, std::get<1>(result));
const std::string& output = std::get<0>(result);
EXPECT_THAT(output, Not(HasSubstr("OpIMul")));
EXPECT_THAT(output, HasSubstr("OpShiftLeftLogical %int %int_4 %uint_5"));
}
// Test to make sure we don't replace 0*5.
TEST_F(StrengthReductionBasicTest, BasicDontReplace0) {
const std::vector<const char*> text = {
// clang-format off
"OpCapability Shader",
"%1 = OpExtInstImport \"GLSL.std.450\"",
"OpMemoryModel Logical GLSL450",
"OpEntryPoint Vertex %main \"main\"",
"OpName %main \"main\"",
"%void = OpTypeVoid",
"%4 = OpTypeFunction %void",
"%int = OpTypeInt 32 1",
"%int_0 = OpConstant %int 0",
"%int_5 = OpConstant %int 5",
"%main = OpFunction %void None %4",
"%8 = OpLabel",
"%9 = OpIMul %int %int_0 %int_5",
"OpReturn",
"OpFunctionEnd"
// clang-format on
};
auto result = SinglePassRunAndDisassemble<opt::StrengthReductionPass>(
JoinAllInsts(text), /* skip_nop = */ true);
EXPECT_EQ(opt::Pass::Status::SuccessWithoutChange, std::get<1>(result));
}
// Test to make sure we do not replace a multiple of 5 and 7.
TEST_F(StrengthReductionBasicTest, BasicNoChange) {
const std::vector<const char*> text = {
// clang-format off
"OpCapability Shader",
"%1 = OpExtInstImport \"GLSL.std.450\"",
"OpMemoryModel Logical GLSL450",
"OpEntryPoint Vertex %2 \"main\"",
"OpName %2 \"main\"",
"%3 = OpTypeVoid",
"%4 = OpTypeFunction %3",
"%5 = OpTypeInt 32 1",
"%6 = OpTypeInt 32 0",
"%7 = OpConstant %5 5",
"%8 = OpConstant %5 7",
"%2 = OpFunction %3 None %4",
"%9 = OpLabel",
"%10 = OpIMul %5 %7 %8",
"OpReturn",
"OpFunctionEnd",
// clang-format on
};
auto result = SinglePassRunAndDisassemble<opt::StrengthReductionPass>(
JoinAllInsts(text), /* skip_nop = */ true);
EXPECT_EQ(opt::Pass::Status::SuccessWithoutChange, std::get<1>(result));
}
// Test to make sure constants and types are reused and not duplicated.
TEST_F(StrengthReductionBasicTest, NoDuplicateConstantsAndTypes) {
const std::vector<const char*> text = {
// clang-format off
"OpCapability Shader",
"%1 = OpExtInstImport \"GLSL.std.450\"",
"OpMemoryModel Logical GLSL450",
"OpEntryPoint Vertex %main \"main\"",
"OpName %main \"main\"",
"%void = OpTypeVoid",
"%4 = OpTypeFunction %void",
"%uint = OpTypeInt 32 0",
"%uint_8 = OpConstant %uint 8",
"%uint_3 = OpConstant %uint 3",
"%main = OpFunction %void None %4",
"%8 = OpLabel",
"%9 = OpIMul %uint %uint_8 %uint_3",
"OpReturn",
"OpFunctionEnd",
// clang-format on
};
auto result = SinglePassRunAndDisassemble<opt::StrengthReductionPass>(
JoinAllInsts(text), /* skip_nop = */ true);
EXPECT_EQ(opt::Pass::Status::SuccessWithChange, std::get<1>(result));
const std::string& output = std::get<0>(result);
EXPECT_THAT(output,
Not(MatchesRegex(".*OpConstant %uint 3.*OpConstant %uint 3.*")));
EXPECT_THAT(output, Not(MatchesRegex(".*OpTypeInt 32 0.*OpTypeInt 32 0.*")));
}
// Test to make sure we generate the constants only once
TEST_F(StrengthReductionBasicTest, BasicCreateOneConst) {
const std::vector<const char*> text = {
// clang-format off
"OpCapability Shader",
"%1 = OpExtInstImport \"GLSL.std.450\"",
"OpMemoryModel Logical GLSL450",
"OpEntryPoint Vertex %main \"main\"",
"OpName %main \"main\"",
"%void = OpTypeVoid",
"%4 = OpTypeFunction %void",
"%uint = OpTypeInt 32 0",
"%uint_5 = OpConstant %uint 5",
"%uint_9 = OpConstant %uint 9",
"%uint_128 = OpConstant %uint 128",
"%main = OpFunction %void None %4",
"%8 = OpLabel",
"%9 = OpIMul %uint %uint_5 %uint_128",
"%10 = OpIMul %uint %uint_9 %uint_128",
"OpReturn",
"OpFunctionEnd"
// clang-format on
};
auto result = SinglePassRunAndDisassemble<opt::StrengthReductionPass>(
JoinAllInsts(text), /* skip_nop = */ true);
EXPECT_EQ(opt::Pass::Status::SuccessWithChange, std::get<1>(result));
const std::string& output = std::get<0>(result);
EXPECT_THAT(output, Not(HasSubstr("OpIMul")));
EXPECT_THAT(output, HasSubstr("OpShiftLeftLogical %uint %uint_5 %uint_7"));
EXPECT_THAT(output, HasSubstr("OpShiftLeftLogical %uint %uint_9 %uint_7"));
}
// Test to make sure we generate the instructions in the correct position and
// that the uses get replaced as well. Here we check that the use in the return
// is replaced, we also check that we can replace two OpIMuls when one feeds the
// other.
TEST_F(StrengthReductionBasicTest, BasicCheckPositionAndReplacement) {
// This is just the preamble to set up the test.
const std::vector<const char*> common_text = {
// clang-format off
"OpCapability Shader",
"%1 = OpExtInstImport \"GLSL.std.450\"",
"OpMemoryModel Logical GLSL450",
"OpEntryPoint Fragment %main \"main\" %gl_FragColor",
"OpExecutionMode %main OriginUpperLeft",
"OpName %main \"main\"",
"OpName %foo_i1_ \"foo(i1;\"",
"OpName %n \"n\"",
"OpName %gl_FragColor \"gl_FragColor\"",
"OpName %param \"param\"",
"OpDecorate %gl_FragColor Location 0",
"%void = OpTypeVoid",
"%3 = OpTypeFunction %void",
"%int = OpTypeInt 32 1",
"%_ptr_Function_int = OpTypePointer Function %int",
"%8 = OpTypeFunction %int %_ptr_Function_int",
"%int_256 = OpConstant %int 256",
"%int_2 = OpConstant %int 2",
"%float = OpTypeFloat 32",
"%v4float = OpTypeVector %float 4",
"%_ptr_Output_v4float = OpTypePointer Output %v4float",
"%gl_FragColor = OpVariable %_ptr_Output_v4float Output",
"%float_1 = OpConstant %float 1",
"%int_10 = OpConstant %int 10",
"%float_0_4 = OpConstant %float 0.4",
"%float_0_8 = OpConstant %float 0.8",
"%uint = OpTypeInt 32 0",
"%uint_8 = OpConstant %uint 8",
"%uint_1 = OpConstant %uint 1",
"%main = OpFunction %void None %3",
"%5 = OpLabel",
"%param = OpVariable %_ptr_Function_int Function",
"OpStore %param %int_10",
"%26 = OpFunctionCall %int %foo_i1_ %param",
"%27 = OpConvertSToF %float %26",
"%28 = OpFDiv %float %float_1 %27",
"%31 = OpCompositeConstruct %v4float %28 %float_0_4 %float_0_8 %float_1",
"OpStore %gl_FragColor %31",
"OpReturn",
"OpFunctionEnd"
// clang-format on
};
// This is the real test. The two OpIMul should be replaced. The expected
// output is in |foo_after|.
const std::vector<const char*> foo_before = {
// clang-format off
"%foo_i1_ = OpFunction %int None %8",
"%n = OpFunctionParameter %_ptr_Function_int",
"%11 = OpLabel",
"%12 = OpLoad %int %n",
"%14 = OpIMul %int %12 %int_256",
"%16 = OpIMul %int %14 %int_2",
"OpReturnValue %16",
"OpFunctionEnd",
// clang-format on
};
const std::vector<const char*> foo_after = {
// clang-format off
"%foo_i1_ = OpFunction %int None %8",
"%n = OpFunctionParameter %_ptr_Function_int",
"%11 = OpLabel",
"%12 = OpLoad %int %n",
"%33 = OpShiftLeftLogical %int %12 %uint_8",
"%34 = OpShiftLeftLogical %int %33 %uint_1",
"OpReturnValue %34",
"OpFunctionEnd",
// clang-format on
};
SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
SinglePassRunAndCheck<opt::StrengthReductionPass>(
JoinAllInsts(Concat(common_text, foo_before)),
JoinAllInsts(Concat(common_text, foo_after)),
/* skip_nop = */ true, /* do_validate = */ true);
}
// Test that, when the result of an OpIMul instruction has more than 1 use, and
// the instruction is replaced, all of the uses of the results are replace with
// the new result.
TEST_F(StrengthReductionBasicTest, BasicTestMultipleReplacements) {
// This is just the preamble to set up the test.
const std::vector<const char*> common_text = {
// clang-format off
"OpCapability Shader",
"%1 = OpExtInstImport \"GLSL.std.450\"",
"OpMemoryModel Logical GLSL450",
"OpEntryPoint Fragment %main \"main\" %gl_FragColor",
"OpExecutionMode %main OriginUpperLeft",
"OpName %main \"main\"",
"OpName %foo_i1_ \"foo(i1;\"",
"OpName %n \"n\"",
"OpName %gl_FragColor \"gl_FragColor\"",
"OpName %param \"param\"",
"OpDecorate %gl_FragColor Location 0",
"%void = OpTypeVoid",
"%3 = OpTypeFunction %void",
"%int = OpTypeInt 32 1",
"%_ptr_Function_int = OpTypePointer Function %int",
"%8 = OpTypeFunction %int %_ptr_Function_int",
"%int_256 = OpConstant %int 256",
"%int_2 = OpConstant %int 2",
"%float = OpTypeFloat 32",
"%v4float = OpTypeVector %float 4",
"%_ptr_Output_v4float = OpTypePointer Output %v4float",
"%gl_FragColor = OpVariable %_ptr_Output_v4float Output",
"%float_1 = OpConstant %float 1",
"%int_10 = OpConstant %int 10",
"%float_0_4 = OpConstant %float 0.4",
"%float_0_8 = OpConstant %float 0.8",
"%uint = OpTypeInt 32 0",
"%uint_8 = OpConstant %uint 8",
"%uint_1 = OpConstant %uint 1",
"%main = OpFunction %void None %3",
"%5 = OpLabel",
"%param = OpVariable %_ptr_Function_int Function",
"OpStore %param %int_10",
"%26 = OpFunctionCall %int %foo_i1_ %param",
"%27 = OpConvertSToF %float %26",
"%28 = OpFDiv %float %float_1 %27",
"%31 = OpCompositeConstruct %v4float %28 %float_0_4 %float_0_8 %float_1",
"OpStore %gl_FragColor %31",
"OpReturn",
"OpFunctionEnd"
// clang-format on
};
// This is the real test. The two OpIMul instructions should be replaced. In
// particular, we want to be sure that both uses of %16 are changed to use the
// new result.
const std::vector<const char*> foo_before = {
// clang-format off
"%foo_i1_ = OpFunction %int None %8",
"%n = OpFunctionParameter %_ptr_Function_int",
"%11 = OpLabel",
"%12 = OpLoad %int %n",
"%14 = OpIMul %int %12 %int_256",
"%16 = OpIMul %int %14 %int_2",
"%17 = OpIAdd %int %14 %16",
"OpReturnValue %17",
"OpFunctionEnd",
// clang-format on
};
const std::vector<const char*> foo_after = {
// clang-format off
"%foo_i1_ = OpFunction %int None %8",
"%n = OpFunctionParameter %_ptr_Function_int",
"%11 = OpLabel",
"%12 = OpLoad %int %n",
"%34 = OpShiftLeftLogical %int %12 %uint_8",
"%35 = OpShiftLeftLogical %int %34 %uint_1",
"%17 = OpIAdd %int %34 %35",
"OpReturnValue %17",
"OpFunctionEnd",
// clang-format on
};
SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
SinglePassRunAndCheck<opt::StrengthReductionPass>(
JoinAllInsts(Concat(common_text, foo_before)),
JoinAllInsts(Concat(common_text, foo_after)),
/* skip_nop = */ true, /* do_validate = */ true);
}
} // anonymous namespace

View File

@ -112,6 +112,8 @@ Options:
Join two blocks into a single block if the second has the
first as its only predecessor. Performed only on entry point
call tree functions.
--strength-reduction
Replaces instructions with equivalent and less expensive ones.
-h, --help
Print this help.
--version
@ -200,6 +202,8 @@ int main(int argc, char** argv) {
optimizer.RegisterPass(CreateEliminateDeadConstantPass());
} else if (0 == strcmp(cur_arg, "--fold-spec-const-op-composite")) {
optimizer.RegisterPass(CreateFoldSpecConstantOpAndCompositePass());
} else if (0 == strcmp(cur_arg, "--strength-reduction")) {
optimizer.RegisterPass(CreateStrengthReductionPass());
} else if (0 == strcmp(cur_arg, "--unify-const")) {
optimizer.RegisterPass(CreateUnifyConstantPass());
} else if (0 == strcmp(cur_arg, "--flatten-decorations")) {