Fold Fmix should accept vector operands. (#2826)

Fixes #2819
This commit is contained in:
Steven Perron 2019-09-03 09:17:18 -04:00 committed by GitHub
parent 2c5ed16ba9
commit b54d950298
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 154 additions and 70 deletions

View File

@ -296,6 +296,51 @@ ConstantFoldingRule FoldFPUnaryOp(UnaryScalarFoldingRule scalar_rule) {
}; };
} }
// Returns the result of folding the constants in |constants| according the
// |scalar_rule|. If |result_type| is a vector, then |scalar_rule| is applied
// per component.
const analysis::Constant* FoldFPBinaryOp(
BinaryScalarFoldingRule scalar_rule, uint32_t result_type_id,
const std::vector<const analysis::Constant*>& constants,
IRContext* context) {
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
analysis::TypeManager* type_mgr = context->get_type_mgr();
const analysis::Type* result_type = type_mgr->GetType(result_type_id);
const analysis::Vector* vector_type = result_type->AsVector();
if (constants[0] == nullptr || constants[1] == nullptr) {
return nullptr;
}
if (vector_type != nullptr) {
std::vector<const analysis::Constant*> a_components;
std::vector<const analysis::Constant*> b_components;
std::vector<const analysis::Constant*> results_components;
a_components = constants[0]->GetVectorComponents(const_mgr);
b_components = constants[1]->GetVectorComponents(const_mgr);
// Fold each component of the vector.
for (uint32_t i = 0; i < a_components.size(); ++i) {
results_components.push_back(scalar_rule(vector_type->element_type(),
a_components[i], b_components[i],
const_mgr));
if (results_components[i] == nullptr) {
return nullptr;
}
}
// Build the constant object and return it.
std::vector<uint32_t> ids;
for (const analysis::Constant* member : results_components) {
ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id());
}
return const_mgr->GetConstant(vector_type, ids);
} else {
return scalar_rule(result_type, constants[0], constants[1], const_mgr);
}
}
// Returns a |ConstantFoldingRule| that folds floating point scalars using // Returns a |ConstantFoldingRule| that folds floating point scalars using
// |scalar_rule| and vectors of floating point by applying |scalar_rule| to the // |scalar_rule| and vectors of floating point by applying |scalar_rule| to the
// elements of the vector. The |ConstantFoldingRule| that is returned assumes // elements of the vector. The |ConstantFoldingRule| that is returned assumes
@ -305,46 +350,10 @@ ConstantFoldingRule FoldFPBinaryOp(BinaryScalarFoldingRule scalar_rule) {
return [scalar_rule](IRContext* context, Instruction* inst, return [scalar_rule](IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>& constants) const std::vector<const analysis::Constant*>& constants)
-> const analysis::Constant* { -> const analysis::Constant* {
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
analysis::TypeManager* type_mgr = context->get_type_mgr();
const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
const analysis::Vector* vector_type = result_type->AsVector();
if (!inst->IsFloatingPointFoldingAllowed()) { if (!inst->IsFloatingPointFoldingAllowed()) {
return nullptr; return nullptr;
} }
return FoldFPBinaryOp(scalar_rule, inst->type_id(), constants, context);
if (constants[0] == nullptr || constants[1] == nullptr) {
return nullptr;
}
if (vector_type != nullptr) {
std::vector<const analysis::Constant*> a_components;
std::vector<const analysis::Constant*> b_components;
std::vector<const analysis::Constant*> results_components;
a_components = constants[0]->GetVectorComponents(const_mgr);
b_components = constants[1]->GetVectorComponents(const_mgr);
// Fold each component of the vector.
for (uint32_t i = 0; i < a_components.size(); ++i) {
results_components.push_back(scalar_rule(vector_type->element_type(),
a_components[i],
b_components[i], const_mgr));
if (results_components[i] == nullptr) {
return nullptr;
}
}
// Build the constant object and return it.
std::vector<uint32_t> ids;
for (const analysis::Constant* member : results_components) {
ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id());
}
return const_mgr->GetConstant(vector_type, ids);
} else {
return scalar_rule(result_type, constants[0], constants[1], const_mgr);
}
}; };
} }
@ -435,29 +444,33 @@ UnaryScalarFoldingRule FoldQuantizeToF16Scalar() {
// This macro defines a |BinaryScalarFoldingRule| that applies |op|. The // This macro defines a |BinaryScalarFoldingRule| that applies |op|. The
// operator |op| must work for both float and double, and use syntax "f1 op f2". // operator |op| must work for both float and double, and use syntax "f1 op f2".
#define FOLD_FPARITH_OP(op) \ #define FOLD_FPARITH_OP(op) \
[](const analysis::Type* result_type, const analysis::Constant* a, \ [](const analysis::Type* result_type_in_macro, const analysis::Constant* a, \
const analysis::Constant* b, \ const analysis::Constant* b, \
analysis::ConstantManager* const_mgr_in_macro) \ analysis::ConstantManager* const_mgr_in_macro) \
-> const analysis::Constant* { \ -> const analysis::Constant* { \
assert(result_type != nullptr && a != nullptr && b != nullptr); \ assert(result_type_in_macro != nullptr && a != nullptr && b != nullptr); \
assert(result_type == a->type() && result_type == b->type()); \ assert(result_type_in_macro == a->type() && \
const analysis::Float* float_type_in_macro = result_type->AsFloat(); \ result_type_in_macro == b->type()); \
assert(float_type_in_macro != nullptr); \ const analysis::Float* float_type_in_macro = \
if (float_type_in_macro->width() == 32) { \ result_type_in_macro->AsFloat(); \
float fa = a->GetFloat(); \ assert(float_type_in_macro != nullptr); \
float fb = b->GetFloat(); \ if (float_type_in_macro->width() == 32) { \
utils::FloatProxy<float> result_in_macro(fa op fb); \ float fa = a->GetFloat(); \
std::vector<uint32_t> words_in_macro = result_in_macro.GetWords(); \ float fb = b->GetFloat(); \
return const_mgr_in_macro->GetConstant(result_type, words_in_macro); \ utils::FloatProxy<float> result_in_macro(fa op fb); \
} else if (float_type_in_macro->width() == 64) { \ std::vector<uint32_t> words_in_macro = result_in_macro.GetWords(); \
double fa = a->GetDouble(); \ return const_mgr_in_macro->GetConstant(result_type_in_macro, \
double fb = b->GetDouble(); \ words_in_macro); \
utils::FloatProxy<double> result_in_macro(fa op fb); \ } else if (float_type_in_macro->width() == 64) { \
std::vector<uint32_t> words_in_macro = result_in_macro.GetWords(); \ double fa = a->GetDouble(); \
return const_mgr_in_macro->GetConstant(result_type, words_in_macro); \ double fb = b->GetDouble(); \
} \ utils::FloatProxy<double> result_in_macro(fa op fb); \
return nullptr; \ std::vector<uint32_t> words_in_macro = result_in_macro.GetWords(); \
return const_mgr_in_macro->GetConstant(result_type_in_macro, \
words_in_macro); \
} \
return nullptr; \
} }
// Define the folding rule for conversion between floating point and integer // Define the folding rule for conversion between floating point and integer
@ -834,31 +847,49 @@ ConstantFoldingRule FoldFMix() {
} }
const analysis::Constant* one; const analysis::Constant* one;
if (constants[1]->type()->AsFloat()->width() == 32) { bool is_vector = false;
one = const_mgr->GetConstant(constants[1]->type(), const analysis::Type* result_type = constants[1]->type();
const analysis::Type* base_type = result_type;
if (base_type->AsVector()) {
is_vector = true;
base_type = base_type->AsVector()->element_type();
}
assert(base_type->AsFloat() != nullptr &&
"FMix is suppose to act on floats or vectors of floats.");
if (base_type->AsFloat()->width() == 32) {
one = const_mgr->GetConstant(base_type,
utils::FloatProxy<float>(1.0f).GetWords()); utils::FloatProxy<float>(1.0f).GetWords());
} else { } else {
one = const_mgr->GetConstant(constants[1]->type(), one = const_mgr->GetConstant(base_type,
utils::FloatProxy<double>(1.0).GetWords()); utils::FloatProxy<double>(1.0).GetWords());
} }
const analysis::Constant* temp1 = if (is_vector) {
FOLD_FPARITH_OP(-)(constants[1]->type(), one, constants[3], const_mgr); uint32_t one_id = const_mgr->GetDefiningInstruction(one)->result_id();
one =
const_mgr->GetConstant(result_type, std::vector<uint32_t>(4, one_id));
}
const analysis::Constant* temp1 = FoldFPBinaryOp(
FOLD_FPARITH_OP(-), inst->type_id(), {one, constants[3]}, context);
if (temp1 == nullptr) { if (temp1 == nullptr) {
return nullptr; return nullptr;
} }
const analysis::Constant* temp2 = FOLD_FPARITH_OP(*)( const analysis::Constant* temp2 = FoldFPBinaryOp(
constants[1]->type(), constants[1], temp1, const_mgr); FOLD_FPARITH_OP(*), inst->type_id(), {constants[1], temp1}, context);
if (temp2 == nullptr) { if (temp2 == nullptr) {
return nullptr; return nullptr;
} }
const analysis::Constant* temp3 = FOLD_FPARITH_OP(*)( const analysis::Constant* temp3 =
constants[2]->type(), constants[2], constants[3], const_mgr); FoldFPBinaryOp(FOLD_FPARITH_OP(*), inst->type_id(),
{constants[2], constants[3]}, context);
if (temp3 == nullptr) { if (temp3 == nullptr) {
return nullptr; return nullptr;
} }
return FOLD_FPARITH_OP(+)(temp2->type(), temp2, temp3, const_mgr); return FoldFPBinaryOp(FOLD_FPARITH_OP(+), inst->type_id(), {temp2, temp3},
context);
}; };
} }

View File

@ -222,6 +222,7 @@ OpName %main "main"
%v2float_3_2 = OpConstantComposite %v2float %float_3 %float_2 %v2float_3_2 = OpConstantComposite %v2float %float_3 %float_2
%v2float_4_4 = OpConstantComposite %v2float %float_4 %float_4 %v2float_4_4 = OpConstantComposite %v2float %float_4 %float_4
%v2float_2_0p5 = OpConstantComposite %v2float %float_2 %float_0p5 %v2float_2_0p5 = OpConstantComposite %v2float %float_2 %float_0p5
%v2float_0p2_0p5 = OpConstantComposite %v2float %float_0p2 %float_0p5
%v2float_null = OpConstantNull %v2float %v2float_null = OpConstantNull %v2float
%double_n1 = OpConstant %double -1 %double_n1 = OpConstant %double -1
%105 = OpConstant %double 0 ; Need a def with an numerical id to define id maps. %105 = OpConstant %double 0 ; Need a def with an numerical id to define id maps.
@ -643,6 +644,58 @@ INSTANTIATE_TEST_SUITE_P(TestCase, IntVectorInstructionFoldingTest,
)); ));
// clang-format on // clang-format on
using FloatVectorInstructionFoldingTest =
::testing::TestWithParam<InstructionFoldingCase<std::vector<float>>>;
TEST_P(FloatVectorInstructionFoldingTest, Case) {
const auto& tc = GetParam();
// Build module.
std::unique_ptr<IRContext> context =
BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.test_body,
SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
ASSERT_NE(nullptr, context);
// Fold the instruction to test.
analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold);
SpvOp original_opcode = inst->opcode();
bool succeeded = context->get_instruction_folder().FoldInstruction(inst);
// Make sure the instruction folded as expected.
EXPECT_EQ(succeeded, inst == nullptr || inst->opcode() != original_opcode);
if (succeeded && inst != nullptr) {
EXPECT_EQ(inst->opcode(), SpvOpCopyObject);
inst = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0));
std::vector<SpvOp> opcodes = {SpvOpConstantComposite};
EXPECT_THAT(opcodes, Contains(inst->opcode()));
analysis::ConstantManager* const_mrg = context->get_constant_mgr();
const analysis::Constant* result = const_mrg->GetConstantFromInst(inst);
EXPECT_NE(result, nullptr);
if (result != nullptr) {
const std::vector<const analysis::Constant*>& componenets =
result->AsVectorConstant()->GetComponents();
EXPECT_EQ(componenets.size(), tc.expected_result.size());
for (size_t i = 0; i < componenets.size(); i++) {
EXPECT_EQ(tc.expected_result[i], componenets[i]->GetFloat());
}
}
}
}
// clang-format off
INSTANTIATE_TEST_SUITE_P(TestCase, FloatVectorInstructionFoldingTest,
::testing::Values(
// Test case 0: FMix {2.0, 2.0}, {2.0, 3.0} {0.2,0.5}
InstructionFoldingCase<std::vector<float>>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpExtInst %v2float %1 FMix %v2float_2_3 %v2float_0_0 %v2float_0p2_0p5\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, {1.6f,1.5f})
));
// clang-format on
using BooleanInstructionFoldingTest = using BooleanInstructionFoldingTest =
::testing::TestWithParam<InstructionFoldingCase<bool>>; ::testing::TestWithParam<InstructionFoldingCase<bool>>;