From ce5941a6425e2b0f8128d02e830a9609b3f18709 Mon Sep 17 00:00:00 2001 From: Alan Baker Date: Wed, 28 Feb 2018 15:23:19 -0500 Subject: [PATCH] Fixes #1357. Support null constants better in folding * getFloatConstantKind() now handles OpConstantNull * PerformOperation() now handles OpConstantNull for vectors * Fixed some instances where we would attempt to merge a division by 0 * added tests --- source/opt/constants.h | 23 +++++++------- source/opt/folding_rules.cpp | 60 +++++++++++++++++++++++++++++------ test/opt/fold_test.cpp | 61 ++++++++++++++++++++++++++++++++++-- 3 files changed, 120 insertions(+), 24 deletions(-) diff --git a/source/opt/constants.h b/source/opt/constants.h index cd3134b11..999dc52c4 100644 --- a/source/opt/constants.h +++ b/source/opt/constants.h @@ -126,6 +126,18 @@ class ScalarConstant : public Constant { // Returns a const reference of the value of this constant in 32-bit words. virtual const std::vector& words() const { return words_; } + // Returns true if the value is zero. + bool IsZero() const { + bool is_zero = true; + for (uint32_t v : words()) { + if (v != 0) { + is_zero = false; + break; + } + } + return is_zero; + } + protected: ScalarConstant(const Type* ty, const std::vector& w) : Constant(ty), words_(w) {} @@ -175,17 +187,6 @@ class IntConstant : public ScalarConstant { static_cast(words()[0]); } - bool IsZero() const { - bool is_zero = true; - for (uint32_t v : words()) { - if (v != 0) { - is_zero = false; - break; - } - } - return is_zero; - } - // Make a copy of this IntConstant instance. std::unique_ptr CopyIntConstant() const { return MakeUnique(type_->AsInteger(), words_); diff --git a/source/opt/folding_rules.cpp b/source/opt/folding_rules.cpp index f94ba7b01..7e4dddba3 100644 --- a/source/opt/folding_rules.cpp +++ b/source/opt/folding_rules.cpp @@ -218,9 +218,12 @@ FoldingRule ReciprocalFDiv() { const analysis::Constant* negated_const = const_mgr->GetConstant(constants[1]->type(), std::move(neg_ids)); id = const_mgr->GetDefiningInstruction(negated_const)->result_id(); - } else { + } else if (constants[1]->AsFloatConstant()) { id = Reciprocal(const_mgr, constants[1]); if (id == 0) return false; + } else { + // Don't fold a null constant. + return false; } inst->SetOpcode(SpvOpFMul); inst->SetInOperands( @@ -384,6 +387,22 @@ FoldingRule MergeNegateAddSubArithmetic() { }; } +// Returns true if |c| has a zero element. +bool HasZero(const analysis::Constant* c) { + if (c->AsNullConstant()) { + return true; + } + if (const analysis::VectorConstant* vec_const = c->AsVectorConstant()) { + for (auto& comp : vec_const->GetComponents()) + if (HasZero(comp)) return true; + } else { + assert(c->AsScalarConstant()); + return c->AsScalarConstant()->IsZero(); + } + + return false; +} + // Performs |input1| |opcode| |input2| and returns the merged constant result // id. Returns 0 if the result is not a valid value. The input types must be // Float. @@ -415,6 +434,7 @@ uint32_t PerformFloatingPointOperation(analysis::ConstantManager* const_mgr, FOLD_OP(*); break; case SpvOpFDiv: + if (HasZero(input2)) return 0; FOLD_OP(/); break; case SpvOpFAdd: @@ -498,10 +518,25 @@ uint32_t PerformOperation(analysis::ConstantManager* const_mgr, SpvOp opcode, const analysis::Type* ele_type = vector_type->element_type(); for (uint32_t i = 0; i != vector_type->element_count(); ++i) { uint32_t id = 0; - const analysis::Constant* input1_comp = - input1->AsVectorConstant()->GetComponents()[i]; - const analysis::Constant* input2_comp = - input2->AsVectorConstant()->GetComponents()[i]; + + const analysis::Constant* input1_comp = nullptr; + if (const analysis::VectorConstant* input1_vector = + input1->AsVectorConstant()) { + input1_comp = input1_vector->GetComponents()[i]; + } else { + assert(input1->AsNullConstant()); + input1_comp = const_mgr->GetConstant(ele_type, {}); + } + + const analysis::Constant* input2_comp = nullptr; + if (const analysis::VectorConstant* input2_vector = + input2->AsVectorConstant()) { + input2_comp = input2_vector->GetComponents()[i]; + } else { + assert(input2->AsNullConstant()); + input2_comp = const_mgr->GetConstant(ele_type, {}); + } + if (ele_type->AsFloat()) { id = PerformFloatingPointOperation(const_mgr, opcode, input1_comp, input2_comp); @@ -603,7 +638,7 @@ FoldingRule MergeMulDivArithmetic() { std::vector other_constants = const_mgr->GetOperandConstants(other_inst); const analysis::Constant* const_input2 = ConstInput(other_constants); - if (!const_input2) return false; + if (!const_input2 || HasZero(const_input2)) return false; bool other_first_is_variable = other_constants[0] == nullptr; // If the variable value is the second operand of the divide, multiply @@ -695,7 +730,7 @@ FoldingRule MergeDivDivArithmetic() { if (width != 32 && width != 64) return false; const analysis::Constant* const_input1 = ConstInput(constants); - if (!const_input1) return false; + if (!const_input1 || HasZero(const_input1)) return false; ir::Instruction* other_inst = NonConstInput(context, constants[0], inst); if (!other_inst->IsFloatingPointFoldingAllowed()) return false; @@ -704,7 +739,7 @@ FoldingRule MergeDivDivArithmetic() { std::vector other_constants = const_mgr->GetOperandConstants(other_inst); const analysis::Constant* const_input2 = ConstInput(other_constants); - if (!const_input2) return false; + if (!const_input2 || HasZero(const_input2)) return false; bool other_first_is_variable = other_constants[0] == nullptr; @@ -765,7 +800,7 @@ FoldingRule MergeDivMulArithmetic() { if (width != 32 && width != 64) return false; const analysis::Constant* const_input1 = ConstInput(constants); - if (!const_input1) return false; + if (!const_input1 || HasZero(const_input1)) return false; ir::Instruction* other_inst = NonConstInput(context, constants[0], inst); if (!other_inst->IsFloatingPointFoldingAllowed()) return false; @@ -1543,7 +1578,12 @@ FloatConstantKind getFloatConstantKind(const analysis::Constant* constant) { return FloatConstantKind::Unknown; } - if (const analysis::VectorConstant* vc = constant->AsVectorConstant()) { + assert(HasFloatingPoint(constant->type()) && "Unexpected constant type"); + + if (constant->AsNullConstant()) { + return FloatConstantKind::Zero; + } else if (const analysis::VectorConstant* vc = + constant->AsVectorConstant()) { const std::vector& components = vc->GetComponents(); assert(!components.empty()); diff --git a/test/opt/fold_test.cpp b/test/opt/fold_test.cpp index 4e418b925..345cfeacb 100644 --- a/test/opt/fold_test.cpp +++ b/test/opt/fold_test.cpp @@ -198,6 +198,7 @@ OpName %main "main" %v2float_3_2 = OpConstantComposite %v2float %float_3 %float_2 %v2float_4_4 = OpConstantComposite %v2float %float_4 %float_4 %v2float_2_0p5 = OpConstantComposite %v2float %float_2 %float_0p5 +%v2float_null = OpConstantNull %v2float %double_n1 = OpConstant %double -1 %105 = OpConstant %double 0 ; Need a def with an numerical id to define id maps. %double_0 = OpConstant %double 0 @@ -2526,7 +2527,37 @@ INSTANTIATE_TEST_CASE_P(FloatRedundantFoldingTest, GeneralInstructionFoldingTest "%2 = OpExtInst %float %1 FMix %3 %4 %float_1\n" + "OpReturn\n" + "OpFunctionEnd", - 2, 4) + 2, 4), + // Test case 15: Fold vector fadd with null + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%a = OpVariable %_ptr_v2float Function\n" + + "%2 = OpLoad %v2float %a\n" + + "%3 = OpFAdd %v2float %2 %v2float_null\n" + + "OpReturn\n" + + "OpFunctionEnd", + 3, 2), + // Test case 16: Fold vector fadd with null + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%a = OpVariable %_ptr_v2float Function\n" + + "%2 = OpLoad %v2float %a\n" + + "%3 = OpFAdd %v2float %v2float_null %2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 3, 2), + // Test case 15: Fold vector fsub with null + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%a = OpVariable %_ptr_v2float Function\n" + + "%2 = OpLoad %v2float %a\n" + + "%3 = OpFSub %v2float %2 %v2float_null\n" + + "OpReturn\n" + + "OpFunctionEnd", + 3, 2) )); INSTANTIATE_TEST_CASE_P(DoubleRedundantFoldingTest, GeneralInstructionFoldingTest, @@ -3317,7 +3348,18 @@ INSTANTIATE_TEST_CASE_P(ReciprocalFDivTest, MatchingInstructionFoldingTest, "%3 = OpFDiv %double %2 %double_2\n" + "OpReturn\n" + "OpFunctionEnd\n", - 3, true) + 3, true), + // Test case 4: don't fold x / 0. + InstructionFoldingCase( + Header() + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_v2float Function\n" + + "%2 = OpLoad %v2float %var\n" + + "%3 = OpFDiv %v2float %2 %v2float_null\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 3, false) )); INSTANTIATE_TEST_CASE_P(MergeMulTest, MatchingInstructionFoldingTest, @@ -3812,7 +3854,20 @@ INSTANTIATE_TEST_CASE_P(MergeDivTest, MatchingInstructionFoldingTest, "%4 = OpSDiv %int %int_2 %3\n" + "OpReturn\n" + "OpFunctionEnd\n", - 4, true) + 4, true), + // Test case 13: Don't merge + // (x / {null}) / {null} + InstructionFoldingCase( + Header() + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_v2float Function\n" + + "%2 = OpLoad %float %var\n" + + "%3 = OpFDiv %float %2 %v2float_null\n" + + "%4 = OpFDiv %float %3 %v2float_null\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, false) )); INSTANTIATE_TEST_CASE_P(MergeAddTest, MatchingInstructionFoldingTest,