diff --git a/source/fuzz/fuzzer_pass_flatten_conditional_branches.cpp b/source/fuzz/fuzzer_pass_flatten_conditional_branches.cpp index 1bc0c2b4c..d8f50c348 100644 --- a/source/fuzz/fuzzer_pass_flatten_conditional_branches.cpp +++ b/source/fuzz/fuzzer_pass_flatten_conditional_branches.cpp @@ -72,6 +72,107 @@ void FuzzerPassFlattenConditionalBranches::Apply() { continue; } + uint32_t convergence_block_id = + TransformationFlattenConditionalBranch::FindConvergenceBlock( + GetIRContext(), *header); + + // If the SPIR-V version is restricted so that OpSelect can only work on + // scalar, pointer and vector types then we cannot apply this + // transformation to a header whose convergence block features OpPhi + // instructions on different types, as we cannot convert such instructions + // to OpSelect instructions. + if (TransformationFlattenConditionalBranch:: + OpSelectArgumentsAreRestricted(GetIRContext())) { + if (!GetIRContext() + ->cfg() + ->block(convergence_block_id) + ->WhileEachPhiInst( + [this](opt::Instruction* phi_instruction) -> bool { + switch (GetIRContext() + ->get_def_use_mgr() + ->GetDef(phi_instruction->type_id()) + ->opcode()) { + case SpvOpTypeBool: + case SpvOpTypeInt: + case SpvOpTypeFloat: + case SpvOpTypePointer: + case SpvOpTypeVector: + return true; + default: + return false; + } + })) { + // An OpPhi is performed on a type not supported by OpSelect; we + // cannot flatten this selection. + continue; + } + } + + // If the construct's convergence block features OpPhi instructions with + // vector result types then we may be *forced*, by the SPIR-V version, to + // turn these into component-wise OpSelect instructions, or we might wish + // to do so anyway. The following booleans capture whether we will opt + // to use a component-wise select even if we don't have to. + bool use_component_wise_2d_select_even_if_optional = + GetFuzzerContext()->ChooseEven(); + bool use_component_wise_3d_select_even_if_optional = + GetFuzzerContext()->ChooseEven(); + bool use_component_wise_4d_select_even_if_optional = + GetFuzzerContext()->ChooseEven(); + + // If we do need to perform any component-wise selections, we will need a + // fresh id for a boolean vector representing the selection's condition + // repeated N times, where N is the vector dimension. + uint32_t fresh_id_for_bvec2_selector = 0; + uint32_t fresh_id_for_bvec3_selector = 0; + uint32_t fresh_id_for_bvec4_selector = 0; + + GetIRContext() + ->cfg() + ->block(convergence_block_id) + ->ForEachPhiInst([this, &fresh_id_for_bvec2_selector, + &fresh_id_for_bvec3_selector, + &fresh_id_for_bvec4_selector, + use_component_wise_2d_select_even_if_optional, + use_component_wise_3d_select_even_if_optional, + use_component_wise_4d_select_even_if_optional]( + opt::Instruction* phi_instruction) { + opt::Instruction* type_instruction = + GetIRContext()->get_def_use_mgr()->GetDef( + phi_instruction->type_id()); + switch (type_instruction->opcode()) { + case SpvOpTypeVector: { + uint32_t dimension = + type_instruction->GetSingleWordInOperand(1); + switch (dimension) { + case 2: + PrepareForOpPhiOnVectors( + dimension, + use_component_wise_2d_select_even_if_optional, + &fresh_id_for_bvec2_selector); + break; + case 3: + PrepareForOpPhiOnVectors( + dimension, + use_component_wise_3d_select_even_if_optional, + &fresh_id_for_bvec3_selector); + break; + case 4: + PrepareForOpPhiOnVectors( + dimension, + use_component_wise_4d_select_even_if_optional, + &fresh_id_for_bvec4_selector); + break; + default: + assert(false && "Invalid vector dimension."); + } + break; + } + default: + break; + } + }); + // Some instructions will require to be enclosed inside conditionals // because they have side effects (for example, loads and stores). Some of // this have no result id, so we require instruction descriptors to @@ -116,10 +217,31 @@ void FuzzerPassFlattenConditionalBranches::Apply() { // Apply the transformation, evenly choosing whether to lay out the true // branch or the false branch first. ApplyTransformation(TransformationFlattenConditionalBranch( - header->id(), GetFuzzerContext()->ChooseEven(), wrappers_info)); + header->id(), GetFuzzerContext()->ChooseEven(), + fresh_id_for_bvec2_selector, fresh_id_for_bvec3_selector, + fresh_id_for_bvec4_selector, wrappers_info)); } } } +void FuzzerPassFlattenConditionalBranches::PrepareForOpPhiOnVectors( + uint32_t vector_dimension, bool use_vector_select_if_optional, + uint32_t* fresh_id_for_bvec_selector) { + if (*fresh_id_for_bvec_selector != 0) { + // We already have a fresh id for a component-wise OpSelect of this + // dimension + return; + } + if (TransformationFlattenConditionalBranch::OpSelectArgumentsAreRestricted( + GetIRContext()) || + use_vector_select_if_optional) { + // We either have to, or have chosen to, perform a component-wise select, so + // we ensure that the right boolean vector type is available, and grab a + // fresh id. + FindOrCreateVectorType(FindOrCreateBoolType(), vector_dimension); + *fresh_id_for_bvec_selector = GetFuzzerContext()->GetFreshId(); + } +} + } // namespace fuzz } // namespace spvtools diff --git a/source/fuzz/fuzzer_pass_flatten_conditional_branches.h b/source/fuzz/fuzzer_pass_flatten_conditional_branches.h index 715385ae1..76f7782cf 100644 --- a/source/fuzz/fuzzer_pass_flatten_conditional_branches.h +++ b/source/fuzz/fuzzer_pass_flatten_conditional_branches.h @@ -30,6 +30,16 @@ class FuzzerPassFlattenConditionalBranches : public FuzzerPass { ~FuzzerPassFlattenConditionalBranches() override; void Apply() override; + + private: + // If the SPIR-V version requires vector OpSelects to be component-wise, or + // if |use_vector_select_if_optional| holds, |fresh_id_for_bvec_selector| is + // populated with a fresh id if it is currently zero, and a + // |vector_dimension|-dimensional boolean vector type is added to the module + // if not already present. + void PrepareForOpPhiOnVectors(uint32_t vector_dimension, + bool use_vector_select_if_optional, + uint32_t* fresh_id_for_bvec_selector); }; } // namespace fuzz } // namespace spvtools diff --git a/source/fuzz/fuzzer_pass_replace_opselects_with_conditional_branches.cpp b/source/fuzz/fuzzer_pass_replace_opselects_with_conditional_branches.cpp index 0496268c5..c3db0ef14 100644 --- a/source/fuzz/fuzzer_pass_replace_opselects_with_conditional_branches.cpp +++ b/source/fuzz/fuzzer_pass_replace_opselects_with_conditional_branches.cpp @@ -65,6 +65,16 @@ void FuzzerPassReplaceOpSelectsWithConditionalBranches::Apply() { continue; } + // If the selector does not have scalar boolean type (i.e., it is a + // boolean vector) then ignore this OpSelect. + if (GetIRContext() + ->get_def_use_mgr() + ->GetDef(fuzzerutil::GetTypeId( + GetIRContext(), instruction.GetSingleWordInOperand(0))) + ->opcode() != SpvOpTypeBool) { + continue; + } + // If the block is a loop header and we need to split it, the // transformation cannot be applied because loop headers cannot be // split. We can break out of this loop because the transformation can diff --git a/source/fuzz/protobufs/spvtoolsfuzz.proto b/source/fuzz/protobufs/spvtoolsfuzz.proto index 7778165ae..20116c7dd 100644 --- a/source/fuzz/protobufs/spvtoolsfuzz.proto +++ b/source/fuzz/protobufs/spvtoolsfuzz.proto @@ -1468,10 +1468,23 @@ message TransformationFlattenConditionalBranch { // field is true. bool true_branch_first = 2; + // If the convergence block contains an OpPhi with bvec2 result type, it may + // be necessary to introduce a bvec2 with the selection construct's condition + // in both components in order to turn the OpPhi into an OpSelect. This + // this field provides a fresh id for an OpCompositeConstruct instruction for + // this purpose. It should be set to 0 if no such instruction is required. + uint32 fresh_id_for_bvec2_selector = 3; + + // The same as |fresh_id_for_bvec2_selector| but for the bvec3 case. + uint32 fresh_id_for_bvec3_selector = 4; + + // The same as |fresh_id_for_bvec2_selector| but for the bvec4 case. + uint32 fresh_id_for_bvec4_selector = 5; + // A list of instructions with side effects, which must be enclosed // inside smaller conditionals before flattening the main one, and // the corresponding fresh ids and module ids needed. - repeated SideEffectWrapperInfo side_effect_wrapper_info = 3; + repeated SideEffectWrapperInfo side_effect_wrapper_info = 6; } message TransformationFunctionCall { diff --git a/source/fuzz/transformation_flatten_conditional_branch.cpp b/source/fuzz/transformation_flatten_conditional_branch.cpp index 86a031393..ecb2fc58e 100644 --- a/source/fuzz/transformation_flatten_conditional_branch.cpp +++ b/source/fuzz/transformation_flatten_conditional_branch.cpp @@ -26,10 +26,15 @@ TransformationFlattenConditionalBranch::TransformationFlattenConditionalBranch( TransformationFlattenConditionalBranch::TransformationFlattenConditionalBranch( uint32_t header_block_id, bool true_branch_first, + uint32_t fresh_id_for_bvec2_selector, uint32_t fresh_id_for_bvec3_selector, + uint32_t fresh_id_for_bvec4_selector, const std::vector& side_effect_wrappers_info) { message_.set_header_block_id(header_block_id); message_.set_true_branch_first(true_branch_first); + message_.set_fresh_id_for_bvec2_selector(fresh_id_for_bvec2_selector); + message_.set_fresh_id_for_bvec3_selector(fresh_id_for_bvec3_selector); + message_.set_fresh_id_for_bvec4_selector(fresh_id_for_bvec4_selector); for (auto const& side_effect_wrapper_info : side_effect_wrappers_info) { *message_.add_side_effect_wrapper_info() = side_effect_wrapper_info; } @@ -38,8 +43,8 @@ TransformationFlattenConditionalBranch::TransformationFlattenConditionalBranch( bool TransformationFlattenConditionalBranch::IsApplicable( opt::IRContext* ir_context, const TransformationContext& transformation_context) const { - uint32_t header_block_id = message_.header_block_id(); - auto header_block = fuzzerutil::MaybeFindBlock(ir_context, header_block_id); + auto header_block = + fuzzerutil::MaybeFindBlock(ir_context, message_.header_block_id()); // The block must have been found and it must be a selection header. if (!header_block || !header_block->GetMergeInst() || @@ -52,6 +57,81 @@ bool TransformationFlattenConditionalBranch::IsApplicable( return false; } + std::set used_fresh_ids; + + // If ids have been provided to be used as vector guards for OpSelect + // instructions then they must be fresh. + for (uint32_t fresh_id_for_bvec_selector : + {message_.fresh_id_for_bvec2_selector(), + message_.fresh_id_for_bvec3_selector(), + message_.fresh_id_for_bvec4_selector()}) { + if (fresh_id_for_bvec_selector != 0) { + if (!CheckIdIsFreshAndNotUsedByThisTransformation( + fresh_id_for_bvec_selector, ir_context, &used_fresh_ids)) { + return false; + } + } + } + + if (OpSelectArgumentsAreRestricted(ir_context)) { + // OpPhi instructions at the convergence block for the selection are handled + // by turning them into OpSelect instructions. As the SPIR-V version in use + // has restrictions on the arguments that OpSelect can take, we must check + // that any OpPhi instructions are compatible with these restrictions. + uint32_t convergence_block_id = + FindConvergenceBlock(ir_context, *header_block); + // Consider every OpPhi instruction at the convergence block. + if (!ir_context->cfg() + ->block(convergence_block_id) + ->WhileEachPhiInst([this, + ir_context](opt::Instruction* inst) -> bool { + // Decide whether the OpPhi can be handled based on its result + // type. + opt::Instruction* phi_result_type = + ir_context->get_def_use_mgr()->GetDef(inst->type_id()); + switch (phi_result_type->opcode()) { + case SpvOpTypeBool: + case SpvOpTypeInt: + case SpvOpTypeFloat: + case SpvOpTypePointer: + // Fine: OpSelect can work directly on scalar and pointer + // types. + return true; + case SpvOpTypeVector: { + // In its restricted form, OpSelect can only select between + // vectors if the condition of the select is a boolean + // boolean vector. We thus require the appropriate boolean + // vector type to be present. + uint32_t bool_type_id = + fuzzerutil::MaybeGetBoolType(ir_context); + uint32_t dimension = + phi_result_type->GetSingleWordInOperand(1); + if (fuzzerutil::MaybeGetVectorType(ir_context, bool_type_id, + dimension) == 0) { + // The required boolean vector type is not present. + return false; + } + // The transformation needs to be equipped with a fresh id + // in which to store the vectorized version of the selection + // construct's condition. + switch (dimension) { + case 2: + return message_.fresh_id_for_bvec2_selector() != 0; + case 3: + return message_.fresh_id_for_bvec3_selector() != 0; + default: + assert(dimension == 4 && "Invalid vector dimension."); + return message_.fresh_id_for_bvec4_selector() != 0; + } + } + default: + return false; + } + })) { + return false; + } + } + // Use a set to keep track of the instructions that require fresh ids. std::set instructions_that_need_ids; @@ -68,8 +148,6 @@ bool TransformationFlattenConditionalBranch::IsApplicable( auto insts_to_wrapper_info = GetInstructionsToWrapperInfo(ir_context); { - std::set used_fresh_ids; - // Check the ids in the map. for (const auto& inst_to_info : insts_to_wrapper_info) { // Check the fresh ids needed for all of the instructions that need to be @@ -125,28 +203,6 @@ bool TransformationFlattenConditionalBranch::IsApplicable( void TransformationFlattenConditionalBranch::Apply( opt::IRContext* ir_context, TransformationContext* transformation_context) const { - uint32_t header_block_id = message_.header_block_id(); - auto header_block = ir_context->cfg()->block(header_block_id); - - // Find the first block where flow converges (it is not necessarily the merge - // block) by walking the true branch until reaching a block that - // post-dominates the header. - // This is necessary because a potential common set of blocks at the end of - // the construct should not be duplicated. - uint32_t convergence_block_id = - header_block->terminator()->GetSingleWordInOperand(1); - auto postdominator_analysis = - ir_context->GetPostDominatorAnalysis(header_block->GetParent()); - while (!postdominator_analysis->Dominates(convergence_block_id, - header_block_id)) { - auto current_block = ir_context->get_instr_block(convergence_block_id); - // If the transformation is applicable, the terminator is OpBranch. - convergence_block_id = - current_block->terminator()->GetSingleWordInOperand(0); - } - - auto branch_instruction = header_block->terminator(); - // branch = 1 corresponds to the true branch, branch = 2 corresponds to the // false branch. If the true branch is to be laid out first, we need to visit // the false branch first, because each branch is moved to right after the @@ -158,58 +214,27 @@ void TransformationFlattenConditionalBranch::Apply( branches = {1, 2}; } + auto header_block = ir_context->cfg()->block(message_.header_block_id()); + // Get the ids of the starting blocks of the first and last branches to be // laid out. The first branch is the true branch iff // |message_.true_branch_first| is true. + auto branch_instruction = header_block->terminator(); uint32_t first_block_first_branch_id = branch_instruction->GetSingleWordInOperand(branches[1]); uint32_t first_block_last_branch_id = branch_instruction->GetSingleWordInOperand(branches[0]); + uint32_t convergence_block_id = + FindConvergenceBlock(ir_context, *header_block); + // If the OpBranchConditional instruction in the header branches to the same // block for both values of the condition, this is the convergence block (the // flow does not actually diverge) and the OpPhi instructions in it are still // valid, so we do not need to make any changes. if (first_block_first_branch_id != first_block_last_branch_id) { - // Replace all of the current OpPhi instructions in the convergence block - // with OpSelect. - ir_context->get_instr_block(convergence_block_id) - ->ForEachPhiInst([branch_instruction, header_block, - ir_context](opt::Instruction* phi_inst) { - assert(phi_inst->NumInOperands() == 4 && - "We are going to replace an OpPhi with an OpSelect. This " - "only makes sense if the block has two distinct " - "predecessors."); - // The OpPhi takes values from two distinct predecessors. One - // predecessor is associated with the "true" path of the conditional - // we are flattening, the other with the "false" path, but these - // predecessors can appear in either order as operands to the OpPhi - // instruction. - - std::vector operands; - operands.emplace_back(branch_instruction->GetInOperand(0)); - - uint32_t branch_instruction_true_block_id = - branch_instruction->GetSingleWordInOperand(1); - - if (ir_context->GetDominatorAnalysis(header_block->GetParent()) - ->Dominates(branch_instruction_true_block_id, - phi_inst->GetSingleWordInOperand(1))) { - // The "true" branch is handled first in the OpPhi's operands; we - // thus provide operands to OpSelect in the same order that they - // appear in the OpPhi. - operands.emplace_back(phi_inst->GetInOperand(0)); - operands.emplace_back(phi_inst->GetInOperand(2)); - } else { - // The "false" branch is handled first in the OpPhi's operands; we - // thus provide operands to OpSelect in reverse of the order that - // they appear in the OpPhi. - operands.emplace_back(phi_inst->GetInOperand(2)); - operands.emplace_back(phi_inst->GetInOperand(0)); - } - phi_inst->SetOpcode(SpvOpSelect); - phi_inst->SetInOperands(std::move(operands)); - }); + RewriteOpPhiInstructionsAtConvergenceBlock( + *header_block, convergence_block_id, ir_context); } // Get the mapping from instructions to fresh ids. @@ -732,7 +757,10 @@ bool TransformationFlattenConditionalBranch::InstructionCanBeHandled( std::unordered_set TransformationFlattenConditionalBranch::GetFreshIds() const { - std::unordered_set result; + std::unordered_set result = { + message_.fresh_id_for_bvec2_selector(), + message_.fresh_id_for_bvec3_selector(), + message_.fresh_id_for_bvec4_selector()}; for (auto& side_effect_wrapper_info : message_.side_effect_wrapper_info()) { result.insert(side_effect_wrapper_info.merge_block_id()); result.insert(side_effect_wrapper_info.execute_block_id()); @@ -743,5 +771,166 @@ TransformationFlattenConditionalBranch::GetFreshIds() const { return result; } +uint32_t TransformationFlattenConditionalBranch::FindConvergenceBlock( + opt::IRContext* ir_context, const opt::BasicBlock& header_block) { + uint32_t result = header_block.terminator()->GetSingleWordInOperand(1); + auto postdominator_analysis = + ir_context->GetPostDominatorAnalysis(header_block.GetParent()); + while (!postdominator_analysis->Dominates(result, header_block.id())) { + auto current_block = ir_context->get_instr_block(result); + // If the transformation is applicable, the terminator is OpBranch. + result = current_block->terminator()->GetSingleWordInOperand(0); + } + return result; +} + +bool TransformationFlattenConditionalBranch::OpSelectArgumentsAreRestricted( + opt::IRContext* ir_context) { + switch (ir_context->grammar().target_env()) { + case SPV_ENV_UNIVERSAL_1_0: + case SPV_ENV_UNIVERSAL_1_1: + case SPV_ENV_UNIVERSAL_1_2: + case SPV_ENV_UNIVERSAL_1_3: { + return true; + } + default: + return false; + } +} + +void TransformationFlattenConditionalBranch::AddBooleanVectorConstructorToBlock( + uint32_t fresh_id, uint32_t dimension, + const opt::Operand& branch_condition_operand, opt::IRContext* ir_context, + opt::BasicBlock* block) const { + opt::Instruction::OperandList in_operands; + for (uint32_t i = 0; i < dimension; i++) { + in_operands.emplace_back(branch_condition_operand); + } + block->begin()->InsertBefore(MakeUnique( + ir_context, SpvOpCompositeConstruct, + fuzzerutil::MaybeGetVectorType( + ir_context, fuzzerutil::MaybeGetBoolType(ir_context), dimension), + fresh_id, in_operands)); + fuzzerutil::UpdateModuleIdBound(ir_context, fresh_id); +} + +void TransformationFlattenConditionalBranch:: + RewriteOpPhiInstructionsAtConvergenceBlock( + const opt::BasicBlock& header_block, uint32_t convergence_block_id, + opt::IRContext* ir_context) const { + const opt::Instruction& branch_instruction = *header_block.terminator(); + + const opt::Operand& branch_condition_operand = + branch_instruction.GetInOperand(0); + + // If we encounter OpPhi instructions on vector types then we may need to + // introduce vector versions of the selection construct's condition to use + // in corresponding OpSelect instructions. These booleans track whether we + // need to introduce such boolean vectors. + bool require_2d_boolean_vector = false; + bool require_3d_boolean_vector = false; + bool require_4d_boolean_vector = false; + + // Consider every OpPhi instruction at the convergence block. + opt::BasicBlock* convergence_block = + ir_context->get_instr_block(convergence_block_id); + convergence_block->ForEachPhiInst( + [this, &branch_condition_operand, branch_instruction, &header_block, + ir_context, &require_2d_boolean_vector, &require_3d_boolean_vector, + &require_4d_boolean_vector](opt::Instruction* phi_inst) { + assert(phi_inst->NumInOperands() == 4 && + "We are going to replace an OpPhi with an OpSelect. This " + "only makes sense if the block has two distinct " + "predecessors."); + // We are going to replace the OpPhi with an OpSelect. By default, + // the condition for the OpSelect will be the branch condition's + // operand. However, if the OpPhi has vector result type we may need + // to use a boolean vector as the condition instead. + opt::Operand selector_operand = branch_condition_operand; + opt::Instruction* type_inst = + ir_context->get_def_use_mgr()->GetDef(phi_inst->type_id()); + if (type_inst->opcode() == SpvOpTypeVector) { + uint32_t dimension = type_inst->GetSingleWordInOperand(1); + switch (dimension) { + case 2: + // The OpPhi's result type is a 2D vector. If a fresh id for a + // bvec2 selector was provided then we should use it as the + // OpSelect's condition, and note the fact that we will need to + // add an instruction to bring this bvec2 into existence. + if (message_.fresh_id_for_bvec2_selector() != 0) { + selector_operand = {SPV_OPERAND_TYPE_TYPE_ID, + {message_.fresh_id_for_bvec2_selector()}}; + require_2d_boolean_vector = true; + } + break; + case 3: + // Similar to the 2D case. + if (message_.fresh_id_for_bvec3_selector() != 0) { + selector_operand = {SPV_OPERAND_TYPE_TYPE_ID, + {message_.fresh_id_for_bvec3_selector()}}; + require_3d_boolean_vector = true; + } + break; + default: + assert(dimension == 4 && "Invalid vector dimension."); + // Similar to the 2D case. + if (message_.fresh_id_for_bvec4_selector() != 0) { + selector_operand = {SPV_OPERAND_TYPE_TYPE_ID, + {message_.fresh_id_for_bvec4_selector()}}; + require_4d_boolean_vector = true; + } + break; + } + } + std::vector operands; + operands.emplace_back(selector_operand); + + uint32_t branch_instruction_true_block_id = + branch_instruction.GetSingleWordInOperand(1); + + // The OpPhi takes values from two distinct predecessors. One + // predecessor is associated with the "true" path of the conditional + // we are flattening, the other with the "false" path, but these + // predecessors can appear in either order as operands to the OpPhi + // instruction. We determine in which order the OpPhi inputs should + // appear as OpSelect arguments by checking dominance of the true and + // false immediate successors of the header block. + if (ir_context->GetDominatorAnalysis(header_block.GetParent()) + ->Dominates(branch_instruction_true_block_id, + phi_inst->GetSingleWordInOperand(1))) { + // The "true" branch is handled first in the OpPhi's operands; we + // thus provide operands to OpSelect in the same order that they + // appear in the OpPhi. + operands.emplace_back(phi_inst->GetInOperand(0)); + operands.emplace_back(phi_inst->GetInOperand(2)); + } else { + // The "false" branch is handled first in the OpPhi's operands; we + // thus provide operands to OpSelect in reverse of the order that + // they appear in the OpPhi. + operands.emplace_back(phi_inst->GetInOperand(2)); + operands.emplace_back(phi_inst->GetInOperand(0)); + } + phi_inst->SetOpcode(SpvOpSelect); + phi_inst->SetInOperands(std::move(operands)); + }); + + // Add boolean vector instructions to the start of the block as required. + if (require_2d_boolean_vector) { + AddBooleanVectorConstructorToBlock(message_.fresh_id_for_bvec2_selector(), + 2, branch_condition_operand, ir_context, + convergence_block); + } + if (require_3d_boolean_vector) { + AddBooleanVectorConstructorToBlock(message_.fresh_id_for_bvec3_selector(), + 3, branch_condition_operand, ir_context, + convergence_block); + } + if (require_4d_boolean_vector) { + AddBooleanVectorConstructorToBlock(message_.fresh_id_for_bvec4_selector(), + 4, branch_condition_operand, ir_context, + convergence_block); + } +} + } // namespace fuzz } // namespace spvtools diff --git a/source/fuzz/transformation_flatten_conditional_branch.h b/source/fuzz/transformation_flatten_conditional_branch.h index fd8c1f5d0..b8ba7e3bc 100644 --- a/source/fuzz/transformation_flatten_conditional_branch.h +++ b/source/fuzz/transformation_flatten_conditional_branch.h @@ -27,6 +27,9 @@ class TransformationFlattenConditionalBranch : public Transformation { TransformationFlattenConditionalBranch( uint32_t header_block_id, bool true_branch_first, + uint32_t fresh_id_for_bvec2_selector, + uint32_t fresh_id_for_bvec3_selector, + uint32_t fresh_id_for_bvec4_selector, const std::vector& side_effect_wrappers_info); @@ -81,6 +84,19 @@ class TransformationFlattenConditionalBranch : public Transformation { static bool InstructionNeedsPlaceholder(opt::IRContext* ir_context, const opt::Instruction& instruction); + // Returns true if and only if the SPIR-V version is such that the arguments + // to OpSelect are restricted to only scalars, pointers (if the appropriate + // capability is enabled) and component-wise vectors. + static bool OpSelectArgumentsAreRestricted(opt::IRContext* ir_context); + + // Find the first block where flow converges (it is not necessarily the merge + // block) by walking the true branch until reaching a block that post- + // dominates the header. + // This is necessary because a potential common set of blocks at the end of + // the construct should not be duplicated. + static uint32_t FindConvergenceBlock(opt::IRContext* ir_context, + const opt::BasicBlock& header_block); + private: // Returns an unordered_map mapping instructions to the info required to // enclose them inside a conditional. It maps the instructions to the @@ -108,6 +124,22 @@ class TransformationFlattenConditionalBranch : public Transformation { std::vector* dead_blocks, std::vector* irrelevant_ids) const; + // Turns every OpPhi instruction of |convergence_block| -- the convergence + // block for |header_block| (both in |ir_context|) into an OpSelect + // instruction. + void RewriteOpPhiInstructionsAtConvergenceBlock( + const opt::BasicBlock& header_block, uint32_t convergence_block_id, + opt::IRContext* ir_context) const; + + // Adds an OpCompositeExtract instruction to the start of |block| in + // |ir_context|, with result id given by |fresh_id|. The instruction will + // make a |dimension|-dimensional boolean vector with + // |branch_condition_operand| at every component. + void AddBooleanVectorConstructorToBlock( + uint32_t fresh_id, uint32_t dimension, + const opt::Operand& branch_condition_operand, opt::IRContext* ir_context, + opt::BasicBlock* block) const; + // Returns true if the given instruction either has no side effects or it can // be handled by being enclosed in a conditional. static bool InstructionCanBeHandled(opt::IRContext* ir_context, diff --git a/test/fuzz/transformation_flatten_conditional_branch_test.cpp b/test/fuzz/transformation_flatten_conditional_branch_test.cpp index ba3493366..0930949f8 100644 --- a/test/fuzz/transformation_flatten_conditional_branch_test.cpp +++ b/test/fuzz/transformation_flatten_conditional_branch_test.cpp @@ -145,42 +145,42 @@ TEST(TransformationFlattenConditionalBranchTest, Inapplicable) { TransformationContext transformation_context( MakeUnique(context.get()), validator_options); // Block %15 does not end with OpBranchConditional. - ASSERT_FALSE(TransformationFlattenConditionalBranch(15, true, {}) + ASSERT_FALSE(TransformationFlattenConditionalBranch(15, true, 0, 0, 0, {}) .IsApplicable(context.get(), transformation_context)); // Block %17 is not a selection header. - ASSERT_FALSE(TransformationFlattenConditionalBranch(17, true, {}) + ASSERT_FALSE(TransformationFlattenConditionalBranch(17, true, 0, 0, 0, {}) .IsApplicable(context.get(), transformation_context)); // Block %16 is a loop header, not a selection header. - ASSERT_FALSE(TransformationFlattenConditionalBranch(16, true, {}) + ASSERT_FALSE(TransformationFlattenConditionalBranch(16, true, 0, 0, 0, {}) .IsApplicable(context.get(), transformation_context)); // Block %19 and the corresponding merge block do not describe a single-entry, // single-exit region, because there is a return instruction in %21. - ASSERT_FALSE(TransformationFlattenConditionalBranch(19, true, {}) + ASSERT_FALSE(TransformationFlattenConditionalBranch(19, true, 0, 0, 0, {}) .IsApplicable(context.get(), transformation_context)); // Block %20 is the header of a construct containing an inner selection // construct. - ASSERT_FALSE(TransformationFlattenConditionalBranch(20, true, {}) + ASSERT_FALSE(TransformationFlattenConditionalBranch(20, true, 0, 0, 0, {}) .IsApplicable(context.get(), transformation_context)); // Block %22 is the header of a construct containing an inner loop. - ASSERT_FALSE(TransformationFlattenConditionalBranch(22, true, {}) + ASSERT_FALSE(TransformationFlattenConditionalBranch(22, true, 0, 0, 0, {}) .IsApplicable(context.get(), transformation_context)); // Block %30 is the header of a construct containing a barrier instruction. - ASSERT_FALSE(TransformationFlattenConditionalBranch(30, true, {}) + ASSERT_FALSE(TransformationFlattenConditionalBranch(30, true, 0, 0, 0, {}) .IsApplicable(context.get(), transformation_context)); // %33 is not a block. - ASSERT_FALSE(TransformationFlattenConditionalBranch(33, true, {}) + ASSERT_FALSE(TransformationFlattenConditionalBranch(33, true, 0, 0, 0, {}) .IsApplicable(context.get(), transformation_context)); // Block %36 and the corresponding merge block do not describe a single-entry, // single-exit region, because block %37 breaks out of the outer loop. - ASSERT_FALSE(TransformationFlattenConditionalBranch(36, true, {}) + ASSERT_FALSE(TransformationFlattenConditionalBranch(36, true, 0, 0, 0, {}) .IsApplicable(context.get(), transformation_context)); } @@ -250,25 +250,29 @@ TEST(TransformationFlattenConditionalBranchTest, Simple) { kConsoleMessageConsumer)); TransformationContext transformation_context( MakeUnique(context.get()), validator_options); - auto transformation1 = TransformationFlattenConditionalBranch(7, true, {}); + auto transformation1 = + TransformationFlattenConditionalBranch(7, true, 0, 0, 0, {}); ASSERT_TRUE( transformation1.IsApplicable(context.get(), transformation_context)); ApplyAndCheckFreshIds(transformation1, context.get(), &transformation_context); - auto transformation2 = TransformationFlattenConditionalBranch(13, false, {}); + auto transformation2 = + TransformationFlattenConditionalBranch(13, false, 0, 0, 0, {}); ASSERT_TRUE( transformation2.IsApplicable(context.get(), transformation_context)); ApplyAndCheckFreshIds(transformation2, context.get(), &transformation_context); - auto transformation3 = TransformationFlattenConditionalBranch(15, true, {}); + auto transformation3 = + TransformationFlattenConditionalBranch(15, true, 0, 0, 0, {}); ASSERT_TRUE( transformation3.IsApplicable(context.get(), transformation_context)); ApplyAndCheckFreshIds(transformation3, context.get(), &transformation_context); - auto transformation4 = TransformationFlattenConditionalBranch(22, false, {}); + auto transformation4 = + TransformationFlattenConditionalBranch(22, false, 0, 0, 0, {}); ASSERT_TRUE( transformation4.IsApplicable(context.get(), transformation_context)); ApplyAndCheckFreshIds(transformation4, context.get(), @@ -430,12 +434,12 @@ TEST(TransformationFlattenConditionalBranchTest, LoadStoreFunctionCall) { // requiring fresh ids are not present in the map, and the transformation // context does not have a source overflow ids. - ASSERT_DEATH(TransformationFlattenConditionalBranch(31, true, {}) + ASSERT_DEATH(TransformationFlattenConditionalBranch(31, true, 0, 0, 0, {}) .IsApplicable(context.get(), transformation_context), "Bad attempt to query whether overflow ids are available."); ASSERT_DEATH(TransformationFlattenConditionalBranch( - 31, true, + 31, true, 0, 0, 0, {{MakeSideEffectWrapperInfo( MakeInstructionDescriptor(6, SpvOpLoad, 0), 100, 101, 102, 103, 104, 14)}}) @@ -445,7 +449,7 @@ TEST(TransformationFlattenConditionalBranchTest, LoadStoreFunctionCall) { // The map maps from an instruction to a list with not enough fresh ids. ASSERT_FALSE(TransformationFlattenConditionalBranch( - 31, true, + 31, true, 0, 0, 0, {{MakeSideEffectWrapperInfo( MakeInstructionDescriptor(6, SpvOpLoad, 0), 100, 101, 102, 103, 0, 0)}}) @@ -453,7 +457,7 @@ TEST(TransformationFlattenConditionalBranchTest, LoadStoreFunctionCall) { // Not all fresh ids given are distinct. ASSERT_FALSE(TransformationFlattenConditionalBranch( - 31, true, + 31, true, 0, 0, 0, {{MakeSideEffectWrapperInfo( MakeInstructionDescriptor(6, SpvOpLoad, 0), 100, 100, 102, 103, 104, 0)}}) @@ -461,7 +465,7 @@ TEST(TransformationFlattenConditionalBranchTest, LoadStoreFunctionCall) { // %48 heads a construct containing an OpSampledImage instruction. ASSERT_FALSE(TransformationFlattenConditionalBranch( - 48, true, + 48, true, 0, 0, 0, {{MakeSideEffectWrapperInfo( MakeInstructionDescriptor(53, SpvOpLoad, 0), 100, 101, 102, 103, 104, 0)}}) @@ -470,7 +474,7 @@ TEST(TransformationFlattenConditionalBranchTest, LoadStoreFunctionCall) { // %0 is not a valid id. ASSERT_FALSE( TransformationFlattenConditionalBranch( - 31, true, + 31, true, 0, 0, 0, {MakeSideEffectWrapperInfo(MakeInstructionDescriptor(6, SpvOpLoad, 0), 104, 100, 101, 102, 103, 0), MakeSideEffectWrapperInfo( @@ -480,7 +484,7 @@ TEST(TransformationFlattenConditionalBranchTest, LoadStoreFunctionCall) { // %17 is a float constant, while %6 has int type. ASSERT_FALSE( TransformationFlattenConditionalBranch( - 31, true, + 31, true, 0, 0, 0, {MakeSideEffectWrapperInfo(MakeInstructionDescriptor(6, SpvOpLoad, 0), 104, 100, 101, 102, 103, 17), MakeSideEffectWrapperInfo( @@ -488,7 +492,7 @@ TEST(TransformationFlattenConditionalBranchTest, LoadStoreFunctionCall) { .IsApplicable(context.get(), transformation_context)); auto transformation1 = TransformationFlattenConditionalBranch( - 31, true, + 31, true, 0, 0, 0, {MakeSideEffectWrapperInfo(MakeInstructionDescriptor(6, SpvOpLoad, 0), 104, 100, 101, 102, 103, 70), MakeSideEffectWrapperInfo(MakeInstructionDescriptor(6, SpvOpStore, 0), @@ -509,7 +513,7 @@ TEST(TransformationFlattenConditionalBranchTest, LoadStoreFunctionCall) { std::move(overflow_ids_unique_ptr)); auto transformation2 = TransformationFlattenConditionalBranch( - 36, false, + 36, false, 0, 0, 0, {MakeSideEffectWrapperInfo(MakeInstructionDescriptor(8, SpvOpStore, 0), 114, 113)}); ASSERT_TRUE( @@ -706,13 +710,13 @@ TEST(TransformationFlattenConditionalBranchTest, EdgeCases) { // The selection construct headed by %7 requires fresh ids because it contains // a function call. This causes an assertion failure because transformation // context does not have a source of overflow ids. - ASSERT_DEATH(TransformationFlattenConditionalBranch(7, true, {}) + ASSERT_DEATH(TransformationFlattenConditionalBranch(7, true, 0, 0, 0, {}) .IsApplicable(context.get(), transformation_context), "Bad attempt to query whether overflow ids are available."); #endif auto transformation1 = TransformationFlattenConditionalBranch( - 7, true, + 7, true, 0, 0, 0, {{MakeSideEffectWrapperInfo( MakeInstructionDescriptor(10, SpvOpFunctionCall, 0), 100, 101)}}); ASSERT_TRUE( @@ -724,16 +728,17 @@ TEST(TransformationFlattenConditionalBranchTest, EdgeCases) { // contains a function call returning void, whose result id is used. ASSERT_FALSE( TransformationFlattenConditionalBranch( - 7, true, + 7, true, 0, 0, 0, {{MakeSideEffectWrapperInfo( MakeInstructionDescriptor(14, SpvOpFunctionCall, 0), 102, 103)}}) .IsApplicable(context.get(), transformation_context)); // Block %16 is unreachable. - ASSERT_FALSE(TransformationFlattenConditionalBranch(16, true, {}) + ASSERT_FALSE(TransformationFlattenConditionalBranch(16, true, 0, 0, 0, {}) .IsApplicable(context.get(), transformation_context)); - auto transformation2 = TransformationFlattenConditionalBranch(20, false, {}); + auto transformation2 = + TransformationFlattenConditionalBranch(20, false, 0, 0, 0, {}); ASSERT_TRUE( transformation2.IsApplicable(context.get(), transformation_context)); ApplyAndCheckFreshIds(transformation2, context.get(), @@ -837,7 +842,8 @@ TEST(TransformationFlattenConditionalBranchTest, PhiToSelect1) { TransformationContext transformation_context( MakeUnique(context.get()), validator_options); - auto transformation = TransformationFlattenConditionalBranch(7, true, {}); + auto transformation = + TransformationFlattenConditionalBranch(7, true, 0, 0, 0, {}); ASSERT_TRUE( transformation.IsApplicable(context.get(), transformation_context)); ApplyAndCheckFreshIds(transformation, context.get(), &transformation_context); @@ -903,7 +909,8 @@ TEST(TransformationFlattenConditionalBranchTest, PhiToSelect2) { TransformationContext transformation_context( MakeUnique(context.get()), validator_options); - auto transformation = TransformationFlattenConditionalBranch(7, true, {}); + auto transformation = + TransformationFlattenConditionalBranch(7, true, 0, 0, 0, {}); ASSERT_TRUE( transformation.IsApplicable(context.get(), transformation_context)); ApplyAndCheckFreshIds(transformation, context.get(), &transformation_context); @@ -971,7 +978,8 @@ TEST(TransformationFlattenConditionalBranchTest, PhiToSelect3) { TransformationContext transformation_context( MakeUnique(context.get()), validator_options); - auto transformation = TransformationFlattenConditionalBranch(7, true, {}); + auto transformation = + TransformationFlattenConditionalBranch(7, true, 0, 0, 0, {}); ASSERT_TRUE( transformation.IsApplicable(context.get(), transformation_context)); ApplyAndCheckFreshIds(transformation, context.get(), &transformation_context); @@ -1041,7 +1049,8 @@ TEST(TransformationFlattenConditionalBranchTest, PhiToSelect4) { TransformationContext transformation_context( MakeUnique(context.get()), validator_options); - auto transformation = TransformationFlattenConditionalBranch(7, true, {}); + auto transformation = + TransformationFlattenConditionalBranch(7, true, 0, 0, 0, {}); ASSERT_TRUE( transformation.IsApplicable(context.get(), transformation_context)); ApplyAndCheckFreshIds(transformation, context.get(), &transformation_context); @@ -1118,7 +1127,7 @@ TEST(TransformationFlattenConditionalBranchTest, PhiToSelect5) { MakeUnique(context.get()), validator_options); auto transformation = TransformationFlattenConditionalBranch( - 7, true, + 7, true, 0, 0, 0, {MakeSideEffectWrapperInfo(MakeInstructionDescriptor(522, SpvOpLoad, 0), 200, 201, 202, 203, 204, 5), MakeSideEffectWrapperInfo(MakeInstructionDescriptor(466, SpvOpLoad, 0), @@ -1223,7 +1232,7 @@ TEST(TransformationFlattenConditionalBranchTest, MakeUnique(context.get()), validator_options); auto transformation = TransformationFlattenConditionalBranch( - 5, true, + 5, true, 0, 0, 0, {MakeSideEffectWrapperInfo(MakeInstructionDescriptor(20, SpvOpLoad, 0), 100, 101, 102, 103, 104, 21)}); ASSERT_TRUE( @@ -1290,13 +1299,553 @@ TEST(TransformationFlattenConditionalBranchTest, InapplicableSampledImageLoad) { MakeUnique(context.get()), validator_options); ASSERT_FALSE(TransformationFlattenConditionalBranch( - 28, true, + 28, true, 0, 0, 0, {MakeSideEffectWrapperInfo( MakeInstructionDescriptor(40, SpvOpLoad, 0), 100, 101, 102, 103, 104, 200)}) .IsApplicable(context.get(), transformation_context)); } +TEST(TransformationFlattenConditionalBranchTest, + InapplicablePhiToSelectVector) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 320 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeBool + %7 = OpConstantTrue %6 + %10 = OpTypeInt 32 1 + %11 = OpTypeVector %10 3 + %12 = OpUndef %11 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpSelectionMerge %20 None + OpBranchConditional %7 %8 %9 + %8 = OpLabel + OpBranch %20 + %9 = OpLabel + OpBranch %20 + %20 = OpLabel + %21 = OpPhi %11 %12 %8 %12 %9 + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto consumer = nullptr; + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + spvtools::ValidatorOptions validator_options; + ASSERT_TRUE(fuzzerutil::IsValidAndWellFormed(context.get(), validator_options, + kConsoleMessageConsumer)); + + TransformationContext transformation_context( + MakeUnique(context.get()), validator_options); + + auto transformation = + TransformationFlattenConditionalBranch(5, true, 0, 0, 0, {}); + ASSERT_FALSE( + transformation.IsApplicable(context.get(), transformation_context)); +} + +TEST(TransformationFlattenConditionalBranchTest, + InapplicablePhiToSelectVector2) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 320 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeBool + %30 = OpTypeVector %6 3 + %31 = OpTypeVector %6 2 + %7 = OpConstantTrue %6 + %10 = OpTypeInt 32 1 + %11 = OpTypeVector %10 3 + %40 = OpTypeFloat 32 + %41 = OpTypeVector %40 4 + %12 = OpUndef %11 + %60 = OpUndef %41 + %61 = OpConstantComposite %31 %7 %7 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpSelectionMerge %20 None + OpBranchConditional %7 %8 %9 + %8 = OpLabel + OpBranch %20 + %9 = OpLabel + OpBranch %20 + %20 = OpLabel + %21 = OpPhi %11 %12 %8 %12 %9 + %22 = OpPhi %11 %12 %8 %12 %9 + %23 = OpPhi %41 %60 %8 %60 %9 + %24 = OpPhi %31 %61 %8 %61 %9 + %25 = OpPhi %41 %60 %8 %60 %9 + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto consumer = nullptr; + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + spvtools::ValidatorOptions validator_options; + ASSERT_TRUE(fuzzerutil::IsValidAndWellFormed(context.get(), validator_options, + kConsoleMessageConsumer)); + + TransformationContext transformation_context( + MakeUnique(context.get()), validator_options); + + auto transformation = + TransformationFlattenConditionalBranch(5, true, 101, 102, 103, {}); + + // bvec4 is not present in the module. + ASSERT_FALSE( + transformation.IsApplicable(context.get(), transformation_context)); + ApplyAndCheckFreshIds(transformation, context.get(), &transformation_context); +} + +TEST(TransformationFlattenConditionalBranchTest, + InapplicablePhiToSelectMatrix) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 320 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeBool + %7 = OpConstantTrue %6 + %10 = OpTypeFloat 32 + %30 = OpTypeVector %10 3 + %11 = OpTypeMatrix %30 3 + %12 = OpUndef %11 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpSelectionMerge %20 None + OpBranchConditional %7 %8 %9 + %8 = OpLabel + OpBranch %20 + %9 = OpLabel + OpBranch %20 + %20 = OpLabel + %21 = OpPhi %11 %12 %8 %12 %9 + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto consumer = nullptr; + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + spvtools::ValidatorOptions validator_options; + ASSERT_TRUE(fuzzerutil::IsValidAndWellFormed(context.get(), validator_options, + kConsoleMessageConsumer)); + + TransformationContext transformation_context( + MakeUnique(context.get()), validator_options); + + auto transformation = + TransformationFlattenConditionalBranch(5, true, 0, 0, 0, {}); + ASSERT_FALSE( + transformation.IsApplicable(context.get(), transformation_context)); +} + +TEST(TransformationFlattenConditionalBranchTest, ApplicablePhiToSelectVector) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 320 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeBool + %7 = OpConstantTrue %6 + %10 = OpTypeInt 32 1 + %11 = OpTypeVector %10 3 + %12 = OpUndef %11 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpSelectionMerge %20 None + OpBranchConditional %7 %8 %9 + %8 = OpLabel + OpBranch %20 + %9 = OpLabel + OpBranch %20 + %20 = OpLabel + %21 = OpPhi %11 %12 %8 %12 %9 + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_5; + const auto consumer = nullptr; + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + spvtools::ValidatorOptions validator_options; + ASSERT_TRUE(fuzzerutil::IsValidAndWellFormed(context.get(), validator_options, + kConsoleMessageConsumer)); + + TransformationContext transformation_context( + MakeUnique(context.get()), validator_options); + + auto transformation = + TransformationFlattenConditionalBranch(5, true, 0, 0, 0, {}); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + ApplyAndCheckFreshIds(transformation, context.get(), &transformation_context); + ASSERT_TRUE(fuzzerutil::IsValidAndWellFormed(context.get(), validator_options, + kConsoleMessageConsumer)); + + std::string expected_shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 320 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeBool + %7 = OpConstantTrue %6 + %10 = OpTypeInt 32 1 + %11 = OpTypeVector %10 3 + %12 = OpUndef %11 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpBranch %8 + %8 = OpLabel + OpBranch %9 + %9 = OpLabel + OpBranch %20 + %20 = OpLabel + %21 = OpSelect %11 %7 %12 %12 + OpReturn + OpFunctionEnd + )"; + ASSERT_TRUE(IsEqual(env, expected_shader, context.get())); +} + +TEST(TransformationFlattenConditionalBranchTest, ApplicablePhiToSelectVector2) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 320 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeBool + %30 = OpTypeVector %6 3 + %31 = OpTypeVector %6 2 + %32 = OpTypeVector %6 4 + %7 = OpConstantTrue %6 + %10 = OpTypeInt 32 1 + %11 = OpTypeVector %10 3 + %40 = OpTypeFloat 32 + %41 = OpTypeVector %40 4 + %12 = OpUndef %11 + %60 = OpUndef %41 + %61 = OpConstantComposite %31 %7 %7 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpSelectionMerge %20 None + OpBranchConditional %7 %8 %9 + %8 = OpLabel + OpBranch %20 + %9 = OpLabel + OpBranch %20 + %20 = OpLabel + %21 = OpPhi %11 %12 %8 %12 %9 + %22 = OpPhi %11 %12 %8 %12 %9 + %23 = OpPhi %41 %60 %8 %60 %9 + %24 = OpPhi %31 %61 %8 %61 %9 + %25 = OpPhi %41 %60 %8 %60 %9 + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto consumer = nullptr; + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + spvtools::ValidatorOptions validator_options; + ASSERT_TRUE(fuzzerutil::IsValidAndWellFormed(context.get(), validator_options, + kConsoleMessageConsumer)); + + TransformationContext transformation_context( + MakeUnique(context.get()), validator_options); + + // No id for the 2D vector case is provided. + ASSERT_FALSE(TransformationFlattenConditionalBranch(5, true, 0, 102, 103, {}) + .IsApplicable(context.get(), transformation_context)); + + // No id for the 3D vector case is provided. + ASSERT_FALSE(TransformationFlattenConditionalBranch(5, true, 101, 0, 103, {}) + .IsApplicable(context.get(), transformation_context)); + + // No id for the 4D vector case is provided. + ASSERT_FALSE(TransformationFlattenConditionalBranch(5, true, 101, 102, 0, {}) + .IsApplicable(context.get(), transformation_context)); + + // %10 is not fresh + ASSERT_FALSE(TransformationFlattenConditionalBranch(5, true, 10, 102, 103, {}) + .IsApplicable(context.get(), transformation_context)); + + // %10 is not fresh + ASSERT_FALSE(TransformationFlattenConditionalBranch(5, true, 101, 10, 103, {}) + .IsApplicable(context.get(), transformation_context)); + + // %10 is not fresh + ASSERT_FALSE(TransformationFlattenConditionalBranch(5, true, 101, 102, 10, {}) + .IsApplicable(context.get(), transformation_context)); + + // Duplicate "fresh" ids used for boolean vector constructors + ASSERT_FALSE( + TransformationFlattenConditionalBranch(5, true, 101, 102, 102, {}) + .IsApplicable(context.get(), transformation_context)); + + auto transformation = + TransformationFlattenConditionalBranch(5, true, 101, 102, 103, {}); + + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + ApplyAndCheckFreshIds(transformation, context.get(), &transformation_context); + ASSERT_TRUE(fuzzerutil::IsValidAndWellFormed(context.get(), validator_options, + kConsoleMessageConsumer)); + + std::string expected_shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 320 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeBool + %30 = OpTypeVector %6 3 + %31 = OpTypeVector %6 2 + %32 = OpTypeVector %6 4 + %7 = OpConstantTrue %6 + %10 = OpTypeInt 32 1 + %11 = OpTypeVector %10 3 + %40 = OpTypeFloat 32 + %41 = OpTypeVector %40 4 + %12 = OpUndef %11 + %60 = OpUndef %41 + %61 = OpConstantComposite %31 %7 %7 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpBranch %8 + %8 = OpLabel + OpBranch %9 + %9 = OpLabel + OpBranch %20 + %20 = OpLabel + %103 = OpCompositeConstruct %32 %7 %7 %7 %7 + %102 = OpCompositeConstruct %30 %7 %7 %7 + %101 = OpCompositeConstruct %31 %7 %7 + %21 = OpSelect %11 %102 %12 %12 + %22 = OpSelect %11 %102 %12 %12 + %23 = OpSelect %41 %103 %60 %60 + %24 = OpSelect %31 %101 %61 %61 + %25 = OpSelect %41 %103 %60 %60 + OpReturn + OpFunctionEnd + )"; + + ASSERT_TRUE(IsEqual(env, expected_shader, context.get())); +} + +TEST(TransformationFlattenConditionalBranchTest, ApplicablePhiToSelectVector3) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 320 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeBool + %30 = OpTypeVector %6 3 + %31 = OpTypeVector %6 2 + %32 = OpTypeVector %6 4 + %7 = OpConstantTrue %6 + %10 = OpTypeInt 32 1 + %11 = OpTypeVector %10 3 + %40 = OpTypeFloat 32 + %41 = OpTypeVector %40 4 + %12 = OpUndef %11 + %60 = OpUndef %41 + %61 = OpConstantComposite %31 %7 %7 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpSelectionMerge %20 None + OpBranchConditional %7 %8 %9 + %8 = OpLabel + OpBranch %20 + %9 = OpLabel + OpBranch %20 + %20 = OpLabel + %21 = OpPhi %11 %12 %8 %12 %9 + %22 = OpPhi %11 %12 %8 %12 %9 + %23 = OpPhi %41 %60 %8 %60 %9 + %24 = OpPhi %31 %61 %8 %61 %9 + %25 = OpPhi %41 %60 %8 %60 %9 + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_5; + const auto consumer = nullptr; + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + spvtools::ValidatorOptions validator_options; + ASSERT_TRUE(fuzzerutil::IsValidAndWellFormed(context.get(), validator_options, + kConsoleMessageConsumer)); + + TransformationContext transformation_context( + MakeUnique(context.get()), validator_options); + + auto transformation = + TransformationFlattenConditionalBranch(5, true, 101, 0, 103, {}); + + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + ApplyAndCheckFreshIds(transformation, context.get(), &transformation_context); + ASSERT_TRUE(fuzzerutil::IsValidAndWellFormed(context.get(), validator_options, + kConsoleMessageConsumer)); + + std::string expected_shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 320 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeBool + %30 = OpTypeVector %6 3 + %31 = OpTypeVector %6 2 + %32 = OpTypeVector %6 4 + %7 = OpConstantTrue %6 + %10 = OpTypeInt 32 1 + %11 = OpTypeVector %10 3 + %40 = OpTypeFloat 32 + %41 = OpTypeVector %40 4 + %12 = OpUndef %11 + %60 = OpUndef %41 + %61 = OpConstantComposite %31 %7 %7 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpBranch %8 + %8 = OpLabel + OpBranch %9 + %9 = OpLabel + OpBranch %20 + %20 = OpLabel + %103 = OpCompositeConstruct %32 %7 %7 %7 %7 + %101 = OpCompositeConstruct %31 %7 %7 + %21 = OpSelect %11 %7 %12 %12 + %22 = OpSelect %11 %7 %12 %12 + %23 = OpSelect %41 %103 %60 %60 + %24 = OpSelect %31 %101 %61 %61 + %25 = OpSelect %41 %103 %60 %60 + OpReturn + OpFunctionEnd + )"; + + ASSERT_TRUE(IsEqual(env, expected_shader, context.get())); +} + +TEST(TransformationFlattenConditionalBranchTest, ApplicablePhiToSelectMatrix) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 320 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeBool + %7 = OpConstantTrue %6 + %10 = OpTypeFloat 32 + %30 = OpTypeVector %10 3 + %11 = OpTypeMatrix %30 3 + %12 = OpUndef %11 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpSelectionMerge %20 None + OpBranchConditional %7 %8 %9 + %8 = OpLabel + OpBranch %20 + %9 = OpLabel + OpBranch %20 + %20 = OpLabel + %21 = OpPhi %11 %12 %8 %12 %9 + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_5; + const auto consumer = nullptr; + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + spvtools::ValidatorOptions validator_options; + ASSERT_TRUE(fuzzerutil::IsValidAndWellFormed(context.get(), validator_options, + kConsoleMessageConsumer)); + + TransformationContext transformation_context( + MakeUnique(context.get()), validator_options); + + auto transformation = + TransformationFlattenConditionalBranch(5, true, 0, 0, 0, {}); + ASSERT_TRUE( + transformation.IsApplicable(context.get(), transformation_context)); + ApplyAndCheckFreshIds(transformation, context.get(), &transformation_context); + ASSERT_TRUE(fuzzerutil::IsValidAndWellFormed(context.get(), validator_options, + kConsoleMessageConsumer)); + + std::string expected_shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 320 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeBool + %7 = OpConstantTrue %6 + %10 = OpTypeFloat 32 + %30 = OpTypeVector %10 3 + %11 = OpTypeMatrix %30 3 + %12 = OpUndef %11 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpBranch %8 + %8 = OpLabel + OpBranch %9 + %9 = OpLabel + OpBranch %20 + %20 = OpLabel + %21 = OpSelect %11 %7 %12 %12 + OpReturn + OpFunctionEnd + )"; + ASSERT_TRUE(IsEqual(env, expected_shader, context.get())); +} + } // namespace } // namespace fuzz } // namespace spvtools