Fold OpDot

Adding three rules to fold OpDot (implemented as two).

- When an OpDot has two constants, then fold to the resulting const.

- When one of the inputs is the 0 vector, then fold to zero.

- When one of the inputs is a single 1 with 0s, then rewrite to an
OpCompositeExtract of the appropriate element.  This will help find
even more folding opportunities.

Contributes to #709.
This commit is contained in:
Steven Perron 2018-04-05 15:01:10 -04:00
parent 3020104ff2
commit 53bc1623ec
5 changed files with 414 additions and 50 deletions

View File

@ -142,29 +142,6 @@ using BinaryScalarFoldingRule = std::function<const analysis::Constant*(
const analysis::Type* result_type, const analysis::Constant* a,
const analysis::Constant* b, analysis::ConstantManager*)>;
// Returns an std::vector containing the elements of |constant|. The type of
// |constant| must be |Vector|.
std::vector<const analysis::Constant*> GetVectorComponents(
const analysis::Constant* constant, analysis::ConstantManager* const_mgr) {
std::vector<const analysis::Constant*> components;
const analysis::VectorConstant* a = constant->AsVectorConstant();
const analysis::Vector* vector_type = constant->type()->AsVector();
assert(vector_type != nullptr);
if (a != nullptr) {
for (uint32_t i = 0; i < vector_type->element_count(); ++i) {
components.push_back(a->GetComponents()[i]);
}
} else {
const analysis::Type* element_type = vector_type->element_type();
const analysis::Constant* element_null_const =
const_mgr->GetConstant(element_type, {});
for (uint32_t i = 0; i < vector_type->element_count(); ++i) {
components.push_back(element_null_const);
}
}
return components;
}
// Returns a |ConstantFoldingRule| that folds unary floating point scalar ops
// using |scalar_rule| and unary float point vectors ops by applying
// |scalar_rule| to the elements of the vector. The |ConstantFoldingRule|
@ -193,7 +170,7 @@ ConstantFoldingRule FoldFPUnaryOp(UnaryScalarFoldingRule scalar_rule) {
std::vector<const analysis::Constant*> a_components;
std::vector<const analysis::Constant*> results_components;
a_components = GetVectorComponents(constants[0], const_mgr);
a_components = constants[0]->GetVectorComponents(const_mgr);
// Fold each component of the vector.
for (uint32_t i = 0; i < a_components.size(); ++i) {
@ -244,8 +221,8 @@ ConstantFoldingRule FoldFPBinaryOp(BinaryScalarFoldingRule scalar_rule) {
std::vector<const analysis::Constant*> b_components;
std::vector<const analysis::Constant*> results_components;
a_components = GetVectorComponents(constants[0], const_mgr);
b_components = GetVectorComponents(constants[1], const_mgr);
a_components = constants[0]->GetVectorComponents(const_mgr);
b_components = constants[1]->GetVectorComponents(const_mgr);
// Fold each component of the vector.
for (uint32_t i = 0; i < a_components.size(); ++i) {
@ -334,28 +311,29 @@ UnaryScalarFoldingRule FoldIToFOp() {
// This macro defines a |BinaryScalarFoldingRule| that applies |op|. The
// operator |op| must work for both float and double, and use syntax "f1 op f2".
#define FOLD_FPARITH_OP(op) \
[](const analysis::Type* result_type, const analysis::Constant* a, \
const analysis::Constant* b, \
analysis::ConstantManager* const_mgr) -> const analysis::Constant* { \
assert(result_type != nullptr && a != nullptr && b != nullptr); \
assert(result_type == a->type() && result_type == b->type()); \
const analysis::Float* float_type = result_type->AsFloat(); \
assert(float_type != nullptr); \
if (float_type->width() == 32) { \
float fa = a->GetFloat(); \
float fb = b->GetFloat(); \
spvutils::FloatProxy<float> result(fa op fb); \
std::vector<uint32_t> words = result.GetWords(); \
return const_mgr->GetConstant(result_type, words); \
} else if (float_type->width() == 64) { \
double fa = a->GetDouble(); \
double fb = b->GetDouble(); \
spvutils::FloatProxy<double> result(fa op fb); \
std::vector<uint32_t> words = result.GetWords(); \
return const_mgr->GetConstant(result_type, words); \
} \
return nullptr; \
#define FOLD_FPARITH_OP(op) \
[](const analysis::Type* result_type, const analysis::Constant* a, \
const analysis::Constant* b, \
analysis::ConstantManager* const_mgr_in_macro) \
-> const analysis::Constant* { \
assert(result_type != nullptr && a != nullptr && b != nullptr); \
assert(result_type == a->type() && result_type == b->type()); \
const analysis::Float* float_type_in_macro = result_type->AsFloat(); \
assert(float_type_in_macro != nullptr); \
if (float_type_in_macro->width() == 32) { \
float fa = a->GetFloat(); \
float fb = b->GetFloat(); \
spvutils::FloatProxy<float> result_in_macro(fa op fb); \
std::vector<uint32_t> words_in_macro = result_in_macro.GetWords(); \
return const_mgr_in_macro->GetConstant(result_type, words_in_macro); \
} else if (float_type_in_macro->width() == 64) { \
double fa = a->GetDouble(); \
double fb = b->GetDouble(); \
spvutils::FloatProxy<double> result_in_macro(fa op fb); \
std::vector<uint32_t> words_in_macro = result_in_macro.GetWords(); \
return const_mgr_in_macro->GetConstant(result_type, words_in_macro); \
} \
return nullptr; \
}
// Define the folding rule for conversion between floating point and integer
@ -447,6 +425,79 @@ ConstantFoldingRule FoldFOrdGreaterThanEqual() {
ConstantFoldingRule FoldFUnordGreaterThanEqual() {
return FoldFPBinaryOp(FOLD_FPCMP_OP(>=, false));
}
// Folds an OpDot where all of the inputs are constants to a
// constant. A new constant is created if necessary.
ConstantFoldingRule FoldOpDotWithConstants() {
return [](ir::Instruction* inst,
const std::vector<const analysis::Constant*>& constants)
-> const analysis::Constant* {
ir::IRContext* context = inst->context();
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
analysis::TypeManager* type_mgr = context->get_type_mgr();
const analysis::Type* new_type = type_mgr->GetType(inst->type_id());
assert(new_type->AsFloat() && "OpDot should have a float return type.");
const analysis::Float* float_type = new_type->AsFloat();
if (!inst->IsFloatingPointFoldingAllowed()) {
return nullptr;
}
// If one of the operands is 0, then the result is 0.
bool has_zero_operand = false;
for (int i = 0; i < 2; ++i) {
if (constants[i]) {
if (constants[i]->AsNullConstant() ||
constants[i]->AsVectorConstant()->IsZero()) {
has_zero_operand = true;
break;
}
}
}
if (has_zero_operand) {
if (float_type->width() == 32) {
spvutils::FloatProxy<float> result(0.0f);
std::vector<uint32_t> words = result.GetWords();
return const_mgr->GetConstant(float_type, words);
}
if (float_type->width() == 64) {
spvutils::FloatProxy<double> result(0.0);
std::vector<uint32_t> words = result.GetWords();
return const_mgr->GetConstant(float_type, words);
}
return nullptr;
}
if (constants[0] == nullptr || constants[1] == nullptr) {
return nullptr;
}
std::vector<const analysis::Constant*> a_components;
std::vector<const analysis::Constant*> b_components;
a_components = constants[0]->GetVectorComponents(const_mgr);
b_components = constants[1]->GetVectorComponents(const_mgr);
spvutils::FloatProxy<double> result(0.0);
std::vector<uint32_t> words = result.GetWords();
const analysis::Constant* result_const =
const_mgr->GetConstant(float_type, words);
for (uint32_t i = 0; i < a_components.size(); ++i) {
if (a_components[i] == nullptr || b_components[i] == nullptr) {
return nullptr;
}
const analysis::Constant* component = FOLD_FPARITH_OP(*)(
new_type, a_components[i], b_components[i], const_mgr);
result_const =
FOLD_FPARITH_OP(+)(new_type, result_const, component, const_mgr);
}
return result_const;
};
}
} // namespace
spvtools::opt::ConstantFoldingRules::ConstantFoldingRules() {
@ -464,6 +515,7 @@ spvtools::opt::ConstantFoldingRules::ConstantFoldingRules() {
rules_[SpvOpConvertSToF].push_back(FoldIToF());
rules_[SpvOpConvertUToF].push_back(FoldIToF());
rules_[SpvOpDot].push_back(FoldOpDotWithConstants());
rules_[SpvOpFAdd].push_back(FoldFAdd());
rules_[SpvOpFDiv].push_back(FoldFDiv());
rules_[SpvOpFMul].push_back(FoldFMul());

View File

@ -306,6 +306,37 @@ const Constant* ConstantManager::GetConstant(
return cst ? RegisterConstant(cst) : nullptr;
}
bool VectorConstant::IsZero() const {
for (const Constant* component : GetComponents()) {
if (!component->AsNullConstant() &&
!component->AsScalarConstant()->IsZero()) {
return false;
}
}
return true;
}
std::vector<const analysis::Constant*> Constant::GetVectorComponents(
analysis::ConstantManager* const_mgr) const {
std::vector<const analysis::Constant*> components;
const analysis::VectorConstant* a = this->AsVectorConstant();
const analysis::Vector* vector_type = this->type()->AsVector();
assert(vector_type != nullptr);
if (a != nullptr) {
for (uint32_t i = 0; i < vector_type->element_count(); ++i) {
components.push_back(a->GetComponents()[i]);
}
} else {
const analysis::Type* element_type = vector_type->element_type();
const analysis::Constant* element_null_const =
const_mgr->GetConstant(element_type, {});
for (uint32_t i = 0; i < vector_type->element_count(); ++i) {
components.push_back(element_null_const);
}
}
return components;
}
} // namespace analysis
} // namespace opt
} // namespace spvtools

View File

@ -47,6 +47,7 @@ class VectorConstant;
class MatrixConstant;
class ArrayConstant;
class NullConstant;
class ConstantManager;
// Abstract class for a SPIR-V constant. It has a bunch of As<subclass> methods,
// which is used as a way to probe the actual <subclass>
@ -109,6 +110,11 @@ class Constant {
const Type* type() const { return type_; }
// Returns an std::vector containing the elements of |constant|. The type of
// |constant| must be |Vector|.
std::vector<const Constant*> GetVectorComponents(
ConstantManager* const_mgr) const;
protected:
Constant(const Type* ty) : type_(ty) {}
@ -334,6 +340,9 @@ class VectorConstant : public CompositeConstant {
const Type* component_type() const { return component_type_; }
// Returns true if the vector is all zeros.
bool IsZero() const;
private:
const Type* component_type_;
};

View File

@ -13,6 +13,9 @@
// limitations under the License.
#include "folding_rules.h"
#include <limits>
#include "latest_version_glsl_std_450_header.h"
namespace spvtools {
@ -1812,6 +1815,79 @@ FoldingRule RedundantFMix() {
};
}
// This rule look for a dot with a constant vector containing a single 1 and
// the rest 0s. This is the same as doing an extract.
FoldingRule DotProductDoingExtract() {
return [](ir::Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == SpvOpDot && "Wrong opcode. Should be OpDot.");
ir::IRContext* context = inst->context();
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
if (!inst->IsFloatingPointFoldingAllowed()) {
return false;
}
for (int i = 0; i < 2; ++i) {
if (!constants[i]) {
continue;
}
const analysis::Vector* vector_type = constants[i]->type()->AsVector();
assert(vector_type && "Inputs to OpDot must be vectors.");
const analysis::Float* element_type =
vector_type->element_type()->AsFloat();
assert(element_type && "Inputs to OpDot must be vectors of floats.");
uint32_t element_width = element_type->width();
if (element_width != 32 && element_width != 64) {
return false;
}
std::vector<const analysis::Constant*> components;
components = constants[i]->GetVectorComponents(const_mgr);
const uint32_t kNotFound = std::numeric_limits<uint32_t>::max();
uint32_t component_with_one = kNotFound;
bool all_others_zero = true;
for (uint32_t j = 0; j < components.size(); ++j) {
const analysis::Constant* element = components[j];
double value =
(element_width == 32 ? element->GetFloat() : element->GetDouble());
if (value == 0.0) {
continue;
} else if (value == 1.0) {
if (component_with_one == kNotFound) {
component_with_one = j;
} else {
component_with_one = kNotFound;
break;
}
} else {
all_others_zero = false;
break;
}
}
if (!all_others_zero || component_with_one == kNotFound) {
continue;
}
std::vector<ir::Operand> operands;
operands.push_back(
{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1u - i)}});
operands.push_back(
{SPV_OPERAND_TYPE_LITERAL_INTEGER, {component_with_one}});
inst->SetOpcode(SpvOpCompositeExtract);
inst->SetInOperands(std::move(operands));
return true;
}
return false;
};
}
} // namespace
spvtools::opt::FoldingRules::FoldingRules() {
@ -1826,6 +1902,8 @@ spvtools::opt::FoldingRules::FoldingRules() {
rules_[SpvOpCompositeExtract].push_back(CompositeConstructFeedingExtract());
rules_[SpvOpCompositeExtract].push_back(VectorShuffleFeedingExtract());
rules_[SpvOpDot].push_back(DotProductDoingExtract());
rules_[SpvOpExtInst].push_back(RedundantFMix());
rules_[SpvOpFAdd].push_back(RedundantFAdd());

View File

@ -143,6 +143,7 @@ OpName %main "main"
%v4float = OpTypeVector %float 4
%v4double = OpTypeVector %double 4
%v2float = OpTypeVector %float 2
%v2double = OpTypeVector %double 2
%v2bool = OpTypeVector %bool 2
%struct_v2int_int_int = OpTypeStruct %v2int %int %int
%_ptr_int = OpTypePointer Function %int
@ -158,6 +159,7 @@ OpName %main "main"
%_ptr_v4double = OpTypePointer Function %v4double
%_ptr_struct_v2int_int_int = OpTypePointer Function %struct_v2int_int_int
%_ptr_v2float = OpTypePointer Function %v2float
%_ptr_v2double = OpTypePointer Function %v2double
%short_0 = OpConstant %short 0
%short_3 = OpConstant %short 3
%100 = OpConstant %int 0 ; Need a def with an numerical id to define id maps.
@ -196,6 +198,7 @@ OpName %main "main"
%float16_2 = OpConstant %float16 2
%float_n1 = OpConstant %float -1
%104 = OpConstant %float 0 ; Need a def with an numerical id to define id maps.
%float_null = OpConstantNull %float
%float_0 = OpConstant %float 0
%float_half = OpConstant %float 0.5
%float_1 = OpConstant %float 1
@ -203,6 +206,7 @@ OpName %main "main"
%float_3 = OpConstant %float 3
%float_4 = OpConstant %float 4
%float_0p5 = OpConstant %float 0.5
%v2float_0_0 = OpConstantComposite %v2float %float_0 %float_0
%v2float_2_2 = OpConstantComposite %v2float %float_2 %float_2
%v2float_2_3 = OpConstantComposite %v2float %float_2 %float_3
%v2float_3_2 = OpConstantComposite %v2float %float_3 %float_2
@ -211,10 +215,20 @@ OpName %main "main"
%v2float_null = OpConstantNull %v2float
%double_n1 = OpConstant %double -1
%105 = OpConstant %double 0 ; Need a def with an numerical id to define id maps.
%double_null = OpConstantNull %double
%double_0 = OpConstant %double 0
%double_1 = OpConstant %double 1
%double_2 = OpConstant %double 2
%double_3 = OpConstant %double 3
%double_4 = OpConstant %double 4
%double_0p5 = OpConstant %double 0.5
%v2double_0_0 = OpConstantComposite %v2double %double_0 %double_0
%v2double_2_2 = OpConstantComposite %v2double %double_2 %double_2
%v2double_2_3 = OpConstantComposite %v2double %double_2 %double_3
%v2double_3_2 = OpConstantComposite %v2double %double_3 %double_2
%v2double_4_4 = OpConstantComposite %v2double %double_4 %double_4
%v2double_2_0p5 = OpConstantComposite %v2double %double_2 %double_0p5
%v2double_null = OpConstantNull %v2double
%float_nan = OpConstant %float -0x1.8p+128
%double_nan = OpConstant %double -0x1.8p+1024
%108 = OpConstant %half 0
@ -222,10 +236,12 @@ OpName %main "main"
%106 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0
%v4float_0_0_0_0 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0
%v4float_0_0_0_1 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_1
%v4float_0_1_0_0 = OpConstantComposite %v4float %float_0 %float_1 %float_null %float_0
%v4float_1_1_1_1 = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1
%107 = OpConstantComposite %v4double %double_0 %double_0 %double_0 %double_0
%v4double_0_0_0_0 = OpConstantComposite %v4double %double_0 %double_0 %double_0 %double_0
%v4double_0_0_0_1 = OpConstantComposite %v4double %double_0 %double_0 %double_0 %double_1
%v4double_0_1_0_0 = OpConstantComposite %v4double %double_0 %double_1 %double_null %double_0
%v4double_1_1_1_1 = OpConstantComposite %v4double %double_1 %double_1 %double_1 %double_1
)";
@ -799,7 +815,55 @@ INSTANTIATE_TEST_CASE_P(FloatConstantFoldingTest, FloatInstructionFoldingTest,
"%2 = OpFDiv %float %float_n1 %float_0\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, -std::numeric_limits<float>::infinity())
2, -std::numeric_limits<float>::infinity()),
// Test case 6: Fold (2.0, 3.0) dot (2.0, 0.5)
InstructionFoldingCase<float>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpDot %float %v2float_2_3 %v2float_2_0p5\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 5.5f),
// Test case 7: Fold (0.0, 0.0) dot v
InstructionFoldingCase<float>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%v = OpVariable %_ptr_v2float Function\n" +
"%2 = OpLoad %v2float %v\n" +
"%3 = OpDot %float %v2float_0_0 %2\n" +
"OpReturn\n" +
"OpFunctionEnd",
3, 0.0f),
// Test case 8: Fold v dot (0.0, 0.0)
InstructionFoldingCase<float>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%v = OpVariable %_ptr_v2float Function\n" +
"%2 = OpLoad %v2float %v\n" +
"%3 = OpDot %float %2 %v2float_0_0\n" +
"OpReturn\n" +
"OpFunctionEnd",
3, 0.0f),
// Test case 9: Fold Null dot v
InstructionFoldingCase<float>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%v = OpVariable %_ptr_v2float Function\n" +
"%2 = OpLoad %v2float %v\n" +
"%3 = OpDot %float %v2float_null %2\n" +
"OpReturn\n" +
"OpFunctionEnd",
3, 0.0f),
// Test case 10: Fold v dot Null
InstructionFoldingCase<float>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%v = OpVariable %_ptr_v2float Function\n" +
"%2 = OpLoad %v2float %v\n" +
"%3 = OpDot %float %2 %v2float_null\n" +
"OpReturn\n" +
"OpFunctionEnd",
3, 0.0f)
));
// clang-format on
@ -886,7 +950,55 @@ INSTANTIATE_TEST_CASE_P(DoubleConstantFoldingTest, DoubleInstructionFoldingTest,
"%2 = OpFDiv %double %double_n1 %double_0\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, -std::numeric_limits<double>::infinity())
2, -std::numeric_limits<double>::infinity()),
// Test case 5: Fold (2.0, 3.0) dot (2.0, 0.5)
InstructionFoldingCase<double>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpDot %double %v2double_2_3 %v2double_2_0p5\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 5.5f),
// Test case 6: Fold (0.0, 0.0) dot v
InstructionFoldingCase<double>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%v = OpVariable %_ptr_v2double Function\n" +
"%2 = OpLoad %v2double %v\n" +
"%3 = OpDot %double %v2double_0_0 %2\n" +
"OpReturn\n" +
"OpFunctionEnd",
3, 0.0f),
// Test case 7: Fold v dot (0.0, 0.0)
InstructionFoldingCase<double>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%v = OpVariable %_ptr_v2double Function\n" +
"%2 = OpLoad %v2double %v\n" +
"%3 = OpDot %double %2 %v2double_0_0\n" +
"OpReturn\n" +
"OpFunctionEnd",
3, 0.0f),
// Test case 8: Fold Null dot v
InstructionFoldingCase<double>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%v = OpVariable %_ptr_v2double Function\n" +
"%2 = OpLoad %v2double %v\n" +
"%3 = OpDot %double %v2double_null %2\n" +
"OpReturn\n" +
"OpFunctionEnd",
3, 0.0f),
// Test case 9: Fold v dot Null
InstructionFoldingCase<double>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%v = OpVariable %_ptr_v2double Function\n" +
"%2 = OpLoad %v2double %v\n" +
"%3 = OpDot %double %2 %v2double_null\n" +
"OpReturn\n" +
"OpFunctionEnd",
3, 0.0f)
));
// clang-format on
@ -4466,5 +4578,87 @@ INSTANTIATE_TEST_CASE_P(CompositeExtractMatchingTest, MatchingInstructionFolding
"OpFunctionEnd",
4, true)
));
INSTANTIATE_TEST_CASE_P(DotProductMatchingTest, MatchingInstructionFoldingTest,
::testing::Values(
// Test case 0: Using OpDot to extract last element.
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
"; CHECK: %3 = OpCompositeExtract [[float]] %2 3\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_v4float Function\n" +
"%2 = OpLoad %v4float %n\n" +
"%3 = OpDot %float %2 %v4float_0_0_0_1\n" +
"OpReturn\n" +
"OpFunctionEnd",
3, true),
// Test case 1: Using OpDot to extract last element.
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
"; CHECK: %3 = OpCompositeExtract [[float]] %2 3\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_v4float Function\n" +
"%2 = OpLoad %v4float %n\n" +
"%3 = OpDot %float %v4float_0_0_0_1 %2\n" +
"OpReturn\n" +
"OpFunctionEnd",
3, true),
// Test case 2: Using OpDot to extract second element.
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
"; CHECK: %3 = OpCompositeExtract [[float]] %2 1\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_v4float Function\n" +
"%2 = OpLoad %v4float %n\n" +
"%3 = OpDot %float %v4float_0_1_0_0 %2\n" +
"OpReturn\n" +
"OpFunctionEnd",
3, true),
// Test case 3: Using OpDot to extract last element.
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[double:%\\w+]] = OpTypeFloat 64\n" +
"; CHECK: %3 = OpCompositeExtract [[double]] %2 3\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_v4double Function\n" +
"%2 = OpLoad %v4double %n\n" +
"%3 = OpDot %double %2 %v4double_0_0_0_1\n" +
"OpReturn\n" +
"OpFunctionEnd",
3, true),
// Test case 4: Using OpDot to extract last element.
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[double:%\\w+]] = OpTypeFloat 64\n" +
"; CHECK: %3 = OpCompositeExtract [[double]] %2 3\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_v4double Function\n" +
"%2 = OpLoad %v4double %n\n" +
"%3 = OpDot %double %v4double_0_0_0_1 %2\n" +
"OpReturn\n" +
"OpFunctionEnd",
3, true),
// Test case 5: Using OpDot to extract second element.
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[double:%\\w+]] = OpTypeFloat 64\n" +
"; CHECK: %3 = OpCompositeExtract [[double]] %2 1\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_v4double Function\n" +
"%2 = OpLoad %v4double %n\n" +
"%3 = OpDot %double %v4double_0_1_0_0 %2\n" +
"OpReturn\n" +
"OpFunctionEnd",
3, true)
));
#endif
} // anonymous namespace