Do not distrubute OpSNegate into OpUDiv (#5823)

We cannot apply the negate to an operand of an OpUDiv instead of it
result. This is because the operands of the OpUDiv are interpreted as
unsigned. We stop the optimizer from doing that.

There were no tests for distributing a negate into OpIMul, OpSDiv, and
OpUDiv. Tests are added for all of these.

Fixes #5822
This commit is contained in:
Steven Perron 2024-09-26 15:00:56 -04:00 committed by GitHub
parent 5c8442f7fc
commit 5b38abc877
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 218 additions and 29 deletions

View File

@ -422,36 +422,37 @@ FoldingRule MergeNegateMulDivArithmetic() {
if (width != 32 && width != 64) return false;
spv::Op opcode = op_inst->opcode();
if (opcode == spv::Op::OpFMul || opcode == spv::Op::OpFDiv ||
opcode == spv::Op::OpIMul || opcode == spv::Op::OpSDiv ||
opcode == spv::Op::OpUDiv) {
std::vector<const analysis::Constant*> op_constants =
const_mgr->GetOperandConstants(op_inst);
// Merge negate into mul or div if one operand is constant.
if (op_constants[0] || op_constants[1]) {
bool zero_is_variable = op_constants[0] == nullptr;
const analysis::Constant* c = ConstInput(op_constants);
uint32_t neg_id = NegateConstant(const_mgr, c);
uint32_t non_const_id = zero_is_variable
? op_inst->GetSingleWordInOperand(0u)
: op_inst->GetSingleWordInOperand(1u);
// Change this instruction to a mul/div.
inst->SetOpcode(op_inst->opcode());
if (opcode == spv::Op::OpFDiv || opcode == spv::Op::OpUDiv ||
opcode == spv::Op::OpSDiv) {
uint32_t op0 = zero_is_variable ? non_const_id : neg_id;
uint32_t op1 = zero_is_variable ? neg_id : non_const_id;
inst->SetInOperands(
{{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}});
} else {
inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
{SPV_OPERAND_TYPE_ID, {neg_id}}});
}
return true;
}
if (opcode != spv::Op::OpFMul && opcode != spv::Op::OpFDiv &&
opcode != spv::Op::OpIMul && opcode != spv::Op::OpSDiv) {
return false;
}
return false;
std::vector<const analysis::Constant*> op_constants =
const_mgr->GetOperandConstants(op_inst);
// Merge negate into mul or div if one operand is constant.
if (op_constants[0] == nullptr && op_constants[1] == nullptr) {
return false;
}
bool zero_is_variable = op_constants[0] == nullptr;
const analysis::Constant* c = ConstInput(op_constants);
uint32_t neg_id = NegateConstant(const_mgr, c);
uint32_t non_const_id = zero_is_variable
? op_inst->GetSingleWordInOperand(0u)
: op_inst->GetSingleWordInOperand(1u);
// Change this instruction to a mul/div.
inst->SetOpcode(op_inst->opcode());
if (opcode == spv::Op::OpFDiv || opcode == spv::Op::OpUDiv ||
opcode == spv::Op::OpSDiv) {
uint32_t op0 = zero_is_variable ? non_const_id : neg_id;
uint32_t op1 = zero_is_variable ? neg_id : non_const_id;
inst->SetInOperands(
{{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}});
} else {
inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
{SPV_OPERAND_TYPE_ID, {neg_id}}});
}
return true;
};
}

View File

@ -5940,7 +5940,195 @@ INSTANTIATE_TEST_SUITE_P(MergeNegateTest, MatchingInstructionFoldingTest,
"%2 = OpFNegate %v2double %v2double_null\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, true)
2, true),
// Test case 20: fold snegate with OpIMul.
// -(x * 2) = x * -2
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[long:%\\w+]] = OpTypeInt 64 1\n" +
"; CHECK: [[long_n2:%\\w+]] = OpConstant [[long]] -2\n" +
"; CHECK: [[ld:%\\w+]] = OpLoad [[long]]\n" +
"; CHECK: %4 = OpIMul [[long]] [[ld]] [[long_n2]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%var = OpVariable %_ptr_long Function\n" +
"%2 = OpLoad %long %var\n" +
"%3 = OpIMul %long %2 %long_2\n" +
"%4 = OpSNegate %long %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
4, true),
// Test case 21: fold snegate with OpIMul.
// -(x * 2) = x * -2
InstructionFoldingCase<bool>(
Header() +
"; CHECK-DAG: [[int:%\\w+]] = OpTypeInt 32 1\n" +
"; CHECK-DAG: [[uint:%\\w+]] = OpTypeInt 32 0\n" +
"; CHECK: [[uint_n2:%\\w+]] = OpConstant [[uint]] 4294967294\n" +
"; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
"; CHECK: %4 = OpIMul [[int]] [[ld]] [[uint_n2]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%var = OpVariable %_ptr_int Function\n" +
"%2 = OpLoad %int %var\n" +
"%3 = OpIMul %int %2 %uint_2\n" +
"%4 = OpSNegate %int %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
4, true),
// Test case 22: fold snegate with OpIMul.
// -(-24 * x) = x * 24
InstructionFoldingCase<bool>(
Header() +
"; CHECK-DAG: [[int:%\\w+]] = OpTypeInt 32 1\n" +
"; CHECK: [[int_24:%\\w+]] = OpConstant [[int]] 24\n" +
"; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
"; CHECK: %4 = OpIMul [[int]] [[ld]] [[int_24]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%var = OpVariable %_ptr_int Function\n" +
"%2 = OpLoad %int %var\n" +
"%3 = OpIMul %int %int_n24 %2\n" +
"%4 = OpSNegate %int %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
4, true),
// Test case 23: fold snegate with OpIMul with UINT_MAX
// -(UINT_MAX * x) = x
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
"; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
"; CHECK: %4 = OpCopyObject [[int]] [[ld]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%var = OpVariable %_ptr_int Function\n" +
"%2 = OpLoad %int %var\n" +
"%3 = OpIMul %int %uint_max %2\n" +
"%4 = OpSNegate %int %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
4, true),
// Test case 24: fold snegate with OpIMul using -INT_MAX
// -(x * 2147483649u) = x * 2147483647u
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
"; CHECK: [[uint:%\\w+]] = OpTypeInt 32 0\n" +
"; CHECK: [[uint_2147483647:%\\w+]] = OpConstant [[uint]] 2147483647\n" +
"; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
"; CHECK: %4 = OpIMul [[int]] [[ld]] [[uint_2147483647]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%var = OpVariable %_ptr_int Function\n" +
"%2 = OpLoad %int %var\n" +
"%3 = OpIMul %int %2 %uint_2147483649\n" +
"%4 = OpSNegate %int %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
4, true),
// Test case 25: fold snegate with OpSDiv (long).
// -(x / 2) = x / -2
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[long:%\\w+]] = OpTypeInt 64 1\n" +
"; CHECK: [[long_n2:%\\w+]] = OpConstant [[long]] -2\n" +
"; CHECK: [[ld:%\\w+]] = OpLoad [[long]]\n" +
"; CHECK: %4 = OpSDiv [[long]] [[ld]] [[long_n2]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%var = OpVariable %_ptr_long Function\n" +
"%2 = OpLoad %long %var\n" +
"%3 = OpSDiv %long %2 %long_2\n" +
"%4 = OpSNegate %long %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
4, true),
// Test case 26: fold snegate with OpSDiv (int).
// -(x / 2) = x / -2
InstructionFoldingCase<bool>(
Header() +
"; CHECK-DAG: [[int:%\\w+]] = OpTypeInt 32 1\n" +
"; CHECK-DAG: [[uint:%\\w+]] = OpTypeInt 32 0\n" +
"; CHECK: [[uint_n2:%\\w+]] = OpConstant [[uint]] 4294967294\n" +
"; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
"; CHECK: %4 = OpSDiv [[int]] [[ld]] [[uint_n2]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%var = OpVariable %_ptr_int Function\n" +
"%2 = OpLoad %int %var\n" +
"%3 = OpSDiv %int %2 %uint_2\n" +
"%4 = OpSNegate %int %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
4, true),
// Test case 27: fold snegate with OpSDiv.
// -(-24 / x) = 24 / x
InstructionFoldingCase<bool>(
Header() +
"; CHECK-DAG: [[int:%\\w+]] = OpTypeInt 32 1\n" +
"; CHECK: [[int_24:%\\w+]] = OpConstant [[int]] 24\n" +
"; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
"; CHECK: %4 = OpSDiv [[int]] [[int_24]] [[ld]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%var = OpVariable %_ptr_int Function\n" +
"%2 = OpLoad %int %var\n" +
"%3 = OpSDiv %int %int_n24 %2\n" +
"%4 = OpSNegate %int %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
4, true),
// Test case 28: fold snegate with OpSDiv with UINT_MAX
// -(UINT_MAX / x) = (1 / x)
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
"; CHECK: [[uint:%\\w+]] = OpTypeInt 32 0\n" +
"; CHECK: [[uint_1:%\\w+]] = OpConstant [[uint]] 1\n" +
"; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
"; CHECK: %4 = OpSDiv [[int]] [[uint_1]] [[ld]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%var = OpVariable %_ptr_int Function\n" +
"%2 = OpLoad %int %var\n" +
"%3 = OpSDiv %int %uint_max %2\n" +
"%4 = OpSNegate %int %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
4, true),
// Test case 29: fold snegate with OpSDiv using -INT_MAX
// -(x / 2147483647u) = x / 2147483647
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
"; CHECK: [[uint:%\\w+]] = OpTypeInt 32 0\n" +
"; CHECK: [[uint_2147483647:%\\w+]] = OpConstant [[uint]] 2147483647\n" +
"; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
"; CHECK: %4 = OpSDiv [[int]] [[ld]] [[uint_2147483647]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%var = OpVariable %_ptr_int Function\n" +
"%2 = OpLoad %int %var\n" +
"%3 = OpSDiv %int %2 %uint_2147483649\n" +
"%4 = OpSNegate %int %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
4, true),
// Test case 30: Don't fold snegate int OpUDiv. The operands are interpreted
// as unsigned, so negating an operand is not the same a negating the
// result.
InstructionFoldingCase<bool>(
Header() +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%var = OpVariable %_ptr_int Function\n" +
"%2 = OpLoad %int %var\n" +
"%3 = OpUDiv %int %2 %uint_1\n" +
"%4 = OpSNegate %int %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
4, false)
));
INSTANTIATE_TEST_SUITE_P(ReciprocalFDivTest, MatchingInstructionFoldingTest,