diff --git a/include/spirv-tools/optimizer.hpp b/include/spirv-tools/optimizer.hpp index d393495ba..741f9476d 100644 --- a/include/spirv-tools/optimizer.hpp +++ b/include/spirv-tools/optimizer.hpp @@ -876,8 +876,10 @@ Optimizer::PassToken CreateGraphicsRobustAccessPass(); // for the first index. Optimizer::PassToken CreateDescriptorScalarReplacementPass(); -// Create a pass to replace all OpKill instruction with a function call to a -// function that has a single OpKill. This allows more code to be inlined. +// Create a pass to replace each OpKill instruction with a function call to a +// function that has a single OpKill. Also replace each OpTerminateInvocation +// instruction with a function call to a function that has a single +// OpTerminateInvocation. This allows more code to be inlined. Optimizer::PassToken CreateWrapOpKillPass(); // Replaces the extensions VK_AMD_shader_ballot,VK_AMD_gcn_shader, and diff --git a/source/opcode.cpp b/source/opcode.cpp index 3781a8db8..f93cfd371 100644 --- a/source/opcode.cpp +++ b/source/opcode.cpp @@ -446,7 +446,7 @@ bool spvOpcodeIsReturn(SpvOp opcode) { bool spvOpcodeIsReturnOrAbort(SpvOp opcode) { return spvOpcodeIsReturn(opcode) || opcode == SpvOpKill || - opcode == SpvOpUnreachable; + opcode == SpvOpUnreachable || opcode == SpvOpTerminateInvocation; } bool spvOpcodeIsBlockTerminator(SpvOp opcode) { diff --git a/source/opt/aggressive_dead_code_elim_pass.cpp b/source/opt/aggressive_dead_code_elim_pass.cpp index 9fcfd3a54..b75578744 100644 --- a/source/opt/aggressive_dead_code_elim_pass.cpp +++ b/source/opt/aggressive_dead_code_elim_pass.cpp @@ -986,6 +986,7 @@ void AggressiveDCEPass::InitExtensions() { "SPV_KHR_ray_tracing", "SPV_EXT_fragment_invocation_density", "SPV_EXT_physical_storage_buffer", + "SPV_KHR_terminate_invocation", }); } diff --git a/source/opt/dominator_tree.cpp b/source/opt/dominator_tree.cpp index da5073a5e..7e61506b9 100644 --- a/source/opt/dominator_tree.cpp +++ b/source/opt/dominator_tree.cpp @@ -176,7 +176,8 @@ void BasicBlockSuccessorHelper::CreateSuccessorMap( // The tree construction requires 1 entry point, so we add a dummy node // that is connected to all function exiting basic blocks. // An exiting basic block is a block with an OpKill, OpUnreachable, - // OpReturn or OpReturnValue as terminator instruction. + // OpReturn, OpReturnValue, or OpTerminateInvocation as terminator + // instruction. for (BasicBlock& bb : f) { if (bb.hasSuccessor()) { BasicBlockListTy& pred_list = predecessors_[&bb]; diff --git a/source/opt/inline_pass.cpp b/source/opt/inline_pass.cpp index cb5a1265e..ef94d0d6c 100644 --- a/source/opt/inline_pass.cpp +++ b/source/opt/inline_pass.cpp @@ -384,7 +384,8 @@ std::unique_ptr InlinePass::InlineReturn( for (auto callee_block_itr = calleeFn->begin(); callee_block_itr != calleeFn->end(); ++callee_block_itr) { if (callee_block_itr->tail()->opcode() == SpvOpUnreachable || - callee_block_itr->tail()->opcode() == SpvOpKill) { + callee_block_itr->tail()->opcode() == SpvOpKill || + callee_block_itr->tail()->opcode() == SpvOpTerminateInvocation) { returnLabelId = context()->TakeNextId(); break; } @@ -738,16 +739,18 @@ bool InlinePass::IsInlinableFunction(Function* func) { bool func_is_called_from_continue = funcs_called_from_continue_.count(func->result_id()) != 0; - if (func_is_called_from_continue && ContainsKill(func)) { + if (func_is_called_from_continue && ContainsKillOrTerminateInvocation(func)) { return false; } return true; } -bool InlinePass::ContainsKill(Function* func) const { - return !func->WhileEachInst( - [](Instruction* inst) { return inst->opcode() != SpvOpKill; }); +bool InlinePass::ContainsKillOrTerminateInvocation(Function* func) const { + return !func->WhileEachInst([](Instruction* inst) { + const auto opcode = inst->opcode(); + return (opcode != SpvOpKill) && (opcode != SpvOpTerminateInvocation); + }); } void InlinePass::InitializeInline() { diff --git a/source/opt/inline_pass.h b/source/opt/inline_pass.h index 202bc97fd..abe773af8 100644 --- a/source/opt/inline_pass.h +++ b/source/opt/inline_pass.h @@ -139,8 +139,9 @@ class InlinePass : public Pass { // Return true if |func| is a function that can be inlined. bool IsInlinableFunction(Function* func); - // Returns true if |func| contains an OpKill instruction. - bool ContainsKill(Function* func) const; + // Returns true if |func| contains an OpKill or OpTerminateInvocation + // instruction. + bool ContainsKillOrTerminateInvocation(Function* func) const; // Update phis in succeeding blocks to point to new last block void UpdateSucceedingPhis( diff --git a/source/opt/local_access_chain_convert_pass.cpp b/source/opt/local_access_chain_convert_pass.cpp index 05704c148..9b8c112e1 100644 --- a/source/opt/local_access_chain_convert_pass.cpp +++ b/source/opt/local_access_chain_convert_pass.cpp @@ -382,6 +382,7 @@ void LocalAccessChainConvertPass::InitExtensions() { "SPV_KHR_ray_tracing", "SPV_KHR_ray_query", "SPV_EXT_fragment_invocation_density", + "SPV_KHR_terminate_invocation", }); } diff --git a/source/opt/local_single_block_elim_pass.cpp b/source/opt/local_single_block_elim_pass.cpp index 57572825d..bd5d75101 100644 --- a/source/opt/local_single_block_elim_pass.cpp +++ b/source/opt/local_single_block_elim_pass.cpp @@ -267,6 +267,7 @@ void LocalSingleBlockLoadStoreElimPass::InitExtensions() { "SPV_KHR_ray_query", "SPV_EXT_fragment_invocation_density", "SPV_EXT_physical_storage_buffer", + "SPV_KHR_terminate_invocation", }); } diff --git a/source/opt/local_single_store_elim_pass.cpp b/source/opt/local_single_store_elim_pass.cpp index 6626d87f3..238410755 100644 --- a/source/opt/local_single_store_elim_pass.cpp +++ b/source/opt/local_single_store_elim_pass.cpp @@ -121,6 +121,7 @@ void LocalSingleStoreElimPass::InitExtensionAllowList() { "SPV_KHR_ray_query", "SPV_EXT_fragment_invocation_density", "SPV_EXT_physical_storage_buffer", + "SPV_KHR_terminate_invocation", }); } bool LocalSingleStoreElimPass::ProcessVariable(Instruction* var_inst) { diff --git a/source/opt/loop_unroller.cpp b/source/opt/loop_unroller.cpp index 10fac0433..40cf6bc2a 100644 --- a/source/opt/loop_unroller.cpp +++ b/source/opt/loop_unroller.cpp @@ -997,7 +997,8 @@ bool LoopUtils::CanPerformUnroll() { const BasicBlock* block = context_->cfg()->block(label_id); if (block->ctail()->opcode() == SpvOp::SpvOpKill || block->ctail()->opcode() == SpvOp::SpvOpReturn || - block->ctail()->opcode() == SpvOp::SpvOpReturnValue) { + block->ctail()->opcode() == SpvOp::SpvOpReturnValue || + block->ctail()->opcode() == SpvOp::SpvOpTerminateInvocation) { return false; } } diff --git a/source/opt/reflect.h b/source/opt/reflect.h index 51d23a740..2e253add3 100644 --- a/source/opt/reflect.h +++ b/source/opt/reflect.h @@ -60,7 +60,8 @@ inline bool IsSpecConstantInst(SpvOp opcode) { return opcode >= SpvOpSpecConstantTrue && opcode <= SpvOpSpecConstantOp; } inline bool IsTerminatorInst(SpvOp opcode) { - return opcode >= SpvOpBranch && opcode <= SpvOpUnreachable; + return (opcode >= SpvOpBranch && opcode <= SpvOpUnreachable) || + (opcode == SpvOpTerminateInvocation); } } // namespace opt diff --git a/source/opt/replace_invalid_opc.cpp b/source/opt/replace_invalid_opc.cpp index 4e0f24f46..38b7539bf 100644 --- a/source/opt/replace_invalid_opc.cpp +++ b/source/opt/replace_invalid_opc.cpp @@ -141,6 +141,7 @@ bool ReplaceInvalidOpcodePass::IsFragmentShaderOnlyInstruction( // TODO: Teach |ReplaceInstruction| to handle block terminators. Then // uncomment the OpKill case. // case SpvOpKill: + // case SpvOpTerminateInstruction: return true; default: return false; diff --git a/source/opt/wrap_opkill.cpp b/source/opt/wrap_opkill.cpp index 3c8bae6d7..4d708405c 100644 --- a/source/opt/wrap_opkill.cpp +++ b/source/opt/wrap_opkill.cpp @@ -27,7 +27,8 @@ Pass::Status WrapOpKill::Process() { for (uint32_t func_id : func_to_process) { Function* func = context()->GetFunction(func_id); bool successful = func->WhileEachInst([this, &modified](Instruction* inst) { - if (inst->opcode() == SpvOpKill) { + const auto opcode = inst->opcode(); + if ((opcode == SpvOpKill) || (opcode == SpvOpTerminateInvocation)) { modified = true; if (!ReplaceWithFunctionCall(inst)) { return false; @@ -46,16 +47,22 @@ Pass::Status WrapOpKill::Process() { "The function should only be generated if something was modified."); context()->AddFunction(std::move(opkill_function_)); } + if (opterminateinvocation_function_ != nullptr) { + assert(modified && + "The function should only be generated if something was modified."); + context()->AddFunction(std::move(opterminateinvocation_function_)); + } return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange); } bool WrapOpKill::ReplaceWithFunctionCall(Instruction* inst) { - assert(inst->opcode() == SpvOpKill && - "|inst| must be an OpKill instruction."); + assert((inst->opcode() == SpvOpKill || + inst->opcode() == SpvOpTerminateInvocation) && + "|inst| must be an OpKill or OpTerminateInvocation instruction."); InstructionBuilder ir_builder( context(), inst, IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); - uint32_t func_id = GetOpKillFuncId(); + uint32_t func_id = GetKillingFuncId(inst->opcode()); if (func_id == 0) { return false; } @@ -108,13 +115,20 @@ uint32_t WrapOpKill::GetVoidFunctionTypeId() { return type_mgr->GetTypeInstruction(&func_type); } -uint32_t WrapOpKill::GetOpKillFuncId() { - if (opkill_function_ != nullptr) { - return opkill_function_->result_id(); +uint32_t WrapOpKill::GetKillingFuncId(SpvOp opcode) { + // Parameterize by opcode + assert(opcode == SpvOpKill || opcode == SpvOpTerminateInvocation); + + std::unique_ptr* const killing_func = + (opcode == SpvOpKill) ? &opkill_function_ + : &opterminateinvocation_function_; + + if (*killing_func != nullptr) { + return (*killing_func)->result_id(); } - uint32_t opkill_func_id = TakeNextId(); - if (opkill_func_id == 0) { + uint32_t killing_func_id = TakeNextId(); + if (killing_func_id == 0) { return 0; } @@ -125,15 +139,15 @@ uint32_t WrapOpKill::GetOpKillFuncId() { // Generate the function start instruction std::unique_ptr func_start(new Instruction( - context(), SpvOpFunction, void_type_id, opkill_func_id, {})); + context(), SpvOpFunction, void_type_id, killing_func_id, {})); func_start->AddOperand({SPV_OPERAND_TYPE_FUNCTION_CONTROL, {0}}); func_start->AddOperand({SPV_OPERAND_TYPE_ID, {GetVoidFunctionTypeId()}}); - opkill_function_.reset(new Function(std::move(func_start))); + (*killing_func).reset(new Function(std::move(func_start))); // Generate the function end instruction std::unique_ptr func_end( new Instruction(context(), SpvOpFunctionEnd, 0, 0, {})); - opkill_function_->SetFunctionEnd(std::move(func_end)); + (*killing_func)->SetFunctionEnd(std::move(func_end)); // Create the one basic block for the function. uint32_t lab_id = TakeNextId(); @@ -146,21 +160,22 @@ uint32_t WrapOpKill::GetOpKillFuncId() { // Add the OpKill to the basic block std::unique_ptr kill_inst( - new Instruction(context(), SpvOpKill, 0, 0, {})); + new Instruction(context(), opcode, 0, 0, {})); bb->AddInstruction(std::move(kill_inst)); // Add the bb to the function - bb->SetParent(opkill_function_.get()); - opkill_function_->AddBasicBlock(std::move(bb)); + bb->SetParent((*killing_func).get()); + (*killing_func)->AddBasicBlock(std::move(bb)); // Add the function to the module. if (context()->AreAnalysesValid(IRContext::kAnalysisDefUse)) { - opkill_function_->ForEachInst( - [this](Instruction* inst) { context()->AnalyzeDefUse(inst); }); + (*killing_func)->ForEachInst([this](Instruction* inst) { + context()->AnalyzeDefUse(inst); + }); } if (context()->AreAnalysesValid(IRContext::kAnalysisInstrToBlockMapping)) { - for (BasicBlock& basic_block : *opkill_function_) { + for (BasicBlock& basic_block : *(*killing_func)) { context()->set_instr_block(basic_block.GetLabelInst(), &basic_block); for (Instruction& inst : basic_block) { context()->set_instr_block(&inst, &basic_block); @@ -168,7 +183,7 @@ uint32_t WrapOpKill::GetOpKillFuncId() { } } - return opkill_function_->result_id(); + return (*killing_func)->result_id(); } uint32_t WrapOpKill::GetOwningFunctionsReturnType(Instruction* inst) { diff --git a/source/opt/wrap_opkill.h b/source/opt/wrap_opkill.h index 09f2dfafd..7e43ca6cd 100644 --- a/source/opt/wrap_opkill.h +++ b/source/opt/wrap_opkill.h @@ -38,10 +38,10 @@ class WrapOpKill : public Pass { } private: - // Replaces the OpKill instruction |inst| with a function call to a function - // that contains a single instruction, which is OpKill. An OpUnreachable - // instruction will be placed after the function call. Return true if - // successful. + // Replaces the OpKill or OpTerminateInvocation instruction |inst| with a + // function call to a function that contains a single instruction, a clone of + // |inst|. An OpUnreachable instruction will be placed after the function + // call. Return true if successful. bool ReplaceWithFunctionCall(Instruction* inst); // Returns the id of the void type. @@ -51,9 +51,9 @@ class WrapOpKill : public Pass { uint32_t GetVoidFunctionTypeId(); // Return the id of a function that has return type void, has no parameters, - // and contains a single instruction, which is an OpKill. Returns 0 if the - // function could not be generated. - uint32_t GetOpKillFuncId(); + // and contains a single instruction, which is |opcode|, either OpKill or + // OpTerminateInvocation. Returns 0 if the function could not be generated. + uint32_t GetKillingFuncId(SpvOp opcode); // Returns the id of the return type for the function that contains |inst|. // Returns 0 if |inst| is not in a function. @@ -67,6 +67,11 @@ class WrapOpKill : public Pass { // function has a void return type and takes no parameters. If the function is // |nullptr|, then the function has not been generated. std::unique_ptr opkill_function_; + // The function that is a single instruction, which is an + // OpTerminateInvocation. The function has a void return type and takes no + // parameters. If the function is |nullptr|, then the function has not been + // generated. + std::unique_ptr opterminateinvocation_function_; }; } // namespace opt diff --git a/source/val/validate_cfg.cpp b/source/val/validate_cfg.cpp index a2fe88279..8eb3a968f 100644 --- a/source/val/validate_cfg.cpp +++ b/source/val/validate_cfg.cpp @@ -1096,12 +1096,18 @@ spv_result_t CfgPass(ValidationState_t& _, const Instruction* inst) { case SpvOpKill: case SpvOpReturnValue: case SpvOpUnreachable: + case SpvOpTerminateInvocation: _.current_function().RegisterBlockEnd(std::vector()); if (opcode == SpvOpKill) { _.current_function().RegisterExecutionModelLimitation( SpvExecutionModelFragment, "OpKill requires Fragment execution model"); } + if (opcode == SpvOpTerminateInvocation) { + _.current_function().RegisterExecutionModelLimitation( + SpvExecutionModelFragment, + "OpTerminateInvocation requires Fragment execution model"); + } break; default: break; diff --git a/source/val/validate_instruction.cpp b/source/val/validate_instruction.cpp index 6478b3cb6..9d395fb46 100644 --- a/source/val/validate_instruction.cpp +++ b/source/val/validate_instruction.cpp @@ -296,7 +296,12 @@ spv_result_t VersionCheck(ValidationState_t& _, const Instruction* inst) { << SPV_SPIRV_VERSION_MINOR_PART(last_version) << " or earlier"; } - if (inst_desc->numCapabilities > 0u) { + // OpTerminateInvocation is special because it is enabled by Shader + // capability, but also requries a extension and/or version check. + const bool capability_check_is_sufficient = + inst->opcode() != SpvOpTerminateInvocation; + + if (capability_check_is_sufficient && (inst_desc->numCapabilities > 0u)) { // We already checked that the direct capability dependency has been // satisfied. We don't need to check any further. return SPV_SUCCESS; diff --git a/test/opt/block_merge_test.cpp b/test/opt/block_merge_test.cpp index f1460c5f7..7381908ed 100644 --- a/test/opt/block_merge_test.cpp +++ b/test/opt/block_merge_test.cpp @@ -639,6 +639,40 @@ OpFunctionEnd SinglePassRunAndMatch(text, true); } +TEST_F(BlockMergeTest, DontMergeTerminateInvocation) { + const std::string text = R"( +; CHECK: OpLoopMerge [[merge:%\w+]] [[cont:%\w+]] None +; CHECK-NEXT: OpBranch [[ret:%\w+]] +; CHECK: [[ret:%\w+]] = OpLabel +; CHECK-NEXT: OpTerminateInvocation +; CHECK-DAG: [[cont]] = OpLabel +; CHECK-DAG: [[merge]] = OpLabel +OpCapability Shader +OpExtension "SPV_KHR_terminate_invocation" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%void = OpTypeVoid +%bool = OpTypeBool +%functy = OpTypeFunction %void +%func = OpFunction %void None %functy +%1 = OpLabel +OpBranch %2 +%2 = OpLabel +OpLoopMerge %3 %4 None +OpBranch %5 +%5 = OpLabel +OpTerminateInvocation +%4 = OpLabel +OpBranch %2 +%3 = OpLabel +OpUnreachable +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + TEST_F(BlockMergeTest, DontMergeUnreachable) { const std::string text = R"( ; CHECK: OpLoopMerge [[merge:%\w+]] [[cont:%\w+]] None diff --git a/test/opt/inline_test.cpp b/test/opt/inline_test.cpp index fc2197c8e..ffd3e38a5 100644 --- a/test/opt/inline_test.cpp +++ b/test/opt/inline_test.cpp @@ -2453,6 +2453,103 @@ OpFunctionEnd SinglePassRunAndCheck(before, after, false, true); } +TEST_F(InlineTest, DontInlineFuncWithOpTerminateInvocationInContinue) { + const std::string test = + R"(OpCapability Shader +OpExtension "SPV_KHR_terminate_invocation" +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 330 +OpName %main "main" +OpName %kill_ "kill(" +%void = OpTypeVoid +%3 = OpTypeFunction %void +%bool = OpTypeBool +%true = OpConstantTrue %bool +%main = OpFunction %void None %3 +%5 = OpLabel +OpBranch %9 +%9 = OpLabel +OpLoopMerge %11 %12 None +OpBranch %13 +%13 = OpLabel +OpBranchConditional %true %10 %11 +%10 = OpLabel +OpBranch %12 +%12 = OpLabel +%16 = OpFunctionCall %void %kill_ +OpBranch %9 +%11 = OpLabel +OpReturn +OpFunctionEnd +%kill_ = OpFunction %void None %3 +%7 = OpLabel +OpTerminateInvocation +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck(test, test, false, true); +} + +TEST_F(InlineTest, InlineFuncWithOpTerminateInvocationNotInContinue) { + const std::string before = + R"(OpCapability Shader +OpExtension "SPV_KHR_terminate_invocation" +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 330 +OpName %main "main" +OpName %kill_ "kill(" +%void = OpTypeVoid +%3 = OpTypeFunction %void +%bool = OpTypeBool +%true = OpConstantTrue %bool +%main = OpFunction %void None %3 +%5 = OpLabel +%16 = OpFunctionCall %void %kill_ +OpReturn +OpFunctionEnd +%kill_ = OpFunction %void None %3 +%7 = OpLabel +OpTerminateInvocation +OpFunctionEnd +)"; + + const std::string after = + R"(OpCapability Shader +OpExtension "SPV_KHR_terminate_invocation" +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 330 +OpName %main "main" +OpName %kill_ "kill(" +%void = OpTypeVoid +%3 = OpTypeFunction %void +%bool = OpTypeBool +%true = OpConstantTrue %bool +%main = OpFunction %void None %3 +%5 = OpLabel +OpTerminateInvocation +%18 = OpLabel +OpReturn +OpFunctionEnd +%kill_ = OpFunction %void None %3 +%7 = OpLabel +OpTerminateInvocation +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck(before, after, false, true); +} + TEST_F(InlineTest, EarlyReturnFunctionInlined) { // #version 140 // diff --git a/test/opt/loop_optimizations/unroll_assumptions.cpp b/test/opt/loop_optimizations/unroll_assumptions.cpp index 62f77d782..0f9330218 100644 --- a/test/opt/loop_optimizations/unroll_assumptions.cpp +++ b/test/opt/loop_optimizations/unroll_assumptions.cpp @@ -467,6 +467,73 @@ OpFunctionEnd SinglePassRunAndCheck(text, text, false); } +TEST_F(PassClassTest, KillInBody) { + const std::string text = R"(OpCapability Shader +OpMemoryModel Logical Simple +OpEntryPoint Fragment %1 "main" +OpExecutionMode %1 OriginUpperLeft +%2 = OpTypeVoid +%3 = OpTypeFunction %2 +%4 = OpTypeBool +%5 = OpTypeInt 32 0 +%6 = OpConstant %5 0 +%7 = OpConstant %5 1 +%8 = OpConstant %5 5 +%1 = OpFunction %2 None %3 +%9 = OpLabel +OpBranch %10 +%10 = OpLabel +%11 = OpPhi %5 %6 %9 %12 %13 +%14 = OpULessThan %4 %11 %8 +OpLoopMerge %15 %13 Unroll +OpBranchConditional %14 %16 %15 +%16 = OpLabel +OpKill +%13 = OpLabel +%12 = OpIAdd %5 %11 %7 +OpBranch %10 +%15 = OpLabel +OpReturn +OpFunctionEnd +)"; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(text, text, false); +} + +TEST_F(PassClassTest, TerminateInvocationInBody) { + const std::string text = R"(OpCapability Shader +OpExtension "SPV_KHR_terminate_invocation" +OpMemoryModel Logical Simple +OpEntryPoint Fragment %1 "main" +OpExecutionMode %1 OriginUpperLeft +%2 = OpTypeVoid +%3 = OpTypeFunction %2 +%4 = OpTypeBool +%5 = OpTypeInt 32 0 +%6 = OpConstant %5 0 +%7 = OpConstant %5 1 +%8 = OpConstant %5 5 +%1 = OpFunction %2 None %3 +%9 = OpLabel +OpBranch %10 +%10 = OpLabel +%11 = OpPhi %5 %6 %9 %12 %13 +%14 = OpULessThan %4 %11 %8 +OpLoopMerge %15 %13 Unroll +OpBranchConditional %14 %16 %15 +%16 = OpLabel +OpTerminateInvocation +%13 = OpLabel +%12 = OpIAdd %5 %11 %7 +OpBranch %10 +%15 = OpLabel +OpReturn +OpFunctionEnd +)"; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(text, text, false); +} + /* Generated from the following GLSL #version 440 core diff --git a/test/opt/wrap_opkill_test.cpp b/test/opt/wrap_opkill_test.cpp index 33e52f06e..e944109e8 100644 --- a/test/opt/wrap_opkill_test.cpp +++ b/test/opt/wrap_opkill_test.cpp @@ -193,6 +193,310 @@ TEST_F(WrapOpKillTest, MultipleOpKillInDifferentFunc) { SinglePassRunAndMatch(text, true); } +TEST_F(WrapOpKillTest, SingleOpTerminateInvocation) { + const std::string text = R"( +; CHECK: OpEntryPoint Fragment [[main:%\w+]] +; CHECK: [[main]] = OpFunction +; CHECK: OpFunctionCall %void [[orig_kill:%\w+]] +; CHECK: [[orig_kill]] = OpFunction +; CHECK-NEXT: OpLabel +; CHECK-NEXT: OpFunctionCall %void [[new_kill:%\w+]] +; CHECK-NEXT: OpReturn +; CHECK: [[new_kill]] = OpFunction +; CHECK-NEXT: OpLabel +; CHECK-NEXT: OpTerminateInvocation +; CHECK-NEXT: OpFunctionEnd + OpCapability Shader + OpExtension "SPV_KHR_terminate_invocation" + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" + OpExecutionMode %main OriginUpperLeft + OpSource GLSL 330 + OpName %main "main" + %void = OpTypeVoid + %5 = OpTypeFunction %void + %bool = OpTypeBool + %true = OpConstantTrue %bool + %main = OpFunction %void None %5 + %8 = OpLabel + OpBranch %9 + %9 = OpLabel + OpLoopMerge %10 %11 None + OpBranch %12 + %12 = OpLabel + OpBranchConditional %true %13 %10 + %13 = OpLabel + OpBranch %11 + %11 = OpLabel + %14 = OpFunctionCall %void %kill_ + OpBranch %9 + %10 = OpLabel + OpReturn + OpFunctionEnd + %kill_ = OpFunction %void None %5 + %15 = OpLabel + OpTerminateInvocation + OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(WrapOpKillTest, MultipleTerminateInvocationInSameFunc) { + const std::string text = R"( +; CHECK: OpEntryPoint Fragment [[main:%\w+]] +; CHECK: [[main]] = OpFunction +; CHECK: OpFunctionCall %void [[orig_kill:%\w+]] +; CHECK: [[orig_kill]] = OpFunction +; CHECK-NEXT: OpLabel +; CHECK-NEXT: OpSelectionMerge +; CHECK-NEXT: OpBranchConditional +; CHECK-NEXT: OpLabel +; CHECK-NEXT: OpFunctionCall %void [[new_kill:%\w+]] +; CHECK-NEXT: OpReturn +; CHECK-NEXT: OpLabel +; CHECK-NEXT: OpFunctionCall %void [[new_kill]] +; CHECK-NEXT: OpReturn +; CHECK: [[new_kill]] = OpFunction +; CHECK-NEXT: OpLabel +; CHECK-NEXT: OpTerminateInvocation +; CHECK-NEXT: OpFunctionEnd + OpCapability Shader + OpExtension "SPV_KHR_terminate_invocation" + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" + OpExecutionMode %main OriginUpperLeft + OpSource GLSL 330 + OpName %main "main" + %void = OpTypeVoid + %5 = OpTypeFunction %void + %bool = OpTypeBool + %true = OpConstantTrue %bool + %main = OpFunction %void None %5 + %8 = OpLabel + OpBranch %9 + %9 = OpLabel + OpLoopMerge %10 %11 None + OpBranch %12 + %12 = OpLabel + OpBranchConditional %true %13 %10 + %13 = OpLabel + OpBranch %11 + %11 = OpLabel + %14 = OpFunctionCall %void %kill_ + OpBranch %9 + %10 = OpLabel + OpReturn + OpFunctionEnd + %kill_ = OpFunction %void None %5 + %15 = OpLabel + OpSelectionMerge %16 None + OpBranchConditional %true %17 %18 + %17 = OpLabel + OpTerminateInvocation + %18 = OpLabel + OpTerminateInvocation + %16 = OpLabel + OpReturn + OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(WrapOpKillTest, MultipleOpTerminateInvocationDifferentFunc) { + const std::string text = R"( +; CHECK: OpEntryPoint Fragment [[main:%\w+]] +; CHECK: [[main]] = OpFunction +; CHECK: OpFunctionCall %void [[orig_kill1:%\w+]] +; CHECK-NEXT: OpFunctionCall %void [[orig_kill2:%\w+]] +; CHECK: [[orig_kill1]] = OpFunction +; CHECK-NEXT: OpLabel +; CHECK-NEXT: OpFunctionCall %void [[new_kill:%\w+]] +; CHECK-NEXT: OpReturn +; CHECK: [[orig_kill2]] = OpFunction +; CHECK-NEXT: OpLabel +; CHECK-NEXT: OpFunctionCall %void [[new_kill]] +; CHECK-NEXT: OpReturn +; CHECK: [[new_kill]] = OpFunction +; CHECK-NEXT: OpLabel +; CHECK-NEXT: OpTerminateInvocation +; CHECK-NEXT: OpFunctionEnd + OpCapability Shader + OpExtension "SPV_KHR_terminate_invocation" + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" + OpExecutionMode %main OriginUpperLeft + OpSource GLSL 330 + OpName %main "main" + %void = OpTypeVoid + %4 = OpTypeFunction %void + %bool = OpTypeBool + %true = OpConstantTrue %bool + %main = OpFunction %void None %4 + %7 = OpLabel + OpBranch %8 + %8 = OpLabel + OpLoopMerge %9 %10 None + OpBranch %11 + %11 = OpLabel + OpBranchConditional %true %12 %9 + %12 = OpLabel + OpBranch %10 + %10 = OpLabel + %13 = OpFunctionCall %void %14 + %15 = OpFunctionCall %void %16 + OpBranch %8 + %9 = OpLabel + OpReturn + OpFunctionEnd + %14 = OpFunction %void None %4 + %17 = OpLabel + OpTerminateInvocation + OpFunctionEnd + %16 = OpFunction %void None %4 + %18 = OpLabel + OpTerminateInvocation + OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(WrapOpKillTest, KillAndTerminateInvocationSameFunc) { + const std::string text = R"( +; CHECK: OpEntryPoint Fragment [[main:%\w+]] +; CHECK: [[main]] = OpFunction +; CHECK: OpFunctionCall %void [[orig_kill:%\w+]] +; CHECK: [[orig_kill]] = OpFunction +; CHECK-NEXT: OpLabel +; CHECK-NEXT: OpSelectionMerge +; CHECK-NEXT: OpBranchConditional +; CHECK-NEXT: OpLabel +; CHECK-NEXT: OpFunctionCall %void [[new_kill:%\w+]] +; CHECK-NEXT: OpReturn +; CHECK-NEXT: OpLabel +; CHECK-NEXT: OpFunctionCall %void [[new_terminate:%\w+]] +; CHECK-NEXT: OpReturn +; CHECK: [[new_kill]] = OpFunction +; CHECK-NEXT: OpLabel +; CHECK-NEXT: OpKill +; CHECK-NEXT: OpFunctionEnd +; CHECK-NEXT: [[new_terminate]] = OpFunction +; CHECK-NEXT: OpLabel +; CHECK-NEXT: OpTerminateInvocation +; CHECK-NEXT: OpFunctionEnd + OpCapability Shader + OpExtension "SPV_KHR_terminate_invocation" + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" + OpExecutionMode %main OriginUpperLeft + OpSource GLSL 330 + OpName %main "main" + %void = OpTypeVoid + %5 = OpTypeFunction %void + %bool = OpTypeBool + %true = OpConstantTrue %bool + %main = OpFunction %void None %5 + %8 = OpLabel + OpBranch %9 + %9 = OpLabel + OpLoopMerge %10 %11 None + OpBranch %12 + %12 = OpLabel + OpBranchConditional %true %13 %10 + %13 = OpLabel + OpBranch %11 + %11 = OpLabel + %14 = OpFunctionCall %void %kill_ + OpBranch %9 + %10 = OpLabel + OpReturn + OpFunctionEnd + %kill_ = OpFunction %void None %5 + %15 = OpLabel + OpSelectionMerge %16 None + OpBranchConditional %true %17 %18 + %17 = OpLabel + OpKill + %18 = OpLabel + OpTerminateInvocation + %16 = OpLabel + OpReturn + OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(WrapOpKillTest, KillAndTerminateInvocationDifferentFunc) { + const std::string text = R"( +; CHECK: OpEntryPoint Fragment [[main:%\w+]] +; CHECK: [[main]] = OpFunction +; CHECK: OpFunctionCall %void [[orig_kill1:%\w+]] +; CHECK-NEXT: OpFunctionCall %void [[orig_kill2:%\w+]] +; CHECK: [[orig_kill1]] = OpFunction +; CHECK-NEXT: OpLabel +; CHECK-NEXT: OpFunctionCall %void [[new_terminate:%\w+]] +; CHECK-NEXT: OpReturn +; CHECK: [[orig_kill2]] = OpFunction +; CHECK-NEXT: OpLabel +; CHECK-NEXT: OpFunctionCall %void [[new_kill:%\w+]] +; CHECK-NEXT: OpReturn +; CHECK: [[new_kill]] = OpFunction +; CHECK-NEXT: OpLabel +; CHECK-NEXT: OpKill +; CHECK-NEXT: OpFunctionEnd +; CHECK-NEXT: [[new_terminate]] = OpFunction +; CHECK-NEXT: OpLabel +; CHECK-NEXT: OpTerminateInvocation +; CHECK-NEXT: OpFunctionEnd + OpCapability Shader + OpExtension "SPV_KHR_terminate_invocation" + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" + OpExecutionMode %main OriginUpperLeft + OpSource GLSL 330 + OpName %main "main" + %void = OpTypeVoid + %4 = OpTypeFunction %void + %bool = OpTypeBool + %true = OpConstantTrue %bool + %main = OpFunction %void None %4 + %7 = OpLabel + OpBranch %8 + %8 = OpLabel + OpLoopMerge %9 %10 None + OpBranch %11 + %11 = OpLabel + OpBranchConditional %true %12 %9 + %12 = OpLabel + OpBranch %10 + %10 = OpLabel + %13 = OpFunctionCall %void %14 + %15 = OpFunctionCall %void %16 + OpBranch %8 + %9 = OpLabel + OpReturn + OpFunctionEnd + %14 = OpFunction %void None %4 + %17 = OpLabel + OpTerminateInvocation + OpFunctionEnd + %16 = OpFunction %void None %4 + %18 = OpLabel + OpKill + OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true); +} + TEST_F(WrapOpKillTest, FuncWithReturnValue) { const std::string text = R"( ; CHECK: OpEntryPoint Fragment [[main:%\w+]] diff --git a/test/text_to_binary.control_flow_test.cpp b/test/text_to_binary.control_flow_test.cpp index ee8fed475..3e117b8f1 100644 --- a/test/text_to_binary.control_flow_test.cpp +++ b/test/text_to_binary.control_flow_test.cpp @@ -388,12 +388,35 @@ INSTANTIATE_TEST_SUITE_P( })); // clang-format on +using OpKillTest = spvtest::TextToBinaryTest; + +INSTANTIATE_TEST_SUITE_P(OpKillTest, ControlFlowRoundTripTest, + Values("OpKill\n")); + +TEST_F(OpKillTest, ExtraArgsAssemblyError) { + const std::string input = "OpKill 1"; + EXPECT_THAT(CompileFailure(input), + Eq("Expected or at the beginning of an " + "instruction, found '1'.")); +} + +using OpTerminateInvocationTest = spvtest::TextToBinaryTest; + +INSTANTIATE_TEST_SUITE_P(OpTerminateInvocationTest, ControlFlowRoundTripTest, + Values("OpTerminateInvocation\n")); + +TEST_F(OpTerminateInvocationTest, ExtraArgsAssemblyError) { + const std::string input = "OpTerminateInvocation 1"; + EXPECT_THAT(CompileFailure(input), + Eq("Expected or at the beginning of an " + "instruction, found '1'.")); +} + // TODO(dneto): OpPhi // TODO(dneto): OpLoopMerge // TODO(dneto): OpLabel // TODO(dneto): OpBranch // TODO(dneto): OpSwitch -// TODO(dneto): OpKill // TODO(dneto): OpReturn // TODO(dneto): OpReturnValue // TODO(dneto): OpUnreachable diff --git a/test/val/CMakeLists.txt b/test/val/CMakeLists.txt index 138e71144..23d7a19e3 100644 --- a/test/val/CMakeLists.txt +++ b/test/val/CMakeLists.txt @@ -38,6 +38,7 @@ add_spvtools_unittest(TARGET val_abcde val_entry_point.cpp val_explicit_reserved_test.cpp val_extensions_test.cpp + val_extension_spv_khr_terminate_invocation.cpp val_ext_inst_test.cpp ${VAL_TEST_COMMON_SRCS} LIBS ${SPIRV_TOOLS} diff --git a/test/val/val_extension_spv_khr_terminate_invocation.cpp b/test/val/val_extension_spv_khr_terminate_invocation.cpp new file mode 100644 index 000000000..4cabf9e21 --- /dev/null +++ b/test/val/val_extension_spv_khr_terminate_invocation.cpp @@ -0,0 +1,150 @@ +// Copyright (c) 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Tests for OpExtension validator rules. + +#include +#include + +#include "gmock/gmock.h" +#include "source/enum_string_mapping.h" +#include "source/extensions.h" +#include "source/spirv_target_env.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" +#include "test/val/val_fixtures.h" + +namespace spvtools { +namespace val { +namespace { + +using ::testing::HasSubstr; +using ::testing::Values; +using ::testing::ValuesIn; + +using ValidateSpvKHRTerminateInvocation = spvtest::ValidateBase; + +TEST_F(ValidateSpvKHRTerminateInvocation, Valid) { + const std::string str = R"( + OpCapability Shader + OpExtension "SPV_KHR_terminate_invocation" + OpMemoryModel Logical Simple + OpEntryPoint Fragment %main "main" + OpExecutionMode %main OriginUpperLeft + + %void = OpTypeVoid + %void_fn = OpTypeFunction %void + + %main = OpFunction %void None %void_fn + %entry = OpLabel + OpTerminateInvocation + OpFunctionEnd +)"; + CompileSuccessfully(str.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateSpvKHRTerminateInvocation, RequiresExtension) { + const std::string str = R"( + OpCapability Shader + OpMemoryModel Logical Simple + OpEntryPoint Fragment %main "main" + OpExecutionMode %main OriginUpperLeft + + %void = OpTypeVoid + %void_fn = OpTypeFunction %void + + %main = OpFunction %void None %void_fn + %entry = OpLabel + OpTerminateInvocation + OpFunctionEnd +)"; + CompileSuccessfully(str.c_str()); + EXPECT_NE(SPV_SUCCESS, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("TerminateInvocation requires one of the following " + "extensions: SPV_KHR_terminate_invocation")); +} + +TEST_F(ValidateSpvKHRTerminateInvocation, RequiresShaderCapability) { + const std::string str = R"( + OpCapability Kernel + OpCapability Addresses + OpExtension "SPV_KHR_terminate_invocation" + OpMemoryModel Physical32 OpenCL + OpEntryPoint Kernel %main "main" + + %void = OpTypeVoid + %void_fn = OpTypeFunction %void + + %main = OpFunction %void None %void_fn + %entry = OpLabel + OpTerminateInvocation + OpFunctionEnd +)"; + CompileSuccessfully(str.c_str()); + EXPECT_NE(SPV_SUCCESS, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "TerminateInvocation requires one of these capabilities: Shader \n")); +} + +TEST_F(ValidateSpvKHRTerminateInvocation, RequiresFragmentShader) { + const std::string str = R"( + OpCapability Shader + OpExtension "SPV_KHR_terminate_invocation" + OpMemoryModel Logical Simple + OpEntryPoint GLCompute %main "main" + + %void = OpTypeVoid + %void_fn = OpTypeFunction %void + + %main = OpFunction %void None %void_fn + %entry = OpLabel + OpTerminateInvocation + OpFunctionEnd +)"; + CompileSuccessfully(str.c_str()); + EXPECT_NE(SPV_SUCCESS, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpTerminateInvocation requires Fragment execution model")); +} + +TEST_F(ValidateSpvKHRTerminateInvocation, IsTerminatorInstruction) { + const std::string str = R"( + OpCapability Shader + OpExtension "SPV_KHR_terminate_invocation" + OpMemoryModel Logical Simple + OpEntryPoint GLCompute %main "main" + + %void = OpTypeVoid + %void_fn = OpTypeFunction %void + + %main = OpFunction %void None %void_fn + %entry = OpLabel + OpTerminateInvocation + OpReturn + OpFunctionEnd +)"; + CompileSuccessfully(str.c_str()); + EXPECT_NE(SPV_SUCCESS, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Return must appear in a block")); +} + +} // namespace +} // namespace val +} // namespace spvtools diff --git a/test/val/val_extensions_test.cpp b/test/val/val_extensions_test.cpp index 682c32143..491a80853 100644 --- a/test/val/val_extensions_test.cpp +++ b/test/val/val_extensions_test.cpp @@ -62,7 +62,8 @@ INSTANTIATE_TEST_SUITE_P( "SPV_EXT_shader_viewport_index_layer", "SPV_AMD_shader_image_load_store_lod", "SPV_AMD_shader_fragment_mask", "SPV_GOOGLE_decorate_string", "SPV_GOOGLE_hlsl_functionality1", - "SPV_NV_shader_subgroup_partitioned", "SPV_EXT_descriptor_indexing")); + "SPV_NV_shader_subgroup_partitioned", "SPV_EXT_descriptor_indexing", + "SPV_KHR_terminate_invocation")); INSTANTIATE_TEST_SUITE_P(FailSilently, ValidateUnknownExtensions, Values("ERROR_unknown_extension", "SPV_KHR_",