Fix OpDot folding of half float vectors. (#2411)

* Fix OpDot folding of half float vectors.

The code that folds OpDot does not handle half floats correctly.  After
trying to multiple the first components, we get a nullptr because we
don't fold half float values.  This nullptr gets passed to the code that
does the addition, and causes an assert.

Fixes #2405.
This commit is contained in:
Steven Perron 2019-02-20 20:05:08 -05:00 committed by GitHub
parent 8eddde2e70
commit fde69dcd80
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 27 additions and 5 deletions

View File

@ -582,13 +582,17 @@ ConstantFoldingRule FoldOpDotWithConstants() {
std::vector<uint32_t> words = result.GetWords();
const analysis::Constant* result_const =
const_mgr->GetConstant(float_type, words);
for (uint32_t i = 0; i < a_components.size(); ++i) {
for (uint32_t i = 0; i < a_components.size() && result_const != nullptr;
++i) {
if (a_components[i] == nullptr || b_components[i] == nullptr) {
return nullptr;
}
const analysis::Constant* component = FOLD_FPARITH_OP(*)(
new_type, a_components[i], b_components[i], const_mgr);
if (component == nullptr) {
return nullptr;
}
result_const =
FOLD_FPARITH_OP(+)(new_type, result_const, component, const_mgr);
}

View File

@ -142,6 +142,7 @@ OpName %main "main"
%v4double = OpTypeVector %double 4
%v2float = OpTypeVector %float 2
%v2double = OpTypeVector %double 2
%v2half = OpTypeVector %half 2
%v2bool = OpTypeVector %bool 2
%struct_v2int_int_int = OpTypeStruct %v2int %int %int
%_ptr_int = OpTypePointer Function %int
@ -231,6 +232,7 @@ OpName %main "main"
%v2double_null = OpConstantNull %v2double
%108 = OpConstant %half 0
%half_1 = OpConstant %half 1
%half_0_1 = OpConstantComposite %v2half %108 %half_1
%106 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0
%v4float_0_0_0_0 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0
%v4float_0_0_0_1 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_1
@ -3171,7 +3173,7 @@ INSTANTIATE_TEST_SUITE_P(FloatRedundantFoldingTest, GeneralInstructionFoldingTes
"OpReturn\n" +
"OpFunctionEnd",
3, 2),
// Test case 15: Fold vector fsub with null
// Test case 17: Fold vector fsub with null
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
@ -3181,7 +3183,7 @@ INSTANTIATE_TEST_SUITE_P(FloatRedundantFoldingTest, GeneralInstructionFoldingTes
"OpReturn\n" +
"OpFunctionEnd",
3, 2),
// Test case 16: Fold 0.0(half) * n
// Test case 18: Fold 0.0(half) * n
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
@ -3191,7 +3193,7 @@ INSTANTIATE_TEST_SUITE_P(FloatRedundantFoldingTest, GeneralInstructionFoldingTes
"OpReturn\n" +
"OpFunctionEnd",
2, HALF_0_ID),
// Test case 17: Don't fold 1.0(half) * n
// Test case 19: Don't fold 1.0(half) * n
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
@ -3201,13 +3203,29 @@ INSTANTIATE_TEST_SUITE_P(FloatRedundantFoldingTest, GeneralInstructionFoldingTes
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 18: Don't fold 1.0 * 1.0 (half)
// Test case 20: Don't fold 1.0 * 1.0 (half)
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpFMul %half %half_1 %half_1\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 21: Don't fold (0.0, 1.0) * (0.0, 1.0) (half)
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpFMul %v2half %half_0_1 %half_0_1\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 22: Don't fold (0.0, 1.0) dotp (0.0, 1.0) (half)
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpDot %half %half_0_1 %half_0_1\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0)
));