From d9c00e1d2de10043f1d4968c4bced1863d1893c1 Mon Sep 17 00:00:00 2001 From: Steven Perron Date: Wed, 22 May 2019 02:15:02 -0400 Subject: [PATCH] Add folding rules for OpQuantizeToF16 (#2614) Adding the folding rules for OpQuantizeToF16, and fixed some matching tests to check identify new lines. --- source/opt/const_folding_rules.cpp | 26 ++++ ...ld_spec_constant_op_and_composite_pass.cpp | 1 + .../opt/fold_spec_const_op_composite_test.cpp | 13 ++ test/opt/fold_test.cpp | 143 +++++++++++++++--- 4 files changed, 158 insertions(+), 25 deletions(-) diff --git a/source/opt/const_folding_rules.cpp b/source/opt/const_folding_rules.cpp index 3df5a83ec..10fcde408 100644 --- a/source/opt/const_folding_rules.cpp +++ b/source/opt/const_folding_rules.cpp @@ -408,6 +408,28 @@ UnaryScalarFoldingRule FoldIToFOp() { }; } +// This defines a |UnaryScalarFoldingRule| that performs |OpQuantizeToF16|. +UnaryScalarFoldingRule FoldQuantizeToF16Scalar() { + return [](const analysis::Type* result_type, const analysis::Constant* a, + analysis::ConstantManager* const_mgr) -> const analysis::Constant* { + assert(result_type != nullptr && a != nullptr); + const analysis::Float* float_type = a->type()->AsFloat(); + assert(float_type != nullptr); + if (float_type->width() != 32) { + return nullptr; + } + + float fa = a->GetFloat(); + utils::HexFloat> orignal(fa); + utils::HexFloat> quantized(0); + utils::HexFloat> result(0.0f); + orignal.castTo(quantized, utils::round_direction::kToZero); + quantized.castTo(result, utils::round_direction::kToZero); + std::vector words = {result.getBits()}; + return const_mgr->GetConstant(result_type, words); + }; +} + // This macro defines a |BinaryScalarFoldingRule| that applies |op|. The // operator |op| must work for both float and double, and use syntax "f1 op f2". #define FOLD_FPARITH_OP(op) \ @@ -438,6 +460,9 @@ UnaryScalarFoldingRule FoldIToFOp() { // Define the folding rule for conversion between floating point and integer ConstantFoldingRule FoldFToI() { return FoldFPUnaryOp(FoldFToIOp()); } ConstantFoldingRule FoldIToF() { return FoldFPUnaryOp(FoldIToFOp()); } +ConstantFoldingRule FoldQuantizeToF16() { + return FoldFPUnaryOp(FoldQuantizeToF16Scalar()); +} // Define the folding rules for subtraction, addition, multiplication, and // division for floating point values. @@ -848,6 +873,7 @@ ConstantFoldingRules::ConstantFoldingRules() { rules_[SpvOpVectorTimesScalar].push_back(FoldVectorTimesScalar()); rules_[SpvOpFNegate].push_back(FoldFNegate()); + rules_[SpvOpQuantizeToF16].push_back(FoldQuantizeToF16()); } } // namespace opt } // namespace spvtools diff --git a/source/opt/fold_spec_constant_op_and_composite_pass.cpp b/source/opt/fold_spec_constant_op_and_composite_pass.cpp index d61daaecf..56d0137f3 100644 --- a/source/opt/fold_spec_constant_op_and_composite_pass.cpp +++ b/source/opt/fold_spec_constant_op_and_composite_pass.cpp @@ -122,6 +122,7 @@ bool FoldSpecConstantOpAndCompositePass::ProcessOpSpecConstantOp( case SpvOp::SpvOpCompositeExtract: case SpvOp::SpvOpVectorShuffle: case SpvOp::SpvOpCompositeInsert: + case SpvOp::SpvOpQuantizeToF16: folded_inst = FoldWithInstructionFolder(pos); break; default: diff --git a/test/opt/fold_spec_const_op_composite_test.cpp b/test/opt/fold_spec_const_op_composite_test.cpp index a9e7dd33d..c96dc8c5a 100644 --- a/test/opt/fold_spec_const_op_composite_test.cpp +++ b/test/opt/fold_spec_const_op_composite_test.cpp @@ -325,6 +325,19 @@ INSTANTIATE_TEST_SUITE_P( "%inner = OpConstantComposite %inner_struct %bool_true %signed_one %undef", "%outer = OpSpecConstantComposite %outer_struct %inner %signed_one", }, + }, + // Fold an QuantizetoF16 instruction + { + // original + { + "%float_1 = OpConstant %float 1", + "%quant_float = OpSpecConstantOp %float QuantizeToF16 %float_1", + }, + // expected + { + "%float_1 = OpConstant %float 1", + "%quant_float = OpConstant %float 1", + }, } // clang-format on }))); diff --git a/test/opt/fold_test.cpp b/test/opt/fold_test.cpp index d7b59171d..3ea320463 100644 --- a/test/opt/fold_test.cpp +++ b/test/opt/fold_test.cpp @@ -206,7 +206,14 @@ OpName %main "main" %float_2 = OpConstant %float 2 %float_3 = OpConstant %float 3 %float_4 = OpConstant %float 4 +%float_2049 = OpConstant %float 2049 +%float_n2049 = OpConstant %float -2049 %float_0p5 = OpConstant %float 0.5 +%float_pi = OpConstant %float 1.5555 +%float_1e16 = OpConstant %float 1e16 +%float_n1e16 = OpConstant %float -1e16 +%float_1en16 = OpConstant %float 1e-16 +%float_n1en16 = OpConstant %float -1e-16 %v2float_0_0 = OpConstantComposite %v2float %float_0 %float_0 %v2float_2_2 = OpConstantComposite %v2float %float_2 %float_2 %v2float_2_3 = OpConstantComposite %v2float %float_2 %float_3 @@ -1273,7 +1280,11 @@ TEST_P(FloatInstructionFoldingTest, Case) { const_mrg->GetConstantFromInst(inst)->AsFloatConstant(); EXPECT_NE(result, nullptr); if (result != nullptr) { - EXPECT_EQ(result->GetFloatValue(), tc.expected_result); + if (!std::isnan(tc.expected_result)) { + EXPECT_EQ(result->GetFloatValue(), tc.expected_result); + } else { + EXPECT_TRUE(std::isnan(result->GetFloatValue())); + } } } } @@ -1388,7 +1399,89 @@ INSTANTIATE_TEST_SUITE_P(FloatConstantFoldingTest, FloatInstructionFoldingTest, "%2 = OpFNegate %float %float_2\n" + "OpReturn\n" + "OpFunctionEnd", - 2, -2) + 2, -2), + // Test case 12: QuantizeToF16 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpQuantizeToF16 %float %float_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 1.0), + // Test case 13: QuantizeToF16 positive non exact + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpQuantizeToF16 %float %float_2049\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 2048), + // Test case 14: QuantizeToF16 negative non exact + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpQuantizeToF16 %float %float_n2049\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, -2048), + // Test case 15: QuantizeToF16 large positive + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpQuantizeToF16 %float %float_1e16\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, std::numeric_limits::infinity()), + // Test case 16: QuantizeToF16 large negative + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpQuantizeToF16 %float %float_n1e16\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, -std::numeric_limits::infinity()), + // Test case 17: QuantizeToF16 small positive + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpQuantizeToF16 %float %float_1en16\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0.0), + // Test case 18: QuantizeToF16 small negative + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpQuantizeToF16 %float %float_n1en16\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0.0), + // Test case 19: QuantizeToF16 nan + InstructionFoldingCase( + HeaderWithNaN() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpQuantizeToF16 %float %float_nan\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, std::numeric_limits::quiet_NaN()), + // Test case 20: QuantizeToF16 inf + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFDiv %float %float_1 %float_0\n" + + "%3 = OpQuantizeToF16 %float %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, std::numeric_limits::infinity()), + // Test case 21: QuantizeToF16 -inf + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFDiv %float %float_n1 %float_0\n" + + "%3 = OpQuantizeToF16 %float %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, -std::numeric_limits::infinity()) )); // clang-format on @@ -4051,7 +4144,7 @@ INSTANTIATE_TEST_SUITE_P(MergeNegateTest, MatchingInstructionFoldingTest, InstructionFoldingCase( Header() + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + - "; CHECK: [[float_n2:%\\w+]] = OpConstant [[float]] -2\n" + + "; CHECK: [[float_n2:%\\w+]] = OpConstant [[float]] -2{{[[:space:]]}}\n" + "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" + "; CHECK: %4 = OpFMul [[float]] [[ld]] [[float_n2]]\n" + "%main = OpFunction %void None %void_func\n" + @@ -4068,7 +4161,7 @@ INSTANTIATE_TEST_SUITE_P(MergeNegateTest, MatchingInstructionFoldingTest, InstructionFoldingCase( Header() + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + - "; CHECK: [[float_n2:%\\w+]] = OpConstant [[float]] -2\n" + + "; CHECK: [[float_n2:%\\w+]] = OpConstant [[float]] -2{{[[:space:]]}}\n" + "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" + "; CHECK: %4 = OpFMul [[float]] [[ld]] [[float_n2]]\n" + "%main = OpFunction %void None %void_func\n" + @@ -4102,7 +4195,7 @@ INSTANTIATE_TEST_SUITE_P(MergeNegateTest, MatchingInstructionFoldingTest, InstructionFoldingCase( Header() + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + - "; CHECK: [[float_n2:%\\w+]] = OpConstant [[float]] -2\n" + + "; CHECK: [[float_n2:%\\w+]] = OpConstant [[float]] -2{{[[:space:]]}}\n" + "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" + "; CHECK: %4 = OpFDiv [[float]] [[float_n2]] [[ld]]\n" + "%main = OpFunction %void None %void_func\n" + @@ -4119,7 +4212,7 @@ INSTANTIATE_TEST_SUITE_P(MergeNegateTest, MatchingInstructionFoldingTest, InstructionFoldingCase( Header() + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + - "; CHECK: [[float_n2:%\\w+]] = OpConstant [[float]] -2\n" + + "; CHECK: [[float_n2:%\\w+]] = OpConstant [[float]] -2{{[[:space:]]}}\n" + "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" + "; CHECK: %4 = OpFSub [[float]] [[float_n2]] [[ld]]\n" + "%main = OpFunction %void None %void_func\n" + @@ -4136,7 +4229,7 @@ INSTANTIATE_TEST_SUITE_P(MergeNegateTest, MatchingInstructionFoldingTest, InstructionFoldingCase( Header() + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + - "; CHECK: [[float_n2:%\\w+]] = OpConstant [[float]] -2\n" + + "; CHECK: [[float_n2:%\\w+]] = OpConstant [[float]] -2{{[[:space:]]}}\n" + "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" + "; CHECK: %4 = OpFSub [[float]] [[float_n2]] [[ld]]\n" + "%main = OpFunction %void None %void_func\n" + @@ -4218,7 +4311,7 @@ INSTANTIATE_TEST_SUITE_P(MergeNegateTest, MatchingInstructionFoldingTest, Header() + "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" + "; CHECK: OpConstant [[int]] -2147483648\n" + - "; CHECK: [[int_n2:%\\w+]] = OpConstant [[int]] -2\n" + + "; CHECK: [[int_n2:%\\w+]] = OpConstant [[int]] -2{{[[:space:]]}}\n" + "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" + "; CHECK: %4 = OpISub [[int]] [[int_n2]] [[ld]]\n" + "%main = OpFunction %void None %void_func\n" + @@ -4236,7 +4329,7 @@ INSTANTIATE_TEST_SUITE_P(MergeNegateTest, MatchingInstructionFoldingTest, Header() + "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" + "; CHECK: OpConstant [[int]] -2147483648\n" + - "; CHECK: [[int_n2:%\\w+]] = OpConstant [[int]] -2\n" + + "; CHECK: [[int_n2:%\\w+]] = OpConstant [[int]] -2{{[[:space:]]}}\n" + "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" + "; CHECK: %4 = OpISub [[int]] [[int_n2]] [[ld]]\n" + "%main = OpFunction %void None %void_func\n" + @@ -4287,7 +4380,7 @@ INSTANTIATE_TEST_SUITE_P(MergeNegateTest, MatchingInstructionFoldingTest, InstructionFoldingCase( Header() + "; CHECK: [[long:%\\w+]] = OpTypeInt 64 1\n" + - "; CHECK: [[long_n2:%\\w+]] = OpConstant [[long]] -2\n" + + "; CHECK: [[long_n2:%\\w+]] = OpConstant [[long]] -2{{[[:space:]]}}\n" + "; CHECK: [[ld:%\\w+]] = OpLoad [[long]]\n" + "; CHECK: %4 = OpISub [[long]] [[long_n2]] [[ld]]\n" + "%main = OpFunction %void None %void_func\n" + @@ -4337,11 +4430,11 @@ INSTANTIATE_TEST_SUITE_P(MergeNegateTest, MatchingInstructionFoldingTest, InstructionFoldingCase( Header() + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + - "; CHECK: [[v4float:%\\w+]] = OpTypeVector [[float]] 4\n" + - "; CHECK: [[float_n1:%\\w+]] = OpConstant [[float]] -1\n" + - "; CHECK: [[float_1:%\\w+]] = OpConstant [[float]] 1\n" + - "; CHECK: [[float_n2:%\\w+]] = OpConstant [[float]] -2\n" + - "; CHECK: [[float_n3:%\\w+]] = OpConstant [[float]] -3\n" + + "; CHECK: [[v4float:%\\w+]] = OpTypeVector [[float]] 4{{[[:space:]]}}\n" + + "; CHECK: [[float_n1:%\\w+]] = OpConstant [[float]] -1{{[[:space:]]}}\n" + + "; CHECK: [[float_1:%\\w+]] = OpConstant [[float]] 1{{[[:space:]]}}\n" + + "; CHECK: [[float_n2:%\\w+]] = OpConstant [[float]] -2{{[[:space:]]}}\n" + + "; CHECK: [[float_n3:%\\w+]] = OpConstant [[float]] -3{{[[:space:]]}}\n" + "; CHECK: [[v4float_1_n2_n1_n3:%\\w+]] = OpConstantComposite [[v4float]] [[float_1]] [[float_n2]] [[float_n1]] [[float_n3]]\n" + "; CHECK: %2 = OpCopyObject [[v4float]] [[v4float_1_n2_n1_n3]]\n" + "%main = OpFunction %void None %void_func\n" + @@ -4710,9 +4803,9 @@ INSTANTIATE_TEST_SUITE_P(MergeMulTest, MatchingInstructionFoldingTest, InstructionFoldingCase( Header() + "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" + - "; CHECK: [[v2int:%\\w+]] = OpTypeVector [[int]] 2\n" + - "; CHECK: OpConstant [[int]] -2147483648\n" + - "; CHECK: [[int_n2:%\\w+]] = OpConstant [[int]] -2\n" + + "; CHECK: [[v2int:%\\w+]] = OpTypeVector [[int]] 2{{[[:space:]]}}\n" + + "; CHECK: OpConstant [[int]] -2147483648{{[[:space:]]}}\n" + + "; CHECK: [[int_n2:%\\w+]] = OpConstant [[int]] -2{{[[:space:]]}}\n" + "; CHECK: [[v2int_n2_n2:%\\w+]] = OpConstantComposite [[v2int]] [[int_n2]] [[int_n2]]\n" + "; CHECK: [[ld:%\\w+]] = OpLoad [[v2int]]\n" + "; CHECK: %4 = OpIMul [[v2int]] [[ld]] [[v2int_n2_n2]]\n" + @@ -4730,9 +4823,9 @@ INSTANTIATE_TEST_SUITE_P(MergeMulTest, MatchingInstructionFoldingTest, InstructionFoldingCase( Header() + "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" + - "; CHECK: [[v2int:%\\w+]] = OpTypeVector [[int]] 2\n" + - "; CHECK: OpConstant [[int]] -2147483648\n" + - "; CHECK: [[int_n2:%\\w+]] = OpConstant [[int]] -2\n" + + "; CHECK: [[v2int:%\\w+]] = OpTypeVector [[int]] 2{{[[:space:]]}}\n" + + "; CHECK: OpConstant [[int]] -2147483648{{[[:space:]]}}\n" + + "; CHECK: [[int_n2:%\\w+]] = OpConstant [[int]] -2{{[[:space:]]}}\n" + "; CHECK: [[v2int_n2_n2:%\\w+]] = OpConstantComposite [[v2int]] [[int_n2]] [[int_n2]]\n" + "; CHECK: [[ld:%\\w+]] = OpLoad [[v2int]]\n" + "; CHECK: %4 = OpIMul [[v2int]] [[ld]] [[v2int_n2_n2]]\n" + @@ -5026,7 +5119,7 @@ INSTANTIATE_TEST_SUITE_P(MergeDivTest, MatchingInstructionFoldingTest, Header() + "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" + "; CHECK: OpConstant [[int]] -2147483648\n" + - "; CHECK: [[int_n2:%\\w+]] = OpConstant [[int]] -2\n" + + "; CHECK: [[int_n2:%\\w+]] = OpConstant [[int]] -2{{[[:space:]]}}\n" + "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" + "; CHECK: %4 = OpSDiv [[int]] [[ld]] [[int_n2]]\n" + "%main = OpFunction %void None %void_func\n" + @@ -5044,7 +5137,7 @@ INSTANTIATE_TEST_SUITE_P(MergeDivTest, MatchingInstructionFoldingTest, Header() + "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" + "; CHECK: OpConstant [[int]] -2147483648\n" + - "; CHECK: [[int_n2:%\\w+]] = OpConstant [[int]] -2\n" + + "; CHECK: [[int_n2:%\\w+]] = OpConstant [[int]] -2{{[[:space:]]}}\n" + "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" + "; CHECK: %4 = OpSDiv [[int]] [[int_n2]] [[ld]]\n" + "%main = OpFunction %void None %void_func\n" + @@ -5324,7 +5417,7 @@ INSTANTIATE_TEST_SUITE_P(MergeSubTest, MatchingInstructionFoldingTest, InstructionFoldingCase( Header() + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + - "; CHECK: [[float_n2:%\\w+]] = OpConstant [[float]] -2\n" + + "; CHECK: [[float_n2:%\\w+]] = OpConstant [[float]] -2{{[[:space:]]}}\n" + "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" + "; CHECK: %4 = OpFSub [[float]] [[float_n2]] [[ld]]\n" + "%main = OpFunction %void None %void_func\n" + @@ -5358,7 +5451,7 @@ INSTANTIATE_TEST_SUITE_P(MergeSubTest, MatchingInstructionFoldingTest, InstructionFoldingCase( Header() + "; CHECK: [[long:%\\w+]] = OpTypeInt 64 1\n" + - "; CHECK: [[long_n2:%\\w+]] = OpConstant [[long]] -2\n" + + "; CHECK: [[long_n2:%\\w+]] = OpConstant [[long]] -2{{[[:space:]]}}\n" + "; CHECK: [[ld:%\\w+]] = OpLoad [[long]]\n" + "; CHECK: %4 = OpISub [[long]] [[long_n2]] [[ld]]\n" + "%main = OpFunction %void None %void_func\n" +