mirror of
https://github.com/KhronosGroup/SPIRV-Tools
synced 2025-01-13 09:50:06 +00:00
Fixes #1357. Support null constants better in folding
* getFloatConstantKind() now handles OpConstantNull * PerformOperation() now handles OpConstantNull for vectors * Fixed some instances where we would attempt to merge a division by 0 * added tests
This commit is contained in:
parent
bdaf8d56fb
commit
ce5941a642
@ -126,6 +126,18 @@ class ScalarConstant : public Constant {
|
|||||||
// Returns a const reference of the value of this constant in 32-bit words.
|
// Returns a const reference of the value of this constant in 32-bit words.
|
||||||
virtual const std::vector<uint32_t>& words() const { return words_; }
|
virtual const std::vector<uint32_t>& words() const { return words_; }
|
||||||
|
|
||||||
|
// Returns true if the value is zero.
|
||||||
|
bool IsZero() const {
|
||||||
|
bool is_zero = true;
|
||||||
|
for (uint32_t v : words()) {
|
||||||
|
if (v != 0) {
|
||||||
|
is_zero = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return is_zero;
|
||||||
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
ScalarConstant(const Type* ty, const std::vector<uint32_t>& w)
|
ScalarConstant(const Type* ty, const std::vector<uint32_t>& w)
|
||||||
: Constant(ty), words_(w) {}
|
: Constant(ty), words_(w) {}
|
||||||
@ -175,17 +187,6 @@ class IntConstant : public ScalarConstant {
|
|||||||
static_cast<uint64_t>(words()[0]);
|
static_cast<uint64_t>(words()[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsZero() const {
|
|
||||||
bool is_zero = true;
|
|
||||||
for (uint32_t v : words()) {
|
|
||||||
if (v != 0) {
|
|
||||||
is_zero = false;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return is_zero;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Make a copy of this IntConstant instance.
|
// Make a copy of this IntConstant instance.
|
||||||
std::unique_ptr<IntConstant> CopyIntConstant() const {
|
std::unique_ptr<IntConstant> CopyIntConstant() const {
|
||||||
return MakeUnique<IntConstant>(type_->AsInteger(), words_);
|
return MakeUnique<IntConstant>(type_->AsInteger(), words_);
|
||||||
|
@ -218,9 +218,12 @@ FoldingRule ReciprocalFDiv() {
|
|||||||
const analysis::Constant* negated_const =
|
const analysis::Constant* negated_const =
|
||||||
const_mgr->GetConstant(constants[1]->type(), std::move(neg_ids));
|
const_mgr->GetConstant(constants[1]->type(), std::move(neg_ids));
|
||||||
id = const_mgr->GetDefiningInstruction(negated_const)->result_id();
|
id = const_mgr->GetDefiningInstruction(negated_const)->result_id();
|
||||||
} else {
|
} else if (constants[1]->AsFloatConstant()) {
|
||||||
id = Reciprocal(const_mgr, constants[1]);
|
id = Reciprocal(const_mgr, constants[1]);
|
||||||
if (id == 0) return false;
|
if (id == 0) return false;
|
||||||
|
} else {
|
||||||
|
// Don't fold a null constant.
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
inst->SetOpcode(SpvOpFMul);
|
inst->SetOpcode(SpvOpFMul);
|
||||||
inst->SetInOperands(
|
inst->SetInOperands(
|
||||||
@ -384,6 +387,22 @@ FoldingRule MergeNegateAddSubArithmetic() {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns true if |c| has a zero element.
|
||||||
|
bool HasZero(const analysis::Constant* c) {
|
||||||
|
if (c->AsNullConstant()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (const analysis::VectorConstant* vec_const = c->AsVectorConstant()) {
|
||||||
|
for (auto& comp : vec_const->GetComponents())
|
||||||
|
if (HasZero(comp)) return true;
|
||||||
|
} else {
|
||||||
|
assert(c->AsScalarConstant());
|
||||||
|
return c->AsScalarConstant()->IsZero();
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
// Performs |input1| |opcode| |input2| and returns the merged constant result
|
// Performs |input1| |opcode| |input2| and returns the merged constant result
|
||||||
// id. Returns 0 if the result is not a valid value. The input types must be
|
// id. Returns 0 if the result is not a valid value. The input types must be
|
||||||
// Float.
|
// Float.
|
||||||
@ -415,6 +434,7 @@ uint32_t PerformFloatingPointOperation(analysis::ConstantManager* const_mgr,
|
|||||||
FOLD_OP(*);
|
FOLD_OP(*);
|
||||||
break;
|
break;
|
||||||
case SpvOpFDiv:
|
case SpvOpFDiv:
|
||||||
|
if (HasZero(input2)) return 0;
|
||||||
FOLD_OP(/);
|
FOLD_OP(/);
|
||||||
break;
|
break;
|
||||||
case SpvOpFAdd:
|
case SpvOpFAdd:
|
||||||
@ -498,10 +518,25 @@ uint32_t PerformOperation(analysis::ConstantManager* const_mgr, SpvOp opcode,
|
|||||||
const analysis::Type* ele_type = vector_type->element_type();
|
const analysis::Type* ele_type = vector_type->element_type();
|
||||||
for (uint32_t i = 0; i != vector_type->element_count(); ++i) {
|
for (uint32_t i = 0; i != vector_type->element_count(); ++i) {
|
||||||
uint32_t id = 0;
|
uint32_t id = 0;
|
||||||
const analysis::Constant* input1_comp =
|
|
||||||
input1->AsVectorConstant()->GetComponents()[i];
|
const analysis::Constant* input1_comp = nullptr;
|
||||||
const analysis::Constant* input2_comp =
|
if (const analysis::VectorConstant* input1_vector =
|
||||||
input2->AsVectorConstant()->GetComponents()[i];
|
input1->AsVectorConstant()) {
|
||||||
|
input1_comp = input1_vector->GetComponents()[i];
|
||||||
|
} else {
|
||||||
|
assert(input1->AsNullConstant());
|
||||||
|
input1_comp = const_mgr->GetConstant(ele_type, {});
|
||||||
|
}
|
||||||
|
|
||||||
|
const analysis::Constant* input2_comp = nullptr;
|
||||||
|
if (const analysis::VectorConstant* input2_vector =
|
||||||
|
input2->AsVectorConstant()) {
|
||||||
|
input2_comp = input2_vector->GetComponents()[i];
|
||||||
|
} else {
|
||||||
|
assert(input2->AsNullConstant());
|
||||||
|
input2_comp = const_mgr->GetConstant(ele_type, {});
|
||||||
|
}
|
||||||
|
|
||||||
if (ele_type->AsFloat()) {
|
if (ele_type->AsFloat()) {
|
||||||
id = PerformFloatingPointOperation(const_mgr, opcode, input1_comp,
|
id = PerformFloatingPointOperation(const_mgr, opcode, input1_comp,
|
||||||
input2_comp);
|
input2_comp);
|
||||||
@ -603,7 +638,7 @@ FoldingRule MergeMulDivArithmetic() {
|
|||||||
std::vector<const analysis::Constant*> other_constants =
|
std::vector<const analysis::Constant*> other_constants =
|
||||||
const_mgr->GetOperandConstants(other_inst);
|
const_mgr->GetOperandConstants(other_inst);
|
||||||
const analysis::Constant* const_input2 = ConstInput(other_constants);
|
const analysis::Constant* const_input2 = ConstInput(other_constants);
|
||||||
if (!const_input2) return false;
|
if (!const_input2 || HasZero(const_input2)) return false;
|
||||||
|
|
||||||
bool other_first_is_variable = other_constants[0] == nullptr;
|
bool other_first_is_variable = other_constants[0] == nullptr;
|
||||||
// If the variable value is the second operand of the divide, multiply
|
// If the variable value is the second operand of the divide, multiply
|
||||||
@ -695,7 +730,7 @@ FoldingRule MergeDivDivArithmetic() {
|
|||||||
if (width != 32 && width != 64) return false;
|
if (width != 32 && width != 64) return false;
|
||||||
|
|
||||||
const analysis::Constant* const_input1 = ConstInput(constants);
|
const analysis::Constant* const_input1 = ConstInput(constants);
|
||||||
if (!const_input1) return false;
|
if (!const_input1 || HasZero(const_input1)) return false;
|
||||||
ir::Instruction* other_inst = NonConstInput(context, constants[0], inst);
|
ir::Instruction* other_inst = NonConstInput(context, constants[0], inst);
|
||||||
if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
|
if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
|
||||||
|
|
||||||
@ -704,7 +739,7 @@ FoldingRule MergeDivDivArithmetic() {
|
|||||||
std::vector<const analysis::Constant*> other_constants =
|
std::vector<const analysis::Constant*> other_constants =
|
||||||
const_mgr->GetOperandConstants(other_inst);
|
const_mgr->GetOperandConstants(other_inst);
|
||||||
const analysis::Constant* const_input2 = ConstInput(other_constants);
|
const analysis::Constant* const_input2 = ConstInput(other_constants);
|
||||||
if (!const_input2) return false;
|
if (!const_input2 || HasZero(const_input2)) return false;
|
||||||
|
|
||||||
bool other_first_is_variable = other_constants[0] == nullptr;
|
bool other_first_is_variable = other_constants[0] == nullptr;
|
||||||
|
|
||||||
@ -765,7 +800,7 @@ FoldingRule MergeDivMulArithmetic() {
|
|||||||
if (width != 32 && width != 64) return false;
|
if (width != 32 && width != 64) return false;
|
||||||
|
|
||||||
const analysis::Constant* const_input1 = ConstInput(constants);
|
const analysis::Constant* const_input1 = ConstInput(constants);
|
||||||
if (!const_input1) return false;
|
if (!const_input1 || HasZero(const_input1)) return false;
|
||||||
ir::Instruction* other_inst = NonConstInput(context, constants[0], inst);
|
ir::Instruction* other_inst = NonConstInput(context, constants[0], inst);
|
||||||
if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
|
if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
|
||||||
|
|
||||||
@ -1543,7 +1578,12 @@ FloatConstantKind getFloatConstantKind(const analysis::Constant* constant) {
|
|||||||
return FloatConstantKind::Unknown;
|
return FloatConstantKind::Unknown;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (const analysis::VectorConstant* vc = constant->AsVectorConstant()) {
|
assert(HasFloatingPoint(constant->type()) && "Unexpected constant type");
|
||||||
|
|
||||||
|
if (constant->AsNullConstant()) {
|
||||||
|
return FloatConstantKind::Zero;
|
||||||
|
} else if (const analysis::VectorConstant* vc =
|
||||||
|
constant->AsVectorConstant()) {
|
||||||
const std::vector<const analysis::Constant*>& components =
|
const std::vector<const analysis::Constant*>& components =
|
||||||
vc->GetComponents();
|
vc->GetComponents();
|
||||||
assert(!components.empty());
|
assert(!components.empty());
|
||||||
|
@ -198,6 +198,7 @@ OpName %main "main"
|
|||||||
%v2float_3_2 = OpConstantComposite %v2float %float_3 %float_2
|
%v2float_3_2 = OpConstantComposite %v2float %float_3 %float_2
|
||||||
%v2float_4_4 = OpConstantComposite %v2float %float_4 %float_4
|
%v2float_4_4 = OpConstantComposite %v2float %float_4 %float_4
|
||||||
%v2float_2_0p5 = OpConstantComposite %v2float %float_2 %float_0p5
|
%v2float_2_0p5 = OpConstantComposite %v2float %float_2 %float_0p5
|
||||||
|
%v2float_null = OpConstantNull %v2float
|
||||||
%double_n1 = OpConstant %double -1
|
%double_n1 = OpConstant %double -1
|
||||||
%105 = OpConstant %double 0 ; Need a def with an numerical id to define id maps.
|
%105 = OpConstant %double 0 ; Need a def with an numerical id to define id maps.
|
||||||
%double_0 = OpConstant %double 0
|
%double_0 = OpConstant %double 0
|
||||||
@ -2526,7 +2527,37 @@ INSTANTIATE_TEST_CASE_P(FloatRedundantFoldingTest, GeneralInstructionFoldingTest
|
|||||||
"%2 = OpExtInst %float %1 FMix %3 %4 %float_1\n" +
|
"%2 = OpExtInst %float %1 FMix %3 %4 %float_1\n" +
|
||||||
"OpReturn\n" +
|
"OpReturn\n" +
|
||||||
"OpFunctionEnd",
|
"OpFunctionEnd",
|
||||||
2, 4)
|
2, 4),
|
||||||
|
// Test case 15: Fold vector fadd with null
|
||||||
|
InstructionFoldingCase<uint32_t>(
|
||||||
|
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||||
|
"%main_lab = OpLabel\n" +
|
||||||
|
"%a = OpVariable %_ptr_v2float Function\n" +
|
||||||
|
"%2 = OpLoad %v2float %a\n" +
|
||||||
|
"%3 = OpFAdd %v2float %2 %v2float_null\n" +
|
||||||
|
"OpReturn\n" +
|
||||||
|
"OpFunctionEnd",
|
||||||
|
3, 2),
|
||||||
|
// Test case 16: Fold vector fadd with null
|
||||||
|
InstructionFoldingCase<uint32_t>(
|
||||||
|
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||||
|
"%main_lab = OpLabel\n" +
|
||||||
|
"%a = OpVariable %_ptr_v2float Function\n" +
|
||||||
|
"%2 = OpLoad %v2float %a\n" +
|
||||||
|
"%3 = OpFAdd %v2float %v2float_null %2\n" +
|
||||||
|
"OpReturn\n" +
|
||||||
|
"OpFunctionEnd",
|
||||||
|
3, 2),
|
||||||
|
// Test case 15: Fold vector fsub with null
|
||||||
|
InstructionFoldingCase<uint32_t>(
|
||||||
|
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||||
|
"%main_lab = OpLabel\n" +
|
||||||
|
"%a = OpVariable %_ptr_v2float Function\n" +
|
||||||
|
"%2 = OpLoad %v2float %a\n" +
|
||||||
|
"%3 = OpFSub %v2float %2 %v2float_null\n" +
|
||||||
|
"OpReturn\n" +
|
||||||
|
"OpFunctionEnd",
|
||||||
|
3, 2)
|
||||||
));
|
));
|
||||||
|
|
||||||
INSTANTIATE_TEST_CASE_P(DoubleRedundantFoldingTest, GeneralInstructionFoldingTest,
|
INSTANTIATE_TEST_CASE_P(DoubleRedundantFoldingTest, GeneralInstructionFoldingTest,
|
||||||
@ -3317,7 +3348,18 @@ INSTANTIATE_TEST_CASE_P(ReciprocalFDivTest, MatchingInstructionFoldingTest,
|
|||||||
"%3 = OpFDiv %double %2 %double_2\n" +
|
"%3 = OpFDiv %double %2 %double_2\n" +
|
||||||
"OpReturn\n" +
|
"OpReturn\n" +
|
||||||
"OpFunctionEnd\n",
|
"OpFunctionEnd\n",
|
||||||
3, true)
|
3, true),
|
||||||
|
// Test case 4: don't fold x / 0.
|
||||||
|
InstructionFoldingCase<bool>(
|
||||||
|
Header() +
|
||||||
|
"%main = OpFunction %void None %void_func\n" +
|
||||||
|
"%main_lab = OpLabel\n" +
|
||||||
|
"%var = OpVariable %_ptr_v2float Function\n" +
|
||||||
|
"%2 = OpLoad %v2float %var\n" +
|
||||||
|
"%3 = OpFDiv %v2float %2 %v2float_null\n" +
|
||||||
|
"OpReturn\n" +
|
||||||
|
"OpFunctionEnd\n",
|
||||||
|
3, false)
|
||||||
));
|
));
|
||||||
|
|
||||||
INSTANTIATE_TEST_CASE_P(MergeMulTest, MatchingInstructionFoldingTest,
|
INSTANTIATE_TEST_CASE_P(MergeMulTest, MatchingInstructionFoldingTest,
|
||||||
@ -3812,7 +3854,20 @@ INSTANTIATE_TEST_CASE_P(MergeDivTest, MatchingInstructionFoldingTest,
|
|||||||
"%4 = OpSDiv %int %int_2 %3\n" +
|
"%4 = OpSDiv %int %int_2 %3\n" +
|
||||||
"OpReturn\n" +
|
"OpReturn\n" +
|
||||||
"OpFunctionEnd\n",
|
"OpFunctionEnd\n",
|
||||||
4, true)
|
4, true),
|
||||||
|
// Test case 13: Don't merge
|
||||||
|
// (x / {null}) / {null}
|
||||||
|
InstructionFoldingCase<bool>(
|
||||||
|
Header() +
|
||||||
|
"%main = OpFunction %void None %void_func\n" +
|
||||||
|
"%main_lab = OpLabel\n" +
|
||||||
|
"%var = OpVariable %_ptr_v2float Function\n" +
|
||||||
|
"%2 = OpLoad %float %var\n" +
|
||||||
|
"%3 = OpFDiv %float %2 %v2float_null\n" +
|
||||||
|
"%4 = OpFDiv %float %3 %v2float_null\n" +
|
||||||
|
"OpReturn\n" +
|
||||||
|
"OpFunctionEnd\n",
|
||||||
|
4, false)
|
||||||
));
|
));
|
||||||
|
|
||||||
INSTANTIATE_TEST_CASE_P(MergeAddTest, MatchingInstructionFoldingTest,
|
INSTANTIATE_TEST_CASE_P(MergeAddTest, MatchingInstructionFoldingTest,
|
||||||
|
Loading…
Reference in New Issue
Block a user