From fde69dcd80cc1ca548300702adf01eeb25441f3e Mon Sep 17 00:00:00 2001 From: Steven Perron Date: Wed, 20 Feb 2019 20:05:08 -0500 Subject: [PATCH] Fix OpDot folding of half float vectors. (#2411) * Fix OpDot folding of half float vectors. The code that folds OpDot does not handle half floats correctly. After trying to multiple the first components, we get a nullptr because we don't fold half float values. This nullptr gets passed to the code that does the addition, and causes an assert. Fixes #2405. --- source/opt/const_folding_rules.cpp | 6 +++++- test/opt/fold_test.cpp | 26 ++++++++++++++++++++++---- 2 files changed, 27 insertions(+), 5 deletions(-) diff --git a/source/opt/const_folding_rules.cpp b/source/opt/const_folding_rules.cpp index f6013a3d7..3df5a83ec 100644 --- a/source/opt/const_folding_rules.cpp +++ b/source/opt/const_folding_rules.cpp @@ -582,13 +582,17 @@ ConstantFoldingRule FoldOpDotWithConstants() { std::vector words = result.GetWords(); const analysis::Constant* result_const = const_mgr->GetConstant(float_type, words); - for (uint32_t i = 0; i < a_components.size(); ++i) { + for (uint32_t i = 0; i < a_components.size() && result_const != nullptr; + ++i) { if (a_components[i] == nullptr || b_components[i] == nullptr) { return nullptr; } const analysis::Constant* component = FOLD_FPARITH_OP(*)( new_type, a_components[i], b_components[i], const_mgr); + if (component == nullptr) { + return nullptr; + } result_const = FOLD_FPARITH_OP(+)(new_type, result_const, component, const_mgr); } diff --git a/test/opt/fold_test.cpp b/test/opt/fold_test.cpp index d874953af..b3c344115 100644 --- a/test/opt/fold_test.cpp +++ b/test/opt/fold_test.cpp @@ -142,6 +142,7 @@ OpName %main "main" %v4double = OpTypeVector %double 4 %v2float = OpTypeVector %float 2 %v2double = OpTypeVector %double 2 +%v2half = OpTypeVector %half 2 %v2bool = OpTypeVector %bool 2 %struct_v2int_int_int = OpTypeStruct %v2int %int %int %_ptr_int = OpTypePointer Function %int @@ -231,6 +232,7 @@ OpName %main "main" %v2double_null = OpConstantNull %v2double %108 = OpConstant %half 0 %half_1 = OpConstant %half 1 +%half_0_1 = OpConstantComposite %v2half %108 %half_1 %106 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0 %v4float_0_0_0_0 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0 %v4float_0_0_0_1 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_1 @@ -3171,7 +3173,7 @@ INSTANTIATE_TEST_SUITE_P(FloatRedundantFoldingTest, GeneralInstructionFoldingTes "OpReturn\n" + "OpFunctionEnd", 3, 2), - // Test case 15: Fold vector fsub with null + // Test case 17: Fold vector fsub with null InstructionFoldingCase( Header() + "%main = OpFunction %void None %void_func\n" + "%main_lab = OpLabel\n" + @@ -3181,7 +3183,7 @@ INSTANTIATE_TEST_SUITE_P(FloatRedundantFoldingTest, GeneralInstructionFoldingTes "OpReturn\n" + "OpFunctionEnd", 3, 2), - // Test case 16: Fold 0.0(half) * n + // Test case 18: Fold 0.0(half) * n InstructionFoldingCase( Header() + "%main = OpFunction %void None %void_func\n" + "%main_lab = OpLabel\n" + @@ -3191,7 +3193,7 @@ INSTANTIATE_TEST_SUITE_P(FloatRedundantFoldingTest, GeneralInstructionFoldingTes "OpReturn\n" + "OpFunctionEnd", 2, HALF_0_ID), - // Test case 17: Don't fold 1.0(half) * n + // Test case 19: Don't fold 1.0(half) * n InstructionFoldingCase( Header() + "%main = OpFunction %void None %void_func\n" + "%main_lab = OpLabel\n" + @@ -3201,13 +3203,29 @@ INSTANTIATE_TEST_SUITE_P(FloatRedundantFoldingTest, GeneralInstructionFoldingTes "OpReturn\n" + "OpFunctionEnd", 2, 0), - // Test case 18: Don't fold 1.0 * 1.0 (half) + // Test case 20: Don't fold 1.0 * 1.0 (half) InstructionFoldingCase( Header() + "%main = OpFunction %void None %void_func\n" + "%main_lab = OpLabel\n" + "%2 = OpFMul %half %half_1 %half_1\n" + "OpReturn\n" + "OpFunctionEnd", + 2, 0), + // Test case 21: Don't fold (0.0, 1.0) * (0.0, 1.0) (half) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFMul %v2half %half_0_1 %half_0_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 22: Don't fold (0.0, 1.0) dotp (0.0, 1.0) (half) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpDot %half %half_0_1 %half_0_1\n" + + "OpReturn\n" + + "OpFunctionEnd", 2, 0) ));