diff --git a/source/opt/fold.cpp b/source/opt/fold.cpp index 0cda772cf..d6b583f91 100644 --- a/source/opt/fold.cpp +++ b/source/opt/fold.cpp @@ -114,12 +114,23 @@ uint32_t InstructionFolder::BinaryOperate(SpvOp opcode, uint32_t a, } // Shifting - case SpvOp::SpvOpShiftRightLogical: { + case SpvOp::SpvOpShiftRightLogical: + if (b > 32) { + // This is undefined behaviour. Choose 0 for consistency. + return 0; + } return a >> b; - } case SpvOp::SpvOpShiftRightArithmetic: + if (b > 32) { + // This is undefined behaviour. Choose 0 for consistency. + return 0; + } return (static_cast(a)) >> b; case SpvOp::SpvOpShiftLeftLogical: + if (b > 32) { + // This is undefined behaviour. Choose 0 for consistency. + return 0; + } return a << b; // Bitwise operations diff --git a/test/opt/fold_test.cpp b/test/opt/fold_test.cpp index 9aae338b6..1a5442181 100644 --- a/test/opt/fold_test.cpp +++ b/test/opt/fold_test.cpp @@ -179,6 +179,7 @@ OpName %main "main" %uint_3 = OpConstant %uint 3 %uint_4 = OpConstant %uint 4 %uint_32 = OpConstant %uint 32 +%uint_42 = OpConstant %uint 42 %uint_max = OpConstant %uint 4294967295 %v2int_undef = OpUndef %v2int %v2int_0_0 = OpConstantComposite %v2int %int_0 %int_0 @@ -474,6 +475,36 @@ INSTANTIATE_TEST_CASE_P(TestCase, IntegerInstructionFoldingTest, "%2 = OpUMod %uint %uint_1 %uint_0\n" + "OpReturn\n" + "OpFunctionEnd", + 2, 0), + // Test case 22: fold unsigned n >> 42 (undefined, so set to zero). + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_uint Function\n" + + "%load = OpLoad %uint %n\n" + + "%2 = OpShiftRightLogical %uint %load %uint_42\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 21: fold signed n >> 42 (undefined, so set to zero). + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + + "%2 = OpShiftRightLogical %int %load %uint_42\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 22: fold n << 42 (undefined, so set to zero). + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + + "%2 = OpShiftLeftLogical %int %load %uint_42\n" + + "OpReturn\n" + + "OpFunctionEnd", 2, 0) )); // clang-format on