spirv-fuzz: Handle OpPhi during constant obfuscation (#3640)

Fixes #3639.
This commit is contained in:
Vasyl Teliman 2020-08-05 21:17:27 +03:00 committed by GitHub
parent 28f32ca53e
commit a10e760596
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 190 additions and 3 deletions

View File

@ -1317,6 +1317,31 @@ MapToRepeatedUInt32Pair(const std::map<uint32_t, uint32_t>& data) {
return result;
}
opt::Instruction* GetLastInsertBeforeInstruction(opt::IRContext* ir_context,
uint32_t block_id,
SpvOp opcode) {
// CFG::block uses std::map::at which throws an exception when |block_id| is
// invalid. The error message is unhelpful, though. Thus, we test that
// |block_id| is valid here.
const auto* label_inst = ir_context->get_def_use_mgr()->GetDef(block_id);
(void)label_inst; // Make compilers happy in release mode.
assert(label_inst && label_inst->opcode() == SpvOpLabel &&
"|block_id| is invalid");
auto* block = ir_context->cfg()->block(block_id);
auto it = block->rbegin();
assert(it != block->rend() && "Basic block can't be empty");
if (block->GetMergeInst()) {
++it;
assert(it != block->rend() &&
"|block| must have at least two instructions:"
"terminator and a merge instruction");
}
return CanInsertOpcodeBeforeInstruction(opcode, &*it) ? &*it : nullptr;
}
} // namespace fuzzerutil
} // namespace fuzz

View File

@ -487,6 +487,12 @@ std::map<uint32_t, uint32_t> RepeatedUInt32PairToMap(
google::protobuf::RepeatedPtrField<protobufs::UInt32Pair>
MapToRepeatedUInt32Pair(const std::map<uint32_t, uint32_t>& data);
// Returns the last instruction in |block_id| before which an instruction with
// opcode |opcode| can be inserted, or nullptr if there is no such instruction.
opt::Instruction* GetLastInsertBeforeInstruction(opt::IRContext* ir_context,
uint32_t block_id,
SpvOp opcode);
} // namespace fuzzerutil
} // namespace fuzz

View File

@ -90,6 +90,40 @@ TransformationReplaceConstantWithUniform::MakeLoadInstruction(
operands_for_load);
}
opt::Instruction*
TransformationReplaceConstantWithUniform::GetInsertBeforeInstruction(
opt::IRContext* ir_context) const {
auto* result =
FindInstructionContainingUse(message_.id_use_descriptor(), ir_context);
if (!result) {
return nullptr;
}
// The use might be in an OpPhi instruction.
if (result->opcode() == SpvOpPhi) {
// OpPhi instructions must be the first instructions in a block. Thus, we
// can't insert above the OpPhi instruction. Given the predecessor block
// that corresponds to the id use, get the last instruction in that block
// above which we can insert OpAccessChain and OpLoad.
return fuzzerutil::GetLastInsertBeforeInstruction(
ir_context,
result->GetSingleWordInOperand(
message_.id_use_descriptor().in_operand_index() + 1),
SpvOpLoad);
}
// The only operand that we could've replaced in the OpBranchConditional is
// the condition id. But that operand has a boolean type and uniform variables
// can't store booleans (see the spec on OpTypeBool). Thus, |result| can't be
// an OpBranchConditional.
assert(result->opcode() != SpvOpBranchConditional &&
"OpBranchConditional has no operands to replace");
assert(fuzzerutil::CanInsertOpcodeBeforeInstruction(SpvOpLoad, result) &&
"We should be able to insert OpLoad and OpAccessChain at this point");
return result;
}
bool TransformationReplaceConstantWithUniform::IsApplicable(
opt::IRContext* ir_context,
const TransformationContext& transformation_context) const {
@ -188,6 +222,12 @@ bool TransformationReplaceConstantWithUniform::IsApplicable(
}
}
// Once all checks are completed, we should be able to safely insert
// OpAccessChain and OpLoad into the module.
assert(GetInsertBeforeInstruction(ir_context) &&
"There must exist an instruction that we can use to insert "
"OpAccessChain and OpLoad above");
return true;
}
@ -195,7 +235,7 @@ void TransformationReplaceConstantWithUniform::Apply(
spvtools::opt::IRContext* ir_context,
TransformationContext* /*unused*/) const {
// Get the instruction that contains the id use we wish to replace.
auto instruction_containing_constant_use =
auto* instruction_containing_constant_use =
FindInstructionContainingUse(message_.id_use_descriptor(), ir_context);
assert(instruction_containing_constant_use &&
"Precondition requires that the id use can be found.");
@ -210,12 +250,17 @@ void TransformationReplaceConstantWithUniform::Apply(
->GetDef(message_.id_use_descriptor().id_of_interest())
->type_id();
// Get an instruction that will be used to insert OpAccessChain and OpLoad.
auto* insert_before_inst = GetInsertBeforeInstruction(ir_context);
assert(insert_before_inst &&
"There must exist an insertion point for OpAccessChain and OpLoad");
// Add an access chain instruction to target the uniform element.
instruction_containing_constant_use->InsertBefore(
insert_before_inst->InsertBefore(
MakeAccessChainInstruction(ir_context, constant_type_id));
// Add a load from this access chain.
instruction_containing_constant_use->InsertBefore(
insert_before_inst->InsertBefore(
MakeLoadInstruction(ir_context, constant_type_id));
// Adjust the instruction containing the usage of the constant so that this

View File

@ -84,6 +84,11 @@ class TransformationReplaceConstantWithUniform : public Transformation {
std::unique_ptr<opt::Instruction> MakeLoadInstruction(
spvtools::opt::IRContext* ir_context, uint32_t constant_type_id) const;
// OpAccessChain and OpLoad will be inserted above the instruction returned
// by this function. Returns nullptr if no such instruction is present.
opt::Instruction* GetInsertBeforeInstruction(
opt::IRContext* ir_context) const;
protobufs::TransformationReplaceConstantWithUniform message_;
};

View File

@ -13,6 +13,7 @@
// limitations under the License.
#include "source/fuzz/transformation_replace_constant_with_uniform.h"
#include "source/fuzz/instruction_descriptor.h"
#include "source/fuzz/uniform_buffer_element_descriptor.h"
#include "test/fuzz/fuzz_test_util.h"
@ -1548,6 +1549,111 @@ TEST(TransformationReplaceConstantWithUniformTest,
.IsApplicable(context.get(), transformation_context));
}
TEST(TransformationReplaceConstantWithUniformTest, ReplaceOpPhiOperand) {
std::string shader = R"(
OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %4 "main"
OpExecutionMode %4 OriginUpperLeft
OpSource ESSL 320
OpDecorate %32 DescriptorSet 0
OpDecorate %32 Binding 0
%2 = OpTypeVoid
%3 = OpTypeFunction %2
%6 = OpTypeInt 32 1
%7 = OpConstant %6 2
%13 = OpConstant %6 4
%21 = OpConstant %6 1
%34 = OpConstant %6 0
%10 = OpTypeBool
%30 = OpTypeStruct %6
%31 = OpTypePointer Uniform %30
%32 = OpVariable %31 Uniform
%33 = OpTypePointer Uniform %6
%4 = OpFunction %2 None %3
%11 = OpLabel
OpBranch %5
%5 = OpLabel
%23 = OpPhi %6 %7 %11 %20 %15
%9 = OpSLessThan %10 %23 %13
OpLoopMerge %8 %15 None
OpBranchConditional %9 %15 %8
%15 = OpLabel
%20 = OpIAdd %6 %23 %21
OpBranch %5
%8 = OpLabel
OpReturn
OpFunctionEnd
)";
const auto env = SPV_ENV_UNIVERSAL_1_3;
const auto consumer = nullptr;
const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption);
ASSERT_TRUE(IsValid(env, context.get()));
FactManager fact_manager;
spvtools::ValidatorOptions validator_options;
TransformationContext transformation_context(&fact_manager,
validator_options);
auto int_descriptor = MakeUniformBufferElementDescriptor(0, 0, {0});
ASSERT_TRUE(
AddFactHelper(&transformation_context, context.get(), 2, int_descriptor));
{
TransformationReplaceConstantWithUniform transformation(
MakeIdUseDescriptor(7, MakeInstructionDescriptor(23, SpvOpPhi, 0), 0),
int_descriptor, 50, 51);
ASSERT_TRUE(
transformation.IsApplicable(context.get(), transformation_context));
transformation.Apply(context.get(), &transformation_context);
ASSERT_TRUE(IsValid(env, context.get()));
}
std::string after_transformation = R"(
OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %4 "main"
OpExecutionMode %4 OriginUpperLeft
OpSource ESSL 320
OpDecorate %32 DescriptorSet 0
OpDecorate %32 Binding 0
%2 = OpTypeVoid
%3 = OpTypeFunction %2
%6 = OpTypeInt 32 1
%7 = OpConstant %6 2
%13 = OpConstant %6 4
%21 = OpConstant %6 1
%34 = OpConstant %6 0
%10 = OpTypeBool
%30 = OpTypeStruct %6
%31 = OpTypePointer Uniform %30
%32 = OpVariable %31 Uniform
%33 = OpTypePointer Uniform %6
%4 = OpFunction %2 None %3
%11 = OpLabel
%50 = OpAccessChain %33 %32 %34
%51 = OpLoad %6 %50
OpBranch %5
%5 = OpLabel
%23 = OpPhi %6 %51 %11 %20 %15
%9 = OpSLessThan %10 %23 %13
OpLoopMerge %8 %15 None
OpBranchConditional %9 %15 %8
%15 = OpLabel
%20 = OpIAdd %6 %23 %21
OpBranch %5
%8 = OpLabel
OpReturn
OpFunctionEnd
)";
ASSERT_TRUE(IsEqual(env, after_transformation, context.get()));
}
} // namespace
} // namespace fuzz
} // namespace spvtools