mirror of
https://github.com/KhronosGroup/SPIRV-Tools
synced 2024-11-22 19:50:05 +00:00
Add folding rule to generate Fma instructions (#4783)
Adding Fma instruction can speed up the code. This was requested by swiftshader, so they do not have to do this analysis themselves. It can also help reduce the code size, and the work the ICD compilers have to do.
This commit is contained in:
parent
cb96abbf7a
commit
2b2b0282af
@ -1430,6 +1430,64 @@ FoldingRule FactorAddMuls() {
|
||||
};
|
||||
}
|
||||
|
||||
// Replaces |inst| inplace with an FMA instruction |(x*y)+a|.
|
||||
void ReplaceWithFma(Instruction* inst, uint32_t x, uint32_t y, uint32_t a) {
|
||||
uint32_t ext =
|
||||
inst->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
|
||||
|
||||
if (ext == 0) {
|
||||
inst->context()->AddExtInstImport("GLSL.std.450");
|
||||
ext = inst->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
|
||||
assert(ext != 0 &&
|
||||
"Could not add the GLSL.std.450 extended instruction set");
|
||||
}
|
||||
|
||||
std::vector<Operand> operands;
|
||||
operands.push_back({SPV_OPERAND_TYPE_ID, {ext}});
|
||||
operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {GLSLstd450Fma}});
|
||||
operands.push_back({SPV_OPERAND_TYPE_ID, {x}});
|
||||
operands.push_back({SPV_OPERAND_TYPE_ID, {y}});
|
||||
operands.push_back({SPV_OPERAND_TYPE_ID, {a}});
|
||||
|
||||
inst->SetOpcode(SpvOpExtInst);
|
||||
inst->SetInOperands(std::move(operands));
|
||||
}
|
||||
|
||||
// Folds a multiple and add into an Fma.
|
||||
//
|
||||
// Cases:
|
||||
// (x * y) + a = Fma x y a
|
||||
// a + (x * y) = Fma x y a
|
||||
bool MergeMulAddArithmetic(IRContext* context, Instruction* inst,
|
||||
const std::vector<const analysis::Constant*>&) {
|
||||
assert(inst->opcode() == SpvOpFAdd);
|
||||
|
||||
if (!inst->IsFloatingPointFoldingAllowed()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
|
||||
for (int i = 0; i < 2; i++) {
|
||||
uint32_t op_id = inst->GetSingleWordInOperand(i);
|
||||
Instruction* op_inst = def_use_mgr->GetDef(op_id);
|
||||
|
||||
if (op_inst->opcode() != SpvOpFMul) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!op_inst->IsFloatingPointFoldingAllowed()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
uint32_t x = op_inst->GetSingleWordInOperand(0);
|
||||
uint32_t y = op_inst->GetSingleWordInOperand(1);
|
||||
uint32_t a = inst->GetSingleWordInOperand((i + 1) % 2);
|
||||
ReplaceWithFma(inst, x, y, a);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
FoldingRule IntMultipleBy1() {
|
||||
return [](IRContext*, Instruction* inst,
|
||||
const std::vector<const analysis::Constant*>& constants) {
|
||||
@ -2543,6 +2601,7 @@ void FoldingRules::AddFoldingRules() {
|
||||
rules_[SpvOpFAdd].push_back(MergeAddSubArithmetic());
|
||||
rules_[SpvOpFAdd].push_back(MergeGenericAddSubArithmetic());
|
||||
rules_[SpvOpFAdd].push_back(FactorAddMuls());
|
||||
rules_[SpvOpFAdd].push_back(MergeMulAddArithmetic);
|
||||
|
||||
rules_[SpvOpFDiv].push_back(RedundantFDiv());
|
||||
rules_[SpvOpFDiv].push_back(ReciprocalFDiv());
|
||||
|
@ -7108,6 +7108,214 @@ INSTANTIATE_TEST_SUITE_P(VectorShuffleMatchingTest, MatchingInstructionFoldingTe
|
||||
3, true)
|
||||
));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(FmaGenerationMatchingTest, MatchingInstructionFoldingTest,
|
||||
::testing::Values(
|
||||
// Test case 0: (x * y) + a = Fma(x, y, a)
|
||||
InstructionFoldingCase<bool>(
|
||||
Header() +
|
||||
"; CHECK: [[ext:%\\w+]] = OpExtInstImport \"GLSL.std.450\"\n" +
|
||||
"; CHECK: OpFunction\n" +
|
||||
"; CHECK: [[x:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
|
||||
"; CHECK: [[y:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
|
||||
"; CHECK: [[a:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
|
||||
"; CHECK: [[lx:%\\w+]] = OpLoad {{%\\w+}} [[x]]\n" +
|
||||
"; CHECK: [[ly:%\\w+]] = OpLoad {{%\\w+}} [[y]]\n" +
|
||||
"; CHECK: [[la:%\\w+]] = OpLoad {{%\\w+}} [[a]]\n" +
|
||||
"; CHECK: [[fma:%\\w+]] = OpExtInst {{%\\w+}} [[ext]] Fma [[lx]] [[ly]] [[la]]\n" +
|
||||
"; CHECK: OpStore {{%\\w+}} [[fma]]\n" +
|
||||
"%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%x = OpVariable %_ptr_float Function\n" +
|
||||
"%y = OpVariable %_ptr_float Function\n" +
|
||||
"%a = OpVariable %_ptr_float Function\n" +
|
||||
"%lx = OpLoad %float %x\n" +
|
||||
"%ly = OpLoad %float %y\n" +
|
||||
"%mul = OpFMul %float %lx %ly\n" +
|
||||
"%la = OpLoad %float %a\n" +
|
||||
"%3 = OpFAdd %float %mul %la\n" +
|
||||
"OpStore %a %3\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
3, true),
|
||||
// Test case 1: a + (x * y) = Fma(x, y, a)
|
||||
InstructionFoldingCase<bool>(
|
||||
Header() +
|
||||
"; CHECK: [[ext:%\\w+]] = OpExtInstImport \"GLSL.std.450\"\n" +
|
||||
"; CHECK: OpFunction\n" +
|
||||
"; CHECK: [[x:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
|
||||
"; CHECK: [[y:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
|
||||
"; CHECK: [[a:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
|
||||
"; CHECK: [[lx:%\\w+]] = OpLoad {{%\\w+}} [[x]]\n" +
|
||||
"; CHECK: [[ly:%\\w+]] = OpLoad {{%\\w+}} [[y]]\n" +
|
||||
"; CHECK: [[la:%\\w+]] = OpLoad {{%\\w+}} [[a]]\n" +
|
||||
"; CHECK: [[fma:%\\w+]] = OpExtInst {{%\\w+}} [[ext]] Fma [[lx]] [[ly]] [[la]]\n" +
|
||||
"; CHECK: OpStore {{%\\w+}} [[fma]]\n" +
|
||||
"%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%x = OpVariable %_ptr_float Function\n" +
|
||||
"%y = OpVariable %_ptr_float Function\n" +
|
||||
"%a = OpVariable %_ptr_float Function\n" +
|
||||
"%lx = OpLoad %float %x\n" +
|
||||
"%ly = OpLoad %float %y\n" +
|
||||
"%mul = OpFMul %float %lx %ly\n" +
|
||||
"%la = OpLoad %float %a\n" +
|
||||
"%3 = OpFAdd %float %la %mul\n" +
|
||||
"OpStore %a %3\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
3, true),
|
||||
// Test case 2: (x * y) + a = Fma(x, y, a) with vectors
|
||||
InstructionFoldingCase<bool>(
|
||||
Header() +
|
||||
"; CHECK: [[ext:%\\w+]] = OpExtInstImport \"GLSL.std.450\"\n" +
|
||||
"; CHECK: OpFunction\n" +
|
||||
"; CHECK: [[x:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
|
||||
"; CHECK: [[y:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
|
||||
"; CHECK: [[a:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
|
||||
"; CHECK: [[lx:%\\w+]] = OpLoad {{%\\w+}} [[x]]\n" +
|
||||
"; CHECK: [[ly:%\\w+]] = OpLoad {{%\\w+}} [[y]]\n" +
|
||||
"; CHECK: [[la:%\\w+]] = OpLoad {{%\\w+}} [[a]]\n" +
|
||||
"; CHECK: [[fma:%\\w+]] = OpExtInst {{%\\w+}} [[ext]] Fma [[lx]] [[ly]] [[la]]\n" +
|
||||
"; CHECK: OpStore {{%\\w+}} [[fma]]\n" +
|
||||
"%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%x = OpVariable %_ptr_v4float Function\n" +
|
||||
"%y = OpVariable %_ptr_v4float Function\n" +
|
||||
"%a = OpVariable %_ptr_v4float Function\n" +
|
||||
"%lx = OpLoad %v4float %x\n" +
|
||||
"%ly = OpLoad %v4float %y\n" +
|
||||
"%mul = OpFMul %v4float %lx %ly\n" +
|
||||
"%la = OpLoad %v4float %a\n" +
|
||||
"%3 = OpFAdd %v4float %mul %la\n" +
|
||||
"OpStore %a %3\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
3, true),
|
||||
// Test case 3: a + (x * y) = Fma(x, y, a) with vectors
|
||||
InstructionFoldingCase<bool>(
|
||||
Header() +
|
||||
"; CHECK: [[ext:%\\w+]] = OpExtInstImport \"GLSL.std.450\"\n" +
|
||||
"; CHECK: OpFunction\n" +
|
||||
"; CHECK: [[x:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
|
||||
"; CHECK: [[y:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
|
||||
"; CHECK: [[a:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
|
||||
"; CHECK: [[lx:%\\w+]] = OpLoad {{%\\w+}} [[x]]\n" +
|
||||
"; CHECK: [[ly:%\\w+]] = OpLoad {{%\\w+}} [[y]]\n" +
|
||||
"; CHECK: [[la:%\\w+]] = OpLoad {{%\\w+}} [[a]]\n" +
|
||||
"; CHECK: [[fma:%\\w+]] = OpExtInst {{%\\w+}} [[ext]] Fma [[lx]] [[ly]] [[la]]\n" +
|
||||
"; CHECK: OpStore {{%\\w+}} [[fma]]\n" +
|
||||
"%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%x = OpVariable %_ptr_float Function\n" +
|
||||
"%y = OpVariable %_ptr_float Function\n" +
|
||||
"%a = OpVariable %_ptr_float Function\n" +
|
||||
"%lx = OpLoad %float %x\n" +
|
||||
"%ly = OpLoad %float %y\n" +
|
||||
"%mul = OpFMul %float %lx %ly\n" +
|
||||
"%la = OpLoad %float %a\n" +
|
||||
"%3 = OpFAdd %float %la %mul\n" +
|
||||
"OpStore %a %3\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
3, true),
|
||||
// Test 5: that the OpExtInstImport instruction is generated if it is missing.
|
||||
InstructionFoldingCase<bool>(
|
||||
std::string() +
|
||||
"; CHECK: [[ext:%\\w+]] = OpExtInstImport \"GLSL.std.450\"\n" +
|
||||
"; CHECK: OpFunction\n" +
|
||||
"; CHECK: [[x:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
|
||||
"; CHECK: [[y:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
|
||||
"; CHECK: [[a:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
|
||||
"; CHECK: [[lx:%\\w+]] = OpLoad {{%\\w+}} [[x]]\n" +
|
||||
"; CHECK: [[ly:%\\w+]] = OpLoad {{%\\w+}} [[y]]\n" +
|
||||
"; CHECK: [[la:%\\w+]] = OpLoad {{%\\w+}} [[a]]\n" +
|
||||
"; CHECK: [[fma:%\\w+]] = OpExtInst {{%\\w+}} [[ext]] Fma [[lx]] [[ly]] [[la]]\n" +
|
||||
"; CHECK: OpStore {{%\\w+}} [[fma]]\n" +
|
||||
"OpCapability Shader\n" +
|
||||
"OpMemoryModel Logical GLSL450\n" +
|
||||
"OpEntryPoint Fragment %main \"main\"\n" +
|
||||
"OpExecutionMode %main OriginUpperLeft\n" +
|
||||
"OpSource GLSL 140\n" +
|
||||
"OpName %main \"main\"\n" +
|
||||
"%void = OpTypeVoid\n" +
|
||||
"%void_func = OpTypeFunction %void\n" +
|
||||
"%bool = OpTypeBool\n" +
|
||||
"%float = OpTypeFloat 32\n" +
|
||||
"%_ptr_float = OpTypePointer Function %float\n" +
|
||||
"%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%x = OpVariable %_ptr_float Function\n" +
|
||||
"%y = OpVariable %_ptr_float Function\n" +
|
||||
"%a = OpVariable %_ptr_float Function\n" +
|
||||
"%lx = OpLoad %float %x\n" +
|
||||
"%ly = OpLoad %float %y\n" +
|
||||
"%mul = OpFMul %float %lx %ly\n" +
|
||||
"%la = OpLoad %float %a\n" +
|
||||
"%3 = OpFAdd %float %mul %la\n" +
|
||||
"OpStore %a %3\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
3, true),
|
||||
// Test 5: Don't fold if the multiple is marked no contract.
|
||||
InstructionFoldingCase<bool>(
|
||||
std::string() +
|
||||
"OpCapability Shader\n" +
|
||||
"OpMemoryModel Logical GLSL450\n" +
|
||||
"OpEntryPoint Fragment %main \"main\"\n" +
|
||||
"OpExecutionMode %main OriginUpperLeft\n" +
|
||||
"OpSource GLSL 140\n" +
|
||||
"OpName %main \"main\"\n" +
|
||||
"OpDecorate %mul NoContraction\n" +
|
||||
"%void = OpTypeVoid\n" +
|
||||
"%void_func = OpTypeFunction %void\n" +
|
||||
"%bool = OpTypeBool\n" +
|
||||
"%float = OpTypeFloat 32\n" +
|
||||
"%_ptr_float = OpTypePointer Function %float\n" +
|
||||
"%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%x = OpVariable %_ptr_float Function\n" +
|
||||
"%y = OpVariable %_ptr_float Function\n" +
|
||||
"%a = OpVariable %_ptr_float Function\n" +
|
||||
"%lx = OpLoad %float %x\n" +
|
||||
"%ly = OpLoad %float %y\n" +
|
||||
"%mul = OpFMul %float %lx %ly\n" +
|
||||
"%la = OpLoad %float %a\n" +
|
||||
"%3 = OpFAdd %float %mul %la\n" +
|
||||
"OpStore %a %3\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
3, false),
|
||||
// Test 6: Don't fold if the add is marked no contract.
|
||||
InstructionFoldingCase<bool>(
|
||||
std::string() +
|
||||
"OpCapability Shader\n" +
|
||||
"OpMemoryModel Logical GLSL450\n" +
|
||||
"OpEntryPoint Fragment %main \"main\"\n" +
|
||||
"OpExecutionMode %main OriginUpperLeft\n" +
|
||||
"OpSource GLSL 140\n" +
|
||||
"OpName %main \"main\"\n" +
|
||||
"OpDecorate %3 NoContraction\n" +
|
||||
"%void = OpTypeVoid\n" +
|
||||
"%void_func = OpTypeFunction %void\n" +
|
||||
"%bool = OpTypeBool\n" +
|
||||
"%float = OpTypeFloat 32\n" +
|
||||
"%_ptr_float = OpTypePointer Function %float\n" +
|
||||
"%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%x = OpVariable %_ptr_float Function\n" +
|
||||
"%y = OpVariable %_ptr_float Function\n" +
|
||||
"%a = OpVariable %_ptr_float Function\n" +
|
||||
"%lx = OpLoad %float %x\n" +
|
||||
"%ly = OpLoad %float %y\n" +
|
||||
"%mul = OpFMul %float %lx %ly\n" +
|
||||
"%la = OpLoad %float %a\n" +
|
||||
"%3 = OpFAdd %float %mul %la\n" +
|
||||
"OpStore %a %3\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
3, false)
|
||||
));
|
||||
|
||||
using MatchingInstructionWithNoResultFoldingTest =
|
||||
::testing::TestWithParam<InstructionFoldingCase<bool>>;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user