mirror of
https://github.com/KhronosGroup/SPIRV-Tools
synced 2025-01-13 09:50:06 +00:00
Fold OpDot
Adding three rules to fold OpDot (implemented as two). - When an OpDot has two constants, then fold to the resulting const. - When one of the inputs is the 0 vector, then fold to zero. - When one of the inputs is a single 1 with 0s, then rewrite to an OpCompositeExtract of the appropriate element. This will help find even more folding opportunities. Contributes to #709.
This commit is contained in:
parent
3020104ff2
commit
53bc1623ec
@ -142,29 +142,6 @@ using BinaryScalarFoldingRule = std::function<const analysis::Constant*(
|
||||
const analysis::Type* result_type, const analysis::Constant* a,
|
||||
const analysis::Constant* b, analysis::ConstantManager*)>;
|
||||
|
||||
// Returns an std::vector containing the elements of |constant|. The type of
|
||||
// |constant| must be |Vector|.
|
||||
std::vector<const analysis::Constant*> GetVectorComponents(
|
||||
const analysis::Constant* constant, analysis::ConstantManager* const_mgr) {
|
||||
std::vector<const analysis::Constant*> components;
|
||||
const analysis::VectorConstant* a = constant->AsVectorConstant();
|
||||
const analysis::Vector* vector_type = constant->type()->AsVector();
|
||||
assert(vector_type != nullptr);
|
||||
if (a != nullptr) {
|
||||
for (uint32_t i = 0; i < vector_type->element_count(); ++i) {
|
||||
components.push_back(a->GetComponents()[i]);
|
||||
}
|
||||
} else {
|
||||
const analysis::Type* element_type = vector_type->element_type();
|
||||
const analysis::Constant* element_null_const =
|
||||
const_mgr->GetConstant(element_type, {});
|
||||
for (uint32_t i = 0; i < vector_type->element_count(); ++i) {
|
||||
components.push_back(element_null_const);
|
||||
}
|
||||
}
|
||||
return components;
|
||||
}
|
||||
|
||||
// Returns a |ConstantFoldingRule| that folds unary floating point scalar ops
|
||||
// using |scalar_rule| and unary float point vectors ops by applying
|
||||
// |scalar_rule| to the elements of the vector. The |ConstantFoldingRule|
|
||||
@ -193,7 +170,7 @@ ConstantFoldingRule FoldFPUnaryOp(UnaryScalarFoldingRule scalar_rule) {
|
||||
std::vector<const analysis::Constant*> a_components;
|
||||
std::vector<const analysis::Constant*> results_components;
|
||||
|
||||
a_components = GetVectorComponents(constants[0], const_mgr);
|
||||
a_components = constants[0]->GetVectorComponents(const_mgr);
|
||||
|
||||
// Fold each component of the vector.
|
||||
for (uint32_t i = 0; i < a_components.size(); ++i) {
|
||||
@ -244,8 +221,8 @@ ConstantFoldingRule FoldFPBinaryOp(BinaryScalarFoldingRule scalar_rule) {
|
||||
std::vector<const analysis::Constant*> b_components;
|
||||
std::vector<const analysis::Constant*> results_components;
|
||||
|
||||
a_components = GetVectorComponents(constants[0], const_mgr);
|
||||
b_components = GetVectorComponents(constants[1], const_mgr);
|
||||
a_components = constants[0]->GetVectorComponents(const_mgr);
|
||||
b_components = constants[1]->GetVectorComponents(const_mgr);
|
||||
|
||||
// Fold each component of the vector.
|
||||
for (uint32_t i = 0; i < a_components.size(); ++i) {
|
||||
@ -334,28 +311,29 @@ UnaryScalarFoldingRule FoldIToFOp() {
|
||||
|
||||
// This macro defines a |BinaryScalarFoldingRule| that applies |op|. The
|
||||
// operator |op| must work for both float and double, and use syntax "f1 op f2".
|
||||
#define FOLD_FPARITH_OP(op) \
|
||||
[](const analysis::Type* result_type, const analysis::Constant* a, \
|
||||
const analysis::Constant* b, \
|
||||
analysis::ConstantManager* const_mgr) -> const analysis::Constant* { \
|
||||
assert(result_type != nullptr && a != nullptr && b != nullptr); \
|
||||
assert(result_type == a->type() && result_type == b->type()); \
|
||||
const analysis::Float* float_type = result_type->AsFloat(); \
|
||||
assert(float_type != nullptr); \
|
||||
if (float_type->width() == 32) { \
|
||||
float fa = a->GetFloat(); \
|
||||
float fb = b->GetFloat(); \
|
||||
spvutils::FloatProxy<float> result(fa op fb); \
|
||||
std::vector<uint32_t> words = result.GetWords(); \
|
||||
return const_mgr->GetConstant(result_type, words); \
|
||||
} else if (float_type->width() == 64) { \
|
||||
double fa = a->GetDouble(); \
|
||||
double fb = b->GetDouble(); \
|
||||
spvutils::FloatProxy<double> result(fa op fb); \
|
||||
std::vector<uint32_t> words = result.GetWords(); \
|
||||
return const_mgr->GetConstant(result_type, words); \
|
||||
} \
|
||||
return nullptr; \
|
||||
#define FOLD_FPARITH_OP(op) \
|
||||
[](const analysis::Type* result_type, const analysis::Constant* a, \
|
||||
const analysis::Constant* b, \
|
||||
analysis::ConstantManager* const_mgr_in_macro) \
|
||||
-> const analysis::Constant* { \
|
||||
assert(result_type != nullptr && a != nullptr && b != nullptr); \
|
||||
assert(result_type == a->type() && result_type == b->type()); \
|
||||
const analysis::Float* float_type_in_macro = result_type->AsFloat(); \
|
||||
assert(float_type_in_macro != nullptr); \
|
||||
if (float_type_in_macro->width() == 32) { \
|
||||
float fa = a->GetFloat(); \
|
||||
float fb = b->GetFloat(); \
|
||||
spvutils::FloatProxy<float> result_in_macro(fa op fb); \
|
||||
std::vector<uint32_t> words_in_macro = result_in_macro.GetWords(); \
|
||||
return const_mgr_in_macro->GetConstant(result_type, words_in_macro); \
|
||||
} else if (float_type_in_macro->width() == 64) { \
|
||||
double fa = a->GetDouble(); \
|
||||
double fb = b->GetDouble(); \
|
||||
spvutils::FloatProxy<double> result_in_macro(fa op fb); \
|
||||
std::vector<uint32_t> words_in_macro = result_in_macro.GetWords(); \
|
||||
return const_mgr_in_macro->GetConstant(result_type, words_in_macro); \
|
||||
} \
|
||||
return nullptr; \
|
||||
}
|
||||
|
||||
// Define the folding rule for conversion between floating point and integer
|
||||
@ -447,6 +425,79 @@ ConstantFoldingRule FoldFOrdGreaterThanEqual() {
|
||||
ConstantFoldingRule FoldFUnordGreaterThanEqual() {
|
||||
return FoldFPBinaryOp(FOLD_FPCMP_OP(>=, false));
|
||||
}
|
||||
|
||||
// Folds an OpDot where all of the inputs are constants to a
|
||||
// constant. A new constant is created if necessary.
|
||||
ConstantFoldingRule FoldOpDotWithConstants() {
|
||||
return [](ir::Instruction* inst,
|
||||
const std::vector<const analysis::Constant*>& constants)
|
||||
-> const analysis::Constant* {
|
||||
ir::IRContext* context = inst->context();
|
||||
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
|
||||
analysis::TypeManager* type_mgr = context->get_type_mgr();
|
||||
const analysis::Type* new_type = type_mgr->GetType(inst->type_id());
|
||||
assert(new_type->AsFloat() && "OpDot should have a float return type.");
|
||||
const analysis::Float* float_type = new_type->AsFloat();
|
||||
|
||||
if (!inst->IsFloatingPointFoldingAllowed()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// If one of the operands is 0, then the result is 0.
|
||||
bool has_zero_operand = false;
|
||||
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
if (constants[i]) {
|
||||
if (constants[i]->AsNullConstant() ||
|
||||
constants[i]->AsVectorConstant()->IsZero()) {
|
||||
has_zero_operand = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (has_zero_operand) {
|
||||
if (float_type->width() == 32) {
|
||||
spvutils::FloatProxy<float> result(0.0f);
|
||||
std::vector<uint32_t> words = result.GetWords();
|
||||
return const_mgr->GetConstant(float_type, words);
|
||||
}
|
||||
if (float_type->width() == 64) {
|
||||
spvutils::FloatProxy<double> result(0.0);
|
||||
std::vector<uint32_t> words = result.GetWords();
|
||||
return const_mgr->GetConstant(float_type, words);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (constants[0] == nullptr || constants[1] == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::vector<const analysis::Constant*> a_components;
|
||||
std::vector<const analysis::Constant*> b_components;
|
||||
|
||||
a_components = constants[0]->GetVectorComponents(const_mgr);
|
||||
b_components = constants[1]->GetVectorComponents(const_mgr);
|
||||
|
||||
spvutils::FloatProxy<double> result(0.0);
|
||||
std::vector<uint32_t> words = result.GetWords();
|
||||
const analysis::Constant* result_const =
|
||||
const_mgr->GetConstant(float_type, words);
|
||||
for (uint32_t i = 0; i < a_components.size(); ++i) {
|
||||
if (a_components[i] == nullptr || b_components[i] == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const analysis::Constant* component = FOLD_FPARITH_OP(*)(
|
||||
new_type, a_components[i], b_components[i], const_mgr);
|
||||
result_const =
|
||||
FOLD_FPARITH_OP(+)(new_type, result_const, component, const_mgr);
|
||||
}
|
||||
return result_const;
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
spvtools::opt::ConstantFoldingRules::ConstantFoldingRules() {
|
||||
@ -464,6 +515,7 @@ spvtools::opt::ConstantFoldingRules::ConstantFoldingRules() {
|
||||
rules_[SpvOpConvertSToF].push_back(FoldIToF());
|
||||
rules_[SpvOpConvertUToF].push_back(FoldIToF());
|
||||
|
||||
rules_[SpvOpDot].push_back(FoldOpDotWithConstants());
|
||||
rules_[SpvOpFAdd].push_back(FoldFAdd());
|
||||
rules_[SpvOpFDiv].push_back(FoldFDiv());
|
||||
rules_[SpvOpFMul].push_back(FoldFMul());
|
||||
|
@ -306,6 +306,37 @@ const Constant* ConstantManager::GetConstant(
|
||||
return cst ? RegisterConstant(cst) : nullptr;
|
||||
}
|
||||
|
||||
bool VectorConstant::IsZero() const {
|
||||
for (const Constant* component : GetComponents()) {
|
||||
if (!component->AsNullConstant() &&
|
||||
!component->AsScalarConstant()->IsZero()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<const analysis::Constant*> Constant::GetVectorComponents(
|
||||
analysis::ConstantManager* const_mgr) const {
|
||||
std::vector<const analysis::Constant*> components;
|
||||
const analysis::VectorConstant* a = this->AsVectorConstant();
|
||||
const analysis::Vector* vector_type = this->type()->AsVector();
|
||||
assert(vector_type != nullptr);
|
||||
if (a != nullptr) {
|
||||
for (uint32_t i = 0; i < vector_type->element_count(); ++i) {
|
||||
components.push_back(a->GetComponents()[i]);
|
||||
}
|
||||
} else {
|
||||
const analysis::Type* element_type = vector_type->element_type();
|
||||
const analysis::Constant* element_null_const =
|
||||
const_mgr->GetConstant(element_type, {});
|
||||
for (uint32_t i = 0; i < vector_type->element_count(); ++i) {
|
||||
components.push_back(element_null_const);
|
||||
}
|
||||
}
|
||||
return components;
|
||||
}
|
||||
|
||||
} // namespace analysis
|
||||
} // namespace opt
|
||||
} // namespace spvtools
|
||||
|
@ -47,6 +47,7 @@ class VectorConstant;
|
||||
class MatrixConstant;
|
||||
class ArrayConstant;
|
||||
class NullConstant;
|
||||
class ConstantManager;
|
||||
|
||||
// Abstract class for a SPIR-V constant. It has a bunch of As<subclass> methods,
|
||||
// which is used as a way to probe the actual <subclass>
|
||||
@ -109,6 +110,11 @@ class Constant {
|
||||
|
||||
const Type* type() const { return type_; }
|
||||
|
||||
// Returns an std::vector containing the elements of |constant|. The type of
|
||||
// |constant| must be |Vector|.
|
||||
std::vector<const Constant*> GetVectorComponents(
|
||||
ConstantManager* const_mgr) const;
|
||||
|
||||
protected:
|
||||
Constant(const Type* ty) : type_(ty) {}
|
||||
|
||||
@ -334,6 +340,9 @@ class VectorConstant : public CompositeConstant {
|
||||
|
||||
const Type* component_type() const { return component_type_; }
|
||||
|
||||
// Returns true if the vector is all zeros.
|
||||
bool IsZero() const;
|
||||
|
||||
private:
|
||||
const Type* component_type_;
|
||||
};
|
||||
|
@ -13,6 +13,9 @@
|
||||
// limitations under the License.
|
||||
|
||||
#include "folding_rules.h"
|
||||
|
||||
#include <limits>
|
||||
|
||||
#include "latest_version_glsl_std_450_header.h"
|
||||
|
||||
namespace spvtools {
|
||||
@ -1812,6 +1815,79 @@ FoldingRule RedundantFMix() {
|
||||
};
|
||||
}
|
||||
|
||||
// This rule look for a dot with a constant vector containing a single 1 and
|
||||
// the rest 0s. This is the same as doing an extract.
|
||||
FoldingRule DotProductDoingExtract() {
|
||||
return [](ir::Instruction* inst,
|
||||
const std::vector<const analysis::Constant*>& constants) {
|
||||
assert(inst->opcode() == SpvOpDot && "Wrong opcode. Should be OpDot.");
|
||||
|
||||
ir::IRContext* context = inst->context();
|
||||
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
|
||||
|
||||
if (!inst->IsFloatingPointFoldingAllowed()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
if (!constants[i]) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const analysis::Vector* vector_type = constants[i]->type()->AsVector();
|
||||
assert(vector_type && "Inputs to OpDot must be vectors.");
|
||||
const analysis::Float* element_type =
|
||||
vector_type->element_type()->AsFloat();
|
||||
assert(element_type && "Inputs to OpDot must be vectors of floats.");
|
||||
uint32_t element_width = element_type->width();
|
||||
if (element_width != 32 && element_width != 64) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<const analysis::Constant*> components;
|
||||
components = constants[i]->GetVectorComponents(const_mgr);
|
||||
|
||||
const uint32_t kNotFound = std::numeric_limits<uint32_t>::max();
|
||||
|
||||
uint32_t component_with_one = kNotFound;
|
||||
bool all_others_zero = true;
|
||||
for (uint32_t j = 0; j < components.size(); ++j) {
|
||||
const analysis::Constant* element = components[j];
|
||||
double value =
|
||||
(element_width == 32 ? element->GetFloat() : element->GetDouble());
|
||||
if (value == 0.0) {
|
||||
continue;
|
||||
} else if (value == 1.0) {
|
||||
if (component_with_one == kNotFound) {
|
||||
component_with_one = j;
|
||||
} else {
|
||||
component_with_one = kNotFound;
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
all_others_zero = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!all_others_zero || component_with_one == kNotFound) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::vector<ir::Operand> operands;
|
||||
operands.push_back(
|
||||
{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1u - i)}});
|
||||
operands.push_back(
|
||||
{SPV_OPERAND_TYPE_LITERAL_INTEGER, {component_with_one}});
|
||||
|
||||
inst->SetOpcode(SpvOpCompositeExtract);
|
||||
inst->SetInOperands(std::move(operands));
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
spvtools::opt::FoldingRules::FoldingRules() {
|
||||
@ -1826,6 +1902,8 @@ spvtools::opt::FoldingRules::FoldingRules() {
|
||||
rules_[SpvOpCompositeExtract].push_back(CompositeConstructFeedingExtract());
|
||||
rules_[SpvOpCompositeExtract].push_back(VectorShuffleFeedingExtract());
|
||||
|
||||
rules_[SpvOpDot].push_back(DotProductDoingExtract());
|
||||
|
||||
rules_[SpvOpExtInst].push_back(RedundantFMix());
|
||||
|
||||
rules_[SpvOpFAdd].push_back(RedundantFAdd());
|
||||
|
@ -143,6 +143,7 @@ OpName %main "main"
|
||||
%v4float = OpTypeVector %float 4
|
||||
%v4double = OpTypeVector %double 4
|
||||
%v2float = OpTypeVector %float 2
|
||||
%v2double = OpTypeVector %double 2
|
||||
%v2bool = OpTypeVector %bool 2
|
||||
%struct_v2int_int_int = OpTypeStruct %v2int %int %int
|
||||
%_ptr_int = OpTypePointer Function %int
|
||||
@ -158,6 +159,7 @@ OpName %main "main"
|
||||
%_ptr_v4double = OpTypePointer Function %v4double
|
||||
%_ptr_struct_v2int_int_int = OpTypePointer Function %struct_v2int_int_int
|
||||
%_ptr_v2float = OpTypePointer Function %v2float
|
||||
%_ptr_v2double = OpTypePointer Function %v2double
|
||||
%short_0 = OpConstant %short 0
|
||||
%short_3 = OpConstant %short 3
|
||||
%100 = OpConstant %int 0 ; Need a def with an numerical id to define id maps.
|
||||
@ -196,6 +198,7 @@ OpName %main "main"
|
||||
%float16_2 = OpConstant %float16 2
|
||||
%float_n1 = OpConstant %float -1
|
||||
%104 = OpConstant %float 0 ; Need a def with an numerical id to define id maps.
|
||||
%float_null = OpConstantNull %float
|
||||
%float_0 = OpConstant %float 0
|
||||
%float_half = OpConstant %float 0.5
|
||||
%float_1 = OpConstant %float 1
|
||||
@ -203,6 +206,7 @@ OpName %main "main"
|
||||
%float_3 = OpConstant %float 3
|
||||
%float_4 = OpConstant %float 4
|
||||
%float_0p5 = OpConstant %float 0.5
|
||||
%v2float_0_0 = OpConstantComposite %v2float %float_0 %float_0
|
||||
%v2float_2_2 = OpConstantComposite %v2float %float_2 %float_2
|
||||
%v2float_2_3 = OpConstantComposite %v2float %float_2 %float_3
|
||||
%v2float_3_2 = OpConstantComposite %v2float %float_3 %float_2
|
||||
@ -211,10 +215,20 @@ OpName %main "main"
|
||||
%v2float_null = OpConstantNull %v2float
|
||||
%double_n1 = OpConstant %double -1
|
||||
%105 = OpConstant %double 0 ; Need a def with an numerical id to define id maps.
|
||||
%double_null = OpConstantNull %double
|
||||
%double_0 = OpConstant %double 0
|
||||
%double_1 = OpConstant %double 1
|
||||
%double_2 = OpConstant %double 2
|
||||
%double_3 = OpConstant %double 3
|
||||
%double_4 = OpConstant %double 4
|
||||
%double_0p5 = OpConstant %double 0.5
|
||||
%v2double_0_0 = OpConstantComposite %v2double %double_0 %double_0
|
||||
%v2double_2_2 = OpConstantComposite %v2double %double_2 %double_2
|
||||
%v2double_2_3 = OpConstantComposite %v2double %double_2 %double_3
|
||||
%v2double_3_2 = OpConstantComposite %v2double %double_3 %double_2
|
||||
%v2double_4_4 = OpConstantComposite %v2double %double_4 %double_4
|
||||
%v2double_2_0p5 = OpConstantComposite %v2double %double_2 %double_0p5
|
||||
%v2double_null = OpConstantNull %v2double
|
||||
%float_nan = OpConstant %float -0x1.8p+128
|
||||
%double_nan = OpConstant %double -0x1.8p+1024
|
||||
%108 = OpConstant %half 0
|
||||
@ -222,10 +236,12 @@ OpName %main "main"
|
||||
%106 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0
|
||||
%v4float_0_0_0_0 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0
|
||||
%v4float_0_0_0_1 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_1
|
||||
%v4float_0_1_0_0 = OpConstantComposite %v4float %float_0 %float_1 %float_null %float_0
|
||||
%v4float_1_1_1_1 = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1
|
||||
%107 = OpConstantComposite %v4double %double_0 %double_0 %double_0 %double_0
|
||||
%v4double_0_0_0_0 = OpConstantComposite %v4double %double_0 %double_0 %double_0 %double_0
|
||||
%v4double_0_0_0_1 = OpConstantComposite %v4double %double_0 %double_0 %double_0 %double_1
|
||||
%v4double_0_1_0_0 = OpConstantComposite %v4double %double_0 %double_1 %double_null %double_0
|
||||
%v4double_1_1_1_1 = OpConstantComposite %v4double %double_1 %double_1 %double_1 %double_1
|
||||
)";
|
||||
|
||||
@ -799,7 +815,55 @@ INSTANTIATE_TEST_CASE_P(FloatConstantFoldingTest, FloatInstructionFoldingTest,
|
||||
"%2 = OpFDiv %float %float_n1 %float_0\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
2, -std::numeric_limits<float>::infinity())
|
||||
2, -std::numeric_limits<float>::infinity()),
|
||||
// Test case 6: Fold (2.0, 3.0) dot (2.0, 0.5)
|
||||
InstructionFoldingCase<float>(
|
||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%2 = OpDot %float %v2float_2_3 %v2float_2_0p5\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
2, 5.5f),
|
||||
// Test case 7: Fold (0.0, 0.0) dot v
|
||||
InstructionFoldingCase<float>(
|
||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%v = OpVariable %_ptr_v2float Function\n" +
|
||||
"%2 = OpLoad %v2float %v\n" +
|
||||
"%3 = OpDot %float %v2float_0_0 %2\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
3, 0.0f),
|
||||
// Test case 8: Fold v dot (0.0, 0.0)
|
||||
InstructionFoldingCase<float>(
|
||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%v = OpVariable %_ptr_v2float Function\n" +
|
||||
"%2 = OpLoad %v2float %v\n" +
|
||||
"%3 = OpDot %float %2 %v2float_0_0\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
3, 0.0f),
|
||||
// Test case 9: Fold Null dot v
|
||||
InstructionFoldingCase<float>(
|
||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%v = OpVariable %_ptr_v2float Function\n" +
|
||||
"%2 = OpLoad %v2float %v\n" +
|
||||
"%3 = OpDot %float %v2float_null %2\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
3, 0.0f),
|
||||
// Test case 10: Fold v dot Null
|
||||
InstructionFoldingCase<float>(
|
||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%v = OpVariable %_ptr_v2float Function\n" +
|
||||
"%2 = OpLoad %v2float %v\n" +
|
||||
"%3 = OpDot %float %2 %v2float_null\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
3, 0.0f)
|
||||
));
|
||||
// clang-format on
|
||||
|
||||
@ -886,7 +950,55 @@ INSTANTIATE_TEST_CASE_P(DoubleConstantFoldingTest, DoubleInstructionFoldingTest,
|
||||
"%2 = OpFDiv %double %double_n1 %double_0\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
2, -std::numeric_limits<double>::infinity())
|
||||
2, -std::numeric_limits<double>::infinity()),
|
||||
// Test case 5: Fold (2.0, 3.0) dot (2.0, 0.5)
|
||||
InstructionFoldingCase<double>(
|
||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%2 = OpDot %double %v2double_2_3 %v2double_2_0p5\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
2, 5.5f),
|
||||
// Test case 6: Fold (0.0, 0.0) dot v
|
||||
InstructionFoldingCase<double>(
|
||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%v = OpVariable %_ptr_v2double Function\n" +
|
||||
"%2 = OpLoad %v2double %v\n" +
|
||||
"%3 = OpDot %double %v2double_0_0 %2\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
3, 0.0f),
|
||||
// Test case 7: Fold v dot (0.0, 0.0)
|
||||
InstructionFoldingCase<double>(
|
||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%v = OpVariable %_ptr_v2double Function\n" +
|
||||
"%2 = OpLoad %v2double %v\n" +
|
||||
"%3 = OpDot %double %2 %v2double_0_0\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
3, 0.0f),
|
||||
// Test case 8: Fold Null dot v
|
||||
InstructionFoldingCase<double>(
|
||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%v = OpVariable %_ptr_v2double Function\n" +
|
||||
"%2 = OpLoad %v2double %v\n" +
|
||||
"%3 = OpDot %double %v2double_null %2\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
3, 0.0f),
|
||||
// Test case 9: Fold v dot Null
|
||||
InstructionFoldingCase<double>(
|
||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%v = OpVariable %_ptr_v2double Function\n" +
|
||||
"%2 = OpLoad %v2double %v\n" +
|
||||
"%3 = OpDot %double %2 %v2double_null\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
3, 0.0f)
|
||||
));
|
||||
// clang-format on
|
||||
|
||||
@ -4466,5 +4578,87 @@ INSTANTIATE_TEST_CASE_P(CompositeExtractMatchingTest, MatchingInstructionFolding
|
||||
"OpFunctionEnd",
|
||||
4, true)
|
||||
));
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(DotProductMatchingTest, MatchingInstructionFoldingTest,
|
||||
::testing::Values(
|
||||
// Test case 0: Using OpDot to extract last element.
|
||||
InstructionFoldingCase<bool>(
|
||||
Header() +
|
||||
"; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
|
||||
"; CHECK: %3 = OpCompositeExtract [[float]] %2 3\n" +
|
||||
"%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%n = OpVariable %_ptr_v4float Function\n" +
|
||||
"%2 = OpLoad %v4float %n\n" +
|
||||
"%3 = OpDot %float %2 %v4float_0_0_0_1\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
3, true),
|
||||
// Test case 1: Using OpDot to extract last element.
|
||||
InstructionFoldingCase<bool>(
|
||||
Header() +
|
||||
"; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
|
||||
"; CHECK: %3 = OpCompositeExtract [[float]] %2 3\n" +
|
||||
"%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%n = OpVariable %_ptr_v4float Function\n" +
|
||||
"%2 = OpLoad %v4float %n\n" +
|
||||
"%3 = OpDot %float %v4float_0_0_0_1 %2\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
3, true),
|
||||
// Test case 2: Using OpDot to extract second element.
|
||||
InstructionFoldingCase<bool>(
|
||||
Header() +
|
||||
"; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
|
||||
"; CHECK: %3 = OpCompositeExtract [[float]] %2 1\n" +
|
||||
"%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%n = OpVariable %_ptr_v4float Function\n" +
|
||||
"%2 = OpLoad %v4float %n\n" +
|
||||
"%3 = OpDot %float %v4float_0_1_0_0 %2\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
3, true),
|
||||
// Test case 3: Using OpDot to extract last element.
|
||||
InstructionFoldingCase<bool>(
|
||||
Header() +
|
||||
"; CHECK: [[double:%\\w+]] = OpTypeFloat 64\n" +
|
||||
"; CHECK: %3 = OpCompositeExtract [[double]] %2 3\n" +
|
||||
"%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%n = OpVariable %_ptr_v4double Function\n" +
|
||||
"%2 = OpLoad %v4double %n\n" +
|
||||
"%3 = OpDot %double %2 %v4double_0_0_0_1\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
3, true),
|
||||
// Test case 4: Using OpDot to extract last element.
|
||||
InstructionFoldingCase<bool>(
|
||||
Header() +
|
||||
"; CHECK: [[double:%\\w+]] = OpTypeFloat 64\n" +
|
||||
"; CHECK: %3 = OpCompositeExtract [[double]] %2 3\n" +
|
||||
"%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%n = OpVariable %_ptr_v4double Function\n" +
|
||||
"%2 = OpLoad %v4double %n\n" +
|
||||
"%3 = OpDot %double %v4double_0_0_0_1 %2\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
3, true),
|
||||
// Test case 5: Using OpDot to extract second element.
|
||||
InstructionFoldingCase<bool>(
|
||||
Header() +
|
||||
"; CHECK: [[double:%\\w+]] = OpTypeFloat 64\n" +
|
||||
"; CHECK: %3 = OpCompositeExtract [[double]] %2 1\n" +
|
||||
"%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%n = OpVariable %_ptr_v4double Function\n" +
|
||||
"%2 = OpLoad %v4double %n\n" +
|
||||
"%3 = OpDot %double %v4double_0_1_0_0 %2\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
3, true)
|
||||
));
|
||||
#endif
|
||||
} // anonymous namespace
|
||||
|
Loading…
Reference in New Issue
Block a user