From bdaf8d56fbe3fa22ee699e33306ffd5f77b7762f Mon Sep 17 00:00:00 2001 From: GregF Date: Fri, 23 Feb 2018 15:46:30 -0700 Subject: [PATCH] Opt: Add constant folding for FToI and IToF --- source/opt/const_folding_rules.cpp | 178 ++++++++++++++++++++++++----- test/opt/fold_test.cpp | 40 +++++++ 2 files changed, 190 insertions(+), 28 deletions(-) diff --git a/source/opt/const_folding_rules.cpp b/source/opt/const_folding_rules.cpp index ad0f8f79a..a2d23c172 100644 --- a/source/opt/const_folding_rules.cpp +++ b/source/opt/const_folding_rules.cpp @@ -131,7 +131,14 @@ ConstantFoldingRule FoldCompositeWithConstants() { // The interface for a function that returns the result of applying a scalar // floating-point binary operation on |a| and |b|. The type of the return value // will be |type|. The input constants must also be of type |type|. -using FloatScalarFoldingRule = std::function; + +// The interface for a function that returns the result of applying a scalar +// floating-point binary operation on |a| and |b|. The type of the return value +// will be |type|. The input constants must also be of type |type|. +using BinaryScalarFoldingRule = std::function; @@ -158,12 +165,63 @@ std::vector GetVectorComponents( return components; } +// Returns a |ConstantFoldingRule| that folds unary floating point scalar ops +// using |scalar_rule| and unary float point vectors ops by applying +// |scalar_rule| to the elements of the vector. The |ConstantFoldingRule| +// that is returned assumes that |constants| contains 1 entry. If they are +// not |nullptr|, then their type is either |Float| or |Integer| or a |Vector| +// whose element type is |Float| or |Integer|. +ConstantFoldingRule FoldFPUnaryOp(UnaryScalarFoldingRule scalar_rule) { + return [scalar_rule](ir::Instruction* inst, + const std::vector& constants) + -> const analysis::Constant* { + ir::IRContext* context = inst->context(); + analysis::ConstantManager* const_mgr = context->get_constant_mgr(); + analysis::TypeManager* type_mgr = context->get_type_mgr(); + const analysis::Type* result_type = type_mgr->GetType(inst->type_id()); + const analysis::Vector* vector_type = result_type->AsVector(); + + if (!inst->IsFloatingPointFoldingAllowed()) { + return nullptr; + } + + if (constants[0] == nullptr) { + return nullptr; + } + + if (vector_type != nullptr) { + std::vector a_components; + std::vector results_components; + + a_components = GetVectorComponents(constants[0], const_mgr); + + // Fold each component of the vector. + for (uint32_t i = 0; i < a_components.size(); ++i) { + results_components.push_back(scalar_rule(vector_type->element_type(), + a_components[i], const_mgr)); + if (results_components[i] == nullptr) { + return nullptr; + } + } + + // Build the constant object and return it. + std::vector ids; + for (const analysis::Constant* member : results_components) { + ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id()); + } + return const_mgr->GetConstant(vector_type, ids); + } else { + return scalar_rule(result_type, constants[0], const_mgr); + } + }; +} + // Returns a |ConstantFoldingRule| that folds floating point scalars using // |scalar_rule| and vectors of floating point by applying |scalar_rule| to the // elements of the vector. The |ConstantFoldingRule| that is returned assumes // that |constants| contains 2 entries. If they are not |nullptr|, then their // type is either |Float| or a |Vector| whose element type is |Float|. -ConstantFoldingRule FoldFloatingPointOp(FloatScalarFoldingRule scalar_rule) { +ConstantFoldingRule FoldFPBinaryOp(BinaryScalarFoldingRule scalar_rule) { return [scalar_rule](ir::Instruction* inst, const std::vector& constants) -> const analysis::Constant* { @@ -211,7 +269,70 @@ ConstantFoldingRule FoldFloatingPointOp(FloatScalarFoldingRule scalar_rule) { }; } -// This macro defines a |FloatScalarFoldingRule| that applies |op|. The +// This macro defines a |UnaryScalarFoldingRule| that performs float to +// integer conversion. +// TODO(greg-lunarg): Support for 64-bit integer types. +UnaryScalarFoldingRule FoldFToIOp() { + return [](const analysis::Type* result_type, const analysis::Constant* a, + analysis::ConstantManager* const_mgr) -> const analysis::Constant* { + assert(result_type != nullptr && a != nullptr); + const analysis::Integer* integer_type = result_type->AsInteger(); + const analysis::Float* float_type = a->type()->AsFloat(); + assert(float_type != nullptr); + assert(integer_type != nullptr); + if (integer_type->width() != 32) return nullptr; + if (float_type->width() == 32) { + float fa = a->GetFloat(); + uint32_t result = integer_type->IsSigned() + ? static_cast(static_cast(fa)) + : static_cast(fa); + std::vector words = {result}; + return const_mgr->GetConstant(result_type, words); + } else if (float_type->width() == 64) { + double fa = a->GetDouble(); + uint32_t result = integer_type->IsSigned() + ? static_cast(static_cast(fa)) + : static_cast(fa); + std::vector words = {result}; + return const_mgr->GetConstant(result_type, words); + } + return nullptr; + }; +} + +// This macro defines a |UnaryScalarFoldingRule| that performs integer to +// float conversion. +// TODO(greg-lunarg): Support for 64-bit integer types. +UnaryScalarFoldingRule FoldIToFOp() { + return [](const analysis::Type* result_type, const analysis::Constant* a, + analysis::ConstantManager* const_mgr) -> const analysis::Constant* { + assert(result_type != nullptr && a != nullptr); + const analysis::Integer* integer_type = a->type()->AsInteger(); + const analysis::Float* float_type = result_type->AsFloat(); + assert(float_type != nullptr); + assert(integer_type != nullptr); + if (integer_type->width() != 32) return nullptr; + uint32_t ua = a->GetU32(); + if (float_type->width() == 32) { + float result_val = integer_type->IsSigned() + ? static_cast(static_cast(ua)) + : static_cast(ua); + spvutils::FloatProxy result(result_val); + std::vector words = {result.data()}; + return const_mgr->GetConstant(result_type, words); + } else if (float_type->width() == 64) { + double result_val = integer_type->IsSigned() + ? static_cast(static_cast(ua)) + : static_cast(ua); + spvutils::FloatProxy result(result_val); + std::vector words = result.GetWords(); + return const_mgr->GetConstant(result_type, words); + } + return nullptr; + }; +} + +// This macro defines a |BinaryScalarFoldingRule| that applies |op|. The // operator |op| must work for both float and double, and use syntax "f1 op f2". #define FOLD_FPARITH_OP(op) \ [](const analysis::Type* result_type, const analysis::Constant* a, \ @@ -237,20 +358,16 @@ ConstantFoldingRule FoldFloatingPointOp(FloatScalarFoldingRule scalar_rule) { return nullptr; \ } +// Define the folding rule for conversion between floating point and integer +ConstantFoldingRule FoldFToI() { return FoldFPUnaryOp(FoldFToIOp()); } +ConstantFoldingRule FoldIToF() { return FoldFPUnaryOp(FoldIToFOp()); } + // Define the folding rules for subtraction, addition, multiplication, and // division for floating point values. -ConstantFoldingRule FoldFSub() { - return FoldFloatingPointOp(FOLD_FPARITH_OP(-)); -} -ConstantFoldingRule FoldFAdd() { - return FoldFloatingPointOp(FOLD_FPARITH_OP(+)); -} -ConstantFoldingRule FoldFMul() { - return FoldFloatingPointOp(FOLD_FPARITH_OP(*)); -} -ConstantFoldingRule FoldFDiv() { - return FoldFloatingPointOp(FOLD_FPARITH_OP(/)); -} +ConstantFoldingRule FoldFSub() { return FoldFPBinaryOp(FOLD_FPARITH_OP(-)); } +ConstantFoldingRule FoldFAdd() { return FoldFPBinaryOp(FOLD_FPARITH_OP(+)); } +ConstantFoldingRule FoldFMul() { return FoldFPBinaryOp(FOLD_FPARITH_OP(*)); } +ConstantFoldingRule FoldFDiv() { return FoldFPBinaryOp(FOLD_FPARITH_OP(/)); } bool CompareFloatingPoint(bool op_result, bool op_unordered, bool need_ordered) { @@ -263,7 +380,7 @@ bool CompareFloatingPoint(bool op_result, bool op_unordered, } } -// This macro defines a |FloatScalarFoldingRule| that applies |op|. The +// This macro defines a |BinaryScalarFoldingRule| that applies |op|. The // operator |op| must work for both float and double, and use syntax "f1 op f2". #define FOLD_FPCMP_OP(op, ord) \ [](const analysis::Type* result_type, const analysis::Constant* a, \ @@ -295,40 +412,40 @@ bool CompareFloatingPoint(bool op_result, bool op_unordered, // Define the folding rules for ordered and unordered comparison for floating // point values. ConstantFoldingRule FoldFOrdEqual() { - return FoldFloatingPointOp(FOLD_FPCMP_OP(==, true)); + return FoldFPBinaryOp(FOLD_FPCMP_OP(==, true)); } ConstantFoldingRule FoldFUnordEqual() { - return FoldFloatingPointOp(FOLD_FPCMP_OP(==, false)); + return FoldFPBinaryOp(FOLD_FPCMP_OP(==, false)); } ConstantFoldingRule FoldFOrdNotEqual() { - return FoldFloatingPointOp(FOLD_FPCMP_OP(!=, true)); + return FoldFPBinaryOp(FOLD_FPCMP_OP(!=, true)); } ConstantFoldingRule FoldFUnordNotEqual() { - return FoldFloatingPointOp(FOLD_FPCMP_OP(!=, false)); + return FoldFPBinaryOp(FOLD_FPCMP_OP(!=, false)); } ConstantFoldingRule FoldFOrdLessThan() { - return FoldFloatingPointOp(FOLD_FPCMP_OP(<, true)); + return FoldFPBinaryOp(FOLD_FPCMP_OP(<, true)); } ConstantFoldingRule FoldFUnordLessThan() { - return FoldFloatingPointOp(FOLD_FPCMP_OP(<, false)); + return FoldFPBinaryOp(FOLD_FPCMP_OP(<, false)); } ConstantFoldingRule FoldFOrdGreaterThan() { - return FoldFloatingPointOp(FOLD_FPCMP_OP(>, true)); + return FoldFPBinaryOp(FOLD_FPCMP_OP(>, true)); } ConstantFoldingRule FoldFUnordGreaterThan() { - return FoldFloatingPointOp(FOLD_FPCMP_OP(>, false)); + return FoldFPBinaryOp(FOLD_FPCMP_OP(>, false)); } ConstantFoldingRule FoldFOrdLessThanEqual() { - return FoldFloatingPointOp(FOLD_FPCMP_OP(<=, true)); + return FoldFPBinaryOp(FOLD_FPCMP_OP(<=, true)); } ConstantFoldingRule FoldFUnordLessThanEqual() { - return FoldFloatingPointOp(FOLD_FPCMP_OP(<=, false)); + return FoldFPBinaryOp(FOLD_FPCMP_OP(<=, false)); } ConstantFoldingRule FoldFOrdGreaterThanEqual() { - return FoldFloatingPointOp(FOLD_FPCMP_OP(>=, true)); + return FoldFPBinaryOp(FOLD_FPCMP_OP(>=, true)); } ConstantFoldingRule FoldFUnordGreaterThanEqual() { - return FoldFloatingPointOp(FOLD_FPCMP_OP(>=, false)); + return FoldFPBinaryOp(FOLD_FPCMP_OP(>=, false)); } } // namespace @@ -342,6 +459,11 @@ spvtools::opt::ConstantFoldingRules::ConstantFoldingRules() { rules_[SpvOpCompositeExtract].push_back(FoldExtractWithConstants()); + rules_[SpvOpConvertFToS].push_back(FoldFToI()); + rules_[SpvOpConvertFToU].push_back(FoldFToI()); + rules_[SpvOpConvertSToF].push_back(FoldIToF()); + rules_[SpvOpConvertUToF].push_back(FoldIToF()); + rules_[SpvOpFAdd].push_back(FoldFAdd()); rules_[SpvOpFDiv].push_back(FoldFDiv()); rules_[SpvOpFMul].push_back(FoldFMul()); diff --git a/test/opt/fold_test.cpp b/test/opt/fold_test.cpp index 1816213a5..4e418b925 100644 --- a/test/opt/fold_test.cpp +++ b/test/opt/fold_test.cpp @@ -2756,6 +2756,46 @@ INSTANTIATE_TEST_CASE_P(DoubleVectorRedundantFoldingTest, GeneralInstructionFold "OpFunctionEnd", 2, 3) )); + +INSTANTIATE_TEST_CASE_P(FToIConstantFoldingTest, IntegerInstructionFoldingTest, + ::testing::Values( + // Test case 0: Fold int(3.0) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpConvertFToS %int %float_3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 3), + // Test case 1: Fold uint(3.0) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpConvertFToU %int %float_3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 3) +)); + +INSTANTIATE_TEST_CASE_P(IToFConstantFoldingTest, FloatInstructionFoldingTest, + ::testing::Values( + // Test case 0: Fold float(3) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpConvertSToF %float %int_3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 3.0), + // Test case 1: Fold float(3u) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpConvertUToF %float %uint_3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 3.0) +)); // clang-format on using ToNegateFoldingTest =