Add folding for redundant add/sub/mul/div/mix operations

This change implements instruction folding for arithmetic operations
that are redundant, specifically:

  x + 0 = 0 + x = x
  x - 0 = x
  0 - x = -x
  x * 0 = 0 * x = 0
  x * 1 = 1 * x = x
  0 / x = 0
  x / 1 = x
  mix(a, b, 0) = a
  mix(a, b, 1) = b

Cache ExtInst import id in feature manager

This allows us to avoid string lookups during optimization; for now we
just cache GLSL std450 import id but I can imagine caching more sets as
they become utilized by the optimizer.

Add tests for add/sub/mul/div/mix folding

The tests cover scalar float/double cases, and some vector cases.

Since most of the code for floating point folding is shared, the tests
for vector folding are not as exhaustive as scalar.

To test sub->negate folding I had to implement a custom fixture.
This commit is contained in:
Arseny Kapoulkine 2018-02-17 11:55:54 -08:00 committed by Steven Perron
parent fa3ac3cc33
commit 309be423cc
9 changed files with 800 additions and 19 deletions

View File

@ -30,23 +30,6 @@ inline std::vector<uint32_t> ExtractInts(uint64_t a) {
return result;
}
// Returns true if we are allowed to fold or otherwise manipulate the
// instruction that defines |id| in the given context.
bool CanFoldFloatingPoint(ir::IRContext* context, uint32_t id) {
// TODO: Add the rules for kernels. For now it will be pessimistic.
if (!context->get_feature_mgr()->HasCapability(SpvCapabilityShader)) {
return false;
}
bool is_nocontract = false;
context->get_decoration_mgr()->WhileEachDecoration(
id, SpvDecorationNoContraction, [&is_nocontract](const ir::Instruction&) {
is_nocontract = true;
return false;
});
return !is_nocontract;
}
// Folds an OpcompositeExtract where input is a composite constant.
ConstantFoldingRule FoldExtractWithConstants() {
return [](ir::Instruction* inst,
@ -147,7 +130,7 @@ ConstantFoldingRule FoldFloatingPointOp(FloatScalarFoldingRule scalar_rule) {
const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
const analysis::Vector* vector_type = result_type->AsVector();
if (!CanFoldFloatingPoint(context, inst->result_id())) {
if (!inst->IsFloatingPointFoldingAllowed()) {
return nullptr;
}

View File

@ -24,6 +24,7 @@ namespace opt {
void FeatureManager::Analyze(ir::Module* module) {
AddExtensions(module);
AddCapabilities(module);
AddExtInstImportIds(module);
}
void FeatureManager::AddExtensions(ir::Module* module) {
@ -56,5 +57,9 @@ void FeatureManager::AddCapabilities(ir::Module* module) {
}
}
void FeatureManager::AddExtInstImportIds(ir::Module* module) {
extinst_importid_GLSLstd450_ = module->GetExtInstImportId("GLSL.std.450");
}
} // namespace opt
} // namespace spvtools

View File

@ -46,6 +46,10 @@ class FeatureManager {
return &capabilities_;
}
uint32_t GetExtInstImportId_GLSLstd450() const {
return extinst_importid_GLSLstd450_;
}
private:
// Analyzes |module| and records enabled extensions.
void AddExtensions(ir::Module* module);
@ -57,6 +61,9 @@ class FeatureManager {
// Analyzes |module| and records enabled capabilities.
void AddCapabilities(ir::Module* module);
// Analyzes |module| and records imported external instruction sets.
void AddExtInstImportIds(ir::Module* module);
// Auxiliary object for querying SPIR-V grammar facts.
const libspirv::AssemblyGrammar& grammar_;
@ -65,6 +72,9 @@ class FeatureManager {
// The enabled capabilities.
libspirv::CapabilitySet capabilities_;
// Common external instruction import ids, cached for performance.
uint32_t extinst_importid_GLSLstd450_ = 0;
};
} // namespace opt

View File

@ -13,6 +13,7 @@
// limitations under the License.
#include "folding_rules.h"
#include "latest_version_glsl_std_450_header.h"
namespace spvtools {
namespace opt {
@ -21,6 +22,10 @@ namespace {
const uint32_t kExtractCompositeIdInIdx = 0;
const uint32_t kInsertObjectIdInIdx = 0;
const uint32_t kInsertCompositeIdInIdx = 1;
const uint32_t kExtInstSetIdInIdx = 0;
const uint32_t kExtInstInstructionInIdx = 1;
const uint32_t kFMixXIdInIdx = 2;
const uint32_t kFMixYIdInIdx = 3;
FoldingRule IntMultipleBy1() {
return [](ir::Instruction* inst,
@ -326,6 +331,199 @@ FoldingRule RedundantSelect() {
}
};
}
enum class FloatConstantKind { Unknown, Zero, One };
FloatConstantKind getFloatConstantKind(const analysis::Constant* constant) {
if (constant == nullptr) {
return FloatConstantKind::Unknown;
}
if (const analysis::VectorConstant* vc = constant->AsVectorConstant()) {
const std::vector<const analysis::Constant*>& components =
vc->GetComponents();
assert(!components.empty());
FloatConstantKind kind = getFloatConstantKind(components[0]);
for (size_t i = 1; i < components.size(); ++i) {
if (getFloatConstantKind(components[i]) != kind) {
return FloatConstantKind::Unknown;
}
}
return kind;
} else if (const analysis::FloatConstant* fc = constant->AsFloatConstant()) {
double value = (fc->type()->AsFloat()->width() == 64) ? fc->GetDoubleValue()
: fc->GetFloatValue();
if (value == 0.0) {
return FloatConstantKind::Zero;
} else if (value == 1.0) {
return FloatConstantKind::One;
} else {
return FloatConstantKind::Unknown;
}
} else {
return FloatConstantKind::Unknown;
}
}
FoldingRule RedundantFAdd() {
return [](ir::Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == SpvOpFAdd && "Wrong opcode. Should be OpFAdd.");
assert(constants.size() == 2);
if (!inst->IsFloatingPointFoldingAllowed()) {
return false;
}
FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::Zero) {
inst->SetOpcode(SpvOpCopyObject);
inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
{inst->GetSingleWordInOperand(
kind0 == FloatConstantKind::Zero ? 1 : 0)}}});
return true;
}
return false;
};
}
FoldingRule RedundantFSub() {
return [](ir::Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == SpvOpFSub && "Wrong opcode. Should be OpFSub.");
assert(constants.size() == 2);
if (!inst->IsFloatingPointFoldingAllowed()) {
return false;
}
FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
if (kind0 == FloatConstantKind::Zero) {
inst->SetOpcode(SpvOpFNegate);
inst->SetInOperands(
{{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1)}}});
return true;
}
if (kind1 == FloatConstantKind::Zero) {
inst->SetOpcode(SpvOpCopyObject);
inst->SetInOperands(
{{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
return true;
}
return false;
};
}
FoldingRule RedundantFMul() {
return [](ir::Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == SpvOpFMul && "Wrong opcode. Should be OpFMul.");
assert(constants.size() == 2);
if (!inst->IsFloatingPointFoldingAllowed()) {
return false;
}
FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::Zero) {
inst->SetOpcode(SpvOpCopyObject);
inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
{inst->GetSingleWordInOperand(
kind0 == FloatConstantKind::Zero ? 0 : 1)}}});
return true;
}
if (kind0 == FloatConstantKind::One || kind1 == FloatConstantKind::One) {
inst->SetOpcode(SpvOpCopyObject);
inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
{inst->GetSingleWordInOperand(
kind0 == FloatConstantKind::One ? 1 : 0)}}});
return true;
}
return false;
};
}
FoldingRule RedundantFDiv() {
return [](ir::Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == SpvOpFDiv && "Wrong opcode. Should be OpFDiv.");
assert(constants.size() == 2);
if (!inst->IsFloatingPointFoldingAllowed()) {
return false;
}
FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
if (kind0 == FloatConstantKind::Zero) {
inst->SetOpcode(SpvOpCopyObject);
inst->SetInOperands(
{{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
return true;
}
if (kind1 == FloatConstantKind::One) {
inst->SetOpcode(SpvOpCopyObject);
inst->SetInOperands(
{{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
return true;
}
return false;
};
}
FoldingRule RedundantFMix() {
return [](ir::Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == SpvOpExtInst &&
"Wrong opcode. Should be OpExtInst.");
if (!inst->IsFloatingPointFoldingAllowed()) {
return false;
}
uint32_t instSetId =
inst->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
if (inst->GetSingleWordInOperand(kExtInstSetIdInIdx) == instSetId &&
inst->GetSingleWordInOperand(kExtInstInstructionInIdx) ==
GLSLstd450FMix) {
assert(constants.size() == 5);
FloatConstantKind kind4 = getFloatConstantKind(constants[4]);
if (kind4 == FloatConstantKind::Zero || kind4 == FloatConstantKind::One) {
inst->SetOpcode(SpvOpCopyObject);
inst->SetInOperands(
{{SPV_OPERAND_TYPE_ID,
{inst->GetSingleWordInOperand(kind4 == FloatConstantKind::Zero
? kFMixXIdInIdx
: kFMixYIdInIdx)}}});
return true;
}
}
return false;
};
}
} // namespace
spvtools::opt::FoldingRules::FoldingRules() {
@ -339,11 +537,19 @@ spvtools::opt::FoldingRules::FoldingRules() {
rules_[SpvOpCompositeExtract].push_back(InsertFeedingExtract());
rules_[SpvOpCompositeExtract].push_back(CompositeConstructFeedingExtract());
rules_[SpvOpExtInst].push_back(RedundantFMix());
rules_[SpvOpFAdd].push_back(RedundantFAdd());
rules_[SpvOpFDiv].push_back(RedundantFDiv());
rules_[SpvOpFMul].push_back(RedundantFMul());
rules_[SpvOpFSub].push_back(RedundantFSub());
rules_[SpvOpIMul].push_back(IntMultipleBy1());
rules_[SpvOpPhi].push_back(RedundantPhi());
rules_[SpvOpSelect].push_back(RedundantSelect());
}
} // namespace opt
} // namespace spvtools

View File

@ -96,7 +96,7 @@ uint32_t InsertExtractElimPass::DoExtract(ir::Instruction* compInst,
}
} else if (cinst->opcode() == SpvOpExtInst &&
cinst->GetSingleWordInOperand(kExtInstSetIdInIdx) ==
get_module()->GetExtInstImportId("GLSL.std.450") &&
get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
cinst->GetSingleWordInOperand(kExtInstInstructionInIdx) ==
GLSLstd450FMix) {
// If mixing value component is 0 or 1 we just match with x or y.

View File

@ -483,6 +483,22 @@ bool Instruction::IsFoldableByFoldScalar() const {
return opt::IsFoldableType(type);
}
bool Instruction::IsFloatingPointFoldingAllowed() const {
// TODO: Add the rules for kernels. For now it will be pessimistic.
if (!context_->get_feature_mgr()->HasCapability(SpvCapabilityShader)) {
return false;
}
bool is_nocontract = false;
context_->get_decoration_mgr()->WhileEachDecoration(
opcode_, SpvDecorationNoContraction,
[&is_nocontract](const ir::Instruction&) {
is_nocontract = true;
return false;
});
return !is_nocontract;
}
std::string Instruction::PrettyPrint(uint32_t options) const {
// Convert the module to binary.
std::vector<uint32_t> module_binary;

View File

@ -372,6 +372,11 @@ class Instruction : public utils::IntrusiveNodeBase<Instruction> {
// constant value by |FoldScalar|.
bool IsFoldableByFoldScalar() const;
// Returns true if we are allowed to fold or otherwise manipulate the
// instruction that defines |id| in the given context. This includes not
// handling NaN values.
bool IsFloatingPointFoldingAllowed() const;
inline bool operator==(const Instruction&) const;
inline bool operator!=(const Instruction&) const;
inline bool operator<(const Instruction&) const;

View File

@ -82,6 +82,10 @@ class Pass {
return context()->get_decoration_mgr();
}
FeatureManager* get_feature_mgr() const {
return context()->get_feature_mgr();
}
// Returns a pointer to the current module for this pass.
ir::Module* get_module() const { return context_->module(); }

View File

@ -80,6 +80,10 @@ TEST_P(IntegerInstructionFoldingTest, Case) {
#define TRUE_ID 101
#define VEC2_0_ID 102
#define INT_7_ID 103
#define FLOAT_0_ID 104
#define DOUBLE_0_ID 105
#define VEC4_0_ID 106
#define DVEC4_0_ID 106
const std::string& Header() {
static const std::string header = R"(OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
@ -103,10 +107,16 @@ OpName %main "main"
%uint = OpTypeInt 32 1
%v2int = OpTypeVector %int 2
%v4int = OpTypeVector %int 4
%v4float = OpTypeVector %float 4
%v4double = OpTypeVector %double 4
%struct_v2int_int_int = OpTypeStruct %v2int %int %int
%_ptr_int = OpTypePointer Function %int
%_ptr_uint = OpTypePointer Function %uint
%_ptr_bool = OpTypePointer Function %bool
%_ptr_float = OpTypePointer Function %float
%_ptr_double = OpTypePointer Function %double
%_ptr_v4float = OpTypePointer Function %v4float
%_ptr_v4double = OpTypePointer Function %v4double
%_ptr_struct_v2int_int_int = OpTypePointer Function %struct_v2int_int_int
%short_0 = OpConstant %short 0
%short_3 = OpConstant %short 3
@ -132,17 +142,27 @@ OpName %main "main"
%float16_1 = OpConstant %float16 1
%float16_2 = OpConstant %float16 2
%float_n1 = OpConstant %float -1
%104 = OpConstant %float 0 ; Need a def with an numerical id to define id maps.
%float_0 = OpConstant %float 0
%float_1 = OpConstant %float 1
%float_2 = OpConstant %float 2
%float_3 = OpConstant %float 3
%double_n1 = OpConstant %double -1
%105 = OpConstant %double 0 ; Need a def with an numerical id to define id maps.
%double_0 = OpConstant %double 0
%double_1 = OpConstant %double 1
%double_2 = OpConstant %double 2
%double_3 = OpConstant %double 3
%float_nan = OpConstant %float -0x1.8p+128
%double_nan = OpConstant %double -0x1.8p+1024
%106 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0
%v4float_0_0_0_0 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0
%v4float_0_0_0_1 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_1
%v4float_1_1_1_1 = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1
%107 = OpConstantComposite %v4double %double_0 %double_0 %double_0 %double_0
%v4double_0_0_0_0 = OpConstantComposite %v4double %double_0 %double_0 %double_0 %double_0
%v4double_0_0_0_1 = OpConstantComposite %v4double %double_0 %double_0 %double_0 %double_1
%v4double_1_1_1_1 = OpConstantComposite %v4double %double_1 %double_1 %double_1 %double_1
)";
return header;
@ -2211,5 +2231,537 @@ INSTANTIATE_TEST_CASE_P(SelectFoldingTest, GeneralInstructionFoldingTest,
"OpFunctionEnd",
2, INT_0_ID)
));
INSTANTIATE_TEST_CASE_P(FloatRedundantFoldingTest, GeneralInstructionFoldingTest,
::testing::Values(
// Test case 0: Don't fold n + 1.0
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_float Function\n" +
"%3 = OpLoad %float %n\n" +
"%2 = OpFAdd %float %3 %float_2\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 1: Don't fold n - 1.0
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_float Function\n" +
"%3 = OpLoad %float %n\n" +
"%2 = OpFSub %float %3 %float_2\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 2: Don't fold n * 2.0
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_float Function\n" +
"%3 = OpLoad %float %n\n" +
"%2 = OpFMul %float %3 %float_2\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 3: Don't fold n / 2.0
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_float Function\n" +
"%3 = OpLoad %float %n\n" +
"%2 = OpFDiv %float %3 %float_2\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 4: Fold n + 0.0
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_float Function\n" +
"%3 = OpLoad %float %n\n" +
"%2 = OpFAdd %float %3 %float_0\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 3),
// Test case 5: Fold 0.0 + n
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_float Function\n" +
"%3 = OpLoad %float %n\n" +
"%2 = OpFAdd %float %float_0 %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 3),
// Test case 6: Fold n - 0.0
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_float Function\n" +
"%3 = OpLoad %float %n\n" +
"%2 = OpFSub %float %3 %float_0\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 3),
// Test case 7: Fold n * 1.0
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_float Function\n" +
"%3 = OpLoad %float %n\n" +
"%2 = OpFMul %float %3 %float_1\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 3),
// Test case 8: Fold 1.0 * n
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_float Function\n" +
"%3 = OpLoad %float %n\n" +
"%2 = OpFMul %float %float_1 %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 3),
// Test case 9: Fold n / 1.0
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_float Function\n" +
"%3 = OpLoad %float %n\n" +
"%2 = OpFDiv %float %3 %float_1\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 3),
// Test case 10: Fold n * 0.0
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_float Function\n" +
"%3 = OpLoad %float %n\n" +
"%2 = OpFMul %float %3 %104\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, FLOAT_0_ID),
// Test case 11: Fold 0.0 * n
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_float Function\n" +
"%3 = OpLoad %float %n\n" +
"%2 = OpFMul %float %104 %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, FLOAT_0_ID),
// Test case 12: Fold 0.0 / n
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_float Function\n" +
"%3 = OpLoad %float %n\n" +
"%2 = OpFDiv %float %104 %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, FLOAT_0_ID),
// Test case 13: Don't fold mix(a, b, 2.0)
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%a = OpVariable %_ptr_float Function\n" +
"%b = OpVariable %_ptr_float Function\n" +
"%3 = OpLoad %float %a\n" +
"%4 = OpLoad %float %b\n" +
"%2 = OpExtInst %float %1 FMix %3 %4 %float_2\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 14: Fold mix(a, b, 0.0)
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%a = OpVariable %_ptr_float Function\n" +
"%b = OpVariable %_ptr_float Function\n" +
"%3 = OpLoad %float %a\n" +
"%4 = OpLoad %float %b\n" +
"%2 = OpExtInst %float %1 FMix %3 %4 %float_0\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 3),
// Test case 15: Fold mix(a, b, 1.0)
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%a = OpVariable %_ptr_float Function\n" +
"%b = OpVariable %_ptr_float Function\n" +
"%3 = OpLoad %float %a\n" +
"%4 = OpLoad %float %b\n" +
"%2 = OpExtInst %float %1 FMix %3 %4 %float_1\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 4)
));
INSTANTIATE_TEST_CASE_P(DoubleRedundantFoldingTest, GeneralInstructionFoldingTest,
::testing::Values(
// Test case 0: Don't fold n + 1.0
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_double Function\n" +
"%3 = OpLoad %double %n\n" +
"%2 = OpFAdd %double %3 %double_2\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 1: Don't fold n - 1.0
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_double Function\n" +
"%3 = OpLoad %double %n\n" +
"%2 = OpFSub %double %3 %double_2\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 2: Don't fold n * 2.0
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_double Function\n" +
"%3 = OpLoad %double %n\n" +
"%2 = OpFMul %double %3 %double_2\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 3: Don't fold n / 2.0
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_double Function\n" +
"%3 = OpLoad %double %n\n" +
"%2 = OpFDiv %double %3 %double_2\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 4: Fold n + 0.0
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_double Function\n" +
"%3 = OpLoad %double %n\n" +
"%2 = OpFAdd %double %3 %double_0\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 3),
// Test case 5: Fold 0.0 + n
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_double Function\n" +
"%3 = OpLoad %double %n\n" +
"%2 = OpFAdd %double %double_0 %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 3),
// Test case 6: Fold n - 0.0
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_double Function\n" +
"%3 = OpLoad %double %n\n" +
"%2 = OpFSub %double %3 %double_0\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 3),
// Test case 7: Fold n * 1.0
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_double Function\n" +
"%3 = OpLoad %double %n\n" +
"%2 = OpFMul %double %3 %double_1\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 3),
// Test case 8: Fold 1.0 * n
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_double Function\n" +
"%3 = OpLoad %double %n\n" +
"%2 = OpFMul %double %double_1 %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 3),
// Test case 9: Fold n / 1.0
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_double Function\n" +
"%3 = OpLoad %double %n\n" +
"%2 = OpFDiv %double %3 %double_1\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 3),
// Test case 10: Fold n * 0.0
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_double Function\n" +
"%3 = OpLoad %double %n\n" +
"%2 = OpFMul %double %3 %105\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, DOUBLE_0_ID),
// Test case 11: Fold 0.0 * n
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_double Function\n" +
"%3 = OpLoad %double %n\n" +
"%2 = OpFMul %double %105 %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, DOUBLE_0_ID),
// Test case 12: Fold 0.0 / n
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_double Function\n" +
"%3 = OpLoad %double %n\n" +
"%2 = OpFDiv %double %105 %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, DOUBLE_0_ID),
// Test case 13: Don't fold mix(a, b, 2.0)
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%a = OpVariable %_ptr_double Function\n" +
"%b = OpVariable %_ptr_double Function\n" +
"%3 = OpLoad %double %a\n" +
"%4 = OpLoad %double %b\n" +
"%2 = OpExtInst %double %1 FMix %3 %4 %double_2\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 14: Fold mix(a, b, 0.0)
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%a = OpVariable %_ptr_double Function\n" +
"%b = OpVariable %_ptr_double Function\n" +
"%3 = OpLoad %double %a\n" +
"%4 = OpLoad %double %b\n" +
"%2 = OpExtInst %double %1 FMix %3 %4 %double_0\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 3),
// Test case 15: Fold mix(a, b, 1.0)
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%a = OpVariable %_ptr_double Function\n" +
"%b = OpVariable %_ptr_double Function\n" +
"%3 = OpLoad %double %a\n" +
"%4 = OpLoad %double %b\n" +
"%2 = OpExtInst %double %1 FMix %3 %4 %double_1\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 4)
));
INSTANTIATE_TEST_CASE_P(FloatVectorRedundantFoldingTest, GeneralInstructionFoldingTest,
::testing::Values(
// Test case 0: Don't fold a * vec4(0.0, 0.0, 0.0, 1.0)
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_v4float Function\n" +
"%3 = OpLoad %v4float %n\n" +
"%2 = OpFMul %v4float %3 %v4float_0_0_0_1\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 1: Fold a * vec4(0.0, 0.0, 0.0, 0.0)
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_v4float Function\n" +
"%3 = OpLoad %v4float %n\n" +
"%2 = OpFMul %v4float %3 %106\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, VEC4_0_ID),
// Test case 2: Fold a * vec4(1.0, 1.0, 1.0, 1.0)
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_v4float Function\n" +
"%3 = OpLoad %v4float %n\n" +
"%2 = OpFMul %v4float %3 %v4float_1_1_1_1\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 3)
));
INSTANTIATE_TEST_CASE_P(DoubleVectorRedundantFoldingTest, GeneralInstructionFoldingTest,
::testing::Values(
// Test case 0: Don't fold a * vec4(0.0, 0.0, 0.0, 1.0)
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_v4double Function\n" +
"%3 = OpLoad %v4double %n\n" +
"%2 = OpFMul %v4double %3 %v4double_0_0_0_1\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 1: Fold a * vec4(0.0, 0.0, 0.0, 0.0)
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_v4double Function\n" +
"%3 = OpLoad %v4double %n\n" +
"%2 = OpFMul %v4double %3 %106\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, DVEC4_0_ID),
// Test case 2: Fold a * vec4(1.0, 1.0, 1.0, 1.0)
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_v4double Function\n" +
"%3 = OpLoad %v4double %n\n" +
"%2 = OpFMul %v4double %3 %v4double_1_1_1_1\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 3)
));
// clang-format on
using ToNegateFoldingTest =
::testing::TestWithParam<InstructionFoldingCase<uint32_t>>;
TEST_P(ToNegateFoldingTest, Case) {
const auto& tc = GetParam();
// Build module.
std::unique_ptr<ir::IRContext> context =
BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.test_body,
SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
ASSERT_NE(nullptr, context);
// Fold the instruction to test.
opt::analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
ir::Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold);
std::unique_ptr<ir::Instruction> original_inst(inst->Clone(context.get()));
bool succeeded = opt::FoldInstruction(inst);
// Make sure the instruction folded as expected.
EXPECT_EQ(inst->result_id(), original_inst->result_id());
EXPECT_EQ(inst->type_id(), original_inst->type_id());
EXPECT_TRUE((!succeeded) == (tc.expected_result == 0));
if (succeeded) {
EXPECT_EQ(inst->opcode(), SpvOpFNegate);
EXPECT_EQ(inst->GetSingleWordInOperand(0), tc.expected_result);
} else {
EXPECT_EQ(inst->NumInOperands(), original_inst->NumInOperands());
for (uint32_t i = 0; i < inst->NumInOperands(); ++i) {
EXPECT_EQ(inst->GetOperand(i), original_inst->GetOperand(i));
}
}
}
// clang-format off
INSTANTIATE_TEST_CASE_P(FloatRedundantSubFoldingTest, ToNegateFoldingTest,
::testing::Values(
// Test case 0: Don't fold 1.0 - n
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_float Function\n" +
"%3 = OpLoad %float %n\n" +
"%2 = OpFSub %float %float_1 %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 1: Fold 0.0 - n
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_float Function\n" +
"%3 = OpLoad %float %n\n" +
"%2 = OpFSub %float %float_0 %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 3),
// Test case 2: Don't fold (0,0,0,1) - n
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_v4float Function\n" +
"%3 = OpLoad %v4float %n\n" +
"%2 = OpFSub %v4float %v4float_0_0_0_1 %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 3: Fold (0,0,0,0) - n
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_v4float Function\n" +
"%3 = OpLoad %v4float %n\n" +
"%2 = OpFSub %v4float %v4float_0_0_0_0 %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 3)
));
INSTANTIATE_TEST_CASE_P(DoubleRedundantSubFoldingTest, ToNegateFoldingTest,
::testing::Values(
// Test case 0: Don't fold 1.0 - n
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_double Function\n" +
"%3 = OpLoad %double %n\n" +
"%2 = OpFSub %double %double_1 %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 1: Fold 0.0 - n
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_double Function\n" +
"%3 = OpLoad %double %n\n" +
"%2 = OpFSub %double %double_0 %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 3),
// Test case 2: Don't fold (0,0,0,1) - n
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_v4double Function\n" +
"%3 = OpLoad %v4double %n\n" +
"%2 = OpFSub %v4double %v4double_0_0_0_1 %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 3: Fold (0,0,0,0) - n
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_v4double Function\n" +
"%3 = OpLoad %v4double %n\n" +
"%2 = OpFSub %v4double %v4double_0_0_0_0 %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 3)
));
// clang-format on
} // anonymous namespace