Fold divisions by 0. (#1963)

The current implementation in the folder when seeing a division by zero
is to assert.  In the release build, the compiler will attempt to
compute the value, which causes its own problems.

The solution I will go with is to fold the division, and just give it
the value of 0.  The same goes for remainder and mod operations.

Fixes #1961.
This commit is contained in:
Steven Perron 2018-10-10 11:17:26 -04:00 committed by GitHub
parent fae1e61ab8
commit 4e266f775a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 72 additions and 12 deletions

View File

@ -69,29 +69,49 @@ uint32_t InstructionFolder::BinaryOperate(SpvOp opcode, uint32_t a,
case SpvOp::SpvOpIMul: case SpvOp::SpvOpIMul:
return a * b; return a * b;
case SpvOp::SpvOpUDiv: case SpvOp::SpvOpUDiv:
assert(b != 0); if (b != 0) {
return a / b; return a / b;
} else {
// Dividing by 0 is undefined, so we will just pick 0.
return 0;
}
case SpvOp::SpvOpSDiv: case SpvOp::SpvOpSDiv:
assert(b != 0u); if (b != 0u) {
return (static_cast<int32_t>(a)) / (static_cast<int32_t>(b)); return (static_cast<int32_t>(a)) / (static_cast<int32_t>(b));
} else {
// Dividing by 0 is undefined, so we will just pick 0.
return 0;
}
case SpvOp::SpvOpSRem: { case SpvOp::SpvOpSRem: {
// The sign of non-zero result comes from the first operand: a. This is // The sign of non-zero result comes from the first operand: a. This is
// guaranteed by C++11 rules for integer division operator. The division // guaranteed by C++11 rules for integer division operator. The division
// result is rounded toward zero, so the result of '%' has the sign of // result is rounded toward zero, so the result of '%' has the sign of
// the first operand. // the first operand.
assert(b != 0u); if (b != 0u) {
return static_cast<int32_t>(a) % static_cast<int32_t>(b); return static_cast<int32_t>(a) % static_cast<int32_t>(b);
} else {
// Remainder when dividing with 0 is undefined, so we will just pick 0.
return 0;
}
} }
case SpvOp::SpvOpSMod: { case SpvOp::SpvOpSMod: {
// The sign of non-zero result comes from the second operand: b // The sign of non-zero result comes from the second operand: b
assert(b != 0u); if (b != 0u) {
int32_t rem = BinaryOperate(SpvOp::SpvOpSRem, a, b); int32_t rem = BinaryOperate(SpvOp::SpvOpSRem, a, b);
int32_t b_prim = static_cast<int32_t>(b); int32_t b_prim = static_cast<int32_t>(b);
return (rem + b_prim) % b_prim; return (rem + b_prim) % b_prim;
} else {
// Mod with 0 is undefined, so we will just pick 0.
return 0;
}
} }
case SpvOp::SpvOpUMod: case SpvOp::SpvOpUMod:
assert(b != 0u); if (b != 0u) {
return (a % b); return (a % b);
} else {
// Mod with 0 is undefined, so we will just pick 0.
return 0;
}
// Shifting // Shifting
case SpvOp::SpvOpShiftRightLogical: { case SpvOp::SpvOpShiftRightLogical: {

View File

@ -434,6 +434,46 @@ INSTANTIATE_TEST_CASE_P(TestCase, IntegerInstructionFoldingTest,
"%2 = OpBitwiseAnd %uint %load %uint_0\n" + "%2 = OpBitwiseAnd %uint %load %uint_0\n" +
"OpReturn\n" + "OpReturn\n" +
"OpFunctionEnd", "OpFunctionEnd",
2, 0),
// Test case 17: fold 1/0 (signed)
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpSDiv %int %int_1 %int_0\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 18: fold 1/0 (unsigned)
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpUDiv %uint %uint_1 %uint_0\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 19: fold OpSRem 1 0 (signed)
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpSRem %int %int_1 %int_0\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 20: fold 1%0 (signed)
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpSMod %int %int_1 %int_0\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 21: fold 1%0 (unsigned)
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpUMod %uint %uint_1 %uint_0\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0) 2, 0)
)); ));
// clang-format on // clang-format on