mirror of
https://github.com/KhronosGroup/SPIRV-Tools
synced 2024-12-27 10:20:14 +00:00
57b3723c5c
Fixes #3902.
317 lines
12 KiB
C++
317 lines
12 KiB
C++
// Copyright (c) 2020 Vasyl Teliman
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#include "source/fuzz/transformation_replace_params_with_struct.h"
|
|
|
|
#include <vector>
|
|
|
|
#include "source/fuzz/fuzzer_util.h"
|
|
|
|
namespace spvtools {
|
|
namespace fuzz {
|
|
|
|
TransformationReplaceParamsWithStruct::TransformationReplaceParamsWithStruct(
|
|
const protobufs::TransformationReplaceParamsWithStruct& message)
|
|
: message_(message) {}
|
|
|
|
TransformationReplaceParamsWithStruct::TransformationReplaceParamsWithStruct(
|
|
const std::vector<uint32_t>& parameter_id, uint32_t fresh_function_type_id,
|
|
uint32_t fresh_parameter_id,
|
|
const std::map<uint32_t, uint32_t>& caller_id_to_fresh_composite_id) {
|
|
message_.set_fresh_function_type_id(fresh_function_type_id);
|
|
message_.set_fresh_parameter_id(fresh_parameter_id);
|
|
|
|
for (auto id : parameter_id) {
|
|
message_.add_parameter_id(id);
|
|
}
|
|
|
|
*message_.mutable_caller_id_to_fresh_composite_id() =
|
|
fuzzerutil::MapToRepeatedUInt32Pair(caller_id_to_fresh_composite_id);
|
|
}
|
|
|
|
bool TransformationReplaceParamsWithStruct::IsApplicable(
|
|
opt::IRContext* ir_context, const TransformationContext& /*unused*/) const {
|
|
std::vector<uint32_t> parameter_id(message_.parameter_id().begin(),
|
|
message_.parameter_id().end());
|
|
|
|
// Check that |parameter_id| is neither empty nor it has duplicates.
|
|
if (parameter_id.empty() || fuzzerutil::HasDuplicates(parameter_id)) {
|
|
return false;
|
|
}
|
|
|
|
// All ids must correspond to valid parameters of the same function.
|
|
// The function can't be an entry-point function.
|
|
|
|
// fuzzerutil::GetFunctionFromParameterId requires a valid id.
|
|
if (!ir_context->get_def_use_mgr()->GetDef(parameter_id[0])) {
|
|
return false;
|
|
}
|
|
|
|
const auto* function =
|
|
fuzzerutil::GetFunctionFromParameterId(ir_context, parameter_id[0]);
|
|
if (!function ||
|
|
fuzzerutil::FunctionIsEntryPoint(ir_context, function->result_id())) {
|
|
return false;
|
|
}
|
|
|
|
// Compute all ids of the function's parameters.
|
|
std::unordered_set<uint32_t> all_parameter_ids;
|
|
for (const auto* param :
|
|
fuzzerutil::GetParameters(ir_context, function->result_id())) {
|
|
all_parameter_ids.insert(param->result_id());
|
|
}
|
|
|
|
// Check that all elements in |parameter_id| are valid.
|
|
for (auto id : parameter_id) {
|
|
// fuzzerutil::GetFunctionFromParameterId requires a valid id.
|
|
if (!ir_context->get_def_use_mgr()->GetDef(id)) {
|
|
return false;
|
|
}
|
|
|
|
// Check that |id| is a result id of one of the |function|'s parameters.
|
|
if (!all_parameter_ids.count(id)) {
|
|
return false;
|
|
}
|
|
|
|
// Check that the parameter with result id |id| has supported type.
|
|
if (!IsParameterTypeSupported(ir_context,
|
|
fuzzerutil::GetTypeId(ir_context, id))) {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
// We already know that the function has at least |parameter_id.size()|
|
|
// parameters.
|
|
|
|
// Check that a relevant OpTypeStruct exists in the module.
|
|
if (!MaybeGetRequiredStructType(ir_context)) {
|
|
return false;
|
|
}
|
|
|
|
const auto caller_id_to_fresh_composite_id =
|
|
fuzzerutil::RepeatedUInt32PairToMap(
|
|
message_.caller_id_to_fresh_composite_id());
|
|
|
|
// Check that |callee_id_to_fresh_composite_id| is valid.
|
|
for (const auto* inst :
|
|
fuzzerutil::GetCallers(ir_context, function->result_id())) {
|
|
// Check that the callee is present in the map. It's ok if the map contains
|
|
// more ids that there are callees (those ids will not be used).
|
|
if (!caller_id_to_fresh_composite_id.count(inst->result_id())) {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
// Check that all fresh ids are unique and fresh.
|
|
std::vector<uint32_t> fresh_ids = {message_.fresh_function_type_id(),
|
|
message_.fresh_parameter_id()};
|
|
|
|
for (const auto& entry : caller_id_to_fresh_composite_id) {
|
|
fresh_ids.push_back(entry.second);
|
|
}
|
|
|
|
return !fuzzerutil::HasDuplicates(fresh_ids) &&
|
|
std::all_of(fresh_ids.begin(), fresh_ids.end(),
|
|
[ir_context](uint32_t id) {
|
|
return fuzzerutil::IsFreshId(ir_context, id);
|
|
});
|
|
}
|
|
|
|
void TransformationReplaceParamsWithStruct::Apply(
|
|
opt::IRContext* ir_context, TransformationContext* /*unused*/) const {
|
|
auto* function = fuzzerutil::GetFunctionFromParameterId(
|
|
ir_context, message_.parameter_id(0));
|
|
assert(function &&
|
|
"All parameters' ids should've been checked in the IsApplicable");
|
|
|
|
// Get a type id of the OpTypeStruct used as a type id of the new parameter.
|
|
auto struct_type_id = MaybeGetRequiredStructType(ir_context);
|
|
assert(struct_type_id &&
|
|
"IsApplicable should've guaranteed that this value isn't equal to 0");
|
|
|
|
// Add new parameter to the function.
|
|
function->AddParameter(MakeUnique<opt::Instruction>(
|
|
ir_context, SpvOpFunctionParameter, struct_type_id,
|
|
message_.fresh_parameter_id(), opt::Instruction::OperandList()));
|
|
|
|
fuzzerutil::UpdateModuleIdBound(ir_context, message_.fresh_parameter_id());
|
|
|
|
// Compute indices of replaced parameters. This will be used to adjust
|
|
// OpFunctionCall instructions and create OpCompositeConstruct instructions at
|
|
// every call site.
|
|
const auto indices_of_replaced_params =
|
|
ComputeIndicesOfReplacedParameters(ir_context);
|
|
|
|
const auto caller_id_to_fresh_composite_id =
|
|
fuzzerutil::RepeatedUInt32PairToMap(
|
|
message_.caller_id_to_fresh_composite_id());
|
|
|
|
// Update all function calls.
|
|
for (auto* inst : fuzzerutil::GetCallers(ir_context, function->result_id())) {
|
|
// Create a list of operands for the OpCompositeConstruct instruction.
|
|
opt::Instruction::OperandList composite_components;
|
|
for (auto index : indices_of_replaced_params) {
|
|
// +1 since the first in operand to OpFunctionCall is the result id of
|
|
// the function.
|
|
composite_components.emplace_back(
|
|
std::move(inst->GetInOperand(index + 1)));
|
|
}
|
|
|
|
// Remove arguments from the function call. We do it in a separate loop
|
|
// and in decreasing order to make sure we have removed correct operands.
|
|
for (auto index : std::set<uint32_t, std::greater<uint32_t>>(
|
|
indices_of_replaced_params.begin(),
|
|
indices_of_replaced_params.end())) {
|
|
// +1 since the first in operand to OpFunctionCall is the result id of
|
|
// the function.
|
|
inst->RemoveInOperand(index + 1);
|
|
}
|
|
|
|
// Insert OpCompositeConstruct before the function call.
|
|
auto fresh_composite_id =
|
|
caller_id_to_fresh_composite_id.at(inst->result_id());
|
|
inst->InsertBefore(MakeUnique<opt::Instruction>(
|
|
ir_context, SpvOpCompositeConstruct, struct_type_id, fresh_composite_id,
|
|
std::move(composite_components)));
|
|
|
|
// Add a new operand to the OpFunctionCall instruction.
|
|
inst->AddOperand({SPV_OPERAND_TYPE_ID, {fresh_composite_id}});
|
|
fuzzerutil::UpdateModuleIdBound(ir_context, fresh_composite_id);
|
|
}
|
|
|
|
// Insert OpCompositeExtract instructions into the entry point block of the
|
|
// function and remove replaced parameters.
|
|
for (int i = 0; i < message_.parameter_id_size(); ++i) {
|
|
const auto* param_inst =
|
|
ir_context->get_def_use_mgr()->GetDef(message_.parameter_id(i));
|
|
assert(param_inst && "Parameter id is invalid");
|
|
|
|
// Skip all OpVariable instructions.
|
|
auto iter = function->begin()->begin();
|
|
while (iter != function->begin()->end() &&
|
|
!fuzzerutil::CanInsertOpcodeBeforeInstruction(SpvOpCompositeExtract,
|
|
iter)) {
|
|
++iter;
|
|
}
|
|
|
|
assert(fuzzerutil::CanInsertOpcodeBeforeInstruction(SpvOpCompositeExtract,
|
|
iter) &&
|
|
"Can't extract parameter's value from the structure");
|
|
|
|
// Insert OpCompositeExtract instructions to unpack parameters' values from
|
|
// the struct type.
|
|
iter.InsertBefore(MakeUnique<opt::Instruction>(
|
|
ir_context, SpvOpCompositeExtract, param_inst->type_id(),
|
|
param_inst->result_id(),
|
|
opt::Instruction::OperandList{
|
|
{SPV_OPERAND_TYPE_ID, {message_.fresh_parameter_id()}},
|
|
{SPV_OPERAND_TYPE_LITERAL_INTEGER, {static_cast<uint32_t>(i)}}}));
|
|
|
|
fuzzerutil::RemoveParameter(ir_context, param_inst->result_id());
|
|
}
|
|
|
|
// Update function's type.
|
|
{
|
|
// We use a separate scope here since |old_function_type| might become a
|
|
// dangling pointer after the call to the fuzzerutil::UpdateFunctionType.
|
|
|
|
auto* old_function_type = fuzzerutil::GetFunctionType(ir_context, function);
|
|
assert(old_function_type && "Function has invalid type");
|
|
|
|
// +1 since the first in operand to OpTypeFunction is the result type id
|
|
// of the function.
|
|
std::vector<uint32_t> parameter_type_ids;
|
|
for (uint32_t i = 1; i < old_function_type->NumInOperands(); ++i) {
|
|
if (std::find(indices_of_replaced_params.begin(),
|
|
indices_of_replaced_params.end(),
|
|
i - 1) == indices_of_replaced_params.end()) {
|
|
parameter_type_ids.push_back(
|
|
old_function_type->GetSingleWordInOperand(i));
|
|
}
|
|
}
|
|
|
|
parameter_type_ids.push_back(struct_type_id);
|
|
|
|
fuzzerutil::UpdateFunctionType(
|
|
ir_context, function->result_id(), message_.fresh_function_type_id(),
|
|
old_function_type->GetSingleWordInOperand(0), parameter_type_ids);
|
|
}
|
|
|
|
// Make sure our changes are analyzed
|
|
ir_context->InvalidateAnalysesExceptFor(
|
|
opt::IRContext::Analysis::kAnalysisNone);
|
|
}
|
|
|
|
protobufs::Transformation TransformationReplaceParamsWithStruct::ToMessage()
|
|
const {
|
|
protobufs::Transformation result;
|
|
*result.mutable_replace_params_with_struct() = message_;
|
|
return result;
|
|
}
|
|
|
|
bool TransformationReplaceParamsWithStruct::IsParameterTypeSupported(
|
|
opt::IRContext* ir_context, uint32_t param_type_id) {
|
|
// TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/3403):
|
|
// Consider adding support for more types of parameters.
|
|
return fuzzerutil::CanCreateConstant(ir_context, param_type_id);
|
|
}
|
|
|
|
uint32_t TransformationReplaceParamsWithStruct::MaybeGetRequiredStructType(
|
|
opt::IRContext* ir_context) const {
|
|
std::vector<uint32_t> component_type_ids;
|
|
for (auto id : message_.parameter_id()) {
|
|
component_type_ids.push_back(fuzzerutil::GetTypeId(ir_context, id));
|
|
}
|
|
|
|
return fuzzerutil::MaybeGetStructType(ir_context, component_type_ids);
|
|
}
|
|
|
|
std::vector<uint32_t>
|
|
TransformationReplaceParamsWithStruct::ComputeIndicesOfReplacedParameters(
|
|
opt::IRContext* ir_context) const {
|
|
assert(!message_.parameter_id().empty() &&
|
|
"There must be at least one parameter to replace");
|
|
|
|
const auto* function = fuzzerutil::GetFunctionFromParameterId(
|
|
ir_context, message_.parameter_id(0));
|
|
assert(function && "|parameter_id|s are invalid");
|
|
|
|
std::vector<uint32_t> result;
|
|
|
|
auto params = fuzzerutil::GetParameters(ir_context, function->result_id());
|
|
for (auto id : message_.parameter_id()) {
|
|
auto it = std::find_if(params.begin(), params.end(),
|
|
[id](const opt::Instruction* param) {
|
|
return param->result_id() == id;
|
|
});
|
|
assert(it != params.end() && "Parameter's id is invalid");
|
|
result.push_back(static_cast<uint32_t>(it - params.begin()));
|
|
}
|
|
|
|
return result;
|
|
}
|
|
|
|
std::unordered_set<uint32_t>
|
|
TransformationReplaceParamsWithStruct::GetFreshIds() const {
|
|
std::unordered_set<uint32_t> result = {message_.fresh_function_type_id(),
|
|
message_.fresh_parameter_id()};
|
|
for (auto& pair : message_.caller_id_to_fresh_composite_id()) {
|
|
result.insert(pair.second());
|
|
}
|
|
return result;
|
|
}
|
|
|
|
} // namespace fuzz
|
|
} // namespace spvtools
|