Add folding rules for OpQuantizeToF16 (#2614)

Adding the folding rules for OpQuantizeToF16, and fixed some matching
tests to check identify new lines.
This commit is contained in:
Steven Perron 2019-05-22 02:15:02 -04:00 committed by David Neto
parent 713da30b63
commit d9c00e1d2d
4 changed files with 158 additions and 25 deletions

View File

@ -408,6 +408,28 @@ UnaryScalarFoldingRule FoldIToFOp() {
};
}
// This defines a |UnaryScalarFoldingRule| that performs |OpQuantizeToF16|.
UnaryScalarFoldingRule FoldQuantizeToF16Scalar() {
return [](const analysis::Type* result_type, const analysis::Constant* a,
analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
assert(result_type != nullptr && a != nullptr);
const analysis::Float* float_type = a->type()->AsFloat();
assert(float_type != nullptr);
if (float_type->width() != 32) {
return nullptr;
}
float fa = a->GetFloat();
utils::HexFloat<utils::FloatProxy<float>> orignal(fa);
utils::HexFloat<utils::FloatProxy<utils::Float16>> quantized(0);
utils::HexFloat<utils::FloatProxy<float>> result(0.0f);
orignal.castTo(quantized, utils::round_direction::kToZero);
quantized.castTo(result, utils::round_direction::kToZero);
std::vector<uint32_t> words = {result.getBits()};
return const_mgr->GetConstant(result_type, words);
};
}
// This macro defines a |BinaryScalarFoldingRule| that applies |op|. The
// operator |op| must work for both float and double, and use syntax "f1 op f2".
#define FOLD_FPARITH_OP(op) \
@ -438,6 +460,9 @@ UnaryScalarFoldingRule FoldIToFOp() {
// Define the folding rule for conversion between floating point and integer
ConstantFoldingRule FoldFToI() { return FoldFPUnaryOp(FoldFToIOp()); }
ConstantFoldingRule FoldIToF() { return FoldFPUnaryOp(FoldIToFOp()); }
ConstantFoldingRule FoldQuantizeToF16() {
return FoldFPUnaryOp(FoldQuantizeToF16Scalar());
}
// Define the folding rules for subtraction, addition, multiplication, and
// division for floating point values.
@ -848,6 +873,7 @@ ConstantFoldingRules::ConstantFoldingRules() {
rules_[SpvOpVectorTimesScalar].push_back(FoldVectorTimesScalar());
rules_[SpvOpFNegate].push_back(FoldFNegate());
rules_[SpvOpQuantizeToF16].push_back(FoldQuantizeToF16());
}
} // namespace opt
} // namespace spvtools

View File

@ -122,6 +122,7 @@ bool FoldSpecConstantOpAndCompositePass::ProcessOpSpecConstantOp(
case SpvOp::SpvOpCompositeExtract:
case SpvOp::SpvOpVectorShuffle:
case SpvOp::SpvOpCompositeInsert:
case SpvOp::SpvOpQuantizeToF16:
folded_inst = FoldWithInstructionFolder(pos);
break;
default:

View File

@ -325,6 +325,19 @@ INSTANTIATE_TEST_SUITE_P(
"%inner = OpConstantComposite %inner_struct %bool_true %signed_one %undef",
"%outer = OpSpecConstantComposite %outer_struct %inner %signed_one",
},
},
// Fold an QuantizetoF16 instruction
{
// original
{
"%float_1 = OpConstant %float 1",
"%quant_float = OpSpecConstantOp %float QuantizeToF16 %float_1",
},
// expected
{
"%float_1 = OpConstant %float 1",
"%quant_float = OpConstant %float 1",
},
}
// clang-format on
})));

View File

@ -206,7 +206,14 @@ OpName %main "main"
%float_2 = OpConstant %float 2
%float_3 = OpConstant %float 3
%float_4 = OpConstant %float 4
%float_2049 = OpConstant %float 2049
%float_n2049 = OpConstant %float -2049
%float_0p5 = OpConstant %float 0.5
%float_pi = OpConstant %float 1.5555
%float_1e16 = OpConstant %float 1e16
%float_n1e16 = OpConstant %float -1e16
%float_1en16 = OpConstant %float 1e-16
%float_n1en16 = OpConstant %float -1e-16
%v2float_0_0 = OpConstantComposite %v2float %float_0 %float_0
%v2float_2_2 = OpConstantComposite %v2float %float_2 %float_2
%v2float_2_3 = OpConstantComposite %v2float %float_2 %float_3
@ -1273,7 +1280,11 @@ TEST_P(FloatInstructionFoldingTest, Case) {
const_mrg->GetConstantFromInst(inst)->AsFloatConstant();
EXPECT_NE(result, nullptr);
if (result != nullptr) {
EXPECT_EQ(result->GetFloatValue(), tc.expected_result);
if (!std::isnan(tc.expected_result)) {
EXPECT_EQ(result->GetFloatValue(), tc.expected_result);
} else {
EXPECT_TRUE(std::isnan(result->GetFloatValue()));
}
}
}
}
@ -1388,7 +1399,89 @@ INSTANTIATE_TEST_SUITE_P(FloatConstantFoldingTest, FloatInstructionFoldingTest,
"%2 = OpFNegate %float %float_2\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, -2)
2, -2),
// Test case 12: QuantizeToF16 1.0
InstructionFoldingCase<float>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpQuantizeToF16 %float %float_1\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 1.0),
// Test case 13: QuantizeToF16 positive non exact
InstructionFoldingCase<float>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpQuantizeToF16 %float %float_2049\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 2048),
// Test case 14: QuantizeToF16 negative non exact
InstructionFoldingCase<float>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpQuantizeToF16 %float %float_n2049\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, -2048),
// Test case 15: QuantizeToF16 large positive
InstructionFoldingCase<float>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpQuantizeToF16 %float %float_1e16\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, std::numeric_limits<float>::infinity()),
// Test case 16: QuantizeToF16 large negative
InstructionFoldingCase<float>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpQuantizeToF16 %float %float_n1e16\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, -std::numeric_limits<float>::infinity()),
// Test case 17: QuantizeToF16 small positive
InstructionFoldingCase<float>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpQuantizeToF16 %float %float_1en16\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0.0),
// Test case 18: QuantizeToF16 small negative
InstructionFoldingCase<float>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpQuantizeToF16 %float %float_n1en16\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0.0),
// Test case 19: QuantizeToF16 nan
InstructionFoldingCase<float>(
HeaderWithNaN() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpQuantizeToF16 %float %float_nan\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, std::numeric_limits<float>::quiet_NaN()),
// Test case 20: QuantizeToF16 inf
InstructionFoldingCase<float>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpFDiv %float %float_1 %float_0\n" +
"%3 = OpQuantizeToF16 %float %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, std::numeric_limits<float>::infinity()),
// Test case 21: QuantizeToF16 -inf
InstructionFoldingCase<float>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpFDiv %float %float_n1 %float_0\n" +
"%3 = OpQuantizeToF16 %float %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, -std::numeric_limits<float>::infinity())
));
// clang-format on
@ -4051,7 +4144,7 @@ INSTANTIATE_TEST_SUITE_P(MergeNegateTest, MatchingInstructionFoldingTest,
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
"; CHECK: [[float_n2:%\\w+]] = OpConstant [[float]] -2\n" +
"; CHECK: [[float_n2:%\\w+]] = OpConstant [[float]] -2{{[[:space:]]}}\n" +
"; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
"; CHECK: %4 = OpFMul [[float]] [[ld]] [[float_n2]]\n" +
"%main = OpFunction %void None %void_func\n" +
@ -4068,7 +4161,7 @@ INSTANTIATE_TEST_SUITE_P(MergeNegateTest, MatchingInstructionFoldingTest,
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
"; CHECK: [[float_n2:%\\w+]] = OpConstant [[float]] -2\n" +
"; CHECK: [[float_n2:%\\w+]] = OpConstant [[float]] -2{{[[:space:]]}}\n" +
"; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
"; CHECK: %4 = OpFMul [[float]] [[ld]] [[float_n2]]\n" +
"%main = OpFunction %void None %void_func\n" +
@ -4102,7 +4195,7 @@ INSTANTIATE_TEST_SUITE_P(MergeNegateTest, MatchingInstructionFoldingTest,
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
"; CHECK: [[float_n2:%\\w+]] = OpConstant [[float]] -2\n" +
"; CHECK: [[float_n2:%\\w+]] = OpConstant [[float]] -2{{[[:space:]]}}\n" +
"; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
"; CHECK: %4 = OpFDiv [[float]] [[float_n2]] [[ld]]\n" +
"%main = OpFunction %void None %void_func\n" +
@ -4119,7 +4212,7 @@ INSTANTIATE_TEST_SUITE_P(MergeNegateTest, MatchingInstructionFoldingTest,
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
"; CHECK: [[float_n2:%\\w+]] = OpConstant [[float]] -2\n" +
"; CHECK: [[float_n2:%\\w+]] = OpConstant [[float]] -2{{[[:space:]]}}\n" +
"; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
"; CHECK: %4 = OpFSub [[float]] [[float_n2]] [[ld]]\n" +
"%main = OpFunction %void None %void_func\n" +
@ -4136,7 +4229,7 @@ INSTANTIATE_TEST_SUITE_P(MergeNegateTest, MatchingInstructionFoldingTest,
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
"; CHECK: [[float_n2:%\\w+]] = OpConstant [[float]] -2\n" +
"; CHECK: [[float_n2:%\\w+]] = OpConstant [[float]] -2{{[[:space:]]}}\n" +
"; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
"; CHECK: %4 = OpFSub [[float]] [[float_n2]] [[ld]]\n" +
"%main = OpFunction %void None %void_func\n" +
@ -4218,7 +4311,7 @@ INSTANTIATE_TEST_SUITE_P(MergeNegateTest, MatchingInstructionFoldingTest,
Header() +
"; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
"; CHECK: OpConstant [[int]] -2147483648\n" +
"; CHECK: [[int_n2:%\\w+]] = OpConstant [[int]] -2\n" +
"; CHECK: [[int_n2:%\\w+]] = OpConstant [[int]] -2{{[[:space:]]}}\n" +
"; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
"; CHECK: %4 = OpISub [[int]] [[int_n2]] [[ld]]\n" +
"%main = OpFunction %void None %void_func\n" +
@ -4236,7 +4329,7 @@ INSTANTIATE_TEST_SUITE_P(MergeNegateTest, MatchingInstructionFoldingTest,
Header() +
"; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
"; CHECK: OpConstant [[int]] -2147483648\n" +
"; CHECK: [[int_n2:%\\w+]] = OpConstant [[int]] -2\n" +
"; CHECK: [[int_n2:%\\w+]] = OpConstant [[int]] -2{{[[:space:]]}}\n" +
"; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
"; CHECK: %4 = OpISub [[int]] [[int_n2]] [[ld]]\n" +
"%main = OpFunction %void None %void_func\n" +
@ -4287,7 +4380,7 @@ INSTANTIATE_TEST_SUITE_P(MergeNegateTest, MatchingInstructionFoldingTest,
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[long:%\\w+]] = OpTypeInt 64 1\n" +
"; CHECK: [[long_n2:%\\w+]] = OpConstant [[long]] -2\n" +
"; CHECK: [[long_n2:%\\w+]] = OpConstant [[long]] -2{{[[:space:]]}}\n" +
"; CHECK: [[ld:%\\w+]] = OpLoad [[long]]\n" +
"; CHECK: %4 = OpISub [[long]] [[long_n2]] [[ld]]\n" +
"%main = OpFunction %void None %void_func\n" +
@ -4337,11 +4430,11 @@ INSTANTIATE_TEST_SUITE_P(MergeNegateTest, MatchingInstructionFoldingTest,
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
"; CHECK: [[v4float:%\\w+]] = OpTypeVector [[float]] 4\n" +
"; CHECK: [[float_n1:%\\w+]] = OpConstant [[float]] -1\n" +
"; CHECK: [[float_1:%\\w+]] = OpConstant [[float]] 1\n" +
"; CHECK: [[float_n2:%\\w+]] = OpConstant [[float]] -2\n" +
"; CHECK: [[float_n3:%\\w+]] = OpConstant [[float]] -3\n" +
"; CHECK: [[v4float:%\\w+]] = OpTypeVector [[float]] 4{{[[:space:]]}}\n" +
"; CHECK: [[float_n1:%\\w+]] = OpConstant [[float]] -1{{[[:space:]]}}\n" +
"; CHECK: [[float_1:%\\w+]] = OpConstant [[float]] 1{{[[:space:]]}}\n" +
"; CHECK: [[float_n2:%\\w+]] = OpConstant [[float]] -2{{[[:space:]]}}\n" +
"; CHECK: [[float_n3:%\\w+]] = OpConstant [[float]] -3{{[[:space:]]}}\n" +
"; CHECK: [[v4float_1_n2_n1_n3:%\\w+]] = OpConstantComposite [[v4float]] [[float_1]] [[float_n2]] [[float_n1]] [[float_n3]]\n" +
"; CHECK: %2 = OpCopyObject [[v4float]] [[v4float_1_n2_n1_n3]]\n" +
"%main = OpFunction %void None %void_func\n" +
@ -4710,9 +4803,9 @@ INSTANTIATE_TEST_SUITE_P(MergeMulTest, MatchingInstructionFoldingTest,
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
"; CHECK: [[v2int:%\\w+]] = OpTypeVector [[int]] 2\n" +
"; CHECK: OpConstant [[int]] -2147483648\n" +
"; CHECK: [[int_n2:%\\w+]] = OpConstant [[int]] -2\n" +
"; CHECK: [[v2int:%\\w+]] = OpTypeVector [[int]] 2{{[[:space:]]}}\n" +
"; CHECK: OpConstant [[int]] -2147483648{{[[:space:]]}}\n" +
"; CHECK: [[int_n2:%\\w+]] = OpConstant [[int]] -2{{[[:space:]]}}\n" +
"; CHECK: [[v2int_n2_n2:%\\w+]] = OpConstantComposite [[v2int]] [[int_n2]] [[int_n2]]\n" +
"; CHECK: [[ld:%\\w+]] = OpLoad [[v2int]]\n" +
"; CHECK: %4 = OpIMul [[v2int]] [[ld]] [[v2int_n2_n2]]\n" +
@ -4730,9 +4823,9 @@ INSTANTIATE_TEST_SUITE_P(MergeMulTest, MatchingInstructionFoldingTest,
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
"; CHECK: [[v2int:%\\w+]] = OpTypeVector [[int]] 2\n" +
"; CHECK: OpConstant [[int]] -2147483648\n" +
"; CHECK: [[int_n2:%\\w+]] = OpConstant [[int]] -2\n" +
"; CHECK: [[v2int:%\\w+]] = OpTypeVector [[int]] 2{{[[:space:]]}}\n" +
"; CHECK: OpConstant [[int]] -2147483648{{[[:space:]]}}\n" +
"; CHECK: [[int_n2:%\\w+]] = OpConstant [[int]] -2{{[[:space:]]}}\n" +
"; CHECK: [[v2int_n2_n2:%\\w+]] = OpConstantComposite [[v2int]] [[int_n2]] [[int_n2]]\n" +
"; CHECK: [[ld:%\\w+]] = OpLoad [[v2int]]\n" +
"; CHECK: %4 = OpIMul [[v2int]] [[ld]] [[v2int_n2_n2]]\n" +
@ -5026,7 +5119,7 @@ INSTANTIATE_TEST_SUITE_P(MergeDivTest, MatchingInstructionFoldingTest,
Header() +
"; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
"; CHECK: OpConstant [[int]] -2147483648\n" +
"; CHECK: [[int_n2:%\\w+]] = OpConstant [[int]] -2\n" +
"; CHECK: [[int_n2:%\\w+]] = OpConstant [[int]] -2{{[[:space:]]}}\n" +
"; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
"; CHECK: %4 = OpSDiv [[int]] [[ld]] [[int_n2]]\n" +
"%main = OpFunction %void None %void_func\n" +
@ -5044,7 +5137,7 @@ INSTANTIATE_TEST_SUITE_P(MergeDivTest, MatchingInstructionFoldingTest,
Header() +
"; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
"; CHECK: OpConstant [[int]] -2147483648\n" +
"; CHECK: [[int_n2:%\\w+]] = OpConstant [[int]] -2\n" +
"; CHECK: [[int_n2:%\\w+]] = OpConstant [[int]] -2{{[[:space:]]}}\n" +
"; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
"; CHECK: %4 = OpSDiv [[int]] [[int_n2]] [[ld]]\n" +
"%main = OpFunction %void None %void_func\n" +
@ -5324,7 +5417,7 @@ INSTANTIATE_TEST_SUITE_P(MergeSubTest, MatchingInstructionFoldingTest,
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" +
"; CHECK: [[float_n2:%\\w+]] = OpConstant [[float]] -2\n" +
"; CHECK: [[float_n2:%\\w+]] = OpConstant [[float]] -2{{[[:space:]]}}\n" +
"; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" +
"; CHECK: %4 = OpFSub [[float]] [[float_n2]] [[ld]]\n" +
"%main = OpFunction %void None %void_func\n" +
@ -5358,7 +5451,7 @@ INSTANTIATE_TEST_SUITE_P(MergeSubTest, MatchingInstructionFoldingTest,
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[long:%\\w+]] = OpTypeInt 64 1\n" +
"; CHECK: [[long_n2:%\\w+]] = OpConstant [[long]] -2\n" +
"; CHECK: [[long_n2:%\\w+]] = OpConstant [[long]] -2{{[[:space:]]}}\n" +
"; CHECK: [[ld:%\\w+]] = OpLoad [[long]]\n" +
"; CHECK: %4 = OpISub [[long]] [[long_n2]] [[ld]]\n" +
"%main = OpFunction %void None %void_func\n" +