mirror of
https://github.com/KhronosGroup/SPIRV-Tools
synced 2024-11-26 13:20:05 +00:00
Refactor instruction folders (#2815)
* Refactor instruction folders We want to refactor the instruction folder to allow different sets of rules to be added to the instruction folder. We might want different sets of rules in different circumstances. We also need a way to add rules for extended instructions. Changes are made to the FoldingRules class and ConstFoldingRules class to enable that. We added tests to check that we can fold extended instructions using the new framework. At the same time, I noticed that there were two tests that did not tests what they were suppose to. They could not be easily salvaged. #2813 was opened to track adding the new tests.
This commit is contained in:
parent
1eb89172a8
commit
15fc19d091
@ -809,9 +809,62 @@ ConstantFoldingRule FoldFClampFeedingCompare(uint32_t cmp_opcode) {
|
||||
};
|
||||
}
|
||||
|
||||
ConstantFoldingRule FoldFMix() {
|
||||
return [](IRContext* context, Instruction* inst,
|
||||
const std::vector<const analysis::Constant*>& constants)
|
||||
-> const analysis::Constant* {
|
||||
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
|
||||
assert(inst->opcode() == SpvOpExtInst &&
|
||||
"Expecting an extended instruction.");
|
||||
assert(inst->GetSingleWordInOperand(0) ==
|
||||
context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
|
||||
"Expecting a GLSLstd450 extended instruction.");
|
||||
assert(inst->GetSingleWordInOperand(1) == GLSLstd450FMix &&
|
||||
"Expecting and FMix instruction.");
|
||||
|
||||
if (!inst->IsFloatingPointFoldingAllowed()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Make sure all FMix operands are constants.
|
||||
for (uint32_t i = 1; i < 4; i++) {
|
||||
if (constants[i] == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
const analysis::Constant* one;
|
||||
if (constants[1]->type()->AsFloat()->width() == 32) {
|
||||
one = const_mgr->GetConstant(constants[1]->type(),
|
||||
utils::FloatProxy<float>(1.0f).GetWords());
|
||||
} else {
|
||||
one = const_mgr->GetConstant(constants[1]->type(),
|
||||
utils::FloatProxy<double>(1.0).GetWords());
|
||||
}
|
||||
|
||||
const analysis::Constant* temp1 =
|
||||
FOLD_FPARITH_OP(-)(constants[1]->type(), one, constants[3], const_mgr);
|
||||
if (temp1 == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const analysis::Constant* temp2 = FOLD_FPARITH_OP(*)(
|
||||
constants[1]->type(), constants[1], temp1, const_mgr);
|
||||
if (temp2 == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
const analysis::Constant* temp3 = FOLD_FPARITH_OP(*)(
|
||||
constants[2]->type(), constants[2], constants[3], const_mgr);
|
||||
if (temp3 == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
return FOLD_FPARITH_OP(+)(temp2->type(), temp2, temp3, const_mgr);
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
ConstantFoldingRules::ConstantFoldingRules() {
|
||||
void ConstantFoldingRules::AddFoldingRules() {
|
||||
// Add all folding rules to the list for the opcodes to which they apply.
|
||||
// Note that the order in which rules are added to the list matters. If a rule
|
||||
// applies to the instruction, the rest of the rules will not be attempted.
|
||||
@ -877,6 +930,14 @@ ConstantFoldingRules::ConstantFoldingRules() {
|
||||
|
||||
rules_[SpvOpFNegate].push_back(FoldFNegate());
|
||||
rules_[SpvOpQuantizeToF16].push_back(FoldQuantizeToF16());
|
||||
|
||||
// Add rules for GLSLstd450
|
||||
FeatureManager* feature_manager = context_->get_feature_mgr();
|
||||
uint32_t ext_inst_glslstd450_id =
|
||||
feature_manager->GetExtInstImportId_GLSLstd450();
|
||||
if (ext_inst_glslstd450_id != 0) {
|
||||
ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMix}].push_back(FoldFMix());
|
||||
}
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace spvtools
|
||||
|
@ -53,24 +53,74 @@ using ConstantFoldingRule = std::function<const analysis::Constant*(
|
||||
const std::vector<const analysis::Constant*>& constants)>;
|
||||
|
||||
class ConstantFoldingRules {
|
||||
protected:
|
||||
// The |Key| and |Value| structs are used to by-pass a "decorated name length
|
||||
// exceeded, name was truncated" warning on VS2013 and VS2015.
|
||||
struct Key {
|
||||
uint32_t instruction_set;
|
||||
uint32_t opcode;
|
||||
};
|
||||
|
||||
friend bool operator<(const Key& a, const Key& b) {
|
||||
if (a.instruction_set < b.instruction_set) {
|
||||
return true;
|
||||
}
|
||||
if (a.instruction_set > b.instruction_set) {
|
||||
return false;
|
||||
}
|
||||
return a.opcode < b.opcode;
|
||||
}
|
||||
|
||||
struct Value {
|
||||
std::vector<ConstantFoldingRule> value;
|
||||
void push_back(ConstantFoldingRule rule) { value.push_back(rule); }
|
||||
};
|
||||
|
||||
public:
|
||||
ConstantFoldingRules();
|
||||
ConstantFoldingRules(IRContext* ctx) : context_(ctx) {}
|
||||
virtual ~ConstantFoldingRules() = default;
|
||||
|
||||
// Returns true if there is at least 1 folding rule for |opcode|.
|
||||
bool HasFoldingRule(SpvOp opcode) const { return rules_.count(opcode); }
|
||||
bool HasFoldingRule(const Instruction* inst) const {
|
||||
return !GetRulesForInstruction(inst).empty();
|
||||
}
|
||||
|
||||
// Returns an vector of constant folding rules for |opcode|.
|
||||
const std::vector<ConstantFoldingRule>& GetRulesForOpcode(
|
||||
SpvOp opcode) const {
|
||||
auto it = rules_.find(opcode);
|
||||
if (it != rules_.end()) {
|
||||
return it->second;
|
||||
// Returns true if there is at least 1 folding rule for |inst|.
|
||||
const std::vector<ConstantFoldingRule>& GetRulesForInstruction(
|
||||
const Instruction* inst) const {
|
||||
if (inst->opcode() != SpvOpExtInst) {
|
||||
auto it = rules_.find(inst->opcode());
|
||||
if (it != rules_.end()) {
|
||||
return it->second.value;
|
||||
}
|
||||
} else {
|
||||
uint32_t ext_inst_id = inst->GetSingleWordInOperand(0);
|
||||
uint32_t ext_opcode = inst->GetSingleWordInOperand(1);
|
||||
auto it = ext_rules_.find({ext_inst_id, ext_opcode});
|
||||
if (it != ext_rules_.end()) {
|
||||
return it->second.value;
|
||||
}
|
||||
}
|
||||
return empty_vector_;
|
||||
}
|
||||
|
||||
// Add the folding rules.
|
||||
virtual void AddFoldingRules();
|
||||
|
||||
protected:
|
||||
// |rules[opcode]| is the set of rules that can be applied to instructions
|
||||
// with |opcode| as the opcode.
|
||||
std::unordered_map<uint32_t, Value> rules_;
|
||||
|
||||
// The folding rules for extended instructions.
|
||||
std::map<Key, Value> ext_rules_;
|
||||
|
||||
private:
|
||||
std::unordered_map<uint32_t, std::vector<ConstantFoldingRule>> rules_;
|
||||
// The context that the instruction to be folded will be a part of.
|
||||
IRContext* context_;
|
||||
|
||||
// The empty set of rules to be used as the default return value in
|
||||
// |GetRulesForInstruction|.
|
||||
std::vector<ConstantFoldingRule> empty_vector_;
|
||||
};
|
||||
|
||||
|
@ -234,13 +234,12 @@ bool InstructionFolder::FoldInstructionInternal(Instruction* inst) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
SpvOp opcode = inst->opcode();
|
||||
analysis::ConstantManager* const_manager = context_->get_constant_mgr();
|
||||
|
||||
std::vector<const analysis::Constant*> constants =
|
||||
const_manager->GetOperandConstants(inst);
|
||||
|
||||
for (const FoldingRule& rule : GetFoldingRules().GetRulesForOpcode(opcode)) {
|
||||
for (const FoldingRule& rule :
|
||||
GetFoldingRules().GetRulesForInstruction(inst)) {
|
||||
if (rule(context_, inst, constants)) {
|
||||
return true;
|
||||
}
|
||||
@ -623,7 +622,7 @@ Instruction* InstructionFolder::FoldInstructionToConstant(
|
||||
analysis::ConstantManager* const_mgr = context_->get_constant_mgr();
|
||||
|
||||
if (!inst->IsFoldableByFoldScalar() &&
|
||||
!GetConstantFoldingRules().HasFoldingRule(inst->opcode())) {
|
||||
!GetConstantFoldingRules().HasFoldingRule(inst)) {
|
||||
return nullptr;
|
||||
}
|
||||
// Collect the values of the constant parameters.
|
||||
@ -641,19 +640,16 @@ Instruction* InstructionFolder::FoldInstructionToConstant(
|
||||
}
|
||||
});
|
||||
|
||||
if (GetConstantFoldingRules().HasFoldingRule(inst->opcode())) {
|
||||
const analysis::Constant* folded_const = nullptr;
|
||||
for (auto rule :
|
||||
GetConstantFoldingRules().GetRulesForOpcode(inst->opcode())) {
|
||||
folded_const = rule(context_, inst, constants);
|
||||
if (folded_const != nullptr) {
|
||||
Instruction* const_inst =
|
||||
const_mgr->GetDefiningInstruction(folded_const, inst->type_id());
|
||||
assert(const_inst->type_id() == inst->type_id());
|
||||
// May be a new instruction that needs to be analysed.
|
||||
context_->UpdateDefUse(const_inst);
|
||||
return const_inst;
|
||||
}
|
||||
const analysis::Constant* folded_const = nullptr;
|
||||
for (auto rule : GetConstantFoldingRules().GetRulesForInstruction(inst)) {
|
||||
folded_const = rule(context_, inst, constants);
|
||||
if (folded_const != nullptr) {
|
||||
Instruction* const_inst =
|
||||
const_mgr->GetDefiningInstruction(folded_const, inst->type_id());
|
||||
assert(const_inst->type_id() == inst->type_id());
|
||||
// May be a new instruction that needs to be analysed.
|
||||
context_->UpdateDefUse(const_inst);
|
||||
return const_inst;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -28,7 +28,23 @@ namespace opt {
|
||||
|
||||
class InstructionFolder {
|
||||
public:
|
||||
explicit InstructionFolder(IRContext* context) : context_(context) {}
|
||||
explicit InstructionFolder(IRContext* context)
|
||||
: context_(context),
|
||||
const_folding_rules_(new ConstantFoldingRules(context)),
|
||||
folding_rules_(new FoldingRules(context)) {
|
||||
folding_rules_->AddFoldingRules();
|
||||
const_folding_rules_->AddFoldingRules();
|
||||
}
|
||||
|
||||
explicit InstructionFolder(
|
||||
IRContext* context, std::unique_ptr<FoldingRules>&& folding_rules,
|
||||
std::unique_ptr<ConstantFoldingRules>&& constant_folding_rules)
|
||||
: context_(context),
|
||||
const_folding_rules_(std::move(constant_folding_rules)),
|
||||
folding_rules_(std::move(folding_rules)) {
|
||||
folding_rules_->AddFoldingRules();
|
||||
const_folding_rules_->AddFoldingRules();
|
||||
}
|
||||
|
||||
// Returns the result of folding a scalar instruction with the given |opcode|
|
||||
// and |operands|. Each entry in |operands| is a pointer to an
|
||||
@ -95,18 +111,18 @@ class InstructionFolder {
|
||||
bool FoldInstruction(Instruction* inst) const;
|
||||
|
||||
// Return true if this opcode has a const folding rule associtated with it.
|
||||
bool HasConstFoldingRule(SpvOp opcode) const {
|
||||
return GetConstantFoldingRules().HasFoldingRule(opcode);
|
||||
bool HasConstFoldingRule(const Instruction* inst) const {
|
||||
return GetConstantFoldingRules().HasFoldingRule(inst);
|
||||
}
|
||||
|
||||
private:
|
||||
// Returns a reference to the ConstnatFoldingRules instance.
|
||||
const ConstantFoldingRules& GetConstantFoldingRules() const {
|
||||
return const_folding_rules;
|
||||
return *const_folding_rules_;
|
||||
}
|
||||
|
||||
// Returns a reference to the FoldingRules instance.
|
||||
const FoldingRules& GetFoldingRules() const { return folding_rules; }
|
||||
const FoldingRules& GetFoldingRules() const { return *folding_rules_; }
|
||||
|
||||
// Returns the single-word result from performing the given unary operation on
|
||||
// the operand value which is passed in as a 32-bit word.
|
||||
@ -159,10 +175,10 @@ class InstructionFolder {
|
||||
IRContext* context_;
|
||||
|
||||
// Folding rules used by |FoldInstructionToConstant| and |FoldInstruction|.
|
||||
ConstantFoldingRules const_folding_rules;
|
||||
std::unique_ptr<ConstantFoldingRules> const_folding_rules_;
|
||||
|
||||
// Folding rules used by |FoldInstruction|.
|
||||
FoldingRules folding_rules;
|
||||
std::unique_ptr<FoldingRules> folding_rules_;
|
||||
};
|
||||
|
||||
} // namespace opt
|
||||
|
@ -2200,7 +2200,7 @@ FoldingRule RemoveRedundantOperands() {
|
||||
|
||||
} // namespace
|
||||
|
||||
FoldingRules::FoldingRules() {
|
||||
void FoldingRules::AddFoldingRules() {
|
||||
// Add all folding rules to the list for the opcodes to which they apply.
|
||||
// Note that the order in which rules are added to the list matters. If a rule
|
||||
// applies to the instruction, the rest of the rules will not be attempted.
|
||||
@ -2216,8 +2216,6 @@ FoldingRules::FoldingRules() {
|
||||
|
||||
rules_[SpvOpEntryPoint].push_back(RemoveRedundantOperands());
|
||||
|
||||
rules_[SpvOpExtInst].push_back(RedundantFMix());
|
||||
|
||||
rules_[SpvOpFAdd].push_back(RedundantFAdd());
|
||||
rules_[SpvOpFAdd].push_back(MergeAddNegateArithmetic());
|
||||
rules_[SpvOpFAdd].push_back(MergeAddAddArithmetic());
|
||||
@ -2271,6 +2269,15 @@ FoldingRules::FoldingRules() {
|
||||
rules_[SpvOpUDiv].push_back(MergeDivNegateArithmetic());
|
||||
|
||||
rules_[SpvOpVectorShuffle].push_back(VectorShuffleFeedingShuffle());
|
||||
|
||||
FeatureManager* feature_manager = context_->get_feature_mgr();
|
||||
// Add rules for GLSLstd450
|
||||
uint32_t ext_inst_glslstd450_id =
|
||||
feature_manager->GetExtInstImportId_GLSLstd450();
|
||||
if (ext_inst_glslstd450_id != 0) {
|
||||
ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMix}].push_back(
|
||||
RedundantFMix());
|
||||
}
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace spvtools
|
||||
|
@ -58,19 +58,58 @@ using FoldingRule = std::function<bool(
|
||||
|
||||
class FoldingRules {
|
||||
public:
|
||||
FoldingRules();
|
||||
using FoldingRuleSet = std::vector<FoldingRule>;
|
||||
|
||||
const std::vector<FoldingRule>& GetRulesForOpcode(SpvOp opcode) const {
|
||||
auto it = rules_.find(opcode);
|
||||
if (it != rules_.end()) {
|
||||
return it->second;
|
||||
explicit FoldingRules(IRContext* ctx) : context_(ctx) {}
|
||||
virtual ~FoldingRules() = default;
|
||||
|
||||
const FoldingRuleSet& GetRulesForInstruction(Instruction* inst) const {
|
||||
if (inst->opcode() != SpvOpExtInst) {
|
||||
auto it = rules_.find(inst->opcode());
|
||||
if (it != rules_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
} else {
|
||||
uint32_t ext_inst_id = inst->GetSingleWordInOperand(0);
|
||||
uint32_t ext_opcode = inst->GetSingleWordInOperand(1);
|
||||
auto it = ext_rules_.find({ext_inst_id, ext_opcode});
|
||||
if (it != ext_rules_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
}
|
||||
return empty_vector_;
|
||||
}
|
||||
|
||||
IRContext* context() { return context_; }
|
||||
|
||||
// Adds the folding rules for the object.
|
||||
virtual void AddFoldingRules();
|
||||
|
||||
protected:
|
||||
// The folding rules for core instructions.
|
||||
std::unordered_map<uint32_t, FoldingRuleSet> rules_;
|
||||
|
||||
// The folding rules for extended instructions.
|
||||
struct Key {
|
||||
uint32_t instruction_set;
|
||||
uint32_t opcode;
|
||||
};
|
||||
|
||||
friend bool operator<(const Key& a, const Key& b) {
|
||||
if (a.instruction_set < b.instruction_set) {
|
||||
return true;
|
||||
}
|
||||
if (a.instruction_set > b.instruction_set) {
|
||||
return false;
|
||||
}
|
||||
return a.opcode < b.opcode;
|
||||
}
|
||||
|
||||
std::map<Key, FoldingRuleSet> ext_rules_;
|
||||
|
||||
private:
|
||||
std::unordered_map<uint32_t, std::vector<FoldingRule>> rules_;
|
||||
std::vector<FoldingRule> empty_vector_;
|
||||
IRContext* context_;
|
||||
FoldingRuleSet empty_vector_;
|
||||
};
|
||||
|
||||
} // namespace opt
|
||||
|
@ -469,7 +469,7 @@ bool Instruction::IsOpaqueType() const {
|
||||
|
||||
bool Instruction::IsFoldable() const {
|
||||
return IsFoldableByFoldScalar() ||
|
||||
context()->get_instruction_folder().HasConstFoldingRule(opcode());
|
||||
context()->get_instruction_folder().HasConstFoldingRule(this);
|
||||
}
|
||||
|
||||
bool Instruction::IsFoldableByFoldScalar() const {
|
||||
|
@ -49,7 +49,7 @@ bool SimplificationPass::SimplifyFunction(Function* function) {
|
||||
cfg()->ForEachBlockInReversePostOrder(
|
||||
function->entry().get(),
|
||||
[&modified, &process_phis, &work_list, &in_work_list, &inst_to_kill,
|
||||
folder, this](BasicBlock* bb) {
|
||||
&folder, this](BasicBlock* bb) {
|
||||
for (Instruction* inst = &*bb->begin(); inst; inst = inst->NextNode()) {
|
||||
if (inst->opcode() == SpvOpPhi) {
|
||||
process_phis.insert(inst);
|
||||
|
@ -210,6 +210,7 @@ OpName %main "main"
|
||||
%float_2049 = OpConstant %float 2049
|
||||
%float_n2049 = OpConstant %float -2049
|
||||
%float_0p5 = OpConstant %float 0.5
|
||||
%float_0p2 = OpConstant %float 0.2
|
||||
%float_pi = OpConstant %float 1.5555
|
||||
%float_1e16 = OpConstant %float 1e16
|
||||
%float_n1e16 = OpConstant %float -1e16
|
||||
@ -1465,24 +1466,14 @@ INSTANTIATE_TEST_SUITE_P(FloatConstantFoldingTest, FloatInstructionFoldingTest,
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
2, std::numeric_limits<float>::quiet_NaN()),
|
||||
// Test case 20: QuantizeToF16 inf
|
||||
// Test case 20: FMix 1.0 4.0 0.2
|
||||
InstructionFoldingCase<float>(
|
||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%2 = OpFDiv %float %float_1 %float_0\n" +
|
||||
"%3 = OpQuantizeToF16 %float %3\n" +
|
||||
"%2 = OpExtInst %float %1 FMix %float_1 %float_4 %float_0p2\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
2, std::numeric_limits<float>::infinity()),
|
||||
// Test case 21: QuantizeToF16 -inf
|
||||
InstructionFoldingCase<float>(
|
||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%2 = OpFDiv %float %float_n1 %float_0\n" +
|
||||
"%3 = OpQuantizeToF16 %float %3\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
2, -std::numeric_limits<float>::infinity())
|
||||
2, 1.6f)
|
||||
));
|
||||
// clang-format on
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user