diff --git a/source/opt/scalar_replacement_pass.cpp b/source/opt/scalar_replacement_pass.cpp index d100cb05b..07d8c01f1 100644 --- a/source/opt/scalar_replacement_pass.cpp +++ b/source/opt/scalar_replacement_pass.cpp @@ -329,6 +329,10 @@ void ScalarReplacementPass::TransferAnnotations( if (decoration == SpvDecorationInvariant || decoration == SpvDecorationRestrict) { for (auto var : *replacements) { + if (var == nullptr) { + continue; + } + std::unique_ptr annotation( new Instruction(context(), SpvOpDecorate, 0, 0, std::initializer_list{ @@ -350,6 +354,11 @@ void ScalarReplacementPass::CreateVariable( std::vector* replacements) { uint32_t ptrId = GetOrCreatePointerType(typeId); uint32_t id = TakeNextId(); + + if (id == 0) { + replacements->push_back(nullptr); + } + std::unique_ptr variable(new Instruction( context(), SpvOpVariable, ptrId, id, std::initializer_list{ diff --git a/source/opt/scalar_replacement_pass.h b/source/opt/scalar_replacement_pass.h index 5b5198155..d2eb8975d 100644 --- a/source/opt/scalar_replacement_pass.h +++ b/source/opt/scalar_replacement_pass.h @@ -143,7 +143,8 @@ class ScalarReplacementPass : public Pass { bool CheckStore(const Instruction* inst, uint32_t index) const; // Creates a variable of type |typeId| from the |index|'th element of - // |varInst|. The new variable is added to |replacements|. + // |varInst|. The new variable is added to |replacements|. If the variable + // could not be created, then |nullptr| is appended to |replacements|. void CreateVariable(uint32_t typeId, Instruction* varInst, uint32_t index, std::vector* replacements); diff --git a/test/opt/pass_fixture.h b/test/opt/pass_fixture.h index a4d749f78..53fb206fa 100644 --- a/test/opt/pass_fixture.h +++ b/test/opt/pass_fixture.h @@ -75,7 +75,9 @@ class PassTest : public TestT { const auto status = pass->Run(context()); std::vector binary; - context()->module()->ToBinary(&binary, skip_nop); + if (status != Pass::Status::Failure) { + context()->module()->ToBinary(&binary, skip_nop); + } return std::make_tuple(binary, status); } @@ -241,15 +243,18 @@ class PassTest : public TestT { context()->set_preserve_spec_constants( OptimizerOptions()->preserve_spec_constants_); - manager_->Run(context()); + auto status = manager_->Run(context()); + EXPECT_NE(status, Pass::Status::Failure); - std::vector binary; - context()->module()->ToBinary(&binary, /* skip_nop = */ false); + if (status != Pass::Status::Failure) { + std::vector binary; + context()->module()->ToBinary(&binary, /* skip_nop = */ false); - std::string optimized; - SpirvTools tools(env_); - EXPECT_TRUE(tools.Disassemble(binary, &optimized, disassemble_options_)); - EXPECT_EQ(expected, optimized); + std::string optimized; + SpirvTools tools(env_); + EXPECT_TRUE(tools.Disassemble(binary, &optimized, disassemble_options_)); + EXPECT_EQ(expected, optimized); + } } void SetAssembleOptions(uint32_t assemble_options) { diff --git a/test/opt/scalar_replacement_test.cpp b/test/opt/scalar_replacement_test.cpp index 00f8e17b6..e04332652 100644 --- a/test/opt/scalar_replacement_test.cpp +++ b/test/opt/scalar_replacement_test.cpp @@ -1621,7 +1621,7 @@ TEST_F(ScalarReplacementTest, TestAccessChainWithNoIndexes) { } // Test that id overflow is handled gracefully. -TEST_F(ScalarReplacementTest, IdBoundOverflow) { +TEST_F(ScalarReplacementTest, IdBoundOverflow1) { const std::string text = R"( OpCapability ImageQuery OpMemoryModel Logical GLSL450 @@ -1652,8 +1652,47 @@ OpFunctionEnd {SPV_MSG_ERROR, "", 0, 0, "ID overflow. Try running compact-ids."}, {SPV_MSG_ERROR, "", 0, 0, "ID overflow. Try running compact-ids."}}; SetMessageConsumer(GetTestMessageConsumer(messages)); - auto result = - SinglePassRunAndDisassemble(text, true, false); + auto result = SinglePassRunToBinary(text, true, false); + EXPECT_EQ(Pass::Status::Failure, std::get<1>(result)); +} + +// Test that id overflow is handled gracefully. +TEST_F(ScalarReplacementTest, IdBoundOverflow2) { + const std::string text = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %4 "main" %17 +OpExecutionMode %4 OriginUpperLeft +%2 = OpTypeVoid +%3 = OpTypeFunction %2 +%6 = OpTypeFloat 32 +%7 = OpTypeVector %6 4 +%8 = OpTypeStruct %7 +%9 = OpTypePointer Function %8 +%16 = OpTypePointer Output %7 +%21 = OpTypeInt 32 1 +%22 = OpConstant %21 0 +%23 = OpTypePointer Function %7 +%17 = OpVariable %16 Output +%4 = OpFunction %2 None %3 +%5 = OpLabel +%4194300 = OpVariable %23 Function +%10 = OpVariable %9 Function +%4194301 = OpAccessChain %23 %10 %22 +%4194302 = OpLoad %7 %4194301 +OpStore %4194300 %4194302 +%15 = OpLoad %7 %4194300 +OpStore %17 %15 +OpReturn +OpFunctionEnd + )"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + + std::vector messages = { + {SPV_MSG_ERROR, "", 0, 0, "ID overflow. Try running compact-ids."}}; + SetMessageConsumer(GetTestMessageConsumer(messages)); + auto result = SinglePassRunToBinary(text, true, false); EXPECT_EQ(Pass::Status::Failure, std::get<1>(result)); }