Refactor instruction folders (#2815)

* Refactor instruction folders

We want to refactor the instruction folder to allow different sets of
rules to be added to the instruction folder.  We might want different
sets of rules in different circumstances.

We also need a way to add rules for extended instructions.  Changes are
made to the FoldingRules class and ConstFoldingRules class to enable
that.

We added tests to check that we can fold extended instructions using the
new framework.

At the same time, I noticed that there were two tests that did not tests
what they were suppose to.  They could not be easily salvaged. #2813 was
opened to track adding the new tests.
This commit is contained in:
Steven Perron 2019-08-26 18:54:11 -04:00 committed by GitHub
parent 1eb89172a8
commit 15fc19d091
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 219 additions and 59 deletions

View File

@ -809,9 +809,62 @@ ConstantFoldingRule FoldFClampFeedingCompare(uint32_t cmp_opcode) {
};
}
ConstantFoldingRule FoldFMix() {
return [](IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>& constants)
-> const analysis::Constant* {
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
assert(inst->opcode() == SpvOpExtInst &&
"Expecting an extended instruction.");
assert(inst->GetSingleWordInOperand(0) ==
context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
"Expecting a GLSLstd450 extended instruction.");
assert(inst->GetSingleWordInOperand(1) == GLSLstd450FMix &&
"Expecting and FMix instruction.");
if (!inst->IsFloatingPointFoldingAllowed()) {
return nullptr;
}
// Make sure all FMix operands are constants.
for (uint32_t i = 1; i < 4; i++) {
if (constants[i] == nullptr) {
return nullptr;
}
}
const analysis::Constant* one;
if (constants[1]->type()->AsFloat()->width() == 32) {
one = const_mgr->GetConstant(constants[1]->type(),
utils::FloatProxy<float>(1.0f).GetWords());
} else {
one = const_mgr->GetConstant(constants[1]->type(),
utils::FloatProxy<double>(1.0).GetWords());
}
const analysis::Constant* temp1 =
FOLD_FPARITH_OP(-)(constants[1]->type(), one, constants[3], const_mgr);
if (temp1 == nullptr) {
return nullptr;
}
const analysis::Constant* temp2 = FOLD_FPARITH_OP(*)(
constants[1]->type(), constants[1], temp1, const_mgr);
if (temp2 == nullptr) {
return nullptr;
}
const analysis::Constant* temp3 = FOLD_FPARITH_OP(*)(
constants[2]->type(), constants[2], constants[3], const_mgr);
if (temp3 == nullptr) {
return nullptr;
}
return FOLD_FPARITH_OP(+)(temp2->type(), temp2, temp3, const_mgr);
};
}
} // namespace
ConstantFoldingRules::ConstantFoldingRules() {
void ConstantFoldingRules::AddFoldingRules() {
// Add all folding rules to the list for the opcodes to which they apply.
// Note that the order in which rules are added to the list matters. If a rule
// applies to the instruction, the rest of the rules will not be attempted.
@ -877,6 +930,14 @@ ConstantFoldingRules::ConstantFoldingRules() {
rules_[SpvOpFNegate].push_back(FoldFNegate());
rules_[SpvOpQuantizeToF16].push_back(FoldQuantizeToF16());
// Add rules for GLSLstd450
FeatureManager* feature_manager = context_->get_feature_mgr();
uint32_t ext_inst_glslstd450_id =
feature_manager->GetExtInstImportId_GLSLstd450();
if (ext_inst_glslstd450_id != 0) {
ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMix}].push_back(FoldFMix());
}
}
} // namespace opt
} // namespace spvtools

View File

@ -53,24 +53,74 @@ using ConstantFoldingRule = std::function<const analysis::Constant*(
const std::vector<const analysis::Constant*>& constants)>;
class ConstantFoldingRules {
protected:
// The |Key| and |Value| structs are used to by-pass a "decorated name length
// exceeded, name was truncated" warning on VS2013 and VS2015.
struct Key {
uint32_t instruction_set;
uint32_t opcode;
};
friend bool operator<(const Key& a, const Key& b) {
if (a.instruction_set < b.instruction_set) {
return true;
}
if (a.instruction_set > b.instruction_set) {
return false;
}
return a.opcode < b.opcode;
}
struct Value {
std::vector<ConstantFoldingRule> value;
void push_back(ConstantFoldingRule rule) { value.push_back(rule); }
};
public:
ConstantFoldingRules();
ConstantFoldingRules(IRContext* ctx) : context_(ctx) {}
virtual ~ConstantFoldingRules() = default;
// Returns true if there is at least 1 folding rule for |opcode|.
bool HasFoldingRule(SpvOp opcode) const { return rules_.count(opcode); }
bool HasFoldingRule(const Instruction* inst) const {
return !GetRulesForInstruction(inst).empty();
}
// Returns an vector of constant folding rules for |opcode|.
const std::vector<ConstantFoldingRule>& GetRulesForOpcode(
SpvOp opcode) const {
auto it = rules_.find(opcode);
if (it != rules_.end()) {
return it->second;
// Returns true if there is at least 1 folding rule for |inst|.
const std::vector<ConstantFoldingRule>& GetRulesForInstruction(
const Instruction* inst) const {
if (inst->opcode() != SpvOpExtInst) {
auto it = rules_.find(inst->opcode());
if (it != rules_.end()) {
return it->second.value;
}
} else {
uint32_t ext_inst_id = inst->GetSingleWordInOperand(0);
uint32_t ext_opcode = inst->GetSingleWordInOperand(1);
auto it = ext_rules_.find({ext_inst_id, ext_opcode});
if (it != ext_rules_.end()) {
return it->second.value;
}
}
return empty_vector_;
}
// Add the folding rules.
virtual void AddFoldingRules();
protected:
// |rules[opcode]| is the set of rules that can be applied to instructions
// with |opcode| as the opcode.
std::unordered_map<uint32_t, Value> rules_;
// The folding rules for extended instructions.
std::map<Key, Value> ext_rules_;
private:
std::unordered_map<uint32_t, std::vector<ConstantFoldingRule>> rules_;
// The context that the instruction to be folded will be a part of.
IRContext* context_;
// The empty set of rules to be used as the default return value in
// |GetRulesForInstruction|.
std::vector<ConstantFoldingRule> empty_vector_;
};

View File

@ -234,13 +234,12 @@ bool InstructionFolder::FoldInstructionInternal(Instruction* inst) const {
return true;
}
SpvOp opcode = inst->opcode();
analysis::ConstantManager* const_manager = context_->get_constant_mgr();
std::vector<const analysis::Constant*> constants =
const_manager->GetOperandConstants(inst);
for (const FoldingRule& rule : GetFoldingRules().GetRulesForOpcode(opcode)) {
for (const FoldingRule& rule :
GetFoldingRules().GetRulesForInstruction(inst)) {
if (rule(context_, inst, constants)) {
return true;
}
@ -623,7 +622,7 @@ Instruction* InstructionFolder::FoldInstructionToConstant(
analysis::ConstantManager* const_mgr = context_->get_constant_mgr();
if (!inst->IsFoldableByFoldScalar() &&
!GetConstantFoldingRules().HasFoldingRule(inst->opcode())) {
!GetConstantFoldingRules().HasFoldingRule(inst)) {
return nullptr;
}
// Collect the values of the constant parameters.
@ -641,19 +640,16 @@ Instruction* InstructionFolder::FoldInstructionToConstant(
}
});
if (GetConstantFoldingRules().HasFoldingRule(inst->opcode())) {
const analysis::Constant* folded_const = nullptr;
for (auto rule :
GetConstantFoldingRules().GetRulesForOpcode(inst->opcode())) {
folded_const = rule(context_, inst, constants);
if (folded_const != nullptr) {
Instruction* const_inst =
const_mgr->GetDefiningInstruction(folded_const, inst->type_id());
assert(const_inst->type_id() == inst->type_id());
// May be a new instruction that needs to be analysed.
context_->UpdateDefUse(const_inst);
return const_inst;
}
const analysis::Constant* folded_const = nullptr;
for (auto rule : GetConstantFoldingRules().GetRulesForInstruction(inst)) {
folded_const = rule(context_, inst, constants);
if (folded_const != nullptr) {
Instruction* const_inst =
const_mgr->GetDefiningInstruction(folded_const, inst->type_id());
assert(const_inst->type_id() == inst->type_id());
// May be a new instruction that needs to be analysed.
context_->UpdateDefUse(const_inst);
return const_inst;
}
}

View File

@ -28,7 +28,23 @@ namespace opt {
class InstructionFolder {
public:
explicit InstructionFolder(IRContext* context) : context_(context) {}
explicit InstructionFolder(IRContext* context)
: context_(context),
const_folding_rules_(new ConstantFoldingRules(context)),
folding_rules_(new FoldingRules(context)) {
folding_rules_->AddFoldingRules();
const_folding_rules_->AddFoldingRules();
}
explicit InstructionFolder(
IRContext* context, std::unique_ptr<FoldingRules>&& folding_rules,
std::unique_ptr<ConstantFoldingRules>&& constant_folding_rules)
: context_(context),
const_folding_rules_(std::move(constant_folding_rules)),
folding_rules_(std::move(folding_rules)) {
folding_rules_->AddFoldingRules();
const_folding_rules_->AddFoldingRules();
}
// Returns the result of folding a scalar instruction with the given |opcode|
// and |operands|. Each entry in |operands| is a pointer to an
@ -95,18 +111,18 @@ class InstructionFolder {
bool FoldInstruction(Instruction* inst) const;
// Return true if this opcode has a const folding rule associtated with it.
bool HasConstFoldingRule(SpvOp opcode) const {
return GetConstantFoldingRules().HasFoldingRule(opcode);
bool HasConstFoldingRule(const Instruction* inst) const {
return GetConstantFoldingRules().HasFoldingRule(inst);
}
private:
// Returns a reference to the ConstnatFoldingRules instance.
const ConstantFoldingRules& GetConstantFoldingRules() const {
return const_folding_rules;
return *const_folding_rules_;
}
// Returns a reference to the FoldingRules instance.
const FoldingRules& GetFoldingRules() const { return folding_rules; }
const FoldingRules& GetFoldingRules() const { return *folding_rules_; }
// Returns the single-word result from performing the given unary operation on
// the operand value which is passed in as a 32-bit word.
@ -159,10 +175,10 @@ class InstructionFolder {
IRContext* context_;
// Folding rules used by |FoldInstructionToConstant| and |FoldInstruction|.
ConstantFoldingRules const_folding_rules;
std::unique_ptr<ConstantFoldingRules> const_folding_rules_;
// Folding rules used by |FoldInstruction|.
FoldingRules folding_rules;
std::unique_ptr<FoldingRules> folding_rules_;
};
} // namespace opt

View File

@ -2200,7 +2200,7 @@ FoldingRule RemoveRedundantOperands() {
} // namespace
FoldingRules::FoldingRules() {
void FoldingRules::AddFoldingRules() {
// Add all folding rules to the list for the opcodes to which they apply.
// Note that the order in which rules are added to the list matters. If a rule
// applies to the instruction, the rest of the rules will not be attempted.
@ -2216,8 +2216,6 @@ FoldingRules::FoldingRules() {
rules_[SpvOpEntryPoint].push_back(RemoveRedundantOperands());
rules_[SpvOpExtInst].push_back(RedundantFMix());
rules_[SpvOpFAdd].push_back(RedundantFAdd());
rules_[SpvOpFAdd].push_back(MergeAddNegateArithmetic());
rules_[SpvOpFAdd].push_back(MergeAddAddArithmetic());
@ -2271,6 +2269,15 @@ FoldingRules::FoldingRules() {
rules_[SpvOpUDiv].push_back(MergeDivNegateArithmetic());
rules_[SpvOpVectorShuffle].push_back(VectorShuffleFeedingShuffle());
FeatureManager* feature_manager = context_->get_feature_mgr();
// Add rules for GLSLstd450
uint32_t ext_inst_glslstd450_id =
feature_manager->GetExtInstImportId_GLSLstd450();
if (ext_inst_glslstd450_id != 0) {
ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMix}].push_back(
RedundantFMix());
}
}
} // namespace opt
} // namespace spvtools

View File

@ -58,19 +58,58 @@ using FoldingRule = std::function<bool(
class FoldingRules {
public:
FoldingRules();
using FoldingRuleSet = std::vector<FoldingRule>;
const std::vector<FoldingRule>& GetRulesForOpcode(SpvOp opcode) const {
auto it = rules_.find(opcode);
if (it != rules_.end()) {
return it->second;
explicit FoldingRules(IRContext* ctx) : context_(ctx) {}
virtual ~FoldingRules() = default;
const FoldingRuleSet& GetRulesForInstruction(Instruction* inst) const {
if (inst->opcode() != SpvOpExtInst) {
auto it = rules_.find(inst->opcode());
if (it != rules_.end()) {
return it->second;
}
} else {
uint32_t ext_inst_id = inst->GetSingleWordInOperand(0);
uint32_t ext_opcode = inst->GetSingleWordInOperand(1);
auto it = ext_rules_.find({ext_inst_id, ext_opcode});
if (it != ext_rules_.end()) {
return it->second;
}
}
return empty_vector_;
}
IRContext* context() { return context_; }
// Adds the folding rules for the object.
virtual void AddFoldingRules();
protected:
// The folding rules for core instructions.
std::unordered_map<uint32_t, FoldingRuleSet> rules_;
// The folding rules for extended instructions.
struct Key {
uint32_t instruction_set;
uint32_t opcode;
};
friend bool operator<(const Key& a, const Key& b) {
if (a.instruction_set < b.instruction_set) {
return true;
}
if (a.instruction_set > b.instruction_set) {
return false;
}
return a.opcode < b.opcode;
}
std::map<Key, FoldingRuleSet> ext_rules_;
private:
std::unordered_map<uint32_t, std::vector<FoldingRule>> rules_;
std::vector<FoldingRule> empty_vector_;
IRContext* context_;
FoldingRuleSet empty_vector_;
};
} // namespace opt

View File

@ -469,7 +469,7 @@ bool Instruction::IsOpaqueType() const {
bool Instruction::IsFoldable() const {
return IsFoldableByFoldScalar() ||
context()->get_instruction_folder().HasConstFoldingRule(opcode());
context()->get_instruction_folder().HasConstFoldingRule(this);
}
bool Instruction::IsFoldableByFoldScalar() const {

View File

@ -49,7 +49,7 @@ bool SimplificationPass::SimplifyFunction(Function* function) {
cfg()->ForEachBlockInReversePostOrder(
function->entry().get(),
[&modified, &process_phis, &work_list, &in_work_list, &inst_to_kill,
folder, this](BasicBlock* bb) {
&folder, this](BasicBlock* bb) {
for (Instruction* inst = &*bb->begin(); inst; inst = inst->NextNode()) {
if (inst->opcode() == SpvOpPhi) {
process_phis.insert(inst);

View File

@ -210,6 +210,7 @@ OpName %main "main"
%float_2049 = OpConstant %float 2049
%float_n2049 = OpConstant %float -2049
%float_0p5 = OpConstant %float 0.5
%float_0p2 = OpConstant %float 0.2
%float_pi = OpConstant %float 1.5555
%float_1e16 = OpConstant %float 1e16
%float_n1e16 = OpConstant %float -1e16
@ -1465,24 +1466,14 @@ INSTANTIATE_TEST_SUITE_P(FloatConstantFoldingTest, FloatInstructionFoldingTest,
"OpReturn\n" +
"OpFunctionEnd",
2, std::numeric_limits<float>::quiet_NaN()),
// Test case 20: QuantizeToF16 inf
// Test case 20: FMix 1.0 4.0 0.2
InstructionFoldingCase<float>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpFDiv %float %float_1 %float_0\n" +
"%3 = OpQuantizeToF16 %float %3\n" +
"%2 = OpExtInst %float %1 FMix %float_1 %float_4 %float_0p2\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, std::numeric_limits<float>::infinity()),
// Test case 21: QuantizeToF16 -inf
InstructionFoldingCase<float>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpFDiv %float %float_n1 %float_0\n" +
"%3 = OpQuantizeToF16 %float %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, -std::numeric_limits<float>::infinity())
2, 1.6f)
));
// clang-format on