spirv-fuzz: Use overflow ids when inlining functions (#3880)

Fixes #3751.
This commit is contained in:
Alastair Donaldson 2020-10-02 16:53:54 +01:00 committed by GitHub
parent 67f8e2eddc
commit 0e85530728
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 237 additions and 18 deletions

View File

@ -33,7 +33,8 @@ TransformationInlineFunction::TransformationInlineFunction(
}
bool TransformationInlineFunction::IsApplicable(
opt::IRContext* ir_context, const TransformationContext& /*unused*/) const {
opt::IRContext* ir_context,
const TransformationContext& transformation_context) const {
// The values in the |message_.result_id_map| must be all fresh and all
// distinct.
const auto result_id_map =
@ -71,7 +72,8 @@ bool TransformationInlineFunction::IsApplicable(
// Since the entry block label will not be inlined, only the remaining
// labels must have a corresponding value in the map.
if (&block != &*called_function->entry() &&
!result_id_map.count(block.GetLabel()->result_id())) {
!result_id_map.count(block.id()) &&
!transformation_context.GetOverflowIdSource()->HasOverflowIds()) {
return false;
}
@ -81,7 +83,8 @@ bool TransformationInlineFunction::IsApplicable(
// If |instruction| has result id, then it must have a mapped id in
// |result_id_map|.
if (instruction.HasResultId() &&
!result_id_map.count(instruction.result_id())) {
!result_id_map.count(instruction.result_id()) &&
!transformation_context.GetOverflowIdSource()->HasOverflowIds()) {
return false;
}
}
@ -100,15 +103,37 @@ bool TransformationInlineFunction::IsApplicable(
}
void TransformationInlineFunction::Apply(
opt::IRContext* ir_context, TransformationContext* /*unused*/) const {
opt::IRContext* ir_context,
TransformationContext* transformation_context) const {
auto* function_call_instruction =
ir_context->get_def_use_mgr()->GetDef(message_.function_call_id());
auto* caller_function =
ir_context->get_instr_block(function_call_instruction)->GetParent();
auto* called_function = fuzzerutil::FindFunction(
ir_context, function_call_instruction->GetSingleWordInOperand(0));
const auto result_id_map =
std::map<uint32_t, uint32_t> result_id_map =
fuzzerutil::RepeatedUInt32PairToMap(message_.result_id_map());
// If there are gaps in the result id map, fill them using overflow ids.
for (auto& block : *called_function) {
if (&block != &*called_function->entry() &&
!result_id_map.count(block.id())) {
result_id_map.insert(
{block.id(),
transformation_context->GetOverflowIdSource()->GetNextOverflowId()});
}
for (auto& instruction : block) {
// If |instruction| has result id, then it must have a mapped id in
// |result_id_map|.
if (instruction.HasResultId() &&
!result_id_map.count(instruction.result_id())) {
result_id_map.insert({instruction.result_id(),
transformation_context->GetOverflowIdSource()
->GetNextOverflowId()});
}
}
}
auto* successor_block = ir_context->cfg()->block(
ir_context->get_instr_block(function_call_instruction)
->terminator()
@ -128,7 +153,7 @@ void TransformationInlineFunction::Apply(
MakeUnique<opt::Instruction>(entry_block_instruction));
}
AdaptInlinedInstruction(ir_context, inlined_instruction);
AdaptInlinedInstruction(result_id_map, ir_context, inlined_instruction);
}
// Inline the |called_function| non-entry blocks.
@ -141,13 +166,11 @@ void TransformationInlineFunction::Apply(
cloned_block = caller_function->InsertBasicBlockBefore(
std::unique_ptr<opt::BasicBlock>(cloned_block), successor_block);
cloned_block->SetParent(caller_function);
cloned_block->GetLabel()->SetResultId(
result_id_map.at(cloned_block->GetLabel()->result_id()));
fuzzerutil::UpdateModuleIdBound(ir_context,
cloned_block->GetLabel()->result_id());
cloned_block->GetLabel()->SetResultId(result_id_map.at(cloned_block->id()));
fuzzerutil::UpdateModuleIdBound(ir_context, cloned_block->id());
for (auto& inlined_instruction : *cloned_block) {
AdaptInlinedInstruction(ir_context, &inlined_instruction);
AdaptInlinedInstruction(result_id_map, ir_context, &inlined_instruction);
}
}
@ -202,14 +225,13 @@ bool TransformationInlineFunction::IsSuitableForInlining(
}
void TransformationInlineFunction::AdaptInlinedInstruction(
const std::map<uint32_t, uint32_t>& result_id_map,
opt::IRContext* ir_context,
opt::Instruction* instruction_to_be_inlined) const {
auto* function_call_instruction =
ir_context->get_def_use_mgr()->GetDef(message_.function_call_id());
auto* called_function = fuzzerutil::FindFunction(
ir_context, function_call_instruction->GetSingleWordInOperand(0));
const auto result_id_map =
fuzzerutil::RepeatedUInt32PairToMap(message_.result_id_map());
const auto* function_call_block =
ir_context->get_instr_block(function_call_instruction);

View File

@ -33,7 +33,7 @@ class TransformationInlineFunction : public Transformation {
const std::map<uint32_t, uint32_t>& result_id_map);
// - |message_.result_id_map| must map the instructions of the called function
// to fresh ids.
// to fresh ids, unless overflow ids are available.
// - |message_.function_call_id| must be an OpFunctionCall instruction.
// It must not have an early return and must not use OpUnreachable or
// OpKill. This is to guard against making the module invalid when the
@ -67,8 +67,9 @@ class TransformationInlineFunction : public Transformation {
// Inline |instruction_to_be_inlined| by setting its ids to the corresponding
// ids in |result_id_map|.
void AdaptInlinedInstruction(opt::IRContext* ir_context,
opt::Instruction* instruction) const;
void AdaptInlinedInstruction(
const std::map<uint32_t, uint32_t>& result_id_map,
opt::IRContext* ir_context, opt::Instruction* instruction) const;
};
} // namespace fuzz

View File

@ -14,6 +14,7 @@
#include "source/fuzz/transformation_inline_function.h"
#include "source/fuzz/counter_overflow_id_source.h"
#include "source/fuzz/instruction_descriptor.h"
#include "test/fuzz/fuzz_test_util.h"
@ -533,6 +534,7 @@ TEST(TransformationInlineFunctionTest, ApplyToMultipleFunctions) {
ASSERT_FALSE(
transformation.IsApplicable(context.get(), transformation_context));
#ifndef NDEBUG
// Tests the id of the returned value not included in the id map.
transformation = TransformationInlineFunction(25, {{56, 69},
{57, 70},
@ -544,8 +546,10 @@ TEST(TransformationInlineFunctionTest, ApplyToMultipleFunctions) {
{64, 76},
{65, 77},
{66, 78}});
ASSERT_FALSE(
transformation.IsApplicable(context.get(), transformation_context));
ASSERT_DEATH(
transformation.IsApplicable(context.get(), transformation_context),
"Bad attempt to query whether overflow ids are available.");
#endif
transformation = TransformationInlineFunction(25, {{57, 69},
{58, 70},
@ -819,6 +823,198 @@ TEST(TransformationInlineFunctionTest, HandlesOpPhisInTheSecondBlock) {
ASSERT_TRUE(IsEqual(env, after_transformation, context.get()));
}
TEST(TransformationInlineFunctionTest, OverflowIds) {
std::string reference_shader = R"(
OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint Vertex %39 "main"
; Types
%2 = OpTypeFloat 32
%3 = OpTypeVector %2 4
%4 = OpTypePointer Function %3
%5 = OpTypeVoid
%6 = OpTypeFunction %5
%7 = OpTypeFunction %2 %4 %4
; Constant scalars
%8 = OpConstant %2 1
%9 = OpConstant %2 2
%10 = OpConstant %2 3
%11 = OpConstant %2 4
%12 = OpConstant %2 5
%13 = OpConstant %2 6
%14 = OpConstant %2 7
%15 = OpConstant %2 8
; Constant vectors
%16 = OpConstantComposite %3 %8 %9 %10 %11
%17 = OpConstantComposite %3 %12 %13 %14 %15
; dot product function
%18 = OpFunction %2 None %7
%19 = OpFunctionParameter %4
%20 = OpFunctionParameter %4
%21 = OpLabel
%22 = OpLoad %3 %19
%23 = OpLoad %3 %20
%24 = OpCompositeExtract %2 %22 0
%25 = OpCompositeExtract %2 %23 0
%26 = OpFMul %2 %24 %25
%27 = OpCompositeExtract %2 %22 1
%28 = OpCompositeExtract %2 %23 1
%29 = OpFMul %2 %27 %28
OpBranch %100
%100 = OpLabel
%30 = OpCompositeExtract %2 %22 2
%31 = OpCompositeExtract %2 %23 2
%32 = OpFMul %2 %30 %31
%33 = OpCompositeExtract %2 %22 3
%34 = OpCompositeExtract %2 %23 3
%35 = OpFMul %2 %33 %34
%36 = OpFAdd %2 %26 %29
%37 = OpFAdd %2 %32 %36
%38 = OpFAdd %2 %35 %37
OpReturnValue %38
OpFunctionEnd
; main function
%39 = OpFunction %5 None %6
%40 = OpLabel
%41 = OpVariable %4 Function
%42 = OpVariable %4 Function
OpStore %41 %16
OpStore %42 %17
%43 = OpFunctionCall %2 %18 %41 %42 ; dot product function call
OpBranch %44
%44 = OpLabel
OpReturn
OpFunctionEnd
)";
const auto env = SPV_ENV_UNIVERSAL_1_5;
const auto consumer = nullptr;
const auto context =
BuildModule(env, consumer, reference_shader, kFuzzAssembleOption);
ASSERT_TRUE(IsValid(env, context.get()));
spvtools::ValidatorOptions validator_options;
auto overflow_ids_unique_ptr = MakeUnique<CounterOverflowIdSource>(1000);
auto overflow_ids_ptr = overflow_ids_unique_ptr.get();
TransformationContext transformation_context(
MakeUnique<FactManager>(context.get()), validator_options,
std::move(overflow_ids_unique_ptr));
auto transformation = TransformationInlineFunction(43, {{22, 45},
{23, 46},
{24, 47},
{25, 48},
{26, 49},
{27, 50},
{28, 51},
{29, 52}});
// The following ids are left un-mapped; overflow ids will be required for
// them: 30, 31, 32, 33, 34, 35, 36, 37, 38, 100
ASSERT_TRUE(
transformation.IsApplicable(context.get(), transformation_context));
ApplyAndCheckFreshIds(transformation, context.get(), &transformation_context,
overflow_ids_ptr->GetIssuedOverflowIds());
std::string variant_shader = R"(
OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint Vertex %39 "main"
; Types
%2 = OpTypeFloat 32
%3 = OpTypeVector %2 4
%4 = OpTypePointer Function %3
%5 = OpTypeVoid
%6 = OpTypeFunction %5
%7 = OpTypeFunction %2 %4 %4
; Constant scalars
%8 = OpConstant %2 1
%9 = OpConstant %2 2
%10 = OpConstant %2 3
%11 = OpConstant %2 4
%12 = OpConstant %2 5
%13 = OpConstant %2 6
%14 = OpConstant %2 7
%15 = OpConstant %2 8
; Constant vectors
%16 = OpConstantComposite %3 %8 %9 %10 %11
%17 = OpConstantComposite %3 %12 %13 %14 %15
; dot product function
%18 = OpFunction %2 None %7
%19 = OpFunctionParameter %4
%20 = OpFunctionParameter %4
%21 = OpLabel
%22 = OpLoad %3 %19
%23 = OpLoad %3 %20
%24 = OpCompositeExtract %2 %22 0
%25 = OpCompositeExtract %2 %23 0
%26 = OpFMul %2 %24 %25
%27 = OpCompositeExtract %2 %22 1
%28 = OpCompositeExtract %2 %23 1
%29 = OpFMul %2 %27 %28
OpBranch %100
%100 = OpLabel
%30 = OpCompositeExtract %2 %22 2
%31 = OpCompositeExtract %2 %23 2
%32 = OpFMul %2 %30 %31
%33 = OpCompositeExtract %2 %22 3
%34 = OpCompositeExtract %2 %23 3
%35 = OpFMul %2 %33 %34
%36 = OpFAdd %2 %26 %29
%37 = OpFAdd %2 %32 %36
%38 = OpFAdd %2 %35 %37
OpReturnValue %38
OpFunctionEnd
; main function
%39 = OpFunction %5 None %6
%40 = OpLabel
%41 = OpVariable %4 Function
%42 = OpVariable %4 Function
OpStore %41 %16
OpStore %42 %17
%45 = OpLoad %3 %41
%46 = OpLoad %3 %42
%47 = OpCompositeExtract %2 %45 0
%48 = OpCompositeExtract %2 %46 0
%49 = OpFMul %2 %47 %48
%50 = OpCompositeExtract %2 %45 1
%51 = OpCompositeExtract %2 %46 1
%52 = OpFMul %2 %50 %51
OpBranch %1000
%1000 = OpLabel
%1001 = OpCompositeExtract %2 %45 2
%1002 = OpCompositeExtract %2 %46 2
%1003 = OpFMul %2 %1001 %1002
%1004 = OpCompositeExtract %2 %45 3
%1005 = OpCompositeExtract %2 %46 3
%1006 = OpFMul %2 %1004 %1005
%1007 = OpFAdd %2 %49 %52
%1008 = OpFAdd %2 %1003 %1007
%1009 = OpFAdd %2 %1006 %1008
%43 = OpCopyObject %2 %1009
OpBranch %44
%44 = OpLabel
OpReturn
OpFunctionEnd
)";
ASSERT_TRUE(IsValid(env, context.get()));
ASSERT_TRUE(IsEqual(env, variant_shader, context.get()));
}
} // namespace
} // namespace fuzz
} // namespace spvtools