Add general folding infrastructure.

Create the folding engine that will

1) attempt to fold an instruction.
2) iterates on the folding so small folding rules can be easily combined.
3) insert new instructions when needed.

I've added the minimum number of rules needed to test the features above.
This commit is contained in:
Steven Perron 2018-01-24 13:26:33 -05:00 committed by David Neto
parent 1c0056c339
commit bc1ec9418b
8 changed files with 797 additions and 469 deletions

View File

@ -78,6 +78,7 @@ SPVTOOLS_OPT_SRC_FILES := \
source/opt/feature_manager.cpp \
source/opt/flatten_decoration_pass.cpp \
source/opt/fold.cpp \
source/opt/folding_rules.cpp \
source/opt/fold_spec_constant_op_and_composite_pass.cpp \
source/opt/freeze_spec_constant_value_pass.cpp \
source/opt/function.cpp \

View File

@ -35,6 +35,7 @@ add_library(SPIRV-Tools-opt
feature_manager.h
flatten_decoration_pass.h
fold.h
folding_rules.h
fold_spec_constant_op_and_composite_pass.h
freeze_spec_constant_value_pass.h
function.h
@ -101,6 +102,7 @@ add_library(SPIRV-Tools-opt
feature_manager.cpp
flatten_decoration_pass.cpp
fold.cpp
folding_rules.cpp
fold_spec_constant_op_and_composite_pass.cpp
freeze_spec_constant_value_pass.cpp
function.cpp

View File

@ -15,6 +15,8 @@
#include "fold.h"
#include "def_use_manager.h"
#include "folding_rules.h"
#include "ir_builder.h"
#include "ir_context.h"
#include <cassert>
@ -180,6 +182,42 @@ uint32_t OperateWords(SpvOp opcode,
}
}
bool FoldInstructionInternal(ir::Instruction* inst,
std::function<uint32_t(uint32_t)> id_map) {
ir::IRContext* context = inst->context();
ir::Instruction* folded_inst = FoldInstructionToConstant(inst, id_map);
if (folded_inst != nullptr) {
inst->SetOpcode(SpvOpCopyObject);
inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {folded_inst->result_id()}}});
return true;
}
SpvOp opcode = inst->opcode();
analysis::ConstantManager* const_manger = context->get_constant_mgr();
std::vector<const analysis::Constant*> constants;
for (uint32_t i = 0; i < inst->NumInOperands(); i++) {
const ir::Operand* operand = &inst->GetInOperand(i);
if (operand->type != SPV_OPERAND_TYPE_ID) {
constants.push_back(nullptr);
} else {
uint32_t id = id_map(operand->words[0]);
inst->SetInOperand(i, {id});
const analysis::Constant* constant =
const_manger->FindDeclaredConstant(id);
constants.push_back(constant);
}
}
static FoldingRules* rules = new FoldingRules();
for (FoldingRule rule : rules->GetRulesForOpcode(opcode)) {
if (rule(inst, constants)) {
return true;
}
}
return false;
}
} // namespace
// Returns the result of performing an operation on scalar constant operands.
@ -624,13 +662,26 @@ bool IsFoldableType(ir::Instruction* type_inst) {
ir::Instruction* FoldInstruction(ir::Instruction* inst,
std::function<uint32_t(uint32_t)> id_map) {
ir::Instruction* folded_inst = FoldInstructionToConstant(inst, id_map);
if (folded_inst != nullptr) {
return folded_inst;
ir::IRContext* context = inst->context();
bool modified = false;
std::unique_ptr<ir::Instruction> folded_inst(inst->Clone(context));
while (FoldInstructionInternal(&*folded_inst, id_map)) {
modified = true;
}
// TODO: Add other folding opportunities that do not necessarily fold to a
// constant.
if (modified) {
if (folded_inst->opcode() == SpvOpCopyObject) {
analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
return def_use_mgr->GetDef(folded_inst->GetSingleWordInOperand(0));
} else {
InstructionBuilder ir_builder(
context, inst,
ir::IRContext::kAnalysisDefUse |
ir::IRContext::kAnalysisInstrToBlockMapping);
folded_inst->SetResultId(context->TakeNextId());
return ir_builder.AddInstruction(std::move(folded_inst));
}
}
return nullptr;
}

View File

@ -0,0 +1,129 @@
// Copyright (c) 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "folding_rules.h"
namespace spvtools {
namespace opt {
namespace {
const uint32_t kExtractCompositeIdInIdx = 0;
const uint32_t kInsertObjectIdInIdx = 0;
const uint32_t kInsertCompositeIdInIdx = 1;
FoldingRule IntMultipleBy1() {
return [](ir::Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == SpvOpIMul && "Wrong opcode. Should be OpIMul.");
for (uint32_t i = 0; i < 2; i++) {
if (constants[i] == nullptr) {
continue;
}
const analysis::IntConstant* int_constant = constants[i]->AsIntConstant();
if (int_constant->GetU32BitValue() == 1) {
inst->SetOpcode(SpvOpCopyObject);
inst->SetInOperands(
{{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1 - i)}}});
return true;
}
}
return false;
};
}
FoldingRule InsertFeedingExtract() {
return [](ir::Instruction* inst,
const std::vector<const analysis::Constant*>&) {
assert(inst->opcode() == SpvOpCompositeExtract &&
"Wrong opcode. Should be OpCompositeExtract.");
analysis::DefUseManager* def_use_mgr = inst->context()->get_def_use_mgr();
uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
ir::Instruction* cinst = def_use_mgr->GetDef(cid);
if (cinst->opcode() != SpvOpCompositeInsert) {
return false;
}
// Find the first position where the list of insert and extract indicies
// differ, if at all.
uint32_t i;
for (i = 1; i < inst->NumInOperands(); ++i) {
if (i + 1 >= cinst->NumInOperands()) {
break;
}
if (inst->GetSingleWordInOperand(i) !=
cinst->GetSingleWordInOperand(i + 1)) {
break;
}
}
// We are extracting the element that was inserted.
if (i == inst->NumInOperands() && i + 1 == cinst->NumInOperands()) {
inst->SetOpcode(SpvOpCopyObject);
inst->SetInOperands(
{{SPV_OPERAND_TYPE_ID,
{cinst->GetSingleWordInOperand(kInsertObjectIdInIdx)}}});
return true;
}
// Extracting the value that was inserted along with values for the base
// composite. Cannot do anything.
if (i + 1 == cinst->NumInOperands()) {
return false;
}
// Extracting an element of the value that was inserted. Extract from
// that value directly.
if (i == inst->NumInOperands()) {
std::vector<ir::Operand> operands;
operands.push_back(
{SPV_OPERAND_TYPE_ID,
{cinst->GetSingleWordInOperand(kInsertObjectIdInIdx)}});
for (i = i + 1; i < cinst->NumInOperands(); ++i) {
operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
{cinst->GetSingleWordInOperand(i)}});
}
inst->SetInOperands(std::move(operands));
return true;
}
// Extracting a value that is disjoint from the element being inserted.
// Rewrite the extract to use the composite input to the insert.
std::vector<ir::Operand> operands;
operands.push_back(
{SPV_OPERAND_TYPE_ID,
{cinst->GetSingleWordInOperand(kInsertCompositeIdInIdx)}});
for (i = 1; i < inst->NumInOperands(); ++i) {
operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
{inst->GetSingleWordInOperand(i)}});
}
inst->SetInOperands(std::move(operands));
return true;
};
}
} // namespace
spvtools::opt::FoldingRules::FoldingRules() {
// Add all folding rules to the list for the opcodes to which they apply.
// Note that the order in which rules are added to the list matters. If a rule
// applies to the instruction, the rest of the rules will not be attempted.
// Take that into consideration.
rules[SpvOpIMul].push_back(IntMultipleBy1());
rules[SpvOpCompositeExtract].push_back(InsertFeedingExtract());
}
} // namespace opt
} // namespace spvtools

View File

@ -0,0 +1,81 @@
// Copyright (c) 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef LIBSPIRV_UTIL_FOLDING_RULES_H_
#define LIBSPIRV_UTIL_FOLDING_RULES_H_
#include <cstdint>
#include <vector>
#include "constants.h"
#include "def_use_manager.h"
#include "ir_builder.h"
#include "ir_context.h"
namespace spvtools {
namespace opt {
// Folding Rules:
//
// The folding mechanism is built around the concept of a |FoldingRule|. A
// folding rule is a function that implements a method of simplifying an
// instruction.
//
// The inputs to a folding rule are:
// |inst| - the instruction to be simplified.
// |constants| - if an in-operands is an id of a constant, then the
// corresponding value in |constants| contains that
// constant value. Otherwise, the corresponding entry in
// |constants| is |nullptr|.
//
// A folding rule returns true if |inst| can be simplified using this rule. If
// the instruction can be simplified, then |inst| is changed to the simplified
// instruction. Otherwise, |inst| remains the same.
//
// See folding_rules.cpp for examples on how to write a folding rule. It is
// important to note that if |inst| can be folded to the result of an
// instruction that feed it, then |inst| should be changed to an OpCopyObject
// that copies that id.
//
// Be sure to add new folding rules to the table of folding rules in the
// constructor for FoldingRules. The new rule should be added to the list for
// every opcode that it applies to. Note that earlier rules in the list are
// given priority. That is, if an earlier rule is able to fold an instruction,
// the later rules will not be attempted.
using FoldingRule = std::function<bool(
ir::Instruction* inst,
const std::vector<const analysis::Constant*>& constants)>;
class FoldingRules {
public:
FoldingRules();
const std::vector<FoldingRule>& GetRulesForOpcode(SpvOp opcode) {
auto it = rules.find(opcode);
if (it != rules.end()) {
return it->second;
}
return empty_vector_;
}
private:
std::unordered_map<uint32_t, std::vector<FoldingRule>> rules;
std::vector<FoldingRule> empty_vector_;
};
} // namespace opt
} // namespace spvtools
#endif // LIBSPIRV_UTIL_FOLDING_RULES_H_

View File

@ -201,6 +201,8 @@ class Instruction : public utils::IntrusiveNodeBase<Instruction> {
// This is for in-operands modification only, but with |index| expressed in
// terms of operand index rather than in-operand index.
inline void SetOperand(uint32_t index, std::vector<uint32_t>&& data);
// Replace all of the in operands with those in |new_operands|.
inline void SetInOperands(std::vector<Operand>&& new_operands);
// Sets the result type id.
inline void SetResultType(uint32_t ty_id);
// Sets the result id
@ -465,6 +467,13 @@ inline void Instruction::SetOperand(uint32_t index,
operands_[index].words = std::move(data);
}
inline void Instruction::SetInOperands(std::vector<Operand>&& new_operands) {
// Remove the old in operands.
operands_.erase(operands_.begin() + TypeResultIdCount(), operands_.end());
// Add the new in operands.
operands_.insert(operands_.end(), new_operands.begin(), new_operands.end());
}
inline void Instruction::SetResultId(uint32_t res_id) {
result_id_ = res_id;
auto ridx = (type_id_ != 0) ? 1 : 0;

View File

@ -152,6 +152,22 @@ class InstructionBuilder {
return AddInstruction(std::move(construct));
}
ir::Instruction* AddCompositeExtract(
uint32_t type, uint32_t id_of_composite,
const std::vector<uint32_t>& index_list) {
std::vector<ir::Operand> operands;
operands.push_back({SPV_OPERAND_TYPE_ID, {id_of_composite}});
for (uint32_t index : index_list) {
operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {index}});
}
std::unique_ptr<ir::Instruction> new_inst(
new ir::Instruction(GetContext(), SpvOpCompositeExtract, type,
GetContext()->TakeNextId(), operands));
return AddInstruction(std::move(new_inst));
}
// Inserts the new instruction before the insertion point.
ir::Instruction* AddInstruction(std::unique_ptr<ir::Instruction>&& insn) {
ir::Instruction* insn_ptr = &*insert_before_.InsertBefore(std::move(insn));

View File

@ -1,5 +1,4 @@
// Copyright (c) 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
@ -95,12 +94,14 @@ OpName %main "main"
%int = OpTypeInt 32 1
%long = OpTypeInt 64 1
%uint = OpTypeInt 32 1
%v4int = OpTypeVector %int 4
%_ptr_int = OpTypePointer Function %int
%_ptr_uint = OpTypePointer Function %uint
%_ptr_bool = OpTypePointer Function %bool
%short_0 = OpConstant %short 0
%short_3 = OpConstant %short 3
%int_0 = OpConstant %int 0
%int_1 = OpConstant %int 1
%100 = OpConstant %int 0 ; Need a def with an numerical id to define id maps.
%int_3 = OpConstant %int 3
%int_min = OpConstant %int -2147483648
@ -111,6 +112,7 @@ OpName %main "main"
%uint_3 = OpConstant %uint 3
%uint_32 = OpConstant %uint 32
%uint_max = OpConstant %uint -1
%v4int_0_0_0_0 = OpConstantComposite %v4int %int_0 %int_0 %int_0 %int_0
)";
return header;
@ -530,469 +532,6 @@ INSTANTIATE_TEST_CASE_P(TestCase, BooleanInstructionFoldingTest,
));
// clang-format on
using InstructionNotFoldedTest =
::testing::TestWithParam<InstructionFoldingCase<void*>>;
TEST_P(InstructionNotFoldedTest, Case) {
const auto& tc = GetParam();
// Build module.
std::unique_ptr<ir::IRContext> context =
BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.test_body,
SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
ASSERT_NE(nullptr, context);
// Fold the instruction to test.
opt::analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
ir::Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold);
inst = opt::FoldInstruction(inst);
// Make sure the instruction folded as expected.
EXPECT_EQ(inst, nullptr);
}
// clang-format off
INSTANTIATE_TEST_CASE_P(TestCase, InstructionNotFoldedTest,
::testing::Values(
// Test case 0: Don't fold n * m
InstructionFoldingCase<void*>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_int Function\n" +
"%m = OpVariable %_ptr_int Function\n" +
"%load_n = OpLoad %int %n\n" +
"%load_m = OpLoad %int %m\n" +
"%2 = OpIMul %int %load_n %load_m\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, nullptr),
// Test case 1: Don't fold n / m (unsigned)
InstructionFoldingCase<void*>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_uint Function\n" +
"%m = OpVariable %_ptr_uint Function\n" +
"%load_n = OpLoad %uint %n\n" +
"%load_m = OpLoad %uint %m\n" +
"%2 = OpUDiv %uint %load_n %load_m\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, nullptr),
// Test case 2: Don't fold n / m (signed)
InstructionFoldingCase<void*>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_int Function\n" +
"%m = OpVariable %_ptr_int Function\n" +
"%load_n = OpLoad %int %n\n" +
"%load_m = OpLoad %int %m\n" +
"%2 = OpSDiv %int %load_n %load_m\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, nullptr),
// Test case 3: Don't fold n remainder m
InstructionFoldingCase<void*>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_int Function\n" +
"%m = OpVariable %_ptr_int Function\n" +
"%load_n = OpLoad %int %n\n" +
"%load_m = OpLoad %int %m\n" +
"%2 = OpSRem %int %load_n %load_m\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, nullptr),
// Test case 4: Don't fold n % m (signed)
InstructionFoldingCase<void*>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_int Function\n" +
"%m = OpVariable %_ptr_int Function\n" +
"%load_n = OpLoad %int %n\n" +
"%load_m = OpLoad %int %m\n" +
"%2 = OpSMod %int %load_n %load_m\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, nullptr),
// Test case 5: Don't fold n % m (unsigned)
InstructionFoldingCase<void*>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_uint Function\n" +
"%m = OpVariable %_ptr_uint Function\n" +
"%load_n = OpLoad %uint %n\n" +
"%load_m = OpLoad %uint %m\n" +
"%2 = OpUMod %int %load_n %load_m\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, nullptr),
// Test case 6: Don't fold n << m
InstructionFoldingCase<void*>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_uint Function\n" +
"%m = OpVariable %_ptr_uint Function\n" +
"%load_n = OpLoad %uint %n\n" +
"%load_m = OpLoad %uint %m\n" +
"%2 = OpShiftRightLogical %int %load_n %load_m\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, nullptr),
// Test case 7: Don't fold n >> m
InstructionFoldingCase<void*>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_uint Function\n" +
"%m = OpVariable %_ptr_uint Function\n" +
"%load_n = OpLoad %uint %n\n" +
"%load_m = OpLoad %uint %m\n" +
"%2 = OpShiftLeftLogical %int %load_n %load_m\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, nullptr),
// Test case 8: Don't fold n | m
InstructionFoldingCase<void*>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_uint Function\n" +
"%m = OpVariable %_ptr_uint Function\n" +
"%load_n = OpLoad %uint %n\n" +
"%load_m = OpLoad %uint %m\n" +
"%2 = OpBitwiseOr %int %load_n %load_m\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, nullptr),
// Test case 9: Don't fold n & m
InstructionFoldingCase<void*>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_uint Function\n" +
"%m = OpVariable %_ptr_uint Function\n" +
"%load_n = OpLoad %uint %n\n" +
"%load_m = OpLoad %uint %m\n" +
"%2 = OpBitwiseAnd %int %load_n %load_m\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, nullptr),
// Test case 10: Don't fold n < m (unsigned)
InstructionFoldingCase<void*>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_uint Function\n" +
"%m = OpVariable %_ptr_uint Function\n" +
"%load_n = OpLoad %uint %n\n" +
"%load_m = OpLoad %uint %m\n" +
"%2 = OpULessThan %bool %load_n %load_m\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, nullptr),
// Test case 11: Don't fold n > m (unsigned)
InstructionFoldingCase<void*>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_uint Function\n" +
"%m = OpVariable %_ptr_uint Function\n" +
"%load_n = OpLoad %uint %n\n" +
"%load_m = OpLoad %uint %m\n" +
"%2 = OpUGreaterThan %bool %load_n %load_m\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, nullptr),
// Test case 12: Don't fold n <= m (unsigned)
InstructionFoldingCase<void*>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_uint Function\n" +
"%m = OpVariable %_ptr_uint Function\n" +
"%load_n = OpLoad %uint %n\n" +
"%load_m = OpLoad %uint %m\n" +
"%2 = OpULessThanEqual %bool %load_n %load_m\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, nullptr),
// Test case 13: Don't fold n >= m (unsigned)
InstructionFoldingCase<void*>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_uint Function\n" +
"%m = OpVariable %_ptr_uint Function\n" +
"%load_n = OpLoad %uint %n\n" +
"%load_m = OpLoad %uint %m\n" +
"%2 = OpUGreaterThanEqual %bool %load_n %load_m\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, nullptr),
// Test case 14: Don't fold n < m (signed)
InstructionFoldingCase<void*>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_int Function\n" +
"%m = OpVariable %_ptr_int Function\n" +
"%load_n = OpLoad %int %n\n" +
"%load_m = OpLoad %int %m\n" +
"%2 = OpULessThan %bool %load_n %load_m\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, nullptr),
// Test case 15: Don't fold n > m (signed)
InstructionFoldingCase<void*>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_int Function\n" +
"%m = OpVariable %_ptr_int Function\n" +
"%load_n = OpLoad %int %n\n" +
"%load_m = OpLoad %int %m\n" +
"%2 = OpUGreaterThan %bool %load_n %load_m\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, nullptr),
// Test case 16: Don't fold n <= m (signed)
InstructionFoldingCase<void*>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_int Function\n" +
"%m = OpVariable %_ptr_int Function\n" +
"%load_n = OpLoad %int %n\n" +
"%load_m = OpLoad %int %m\n" +
"%2 = OpULessThanEqual %bool %load_n %load_m\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, nullptr),
// Test case 17: Don't fold n >= m (signed)
InstructionFoldingCase<void*>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_int Function\n" +
"%m = OpVariable %_ptr_int Function\n" +
"%load_n = OpLoad %int %n\n" +
"%load_m = OpLoad %int %m\n" +
"%2 = OpUGreaterThanEqual %bool %load_n %load_m\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, nullptr),
// Test case 18: Don't fold n || m
InstructionFoldingCase<void*>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_bool Function\n" +
"%m = OpVariable %_ptr_bool Function\n" +
"%load_n = OpLoad %bool %n\n" +
"%load_m = OpLoad %bool %m\n" +
"%2 = OpLogicalOr %bool %load_n %load_m\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, nullptr),
// Test case 19: Don't fold n && m
InstructionFoldingCase<void*>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_bool Function\n" +
"%m = OpVariable %_ptr_bool Function\n" +
"%load_n = OpLoad %bool %n\n" +
"%load_m = OpLoad %bool %m\n" +
"%2 = OpLogicalAnd %bool %load_n %load_m\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, nullptr),
// Test case 20: Don't fold n * 3
InstructionFoldingCase<void*>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_int Function\n" +
"%load_n = OpLoad %int %n\n" +
"%2 = OpIMul %int %load_n %int_3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, nullptr),
// Test case 21: Don't fold n / 3 (unsigned)
InstructionFoldingCase<void*>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_uint Function\n" +
"%load_n = OpLoad %uint %n\n" +
"%2 = OpUDiv %uint %load_n %uint_3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, nullptr),
// Test case 22: Don't fold n / 3 (signed)
InstructionFoldingCase<void*>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_int Function\n" +
"%load_n = OpLoad %int %n\n" +
"%2 = OpSDiv %int %load_n %int_3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, nullptr),
// Test case 23: Don't fold n remainder 3
InstructionFoldingCase<void*>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_int Function\n" +
"%load_n = OpLoad %int %n\n" +
"%2 = OpSRem %int %load_n %int_3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, nullptr),
// Test case 24: Don't fold n % 3 (signed)
InstructionFoldingCase<void*>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_int Function\n" +
"%load_n = OpLoad %int %n\n" +
"%2 = OpSMod %int %load_n %int_3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, nullptr),
// Test case 25: Don't fold n % 3 (unsigned)
InstructionFoldingCase<void*>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_uint Function\n" +
"%load_n = OpLoad %uint %n\n" +
"%2 = OpUMod %int %load_n %int_3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, nullptr),
// Test case 26: Don't fold n << 3
InstructionFoldingCase<void*>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_uint Function\n" +
"%load_n = OpLoad %uint %n\n" +
"%2 = OpShiftRightLogical %int %load_n %int_3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, nullptr),
// Test case 27: Don't fold n >> 3
InstructionFoldingCase<void*>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_uint Function\n" +
"%load_n = OpLoad %uint %n\n" +
"%2 = OpShiftLeftLogical %int %load_n %int_3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, nullptr),
// Test case 28: Don't fold n | 3
InstructionFoldingCase<void*>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_uint Function\n" +
"%load_n = OpLoad %uint %n\n" +
"%2 = OpBitwiseOr %int %load_n %int_3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, nullptr),
// Test case 29: Don't fold n & 3
InstructionFoldingCase<void*>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_uint Function\n" +
"%load_n = OpLoad %uint %n\n" +
"%2 = OpBitwiseAnd %uint %load_n %uint_3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, nullptr),
// Test case 30: Don't fold n < 3 (unsigned)
InstructionFoldingCase<void*>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_uint Function\n" +
"%load_n = OpLoad %uint %n\n" +
"%2 = OpULessThan %bool %load_n %uint_3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, nullptr),
// Test case 31: Don't fold n > 3 (unsigned)
InstructionFoldingCase<void*>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_uint Function\n" +
"%load_n = OpLoad %uint %n\n" +
"%2 = OpUGreaterThan %bool %load_n %uint_3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, nullptr),
// Test case 32: Don't fold n <= 3 (unsigned)
InstructionFoldingCase<void*>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_uint Function\n" +
"%load_n = OpLoad %uint %n\n" +
"%2 = OpULessThanEqual %bool %load_n %uint_3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, nullptr),
// Test case 33: Don't fold n >= 3 (unsigned)
InstructionFoldingCase<void*>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_uint Function\n" +
"%load_n = OpLoad %uint %n\n" +
"%2 = OpUGreaterThanEqual %bool %load_n %uint_3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, nullptr),
// Test case 34: Don't fold n < 3 (signed)
InstructionFoldingCase<void*>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_int Function\n" +
"%load_n = OpLoad %int %n\n" +
"%2 = OpULessThan %bool %load_n %int_3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, nullptr),
// Test case 35: Don't fold n > 3 (signed)
InstructionFoldingCase<void*>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_int Function\n" +
"%load_n = OpLoad %int %n\n" +
"%2 = OpUGreaterThan %bool %load_n %int_3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, nullptr),
// Test case 36: Don't fold n <= 3 (signed)
InstructionFoldingCase<void*>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_int Function\n" +
"%load_n = OpLoad %int %n\n" +
"%2 = OpULessThanEqual %bool %load_n %int_3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, nullptr),
// Test case 37: Don't fold n >= 3 (signed)
InstructionFoldingCase<void*>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_int Function\n" +
"%load_n = OpLoad %int %n\n" +
"%2 = OpUGreaterThanEqual %bool %load_n %int_3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, nullptr),
// Test case 38: Don't fold 0 + 3 (long), bad length
InstructionFoldingCase<void*>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpIAdd %long %long_0 %long_3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, nullptr),
// Test case 39: Don't fold 0 + 3 (short), bad length
InstructionFoldingCase<void*>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpIAdd %short %short_0 %short_3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, nullptr)
));
// clang-format on
template <class ResultType>
struct InstructionFoldingCaseWithMap {
InstructionFoldingCaseWithMap(const std::string& tb, uint32_t id,
@ -1101,4 +640,504 @@ INSTANTIATE_TEST_CASE_P(TestCase, BooleanInstructionFoldingTestWithMap,
2, true, [](uint32_t id) {return (id == 3 ? TRUE_ID : id);})
));
// clang-format on
using GeneralInstructionFoldingTest =
::testing::TestWithParam<InstructionFoldingCase<uint32_t>>;
TEST_P(GeneralInstructionFoldingTest, Case) {
const auto& tc = GetParam();
// Build module.
std::unique_ptr<ir::IRContext> context =
BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.test_body,
SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
ASSERT_NE(nullptr, context);
// Fold the instruction to test.
opt::analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
ir::Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold);
inst = opt::FoldInstruction(inst);
// Make sure the instruction folded as expected.
EXPECT_TRUE((inst == nullptr) == (tc.expected_result == 0));
if (inst != nullptr) {
EXPECT_EQ(inst->result_id(), tc.expected_result);
}
}
// clang-format off
INSTANTIATE_TEST_CASE_P(TestCase, GeneralInstructionFoldingTest,
::testing::Values(
// Test case 0: Don't fold n * m
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_int Function\n" +
"%m = OpVariable %_ptr_int Function\n" +
"%load_n = OpLoad %int %n\n" +
"%load_m = OpLoad %int %m\n" +
"%2 = OpIMul %int %load_n %load_m\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 1: Don't fold n / m (unsigned)
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_uint Function\n" +
"%m = OpVariable %_ptr_uint Function\n" +
"%load_n = OpLoad %uint %n\n" +
"%load_m = OpLoad %uint %m\n" +
"%2 = OpUDiv %uint %load_n %load_m\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 2: Don't fold n / m (signed)
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_int Function\n" +
"%m = OpVariable %_ptr_int Function\n" +
"%load_n = OpLoad %int %n\n" +
"%load_m = OpLoad %int %m\n" +
"%2 = OpSDiv %int %load_n %load_m\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 3: Don't fold n remainder m
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_int Function\n" +
"%m = OpVariable %_ptr_int Function\n" +
"%load_n = OpLoad %int %n\n" +
"%load_m = OpLoad %int %m\n" +
"%2 = OpSRem %int %load_n %load_m\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 4: Don't fold n % m (signed)
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_int Function\n" +
"%m = OpVariable %_ptr_int Function\n" +
"%load_n = OpLoad %int %n\n" +
"%load_m = OpLoad %int %m\n" +
"%2 = OpSMod %int %load_n %load_m\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 5: Don't fold n % m (unsigned)
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_uint Function\n" +
"%m = OpVariable %_ptr_uint Function\n" +
"%load_n = OpLoad %uint %n\n" +
"%load_m = OpLoad %uint %m\n" +
"%2 = OpUMod %int %load_n %load_m\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 6: Don't fold n << m
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_uint Function\n" +
"%m = OpVariable %_ptr_uint Function\n" +
"%load_n = OpLoad %uint %n\n" +
"%load_m = OpLoad %uint %m\n" +
"%2 = OpShiftRightLogical %int %load_n %load_m\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 7: Don't fold n >> m
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_uint Function\n" +
"%m = OpVariable %_ptr_uint Function\n" +
"%load_n = OpLoad %uint %n\n" +
"%load_m = OpLoad %uint %m\n" +
"%2 = OpShiftLeftLogical %int %load_n %load_m\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 8: Don't fold n | m
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_uint Function\n" +
"%m = OpVariable %_ptr_uint Function\n" +
"%load_n = OpLoad %uint %n\n" +
"%load_m = OpLoad %uint %m\n" +
"%2 = OpBitwiseOr %int %load_n %load_m\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 9: Don't fold n & m
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_uint Function\n" +
"%m = OpVariable %_ptr_uint Function\n" +
"%load_n = OpLoad %uint %n\n" +
"%load_m = OpLoad %uint %m\n" +
"%2 = OpBitwiseAnd %int %load_n %load_m\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 10: Don't fold n < m (unsigned)
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_uint Function\n" +
"%m = OpVariable %_ptr_uint Function\n" +
"%load_n = OpLoad %uint %n\n" +
"%load_m = OpLoad %uint %m\n" +
"%2 = OpULessThan %bool %load_n %load_m\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 11: Don't fold n > m (unsigned)
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_uint Function\n" +
"%m = OpVariable %_ptr_uint Function\n" +
"%load_n = OpLoad %uint %n\n" +
"%load_m = OpLoad %uint %m\n" +
"%2 = OpUGreaterThan %bool %load_n %load_m\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 12: Don't fold n <= m (unsigned)
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_uint Function\n" +
"%m = OpVariable %_ptr_uint Function\n" +
"%load_n = OpLoad %uint %n\n" +
"%load_m = OpLoad %uint %m\n" +
"%2 = OpULessThanEqual %bool %load_n %load_m\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 13: Don't fold n >= m (unsigned)
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_uint Function\n" +
"%m = OpVariable %_ptr_uint Function\n" +
"%load_n = OpLoad %uint %n\n" +
"%load_m = OpLoad %uint %m\n" +
"%2 = OpUGreaterThanEqual %bool %load_n %load_m\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 14: Don't fold n < m (signed)
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_int Function\n" +
"%m = OpVariable %_ptr_int Function\n" +
"%load_n = OpLoad %int %n\n" +
"%load_m = OpLoad %int %m\n" +
"%2 = OpULessThan %bool %load_n %load_m\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 15: Don't fold n > m (signed)
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_int Function\n" +
"%m = OpVariable %_ptr_int Function\n" +
"%load_n = OpLoad %int %n\n" +
"%load_m = OpLoad %int %m\n" +
"%2 = OpUGreaterThan %bool %load_n %load_m\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 16: Don't fold n <= m (signed)
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_int Function\n" +
"%m = OpVariable %_ptr_int Function\n" +
"%load_n = OpLoad %int %n\n" +
"%load_m = OpLoad %int %m\n" +
"%2 = OpULessThanEqual %bool %load_n %load_m\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 17: Don't fold n >= m (signed)
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_int Function\n" +
"%m = OpVariable %_ptr_int Function\n" +
"%load_n = OpLoad %int %n\n" +
"%load_m = OpLoad %int %m\n" +
"%2 = OpUGreaterThanEqual %bool %load_n %load_m\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 18: Don't fold n || m
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_bool Function\n" +
"%m = OpVariable %_ptr_bool Function\n" +
"%load_n = OpLoad %bool %n\n" +
"%load_m = OpLoad %bool %m\n" +
"%2 = OpLogicalOr %bool %load_n %load_m\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 19: Don't fold n && m
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_bool Function\n" +
"%m = OpVariable %_ptr_bool Function\n" +
"%load_n = OpLoad %bool %n\n" +
"%load_m = OpLoad %bool %m\n" +
"%2 = OpLogicalAnd %bool %load_n %load_m\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 20: Don't fold n * 3
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_int Function\n" +
"%load_n = OpLoad %int %n\n" +
"%2 = OpIMul %int %load_n %int_3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 21: Don't fold n / 3 (unsigned)
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_uint Function\n" +
"%load_n = OpLoad %uint %n\n" +
"%2 = OpUDiv %uint %load_n %uint_3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 22: Don't fold n / 3 (signed)
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_int Function\n" +
"%load_n = OpLoad %int %n\n" +
"%2 = OpSDiv %int %load_n %int_3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 23: Don't fold n remainder 3
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_int Function\n" +
"%load_n = OpLoad %int %n\n" +
"%2 = OpSRem %int %load_n %int_3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 24: Don't fold n % 3 (signed)
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_int Function\n" +
"%load_n = OpLoad %int %n\n" +
"%2 = OpSMod %int %load_n %int_3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 25: Don't fold n % 3 (unsigned)
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_uint Function\n" +
"%load_n = OpLoad %uint %n\n" +
"%2 = OpUMod %int %load_n %int_3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 26: Don't fold n << 3
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_uint Function\n" +
"%load_n = OpLoad %uint %n\n" +
"%2 = OpShiftRightLogical %int %load_n %int_3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 27: Don't fold n >> 3
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_uint Function\n" +
"%load_n = OpLoad %uint %n\n" +
"%2 = OpShiftLeftLogical %int %load_n %int_3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 28: Don't fold n | 3
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_uint Function\n" +
"%load_n = OpLoad %uint %n\n" +
"%2 = OpBitwiseOr %int %load_n %int_3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 29: Don't fold n & 3
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_uint Function\n" +
"%load_n = OpLoad %uint %n\n" +
"%2 = OpBitwiseAnd %uint %load_n %uint_3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 30: Don't fold n < 3 (unsigned)
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_uint Function\n" +
"%load_n = OpLoad %uint %n\n" +
"%2 = OpULessThan %bool %load_n %uint_3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 31: Don't fold n > 3 (unsigned)
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_uint Function\n" +
"%load_n = OpLoad %uint %n\n" +
"%2 = OpUGreaterThan %bool %load_n %uint_3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 32: Don't fold n <= 3 (unsigned)
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_uint Function\n" +
"%load_n = OpLoad %uint %n\n" +
"%2 = OpULessThanEqual %bool %load_n %uint_3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 33: Don't fold n >= 3 (unsigned)
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_uint Function\n" +
"%load_n = OpLoad %uint %n\n" +
"%2 = OpUGreaterThanEqual %bool %load_n %uint_3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 34: Don't fold n < 3 (signed)
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_int Function\n" +
"%load_n = OpLoad %int %n\n" +
"%2 = OpULessThan %bool %load_n %int_3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 35: Don't fold n > 3 (signed)
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_int Function\n" +
"%load_n = OpLoad %int %n\n" +
"%2 = OpUGreaterThan %bool %load_n %int_3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 36: Don't fold n <= 3 (signed)
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_int Function\n" +
"%load_n = OpLoad %int %n\n" +
"%2 = OpULessThanEqual %bool %load_n %int_3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 37: Don't fold n >= 3 (signed)
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_int Function\n" +
"%load_n = OpLoad %int %n\n" +
"%2 = OpUGreaterThanEqual %bool %load_n %int_3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 38: Don't fold 0 + 3 (long), bad length
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpIAdd %long %long_0 %long_3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 39: Don't fold 0 + 3 (short), bad length
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpIAdd %short %short_0 %short_3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 40: fold 1*n
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_int Function\n" +
"%3 = OpLoad %int %n\n" +
"%2 = OpIMul %int %int_1 %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 3),
// Test case 41: fold n*1
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_int Function\n" +
"%3 = OpLoad %int %n\n" +
"%2 = OpIMul %int %3 %int_1\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 3),
// Test case 42: fold Insert feeding extract
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_int Function\n" +
"%2 = OpLoad %int %n\n" +
"%3 = OpCompositeInsert %v4int %2 %v4int_0_0_0_0 0\n" +
"%4 = OpCompositeInsert %v4int %int_1 %3 1\n" +
"%5 = OpCompositeInsert %v4int %int_1 %4 2\n" +
"%6 = OpCompositeInsert %v4int %int_1 %5 3\n" +
"%7 = OpCompositeExtract %int %6 0\n" +
"OpReturn\n" +
"OpFunctionEnd",
7, 2)
));
// clang-format off
} // anonymous namespace