diff --git a/source/opt/eliminate_dead_members_pass.cpp b/source/opt/eliminate_dead_members_pass.cpp index 0b73b2dbc..5b8f4ec54 100644 --- a/source/opt/eliminate_dead_members_pass.cpp +++ b/source/opt/eliminate_dead_members_pass.cpp @@ -19,6 +19,7 @@ namespace { const uint32_t kRemovedMember = 0xFFFFFFFF; +const uint32_t kSpecConstOpOpcodeIdx = 0; } namespace spvtools { @@ -40,7 +41,22 @@ void EliminateDeadMembersPass::FindLiveMembers() { // we have to mark them as fully used just to be safe. for (auto& inst : get_module()->types_values()) { if (inst.opcode() == SpvOpSpecConstantOp) { - MarkTypeAsFullyUsed(inst.type_id()); + switch (inst.GetSingleWordInOperand(kSpecConstOpOpcodeIdx)) { + case SpvOpCompositeExtract: + MarkMembersAsLiveForExtract(&inst); + break; + case SpvOpCompositeInsert: + // Nothing specific to do. + break; + case SpvOpAccessChain: + case SpvOpInBoundsAccessChain: + case SpvOpPtrAccessChain: + case SpvOpInBoundsPtrAccessChain: + assert(false && "Not implemented yet."); + break; + default: + break; + } } else if (inst.opcode() == SpvOpVariable) { switch (inst.GetSingleWordInOperand(0)) { case SpvStorageClassInput: @@ -153,13 +169,17 @@ void EliminateDeadMembersPass::MarkMembersAsLiveForCopyMemory( void EliminateDeadMembersPass::MarkMembersAsLiveForExtract( const Instruction* inst) { - assert(inst->opcode() == SpvOpCompositeExtract); + assert(inst->opcode() == SpvOpCompositeExtract || + (inst->opcode() == SpvOpSpecConstantOp && + inst->GetSingleWordInOperand(kSpecConstOpOpcodeIdx) == + SpvOpCompositeExtract)); - uint32_t composite_id = inst->GetSingleWordInOperand(0); + uint32_t first_operand = (inst->opcode() == SpvOpSpecConstantOp ? 1 : 0); + uint32_t composite_id = inst->GetSingleWordInOperand(first_operand); Instruction* composite_inst = get_def_use_mgr()->GetDef(composite_id); uint32_t type_id = composite_inst->type_id(); - for (uint32_t i = 1; i < inst->NumInOperands(); ++i) { + for (uint32_t i = first_operand + 1; i < inst->NumInOperands(); ++i) { Instruction* type_inst = get_def_use_mgr()->GetDef(type_id); uint32_t member_idx = inst->GetSingleWordInOperand(i); switch (type_inst->opcode()) { @@ -295,10 +315,22 @@ bool EliminateDeadMembersPass::RemoveDeadMembers() { modified |= UpdateOpArrayLength(inst); break; case SpvOpSpecConstantOp: - assert(false && "Not yet implemented."); - // with OpCompositeExtract, OpCompositeInsert - // For kernels: OpAccessChain, OpInBoundsAccessChain, OpPtrAccessChain, - // OpInBoundsPtrAccessChain + switch (inst->GetSingleWordInOperand(kSpecConstOpOpcodeIdx)) { + case SpvOpCompositeExtract: + modified |= UpdateCompsiteExtract(inst); + break; + case SpvOpCompositeInsert: + modified |= UpdateCompositeInsert(inst); + break; + case SpvOpAccessChain: + case SpvOpInBoundsAccessChain: + case SpvOpPtrAccessChain: + case SpvOpInBoundsPtrAccessChain: + assert(false && "Not implemented yet."); + break; + default: + break; + } break; default: break; @@ -393,7 +425,8 @@ bool EliminateDeadMembersPass::UpdateOpGroupMemberDecorate(Instruction* inst) { } bool EliminateDeadMembersPass::UpdateConstantComposite(Instruction* inst) { - assert(inst->opcode() == SpvOpConstantComposite || + assert(inst->opcode() == SpvOpSpecConstantComposite || + inst->opcode() == SpvOpConstantComposite || inst->opcode() == SpvOpCompositeConstruct); uint32_t type_id = inst->type_id(); @@ -506,14 +539,25 @@ uint32_t EliminateDeadMembersPass::GetNewMemberIndex(uint32_t type_id, } bool EliminateDeadMembersPass::UpdateCompsiteExtract(Instruction* inst) { - uint32_t object_id = inst->GetSingleWordInOperand(0); + assert(inst->opcode() == SpvOpCompositeExtract || + (inst->opcode() == SpvOpSpecConstantOp && + inst->GetSingleWordInOperand(kSpecConstOpOpcodeIdx) == + SpvOpCompositeExtract)); + + uint32_t first_operand = 0; + if (inst->opcode() == SpvOpSpecConstantOp) { + first_operand = 1; + } + uint32_t object_id = inst->GetSingleWordInOperand(first_operand); Instruction* object_inst = get_def_use_mgr()->GetDef(object_id); uint32_t type_id = object_inst->type_id(); Instruction::OperandList new_operands; bool modified = false; - new_operands.emplace_back(inst->GetInOperand(0)); - for (uint32_t i = 1; i < inst->NumInOperands(); ++i) { + for (uint32_t i = 0; i < first_operand + 1; i++) { + new_operands.emplace_back(inst->GetInOperand(i)); + } + for (uint32_t i = first_operand + 1; i < inst->NumInOperands(); ++i) { uint32_t member_idx = inst->GetSingleWordInOperand(i); uint32_t new_member_idx = GetNewMemberIndex(type_id, member_idx); assert(new_member_idx != kRemovedMember); @@ -526,8 +570,6 @@ bool EliminateDeadMembersPass::UpdateCompsiteExtract(Instruction* inst) { Instruction* type_inst = get_def_use_mgr()->GetDef(type_id); switch (type_inst->opcode()) { case SpvOpTypeStruct: - assert(i != 1 || (inst->opcode() != SpvOpPtrAccessChain && - inst->opcode() != SpvOpInBoundsPtrAccessChain)); // The type will have already been rewriten, so use the new member // index. type_id = type_inst->GetSingleWordInOperand(new_member_idx); @@ -552,15 +594,27 @@ bool EliminateDeadMembersPass::UpdateCompsiteExtract(Instruction* inst) { } bool EliminateDeadMembersPass::UpdateCompositeInsert(Instruction* inst) { - uint32_t composite_id = inst->GetSingleWordInOperand(1); + assert(inst->opcode() == SpvOpCompositeInsert || + (inst->opcode() == SpvOpSpecConstantOp && + inst->GetSingleWordInOperand(kSpecConstOpOpcodeIdx) == + SpvOpCompositeInsert)); + + uint32_t first_operand = 0; + if (inst->opcode() == SpvOpSpecConstantOp) { + first_operand = 1; + } + + uint32_t composite_id = inst->GetSingleWordInOperand(first_operand + 1); Instruction* composite_inst = get_def_use_mgr()->GetDef(composite_id); uint32_t type_id = composite_inst->type_id(); Instruction::OperandList new_operands; bool modified = false; - new_operands.emplace_back(inst->GetInOperand(0)); - new_operands.emplace_back(inst->GetInOperand(1)); - for (uint32_t i = 2; i < inst->NumInOperands(); ++i) { + + for (uint32_t i = 0; i < first_operand + 2; ++i) { + new_operands.emplace_back(inst->GetInOperand(i)); + } + for (uint32_t i = first_operand + 2; i < inst->NumInOperands(); ++i) { uint32_t member_idx = inst->GetSingleWordInOperand(i); uint32_t new_member_idx = GetNewMemberIndex(type_id, member_idx); if (new_member_idx == kRemovedMember) { diff --git a/test/opt/eliminate_dead_member_test.cpp b/test/opt/eliminate_dead_member_test.cpp index b6925d7d7..a9b0f28c7 100644 --- a/test/opt/eliminate_dead_member_test.cpp +++ b/test/opt/eliminate_dead_member_test.cpp @@ -1085,4 +1085,103 @@ TEST_F(EliminateDeadMemberTest, DontChangeOutputStructs) { EXPECT_EQ(opt::Pass::Status::SuccessWithoutChange, std::get<1>(result)); } +TEST_F(EliminateDeadMemberTest, UpdateSpecConstOpExtract) { + // Test that an extract in an OpSpecConstantOp is correctly updated. + const std::string text = R"( +; CHECK: OpName +; CHECK-NEXT: OpMemberName %type__Globals 0 "y" +; CHECK-NOT: OpMemberName +; CHECK: OpDecorate [[spec_const:%\w+]] SpecId 1 +; CHECK: OpMemberDecorate %type__Globals 0 Offset 4 +; CHECK: %type__Globals = OpTypeStruct %uint +; CHECK: [[struct:%\w+]] = OpSpecConstantComposite %type__Globals [[spec_const]] +; CHECK: OpSpecConstantOp %uint CompositeExtract [[struct]] 0 + OpCapability Shader + OpCapability Addresses + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" + OpSource HLSL 600 + OpName %type__Globals "type.$Globals" + OpMemberName %type__Globals 0 "x" + OpMemberName %type__Globals 1 "y" + OpMemberName %type__Globals 2 "z" + OpName %main "main" + OpDecorate %c_0 SpecId 0 + OpDecorate %c_1 SpecId 1 + OpDecorate %c_2 SpecId 2 + OpMemberDecorate %type__Globals 0 Offset 0 + OpMemberDecorate %type__Globals 1 Offset 4 + OpMemberDecorate %type__Globals 2 Offset 16 + %uint = OpTypeInt 32 0 + %c_0 = OpSpecConstant %uint 0 + %c_1 = OpSpecConstant %uint 1 + %c_2 = OpSpecConstant %uint 2 + %uint_0 = OpConstant %uint 0 + %uint_1 = OpConstant %uint 1 + %uint_2 = OpConstant %uint 2 + %uint_3 = OpConstant %uint 3 +%type__Globals = OpTypeStruct %uint %uint %uint +%spec_const_global = OpSpecConstantComposite %type__Globals %c_0 %c_1 %c_2 +%extract = OpSpecConstantOp %uint CompositeExtract %spec_const_global 1 + %void = OpTypeVoid + %14 = OpTypeFunction %void + %main = OpFunction %void None %14 + %16 = OpLabel + OpReturn + OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(EliminateDeadMemberTest, UpdateSpecConstOpInsert) { + // Test that an insert in an OpSpecConstantOp is correctly updated. + const std::string text = R"( +; CHECK: OpName +; CHECK-NEXT: OpMemberName %type__Globals 0 "y" +; CHECK-NOT: OpMemberName +; CHECK: OpDecorate [[spec_const:%\w+]] SpecId 1 +; CHECK: OpMemberDecorate %type__Globals 0 Offset 4 +; CHECK: %type__Globals = OpTypeStruct %uint +; CHECK: [[struct:%\w+]] = OpSpecConstantComposite %type__Globals [[spec_const]] +; CHECK: OpSpecConstantOp %type__Globals CompositeInsert %uint_3 [[struct]] 0 + OpCapability Shader + OpCapability Addresses + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" + OpSource HLSL 600 + OpName %type__Globals "type.$Globals" + OpMemberName %type__Globals 0 "x" + OpMemberName %type__Globals 1 "y" + OpMemberName %type__Globals 2 "z" + OpName %main "main" + OpDecorate %c_0 SpecId 0 + OpDecorate %c_1 SpecId 1 + OpDecorate %c_2 SpecId 2 + OpMemberDecorate %type__Globals 0 Offset 0 + OpMemberDecorate %type__Globals 1 Offset 4 + OpMemberDecorate %type__Globals 2 Offset 16 + %uint = OpTypeInt 32 0 + %c_0 = OpSpecConstant %uint 0 + %c_1 = OpSpecConstant %uint 1 + %c_2 = OpSpecConstant %uint 2 + %uint_0 = OpConstant %uint 0 + %uint_1 = OpConstant %uint 1 + %uint_2 = OpConstant %uint 2 + %uint_3 = OpConstant %uint 3 +%type__Globals = OpTypeStruct %uint %uint %uint +%spec_const_global = OpSpecConstantComposite %type__Globals %c_0 %c_1 %c_2 +%insert = OpSpecConstantOp %type__Globals CompositeInsert %uint_3 %spec_const_global 1 +%extract = OpSpecConstantOp %uint CompositeExtract %insert 1 + %void = OpTypeVoid + %14 = OpTypeFunction %void + %main = OpFunction %void None %14 + %16 = OpLabel + OpReturn + OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + } // namespace