From b230a7c7d1eb02d59c1d6015a8da00703ca413e1 Mon Sep 17 00:00:00 2001 From: alan-baker Date: Tue, 31 Jan 2023 15:40:22 -0500 Subject: [PATCH] Validate operand type before operating on it (#5092) Fixes https://crbug.com/oss-fuzz/52921 * Validate the data operand of OpBitCount before trying to get its dimension --- source/val/validate_bitwise.cpp | 5 +++-- test/val/val_bitwise_test.cpp | 26 ++++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/source/val/validate_bitwise.cpp b/source/val/validate_bitwise.cpp index 87c955630..6ab1faebc 100644 --- a/source/val/validate_bitwise.cpp +++ b/source/val/validate_bitwise.cpp @@ -206,13 +206,14 @@ spv_result_t BitwisePass(ValidationState_t& _, const Instruction* inst) { << spvOpcodeString(opcode); const uint32_t base_type = _.GetOperandTypeId(inst, 2); - const uint32_t base_dimension = _.GetDimension(base_type); - const uint32_t result_dimension = _.GetDimension(result_type); if (spv_result_t error = ValidateBaseType(_, inst, base_type)) { return error; } + const uint32_t base_dimension = _.GetDimension(base_type); + const uint32_t result_dimension = _.GetDimension(result_type); + if (base_dimension != result_dimension) return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected Base dimension to be equal to Result Type " diff --git a/test/val/val_bitwise_test.cpp b/test/val/val_bitwise_test.cpp index bebaa84fc..b849e7b77 100644 --- a/test/val/val_bitwise_test.cpp +++ b/test/val/val_bitwise_test.cpp @@ -643,6 +643,32 @@ TEST_F(ValidateBitwise, OpBitCountNot32Vulkan) { HasSubstr("Expected 32-bit int type for Base operand: BitCount")); } +TEST_F(ValidateBitwise, OpBitCountPointer) { + const std::string body = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %main "main" +OpExecutionMode %main LocalSize 1 1 1 +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%ptr_int = OpTypePointer Function %int +%void_fn = OpTypeFunction %void +%main = OpFunction %void None %void_fn +%entry = OpLabel +%var = OpVariable %ptr_int Function +%count = OpBitCount %int %var +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(body); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Expected int scalar or vector type for Base operand: BitCount")); +} + } // namespace } // namespace val } // namespace spvtools