Add folding rule to generate Fma instructions (#4783)

Adding Fma instruction can speed up the code.  This was requested by
swiftshader, so they do not have to do this analysis themselves.  It can
also help reduce the code size, and the work the ICD compilers have to
do.
This commit is contained in:
Steven Perron 2022-04-19 11:25:07 -04:00 committed by GitHub
parent cb96abbf7a
commit 2b2b0282af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 267 additions and 0 deletions

View File

@ -1430,6 +1430,64 @@ FoldingRule FactorAddMuls() {
};
}
// Replaces |inst| inplace with an FMA instruction |(x*y)+a|.
void ReplaceWithFma(Instruction* inst, uint32_t x, uint32_t y, uint32_t a) {
uint32_t ext =
inst->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
if (ext == 0) {
inst->context()->AddExtInstImport("GLSL.std.450");
ext = inst->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
assert(ext != 0 &&
"Could not add the GLSL.std.450 extended instruction set");
}
std::vector<Operand> operands;
operands.push_back({SPV_OPERAND_TYPE_ID, {ext}});
operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {GLSLstd450Fma}});
operands.push_back({SPV_OPERAND_TYPE_ID, {x}});
operands.push_back({SPV_OPERAND_TYPE_ID, {y}});
operands.push_back({SPV_OPERAND_TYPE_ID, {a}});
inst->SetOpcode(SpvOpExtInst);
inst->SetInOperands(std::move(operands));
}
// Folds a multiple and add into an Fma.
//
// Cases:
// (x * y) + a = Fma x y a
// a + (x * y) = Fma x y a
bool MergeMulAddArithmetic(IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>&) {
assert(inst->opcode() == SpvOpFAdd);
if (!inst->IsFloatingPointFoldingAllowed()) {
return false;
}
analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
for (int i = 0; i < 2; i++) {
uint32_t op_id = inst->GetSingleWordInOperand(i);
Instruction* op_inst = def_use_mgr->GetDef(op_id);
if (op_inst->opcode() != SpvOpFMul) {
continue;
}
if (!op_inst->IsFloatingPointFoldingAllowed()) {
continue;
}
uint32_t x = op_inst->GetSingleWordInOperand(0);
uint32_t y = op_inst->GetSingleWordInOperand(1);
uint32_t a = inst->GetSingleWordInOperand((i + 1) % 2);
ReplaceWithFma(inst, x, y, a);
return true;
}
return false;
}
FoldingRule IntMultipleBy1() {
return [](IRContext*, Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
@ -2543,6 +2601,7 @@ void FoldingRules::AddFoldingRules() {
rules_[SpvOpFAdd].push_back(MergeAddSubArithmetic());
rules_[SpvOpFAdd].push_back(MergeGenericAddSubArithmetic());
rules_[SpvOpFAdd].push_back(FactorAddMuls());
rules_[SpvOpFAdd].push_back(MergeMulAddArithmetic);
rules_[SpvOpFDiv].push_back(RedundantFDiv());
rules_[SpvOpFDiv].push_back(ReciprocalFDiv());

View File

@ -7108,6 +7108,214 @@ INSTANTIATE_TEST_SUITE_P(VectorShuffleMatchingTest, MatchingInstructionFoldingTe
3, true)
));
INSTANTIATE_TEST_SUITE_P(FmaGenerationMatchingTest, MatchingInstructionFoldingTest,
::testing::Values(
// Test case 0: (x * y) + a = Fma(x, y, a)
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[ext:%\\w+]] = OpExtInstImport \"GLSL.std.450\"\n" +
"; CHECK: OpFunction\n" +
"; CHECK: [[x:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[y:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[a:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[lx:%\\w+]] = OpLoad {{%\\w+}} [[x]]\n" +
"; CHECK: [[ly:%\\w+]] = OpLoad {{%\\w+}} [[y]]\n" +
"; CHECK: [[la:%\\w+]] = OpLoad {{%\\w+}} [[a]]\n" +
"; CHECK: [[fma:%\\w+]] = OpExtInst {{%\\w+}} [[ext]] Fma [[lx]] [[ly]] [[la]]\n" +
"; CHECK: OpStore {{%\\w+}} [[fma]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%x = OpVariable %_ptr_float Function\n" +
"%y = OpVariable %_ptr_float Function\n" +
"%a = OpVariable %_ptr_float Function\n" +
"%lx = OpLoad %float %x\n" +
"%ly = OpLoad %float %y\n" +
"%mul = OpFMul %float %lx %ly\n" +
"%la = OpLoad %float %a\n" +
"%3 = OpFAdd %float %mul %la\n" +
"OpStore %a %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
3, true),
// Test case 1: a + (x * y) = Fma(x, y, a)
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[ext:%\\w+]] = OpExtInstImport \"GLSL.std.450\"\n" +
"; CHECK: OpFunction\n" +
"; CHECK: [[x:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[y:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[a:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[lx:%\\w+]] = OpLoad {{%\\w+}} [[x]]\n" +
"; CHECK: [[ly:%\\w+]] = OpLoad {{%\\w+}} [[y]]\n" +
"; CHECK: [[la:%\\w+]] = OpLoad {{%\\w+}} [[a]]\n" +
"; CHECK: [[fma:%\\w+]] = OpExtInst {{%\\w+}} [[ext]] Fma [[lx]] [[ly]] [[la]]\n" +
"; CHECK: OpStore {{%\\w+}} [[fma]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%x = OpVariable %_ptr_float Function\n" +
"%y = OpVariable %_ptr_float Function\n" +
"%a = OpVariable %_ptr_float Function\n" +
"%lx = OpLoad %float %x\n" +
"%ly = OpLoad %float %y\n" +
"%mul = OpFMul %float %lx %ly\n" +
"%la = OpLoad %float %a\n" +
"%3 = OpFAdd %float %la %mul\n" +
"OpStore %a %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
3, true),
// Test case 2: (x * y) + a = Fma(x, y, a) with vectors
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[ext:%\\w+]] = OpExtInstImport \"GLSL.std.450\"\n" +
"; CHECK: OpFunction\n" +
"; CHECK: [[x:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[y:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[a:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[lx:%\\w+]] = OpLoad {{%\\w+}} [[x]]\n" +
"; CHECK: [[ly:%\\w+]] = OpLoad {{%\\w+}} [[y]]\n" +
"; CHECK: [[la:%\\w+]] = OpLoad {{%\\w+}} [[a]]\n" +
"; CHECK: [[fma:%\\w+]] = OpExtInst {{%\\w+}} [[ext]] Fma [[lx]] [[ly]] [[la]]\n" +
"; CHECK: OpStore {{%\\w+}} [[fma]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%x = OpVariable %_ptr_v4float Function\n" +
"%y = OpVariable %_ptr_v4float Function\n" +
"%a = OpVariable %_ptr_v4float Function\n" +
"%lx = OpLoad %v4float %x\n" +
"%ly = OpLoad %v4float %y\n" +
"%mul = OpFMul %v4float %lx %ly\n" +
"%la = OpLoad %v4float %a\n" +
"%3 = OpFAdd %v4float %mul %la\n" +
"OpStore %a %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
3, true),
// Test case 3: a + (x * y) = Fma(x, y, a) with vectors
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[ext:%\\w+]] = OpExtInstImport \"GLSL.std.450\"\n" +
"; CHECK: OpFunction\n" +
"; CHECK: [[x:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[y:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[a:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[lx:%\\w+]] = OpLoad {{%\\w+}} [[x]]\n" +
"; CHECK: [[ly:%\\w+]] = OpLoad {{%\\w+}} [[y]]\n" +
"; CHECK: [[la:%\\w+]] = OpLoad {{%\\w+}} [[a]]\n" +
"; CHECK: [[fma:%\\w+]] = OpExtInst {{%\\w+}} [[ext]] Fma [[lx]] [[ly]] [[la]]\n" +
"; CHECK: OpStore {{%\\w+}} [[fma]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%x = OpVariable %_ptr_float Function\n" +
"%y = OpVariable %_ptr_float Function\n" +
"%a = OpVariable %_ptr_float Function\n" +
"%lx = OpLoad %float %x\n" +
"%ly = OpLoad %float %y\n" +
"%mul = OpFMul %float %lx %ly\n" +
"%la = OpLoad %float %a\n" +
"%3 = OpFAdd %float %la %mul\n" +
"OpStore %a %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
3, true),
// Test 5: that the OpExtInstImport instruction is generated if it is missing.
InstructionFoldingCase<bool>(
std::string() +
"; CHECK: [[ext:%\\w+]] = OpExtInstImport \"GLSL.std.450\"\n" +
"; CHECK: OpFunction\n" +
"; CHECK: [[x:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[y:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[a:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[lx:%\\w+]] = OpLoad {{%\\w+}} [[x]]\n" +
"; CHECK: [[ly:%\\w+]] = OpLoad {{%\\w+}} [[y]]\n" +
"; CHECK: [[la:%\\w+]] = OpLoad {{%\\w+}} [[a]]\n" +
"; CHECK: [[fma:%\\w+]] = OpExtInst {{%\\w+}} [[ext]] Fma [[lx]] [[ly]] [[la]]\n" +
"; CHECK: OpStore {{%\\w+}} [[fma]]\n" +
"OpCapability Shader\n" +
"OpMemoryModel Logical GLSL450\n" +
"OpEntryPoint Fragment %main \"main\"\n" +
"OpExecutionMode %main OriginUpperLeft\n" +
"OpSource GLSL 140\n" +
"OpName %main \"main\"\n" +
"%void = OpTypeVoid\n" +
"%void_func = OpTypeFunction %void\n" +
"%bool = OpTypeBool\n" +
"%float = OpTypeFloat 32\n" +
"%_ptr_float = OpTypePointer Function %float\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%x = OpVariable %_ptr_float Function\n" +
"%y = OpVariable %_ptr_float Function\n" +
"%a = OpVariable %_ptr_float Function\n" +
"%lx = OpLoad %float %x\n" +
"%ly = OpLoad %float %y\n" +
"%mul = OpFMul %float %lx %ly\n" +
"%la = OpLoad %float %a\n" +
"%3 = OpFAdd %float %mul %la\n" +
"OpStore %a %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
3, true),
// Test 5: Don't fold if the multiple is marked no contract.
InstructionFoldingCase<bool>(
std::string() +
"OpCapability Shader\n" +
"OpMemoryModel Logical GLSL450\n" +
"OpEntryPoint Fragment %main \"main\"\n" +
"OpExecutionMode %main OriginUpperLeft\n" +
"OpSource GLSL 140\n" +
"OpName %main \"main\"\n" +
"OpDecorate %mul NoContraction\n" +
"%void = OpTypeVoid\n" +
"%void_func = OpTypeFunction %void\n" +
"%bool = OpTypeBool\n" +
"%float = OpTypeFloat 32\n" +
"%_ptr_float = OpTypePointer Function %float\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%x = OpVariable %_ptr_float Function\n" +
"%y = OpVariable %_ptr_float Function\n" +
"%a = OpVariable %_ptr_float Function\n" +
"%lx = OpLoad %float %x\n" +
"%ly = OpLoad %float %y\n" +
"%mul = OpFMul %float %lx %ly\n" +
"%la = OpLoad %float %a\n" +
"%3 = OpFAdd %float %mul %la\n" +
"OpStore %a %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
3, false),
// Test 6: Don't fold if the add is marked no contract.
InstructionFoldingCase<bool>(
std::string() +
"OpCapability Shader\n" +
"OpMemoryModel Logical GLSL450\n" +
"OpEntryPoint Fragment %main \"main\"\n" +
"OpExecutionMode %main OriginUpperLeft\n" +
"OpSource GLSL 140\n" +
"OpName %main \"main\"\n" +
"OpDecorate %3 NoContraction\n" +
"%void = OpTypeVoid\n" +
"%void_func = OpTypeFunction %void\n" +
"%bool = OpTypeBool\n" +
"%float = OpTypeFloat 32\n" +
"%_ptr_float = OpTypePointer Function %float\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%x = OpVariable %_ptr_float Function\n" +
"%y = OpVariable %_ptr_float Function\n" +
"%a = OpVariable %_ptr_float Function\n" +
"%lx = OpLoad %float %x\n" +
"%ly = OpLoad %float %y\n" +
"%mul = OpFMul %float %lx %ly\n" +
"%la = OpLoad %float %a\n" +
"%3 = OpFAdd %float %mul %la\n" +
"OpStore %a %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
3, false)
));
using MatchingInstructionWithNoResultFoldingTest =
::testing::TestWithParam<InstructionFoldingCase<bool>>;