diff --git a/source/fuzz/CMakeLists.txt b/source/fuzz/CMakeLists.txt index 97f89763c..09272bb67 100644 --- a/source/fuzz/CMakeLists.txt +++ b/source/fuzz/CMakeLists.txt @@ -48,6 +48,7 @@ if(SPIRV_BUILD_FUZZER) fuzzer_pass_construct_composites.h fuzzer_pass_copy_objects.h fuzzer_pass_obfuscate_constants.h + fuzzer_pass_outline_functions.h fuzzer_pass_permute_blocks.h fuzzer_pass_split_blocks.h fuzzer_util.h @@ -72,6 +73,7 @@ if(SPIRV_BUILD_FUZZER) transformation_composite_extract.h transformation_copy_object.h transformation_move_block_down.h + transformation_outline_function.h transformation_replace_boolean_constant_with_constant_binary.h transformation_replace_constant_with_uniform.h transformation_replace_id_with_synonym.h @@ -102,6 +104,7 @@ if(SPIRV_BUILD_FUZZER) fuzzer_pass_construct_composites.cpp fuzzer_pass_copy_objects.cpp fuzzer_pass_obfuscate_constants.cpp + fuzzer_pass_outline_functions.cpp fuzzer_pass_permute_blocks.cpp fuzzer_pass_split_blocks.cpp fuzzer_util.cpp @@ -125,6 +128,7 @@ if(SPIRV_BUILD_FUZZER) transformation_composite_extract.cpp transformation_copy_object.cpp transformation_move_block_down.cpp + transformation_outline_function.cpp transformation_replace_boolean_constant_with_constant_binary.cpp transformation_replace_constant_with_uniform.cpp transformation_replace_id_with_synonym.cpp diff --git a/source/fuzz/fuzzer.cpp b/source/fuzz/fuzzer.cpp index 20e714d7c..6b4d54a3f 100644 --- a/source/fuzz/fuzzer.cpp +++ b/source/fuzz/fuzzer.cpp @@ -32,6 +32,7 @@ #include "source/fuzz/fuzzer_pass_construct_composites.h" #include "source/fuzz/fuzzer_pass_copy_objects.h" #include "source/fuzz/fuzzer_pass_obfuscate_constants.h" +#include "source/fuzz/fuzzer_pass_outline_functions.h" #include "source/fuzz/fuzzer_pass_permute_blocks.h" #include "source/fuzz/fuzzer_pass_split_blocks.h" #include "source/fuzz/protobufs/spirvfuzz_protobufs.h" @@ -185,6 +186,9 @@ Fuzzer::FuzzerResultStatus Fuzzer::Run( MaybeAddPass(&passes, ir_context.get(), &fact_manager, &fuzzer_context, transformation_sequence_out); + MaybeAddPass(&passes, ir_context.get(), + &fact_manager, &fuzzer_context, + transformation_sequence_out); MaybeAddPass(&passes, ir_context.get(), &fact_manager, &fuzzer_context, transformation_sequence_out); diff --git a/source/fuzz/fuzzer_context.cpp b/source/fuzz/fuzzer_context.cpp index 356cb3599..b9d0ff943 100644 --- a/source/fuzz/fuzzer_context.cpp +++ b/source/fuzz/fuzzer_context.cpp @@ -38,6 +38,7 @@ const std::pair kChanceOfCopyingObject = {20, 50}; const std::pair kChanceOfConstructingComposite = {20, 50}; const std::pair kChanceOfMovingBlockDown = {20, 50}; const std::pair kChanceOfObfuscatingConstant = {10, 90}; +const std::pair kChanceOfOutliningFunction = {10, 90}; const std::pair kChanceOfReplacingIdWithSynonym = {10, 90}; const std::pair kChanceOfSplittingBlock = {40, 95}; @@ -85,6 +86,8 @@ FuzzerContext::FuzzerContext(RandomGenerator* random_generator, ChooseBetweenMinAndMax(kChanceOfMovingBlockDown); chance_of_obfuscating_constant_ = ChooseBetweenMinAndMax(kChanceOfObfuscatingConstant); + chance_of_outlining_function_ = + ChooseBetweenMinAndMax(kChanceOfOutliningFunction); chance_of_replacing_id_with_synonym_ = ChooseBetweenMinAndMax(kChanceOfReplacingIdWithSynonym); chance_of_splitting_block_ = ChooseBetweenMinAndMax(kChanceOfSplittingBlock); diff --git a/source/fuzz/fuzzer_context.h b/source/fuzz/fuzzer_context.h index c8242e617..584c6cb85 100644 --- a/source/fuzz/fuzzer_context.h +++ b/source/fuzz/fuzzer_context.h @@ -85,6 +85,9 @@ class FuzzerContext { uint32_t GetChanceOfObfuscatingConstant() { return chance_of_obfuscating_constant_; } + uint32_t GetChanceOfOutliningFunction() { + return chance_of_outlining_function_; + } uint32_t GetChanceOfReplacingIdWithSynonym() { return chance_of_replacing_id_with_synonym_; } @@ -121,6 +124,7 @@ class FuzzerContext { uint32_t chance_of_copying_object_; uint32_t chance_of_moving_block_down_; uint32_t chance_of_obfuscating_constant_; + uint32_t chance_of_outlining_function_; uint32_t chance_of_replacing_id_with_synonym_; uint32_t chance_of_splitting_block_; diff --git a/source/fuzz/fuzzer_pass_construct_composites.cpp b/source/fuzz/fuzzer_pass_construct_composites.cpp index 9eb56316c..ff0adabcc 100644 --- a/source/fuzz/fuzzer_pass_construct_composites.cpp +++ b/source/fuzz/fuzzer_pass_construct_composites.cpp @@ -148,7 +148,6 @@ void FuzzerPassConstructComposites::Apply() { transformation.Apply(GetIRContext(), GetFactManager()); *GetTransformations()->add_transformation() = transformation.ToMessage(); - // Indicate that one instruction was added. }); } diff --git a/source/fuzz/fuzzer_pass_outline_functions.cpp b/source/fuzz/fuzzer_pass_outline_functions.cpp new file mode 100644 index 000000000..d59c195d7 --- /dev/null +++ b/source/fuzz/fuzzer_pass_outline_functions.cpp @@ -0,0 +1,99 @@ +// Copyright (c) 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/fuzz/fuzzer_pass_outline_functions.h" + +#include + +#include "source/fuzz/fuzzer_util.h" +#include "source/fuzz/transformation_outline_function.h" + +namespace spvtools { +namespace fuzz { + +FuzzerPassOutlineFunctions::FuzzerPassOutlineFunctions( + opt::IRContext* ir_context, FactManager* fact_manager, + FuzzerContext* fuzzer_context, + protobufs::TransformationSequence* transformations) + : FuzzerPass(ir_context, fact_manager, fuzzer_context, transformations) {} + +FuzzerPassOutlineFunctions::~FuzzerPassOutlineFunctions() = default; + +void FuzzerPassOutlineFunctions::Apply() { + std::vector original_functions; + for (auto& function : *GetIRContext()->module()) { + original_functions.push_back(&function); + } + for (auto& function : original_functions) { + if (!GetFuzzerContext()->ChoosePercentage( + GetFuzzerContext()->GetChanceOfOutliningFunction())) { + continue; + } + std::vector blocks; + for (auto& block : *function) { + blocks.push_back(&block); + } + auto entry_block = blocks[GetFuzzerContext()->RandomIndex(blocks)]; + auto dominator_analysis = GetIRContext()->GetDominatorAnalysis(function); + auto postdominator_analysis = + GetIRContext()->GetPostDominatorAnalysis(function); + std::vector candidate_exit_blocks; + for (auto postdominates_entry_block = entry_block; + postdominates_entry_block != nullptr; + postdominates_entry_block = postdominator_analysis->ImmediateDominator( + postdominates_entry_block)) { + if (dominator_analysis->Dominates(entry_block, + postdominates_entry_block)) { + candidate_exit_blocks.push_back(postdominates_entry_block); + } + } + if (candidate_exit_blocks.empty()) { + continue; + } + auto exit_block = candidate_exit_blocks[GetFuzzerContext()->RandomIndex( + candidate_exit_blocks)]; + + auto region_blocks = TransformationOutlineFunction::GetRegionBlocks( + GetIRContext(), entry_block, exit_block); + std::map input_id_to_fresh_id; + for (auto id : TransformationOutlineFunction::GetRegionInputIds( + GetIRContext(), region_blocks, exit_block)) { + input_id_to_fresh_id[id] = GetFuzzerContext()->GetFreshId(); + } + std::map output_id_to_fresh_id; + for (auto id : TransformationOutlineFunction::GetRegionOutputIds( + GetIRContext(), region_blocks, exit_block)) { + output_id_to_fresh_id[id] = GetFuzzerContext()->GetFreshId(); + } + TransformationOutlineFunction transformation( + entry_block->id(), exit_block->id(), + /*new_function_struct_return_type_id*/ + GetFuzzerContext()->GetFreshId(), + /*new_function_type_id*/ GetFuzzerContext()->GetFreshId(), + /*new_function_id*/ GetFuzzerContext()->GetFreshId(), + /*new_function_region_entry_block*/ + GetFuzzerContext()->GetFreshId(), + /*new_caller_result_id*/ GetFuzzerContext()->GetFreshId(), + /*new_callee_result_id*/ GetFuzzerContext()->GetFreshId(), + /*input_id_to_fresh_id*/ std::move(input_id_to_fresh_id), + /*output_id_to_fresh_id*/ std::move(output_id_to_fresh_id)); + if (transformation.IsApplicable(GetIRContext(), *GetFactManager())) { + transformation.Apply(GetIRContext(), GetFactManager()); + *GetTransformations()->add_transformation() = transformation.ToMessage(); + } + } +} + +} // namespace fuzz +} // namespace spvtools diff --git a/source/fuzz/fuzzer_pass_outline_functions.h b/source/fuzz/fuzzer_pass_outline_functions.h new file mode 100644 index 000000000..5448e7df7 --- /dev/null +++ b/source/fuzz/fuzzer_pass_outline_functions.h @@ -0,0 +1,40 @@ +// Copyright (c) 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_FUZZ_FUZZER_PASS_OUTLINE_FUNCTIONS_H_ +#define SOURCE_FUZZ_FUZZER_PASS_OUTLINE_FUNCTIONS_H_ + +#include "source/fuzz/fuzzer_pass.h" + +namespace spvtools { +namespace fuzz { + +// A fuzzer pass for outlining single-entry single-exit regions of a control +// flow graph into their own functions. +class FuzzerPassOutlineFunctions : public FuzzerPass { + public: + FuzzerPassOutlineFunctions( + opt::IRContext* ir_context, FactManager* fact_manager, + FuzzerContext* fuzzer_context, + protobufs::TransformationSequence* transformations); + + ~FuzzerPassOutlineFunctions(); + + void Apply() override; +}; + +} // namespace fuzz +} // namespace spvtools + +#endif // SOURCE_FUZZ_FUZZER_PASS_OUTLINE_FUNCTIONS_H_ diff --git a/source/fuzz/protobufs/spvtoolsfuzz.proto b/source/fuzz/protobufs/spvtoolsfuzz.proto index b33c2e5dc..ba3a01314 100644 --- a/source/fuzz/protobufs/spvtoolsfuzz.proto +++ b/source/fuzz/protobufs/spvtoolsfuzz.proto @@ -21,6 +21,16 @@ syntax = "proto3"; package spvtools.fuzz.protobufs; +message UInt32Pair { + + // A pair of uint32s; useful for defining mappings. + + uint32 first = 1; + + uint32 second = 2; + +} + message InstructionDescriptor { // Describes an instruction in some block of a function with respect to a @@ -190,6 +200,7 @@ message Transformation { TransformationSetMemoryOperandsMask set_memory_operands_mask = 20; TransformationCompositeExtract composite_extract = 21; TransformationVectorShuffle vector_shuffle = 22; + TransformationOutlineFunction outline_function = 23; // Add additional option using the next available number. } } @@ -389,6 +400,53 @@ message TransformationMoveBlockDown { uint32 block_id = 1; } +message TransformationOutlineFunction { + + // A transformation that outlines a single-entry single-exit region of a + // control flow graph into a separate function, and replaces the region with + // a call to that function. + + // Id of the entry block of the single-entry single-exit region to be outlined + uint32 entry_block = 1; + + // Id of the exit block of the single-entry single-exit region to be outlined + uint32 exit_block = 2; + + // Id of a struct that will store the return values of the new function + uint32 new_function_struct_return_type_id = 3; + + // A fresh id for the type of the outlined function + uint32 new_function_type_id = 4; + + // A fresh id for the outlined function itself + uint32 new_function_id = 5; + + // A fresh id to represent the block in the outlined function that represents + // the first block of the outlined region. + uint32 new_function_region_entry_block = 6; + + // A fresh id for the result of the OpFunctionCall instruction that will call + // the outlined function + uint32 new_caller_result_id = 7; + + // A fresh id to capture the return value of the outlined function - the + // argument to OpReturn + uint32 new_callee_result_id = 8; + + // Ids defined outside the region and used inside the region will become + // parameters to the outlined function. This is a mapping from used ids to + // fresh parameter ids. + repeated UInt32Pair input_id_to_fresh_id = 9; + + // Ids defined inside the region and used outside the region will become + // fresh ids defined by the outlined function, which get copied into the + // function's struct return value and then copied into their destination ids + // by the caller. This is a mapping from original ids to corresponding fresh + // ids. + repeated UInt32Pair output_id_to_fresh_id = 10; + +} + message TransformationReplaceBooleanConstantWithConstantBinary { // A transformation to capture replacing a use of a boolean constant with diff --git a/source/fuzz/transformation.cpp b/source/fuzz/transformation.cpp index d8fc92fa4..531cfa1fa 100644 --- a/source/fuzz/transformation.cpp +++ b/source/fuzz/transformation.cpp @@ -29,6 +29,7 @@ #include "source/fuzz/transformation_composite_extract.h" #include "source/fuzz/transformation_copy_object.h" #include "source/fuzz/transformation_move_block_down.h" +#include "source/fuzz/transformation_outline_function.h" #include "source/fuzz/transformation_replace_boolean_constant_with_constant_binary.h" #include "source/fuzz/transformation_replace_constant_with_uniform.h" #include "source/fuzz/transformation_replace_id_with_synonym.h" @@ -83,6 +84,9 @@ std::unique_ptr Transformation::FromMessage( return MakeUnique(message.copy_object()); case protobufs::Transformation::TransformationCase::kMoveBlockDown: return MakeUnique(message.move_block_down()); + case protobufs::Transformation::TransformationCase::kOutlineFunction: + return MakeUnique( + message.outline_function()); case protobufs::Transformation::TransformationCase:: kReplaceBooleanConstantWithConstantBinary: return MakeUnique( diff --git a/source/fuzz/transformation_outline_function.cpp b/source/fuzz/transformation_outline_function.cpp new file mode 100644 index 000000000..95517f54b --- /dev/null +++ b/source/fuzz/transformation_outline_function.cpp @@ -0,0 +1,931 @@ +// Copyright (c) 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/fuzz/transformation_outline_function.h" + +#include + +#include "source/fuzz/fuzzer_util.h" + +namespace spvtools { +namespace fuzz { + +namespace { + +std::map PairSequenceToMap( + const google::protobuf::RepeatedPtrField& + pair_sequence) { + std::map result; + for (auto& pair : pair_sequence) { + result[pair.first()] = pair.second(); + } + return result; +} + +} // namespace + +TransformationOutlineFunction::TransformationOutlineFunction( + const spvtools::fuzz::protobufs::TransformationOutlineFunction& message) + : message_(message) {} + +TransformationOutlineFunction::TransformationOutlineFunction( + uint32_t entry_block, uint32_t exit_block, + uint32_t new_function_struct_return_type_id, uint32_t new_function_type_id, + uint32_t new_function_id, uint32_t new_function_region_entry_block, + uint32_t new_caller_result_id, uint32_t new_callee_result_id, + std::map&& input_id_to_fresh_id, + std::map&& output_id_to_fresh_id) { + message_.set_entry_block(entry_block); + message_.set_exit_block(exit_block); + message_.set_new_function_struct_return_type_id( + new_function_struct_return_type_id); + message_.set_new_function_type_id(new_function_type_id); + message_.set_new_function_id(new_function_id); + message_.set_new_function_region_entry_block(new_function_region_entry_block); + message_.set_new_caller_result_id(new_caller_result_id); + message_.set_new_callee_result_id(new_callee_result_id); + for (auto& entry : input_id_to_fresh_id) { + protobufs::UInt32Pair pair; + pair.set_first(entry.first); + pair.set_second(entry.second); + *message_.add_input_id_to_fresh_id() = pair; + } + for (auto& entry : output_id_to_fresh_id) { + protobufs::UInt32Pair pair; + pair.set_first(entry.first); + pair.set_second(entry.second); + *message_.add_output_id_to_fresh_id() = pair; + } +} + +bool TransformationOutlineFunction::IsApplicable( + opt::IRContext* context, + const spvtools::fuzz::FactManager& /*unused*/) const { + std::set ids_used_by_this_transformation; + + // The various new ids used by the transformation must be fresh and distinct. + + if (!CheckIdIsFreshAndNotUsedByThisTransformation( + message_.new_function_struct_return_type_id(), context, + &ids_used_by_this_transformation)) { + return false; + } + + if (!CheckIdIsFreshAndNotUsedByThisTransformation( + message_.new_function_type_id(), context, + &ids_used_by_this_transformation)) { + return false; + } + + if (!CheckIdIsFreshAndNotUsedByThisTransformation( + message_.new_function_id(), context, + &ids_used_by_this_transformation)) { + return false; + } + + if (!CheckIdIsFreshAndNotUsedByThisTransformation( + message_.new_function_region_entry_block(), context, + &ids_used_by_this_transformation)) { + return false; + } + + if (!CheckIdIsFreshAndNotUsedByThisTransformation( + message_.new_caller_result_id(), context, + &ids_used_by_this_transformation)) { + return false; + } + + if (!CheckIdIsFreshAndNotUsedByThisTransformation( + message_.new_callee_result_id(), context, + &ids_used_by_this_transformation)) { + return false; + } + + for (auto& pair : message_.input_id_to_fresh_id()) { + if (!CheckIdIsFreshAndNotUsedByThisTransformation( + pair.second(), context, &ids_used_by_this_transformation)) { + return false; + } + } + + for (auto& pair : message_.output_id_to_fresh_id()) { + if (!CheckIdIsFreshAndNotUsedByThisTransformation( + pair.second(), context, &ids_used_by_this_transformation)) { + return false; + } + } + + // The entry and exit block ids must indeed refer to blocks. + for (auto block_id : {message_.entry_block(), message_.exit_block()}) { + auto block_label = context->get_def_use_mgr()->GetDef(block_id); + if (!block_label || block_label->opcode() != SpvOpLabel) { + return false; + } + } + + auto entry_block = context->cfg()->block(message_.entry_block()); + auto exit_block = context->cfg()->block(message_.exit_block()); + + // The entry block cannot start with OpVariable - this would mean that + // outlining would remove a variable from the function containing the region + // being outlined. + if (entry_block->begin()->opcode() == SpvOpVariable) { + return false; + } + + // For simplicity, we do not allow the entry block to be a loop header. + if (entry_block->GetLoopMergeInst()) { + return false; + } + + // For simplicity, we do not allow the exit block to be a merge block or + // continue target. + bool exit_block_is_merge_or_continue = false; + context->get_def_use_mgr()->WhileEachUse( + exit_block->id(), + [&exit_block_is_merge_or_continue]( + const opt::Instruction* use_instruction, + uint32_t /*unused*/) -> bool { + switch (use_instruction->opcode()) { + case SpvOpLoopMerge: + case SpvOpSelectionMerge: + exit_block_is_merge_or_continue = true; + return false; + default: + return true; + } + }); + if (exit_block_is_merge_or_continue) { + return false; + } + + // The entry block cannot start with OpPhi. This is to keep the + // transformation logic simple. (Another transformation to split the OpPhis + // from a block could be applied to avoid this scenario.) + if (entry_block->begin()->opcode() == SpvOpPhi) { + return false; + } + + // The block must be in the same function. + if (entry_block->GetParent() != exit_block->GetParent()) { + return false; + } + + // The entry block must dominate the exit block. + auto dominator_analysis = + context->GetDominatorAnalysis(entry_block->GetParent()); + if (!dominator_analysis->Dominates(entry_block, exit_block)) { + return false; + } + + // The exit block must post-dominate the entry block. + auto postdominator_analysis = + context->GetPostDominatorAnalysis(entry_block->GetParent()); + if (!postdominator_analysis->Dominates(exit_block, entry_block)) { + return false; + } + + // Find all the blocks dominated by |message_.entry_block| and post-dominated + // by |message_.exit_block|. + auto region_set = GetRegionBlocks( + context, entry_block = context->cfg()->block(message_.entry_block()), + exit_block = context->cfg()->block(message_.exit_block())); + + // Check whether |region_set| really is a single-entry single-exit region, and + // also check whether structured control flow constructs and their merge + // and continue constructs are either wholly in or wholly out of the region - + // e.g. avoid the situation where the region contains the head of a loop but + // not the loop's continue construct. + // + // This is achieved by going through every block in the function that contains + // the region. + for (auto& block : *entry_block->GetParent()) { + if (&block == exit_block) { + // It is OK (and typically expected) for the exit block of the region to + // have successors outside the region. It is also OK for the exit block + // to head a structured control flow construct - the block containing the + // call to the outlined function will end up heading this construct if + // outlining takes place. + continue; + } + + if (region_set.count(&block) != 0) { + // The block is in the region and is not the region's exit block. Let's + // see whether all of the block's successors are in the region. If they + // are not, the region is not single-entry single-exit. + bool all_successors_in_region = true; + block.WhileEachSuccessorLabel([&all_successors_in_region, context, + ®ion_set](uint32_t successor) -> bool { + if (region_set.count(context->cfg()->block(successor)) == 0) { + all_successors_in_region = false; + return false; + } + return true; + }); + if (!all_successors_in_region) { + return false; + } + } + + if (auto merge = block.GetMergeInst()) { + // The block is a loop or selection header -- the header and its + // associated merge block had better both be in the region or both be + // outside the region. + auto merge_block = context->cfg()->block(merge->GetSingleWordOperand(0)); + if (region_set.count(&block) != region_set.count(merge_block)) { + return false; + } + } + + if (auto loop_merge = block.GetLoopMergeInst()) { + // Similar to the above, but for the continue target of a loop. + auto continue_target = + context->cfg()->block(loop_merge->GetSingleWordOperand(1)); + if (continue_target != exit_block && + region_set.count(&block) != region_set.count(continue_target)) { + return false; + } + } + } + + // For each region input id -- i.e. every id defined outside the region but + // used inside the region -- there needs to be a corresponding fresh id to be + // used as a function parameter. + std::map input_id_to_fresh_id_map = + PairSequenceToMap(message_.input_id_to_fresh_id()); + for (auto id : GetRegionInputIds(context, region_set, exit_block)) { + if (input_id_to_fresh_id_map.count(id) == 0) { + return false; + } + } + + // For each region output id -- i.e. every id defined inside the region but + // used outside the region -- there needs to be a corresponding fresh id that + // can hold the value for this id computed in the outlined function. + std::map output_id_to_fresh_id_map = + PairSequenceToMap(message_.output_id_to_fresh_id()); + for (auto id : GetRegionOutputIds(context, region_set, exit_block)) { + if (output_id_to_fresh_id_map.count(id) == 0) { + return false; + } + } + + return true; +} + +void TransformationOutlineFunction::Apply( + opt::IRContext* context, spvtools::fuzz::FactManager* /*unused*/) const { + // The entry block for the region before outlining. + auto original_region_entry_block = + context->cfg()->block(message_.entry_block()); + + // The exit block for the region before outlining. + auto original_region_exit_block = + context->cfg()->block(message_.exit_block()); + + // The single-entry single-exit region defined by |message_.entry_block| and + // |message_.exit_block|. + std::set region_blocks = GetRegionBlocks( + context, original_region_entry_block, original_region_exit_block); + + // Input and output ids for the region being outlined. + std::vector region_input_ids = + GetRegionInputIds(context, region_blocks, original_region_exit_block); + std::vector region_output_ids = + GetRegionOutputIds(context, region_blocks, original_region_exit_block); + + // Maps from input and output ids to fresh ids. + std::map input_id_to_fresh_id_map = + PairSequenceToMap(message_.input_id_to_fresh_id()); + std::map output_id_to_fresh_id_map = + PairSequenceToMap(message_.output_id_to_fresh_id()); + + UpdateModuleIdBoundForFreshIds(context, input_id_to_fresh_id_map, + output_id_to_fresh_id_map); + + // Construct a map that associates each output id with its type id. + std::map output_id_to_type_id; + for (uint32_t output_id : region_output_ids) { + output_id_to_type_id[output_id] = + context->get_def_use_mgr()->GetDef(output_id)->type_id(); + } + + // The region will be collapsed to a single block that calls a function + // containing the outlined region. This block needs to end with whatever + // the exit block of the region ended with before outlining. We thus clone + // the terminator of the region's exit block, and the merge instruction for + // the block if there is one, so that we can append them to the end of the + // collapsed block later. + std::unique_ptr cloned_exit_block_terminator = + std::unique_ptr( + original_region_exit_block->terminator()->Clone(context)); + std::unique_ptr cloned_exit_block_merge = + original_region_exit_block->GetMergeInst() + ? std::unique_ptr( + original_region_exit_block->GetMergeInst()->Clone(context)) + : nullptr; + + // Make a function prototype for the outlined function, which involves + // figuring out its required type. + std::unique_ptr outlined_function = PrepareFunctionPrototype( + context, region_input_ids, region_output_ids, input_id_to_fresh_id_map); + + // Adapt the region to be outlined so that its input ids are replaced with the + // ids of the outlined function's input parameters, and so that output ids + // are similarly remapped. + RemapInputAndOutputIdsInRegion( + context, *original_region_exit_block, region_blocks, region_input_ids, + region_output_ids, input_id_to_fresh_id_map, output_id_to_fresh_id_map); + + // Fill out the body of the outlined function according to the region that is + // being outlined. + PopulateOutlinedFunction(context, *original_region_entry_block, + *original_region_exit_block, region_blocks, + region_output_ids, output_id_to_fresh_id_map, + outlined_function.get()); + + // Collapse the region that has been outlined into a function down to a single + // block that calls said function. + ShrinkOriginalRegion( + context, region_blocks, region_input_ids, region_output_ids, + output_id_to_type_id, outlined_function->type_id(), + std::move(cloned_exit_block_merge), + std::move(cloned_exit_block_terminator), original_region_entry_block); + + // Add the outlined function to the module. + context->module()->AddFunction(std::move(outlined_function)); + + // Major surgery has been conducted on the module, so invalidate all analyses. + context->InvalidateAnalysesExceptFor(opt::IRContext::Analysis::kAnalysisNone); +} + +protobufs::Transformation TransformationOutlineFunction::ToMessage() const { + protobufs::Transformation result; + *result.mutable_outline_function() = message_; + return result; +} + +bool TransformationOutlineFunction:: + CheckIdIsFreshAndNotUsedByThisTransformation( + uint32_t id, opt::IRContext* context, + std::set* ids_used_by_this_transformation) const { + if (!fuzzerutil::IsFreshId(context, id)) { + return false; + } + if (ids_used_by_this_transformation->count(id) != 0) { + return false; + } + ids_used_by_this_transformation->insert(id); + return true; +} + +std::vector TransformationOutlineFunction::GetRegionInputIds( + opt::IRContext* context, const std::set& region_set, + opt::BasicBlock* region_exit_block) { + std::vector result; + + auto enclosing_function = region_exit_block->GetParent(); + + // Consider each parameter of the function containing the region. + enclosing_function->ForEachParam([context, ®ion_set, &result]( + opt::Instruction* function_parameter) { + // Consider every use of the parameter. + context->get_def_use_mgr()->WhileEachUse( + function_parameter, [context, function_parameter, ®ion_set, &result]( + opt::Instruction* use, uint32_t /*unused*/) { + // Get the block, if any, in which the parameter is used. + auto use_block = context->get_instr_block(use); + // If the use is in a block that lies within the region, the + // parameter is an input id for the region. + if (use_block && region_set.count(use_block) != 0) { + result.push_back(function_parameter->result_id()); + return false; + } + return true; + }); + }); + + // Consider all definitions in the function that might turn out to be input + // ids. + for (auto& block : *enclosing_function) { + std::vector candidate_input_ids_for_block; + if (region_set.count(&block) == 0) { + // All instructions in blocks outside the region are candidate's for + // generating input ids. + for (auto& inst : block) { + candidate_input_ids_for_block.push_back(&inst); + } + } else { + // Blocks in the region cannot generate input ids. + continue; + } + + // Consider each candidate input id to check whether it is used in the + // region. + for (auto& inst : candidate_input_ids_for_block) { + context->get_def_use_mgr()->WhileEachUse( + inst, + [context, &inst, region_exit_block, ®ion_set, &result]( + opt::Instruction* use, uint32_t /*unused*/) -> bool { + + // Find the block in which this id use occurs, recording the id as + // an input id if the block is outside the region, with some + // exceptions detailed below. + auto use_block = context->get_instr_block(use); + + if (!use_block) { + // There might be no containing block, e.g. if the use is in a + // decoration. + return true; + } + + if (region_set.count(use_block) == 0) { + // The use is not in the region: this does not make it an input + // id. + return true; + } + + if (use_block == region_exit_block && use->IsBlockTerminator()) { + // We do not regard uses in the exit block terminator as input + // ids, as this terminator does not get outlined. + return true; + } + + result.push_back(inst->result_id()); + return false; + }); + } + } + return result; +} + +std::vector TransformationOutlineFunction::GetRegionOutputIds( + opt::IRContext* context, const std::set& region_set, + opt::BasicBlock* region_exit_block) { + std::vector result; + + // Consider each block in the function containing the region. + for (auto& block : *region_exit_block->GetParent()) { + if (region_set.count(&block) == 0) { + // Skip blocks that are not in the region. + continue; + } + // Consider each use of each instruction defined in the block. + for (auto& inst : block) { + context->get_def_use_mgr()->WhileEachUse( + &inst, + [®ion_set, context, &inst, region_exit_block, &result]( + opt::Instruction* use, uint32_t /*unused*/) -> bool { + + // Find the block in which this id use occurs, recording the id as + // an output id if the block is outside the region, with some + // exceptions detailed below. + auto use_block = context->get_instr_block(use); + + if (!use_block) { + // There might be no containing block, e.g. if the use is in a + // decoration. + return true; + } + + if (region_set.count(use_block) != 0) { + // The use is in the region. + if (use_block != region_exit_block || !use->IsBlockTerminator()) { + // Furthermore, the use is not in the terminator of the region's + // exit block. + return true; + } + } + + result.push_back(inst.result_id()); + return false; + }); + } + } + return result; +} + +std::set TransformationOutlineFunction::GetRegionBlocks( + opt::IRContext* context, opt::BasicBlock* entry_block, + opt::BasicBlock* exit_block) { + auto enclosing_function = entry_block->GetParent(); + auto dominator_analysis = context->GetDominatorAnalysis(enclosing_function); + auto postdominator_analysis = + context->GetPostDominatorAnalysis(enclosing_function); + + std::set result; + for (auto& block : *enclosing_function) { + if (dominator_analysis->Dominates(entry_block, &block) && + postdominator_analysis->Dominates(exit_block, &block)) { + result.insert(&block); + } + } + return result; +} + +std::unique_ptr +TransformationOutlineFunction::PrepareFunctionPrototype( + opt::IRContext* context, const std::vector& region_input_ids, + const std::vector& region_output_ids, + const std::map& input_id_to_fresh_id_map) const { + uint32_t return_type_id = 0; + uint32_t function_type_id = 0; + + // First, try to find an existing function type that is suitable. This is + // only possible if the region generates no output ids; if it generates output + // ids we are going to make a new struct for those, and since that struct does + // not exist there cannot already be a function type with this struct as its + // return type. + if (region_output_ids.empty()) { + opt::analysis::Void void_type; + return_type_id = context->get_type_mgr()->GetId(&void_type); + std::vector argument_types; + for (auto id : region_input_ids) { + argument_types.push_back(context->get_type_mgr()->GetType( + context->get_def_use_mgr()->GetDef(id)->type_id())); + } + opt::analysis::Function function_type(&void_type, argument_types); + function_type_id = context->get_type_mgr()->GetId(&function_type); + } + + // If no existing function type was found, we need to create one. + if (function_type_id == 0) { + assert( + ((return_type_id == 0) == !region_output_ids.empty()) && + "We should only have set the return type if there are no output ids."); + // If the region generates output ids, we need to make a struct with one + // field per output id. + if (!region_output_ids.empty()) { + opt::Instruction::OperandList struct_member_types; + for (uint32_t output_id : region_output_ids) { + auto output_id_type = + context->get_def_use_mgr()->GetDef(output_id)->type_id(); + struct_member_types.push_back({SPV_OPERAND_TYPE_ID, {output_id_type}}); + } + // Add a new struct type to the module. + context->module()->AddType(MakeUnique( + context, SpvOpTypeStruct, 0, + message_.new_function_struct_return_type_id(), + std::move(struct_member_types))); + // The return type for the function is the newly-created struct. + return_type_id = message_.new_function_struct_return_type_id(); + } + assert( + return_type_id != 0 && + "We should either have a void return type, or have created a struct."); + + // The region's input ids dictate the parameter types to the function. + opt::Instruction::OperandList function_type_operands; + function_type_operands.push_back({SPV_OPERAND_TYPE_ID, {return_type_id}}); + for (auto id : region_input_ids) { + function_type_operands.push_back( + {SPV_OPERAND_TYPE_ID, + {context->get_def_use_mgr()->GetDef(id)->type_id()}}); + } + // Add a new function type to the module, and record that this is the type + // id for the new function. + context->module()->AddType(MakeUnique( + context, SpvOpTypeFunction, 0, message_.new_function_type_id(), + function_type_operands)); + function_type_id = message_.new_function_type_id(); + } + + // Create a new function with |message_.new_function_id| as the function id, + // and the return type and function type prepared above. + std::unique_ptr outlined_function = + MakeUnique(MakeUnique( + context, SpvOpFunction, return_type_id, message_.new_function_id(), + opt::Instruction::OperandList( + {{spv_operand_type_t ::SPV_OPERAND_TYPE_LITERAL_INTEGER, + {SpvFunctionControlMaskNone}}, + {spv_operand_type_t::SPV_OPERAND_TYPE_ID, + {function_type_id}}}))); + + // Add one parameter to the function for each input id, using the fresh ids + // provided in |input_id_to_fresh_id_map|. + for (auto id : region_input_ids) { + outlined_function->AddParameter(MakeUnique( + context, SpvOpFunctionParameter, + context->get_def_use_mgr()->GetDef(id)->type_id(), + input_id_to_fresh_id_map.at(id), opt::Instruction::OperandList())); + } + + return outlined_function; +} + +void TransformationOutlineFunction::UpdateModuleIdBoundForFreshIds( + opt::IRContext* context, + const std::map& input_id_to_fresh_id_map, + const std::map& output_id_to_fresh_id_map) const { + // Enlarge the module's id bound as needed to accommodate the various fresh + // ids associated with the transformation. + fuzzerutil::UpdateModuleIdBound( + context, message_.new_function_struct_return_type_id()); + fuzzerutil::UpdateModuleIdBound(context, message_.new_function_type_id()); + fuzzerutil::UpdateModuleIdBound(context, message_.new_function_id()); + fuzzerutil::UpdateModuleIdBound(context, + message_.new_function_region_entry_block()); + fuzzerutil::UpdateModuleIdBound(context, message_.new_caller_result_id()); + fuzzerutil::UpdateModuleIdBound(context, message_.new_callee_result_id()); + + for (auto& entry : input_id_to_fresh_id_map) { + fuzzerutil::UpdateModuleIdBound(context, entry.second); + } + + for (auto& entry : output_id_to_fresh_id_map) { + fuzzerutil::UpdateModuleIdBound(context, entry.second); + } +} + +void TransformationOutlineFunction::RemapInputAndOutputIdsInRegion( + opt::IRContext* context, const opt::BasicBlock& original_region_exit_block, + const std::set& region_blocks, + const std::vector& region_input_ids, + const std::vector& region_output_ids, + const std::map& input_id_to_fresh_id_map, + const std::map& output_id_to_fresh_id_map) const { + // Change all uses of input ids inside the region to the corresponding fresh + // ids that will ultimately be parameters of the outlined function. + // This is done by considering each region input id in turn. + for (uint32_t id : region_input_ids) { + // We then consider each use of the input id. + context->get_def_use_mgr()->ForEachUse( + id, [context, id, &input_id_to_fresh_id_map, region_blocks]( + opt::Instruction* use, uint32_t operand_index) { + // Find the block in which this use of the input id occurs. + opt::BasicBlock* use_block = context->get_instr_block(use); + // We want to rewrite the use id if its block occurs in the outlined + // region. + if (region_blocks.count(use_block) != 0) { + // Rewrite this use of the input id. + use->SetOperand(operand_index, {input_id_to_fresh_id_map.at(id)}); + } + }); + } + + // Change each definition of a region output id to define the corresponding + // fresh ids that will store intermediate value for the output ids. Also + // change all uses of the output id located in the outlined region. + // This is done by considering each region output id in turn. + for (uint32_t id : region_output_ids) { + // First consider each use of the output id and update the relevant uses. + context->get_def_use_mgr()->ForEachUse( + id, + [context, &original_region_exit_block, id, &output_id_to_fresh_id_map, + region_blocks](opt::Instruction* use, uint32_t operand_index) { + // Find the block in which this use of the output id occurs. + auto use_block = context->get_instr_block(use); + // We want to rewrite the use id if its block occurs in the outlined + // region, with one exception: the terminator of the exit block of + // the region is going to remain in the original function, so if the + // use appears in such a terminator instruction we leave it alone. + if ( + // The block is in the region ... + region_blocks.count(use_block) != 0 && + // ... and the use is not in the terminator instruction of the + // region's exit block. + !(use_block == &original_region_exit_block && + use->IsBlockTerminator())) { + // Rewrite this use of the output id. + use->SetOperand(operand_index, {output_id_to_fresh_id_map.at(id)}); + } + }); + + // Now change the instruction that defines the output id so that it instead + // defines the corresponding fresh id. We do this after changing all the + // uses so that the definition of the original id is still registered when + // we analyse its uses. + context->get_def_use_mgr()->GetDef(id)->SetResultId( + output_id_to_fresh_id_map.at(id)); + } +} + +void TransformationOutlineFunction::PopulateOutlinedFunction( + opt::IRContext* context, const opt::BasicBlock& original_region_entry_block, + const opt::BasicBlock& original_region_exit_block, + const std::set& region_blocks, + const std::vector& region_output_ids, + const std::map& output_id_to_fresh_id_map, + opt::Function* outlined_function) const { + // When we create the exit block for the outlined region, we use this pointer + // to track of it so that we can manipulate it later. + opt::BasicBlock* outlined_region_exit_block = nullptr; + + // The region entry block in the new function is identical to the entry block + // of the region being outlined, except that it has + // |message_.new_function_region_entry_block| as its id. + std::unique_ptr outlined_region_entry_block = + MakeUnique(MakeUnique( + context, SpvOpLabel, 0, message_.new_function_region_entry_block(), + opt::Instruction::OperandList())); + outlined_region_entry_block->SetParent(outlined_function); + if (&original_region_entry_block == &original_region_exit_block) { + outlined_region_exit_block = outlined_region_entry_block.get(); + } + + for (auto& inst : original_region_entry_block) { + outlined_region_entry_block->AddInstruction( + std::unique_ptr(inst.Clone(context))); + } + outlined_function->AddBasicBlock(std::move(outlined_region_entry_block)); + + // We now go through the single-entry single-exit region defined by the entry + // and exit blocks, adding clones of all blocks to the new function. + + // Consider every block in the enclosing function. + auto enclosing_function = original_region_entry_block.GetParent(); + for (auto block_it = enclosing_function->begin(); + block_it != enclosing_function->end();) { + // Skip the region's entry block - we already dealt with it above. + if (region_blocks.count(&*block_it) == 0 || + &*block_it == &original_region_entry_block) { + ++block_it; + continue; + } + // Clone the block so that it can be added to the new function. + auto cloned_block = + std::unique_ptr(block_it->Clone(context)); + + // If this is the region's exit block, then the cloned block is the outlined + // region's exit block. + if (&*block_it == &original_region_exit_block) { + assert(outlined_region_exit_block == nullptr && + "We should not yet have encountered the exit block."); + outlined_region_exit_block = cloned_block.get(); + } + + cloned_block->SetParent(outlined_function); + + // Redirect any OpPhi operands whose predecessors are the original region + // entry block to become the new function entry block. + cloned_block->ForEachPhiInst([this](opt::Instruction* phi_inst) { + for (uint32_t predecessor_index = 1; + predecessor_index < phi_inst->NumInOperands(); + predecessor_index += 2) { + if (phi_inst->GetSingleWordInOperand(predecessor_index) == + message_.entry_block()) { + phi_inst->SetInOperand(predecessor_index, + {message_.new_function_region_entry_block()}); + } + } + }); + + outlined_function->AddBasicBlock(std::move(cloned_block)); + block_it = block_it.Erase(); + } + assert(outlined_region_exit_block != nullptr && + "We should have encountered the region's exit block when iterating " + "through the function"); + + // We now need to adapt the exit block for the region - in the new function - + // so that it ends with a return. + + // We first eliminate the merge instruction (if any) and the terminator for + // the cloned exit block. + for (auto inst_it = outlined_region_exit_block->begin(); + inst_it != outlined_region_exit_block->end();) { + if (inst_it->opcode() == SpvOpLoopMerge || + inst_it->opcode() == SpvOpSelectionMerge) { + inst_it = inst_it.Erase(); + } else if (inst_it->IsBlockTerminator()) { + inst_it = inst_it.Erase(); + } else { + ++inst_it; + } + } + + // We now add either OpReturn or OpReturnValue as the cloned exit block's + // terminator. + if (region_output_ids.empty()) { + // The case where there are no region output ids is simple: we just add + // OpReturn. + outlined_region_exit_block->AddInstruction(MakeUnique( + context, SpvOpReturn, 0, 0, opt::Instruction::OperandList())); + } else { + // In the case where there are output ids, we add an OpCompositeConstruct + // instruction to pack all the output values into a struct, and then an + // OpReturnValue instruction to return this struct. + opt::Instruction::OperandList struct_member_operands; + for (uint32_t id : region_output_ids) { + struct_member_operands.push_back( + {SPV_OPERAND_TYPE_ID, {output_id_to_fresh_id_map.at(id)}}); + } + outlined_region_exit_block->AddInstruction(MakeUnique( + context, SpvOpCompositeConstruct, + message_.new_function_struct_return_type_id(), + message_.new_callee_result_id(), struct_member_operands)); + outlined_region_exit_block->AddInstruction(MakeUnique( + context, SpvOpReturnValue, 0, 0, + opt::Instruction::OperandList( + {{SPV_OPERAND_TYPE_ID, {message_.new_callee_result_id()}}}))); + } + + outlined_function->SetFunctionEnd(MakeUnique( + context, SpvOpFunctionEnd, 0, 0, opt::Instruction::OperandList())); +} + +void TransformationOutlineFunction::ShrinkOriginalRegion( + opt::IRContext* context, std::set& region_blocks, + const std::vector& region_input_ids, + const std::vector& region_output_ids, + const std::map& output_id_to_type_id, + uint32_t return_type_id, + std::unique_ptr cloned_exit_block_merge, + std::unique_ptr cloned_exit_block_terminator, + opt::BasicBlock* original_region_entry_block) const { + // Erase all blocks from the original function that are in the outlined + // region, except for the region's entry block. + // + // In the process, identify all references to the exit block of the region, + // as merge blocks, continue targets, or OpPhi predecessors, and rewrite them + // to refer to the region entry block (the single block to which we are + // shrinking the region). + auto enclosing_function = original_region_entry_block->GetParent(); + for (auto block_it = enclosing_function->begin(); + block_it != enclosing_function->end();) { + if (&*block_it == original_region_entry_block) { + ++block_it; + } else if (region_blocks.count(&*block_it) == 0) { + // The block is not in the region. Check whether it has the last block + // of the region as an OpPhi predecessor, and if so change the + // predecessor to be the first block of the region (i.e. the block + // containing the call to what was outlined). + assert(block_it->MergeBlockIdIfAny() != message_.exit_block() && + "Outlined region must not end with a merge block"); + assert(block_it->ContinueBlockIdIfAny() != message_.exit_block() && + "Outlined region must not end with a continue target"); + block_it->ForEachPhiInst([this](opt::Instruction* phi_inst) { + for (uint32_t predecessor_index = 1; + predecessor_index < phi_inst->NumInOperands(); + predecessor_index += 2) { + if (phi_inst->GetSingleWordInOperand(predecessor_index) == + message_.exit_block()) { + phi_inst->SetInOperand(predecessor_index, {message_.entry_block()}); + } + } + }); + ++block_it; + } else { + // The block is in the region and is not the region's entry block: kill + // it. + block_it = block_it.Erase(); + } + } + + // Now erase all instructions from the region's entry block, as they have + // been outlined. + for (auto inst_it = original_region_entry_block->begin(); + inst_it != original_region_entry_block->end();) { + inst_it = inst_it.Erase(); + } + + // Now we add a call to the outlined function to the region's entry block. + opt::Instruction::OperandList function_call_operands; + function_call_operands.push_back( + {SPV_OPERAND_TYPE_ID, {message_.new_function_id()}}); + // The function parameters are the region input ids. + for (auto input_id : region_input_ids) { + function_call_operands.push_back({SPV_OPERAND_TYPE_ID, {input_id}}); + } + + original_region_entry_block->AddInstruction(MakeUnique( + context, SpvOpFunctionCall, return_type_id, + message_.new_caller_result_id(), function_call_operands)); + + // If there are output ids, the function call will return a struct. For each + // output id, we add an extract operation to pull the appropriate struct + // member out into an output id. + for (uint32_t index = 0; index < region_output_ids.size(); ++index) { + uint32_t output_id = region_output_ids[index]; + original_region_entry_block->AddInstruction(MakeUnique( + context, SpvOpCompositeExtract, output_id_to_type_id.at(output_id), + output_id, + opt::Instruction::OperandList( + {{SPV_OPERAND_TYPE_ID, {message_.new_caller_result_id()}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {index}}}))); + } + + // Finally, we terminate the block with the merge instruction (if any) that + // used to belong to the region's exit block, and the terminator that used + // to belong to the region's exit block. + if (cloned_exit_block_merge != nullptr) { + original_region_entry_block->AddInstruction( + std::move(cloned_exit_block_merge)); + } + original_region_entry_block->AddInstruction( + std::move(cloned_exit_block_terminator)); +} + +} // namespace fuzz +} // namespace spvtools diff --git a/source/fuzz/transformation_outline_function.h b/source/fuzz/transformation_outline_function.h new file mode 100644 index 000000000..784499d28 --- /dev/null +++ b/source/fuzz/transformation_outline_function.h @@ -0,0 +1,222 @@ +// Copyright (c) 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_FUZZ_TRANSFORMATION_OUTLINE_FUNCTION_H_ +#define SOURCE_FUZZ_TRANSFORMATION_OUTLINE_FUNCTION_H_ + +#include +#include +#include + +#include "source/fuzz/fact_manager.h" +#include "source/fuzz/protobufs/spirvfuzz_protobufs.h" +#include "source/fuzz/transformation.h" +#include "source/opt/ir_context.h" + +namespace spvtools { +namespace fuzz { + +class TransformationOutlineFunction : public Transformation { + public: + explicit TransformationOutlineFunction( + const protobufs::TransformationOutlineFunction& message); + + TransformationOutlineFunction( + uint32_t entry_block, uint32_t exit_block, + uint32_t new_function_struct_return_type_id, + uint32_t new_function_type_id, uint32_t new_function_id, + uint32_t new_function_region_entry_block, uint32_t new_caller_result_id, + uint32_t new_callee_result_id, + std::map&& input_id_to_fresh_id, + std::map&& output_id_to_fresh_id); + + // - All the fresh ids occurring in the transformation must be distinct and + // fresh + // - |message_.entry_block| and |message_.exit_block| must form a single-entry + // single-exit control flow graph region + // - |message_.entry_block| must not start with OpVariable + // - |message_.entry_block| must not be a loop header + // - |message_.exit_block| must not be a merge block or the continue target + // of a loop + // - A structured control flow construct must lie either completely within the + // region or completely outside it + // - |message.entry_block| must not start with OpPhi; this is to keep the + // transformation simple - another transformation should be used to split + // a desired entry block that starts with OpPhi if needed + // - |message_.input_id_to_fresh_id| must contain an entry for every id + // defined outside the region but used in the region + // - |message_.output_id_to_fresh_id| must contain an entry for every id + // defined in the region but used outside the region + bool IsApplicable(opt::IRContext* context, + const FactManager& fact_manager) const override; + + // - A new function with id |message_.new_function_id| is added to the module. + // - If the region generates output ids, the return type of this function is + // a new struct type with one field per output id, and with type id + // |message_.new_function_struct_return_type|, otherwise the function return + // types is void and |message_.new_function_struct_return_type| is not used. + // - If the region generates input ids, the new function has one parameter per + // input id. Fresh ids for these parameters are provided by + // |message_.input_id_to_fresh_id|. + // - Unless the type required for the new function is already known, + // |message_.new_function_type_id| is used as the type id for a new function + // type, and the new function uses this type. + // - The new function starts with a dummy block with id + // |message_.new_function_first_block|, which jumps straight to a successor + // block, to avoid violating rules on what the first block in a function may + // look like. + // - The outlined region is replaced with a single block, with the same id + // as |message_.entry_block|, and which calls the new function, passing the + // region's input ids as parameters. The result is stored in + // |message_.new_caller_result_id|, which has type + // |message_.new_function_struct_return_type| (unless there are + // no output ids, in which case the return type is void). The components + // of this returned struct are then copied out into the region's output ids. + // The block ends with the merge instruction (if any) and terminator of + // |message_.exit_block|. + // - The body of the new function is identical to the outlined region, except + // that (a) the region's entry block has id + // |message_.new_function_region_entry_block|, (b) input id uses are + // replaced with parameter accesses, (c) and definitions of output ids are + // replaced with definitions of corresponding fresh ids provided by + // |message_.output_id_to_fresh_id|, and (d) the block of the function + // ends by returning a composite of type + // |message_.new_function_struct_return_type| comprised of all the fresh + // output ids (unless the return type is void, in which case no value is + // returned. + void Apply(opt::IRContext* context, FactManager* fact_manager) const override; + + protobufs::Transformation ToMessage() const override; + + // Returns the set of blocks dominated by |entry_block| and post-dominated + // by |exit_block|. + static std::set GetRegionBlocks( + opt::IRContext* context, opt::BasicBlock* entry_block, + opt::BasicBlock* exit_block); + + // Yields ids that are used in |region_set| and that are either parameters + // to the function containing |region_set|, or are defined by blocks of this + // function that are outside |region_set|. + // + // Special cases: OpPhi instructions in |region_entry_block| and the + // terminator of |region_exit_block| do not get outlined, therefore + // - id uses in OpPhi instructions in |region_entry_block| are ignored + // - id uses in the terminator instruction of |region_exit_block| are ignored + static std::vector GetRegionInputIds( + opt::IRContext* context, const std::set& region_set, + opt::BasicBlock* region_exit_block); + + // Yields all ids that are defined in |region_set| and used outside + // |region_set|. + // + // Special cases: for similar reasons as for |GetRegionInputIds|, + // - ids defined in the region and used in the terminator of + // |region_exit_block| count as output ids + static std::vector GetRegionOutputIds( + opt::IRContext* context, const std::set& region_set, + opt::BasicBlock* region_exit_block); + + private: + // A helper method for the applicability check. Returns true if and only if + // |id| is (a) a fresh id for the module, and (b) an id that has not + // previously been subject to this check. We use this to check whether the + // ids given for the transformation are not only fresh but also different from + // one another. + bool CheckIdIsFreshAndNotUsedByThisTransformation( + uint32_t id, opt::IRContext* context, + std::set* ids_used_by_this_transformation) const; + + // Ensures that the module's id bound is at least the maximum of any fresh id + // associated with the transformation. + void UpdateModuleIdBoundForFreshIds( + opt::IRContext* context, + const std::map& input_id_to_fresh_id_map, + const std::map& output_id_to_fresh_id_map) const; + + // Uses |input_id_to_fresh_id_map| and |output_id_to_fresh_id_map| to convert, + // in the region to be outlined, all the input ids in |region_input_ids| and + // the output ids in |region_output_ids| to their fresh counterparts. + // Parameters |region_blocks| provides access to the blocks that must be + // modified, and |original_region_exit_block| allows for some special cases + // where ids should not be remapped. + void RemapInputAndOutputIdsInRegion( + opt::IRContext* context, + const opt::BasicBlock& original_region_exit_block, + const std::set& region_blocks, + const std::vector& region_input_ids, + const std::vector& region_output_ids, + const std::map& input_id_to_fresh_id_map, + const std::map& output_id_to_fresh_id_map) const; + + // Produce a Function object that has the right function type and parameter + // declarations. The function argument types and parameter ids are dictated + // by |region_input_ids| and |input_id_to_fresh_id_map|. The function return + // type is dictated by |region_output_ids|. + // + // A new struct type to represent the function return type, and a new function + // type for the function, will be added to the module (unless suitable types + // are already present). + std::unique_ptr PrepareFunctionPrototype( + opt::IRContext* context, const std::vector& region_input_ids, + const std::vector& region_output_ids, + const std::map& input_id_to_fresh_id_map) const; + + // Creates the body of the outlined function by cloning blocks from the + // original region, given by |region_blocks|, adapting the cloned version + // of |original_region_exit_block| so that it returns something appropriate, + // and patching up branches to |original_region_entry_block| to refer to its + // clone. Parameters |region_output_ids| and |output_id_to_fresh_id_map| are + // used to determine what the function should return. + void PopulateOutlinedFunction( + opt::IRContext* context, + const opt::BasicBlock& original_region_entry_block, + const opt::BasicBlock& original_region_exit_block, + const std::set& region_blocks, + const std::vector& region_output_ids, + const std::map& output_id_to_fresh_id_map, + opt::Function* outlined_function) const; + + // Shrinks the outlined region, given by |region_blocks|, down to the single + // block |original_region_entry_block|. This block is itself shrunk to just + // contain: + // - any OpPhi instructions that were originally present + // - a call to the outlined function, with parameters provided by + // |region_input_ids| + // - instructions to route components of the call's return value into + // |region_output_ids| + // - The merge instruction (if any) and terminator of the original region's + // exit block, given by |cloned_exit_block_merge| and + // |cloned_exit_block_terminator| + // Parameters |output_id_to_type_id| and |return_type_id| provide the + // provide types for the region's output ids, and the return type of the + // outlined function: as the module is in an inconsistent state when this + // function is called, this information cannot be gotten from the def-use + // manager. + void ShrinkOriginalRegion( + opt::IRContext* context, std::set& region_blocks, + const std::vector& region_input_ids, + const std::vector& region_output_ids, + const std::map& output_id_to_type_id, + uint32_t return_type_id, + std::unique_ptr cloned_exit_block_merge, + std::unique_ptr cloned_exit_block_terminator, + opt::BasicBlock* original_region_entry_block) const; + + protobufs::TransformationOutlineFunction message_; +}; + +} // namespace fuzz +} // namespace spvtools + +#endif // SOURCE_FUZZ_TRANSFORMATION_OUTLINE_FUNCTION_H_ diff --git a/test/fuzz/CMakeLists.txt b/test/fuzz/CMakeLists.txt index b38f35e9d..a36027b59 100644 --- a/test/fuzz/CMakeLists.txt +++ b/test/fuzz/CMakeLists.txt @@ -36,6 +36,7 @@ if (${SPIRV_BUILD_FUZZER}) transformation_composite_extract_test.cpp transformation_copy_object_test.cpp transformation_move_block_down_test.cpp + transformation_outline_function_test.cpp transformation_replace_boolean_constant_with_constant_binary_test.cpp transformation_replace_constant_with_uniform_test.cpp transformation_replace_id_with_synonym_test.cpp diff --git a/test/fuzz/transformation_outline_function_test.cpp b/test/fuzz/transformation_outline_function_test.cpp new file mode 100644 index 000000000..de82ebe25 --- /dev/null +++ b/test/fuzz/transformation_outline_function_test.cpp @@ -0,0 +1,1989 @@ +// Copyright (c) 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/fuzz/transformation_outline_function.h" +#include "test/fuzz/fuzz_test_util.h" + +namespace spvtools { +namespace fuzz { +namespace { + +TEST(TransformationOutlineFunctionTest, TrivialOutline) { + // This tests outlining of a single, empty basic block. + + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + OpName %4 "main" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_4; + const auto consumer = nullptr; + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context.get())); + + FactManager fact_manager; + + TransformationOutlineFunction transformation(5, 5, /* not relevant */ 200, + 100, 101, 102, 103, + /* not relevant */ 201, {}, {}); + ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); + transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE(IsValid(env, context.get())); + + std::string after_transformation = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + OpName %4 "main" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %103 = OpFunctionCall %2 %101 + OpReturn + OpFunctionEnd + %101 = OpFunction %2 None %3 + %102 = OpLabel + OpReturn + OpFunctionEnd + )"; + ASSERT_TRUE(IsEqual(env, after_transformation, context.get())); +} + +TEST(TransformationOutlineFunctionTest, + DoNotOutlineIfRegionStartsWithOpVariable) { + // This checks that we do not outline the first block of a function if it + // contains OpVariable. + + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + OpName %4 "main" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %7 = OpTypeBool + %8 = OpTypePointer Function %7 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %6 = OpVariable %8 Function + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_4; + const auto consumer = nullptr; + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context.get())); + + FactManager fact_manager; + + TransformationOutlineFunction transformation(5, 5, /* not relevant */ 200, + 100, 101, 102, 103, + /* not relevant */ 201, {}, {}); + ASSERT_FALSE(transformation.IsApplicable(context.get(), fact_manager)); +} + +TEST(TransformationOutlineFunctionTest, OutlineInterestingControlFlowNoState) { + // This tests outlining of some non-trivial control flow, but such that the + // basic blocks in the control flow do not actually do anything. + + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + OpName %4 "main" + %2 = OpTypeVoid + %20 = OpTypeBool + %21 = OpConstantTrue %20 + %3 = OpTypeFunction %2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpBranch %6 + %6 = OpLabel + OpBranch %7 + %7 = OpLabel + OpSelectionMerge %9 None + OpBranchConditional %21 %8 %9 + %8 = OpLabel + OpBranch %9 + %9 = OpLabel + OpLoopMerge %12 %11 None + OpBranch %10 + %10 = OpLabel + OpBranchConditional %21 %11 %12 + %11 = OpLabel + OpBranch %9 + %12 = OpLabel + OpBranch %13 + %13 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_4; + const auto consumer = nullptr; + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context.get())); + + FactManager fact_manager; + + TransformationOutlineFunction transformation(6, 13, /* not relevant */ + 200, 100, 101, 102, 103, + /* not relevant */ 201, {}, {}); + ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); + transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE(IsValid(env, context.get())); + + std::string after_transformation = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + OpName %4 "main" + %2 = OpTypeVoid + %20 = OpTypeBool + %21 = OpConstantTrue %20 + %3 = OpTypeFunction %2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpBranch %6 + %6 = OpLabel + %103 = OpFunctionCall %2 %101 + OpReturn + OpFunctionEnd + %101 = OpFunction %2 None %3 + %102 = OpLabel + OpBranch %7 + %7 = OpLabel + OpSelectionMerge %9 None + OpBranchConditional %21 %8 %9 + %8 = OpLabel + OpBranch %9 + %9 = OpLabel + OpLoopMerge %12 %11 None + OpBranch %10 + %10 = OpLabel + OpBranchConditional %21 %11 %12 + %11 = OpLabel + OpBranch %9 + %12 = OpLabel + OpBranch %13 + %13 = OpLabel + OpReturn + OpFunctionEnd + )"; + ASSERT_TRUE(IsEqual(env, after_transformation, context.get())); +} + +TEST(TransformationOutlineFunctionTest, OutlineCodeThatGeneratesUnusedIds) { + // This tests outlining of a single basic block that does some computation, + // but that does not use nor generate ids required outside of the outlined + // region. + + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + OpName %4 "main" + %2 = OpTypeVoid + %20 = OpTypeInt 32 1 + %21 = OpConstant %20 5 + %3 = OpTypeFunction %2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpBranch %6 + %6 = OpLabel + %7 = OpCopyObject %20 %21 + %8 = OpCopyObject %20 %21 + %9 = OpIAdd %20 %7 %8 + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_4; + const auto consumer = nullptr; + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context.get())); + + FactManager fact_manager; + + TransformationOutlineFunction transformation(6, 6, /* not relevant */ 200, + 100, 101, 102, 103, + /* not relevant */ 201, {}, {}); + ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); + transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE(IsValid(env, context.get())); + + std::string after_transformation = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + OpName %4 "main" + %2 = OpTypeVoid + %20 = OpTypeInt 32 1 + %21 = OpConstant %20 5 + %3 = OpTypeFunction %2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpBranch %6 + %6 = OpLabel + %103 = OpFunctionCall %2 %101 + OpReturn + OpFunctionEnd + %101 = OpFunction %2 None %3 + %102 = OpLabel + %7 = OpCopyObject %20 %21 + %8 = OpCopyObject %20 %21 + %9 = OpIAdd %20 %7 %8 + OpReturn + OpFunctionEnd + )"; + ASSERT_TRUE(IsEqual(env, after_transformation, context.get())); +} + +TEST(TransformationOutlineFunctionTest, OutlineCodeThatGeneratesSingleUsedId) { + // This tests outlining of a block that generates an id that is used in a + // later block. + + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + OpName %4 "main" + %2 = OpTypeVoid + %20 = OpTypeInt 32 1 + %21 = OpConstant %20 5 + %3 = OpTypeFunction %2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpBranch %6 + %6 = OpLabel + %7 = OpCopyObject %20 %21 + %8 = OpCopyObject %20 %21 + %9 = OpIAdd %20 %7 %8 + OpBranch %10 + %10 = OpLabel + %11 = OpCopyObject %20 %9 + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_4; + const auto consumer = nullptr; + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context.get())); + + FactManager fact_manager; + + TransformationOutlineFunction transformation(6, 6, 99, 100, 101, 102, 103, + 105, {}, {{9, 104}}); + ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); + transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE(IsValid(env, context.get())); + + std::string after_transformation = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + OpName %4 "main" + %2 = OpTypeVoid + %20 = OpTypeInt 32 1 + %21 = OpConstant %20 5 + %3 = OpTypeFunction %2 + %99 = OpTypeStruct %20 + %100 = OpTypeFunction %99 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpBranch %6 + %6 = OpLabel + %103 = OpFunctionCall %99 %101 + %9 = OpCompositeExtract %20 %103 0 + OpBranch %10 + %10 = OpLabel + %11 = OpCopyObject %20 %9 + OpReturn + OpFunctionEnd + %101 = OpFunction %99 None %100 + %102 = OpLabel + %7 = OpCopyObject %20 %21 + %8 = OpCopyObject %20 %21 + %104 = OpIAdd %20 %7 %8 + %105 = OpCompositeConstruct %99 %104 + OpReturnValue %105 + OpFunctionEnd + )"; + ASSERT_TRUE(IsEqual(env, after_transformation, context.get())); +} + +TEST(TransformationOutlineFunctionTest, OutlineDiamondThatGeneratesSeveralIds) { + // This tests outlining of several blocks that generate a number of ids that + // are used in later blocks. + + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + OpName %4 "main" + %2 = OpTypeVoid + %20 = OpTypeInt 32 1 + %21 = OpConstant %20 5 + %22 = OpTypeBool + %3 = OpTypeFunction %2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpBranch %6 + %6 = OpLabel + %7 = OpCopyObject %20 %21 + %8 = OpCopyObject %20 %21 + %9 = OpSLessThan %22 %7 %8 + OpSelectionMerge %12 None + OpBranchConditional %9 %10 %11 + %10 = OpLabel + %13 = OpIAdd %20 %7 %8 + OpBranch %12 + %11 = OpLabel + %14 = OpIAdd %20 %7 %7 + OpBranch %12 + %12 = OpLabel + %15 = OpPhi %20 %13 %10 %14 %11 + OpBranch %80 + %80 = OpLabel + OpBranch %16 + %16 = OpLabel + %17 = OpCopyObject %20 %15 + %18 = OpCopyObject %22 %9 + %19 = OpIAdd %20 %7 %8 + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_4; + const auto consumer = nullptr; + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context.get())); + + FactManager fact_manager; + + TransformationOutlineFunction transformation( + 6, 80, 100, 101, 102, 103, 104, 105, {}, + {{15, 106}, {9, 107}, {7, 108}, {8, 109}}); + ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); + transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE(IsValid(env, context.get())); + + std::string after_transformation = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + OpName %4 "main" + %2 = OpTypeVoid + %20 = OpTypeInt 32 1 + %21 = OpConstant %20 5 + %22 = OpTypeBool + %3 = OpTypeFunction %2 + %100 = OpTypeStruct %20 %20 %22 %20 + %101 = OpTypeFunction %100 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpBranch %6 + %6 = OpLabel + %104 = OpFunctionCall %100 %102 + %7 = OpCompositeExtract %20 %104 0 + %8 = OpCompositeExtract %20 %104 1 + %9 = OpCompositeExtract %22 %104 2 + %15 = OpCompositeExtract %20 %104 3 + OpBranch %16 + %16 = OpLabel + %17 = OpCopyObject %20 %15 + %18 = OpCopyObject %22 %9 + %19 = OpIAdd %20 %7 %8 + OpReturn + OpFunctionEnd + %102 = OpFunction %100 None %101 + %103 = OpLabel + %108 = OpCopyObject %20 %21 + %109 = OpCopyObject %20 %21 + %107 = OpSLessThan %22 %108 %109 + OpSelectionMerge %12 None + OpBranchConditional %107 %10 %11 + %10 = OpLabel + %13 = OpIAdd %20 %108 %109 + OpBranch %12 + %11 = OpLabel + %14 = OpIAdd %20 %108 %108 + OpBranch %12 + %12 = OpLabel + %106 = OpPhi %20 %13 %10 %14 %11 + OpBranch %80 + %80 = OpLabel + %105 = OpCompositeConstruct %100 %108 %109 %107 %106 + OpReturnValue %105 + OpFunctionEnd + )"; + ASSERT_TRUE(IsEqual(env, after_transformation, context.get())); +} + +TEST(TransformationOutlineFunctionTest, OutlineCodeThatUsesASingleId) { + // This tests outlining of a block that uses an id defined earlier. + + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + OpName %4 "main" + %2 = OpTypeVoid + %20 = OpTypeInt 32 1 + %21 = OpConstant %20 5 + %3 = OpTypeFunction %2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %7 = OpCopyObject %20 %21 + OpBranch %6 + %6 = OpLabel + %8 = OpCopyObject %20 %7 + OpBranch %10 + %10 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_4; + const auto consumer = nullptr; + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context.get())); + + FactManager fact_manager; + + TransformationOutlineFunction transformation(6, 6, 100, 101, 102, 103, 104, + 105, {{7, 106}}, {}); + ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); + transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE(IsValid(env, context.get())); + + std::string after_transformation = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + OpName %4 "main" + %2 = OpTypeVoid + %20 = OpTypeInt 32 1 + %21 = OpConstant %20 5 + %3 = OpTypeFunction %2 + %101 = OpTypeFunction %2 %20 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %7 = OpCopyObject %20 %21 + OpBranch %6 + %6 = OpLabel + %104 = OpFunctionCall %2 %102 %7 + OpBranch %10 + %10 = OpLabel + OpReturn + OpFunctionEnd + %102 = OpFunction %2 None %101 + %106 = OpFunctionParameter %20 + %103 = OpLabel + %8 = OpCopyObject %20 %106 + OpReturn + OpFunctionEnd + )"; + ASSERT_TRUE(IsEqual(env, after_transformation, context.get())); +} + +TEST(TransformationOutlineFunctionTest, OutlineCodeThatUsesAVariable) { + // This tests outlining of a block that uses a variable. + + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + OpName %4 "main" + %2 = OpTypeVoid + %20 = OpTypeInt 32 1 + %21 = OpConstant %20 5 + %3 = OpTypeFunction %2 + %12 = OpTypePointer Function %20 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %13 = OpVariable %12 Function + OpBranch %6 + %6 = OpLabel + %8 = OpLoad %20 %13 + OpBranch %10 + %10 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_4; + const auto consumer = nullptr; + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context.get())); + + FactManager fact_manager; + + TransformationOutlineFunction transformation(6, 6, 100, 101, 102, 103, 104, + 105, {{13, 106}}, {}); + ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); + transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE(IsValid(env, context.get())); + + std::string after_transformation = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + OpName %4 "main" + %2 = OpTypeVoid + %20 = OpTypeInt 32 1 + %21 = OpConstant %20 5 + %3 = OpTypeFunction %2 + %12 = OpTypePointer Function %20 + %101 = OpTypeFunction %2 %12 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %13 = OpVariable %12 Function + OpBranch %6 + %6 = OpLabel + %104 = OpFunctionCall %2 %102 %13 + OpBranch %10 + %10 = OpLabel + OpReturn + OpFunctionEnd + %102 = OpFunction %2 None %101 + %106 = OpFunctionParameter %12 + %103 = OpLabel + %8 = OpLoad %20 %106 + OpReturn + OpFunctionEnd + )"; + ASSERT_TRUE(IsEqual(env, after_transformation, context.get())); +} + +TEST(TransformationOutlineFunctionTest, OutlineCodeThatUsesAParameter) { + // This tests outlining of a block that uses a function parameter. + + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + OpName %4 "main" + OpName %10 "foo(i1;" + OpName %9 "x" + OpName %18 "param" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %8 = OpTypeFunction %6 %7 + %13 = OpConstant %6 1 + %17 = OpConstant %6 3 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %18 = OpVariable %7 Function + OpStore %18 %17 + %19 = OpFunctionCall %6 %10 %18 + OpReturn + OpFunctionEnd + %10 = OpFunction %6 None %8 + %9 = OpFunctionParameter %7 + %11 = OpLabel + %12 = OpLoad %6 %9 + %14 = OpIAdd %6 %12 %13 + OpReturnValue %14 + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_4; + const auto consumer = nullptr; + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context.get())); + + FactManager fact_manager; + + TransformationOutlineFunction transformation(11, 11, 100, 101, 102, 103, 104, + 105, {{9, 106}}, {{14, 107}}); + ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); + transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE(IsValid(env, context.get())); + + std::string after_transformation = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + OpName %4 "main" + OpName %10 "foo(i1;" + OpName %9 "x" + OpName %18 "param" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %8 = OpTypeFunction %6 %7 + %13 = OpConstant %6 1 + %17 = OpConstant %6 3 + %100 = OpTypeStruct %6 + %101 = OpTypeFunction %100 %7 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %18 = OpVariable %7 Function + OpStore %18 %17 + %19 = OpFunctionCall %6 %10 %18 + OpReturn + OpFunctionEnd + %10 = OpFunction %6 None %8 + %9 = OpFunctionParameter %7 + %11 = OpLabel + %104 = OpFunctionCall %100 %102 %9 + %14 = OpCompositeExtract %6 %104 0 + OpReturnValue %14 + OpFunctionEnd + %102 = OpFunction %100 None %101 + %106 = OpFunctionParameter %7 + %103 = OpLabel + %12 = OpLoad %6 %106 + %107 = OpIAdd %6 %12 %13 + %105 = OpCompositeConstruct %100 %107 + OpReturnValue %105 + OpFunctionEnd + )"; + ASSERT_TRUE(IsEqual(env, after_transformation, context.get())); +} + +TEST(TransformationOutlineFunctionTest, + DoNotOutlineIfLoopMergeIsOutsideRegion) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + OpName %4 "main" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %9 = OpTypeBool + %10 = OpConstantTrue %9 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpBranch %6 + %6 = OpLabel + OpLoopMerge %7 %8 None + OpBranch %8 + %8 = OpLabel + OpBranchConditional %10 %6 %7 + %7 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_4; + const auto consumer = nullptr; + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context.get())); + + FactManager fact_manager; + + TransformationOutlineFunction transformation(6, 8, 100, 101, 102, 103, 104, + 105, {}, {}); + ASSERT_FALSE(transformation.IsApplicable(context.get(), fact_manager)); +} + +TEST(TransformationOutlineFunctionTest, DoNotOutlineIfRegionInvolvesReturn) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + OpName %4 "main" + %2 = OpTypeVoid + %20 = OpTypeBool + %21 = OpConstantTrue %20 + %3 = OpTypeFunction %2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpBranch %6 + %6 = OpLabel + OpBranch %7 + %7 = OpLabel + OpSelectionMerge %10 None + OpBranchConditional %21 %8 %9 + %8 = OpLabel + OpReturn + %9 = OpLabel + OpBranch %10 + %10 = OpLabel + OpBranch %11 + %11 = OpLabel + OpBranch %12 + %12 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_4; + const auto consumer = nullptr; + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context.get())); + + FactManager fact_manager; + + TransformationOutlineFunction transformation(6, 11, /* not relevant */ 200, + 100, 101, 102, 103, + /* not relevant */ 201, {}, {}); + ASSERT_FALSE(transformation.IsApplicable(context.get(), fact_manager)); +} + +TEST(TransformationOutlineFunctionTest, DoNotOutlineIfRegionInvolvesKill) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + OpName %4 "main" + %2 = OpTypeVoid + %20 = OpTypeBool + %21 = OpConstantTrue %20 + %3 = OpTypeFunction %2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpBranch %6 + %6 = OpLabel + OpBranch %7 + %7 = OpLabel + OpSelectionMerge %10 None + OpBranchConditional %21 %8 %9 + %8 = OpLabel + OpKill + %9 = OpLabel + OpBranch %10 + %10 = OpLabel + OpBranch %11 + %11 = OpLabel + OpBranch %12 + %12 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_4; + const auto consumer = nullptr; + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context.get())); + + FactManager fact_manager; + + TransformationOutlineFunction transformation(6, 11, /* not relevant */ 200, + 100, 101, 102, 103, + /* not relevant */ 201, {}, {}); + ASSERT_FALSE(transformation.IsApplicable(context.get(), fact_manager)); +} + +TEST(TransformationOutlineFunctionTest, + DoNotOutlineIfRegionInvolvesUnreachable) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + OpName %4 "main" + %2 = OpTypeVoid + %20 = OpTypeBool + %21 = OpConstantTrue %20 + %3 = OpTypeFunction %2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpBranch %6 + %6 = OpLabel + OpBranch %7 + %7 = OpLabel + OpSelectionMerge %10 None + OpBranchConditional %21 %8 %9 + %8 = OpLabel + OpBranch %10 + %9 = OpLabel + OpUnreachable + %10 = OpLabel + OpBranch %11 + %11 = OpLabel + OpBranch %12 + %12 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_4; + const auto consumer = nullptr; + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context.get())); + + FactManager fact_manager; + + TransformationOutlineFunction transformation(6, 11, /* not relevant */ 200, + 100, 101, 102, 103, + /* not relevant */ 201, {}, {}); + ASSERT_FALSE(transformation.IsApplicable(context.get(), fact_manager)); +} + +TEST(TransformationOutlineFunctionTest, + DoNotOutlineIfSelectionMergeIsOutsideRegion) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + OpName %4 "main" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %9 = OpTypeBool + %10 = OpConstantTrue %9 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpBranch %6 + %6 = OpLabel + OpSelectionMerge %7 None + OpBranchConditional %10 %8 %7 + %8 = OpLabel + OpBranch %7 + %7 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_4; + const auto consumer = nullptr; + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context.get())); + + FactManager fact_manager; + + TransformationOutlineFunction transformation(6, 8, 100, 101, 102, 103, 104, + 105, {}, {}); + ASSERT_FALSE(transformation.IsApplicable(context.get(), fact_manager)); +} + +TEST(TransformationOutlineFunctionTest, DoNotOutlineIfLoopHeadIsOutsideRegion) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + OpName %4 "main" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %9 = OpTypeBool + %10 = OpConstantTrue %9 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpBranch %6 + %6 = OpLabel + OpLoopMerge %8 %11 None + OpBranch %7 + %7 = OpLabel + OpBranchConditional %10 %11 %8 + %11 = OpLabel + OpBranch %6 + %8 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_4; + const auto consumer = nullptr; + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context.get())); + + FactManager fact_manager; + + TransformationOutlineFunction transformation(7, 8, 100, 101, 102, 103, 104, + 105, {}, {}); + ASSERT_FALSE(transformation.IsApplicable(context.get(), fact_manager)); +} + +TEST(TransformationOutlineFunctionTest, + DoNotOutlineIfLoopContinueIsOutsideRegion) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + OpName %4 "main" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %9 = OpTypeBool + %10 = OpConstantTrue %9 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpBranch %6 + %6 = OpLabel + OpLoopMerge %7 %8 None + OpBranch %7 + %8 = OpLabel + OpBranch %6 + %7 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_4; + const auto consumer = nullptr; + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context.get())); + + FactManager fact_manager; + + TransformationOutlineFunction transformation(6, 7, 100, 101, 102, 103, 104, + 105, {}, {}); + ASSERT_FALSE(transformation.IsApplicable(context.get(), fact_manager)); +} + +TEST(TransformationOutlineFunctionTest, + DoNotOutlineWithLoopCarriedPhiDependence) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + OpName %4 "main" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %9 = OpTypeBool + %10 = OpConstantTrue %9 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpBranch %6 + %6 = OpLabel + %12 = OpPhi %9 %10 %5 %13 %8 + OpLoopMerge %7 %8 None + OpBranch %8 + %8 = OpLabel + %13 = OpCopyObject %9 %10 + OpBranchConditional %10 %6 %7 + %7 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_4; + const auto consumer = nullptr; + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context.get())); + + FactManager fact_manager; + + TransformationOutlineFunction transformation(6, 7, 100, 101, 102, 103, 104, + 105, {}, {}); + ASSERT_FALSE(transformation.IsApplicable(context.get(), fact_manager)); +} + +TEST(TransformationOutlineFunctionTest, + DoNotOutlineSelectionHeaderNotInRegion) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + OpName %4 "main" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeBool + %7 = OpConstantTrue %6 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpSelectionMerge %10 None + OpBranchConditional %7 %8 %8 + %8 = OpLabel + OpBranch %9 + %9 = OpLabel + OpBranch %10 + %10 = OpLabel + OpBranch %11 + %11 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_4; + const auto consumer = nullptr; + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context.get())); + + FactManager fact_manager; + + TransformationOutlineFunction transformation(8, 11, 100, 101, 102, 103, 104, + 105, {}, {}); + ASSERT_FALSE(transformation.IsApplicable(context.get(), fact_manager)); +} + +TEST(TransformationOutlineFunctionTest, OutlineRegionEndingWithReturnVoid) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %20 = OpTypeInt 32 0 + %21 = OpConstant %20 1 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %22 = OpCopyObject %20 %21 + OpBranch %54 + %54 = OpLabel + OpBranch %57 + %57 = OpLabel + %23 = OpCopyObject %20 %22 + OpBranch %58 + %58 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_5; + const auto consumer = nullptr; + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context.get())); + + FactManager fact_manager; + + TransformationOutlineFunction transformation( + /*entry_block*/ 54, + /*exit_block*/ 58, + /*new_function_struct_return_type_id*/ 200, + /*new_function_type_id*/ 201, + /*new_function_id*/ 202, + /*new_function_region_entry_block*/ 203, + /*new_caller_result_id*/ 204, + /*new_callee_result_id*/ 205, + /*input_id_to_fresh_id*/ {{22, 206}}, + /*output_id_to_fresh_id*/ {}); + + ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); + transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE(IsValid(env, context.get())); + + std::string after_transformation = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %20 = OpTypeInt 32 0 + %21 = OpConstant %20 1 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %201 = OpTypeFunction %2 %20 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %22 = OpCopyObject %20 %21 + OpBranch %54 + %54 = OpLabel + %204 = OpFunctionCall %2 %202 %22 + OpReturn + OpFunctionEnd + %202 = OpFunction %2 None %201 + %206 = OpFunctionParameter %20 + %203 = OpLabel + OpBranch %57 + %57 = OpLabel + %23 = OpCopyObject %20 %206 + OpBranch %58 + %58 = OpLabel + OpReturn + OpFunctionEnd + )"; + ASSERT_TRUE(IsEqual(env, after_transformation, context.get())); +} + +TEST(TransformationOutlineFunctionTest, OutlineRegionEndingWithReturnValue) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %20 = OpTypeInt 32 0 + %21 = OpConstant %20 1 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %30 = OpTypeFunction %20 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %6 = OpFunctionCall %20 %100 + OpReturn + OpFunctionEnd + %100 = OpFunction %20 None %30 + %8 = OpLabel + %31 = OpCopyObject %20 %21 + OpBranch %9 + %9 = OpLabel + %32 = OpCopyObject %20 %31 + OpBranch %10 + %10 = OpLabel + OpReturnValue %32 + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_5; + const auto consumer = nullptr; + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context.get())); + + FactManager fact_manager; + + TransformationOutlineFunction transformation( + /*entry_block*/ 9, + /*exit_block*/ 10, + /*new_function_struct_return_type_id*/ 200, + /*new_function_type_id*/ 201, + /*new_function_id*/ 202, + /*new_function_region_entry_block*/ 203, + /*new_caller_result_id*/ 204, + /*new_callee_result_id*/ 205, + /*input_id_to_fresh_id*/ {{31, 206}}, + /*output_id_to_fresh_id*/ {{32, 207}}); + + ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); + transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE(IsValid(env, context.get())); + + std::string after_transformation = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %20 = OpTypeInt 32 0 + %21 = OpConstant %20 1 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %30 = OpTypeFunction %20 + %200 = OpTypeStruct %20 + %201 = OpTypeFunction %200 %20 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %6 = OpFunctionCall %20 %100 + OpReturn + OpFunctionEnd + %100 = OpFunction %20 None %30 + %8 = OpLabel + %31 = OpCopyObject %20 %21 + OpBranch %9 + %9 = OpLabel + %204 = OpFunctionCall %200 %202 %31 + %32 = OpCompositeExtract %20 %204 0 + OpReturnValue %32 + OpFunctionEnd + %202 = OpFunction %200 None %201 + %206 = OpFunctionParameter %20 + %203 = OpLabel + %207 = OpCopyObject %20 %206 + OpBranch %10 + %10 = OpLabel + %205 = OpCompositeConstruct %200 %207 + OpReturnValue %205 + OpFunctionEnd + )"; + ASSERT_TRUE(IsEqual(env, after_transformation, context.get())); +} + +TEST(TransformationOutlineFunctionTest, + OutlineRegionEndingWithConditionalBranch) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %20 = OpTypeBool + %21 = OpConstantTrue %20 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpBranch %54 + %54 = OpLabel + %6 = OpCopyObject %20 %21 + OpSelectionMerge %8 None + OpBranchConditional %6 %7 %8 + %7 = OpLabel + OpBranch %8 + %8 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_5; + const auto consumer = nullptr; + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context.get())); + + FactManager fact_manager; + + TransformationOutlineFunction transformation( + /*entry_block*/ 54, + /*exit_block*/ 54, + /*new_function_struct_return_type_id*/ 200, + /*new_function_type_id*/ 201, + /*new_function_id*/ 202, + /*new_function_region_entry_block*/ 203, + /*new_caller_result_id*/ 204, + /*new_callee_result_id*/ 205, + /*input_id_to_fresh_id*/ {{}}, + /*output_id_to_fresh_id*/ {{6, 206}}); + + ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); + transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE(IsValid(env, context.get())); + + std::string after_transformation = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %20 = OpTypeBool + %21 = OpConstantTrue %20 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %200 = OpTypeStruct %20 + %201 = OpTypeFunction %200 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpBranch %54 + %54 = OpLabel + %204 = OpFunctionCall %200 %202 + %6 = OpCompositeExtract %20 %204 0 + OpSelectionMerge %8 None + OpBranchConditional %6 %7 %8 + %7 = OpLabel + OpBranch %8 + %8 = OpLabel + OpReturn + OpFunctionEnd + %202 = OpFunction %200 None %201 + %203 = OpLabel + %206 = OpCopyObject %20 %21 + %205 = OpCompositeConstruct %200 %206 + OpReturnValue %205 + OpFunctionEnd + )"; + ASSERT_TRUE(IsEqual(env, after_transformation, context.get())); +} + +TEST(TransformationOutlineFunctionTest, + OutlineRegionEndingWithConditionalBranch2) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %20 = OpTypeBool + %21 = OpConstantTrue %20 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %6 = OpCopyObject %20 %21 + OpBranch %54 + %54 = OpLabel + OpSelectionMerge %8 None + OpBranchConditional %6 %7 %8 + %7 = OpLabel + OpBranch %8 + %8 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_5; + const auto consumer = nullptr; + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context.get())); + + FactManager fact_manager; + + TransformationOutlineFunction transformation( + /*entry_block*/ 54, + /*exit_block*/ 54, + /*new_function_struct_return_type_id*/ 200, + /*new_function_type_id*/ 201, + /*new_function_id*/ 202, + /*new_function_region_entry_block*/ 203, + /*new_caller_result_id*/ 204, + /*new_callee_result_id*/ 205, + /*input_id_to_fresh_id*/ {}, + /*output_id_to_fresh_id*/ {}); + + ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); + transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE(IsValid(env, context.get())); + + std::string after_transformation = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %20 = OpTypeBool + %21 = OpConstantTrue %20 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %6 = OpCopyObject %20 %21 + OpBranch %54 + %54 = OpLabel + %204 = OpFunctionCall %2 %202 + OpSelectionMerge %8 None + OpBranchConditional %6 %7 %8 + %7 = OpLabel + OpBranch %8 + %8 = OpLabel + OpReturn + OpFunctionEnd + %202 = OpFunction %2 None %3 + %203 = OpLabel + OpReturn + OpFunctionEnd + )"; + ASSERT_TRUE(IsEqual(env, after_transformation, context.get())); +} + +TEST(TransformationOutlineFunctionTest, DoNotOutlineRegionThatStartsWithOpPhi) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + OpName %4 "main" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeBool + %7 = OpConstantTrue %6 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpBranch %21 + %21 = OpLabel + %22 = OpPhi %6 %7 %5 + %23 = OpCopyObject %6 %22 + OpBranch %24 + %24 = OpLabel + %25 = OpCopyObject %6 %23 + %26 = OpCopyObject %6 %22 + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_5; + const auto consumer = nullptr; + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context.get())); + + FactManager fact_manager; + + TransformationOutlineFunction transformation( + /*entry_block*/ 21, + /*exit_block*/ 21, + /*new_function_struct_return_type_id*/ 200, + /*new_function_type_id*/ 201, + /*new_function_id*/ 202, + /*new_function_region_entry_block*/ 204, + /*new_caller_result_id*/ 205, + /*new_callee_result_id*/ 206, + /*input_id_to_fresh_id*/ {{22, 207}}, + /*output_id_to_fresh_id*/ {{23, 208}}); + + ASSERT_FALSE(transformation.IsApplicable(context.get(), fact_manager)); +} + +TEST(TransformationOutlineFunctionTest, + DoNotOutlineRegionThatStartsWithLoopHeader) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + OpName %4 "main" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeBool + %7 = OpConstantTrue %6 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpBranch %21 + %21 = OpLabel + OpLoopMerge %22 %23 None + OpBranch %24 + %24 = OpLabel + OpBranchConditional %7 %22 %23 + %23 = OpLabel + OpBranch %21 + %22 = OpLabel + OpBranch %25 + %25 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_5; + const auto consumer = nullptr; + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context.get())); + + FactManager fact_manager; + + TransformationOutlineFunction transformation( + /*entry_block*/ 21, + /*exit_block*/ 24, + /*new_function_struct_return_type_id*/ 200, + /*new_function_type_id*/ 201, + /*new_function_id*/ 202, + /*new_function_region_entry_block*/ 204, + /*new_caller_result_id*/ 205, + /*new_callee_result_id*/ 206, + /*input_id_to_fresh_id*/ {}, + /*output_id_to_fresh_id*/ {}); + + ASSERT_FALSE(transformation.IsApplicable(context.get(), fact_manager)); +} + +TEST(TransformationOutlineFunctionTest, + DoNotOutlineRegionThatEndsWithLoopMerge) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + OpName %4 "main" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeBool + %7 = OpConstantTrue %6 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpBranch %21 + %21 = OpLabel + OpLoopMerge %22 %23 None + OpBranch %24 + %24 = OpLabel + OpBranchConditional %7 %22 %23 + %23 = OpLabel + OpBranch %21 + %22 = OpLabel + OpBranch %25 + %25 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_5; + const auto consumer = nullptr; + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context.get())); + + FactManager fact_manager; + + TransformationOutlineFunction transformation( + /*entry_block*/ 5, + /*exit_block*/ 22, + /*new_function_struct_return_type_id*/ 200, + /*new_function_type_id*/ 201, + /*new_function_id*/ 202, + /*new_function_region_entry_block*/ 204, + /*new_caller_result_id*/ 205, + /*new_callee_result_id*/ 206, + /*input_id_to_fresh_id*/ {}, + /*output_id_to_fresh_id*/ {}); + + ASSERT_FALSE(transformation.IsApplicable(context.get(), fact_manager)); +} + +TEST(TransformationOutlineFunctionTest, Miscellaneous1) { + // This tests outlining of some non-trivial code. + + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" %85 + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + OpName %4 "main" + OpName %28 "buf" + OpMemberName %28 0 "u1" + OpMemberName %28 1 "u2" + OpName %30 "" + OpName %85 "color" + OpMemberDecorate %28 0 Offset 0 + OpMemberDecorate %28 1 Offset 4 + OpDecorate %28 Block + OpDecorate %30 DescriptorSet 0 + OpDecorate %30 Binding 0 + OpDecorate %85 Location 0 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeFloat 32 + %7 = OpTypeVector %6 4 + %10 = OpConstant %6 1 + %11 = OpConstant %6 2 + %12 = OpConstant %6 3 + %13 = OpConstant %6 4 + %14 = OpConstantComposite %7 %10 %11 %12 %13 + %15 = OpTypeInt 32 1 + %18 = OpConstant %15 0 + %28 = OpTypeStruct %6 %6 + %29 = OpTypePointer Uniform %28 + %30 = OpVariable %29 Uniform + %31 = OpTypePointer Uniform %6 + %35 = OpTypeBool + %39 = OpConstant %15 1 + %84 = OpTypePointer Output %7 + %85 = OpVariable %84 Output + %114 = OpConstant %15 8 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpBranch %22 + %22 = OpLabel + %103 = OpPhi %15 %18 %5 %106 %43 + %102 = OpPhi %7 %14 %5 %107 %43 + %101 = OpPhi %15 %18 %5 %40 %43 + %32 = OpAccessChain %31 %30 %18 + %33 = OpLoad %6 %32 + %34 = OpConvertFToS %15 %33 + %36 = OpSLessThan %35 %101 %34 + OpLoopMerge %24 %43 None + OpBranchConditional %36 %23 %24 + %23 = OpLabel + %40 = OpIAdd %15 %101 %39 + OpBranch %150 + %150 = OpLabel + OpBranch %41 + %41 = OpLabel + %107 = OpPhi %7 %102 %150 %111 %65 + %106 = OpPhi %15 %103 %150 %110 %65 + %104 = OpPhi %15 %40 %150 %81 %65 + %47 = OpAccessChain %31 %30 %39 + %48 = OpLoad %6 %47 + %49 = OpConvertFToS %15 %48 + %50 = OpSLessThan %35 %104 %49 + OpLoopMerge %1000 %65 None + OpBranchConditional %50 %42 %1000 + %42 = OpLabel + %60 = OpIAdd %15 %106 %114 + %63 = OpSGreaterThan %35 %104 %60 + OpBranchConditional %63 %64 %65 + %64 = OpLabel + %71 = OpCompositeExtract %6 %107 0 + %72 = OpFAdd %6 %71 %11 + %97 = OpCompositeInsert %7 %72 %107 0 + %76 = OpCompositeExtract %6 %107 3 + %77 = OpConvertFToS %15 %76 + %79 = OpIAdd %15 %60 %77 + OpBranch %65 + %65 = OpLabel + %111 = OpPhi %7 %107 %42 %97 %64 + %110 = OpPhi %15 %60 %42 %79 %64 + %81 = OpIAdd %15 %104 %39 + OpBranch %41 + %1000 = OpLabel + OpBranch %1001 + %1001 = OpLabel + OpBranch %43 + %43 = OpLabel + OpBranch %22 + %24 = OpLabel + %87 = OpCompositeExtract %6 %102 0 + %91 = OpConvertSToF %6 %103 + %92 = OpCompositeConstruct %7 %87 %11 %91 %10 + OpStore %85 %92 + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto consumer = nullptr; + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context.get())); + + FactManager fact_manager; + + TransformationOutlineFunction transformation( + /*entry_block*/ 150, + /*exit_block*/ 1001, + /*new_function_struct_return_type_id*/ 200, + /*new_function_type_id*/ 201, + /*new_function_id*/ 202, + /*new_function_region_entry_block*/ 203, + /*new_caller_result_id*/ 204, + /*new_callee_result_id*/ 205, + /*input_id_to_fresh_id*/ {{102, 300}, {103, 301}, {40, 302}}, + /*output_id_to_fresh_id*/ {{106, 400}, {107, 401}}); + + ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); + transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE(IsValid(env, context.get())); + + std::string after_transformation = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" %85 + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + OpName %4 "main" + OpName %28 "buf" + OpMemberName %28 0 "u1" + OpMemberName %28 1 "u2" + OpName %30 "" + OpName %85 "color" + OpMemberDecorate %28 0 Offset 0 + OpMemberDecorate %28 1 Offset 4 + OpDecorate %28 Block + OpDecorate %30 DescriptorSet 0 + OpDecorate %30 Binding 0 + OpDecorate %85 Location 0 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeFloat 32 + %7 = OpTypeVector %6 4 + %10 = OpConstant %6 1 + %11 = OpConstant %6 2 + %12 = OpConstant %6 3 + %13 = OpConstant %6 4 + %14 = OpConstantComposite %7 %10 %11 %12 %13 + %15 = OpTypeInt 32 1 + %18 = OpConstant %15 0 + %28 = OpTypeStruct %6 %6 + %29 = OpTypePointer Uniform %28 + %30 = OpVariable %29 Uniform + %31 = OpTypePointer Uniform %6 + %35 = OpTypeBool + %39 = OpConstant %15 1 + %84 = OpTypePointer Output %7 + %85 = OpVariable %84 Output + %114 = OpConstant %15 8 + %200 = OpTypeStruct %7 %15 + %201 = OpTypeFunction %200 %15 %7 %15 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpBranch %22 + %22 = OpLabel + %103 = OpPhi %15 %18 %5 %106 %43 + %102 = OpPhi %7 %14 %5 %107 %43 + %101 = OpPhi %15 %18 %5 %40 %43 + %32 = OpAccessChain %31 %30 %18 + %33 = OpLoad %6 %32 + %34 = OpConvertFToS %15 %33 + %36 = OpSLessThan %35 %101 %34 + OpLoopMerge %24 %43 None + OpBranchConditional %36 %23 %24 + %23 = OpLabel + %40 = OpIAdd %15 %101 %39 + OpBranch %150 + %150 = OpLabel + %204 = OpFunctionCall %200 %202 %103 %102 %40 + %107 = OpCompositeExtract %7 %204 0 + %106 = OpCompositeExtract %15 %204 1 + OpBranch %43 + %43 = OpLabel + OpBranch %22 + %24 = OpLabel + %87 = OpCompositeExtract %6 %102 0 + %91 = OpConvertSToF %6 %103 + %92 = OpCompositeConstruct %7 %87 %11 %91 %10 + OpStore %85 %92 + OpReturn + OpFunctionEnd + %202 = OpFunction %200 None %201 + %301 = OpFunctionParameter %15 + %300 = OpFunctionParameter %7 + %302 = OpFunctionParameter %15 + %203 = OpLabel + OpBranch %41 + %41 = OpLabel + %401 = OpPhi %7 %300 %203 %111 %65 + %400 = OpPhi %15 %301 %203 %110 %65 + %104 = OpPhi %15 %302 %203 %81 %65 + %47 = OpAccessChain %31 %30 %39 + %48 = OpLoad %6 %47 + %49 = OpConvertFToS %15 %48 + %50 = OpSLessThan %35 %104 %49 + OpLoopMerge %1000 %65 None + OpBranchConditional %50 %42 %1000 + %42 = OpLabel + %60 = OpIAdd %15 %400 %114 + %63 = OpSGreaterThan %35 %104 %60 + OpBranchConditional %63 %64 %65 + %64 = OpLabel + %71 = OpCompositeExtract %6 %401 0 + %72 = OpFAdd %6 %71 %11 + %97 = OpCompositeInsert %7 %72 %401 0 + %76 = OpCompositeExtract %6 %401 3 + %77 = OpConvertFToS %15 %76 + %79 = OpIAdd %15 %60 %77 + OpBranch %65 + %65 = OpLabel + %111 = OpPhi %7 %401 %42 %97 %64 + %110 = OpPhi %15 %60 %42 %79 %64 + %81 = OpIAdd %15 %104 %39 + OpBranch %41 + %1000 = OpLabel + OpBranch %1001 + %1001 = OpLabel + %205 = OpCompositeConstruct %200 %401 %400 + OpReturnValue %205 + OpFunctionEnd + )"; + ASSERT_TRUE(IsEqual(env, after_transformation, context.get())); +} + +TEST(TransformationOutlineFunctionTest, Miscellaneous2) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %21 = OpTypeBool + %167 = OpConstantTrue %21 + %168 = OpConstantFalse %21 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpBranch %34 + %34 = OpLabel + OpLoopMerge %36 %37 None + OpBranchConditional %168 %37 %38 + %38 = OpLabel + OpBranchConditional %168 %37 %36 + %37 = OpLabel + OpBranch %34 + %36 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_5; + const auto consumer = nullptr; + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context.get())); + + FactManager fact_manager; + + TransformationOutlineFunction transformation( + /*entry_block*/ 38, + /*exit_block*/ 36, + /*new_function_struct_return_type_id*/ 200, + /*new_function_type_id*/ 201, + /*new_function_id*/ 202, + /*new_function_region_entry_block*/ 203, + /*new_caller_result_id*/ 204, + /*new_callee_result_id*/ 205, + /*input_id_to_fresh_id*/ {}, + /*output_id_to_fresh_id*/ {}); + + ASSERT_FALSE(transformation.IsApplicable(context.get(), fact_manager)); +} + +TEST(TransformationOutlineFunctionTest, Miscellaneous3) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %6 "main" + OpExecutionMode %6 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %21 = OpTypeBool + %167 = OpConstantTrue %21 + %6 = OpFunction %2 None %3 + %7 = OpLabel + OpBranch %80 + %80 = OpLabel + OpBranch %14 + %14 = OpLabel + OpLoopMerge %16 %17 None + OpBranch %18 + %18 = OpLabel + OpBranchConditional %167 %15 %16 + %15 = OpLabel + OpBranch %17 + %16 = OpLabel + OpBranch %81 + %81 = OpLabel + OpReturn + %17 = OpLabel + OpBranch %14 + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_5; + const auto consumer = nullptr; + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context.get())); + + FactManager fact_manager; + + TransformationOutlineFunction transformation( + /*entry_block*/ 80, + /*exit_block*/ 81, + /*new_function_struct_return_type_id*/ 300, + /*new_function_type_id*/ 301, + /*new_function_id*/ 302, + /*new_function_region_entry_block*/ 304, + /*new_caller_result_id*/ 305, + /*new_callee_result_id*/ 306, + /*input_id_to_fresh_id*/ {}, + /*output_id_to_fresh_id*/ {}); + + ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); + transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE(IsValid(env, context.get())); + + std::string after_transformation = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %6 "main" + OpExecutionMode %6 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %21 = OpTypeBool + %167 = OpConstantTrue %21 + %6 = OpFunction %2 None %3 + %7 = OpLabel + OpBranch %80 + %80 = OpLabel + %305 = OpFunctionCall %2 %302 + OpReturn + OpFunctionEnd + %302 = OpFunction %2 None %3 + %304 = OpLabel + OpBranch %14 + %14 = OpLabel + OpLoopMerge %16 %17 None + OpBranch %18 + %18 = OpLabel + OpBranchConditional %167 %15 %16 + %15 = OpLabel + OpBranch %17 + %16 = OpLabel + OpBranch %81 + %81 = OpLabel + OpReturn + %17 = OpLabel + OpBranch %14 + OpFunctionEnd + )"; + ASSERT_TRUE(IsEqual(env, after_transformation, context.get())); +} + +} // namespace +} // namespace fuzz +} // namespace spvtools