Do not fold mul and adds to generate fmas (#5682)

This removes the folding rules added in #4783 and #4808. They lead to
poor code generation on Adreno devices when 16-bit floating point values
were used. Since this change is transformation is suppose to be neutral,
there is no general reason to continue doing it.

I have talked to the owners of SwiftShader, and they do not mind if the
transform is removed. They were the ones the requested the change in the
first place.

Fixes #5658
This commit is contained in:
Steven Perron 2024-05-22 13:01:26 -04:00 committed by GitHub
parent ee749f5057
commit 336b5710a5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 18 additions and 242 deletions

View File

@ -1459,132 +1459,6 @@ FoldingRule FactorAddMuls() {
};
}
// Replaces |inst| inplace with an FMA instruction |(x*y)+a|.
void ReplaceWithFma(Instruction* inst, uint32_t x, uint32_t y, uint32_t a) {
uint32_t ext =
inst->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
if (ext == 0) {
inst->context()->AddExtInstImport("GLSL.std.450");
ext = inst->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
assert(ext != 0 &&
"Could not add the GLSL.std.450 extended instruction set");
}
std::vector<Operand> operands;
operands.push_back({SPV_OPERAND_TYPE_ID, {ext}});
operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {GLSLstd450Fma}});
operands.push_back({SPV_OPERAND_TYPE_ID, {x}});
operands.push_back({SPV_OPERAND_TYPE_ID, {y}});
operands.push_back({SPV_OPERAND_TYPE_ID, {a}});
inst->SetOpcode(spv::Op::OpExtInst);
inst->SetInOperands(std::move(operands));
}
// Folds a multiple and add into an Fma.
//
// Cases:
// (x * y) + a = Fma x y a
// a + (x * y) = Fma x y a
bool MergeMulAddArithmetic(IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>&) {
assert(inst->opcode() == spv::Op::OpFAdd);
if (!inst->IsFloatingPointFoldingAllowed()) {
return false;
}
analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
for (int i = 0; i < 2; i++) {
uint32_t op_id = inst->GetSingleWordInOperand(i);
Instruction* op_inst = def_use_mgr->GetDef(op_id);
if (op_inst->opcode() != spv::Op::OpFMul) {
continue;
}
if (!op_inst->IsFloatingPointFoldingAllowed()) {
continue;
}
uint32_t x = op_inst->GetSingleWordInOperand(0);
uint32_t y = op_inst->GetSingleWordInOperand(1);
uint32_t a = inst->GetSingleWordInOperand((i + 1) % 2);
ReplaceWithFma(inst, x, y, a);
return true;
}
return false;
}
// Replaces |sub| inplace with an FMA instruction |(x*y)+a| where |a| first gets
// negated if |negate_addition| is true, otherwise |x| gets negated.
void ReplaceWithFmaAndNegate(Instruction* sub, uint32_t x, uint32_t y,
uint32_t a, bool negate_addition) {
uint32_t ext =
sub->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
if (ext == 0) {
sub->context()->AddExtInstImport("GLSL.std.450");
ext = sub->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
assert(ext != 0 &&
"Could not add the GLSL.std.450 extended instruction set");
}
InstructionBuilder ir_builder(
sub->context(), sub,
IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
Instruction* neg = ir_builder.AddUnaryOp(sub->type_id(), spv::Op::OpFNegate,
negate_addition ? a : x);
uint32_t neg_op = neg->result_id(); // -a : -x
std::vector<Operand> operands;
operands.push_back({SPV_OPERAND_TYPE_ID, {ext}});
operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {GLSLstd450Fma}});
operands.push_back({SPV_OPERAND_TYPE_ID, {negate_addition ? x : neg_op}});
operands.push_back({SPV_OPERAND_TYPE_ID, {y}});
operands.push_back({SPV_OPERAND_TYPE_ID, {negate_addition ? neg_op : a}});
sub->SetOpcode(spv::Op::OpExtInst);
sub->SetInOperands(std::move(operands));
}
// Folds a multiply and subtract into an Fma and negation.
//
// Cases:
// (x * y) - a = Fma x y -a
// a - (x * y) = Fma -x y a
bool MergeMulSubArithmetic(IRContext* context, Instruction* sub,
const std::vector<const analysis::Constant*>&) {
assert(sub->opcode() == spv::Op::OpFSub);
if (!sub->IsFloatingPointFoldingAllowed()) {
return false;
}
analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
for (int i = 0; i < 2; i++) {
uint32_t op_id = sub->GetSingleWordInOperand(i);
Instruction* mul = def_use_mgr->GetDef(op_id);
if (mul->opcode() != spv::Op::OpFMul) {
continue;
}
if (!mul->IsFloatingPointFoldingAllowed()) {
continue;
}
uint32_t x = mul->GetSingleWordInOperand(0);
uint32_t y = mul->GetSingleWordInOperand(1);
uint32_t a = sub->GetSingleWordInOperand((i + 1) % 2);
ReplaceWithFmaAndNegate(sub, x, y, a, i == 0);
return true;
}
return false;
}
FoldingRule IntMultipleBy1() {
return [](IRContext*, Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
@ -2941,7 +2815,6 @@ void FoldingRules::AddFoldingRules() {
rules_[spv::Op::OpFAdd].push_back(MergeAddSubArithmetic());
rules_[spv::Op::OpFAdd].push_back(MergeGenericAddSubArithmetic());
rules_[spv::Op::OpFAdd].push_back(FactorAddMuls());
rules_[spv::Op::OpFAdd].push_back(MergeMulAddArithmetic);
rules_[spv::Op::OpFDiv].push_back(RedundantFDiv());
rules_[spv::Op::OpFDiv].push_back(ReciprocalFDiv());
@ -2962,7 +2835,6 @@ void FoldingRules::AddFoldingRules() {
rules_[spv::Op::OpFSub].push_back(MergeSubNegateArithmetic());
rules_[spv::Op::OpFSub].push_back(MergeSubAddArithmetic());
rules_[spv::Op::OpFSub].push_back(MergeSubSubArithmetic());
rules_[spv::Op::OpFSub].push_back(MergeMulSubArithmetic);
rules_[spv::Op::OpIAdd].push_back(RedundantIAdd());
rules_[spv::Op::OpIAdd].push_back(MergeAddNegateArithmetic());

View File

@ -7933,21 +7933,15 @@ INSTANTIATE_TEST_SUITE_P(VectorShuffleMatchingTest, MatchingInstructionFoldingTe
3, true)
));
// Issue #5658: The Adreno compiler does not handle 16-bit FMA instructions well.
// We want to avoid this by not generating FMA. We decided to never generate
// FMAs because, from a SPIR-V perspective, it is neutral. The ICD can generate
// the FMA if it wants. The simplest code is no code.
INSTANTIATE_TEST_SUITE_P(FmaGenerationMatchingTest, MatchingInstructionFoldingTest,
::testing::Values(
// Test case 0: (x * y) + a = Fma(x, y, a)
// Test case 0: Don't fold (x * y) + a
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[ext:%\\w+]] = OpExtInstImport \"GLSL.std.450\"\n" +
"; CHECK: OpFunction\n" +
"; CHECK: [[x:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[y:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[a:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[lx:%\\w+]] = OpLoad {{%\\w+}} [[x]]\n" +
"; CHECK: [[ly:%\\w+]] = OpLoad {{%\\w+}} [[y]]\n" +
"; CHECK: [[la:%\\w+]] = OpLoad {{%\\w+}} [[a]]\n" +
"; CHECK: [[fma:%\\w+]] = OpExtInst {{%\\w+}} [[ext]] Fma [[lx]] [[ly]] [[la]]\n" +
"; CHECK: OpStore {{%\\w+}} [[fma]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%x = OpVariable %_ptr_float Function\n" +
@ -7961,20 +7955,10 @@ INSTANTIATE_TEST_SUITE_P(FmaGenerationMatchingTest, MatchingInstructionFoldingTe
"OpStore %a %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
3, true),
// Test case 1: a + (x * y) = Fma(x, y, a)
3, false),
// Test case 1: Don't fold a + (x * y)
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[ext:%\\w+]] = OpExtInstImport \"GLSL.std.450\"\n" +
"; CHECK: OpFunction\n" +
"; CHECK: [[x:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[y:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[a:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[lx:%\\w+]] = OpLoad {{%\\w+}} [[x]]\n" +
"; CHECK: [[ly:%\\w+]] = OpLoad {{%\\w+}} [[y]]\n" +
"; CHECK: [[la:%\\w+]] = OpLoad {{%\\w+}} [[a]]\n" +
"; CHECK: [[fma:%\\w+]] = OpExtInst {{%\\w+}} [[ext]] Fma [[lx]] [[ly]] [[la]]\n" +
"; CHECK: OpStore {{%\\w+}} [[fma]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%x = OpVariable %_ptr_float Function\n" +
@ -7988,20 +7972,10 @@ INSTANTIATE_TEST_SUITE_P(FmaGenerationMatchingTest, MatchingInstructionFoldingTe
"OpStore %a %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
3, true),
// Test case 2: (x * y) + a = Fma(x, y, a) with vectors
3, false),
// Test case 2: Don't fold (x * y) + a with vectors
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[ext:%\\w+]] = OpExtInstImport \"GLSL.std.450\"\n" +
"; CHECK: OpFunction\n" +
"; CHECK: [[x:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[y:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[a:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[lx:%\\w+]] = OpLoad {{%\\w+}} [[x]]\n" +
"; CHECK: [[ly:%\\w+]] = OpLoad {{%\\w+}} [[y]]\n" +
"; CHECK: [[la:%\\w+]] = OpLoad {{%\\w+}} [[a]]\n" +
"; CHECK: [[fma:%\\w+]] = OpExtInst {{%\\w+}} [[ext]] Fma [[lx]] [[ly]] [[la]]\n" +
"; CHECK: OpStore {{%\\w+}} [[fma]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%x = OpVariable %_ptr_v4float Function\n" +
@ -8015,20 +7989,10 @@ INSTANTIATE_TEST_SUITE_P(FmaGenerationMatchingTest, MatchingInstructionFoldingTe
"OpStore %a %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
3, true),
// Test case 3: a + (x * y) = Fma(x, y, a) with vectors
3,false),
// Test case 3: Don't fold a + (x * y) with vectors
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[ext:%\\w+]] = OpExtInstImport \"GLSL.std.450\"\n" +
"; CHECK: OpFunction\n" +
"; CHECK: [[x:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[y:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[a:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[lx:%\\w+]] = OpLoad {{%\\w+}} [[x]]\n" +
"; CHECK: [[ly:%\\w+]] = OpLoad {{%\\w+}} [[y]]\n" +
"; CHECK: [[la:%\\w+]] = OpLoad {{%\\w+}} [[a]]\n" +
"; CHECK: [[fma:%\\w+]] = OpExtInst {{%\\w+}} [[ext]] Fma [[lx]] [[ly]] [[la]]\n" +
"; CHECK: OpStore {{%\\w+}} [[fma]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%x = OpVariable %_ptr_float Function\n" +
@ -8042,46 +8006,8 @@ INSTANTIATE_TEST_SUITE_P(FmaGenerationMatchingTest, MatchingInstructionFoldingTe
"OpStore %a %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
3, true),
// Test 4: that the OpExtInstImport instruction is generated if it is missing.
InstructionFoldingCase<bool>(
std::string() +
"; CHECK: [[ext:%\\w+]] = OpExtInstImport \"GLSL.std.450\"\n" +
"; CHECK: OpFunction\n" +
"; CHECK: [[x:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[y:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[a:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[lx:%\\w+]] = OpLoad {{%\\w+}} [[x]]\n" +
"; CHECK: [[ly:%\\w+]] = OpLoad {{%\\w+}} [[y]]\n" +
"; CHECK: [[la:%\\w+]] = OpLoad {{%\\w+}} [[a]]\n" +
"; CHECK: [[fma:%\\w+]] = OpExtInst {{%\\w+}} [[ext]] Fma [[lx]] [[ly]] [[la]]\n" +
"; CHECK: OpStore {{%\\w+}} [[fma]]\n" +
"OpCapability Shader\n" +
"OpMemoryModel Logical GLSL450\n" +
"OpEntryPoint Fragment %main \"main\"\n" +
"OpExecutionMode %main OriginUpperLeft\n" +
"OpSource GLSL 140\n" +
"OpName %main \"main\"\n" +
"%void = OpTypeVoid\n" +
"%void_func = OpTypeFunction %void\n" +
"%bool = OpTypeBool\n" +
"%float = OpTypeFloat 32\n" +
"%_ptr_float = OpTypePointer Function %float\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%x = OpVariable %_ptr_float Function\n" +
"%y = OpVariable %_ptr_float Function\n" +
"%a = OpVariable %_ptr_float Function\n" +
"%lx = OpLoad %float %x\n" +
"%ly = OpLoad %float %y\n" +
"%mul = OpFMul %float %lx %ly\n" +
"%la = OpLoad %float %a\n" +
"%3 = OpFAdd %float %mul %la\n" +
"OpStore %a %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
3, true),
// Test 5: Don't fold if the multiple is marked no contract.
3, false),
// Test 4: Don't fold if the multiple is marked no contract.
InstructionFoldingCase<bool>(
std::string() +
"OpCapability Shader\n" +
@ -8110,7 +8036,7 @@ INSTANTIATE_TEST_SUITE_P(FmaGenerationMatchingTest, MatchingInstructionFoldingTe
"OpReturn\n" +
"OpFunctionEnd",
3, false),
// Test 6: Don't fold if the add is marked no contract.
// Test 5: Don't fold if the add is marked no contract.
InstructionFoldingCase<bool>(
std::string() +
"OpCapability Shader\n" +
@ -8139,20 +8065,9 @@ INSTANTIATE_TEST_SUITE_P(FmaGenerationMatchingTest, MatchingInstructionFoldingTe
"OpReturn\n" +
"OpFunctionEnd",
3, false),
// Test case 7: (x * y) - a = Fma(x, y, -a)
// Test case 6: Don't fold (x * y) - a
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[ext:%\\w+]] = OpExtInstImport \"GLSL.std.450\"\n" +
"; CHECK: OpFunction\n" +
"; CHECK: [[x:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[y:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[a:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[lx:%\\w+]] = OpLoad {{%\\w+}} [[x]]\n" +
"; CHECK: [[ly:%\\w+]] = OpLoad {{%\\w+}} [[y]]\n" +
"; CHECK: [[la:%\\w+]] = OpLoad {{%\\w+}} [[a]]\n" +
"; CHECK: [[na:%\\w+]] = OpFNegate {{%\\w+}} [[la]]\n" +
"; CHECK: [[fma:%\\w+]] = OpExtInst {{%\\w+}} [[ext]] Fma [[lx]] [[ly]] [[na]]\n" +
"; CHECK: OpStore {{%\\w+}} [[fma]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%x = OpVariable %_ptr_float Function\n" +
@ -8166,21 +8081,10 @@ INSTANTIATE_TEST_SUITE_P(FmaGenerationMatchingTest, MatchingInstructionFoldingTe
"OpStore %a %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
3, true),
// Test case 8: a - (x * y) = Fma(-x, y, a)
3, false),
// Test case 7: Don't fold a - (x * y)
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[ext:%\\w+]] = OpExtInstImport \"GLSL.std.450\"\n" +
"; CHECK: OpFunction\n" +
"; CHECK: [[x:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[y:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[a:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[lx:%\\w+]] = OpLoad {{%\\w+}} [[x]]\n" +
"; CHECK: [[ly:%\\w+]] = OpLoad {{%\\w+}} [[y]]\n" +
"; CHECK: [[la:%\\w+]] = OpLoad {{%\\w+}} [[a]]\n" +
"; CHECK: [[nx:%\\w+]] = OpFNegate {{%\\w+}} [[lx]]\n" +
"; CHECK: [[fma:%\\w+]] = OpExtInst {{%\\w+}} [[ext]] Fma [[nx]] [[ly]] [[la]]\n" +
"; CHECK: OpStore {{%\\w+}} [[fma]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%x = OpVariable %_ptr_float Function\n" +
@ -8194,7 +8098,7 @@ INSTANTIATE_TEST_SUITE_P(FmaGenerationMatchingTest, MatchingInstructionFoldingTe
"OpStore %a %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
3, true)
3, false)
));
using MatchingInstructionWithNoResultFoldingTest =