mirror of
https://github.com/KhronosGroup/SPIRV-Tools
synced 2025-01-13 09:50:06 +00:00
Fixes #1357. Support null constants better in folding
* getFloatConstantKind() now handles OpConstantNull * PerformOperation() now handles OpConstantNull for vectors * Fixed some instances where we would attempt to merge a division by 0 * added tests
This commit is contained in:
parent
bdaf8d56fb
commit
ce5941a642
@ -126,6 +126,18 @@ class ScalarConstant : public Constant {
|
||||
// Returns a const reference of the value of this constant in 32-bit words.
|
||||
virtual const std::vector<uint32_t>& words() const { return words_; }
|
||||
|
||||
// Returns true if the value is zero.
|
||||
bool IsZero() const {
|
||||
bool is_zero = true;
|
||||
for (uint32_t v : words()) {
|
||||
if (v != 0) {
|
||||
is_zero = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
return is_zero;
|
||||
}
|
||||
|
||||
protected:
|
||||
ScalarConstant(const Type* ty, const std::vector<uint32_t>& w)
|
||||
: Constant(ty), words_(w) {}
|
||||
@ -175,17 +187,6 @@ class IntConstant : public ScalarConstant {
|
||||
static_cast<uint64_t>(words()[0]);
|
||||
}
|
||||
|
||||
bool IsZero() const {
|
||||
bool is_zero = true;
|
||||
for (uint32_t v : words()) {
|
||||
if (v != 0) {
|
||||
is_zero = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
return is_zero;
|
||||
}
|
||||
|
||||
// Make a copy of this IntConstant instance.
|
||||
std::unique_ptr<IntConstant> CopyIntConstant() const {
|
||||
return MakeUnique<IntConstant>(type_->AsInteger(), words_);
|
||||
|
@ -218,9 +218,12 @@ FoldingRule ReciprocalFDiv() {
|
||||
const analysis::Constant* negated_const =
|
||||
const_mgr->GetConstant(constants[1]->type(), std::move(neg_ids));
|
||||
id = const_mgr->GetDefiningInstruction(negated_const)->result_id();
|
||||
} else {
|
||||
} else if (constants[1]->AsFloatConstant()) {
|
||||
id = Reciprocal(const_mgr, constants[1]);
|
||||
if (id == 0) return false;
|
||||
} else {
|
||||
// Don't fold a null constant.
|
||||
return false;
|
||||
}
|
||||
inst->SetOpcode(SpvOpFMul);
|
||||
inst->SetInOperands(
|
||||
@ -384,6 +387,22 @@ FoldingRule MergeNegateAddSubArithmetic() {
|
||||
};
|
||||
}
|
||||
|
||||
// Returns true if |c| has a zero element.
|
||||
bool HasZero(const analysis::Constant* c) {
|
||||
if (c->AsNullConstant()) {
|
||||
return true;
|
||||
}
|
||||
if (const analysis::VectorConstant* vec_const = c->AsVectorConstant()) {
|
||||
for (auto& comp : vec_const->GetComponents())
|
||||
if (HasZero(comp)) return true;
|
||||
} else {
|
||||
assert(c->AsScalarConstant());
|
||||
return c->AsScalarConstant()->IsZero();
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
// Performs |input1| |opcode| |input2| and returns the merged constant result
|
||||
// id. Returns 0 if the result is not a valid value. The input types must be
|
||||
// Float.
|
||||
@ -415,6 +434,7 @@ uint32_t PerformFloatingPointOperation(analysis::ConstantManager* const_mgr,
|
||||
FOLD_OP(*);
|
||||
break;
|
||||
case SpvOpFDiv:
|
||||
if (HasZero(input2)) return 0;
|
||||
FOLD_OP(/);
|
||||
break;
|
||||
case SpvOpFAdd:
|
||||
@ -498,10 +518,25 @@ uint32_t PerformOperation(analysis::ConstantManager* const_mgr, SpvOp opcode,
|
||||
const analysis::Type* ele_type = vector_type->element_type();
|
||||
for (uint32_t i = 0; i != vector_type->element_count(); ++i) {
|
||||
uint32_t id = 0;
|
||||
const analysis::Constant* input1_comp =
|
||||
input1->AsVectorConstant()->GetComponents()[i];
|
||||
const analysis::Constant* input2_comp =
|
||||
input2->AsVectorConstant()->GetComponents()[i];
|
||||
|
||||
const analysis::Constant* input1_comp = nullptr;
|
||||
if (const analysis::VectorConstant* input1_vector =
|
||||
input1->AsVectorConstant()) {
|
||||
input1_comp = input1_vector->GetComponents()[i];
|
||||
} else {
|
||||
assert(input1->AsNullConstant());
|
||||
input1_comp = const_mgr->GetConstant(ele_type, {});
|
||||
}
|
||||
|
||||
const analysis::Constant* input2_comp = nullptr;
|
||||
if (const analysis::VectorConstant* input2_vector =
|
||||
input2->AsVectorConstant()) {
|
||||
input2_comp = input2_vector->GetComponents()[i];
|
||||
} else {
|
||||
assert(input2->AsNullConstant());
|
||||
input2_comp = const_mgr->GetConstant(ele_type, {});
|
||||
}
|
||||
|
||||
if (ele_type->AsFloat()) {
|
||||
id = PerformFloatingPointOperation(const_mgr, opcode, input1_comp,
|
||||
input2_comp);
|
||||
@ -603,7 +638,7 @@ FoldingRule MergeMulDivArithmetic() {
|
||||
std::vector<const analysis::Constant*> other_constants =
|
||||
const_mgr->GetOperandConstants(other_inst);
|
||||
const analysis::Constant* const_input2 = ConstInput(other_constants);
|
||||
if (!const_input2) return false;
|
||||
if (!const_input2 || HasZero(const_input2)) return false;
|
||||
|
||||
bool other_first_is_variable = other_constants[0] == nullptr;
|
||||
// If the variable value is the second operand of the divide, multiply
|
||||
@ -695,7 +730,7 @@ FoldingRule MergeDivDivArithmetic() {
|
||||
if (width != 32 && width != 64) return false;
|
||||
|
||||
const analysis::Constant* const_input1 = ConstInput(constants);
|
||||
if (!const_input1) return false;
|
||||
if (!const_input1 || HasZero(const_input1)) return false;
|
||||
ir::Instruction* other_inst = NonConstInput(context, constants[0], inst);
|
||||
if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
|
||||
|
||||
@ -704,7 +739,7 @@ FoldingRule MergeDivDivArithmetic() {
|
||||
std::vector<const analysis::Constant*> other_constants =
|
||||
const_mgr->GetOperandConstants(other_inst);
|
||||
const analysis::Constant* const_input2 = ConstInput(other_constants);
|
||||
if (!const_input2) return false;
|
||||
if (!const_input2 || HasZero(const_input2)) return false;
|
||||
|
||||
bool other_first_is_variable = other_constants[0] == nullptr;
|
||||
|
||||
@ -765,7 +800,7 @@ FoldingRule MergeDivMulArithmetic() {
|
||||
if (width != 32 && width != 64) return false;
|
||||
|
||||
const analysis::Constant* const_input1 = ConstInput(constants);
|
||||
if (!const_input1) return false;
|
||||
if (!const_input1 || HasZero(const_input1)) return false;
|
||||
ir::Instruction* other_inst = NonConstInput(context, constants[0], inst);
|
||||
if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
|
||||
|
||||
@ -1543,7 +1578,12 @@ FloatConstantKind getFloatConstantKind(const analysis::Constant* constant) {
|
||||
return FloatConstantKind::Unknown;
|
||||
}
|
||||
|
||||
if (const analysis::VectorConstant* vc = constant->AsVectorConstant()) {
|
||||
assert(HasFloatingPoint(constant->type()) && "Unexpected constant type");
|
||||
|
||||
if (constant->AsNullConstant()) {
|
||||
return FloatConstantKind::Zero;
|
||||
} else if (const analysis::VectorConstant* vc =
|
||||
constant->AsVectorConstant()) {
|
||||
const std::vector<const analysis::Constant*>& components =
|
||||
vc->GetComponents();
|
||||
assert(!components.empty());
|
||||
|
@ -198,6 +198,7 @@ OpName %main "main"
|
||||
%v2float_3_2 = OpConstantComposite %v2float %float_3 %float_2
|
||||
%v2float_4_4 = OpConstantComposite %v2float %float_4 %float_4
|
||||
%v2float_2_0p5 = OpConstantComposite %v2float %float_2 %float_0p5
|
||||
%v2float_null = OpConstantNull %v2float
|
||||
%double_n1 = OpConstant %double -1
|
||||
%105 = OpConstant %double 0 ; Need a def with an numerical id to define id maps.
|
||||
%double_0 = OpConstant %double 0
|
||||
@ -2526,7 +2527,37 @@ INSTANTIATE_TEST_CASE_P(FloatRedundantFoldingTest, GeneralInstructionFoldingTest
|
||||
"%2 = OpExtInst %float %1 FMix %3 %4 %float_1\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
2, 4)
|
||||
2, 4),
|
||||
// Test case 15: Fold vector fadd with null
|
||||
InstructionFoldingCase<uint32_t>(
|
||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%a = OpVariable %_ptr_v2float Function\n" +
|
||||
"%2 = OpLoad %v2float %a\n" +
|
||||
"%3 = OpFAdd %v2float %2 %v2float_null\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
3, 2),
|
||||
// Test case 16: Fold vector fadd with null
|
||||
InstructionFoldingCase<uint32_t>(
|
||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%a = OpVariable %_ptr_v2float Function\n" +
|
||||
"%2 = OpLoad %v2float %a\n" +
|
||||
"%3 = OpFAdd %v2float %v2float_null %2\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
3, 2),
|
||||
// Test case 15: Fold vector fsub with null
|
||||
InstructionFoldingCase<uint32_t>(
|
||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%a = OpVariable %_ptr_v2float Function\n" +
|
||||
"%2 = OpLoad %v2float %a\n" +
|
||||
"%3 = OpFSub %v2float %2 %v2float_null\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
3, 2)
|
||||
));
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(DoubleRedundantFoldingTest, GeneralInstructionFoldingTest,
|
||||
@ -3317,7 +3348,18 @@ INSTANTIATE_TEST_CASE_P(ReciprocalFDivTest, MatchingInstructionFoldingTest,
|
||||
"%3 = OpFDiv %double %2 %double_2\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd\n",
|
||||
3, true)
|
||||
3, true),
|
||||
// Test case 4: don't fold x / 0.
|
||||
InstructionFoldingCase<bool>(
|
||||
Header() +
|
||||
"%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%var = OpVariable %_ptr_v2float Function\n" +
|
||||
"%2 = OpLoad %v2float %var\n" +
|
||||
"%3 = OpFDiv %v2float %2 %v2float_null\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd\n",
|
||||
3, false)
|
||||
));
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(MergeMulTest, MatchingInstructionFoldingTest,
|
||||
@ -3812,7 +3854,20 @@ INSTANTIATE_TEST_CASE_P(MergeDivTest, MatchingInstructionFoldingTest,
|
||||
"%4 = OpSDiv %int %int_2 %3\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd\n",
|
||||
4, true)
|
||||
4, true),
|
||||
// Test case 13: Don't merge
|
||||
// (x / {null}) / {null}
|
||||
InstructionFoldingCase<bool>(
|
||||
Header() +
|
||||
"%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%var = OpVariable %_ptr_v2float Function\n" +
|
||||
"%2 = OpLoad %float %var\n" +
|
||||
"%3 = OpFDiv %float %2 %v2float_null\n" +
|
||||
"%4 = OpFDiv %float %3 %v2float_null\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd\n",
|
||||
4, false)
|
||||
));
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(MergeAddTest, MatchingInstructionFoldingTest,
|
||||
|
Loading…
Reference in New Issue
Block a user