diff --git a/source/opt/const_folding_rules.cpp b/source/opt/const_folding_rules.cpp index 06a1a81e6..e0a17e991 100644 --- a/source/opt/const_folding_rules.cpp +++ b/source/opt/const_folding_rules.cpp @@ -296,6 +296,51 @@ ConstantFoldingRule FoldFPUnaryOp(UnaryScalarFoldingRule scalar_rule) { }; } +// Returns the result of folding the constants in |constants| according the +// |scalar_rule|. If |result_type| is a vector, then |scalar_rule| is applied +// per component. +const analysis::Constant* FoldFPBinaryOp( + BinaryScalarFoldingRule scalar_rule, uint32_t result_type_id, + const std::vector& constants, + IRContext* context) { + analysis::ConstantManager* const_mgr = context->get_constant_mgr(); + analysis::TypeManager* type_mgr = context->get_type_mgr(); + const analysis::Type* result_type = type_mgr->GetType(result_type_id); + const analysis::Vector* vector_type = result_type->AsVector(); + + if (constants[0] == nullptr || constants[1] == nullptr) { + return nullptr; + } + + if (vector_type != nullptr) { + std::vector a_components; + std::vector b_components; + std::vector results_components; + + a_components = constants[0]->GetVectorComponents(const_mgr); + b_components = constants[1]->GetVectorComponents(const_mgr); + + // Fold each component of the vector. + for (uint32_t i = 0; i < a_components.size(); ++i) { + results_components.push_back(scalar_rule(vector_type->element_type(), + a_components[i], b_components[i], + const_mgr)); + if (results_components[i] == nullptr) { + return nullptr; + } + } + + // Build the constant object and return it. + std::vector ids; + for (const analysis::Constant* member : results_components) { + ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id()); + } + return const_mgr->GetConstant(vector_type, ids); + } else { + return scalar_rule(result_type, constants[0], constants[1], const_mgr); + } +} + // Returns a |ConstantFoldingRule| that folds floating point scalars using // |scalar_rule| and vectors of floating point by applying |scalar_rule| to the // elements of the vector. The |ConstantFoldingRule| that is returned assumes @@ -305,46 +350,10 @@ ConstantFoldingRule FoldFPBinaryOp(BinaryScalarFoldingRule scalar_rule) { return [scalar_rule](IRContext* context, Instruction* inst, const std::vector& constants) -> const analysis::Constant* { - analysis::ConstantManager* const_mgr = context->get_constant_mgr(); - analysis::TypeManager* type_mgr = context->get_type_mgr(); - const analysis::Type* result_type = type_mgr->GetType(inst->type_id()); - const analysis::Vector* vector_type = result_type->AsVector(); - if (!inst->IsFloatingPointFoldingAllowed()) { return nullptr; } - - if (constants[0] == nullptr || constants[1] == nullptr) { - return nullptr; - } - - if (vector_type != nullptr) { - std::vector a_components; - std::vector b_components; - std::vector results_components; - - a_components = constants[0]->GetVectorComponents(const_mgr); - b_components = constants[1]->GetVectorComponents(const_mgr); - - // Fold each component of the vector. - for (uint32_t i = 0; i < a_components.size(); ++i) { - results_components.push_back(scalar_rule(vector_type->element_type(), - a_components[i], - b_components[i], const_mgr)); - if (results_components[i] == nullptr) { - return nullptr; - } - } - - // Build the constant object and return it. - std::vector ids; - for (const analysis::Constant* member : results_components) { - ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id()); - } - return const_mgr->GetConstant(vector_type, ids); - } else { - return scalar_rule(result_type, constants[0], constants[1], const_mgr); - } + return FoldFPBinaryOp(scalar_rule, inst->type_id(), constants, context); }; } @@ -435,29 +444,33 @@ UnaryScalarFoldingRule FoldQuantizeToF16Scalar() { // This macro defines a |BinaryScalarFoldingRule| that applies |op|. The // operator |op| must work for both float and double, and use syntax "f1 op f2". -#define FOLD_FPARITH_OP(op) \ - [](const analysis::Type* result_type, const analysis::Constant* a, \ - const analysis::Constant* b, \ - analysis::ConstantManager* const_mgr_in_macro) \ - -> const analysis::Constant* { \ - assert(result_type != nullptr && a != nullptr && b != nullptr); \ - assert(result_type == a->type() && result_type == b->type()); \ - const analysis::Float* float_type_in_macro = result_type->AsFloat(); \ - assert(float_type_in_macro != nullptr); \ - if (float_type_in_macro->width() == 32) { \ - float fa = a->GetFloat(); \ - float fb = b->GetFloat(); \ - utils::FloatProxy result_in_macro(fa op fb); \ - std::vector words_in_macro = result_in_macro.GetWords(); \ - return const_mgr_in_macro->GetConstant(result_type, words_in_macro); \ - } else if (float_type_in_macro->width() == 64) { \ - double fa = a->GetDouble(); \ - double fb = b->GetDouble(); \ - utils::FloatProxy result_in_macro(fa op fb); \ - std::vector words_in_macro = result_in_macro.GetWords(); \ - return const_mgr_in_macro->GetConstant(result_type, words_in_macro); \ - } \ - return nullptr; \ +#define FOLD_FPARITH_OP(op) \ + [](const analysis::Type* result_type_in_macro, const analysis::Constant* a, \ + const analysis::Constant* b, \ + analysis::ConstantManager* const_mgr_in_macro) \ + -> const analysis::Constant* { \ + assert(result_type_in_macro != nullptr && a != nullptr && b != nullptr); \ + assert(result_type_in_macro == a->type() && \ + result_type_in_macro == b->type()); \ + const analysis::Float* float_type_in_macro = \ + result_type_in_macro->AsFloat(); \ + assert(float_type_in_macro != nullptr); \ + if (float_type_in_macro->width() == 32) { \ + float fa = a->GetFloat(); \ + float fb = b->GetFloat(); \ + utils::FloatProxy result_in_macro(fa op fb); \ + std::vector words_in_macro = result_in_macro.GetWords(); \ + return const_mgr_in_macro->GetConstant(result_type_in_macro, \ + words_in_macro); \ + } else if (float_type_in_macro->width() == 64) { \ + double fa = a->GetDouble(); \ + double fb = b->GetDouble(); \ + utils::FloatProxy result_in_macro(fa op fb); \ + std::vector words_in_macro = result_in_macro.GetWords(); \ + return const_mgr_in_macro->GetConstant(result_type_in_macro, \ + words_in_macro); \ + } \ + return nullptr; \ } // Define the folding rule for conversion between floating point and integer @@ -834,31 +847,49 @@ ConstantFoldingRule FoldFMix() { } const analysis::Constant* one; - if (constants[1]->type()->AsFloat()->width() == 32) { - one = const_mgr->GetConstant(constants[1]->type(), + bool is_vector = false; + const analysis::Type* result_type = constants[1]->type(); + const analysis::Type* base_type = result_type; + if (base_type->AsVector()) { + is_vector = true; + base_type = base_type->AsVector()->element_type(); + } + assert(base_type->AsFloat() != nullptr && + "FMix is suppose to act on floats or vectors of floats."); + + if (base_type->AsFloat()->width() == 32) { + one = const_mgr->GetConstant(base_type, utils::FloatProxy(1.0f).GetWords()); } else { - one = const_mgr->GetConstant(constants[1]->type(), + one = const_mgr->GetConstant(base_type, utils::FloatProxy(1.0).GetWords()); } - const analysis::Constant* temp1 = - FOLD_FPARITH_OP(-)(constants[1]->type(), one, constants[3], const_mgr); + if (is_vector) { + uint32_t one_id = const_mgr->GetDefiningInstruction(one)->result_id(); + one = + const_mgr->GetConstant(result_type, std::vector(4, one_id)); + } + + const analysis::Constant* temp1 = FoldFPBinaryOp( + FOLD_FPARITH_OP(-), inst->type_id(), {one, constants[3]}, context); if (temp1 == nullptr) { return nullptr; } - const analysis::Constant* temp2 = FOLD_FPARITH_OP(*)( - constants[1]->type(), constants[1], temp1, const_mgr); + const analysis::Constant* temp2 = FoldFPBinaryOp( + FOLD_FPARITH_OP(*), inst->type_id(), {constants[1], temp1}, context); if (temp2 == nullptr) { return nullptr; } - const analysis::Constant* temp3 = FOLD_FPARITH_OP(*)( - constants[2]->type(), constants[2], constants[3], const_mgr); + const analysis::Constant* temp3 = + FoldFPBinaryOp(FOLD_FPARITH_OP(*), inst->type_id(), + {constants[2], constants[3]}, context); if (temp3 == nullptr) { return nullptr; } - return FOLD_FPARITH_OP(+)(temp2->type(), temp2, temp3, const_mgr); + return FoldFPBinaryOp(FOLD_FPARITH_OP(+), inst->type_id(), {temp2, temp3}, + context); }; } diff --git a/test/opt/fold_test.cpp b/test/opt/fold_test.cpp index b5998c722..f24f08e60 100644 --- a/test/opt/fold_test.cpp +++ b/test/opt/fold_test.cpp @@ -222,6 +222,7 @@ OpName %main "main" %v2float_3_2 = OpConstantComposite %v2float %float_3 %float_2 %v2float_4_4 = OpConstantComposite %v2float %float_4 %float_4 %v2float_2_0p5 = OpConstantComposite %v2float %float_2 %float_0p5 +%v2float_0p2_0p5 = OpConstantComposite %v2float %float_0p2 %float_0p5 %v2float_null = OpConstantNull %v2float %double_n1 = OpConstant %double -1 %105 = OpConstant %double 0 ; Need a def with an numerical id to define id maps. @@ -643,6 +644,58 @@ INSTANTIATE_TEST_SUITE_P(TestCase, IntVectorInstructionFoldingTest, )); // clang-format on +using FloatVectorInstructionFoldingTest = + ::testing::TestWithParam>>; + +TEST_P(FloatVectorInstructionFoldingTest, Case) { + const auto& tc = GetParam(); + + // Build module. + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.test_body, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ASSERT_NE(nullptr, context); + + // Fold the instruction to test. + analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); + Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold); + SpvOp original_opcode = inst->opcode(); + bool succeeded = context->get_instruction_folder().FoldInstruction(inst); + + // Make sure the instruction folded as expected. + EXPECT_EQ(succeeded, inst == nullptr || inst->opcode() != original_opcode); + if (succeeded && inst != nullptr) { + EXPECT_EQ(inst->opcode(), SpvOpCopyObject); + inst = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0)); + std::vector opcodes = {SpvOpConstantComposite}; + EXPECT_THAT(opcodes, Contains(inst->opcode())); + analysis::ConstantManager* const_mrg = context->get_constant_mgr(); + const analysis::Constant* result = const_mrg->GetConstantFromInst(inst); + EXPECT_NE(result, nullptr); + if (result != nullptr) { + const std::vector& componenets = + result->AsVectorConstant()->GetComponents(); + EXPECT_EQ(componenets.size(), tc.expected_result.size()); + for (size_t i = 0; i < componenets.size(); i++) { + EXPECT_EQ(tc.expected_result[i], componenets[i]->GetFloat()); + } + } + } +} + +// clang-format off +INSTANTIATE_TEST_SUITE_P(TestCase, FloatVectorInstructionFoldingTest, +::testing::Values( + // Test case 0: FMix {2.0, 2.0}, {2.0, 3.0} {0.2,0.5} + InstructionFoldingCase>( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpExtInst %v2float %1 FMix %v2float_2_3 %v2float_0_0 %v2float_0p2_0p5\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, {1.6f,1.5f}) +)); +// clang-format on using BooleanInstructionFoldingTest = ::testing::TestWithParam>;