Add folding of OpCompositeExtract and OpConstantComposite constant instructions.

Create files for constant folding rules.

Add the rules for OpConstantComposite and OpCompositeExtract.
This commit is contained in:
Steven Perron 2018-02-08 10:59:03 -05:00 committed by David Neto
parent 886859159e
commit 1d7b1423f9
13 changed files with 292 additions and 34 deletions

View File

@ -63,9 +63,10 @@ SPVTOOLS_OPT_SRC_FILES := \
source/opt/cfg.cpp \
source/opt/cfg_cleanup_pass.cpp \
source/opt/ccp_pass.cpp \
source/opt/common_uniform_elim_pass.cpp \
source/opt/compact_ids_pass.cpp \
source/opt/composite.cpp \
source/opt/common_uniform_elim_pass.cpp \
source/opt/const_folding_rules.cpp \
source/opt/constants.cpp \
source/opt/dead_branch_elim_pass.cpp \
source/opt/dead_insert_elim_pass.cpp \

View File

@ -22,6 +22,7 @@ add_library(SPIRV-Tools-opt
common_uniform_elim_pass.h
compact_ids_pass.h
composite.h
const_folding_rules.h
constants.h
dead_branch_elim_pass.h
dead_insert_elim_pass.h
@ -94,6 +95,7 @@ add_library(SPIRV-Tools-opt
common_uniform_elim_pass.cpp
compact_ids_pass.cpp
composite.cpp
const_folding_rules.cpp
constants.cpp
dead_branch_elim_pass.cpp
dead_insert_elim_pass.cpp

View File

@ -0,0 +1,98 @@
// Copyright (c) 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "const_folding_rules.h"
namespace spvtools {
namespace opt {
namespace {
const uint32_t kExtractCompositeIdInIdx = 0;
ConstantFoldingRule FoldExtractWithConstants() {
// Folds an OpcompositeExtract where input is a composite constant.
return [](ir::Instruction* inst,
const std::vector<const analysis::Constant*>& constants)
-> const analysis::Constant* {
const analysis::Constant* c = constants[kExtractCompositeIdInIdx];
if (c == nullptr) {
return nullptr;
}
for (uint32_t i = 1; i < inst->NumInOperands(); ++i) {
uint32_t element_index = inst->GetSingleWordInOperand(i);
if (c->AsNullConstant()) {
// Return Null for the return type.
ir::IRContext* context = inst->context();
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
analysis::TypeManager* type_mgr = context->get_type_mgr();
const analysis::NullConstant null_const(
type_mgr->GetType(inst->type_id()));
const analysis::Constant* real_const =
const_mgr->FindConstant(&null_const);
if (real_const == nullptr) {
ir::Instruction* const_inst =
const_mgr->GetDefiningInstruction(&null_const);
real_const = const_mgr->GetConstantFromInst(const_inst);
}
return real_const;
}
auto cc = c->AsCompositeConstant();
assert(cc != nullptr);
auto components = cc->GetComponents();
c = components[element_index];
}
return c;
};
}
ConstantFoldingRule FoldCompositeWithConstants() {
// Folds an OpCompositeConstruct where all of the inputs are constants to a
// constant. A new constant is created if necessary.
return [](ir::Instruction* inst,
const std::vector<const analysis::Constant*>& constants)
-> const analysis::Constant* {
ir::IRContext* context = inst->context();
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
analysis::TypeManager* type_mgr = context->get_type_mgr();
const analysis::Type* new_type = type_mgr->GetType(inst->type_id());
std::vector<uint32_t> ids;
for (const analysis::Constant* element_const : constants) {
if (element_const == nullptr) {
return nullptr;
}
uint32_t element_id = const_mgr->FindDeclaredConstant(element_const);
if (element_id == 0) {
return nullptr;
}
ids.push_back(element_id);
}
return const_mgr->GetConstant(new_type, ids);
};
}
} // namespace
spvtools::opt::ConstantFoldingRules::ConstantFoldingRules() {
// Add all folding rules to the list for the opcodes to which they apply.
// Note that the order in which rules are added to the list matters. If a rule
// applies to the instruction, the rest of the rules will not be attempted.
// Take that into consideration.
rules_[SpvOpCompositeConstruct].push_back(FoldCompositeWithConstants());
rules_[SpvOpCompositeExtract].push_back(FoldExtractWithConstants());
}
} // namespace opt
} // namespace spvtools

View File

@ -0,0 +1,84 @@
// Copyright (c) 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef LIBSPIRV_OPT_CONST_FOLDING_RULES_H_
#define LIBSPIRV_OPT_CONST_FOLDING_RULES_H_
#include <vector>
#include "../../external/spirv-headers/include/spirv/1.2/spirv.h"
#include "constants.h"
#include "def_use_manager.h"
#include "folding_rules.h"
#include "ir_builder.h"
#include "ir_context.h"
namespace spvtools {
namespace opt {
// Constant Folding Rules:
//
// The folding mechanism is built around the concept of a |ConstantFoldingRule|.
// A constant folding rule is a function that implements a method of simplifying
// an instruction to a constant.
//
// The inputs to a folding rule are:
// |inst| - the instruction to be simplified.
// |constants| - if an in-operands is an id of a constant, then the
// corresponding value in |constants| contains that
// constant value. Otherwise, the corresponding entry in
// |constants| is |nullptr|.
//
// A constant folding rule returns a pointer to an Constant if |inst| can be
// simplified using this rule. Otherwise, it returns |nullptr|.
//
// See const_folding_rules.cpp for examples on how to write a constant folding
// rule.
//
// Be sure to add new constant folding rules to the table of constant folding
// rules in the constructor for ConstantFoldingRules. The new rule should be
// added to the list for every opcode that it applies to. Note that earlier
// rules in the list are given priority. That is, if an earlier rule is able to
// fold an instruction, the later rules will not be attempted.
using ConstantFoldingRule = std::function<const analysis::Constant*(
ir::Instruction* inst,
const std::vector<const analysis::Constant*>& constants)>;
class ConstantFoldingRules {
public:
ConstantFoldingRules();
// Returns true if there is at least 1 folding rule for |opcode|.
bool HasFoldingRule(SpvOp opcode) const { return rules_.count(opcode); }
// Returns an vector of constant folding rules for |opcode|.
const std::vector<ConstantFoldingRule>& GetRulesForOpcode(
SpvOp opcode) const {
auto it = rules_.find(opcode);
if (it != rules_.end()) {
return it->second;
}
return empty_vector_;
}
private:
std::unordered_map<uint32_t, std::vector<ConstantFoldingRule>> rules_;
std::vector<ConstantFoldingRule> empty_vector_;
};
} // namespace opt
} // namespace spvtools
#endif // LIBSPIRV_OPT_CONST_FOLDING_RULES_H_

View File

@ -204,7 +204,7 @@ class CompositeConstant : public Constant {
CompositeConstant* AsCompositeConstant() override { return this; }
const CompositeConstant* AsCompositeConstant() const override { return this; }
// Returns a const reference of the components holded in this composite
// Returns a const reference of the components held in this composite
// constant.
virtual const std::vector<const Constant*>& GetComponents() const {
return components_;

View File

@ -71,6 +71,18 @@ void DefUseManager::AnalyzeInstDefUse(ir::Instruction* inst) {
AnalyzeInstUse(inst);
}
void DefUseManager::UpdateDefUse(ir::Instruction* inst) {
const uint32_t def_id = inst->result_id();
if (def_id != 0) {
auto iter = id_to_def_.find(def_id);
if (iter != id_to_def_.end()) {
AnalyzeInstDef(inst);
} else {
}
}
AnalyzeInstUse(inst);
}
ir::Instruction* DefUseManager::GetDef(uint32_t id) {
auto iter = id_to_def_.find(id);
if (iter == id_to_def_.end()) return nullptr;

View File

@ -214,6 +214,10 @@ class DefUseManager {
return !(lhs == rhs);
}
// If |inst| has not already been analysed, then analyses its defintion and
// uses.
void UpdateDefUse(ir::Instruction* inst);
private:
using InstToUsedIdsMap =
std::unordered_map<const ir::Instruction*, std::vector<uint32_t>>;

View File

@ -14,15 +14,16 @@
#include "fold.h"
#include <cassert>
#include <cstdint>
#include <vector>
#include "const_folding_rules.h"
#include "def_use_manager.h"
#include "folding_rules.h"
#include "ir_builder.h"
#include "ir_context.h"
#include <cassert>
#include <cstdint>
#include <vector>
namespace spvtools {
namespace opt {
@ -40,6 +41,11 @@ namespace {
#define UINT32_MAX 0xffffffff /* 4294967295U */
#endif
const ConstantFoldingRules& GetConstantFoldingRules() {
static ConstantFoldingRules* rules = new ConstantFoldingRules();
return *rules;
}
// Returns the single-word result from performing the given unary operation on
// the operand value which is passed in as a 32-bit word.
uint32_t UnaryOperate(SpvOp opcode, uint32_t operand) {
@ -603,10 +609,6 @@ bool IsFoldableConstant(const analysis::Constant* cst) {
ir::Instruction* FoldInstructionToConstant(
ir::Instruction* inst, std::function<uint32_t(uint32_t)> id_map) {
if (!inst->IsFoldable()) {
return nullptr;
}
ir::IRContext* context = inst->context();
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
@ -617,7 +619,7 @@ ir::Instruction* FoldInstructionToConstant(
&id_map](uint32_t* op_id) {
uint32_t id = id_map(*op_id);
const analysis::Constant* const_op = const_mgr->FindDeclaredConstant(id);
if (!const_op || !IsFoldableConstant(const_op)) {
if (!const_op) {
constants.push_back(nullptr);
missing_constants = true;
return;
@ -625,15 +627,30 @@ ir::Instruction* FoldInstructionToConstant(
constants.push_back(const_op);
});
if (GetConstantFoldingRules().HasFoldingRule(inst->opcode())) {
const analysis::Constant* folded_const = nullptr;
for (auto rule :
GetConstantFoldingRules().GetRulesForOpcode(inst->opcode())) {
folded_const = rule(inst, constants);
if (folded_const != nullptr) {
ir::Instruction* const_inst =
const_mgr->GetDefiningInstruction(folded_const);
// May be a new instruction that needs to be analysed.
context->UpdateDefUse(const_inst);
return const_inst;
}
}
}
uint32_t result_val = 0;
bool successful = false;
// If all parameters are constant, fold the instruction to a constant.
if (!missing_constants) {
if (!missing_constants && inst->IsFoldable()) {
result_val = FoldScalars(inst->opcode(), constants);
successful = true;
}
if (!successful) {
if (!successful && inst->IsFoldable()) {
successful = FoldIntegerOpToConstant(inst, id_map, &result_val);
}

View File

@ -69,8 +69,8 @@ FoldingRule CompositeConstructFeedingExtract() {
// Add the remaining indices for extraction.
for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
operands.push_back(
{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(i)}});
operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
{inst->GetSingleWordInOperand(i)}});
}
} else {
@ -302,14 +302,14 @@ spvtools::opt::FoldingRules::FoldingRules() {
// applies to the instruction, the rest of the rules will not be attempted.
// Take that into consideration.
rules[SpvOpCompositeConstruct].push_back(CompositeExtractFeedingConstruct());
rules_[SpvOpCompositeConstruct].push_back(CompositeExtractFeedingConstruct());
rules[SpvOpCompositeExtract].push_back(InsertFeedingExtract());
rules[SpvOpCompositeExtract].push_back(CompositeConstructFeedingExtract());
rules_[SpvOpCompositeExtract].push_back(InsertFeedingExtract());
rules_[SpvOpCompositeExtract].push_back(CompositeConstructFeedingExtract());
rules[SpvOpIMul].push_back(IntMultipleBy1());
rules_[SpvOpIMul].push_back(IntMultipleBy1());
rules[SpvOpPhi].push_back(RedundantPhi());
rules_[SpvOpPhi].push_back(RedundantPhi());
}
} // namespace opt
} // namespace spvtools

View File

@ -63,15 +63,15 @@ class FoldingRules {
FoldingRules();
const std::vector<FoldingRule>& GetRulesForOpcode(SpvOp opcode) {
auto it = rules.find(opcode);
if (it != rules.end()) {
auto it = rules_.find(opcode);
if (it != rules_.end()) {
return it->second;
}
return empty_vector_;
}
private:
std::unordered_map<uint32_t, std::vector<FoldingRule>> rules;
std::unordered_map<uint32_t, std::vector<FoldingRule>> rules_;
std::vector<FoldingRule> empty_vector_;
};

View File

@ -399,6 +399,10 @@ class IRContext {
// Returns the grammar for this context.
const libspirv::AssemblyGrammar& grammar() const { return grammar_; }
// If |inst| has not yet been analysed by the def-use manager, then analyse
// its definitions and uses.
inline void UpdateDefUse(Instruction* inst);
private:
// Builds the def-use manager from scratch, even if it was already valid.
void BuildDefUseManager() {
@ -723,6 +727,12 @@ void IRContext::AnalyzeDefUse(Instruction* inst) {
}
}
void IRContext::UpdateDefUse(Instruction* inst) {
if (AreAnalysesValid(kAnalysisDefUse)) {
get_def_use_mgr()->UpdateDefUse(inst);
}
}
} // namespace ir
} // namespace spvtools
#endif // SPIRV_TOOLS_IR_CONTEXT_H

View File

@ -78,6 +78,8 @@ TEST_P(IntegerInstructionFoldingTest, Case) {
// Returns a common SPIR-V header for all of the test that follow.
#define INT_0_ID 100
#define TRUE_ID 101
#define VEC2_0_ID 102
#define INT_7_ID 103
const std::string& Header() {
static const std::string header = R"(OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
@ -89,8 +91,8 @@ OpName %main "main"
%void = OpTypeVoid
%void_func = OpTypeFunction %void
%bool = OpTypeBool
%true = OpConstantTrue %bool
%101 = OpConstantTrue %bool ; Need a def with an numerical id to define id maps.
%true = OpConstantTrue %bool
%false = OpConstantFalse %bool
%short = OpTypeInt 16 1
%int = OpTypeInt 32 1
@ -104,9 +106,10 @@ OpName %main "main"
%_ptr_bool = OpTypePointer Function %bool
%short_0 = OpConstant %short 0
%short_3 = OpConstant %short 3
%100 = OpConstant %int 0 ; Need a def with an numerical id to define id maps.
%103 = OpConstant %int 7 ; Need a def with an numerical id to define id maps.
%int_0 = OpConstant %int 0
%int_1 = OpConstant %int 1
%100 = OpConstant %int 0 ; Need a def with an numerical id to define id maps.
%int_3 = OpConstant %int 3
%int_min = OpConstant %int -2147483648
%int_max = OpConstant %int 2147483647
@ -116,8 +119,11 @@ OpName %main "main"
%uint_3 = OpConstant %uint 3
%uint_32 = OpConstant %uint 32
%uint_max = OpConstant %uint -1
%v2int_undef = OpUndef %v2int
%struct_v2int_int_int_null = OpConstantNull %struct_v2int_int_int
%102 = OpConstantComposite %v2int %103 %103
%v4int_0_0_0_0 = OpConstantComposite %v4int %int_0 %int_0 %int_0 %int_0
%struct_undef_0_0 = OpConstantComposite %struct_v2int_int_int %v2int_undef %int_0 %int_0
)";
return header;
@ -1227,7 +1233,23 @@ INSTANTIATE_TEST_CASE_P(CompositeExtractFoldingTest, GeneralInstructionFoldingTe
"%5 = OpCompositeExtract %int %4 0 1\n" +
"OpReturn\n" +
"OpFunctionEnd",
5, 2)
5, 2),
// Test case 7: fold constant extract.
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpCompositeExtract %int %102 1\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, INT_7_ID),
// Test case 8: constant struct has OpUndef
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpCompositeExtract %int %struct_undef_0_0 0 1\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0)
));
INSTANTIATE_TEST_CASE_P(CompositeConstructFoldingTest, GeneralInstructionFoldingTest,
@ -1282,7 +1304,15 @@ INSTANTIATE_TEST_CASE_P(CompositeConstructFoldingTest, GeneralInstructionFolding
"%7 = OpCompositeConstruct %v4int %3 %4 %5\n" +
"OpReturn\n" +
"OpFunctionEnd",
7, 0)
7, 0),
// Test case 4: Fold construct with constants to constant.
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpCompositeConstruct %v2int %103 %103\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, VEC2_0_ID)
));
INSTANTIATE_TEST_CASE_P(PhiFoldingTest, GeneralInstructionFoldingTest,

View File

@ -92,8 +92,6 @@ TEST_F(SimplificationTest, AcrossBasicBlocks) {
%int = OpTypeInt 32 1
%v4int = OpTypeVector %int 4
%int_0 = OpConstant %int 0
; CHECK: [[constant:%[a-zA-Z_\d]+]] = OpConstantComposite %v4int %int_0 %int_0 %int_0 %int_0
%13 = OpConstantComposite %v4int %int_0 %int_0 %int_0 %int_0
%_ptr_Input_v4int = OpTypePointer Input %v4int
%i = OpVariable %_ptr_Input_v4int Input
%uint = OpTypeInt 32 0
@ -115,14 +113,14 @@ TEST_F(SimplificationTest, AcrossBasicBlocks) {
OpSelectionMerge %30 None
OpBranchConditional %29 %31 %32
%31 = OpLabel
%43 = OpCopyObject %v4int %13
%43 = OpCopyObject %v4int %25
OpBranch %30
%32 = OpLabel
%45 = OpCopyObject %v4int %13
%45 = OpCopyObject %v4int %25
OpBranch %30
%30 = OpLabel
%50 = OpPhi %v4int %43 %31 %45 %32
; CHECK: [[extract1:%[a-zA-Z_\d]+]] = OpCompositeExtract %int [[constant]] 0
; CHECK: [[extract1:%[a-zA-Z_\d]+]] = OpCompositeExtract %int [[load]] 0
%47 = OpCompositeExtract %int %50 0
; CHECK: [[extract2:%[a-zA-Z_\d]+]] = OpCompositeExtract %int [[load]] 1
%49 = OpCompositeExtract %int %41 1
@ -170,9 +168,11 @@ TEST_F(SimplificationTest, ThroughLoops) {
%68 = OpUndef %v4int
%main = OpFunction %void None %8
%23 = OpLabel
; CHECK: [[load:%[a-zA-Z_\d]+]] = OpLoad %v4int %i
%load = OpLoad %v4int %i
OpBranch %24
%24 = OpLabel
%67 = OpPhi %v4int %13 %23 %64 %26
%67 = OpPhi %v4int %load %23 %64 %26
; CHECK: OpLoopMerge [[merge_lab:%[a-zA-Z_\d]+]]
OpLoopMerge %25 %26 None
OpBranch %27
@ -191,7 +191,7 @@ TEST_F(SimplificationTest, ThroughLoops) {
OpBranch %24
%25 = OpLabel
; CHECK: [[merge_lab]] = OpLabel
; CHECK: [[extract:%[a-zA-Z_\d]+]] = OpCompositeExtract %int [[constant]] 0
; CHECK: [[extract:%[a-zA-Z_\d]+]] = OpCompositeExtract %int [[load]] 0
%66 = OpCompositeExtract %int %67 0
; CHECK-NEXT: OpStore %o [[extract]]
OpStore %o %66