Fixes #1357. Support null constants better in folding

* getFloatConstantKind() now handles OpConstantNull
* PerformOperation() now handles OpConstantNull for vectors
* Fixed some instances where we would attempt to merge a division by 0
* added tests
This commit is contained in:
Alan Baker 2018-02-28 15:23:19 -05:00 committed by Steven Perron
parent bdaf8d56fb
commit ce5941a642
3 changed files with 120 additions and 24 deletions

View File

@ -126,6 +126,18 @@ class ScalarConstant : public Constant {
// Returns a const reference of the value of this constant in 32-bit words. // Returns a const reference of the value of this constant in 32-bit words.
virtual const std::vector<uint32_t>& words() const { return words_; } virtual const std::vector<uint32_t>& words() const { return words_; }
// Returns true if the value is zero.
bool IsZero() const {
bool is_zero = true;
for (uint32_t v : words()) {
if (v != 0) {
is_zero = false;
break;
}
}
return is_zero;
}
protected: protected:
ScalarConstant(const Type* ty, const std::vector<uint32_t>& w) ScalarConstant(const Type* ty, const std::vector<uint32_t>& w)
: Constant(ty), words_(w) {} : Constant(ty), words_(w) {}
@ -175,17 +187,6 @@ class IntConstant : public ScalarConstant {
static_cast<uint64_t>(words()[0]); static_cast<uint64_t>(words()[0]);
} }
bool IsZero() const {
bool is_zero = true;
for (uint32_t v : words()) {
if (v != 0) {
is_zero = false;
break;
}
}
return is_zero;
}
// Make a copy of this IntConstant instance. // Make a copy of this IntConstant instance.
std::unique_ptr<IntConstant> CopyIntConstant() const { std::unique_ptr<IntConstant> CopyIntConstant() const {
return MakeUnique<IntConstant>(type_->AsInteger(), words_); return MakeUnique<IntConstant>(type_->AsInteger(), words_);

View File

@ -218,9 +218,12 @@ FoldingRule ReciprocalFDiv() {
const analysis::Constant* negated_const = const analysis::Constant* negated_const =
const_mgr->GetConstant(constants[1]->type(), std::move(neg_ids)); const_mgr->GetConstant(constants[1]->type(), std::move(neg_ids));
id = const_mgr->GetDefiningInstruction(negated_const)->result_id(); id = const_mgr->GetDefiningInstruction(negated_const)->result_id();
} else { } else if (constants[1]->AsFloatConstant()) {
id = Reciprocal(const_mgr, constants[1]); id = Reciprocal(const_mgr, constants[1]);
if (id == 0) return false; if (id == 0) return false;
} else {
// Don't fold a null constant.
return false;
} }
inst->SetOpcode(SpvOpFMul); inst->SetOpcode(SpvOpFMul);
inst->SetInOperands( inst->SetInOperands(
@ -384,6 +387,22 @@ FoldingRule MergeNegateAddSubArithmetic() {
}; };
} }
// Returns true if |c| has a zero element.
bool HasZero(const analysis::Constant* c) {
if (c->AsNullConstant()) {
return true;
}
if (const analysis::VectorConstant* vec_const = c->AsVectorConstant()) {
for (auto& comp : vec_const->GetComponents())
if (HasZero(comp)) return true;
} else {
assert(c->AsScalarConstant());
return c->AsScalarConstant()->IsZero();
}
return false;
}
// Performs |input1| |opcode| |input2| and returns the merged constant result // Performs |input1| |opcode| |input2| and returns the merged constant result
// id. Returns 0 if the result is not a valid value. The input types must be // id. Returns 0 if the result is not a valid value. The input types must be
// Float. // Float.
@ -415,6 +434,7 @@ uint32_t PerformFloatingPointOperation(analysis::ConstantManager* const_mgr,
FOLD_OP(*); FOLD_OP(*);
break; break;
case SpvOpFDiv: case SpvOpFDiv:
if (HasZero(input2)) return 0;
FOLD_OP(/); FOLD_OP(/);
break; break;
case SpvOpFAdd: case SpvOpFAdd:
@ -498,10 +518,25 @@ uint32_t PerformOperation(analysis::ConstantManager* const_mgr, SpvOp opcode,
const analysis::Type* ele_type = vector_type->element_type(); const analysis::Type* ele_type = vector_type->element_type();
for (uint32_t i = 0; i != vector_type->element_count(); ++i) { for (uint32_t i = 0; i != vector_type->element_count(); ++i) {
uint32_t id = 0; uint32_t id = 0;
const analysis::Constant* input1_comp =
input1->AsVectorConstant()->GetComponents()[i]; const analysis::Constant* input1_comp = nullptr;
const analysis::Constant* input2_comp = if (const analysis::VectorConstant* input1_vector =
input2->AsVectorConstant()->GetComponents()[i]; input1->AsVectorConstant()) {
input1_comp = input1_vector->GetComponents()[i];
} else {
assert(input1->AsNullConstant());
input1_comp = const_mgr->GetConstant(ele_type, {});
}
const analysis::Constant* input2_comp = nullptr;
if (const analysis::VectorConstant* input2_vector =
input2->AsVectorConstant()) {
input2_comp = input2_vector->GetComponents()[i];
} else {
assert(input2->AsNullConstant());
input2_comp = const_mgr->GetConstant(ele_type, {});
}
if (ele_type->AsFloat()) { if (ele_type->AsFloat()) {
id = PerformFloatingPointOperation(const_mgr, opcode, input1_comp, id = PerformFloatingPointOperation(const_mgr, opcode, input1_comp,
input2_comp); input2_comp);
@ -603,7 +638,7 @@ FoldingRule MergeMulDivArithmetic() {
std::vector<const analysis::Constant*> other_constants = std::vector<const analysis::Constant*> other_constants =
const_mgr->GetOperandConstants(other_inst); const_mgr->GetOperandConstants(other_inst);
const analysis::Constant* const_input2 = ConstInput(other_constants); const analysis::Constant* const_input2 = ConstInput(other_constants);
if (!const_input2) return false; if (!const_input2 || HasZero(const_input2)) return false;
bool other_first_is_variable = other_constants[0] == nullptr; bool other_first_is_variable = other_constants[0] == nullptr;
// If the variable value is the second operand of the divide, multiply // If the variable value is the second operand of the divide, multiply
@ -695,7 +730,7 @@ FoldingRule MergeDivDivArithmetic() {
if (width != 32 && width != 64) return false; if (width != 32 && width != 64) return false;
const analysis::Constant* const_input1 = ConstInput(constants); const analysis::Constant* const_input1 = ConstInput(constants);
if (!const_input1) return false; if (!const_input1 || HasZero(const_input1)) return false;
ir::Instruction* other_inst = NonConstInput(context, constants[0], inst); ir::Instruction* other_inst = NonConstInput(context, constants[0], inst);
if (!other_inst->IsFloatingPointFoldingAllowed()) return false; if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
@ -704,7 +739,7 @@ FoldingRule MergeDivDivArithmetic() {
std::vector<const analysis::Constant*> other_constants = std::vector<const analysis::Constant*> other_constants =
const_mgr->GetOperandConstants(other_inst); const_mgr->GetOperandConstants(other_inst);
const analysis::Constant* const_input2 = ConstInput(other_constants); const analysis::Constant* const_input2 = ConstInput(other_constants);
if (!const_input2) return false; if (!const_input2 || HasZero(const_input2)) return false;
bool other_first_is_variable = other_constants[0] == nullptr; bool other_first_is_variable = other_constants[0] == nullptr;
@ -765,7 +800,7 @@ FoldingRule MergeDivMulArithmetic() {
if (width != 32 && width != 64) return false; if (width != 32 && width != 64) return false;
const analysis::Constant* const_input1 = ConstInput(constants); const analysis::Constant* const_input1 = ConstInput(constants);
if (!const_input1) return false; if (!const_input1 || HasZero(const_input1)) return false;
ir::Instruction* other_inst = NonConstInput(context, constants[0], inst); ir::Instruction* other_inst = NonConstInput(context, constants[0], inst);
if (!other_inst->IsFloatingPointFoldingAllowed()) return false; if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
@ -1543,7 +1578,12 @@ FloatConstantKind getFloatConstantKind(const analysis::Constant* constant) {
return FloatConstantKind::Unknown; return FloatConstantKind::Unknown;
} }
if (const analysis::VectorConstant* vc = constant->AsVectorConstant()) { assert(HasFloatingPoint(constant->type()) && "Unexpected constant type");
if (constant->AsNullConstant()) {
return FloatConstantKind::Zero;
} else if (const analysis::VectorConstant* vc =
constant->AsVectorConstant()) {
const std::vector<const analysis::Constant*>& components = const std::vector<const analysis::Constant*>& components =
vc->GetComponents(); vc->GetComponents();
assert(!components.empty()); assert(!components.empty());

View File

@ -198,6 +198,7 @@ OpName %main "main"
%v2float_3_2 = OpConstantComposite %v2float %float_3 %float_2 %v2float_3_2 = OpConstantComposite %v2float %float_3 %float_2
%v2float_4_4 = OpConstantComposite %v2float %float_4 %float_4 %v2float_4_4 = OpConstantComposite %v2float %float_4 %float_4
%v2float_2_0p5 = OpConstantComposite %v2float %float_2 %float_0p5 %v2float_2_0p5 = OpConstantComposite %v2float %float_2 %float_0p5
%v2float_null = OpConstantNull %v2float
%double_n1 = OpConstant %double -1 %double_n1 = OpConstant %double -1
%105 = OpConstant %double 0 ; Need a def with an numerical id to define id maps. %105 = OpConstant %double 0 ; Need a def with an numerical id to define id maps.
%double_0 = OpConstant %double 0 %double_0 = OpConstant %double 0
@ -2526,7 +2527,37 @@ INSTANTIATE_TEST_CASE_P(FloatRedundantFoldingTest, GeneralInstructionFoldingTest
"%2 = OpExtInst %float %1 FMix %3 %4 %float_1\n" + "%2 = OpExtInst %float %1 FMix %3 %4 %float_1\n" +
"OpReturn\n" + "OpReturn\n" +
"OpFunctionEnd", "OpFunctionEnd",
2, 4) 2, 4),
// Test case 15: Fold vector fadd with null
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%a = OpVariable %_ptr_v2float Function\n" +
"%2 = OpLoad %v2float %a\n" +
"%3 = OpFAdd %v2float %2 %v2float_null\n" +
"OpReturn\n" +
"OpFunctionEnd",
3, 2),
// Test case 16: Fold vector fadd with null
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%a = OpVariable %_ptr_v2float Function\n" +
"%2 = OpLoad %v2float %a\n" +
"%3 = OpFAdd %v2float %v2float_null %2\n" +
"OpReturn\n" +
"OpFunctionEnd",
3, 2),
// Test case 15: Fold vector fsub with null
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%a = OpVariable %_ptr_v2float Function\n" +
"%2 = OpLoad %v2float %a\n" +
"%3 = OpFSub %v2float %2 %v2float_null\n" +
"OpReturn\n" +
"OpFunctionEnd",
3, 2)
)); ));
INSTANTIATE_TEST_CASE_P(DoubleRedundantFoldingTest, GeneralInstructionFoldingTest, INSTANTIATE_TEST_CASE_P(DoubleRedundantFoldingTest, GeneralInstructionFoldingTest,
@ -3317,7 +3348,18 @@ INSTANTIATE_TEST_CASE_P(ReciprocalFDivTest, MatchingInstructionFoldingTest,
"%3 = OpFDiv %double %2 %double_2\n" + "%3 = OpFDiv %double %2 %double_2\n" +
"OpReturn\n" + "OpReturn\n" +
"OpFunctionEnd\n", "OpFunctionEnd\n",
3, true) 3, true),
// Test case 4: don't fold x / 0.
InstructionFoldingCase<bool>(
Header() +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%var = OpVariable %_ptr_v2float Function\n" +
"%2 = OpLoad %v2float %var\n" +
"%3 = OpFDiv %v2float %2 %v2float_null\n" +
"OpReturn\n" +
"OpFunctionEnd\n",
3, false)
)); ));
INSTANTIATE_TEST_CASE_P(MergeMulTest, MatchingInstructionFoldingTest, INSTANTIATE_TEST_CASE_P(MergeMulTest, MatchingInstructionFoldingTest,
@ -3812,7 +3854,20 @@ INSTANTIATE_TEST_CASE_P(MergeDivTest, MatchingInstructionFoldingTest,
"%4 = OpSDiv %int %int_2 %3\n" + "%4 = OpSDiv %int %int_2 %3\n" +
"OpReturn\n" + "OpReturn\n" +
"OpFunctionEnd\n", "OpFunctionEnd\n",
4, true) 4, true),
// Test case 13: Don't merge
// (x / {null}) / {null}
InstructionFoldingCase<bool>(
Header() +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%var = OpVariable %_ptr_v2float Function\n" +
"%2 = OpLoad %float %var\n" +
"%3 = OpFDiv %float %2 %v2float_null\n" +
"%4 = OpFDiv %float %3 %v2float_null\n" +
"OpReturn\n" +
"OpFunctionEnd\n",
4, false)
)); ));
INSTANTIATE_TEST_CASE_P(MergeAddTest, MatchingInstructionFoldingTest, INSTANTIATE_TEST_CASE_P(MergeAddTest, MatchingInstructionFoldingTest,