Reapply "Add folding rule to generate Fma instructions (#4783)" (#4789)

This reverts commit 671f6e633f.

PR #4783 was reverted because it caused OpenCL CTS failures for clvk.
The was in clspv, which was not adding the no contract decoration when
it was required.  This has been fixed in
https://github.com/google/clspv/pull/845.  We can now reapply #4783.
This commit is contained in:
Steven Perron 2022-05-03 10:20:23 -04:00 committed by GitHub
parent edaf51038b
commit 1295dca8e2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 267 additions and 0 deletions

View File

@ -1430,6 +1430,64 @@ FoldingRule FactorAddMuls() {
};
}
// Replaces |inst| inplace with an FMA instruction |(x*y)+a|.
void ReplaceWithFma(Instruction* inst, uint32_t x, uint32_t y, uint32_t a) {
uint32_t ext =
inst->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
if (ext == 0) {
inst->context()->AddExtInstImport("GLSL.std.450");
ext = inst->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
assert(ext != 0 &&
"Could not add the GLSL.std.450 extended instruction set");
}
std::vector<Operand> operands;
operands.push_back({SPV_OPERAND_TYPE_ID, {ext}});
operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {GLSLstd450Fma}});
operands.push_back({SPV_OPERAND_TYPE_ID, {x}});
operands.push_back({SPV_OPERAND_TYPE_ID, {y}});
operands.push_back({SPV_OPERAND_TYPE_ID, {a}});
inst->SetOpcode(SpvOpExtInst);
inst->SetInOperands(std::move(operands));
}
// Folds a multiple and add into an Fma.
//
// Cases:
// (x * y) + a = Fma x y a
// a + (x * y) = Fma x y a
bool MergeMulAddArithmetic(IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>&) {
assert(inst->opcode() == SpvOpFAdd);
if (!inst->IsFloatingPointFoldingAllowed()) {
return false;
}
analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
for (int i = 0; i < 2; i++) {
uint32_t op_id = inst->GetSingleWordInOperand(i);
Instruction* op_inst = def_use_mgr->GetDef(op_id);
if (op_inst->opcode() != SpvOpFMul) {
continue;
}
if (!op_inst->IsFloatingPointFoldingAllowed()) {
continue;
}
uint32_t x = op_inst->GetSingleWordInOperand(0);
uint32_t y = op_inst->GetSingleWordInOperand(1);
uint32_t a = inst->GetSingleWordInOperand((i + 1) % 2);
ReplaceWithFma(inst, x, y, a);
return true;
}
return false;
}
FoldingRule IntMultipleBy1() {
return [](IRContext*, Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
@ -2543,6 +2601,7 @@ void FoldingRules::AddFoldingRules() {
rules_[SpvOpFAdd].push_back(MergeAddSubArithmetic());
rules_[SpvOpFAdd].push_back(MergeGenericAddSubArithmetic());
rules_[SpvOpFAdd].push_back(FactorAddMuls());
rules_[SpvOpFAdd].push_back(MergeMulAddArithmetic);
rules_[SpvOpFDiv].push_back(RedundantFDiv());
rules_[SpvOpFDiv].push_back(ReciprocalFDiv());

View File

@ -7108,6 +7108,214 @@ INSTANTIATE_TEST_SUITE_P(VectorShuffleMatchingTest, MatchingInstructionFoldingTe
3, true)
));
INSTANTIATE_TEST_SUITE_P(FmaGenerationMatchingTest, MatchingInstructionFoldingTest,
::testing::Values(
// Test case 0: (x * y) + a = Fma(x, y, a)
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[ext:%\\w+]] = OpExtInstImport \"GLSL.std.450\"\n" +
"; CHECK: OpFunction\n" +
"; CHECK: [[x:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[y:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[a:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[lx:%\\w+]] = OpLoad {{%\\w+}} [[x]]\n" +
"; CHECK: [[ly:%\\w+]] = OpLoad {{%\\w+}} [[y]]\n" +
"; CHECK: [[la:%\\w+]] = OpLoad {{%\\w+}} [[a]]\n" +
"; CHECK: [[fma:%\\w+]] = OpExtInst {{%\\w+}} [[ext]] Fma [[lx]] [[ly]] [[la]]\n" +
"; CHECK: OpStore {{%\\w+}} [[fma]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%x = OpVariable %_ptr_float Function\n" +
"%y = OpVariable %_ptr_float Function\n" +
"%a = OpVariable %_ptr_float Function\n" +
"%lx = OpLoad %float %x\n" +
"%ly = OpLoad %float %y\n" +
"%mul = OpFMul %float %lx %ly\n" +
"%la = OpLoad %float %a\n" +
"%3 = OpFAdd %float %mul %la\n" +
"OpStore %a %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
3, true),
// Test case 1: a + (x * y) = Fma(x, y, a)
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[ext:%\\w+]] = OpExtInstImport \"GLSL.std.450\"\n" +
"; CHECK: OpFunction\n" +
"; CHECK: [[x:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[y:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[a:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[lx:%\\w+]] = OpLoad {{%\\w+}} [[x]]\n" +
"; CHECK: [[ly:%\\w+]] = OpLoad {{%\\w+}} [[y]]\n" +
"; CHECK: [[la:%\\w+]] = OpLoad {{%\\w+}} [[a]]\n" +
"; CHECK: [[fma:%\\w+]] = OpExtInst {{%\\w+}} [[ext]] Fma [[lx]] [[ly]] [[la]]\n" +
"; CHECK: OpStore {{%\\w+}} [[fma]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%x = OpVariable %_ptr_float Function\n" +
"%y = OpVariable %_ptr_float Function\n" +
"%a = OpVariable %_ptr_float Function\n" +
"%lx = OpLoad %float %x\n" +
"%ly = OpLoad %float %y\n" +
"%mul = OpFMul %float %lx %ly\n" +
"%la = OpLoad %float %a\n" +
"%3 = OpFAdd %float %la %mul\n" +
"OpStore %a %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
3, true),
// Test case 2: (x * y) + a = Fma(x, y, a) with vectors
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[ext:%\\w+]] = OpExtInstImport \"GLSL.std.450\"\n" +
"; CHECK: OpFunction\n" +
"; CHECK: [[x:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[y:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[a:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[lx:%\\w+]] = OpLoad {{%\\w+}} [[x]]\n" +
"; CHECK: [[ly:%\\w+]] = OpLoad {{%\\w+}} [[y]]\n" +
"; CHECK: [[la:%\\w+]] = OpLoad {{%\\w+}} [[a]]\n" +
"; CHECK: [[fma:%\\w+]] = OpExtInst {{%\\w+}} [[ext]] Fma [[lx]] [[ly]] [[la]]\n" +
"; CHECK: OpStore {{%\\w+}} [[fma]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%x = OpVariable %_ptr_v4float Function\n" +
"%y = OpVariable %_ptr_v4float Function\n" +
"%a = OpVariable %_ptr_v4float Function\n" +
"%lx = OpLoad %v4float %x\n" +
"%ly = OpLoad %v4float %y\n" +
"%mul = OpFMul %v4float %lx %ly\n" +
"%la = OpLoad %v4float %a\n" +
"%3 = OpFAdd %v4float %mul %la\n" +
"OpStore %a %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
3, true),
// Test case 3: a + (x * y) = Fma(x, y, a) with vectors
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[ext:%\\w+]] = OpExtInstImport \"GLSL.std.450\"\n" +
"; CHECK: OpFunction\n" +
"; CHECK: [[x:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[y:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[a:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[lx:%\\w+]] = OpLoad {{%\\w+}} [[x]]\n" +
"; CHECK: [[ly:%\\w+]] = OpLoad {{%\\w+}} [[y]]\n" +
"; CHECK: [[la:%\\w+]] = OpLoad {{%\\w+}} [[a]]\n" +
"; CHECK: [[fma:%\\w+]] = OpExtInst {{%\\w+}} [[ext]] Fma [[lx]] [[ly]] [[la]]\n" +
"; CHECK: OpStore {{%\\w+}} [[fma]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%x = OpVariable %_ptr_float Function\n" +
"%y = OpVariable %_ptr_float Function\n" +
"%a = OpVariable %_ptr_float Function\n" +
"%lx = OpLoad %float %x\n" +
"%ly = OpLoad %float %y\n" +
"%mul = OpFMul %float %lx %ly\n" +
"%la = OpLoad %float %a\n" +
"%3 = OpFAdd %float %la %mul\n" +
"OpStore %a %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
3, true),
// Test 5: that the OpExtInstImport instruction is generated if it is missing.
InstructionFoldingCase<bool>(
std::string() +
"; CHECK: [[ext:%\\w+]] = OpExtInstImport \"GLSL.std.450\"\n" +
"; CHECK: OpFunction\n" +
"; CHECK: [[x:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[y:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[a:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[lx:%\\w+]] = OpLoad {{%\\w+}} [[x]]\n" +
"; CHECK: [[ly:%\\w+]] = OpLoad {{%\\w+}} [[y]]\n" +
"; CHECK: [[la:%\\w+]] = OpLoad {{%\\w+}} [[a]]\n" +
"; CHECK: [[fma:%\\w+]] = OpExtInst {{%\\w+}} [[ext]] Fma [[lx]] [[ly]] [[la]]\n" +
"; CHECK: OpStore {{%\\w+}} [[fma]]\n" +
"OpCapability Shader\n" +
"OpMemoryModel Logical GLSL450\n" +
"OpEntryPoint Fragment %main \"main\"\n" +
"OpExecutionMode %main OriginUpperLeft\n" +
"OpSource GLSL 140\n" +
"OpName %main \"main\"\n" +
"%void = OpTypeVoid\n" +
"%void_func = OpTypeFunction %void\n" +
"%bool = OpTypeBool\n" +
"%float = OpTypeFloat 32\n" +
"%_ptr_float = OpTypePointer Function %float\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%x = OpVariable %_ptr_float Function\n" +
"%y = OpVariable %_ptr_float Function\n" +
"%a = OpVariable %_ptr_float Function\n" +
"%lx = OpLoad %float %x\n" +
"%ly = OpLoad %float %y\n" +
"%mul = OpFMul %float %lx %ly\n" +
"%la = OpLoad %float %a\n" +
"%3 = OpFAdd %float %mul %la\n" +
"OpStore %a %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
3, true),
// Test 5: Don't fold if the multiple is marked no contract.
InstructionFoldingCase<bool>(
std::string() +
"OpCapability Shader\n" +
"OpMemoryModel Logical GLSL450\n" +
"OpEntryPoint Fragment %main \"main\"\n" +
"OpExecutionMode %main OriginUpperLeft\n" +
"OpSource GLSL 140\n" +
"OpName %main \"main\"\n" +
"OpDecorate %mul NoContraction\n" +
"%void = OpTypeVoid\n" +
"%void_func = OpTypeFunction %void\n" +
"%bool = OpTypeBool\n" +
"%float = OpTypeFloat 32\n" +
"%_ptr_float = OpTypePointer Function %float\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%x = OpVariable %_ptr_float Function\n" +
"%y = OpVariable %_ptr_float Function\n" +
"%a = OpVariable %_ptr_float Function\n" +
"%lx = OpLoad %float %x\n" +
"%ly = OpLoad %float %y\n" +
"%mul = OpFMul %float %lx %ly\n" +
"%la = OpLoad %float %a\n" +
"%3 = OpFAdd %float %mul %la\n" +
"OpStore %a %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
3, false),
// Test 6: Don't fold if the add is marked no contract.
InstructionFoldingCase<bool>(
std::string() +
"OpCapability Shader\n" +
"OpMemoryModel Logical GLSL450\n" +
"OpEntryPoint Fragment %main \"main\"\n" +
"OpExecutionMode %main OriginUpperLeft\n" +
"OpSource GLSL 140\n" +
"OpName %main \"main\"\n" +
"OpDecorate %3 NoContraction\n" +
"%void = OpTypeVoid\n" +
"%void_func = OpTypeFunction %void\n" +
"%bool = OpTypeBool\n" +
"%float = OpTypeFloat 32\n" +
"%_ptr_float = OpTypePointer Function %float\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%x = OpVariable %_ptr_float Function\n" +
"%y = OpVariable %_ptr_float Function\n" +
"%a = OpVariable %_ptr_float Function\n" +
"%lx = OpLoad %float %x\n" +
"%ly = OpLoad %float %y\n" +
"%mul = OpFMul %float %lx %ly\n" +
"%la = OpLoad %float %a\n" +
"%3 = OpFAdd %float %mul %la\n" +
"OpStore %a %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
3, false)
));
using MatchingInstructionWithNoResultFoldingTest =
::testing::TestWithParam<InstructionFoldingCase<bool>>;