spirv-fuzz: Add replace linear algebra instruction transformation (#3402)

This PR implements a transformation that replaces
a linear algebra instruction with its mathematical definition.
This commit is contained in:
André Perez 2020-06-16 07:20:51 -03:00 committed by GitHub
parent 52a5f074e9
commit 12a4fb3bc1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 919 additions and 0 deletions

View File

@ -64,6 +64,7 @@ if(SPIRV_BUILD_FUZZER)
fuzzer_pass_permute_blocks.h
fuzzer_pass_permute_function_parameters.h
fuzzer_pass_push_ids_through_variables.h
fuzzer_pass_replace_linear_algebra_instructions.h
fuzzer_pass_split_blocks.h
fuzzer_pass_swap_commutable_operands.h
fuzzer_pass_toggle_access_chain_instruction.h
@ -117,6 +118,7 @@ if(SPIRV_BUILD_FUZZER)
transformation_replace_boolean_constant_with_constant_binary.h
transformation_replace_constant_with_uniform.h
transformation_replace_id_with_synonym.h
transformation_replace_linear_algebra_instruction.h
transformation_set_function_control.h
transformation_set_loop_control.h
transformation_set_memory_operands_mask.h
@ -163,6 +165,7 @@ if(SPIRV_BUILD_FUZZER)
fuzzer_pass_permute_blocks.cpp
fuzzer_pass_permute_function_parameters.cpp
fuzzer_pass_push_ids_through_variables.cpp
fuzzer_pass_replace_linear_algebra_instructions.cpp
fuzzer_pass_split_blocks.cpp
fuzzer_pass_swap_commutable_operands.cpp
fuzzer_pass_toggle_access_chain_instruction.cpp
@ -215,6 +218,7 @@ if(SPIRV_BUILD_FUZZER)
transformation_replace_boolean_constant_with_constant_binary.cpp
transformation_replace_constant_with_uniform.cpp
transformation_replace_id_with_synonym.cpp
transformation_replace_linear_algebra_instruction.cpp
transformation_set_function_control.cpp
transformation_set_loop_control.cpp
transformation_set_memory_operands_mask.cpp

View File

@ -47,6 +47,7 @@
#include "source/fuzz/fuzzer_pass_permute_blocks.h"
#include "source/fuzz/fuzzer_pass_permute_function_parameters.h"
#include "source/fuzz/fuzzer_pass_push_ids_through_variables.h"
#include "source/fuzz/fuzzer_pass_replace_linear_algebra_instructions.h"
#include "source/fuzz/fuzzer_pass_split_blocks.h"
#include "source/fuzz/fuzzer_pass_swap_commutable_operands.h"
#include "source/fuzz/fuzzer_pass_toggle_access_chain_instruction.h"
@ -253,6 +254,9 @@ Fuzzer::FuzzerResultStatus Fuzzer::Run(
MaybeAddPass<FuzzerPassPushIdsThroughVariables>(
&passes, ir_context.get(), &transformation_context, &fuzzer_context,
transformation_sequence_out);
MaybeAddPass<FuzzerPassReplaceLinearAlgebraInstructions>(
&passes, ir_context.get(), &transformation_context, &fuzzer_context,
transformation_sequence_out);
MaybeAddPass<FuzzerPassSplitBlocks>(
&passes, ir_context.get(), &transformation_context, &fuzzer_context,
transformation_sequence_out);

View File

@ -64,6 +64,8 @@ const std::pair<uint32_t, uint32_t> kChanceOfOutliningFunction = {10, 90};
const std::pair<uint32_t, uint32_t> kChanceOfPermutingParameters = {30, 90};
const std::pair<uint32_t, uint32_t> kChanceOfPushingIdThroughVariable = {5, 50};
const std::pair<uint32_t, uint32_t> kChanceOfReplacingIdWithSynonym = {10, 90};
const std::pair<uint32_t, uint32_t>
kChanceOfReplacingLinearAlgebraInstructions = {10, 90};
const std::pair<uint32_t, uint32_t> kChanceOfSplittingBlock = {40, 95};
const std::pair<uint32_t, uint32_t> kChanceOfTogglingAccessChainInstruction = {
20, 90};
@ -162,6 +164,8 @@ FuzzerContext::FuzzerContext(RandomGenerator* random_generator,
ChooseBetweenMinAndMax(kChanceOfPushingIdThroughVariable);
chance_of_replacing_id_with_synonym_ =
ChooseBetweenMinAndMax(kChanceOfReplacingIdWithSynonym);
chance_of_replacing_linear_algebra_instructions_ =
ChooseBetweenMinAndMax(kChanceOfReplacingLinearAlgebraInstructions);
chance_of_splitting_block_ = ChooseBetweenMinAndMax(kChanceOfSplittingBlock);
chance_of_toggling_access_chain_instruction_ =
ChooseBetweenMinAndMax(kChanceOfTogglingAccessChainInstruction);
@ -171,6 +175,16 @@ FuzzerContext::~FuzzerContext() = default;
uint32_t FuzzerContext::GetFreshId() { return next_fresh_id_++; }
std::vector<uint32_t> FuzzerContext::GetFreshIds(const uint32_t count) {
std::vector<uint32_t> fresh_ids(count);
for (uint32_t& fresh_id : fresh_ids) {
fresh_id = next_fresh_id_++;
}
return fresh_ids;
}
bool FuzzerContext::ChooseEven() { return random_generator_->RandomBool(); }
bool FuzzerContext::ChoosePercentage(uint32_t percentage_chance) {

View File

@ -100,6 +100,9 @@ class FuzzerContext {
// or to have been issued before.
uint32_t GetFreshId();
// Returns a vector of |count| fresh ids.
std::vector<uint32_t> GetFreshIds(const uint32_t count);
// Probabilities associated with applying various transformations.
// Keep them in alphabetical order.
uint32_t GetChanceOfAddingAccessChain() {
@ -185,6 +188,9 @@ class FuzzerContext {
uint32_t GetChanceOfReplacingIdWithSynonym() {
return chance_of_replacing_id_with_synonym_;
}
uint32_t GetChanceOfReplacingLinearAlgebraInstructions() {
return chance_of_replacing_linear_algebra_instructions_;
}
uint32_t GetChanceOfSplittingBlock() { return chance_of_splitting_block_; }
uint32_t GetChanceOfTogglingAccessChainInstruction() {
return chance_of_toggling_access_chain_instruction_;
@ -268,6 +274,7 @@ class FuzzerContext {
uint32_t chance_of_permuting_parameters_;
uint32_t chance_of_pushing_id_through_variable_;
uint32_t chance_of_replacing_id_with_synonym_;
uint32_t chance_of_replacing_linear_algebra_instructions_;
uint32_t chance_of_splitting_block_;
uint32_t chance_of_toggling_access_chain_instruction_;

View File

@ -0,0 +1,64 @@
// Copyright (c) 2020 André Perez Maselco
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "source/fuzz/fuzzer_pass_replace_linear_algebra_instructions.h"
#include "source/fuzz/fuzzer_util.h"
#include "source/fuzz/instruction_descriptor.h"
#include "source/fuzz/transformation_replace_linear_algebra_instruction.h"
namespace spvtools {
namespace fuzz {
FuzzerPassReplaceLinearAlgebraInstructions::
FuzzerPassReplaceLinearAlgebraInstructions(
opt::IRContext* ir_context,
TransformationContext* transformation_context,
FuzzerContext* fuzzer_context,
protobufs::TransformationSequence* transformations)
: FuzzerPass(ir_context, transformation_context, fuzzer_context,
transformations) {}
FuzzerPassReplaceLinearAlgebraInstructions::
~FuzzerPassReplaceLinearAlgebraInstructions() = default;
void FuzzerPassReplaceLinearAlgebraInstructions::Apply() {
// For each instruction, checks whether it is a supported linear algebra
// instruction. In this case, the transformation is randomly applied.
GetIRContext()->module()->ForEachInst([this](opt::Instruction* instruction) {
// TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/3354):
// Right now we only support certain operations. When this issue is
// addressed the following conditional can use the function
// |spvOpcodeIsLinearAlgebra|.
if (instruction->opcode() != SpvOpVectorTimesScalar &&
instruction->opcode() != SpvOpDot) {
return;
}
if (!GetFuzzerContext()->ChoosePercentage(
GetFuzzerContext()
->GetChanceOfReplacingLinearAlgebraInstructions())) {
return;
}
ApplyTransformation(TransformationReplaceLinearAlgebraInstruction(
GetFuzzerContext()->GetFreshIds(
TransformationReplaceLinearAlgebraInstruction::
GetRequiredFreshIdCount(GetIRContext(), instruction)),
MakeInstructionDescriptor(GetIRContext(), instruction)));
});
}
} // namespace fuzz
} // namespace spvtools

View File

@ -0,0 +1,40 @@
// Copyright (c) 2020 André Perez Maselco
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SOURCE_FUZZ_FUZZER_PASS_REPLACE_LINEAR_ALGEBRA_INSTRUCTIONS_H_
#define SOURCE_FUZZ_FUZZER_PASS_REPLACE_LINEAR_ALGEBRA_INSTRUCTIONS_H_
#include "source/fuzz/fuzzer_pass.h"
namespace spvtools {
namespace fuzz {
// This fuzzer pass replaces linear algebra instructions with its mathematical
// definition.
class FuzzerPassReplaceLinearAlgebraInstructions : public FuzzerPass {
public:
FuzzerPassReplaceLinearAlgebraInstructions(
opt::IRContext* ir_context, TransformationContext* transformation_context,
FuzzerContext* fuzzer_context,
protobufs::TransformationSequence* transformations);
~FuzzerPassReplaceLinearAlgebraInstructions();
void Apply() override;
};
} // namespace fuzz
} // namespace spvtools
#endif // SOURCE_FUZZ_FUZZER_PASS_REPLACE_LINEAR_ALGEBRA_INSTRUCTIONS_H_

View File

@ -377,6 +377,7 @@ message Transformation {
TransformationAdjustBranchWeights adjust_branch_weights = 46;
TransformationPushIdThroughVariable push_id_through_variable = 47;
TransformationAddSpecConstantOp add_spec_constant_op = 48;
TransformationReplaceLinearAlgebraInstruction replace_linear_algebra_instruction = 49;
// Add additional option using the next available number.
}
}
@ -1081,6 +1082,33 @@ message TransformationReplaceIdWithSynonym {
}
message TransformationReplaceLinearAlgebraInstruction {
// Replaces a linear algebra instruction with its
// mathematical definition.
// The fresh ids needed to apply the transformation.
repeated uint32 fresh_ids = 1;
// A descriptor for a linear algebra instruction.
// This transformation is only applicable if the described instruction has one of the following opcodes.
// Supported:
// OpVectorTimesScalar
// OpDot
// TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/3354):
// Right now we only support certain operations. When this issue is addressed
// the supporting comments can be removed.
// To be supported in the future:
// OpTranspose
// OpMatrixTimesScalar
// OpVectorTimesMatrix
// OpMatrixTimesVector
// OpMatrixTimesMatrix
// OpOuterProduct
InstructionDescriptor instruction_descriptor = 2;
}
message TransformationSetFunctionControl {
// A transformation that sets the function control operand of an OpFunction

View File

@ -56,6 +56,7 @@
#include "source/fuzz/transformation_replace_boolean_constant_with_constant_binary.h"
#include "source/fuzz/transformation_replace_constant_with_uniform.h"
#include "source/fuzz/transformation_replace_id_with_synonym.h"
#include "source/fuzz/transformation_replace_linear_algebra_instruction.h"
#include "source/fuzz/transformation_set_function_control.h"
#include "source/fuzz/transformation_set_loop_control.h"
#include "source/fuzz/transformation_set_memory_operands_mask.h"
@ -182,6 +183,10 @@ std::unique_ptr<Transformation> Transformation::FromMessage(
case protobufs::Transformation::TransformationCase::kReplaceIdWithSynonym:
return MakeUnique<TransformationReplaceIdWithSynonym>(
message.replace_id_with_synonym());
case protobufs::Transformation::TransformationCase::
kReplaceLinearAlgebraInstruction:
return MakeUnique<TransformationReplaceLinearAlgebraInstruction>(
message.replace_linear_algebra_instruction());
case protobufs::Transformation::TransformationCase::kSetFunctionControl:
return MakeUnique<TransformationSetFunctionControl>(
message.set_function_control());

View File

@ -0,0 +1,275 @@
// Copyright (c) 2020 André Perez Maselco
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "source/fuzz/transformation_replace_linear_algebra_instruction.h"
#include "source/fuzz/fuzzer_util.h"
#include "source/fuzz/instruction_descriptor.h"
namespace spvtools {
namespace fuzz {
TransformationReplaceLinearAlgebraInstruction::
TransformationReplaceLinearAlgebraInstruction(
const spvtools::fuzz::protobufs::
TransformationReplaceLinearAlgebraInstruction& message)
: message_(message) {}
TransformationReplaceLinearAlgebraInstruction::
TransformationReplaceLinearAlgebraInstruction(
const std::vector<uint32_t>& fresh_ids,
const protobufs::InstructionDescriptor& instruction_descriptor) {
for (auto fresh_id : fresh_ids) {
message_.add_fresh_ids(fresh_id);
}
*message_.mutable_instruction_descriptor() = instruction_descriptor;
}
bool TransformationReplaceLinearAlgebraInstruction::IsApplicable(
opt::IRContext* ir_context, const TransformationContext& /*unused*/) const {
auto instruction =
FindInstruction(message_.instruction_descriptor(), ir_context);
// TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/3354):
// Right now we only support certain operations. When this issue is addressed
// the following conditional can use the function |spvOpcodeIsLinearAlgebra|.
// It must be a supported linear algebra instruction.
if (instruction->opcode() != SpvOpVectorTimesScalar &&
instruction->opcode() != SpvOpDot) {
return false;
}
// |message_.fresh_ids.size| must be the exact number of fresh ids needed to
// apply the transformation.
if (static_cast<uint32_t>(message_.fresh_ids().size()) !=
GetRequiredFreshIdCount(ir_context, instruction)) {
return false;
}
// All ids in |message_.fresh_ids| must be fresh.
for (uint32_t i = 0; i < static_cast<uint32_t>(message_.fresh_ids().size());
i++) {
if (!fuzzerutil::IsFreshId(ir_context, message_.fresh_ids(i))) {
return false;
}
}
return true;
}
void TransformationReplaceLinearAlgebraInstruction::Apply(
opt::IRContext* ir_context, TransformationContext* /*unused*/) const {
auto linear_algebra_instruction =
FindInstruction(message_.instruction_descriptor(), ir_context);
switch (linear_algebra_instruction->opcode()) {
case SpvOpVectorTimesScalar:
ReplaceOpVectorTimesScalar(ir_context, linear_algebra_instruction);
break;
case SpvOpDot:
ReplaceOpDot(ir_context, linear_algebra_instruction);
break;
default:
assert(false && "Should be unreachable.");
break;
}
ir_context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone);
}
protobufs::Transformation
TransformationReplaceLinearAlgebraInstruction::ToMessage() const {
protobufs::Transformation result;
*result.mutable_replace_linear_algebra_instruction() = message_;
return result;
}
uint32_t TransformationReplaceLinearAlgebraInstruction::GetRequiredFreshIdCount(
opt::IRContext* ir_context, opt::Instruction* instruction) {
// TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/3354):
// Right now we only support certain operations.
switch (instruction->opcode()) {
case SpvOpVectorTimesScalar:
// For each vector component, 1 OpCompositeExtract and 1 OpFMul will be
// inserted.
return 2 *
ir_context->get_type_mgr()
->GetType(ir_context->get_def_use_mgr()
->GetDef(instruction->GetSingleWordInOperand(0))
->type_id())
->AsVector()
->element_count();
case SpvOpDot: {
// For each pair of vector components, 2 OpCompositeExtract and 1 OpFMul
// will be inserted. The first two OpFMul instructions will result the
// first OpFAdd instruction to be inserted. For each remaining OpFMul, 1
// OpFAdd will be inserted. The last OpFAdd instruction is got by changing
// the OpDot instruction.
return 4 * ir_context->get_type_mgr()
->GetType(
ir_context->get_def_use_mgr()
->GetDef(instruction->GetSingleWordInOperand(0))
->type_id())
->AsVector()
->element_count() -
2;
}
default:
assert(false && "Unsupported linear algebra instruction.");
return 0;
}
}
void TransformationReplaceLinearAlgebraInstruction::ReplaceOpVectorTimesScalar(
opt::IRContext* ir_context,
opt::Instruction* linear_algebra_instruction) const {
// Gets OpVectorTimesScalar in operands.
auto vector = ir_context->get_def_use_mgr()->GetDef(
linear_algebra_instruction->GetSingleWordInOperand(0));
auto scalar = ir_context->get_def_use_mgr()->GetDef(
linear_algebra_instruction->GetSingleWordInOperand(1));
uint32_t vector_component_count = ir_context->get_type_mgr()
->GetType(vector->type_id())
->AsVector()
->element_count();
std::vector<uint32_t> float_multiplication_ids(vector_component_count);
uint32_t fresh_id_index = 0;
for (uint32_t i = 0; i < vector_component_count; i++) {
// Extracts |vector| component.
uint32_t vector_extract_id = message_.fresh_ids(fresh_id_index++);
fuzzerutil::UpdateModuleIdBound(ir_context, vector_extract_id);
linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
ir_context, SpvOpCompositeExtract, scalar->type_id(), vector_extract_id,
opt::Instruction::OperandList(
{{SPV_OPERAND_TYPE_ID, {vector->result_id()}},
{SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
// Multiplies the |vector| component with the |scalar|.
uint32_t float_multiplication_id = message_.fresh_ids(fresh_id_index++);
float_multiplication_ids[i] = float_multiplication_id;
fuzzerutil::UpdateModuleIdBound(ir_context, float_multiplication_id);
linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
ir_context, SpvOpFMul, scalar->type_id(), float_multiplication_id,
opt::Instruction::OperandList(
{{SPV_OPERAND_TYPE_ID, {vector_extract_id}},
{SPV_OPERAND_TYPE_ID, {scalar->result_id()}}})));
}
// The OpVectorTimesScalar instruction is changed to an OpCompositeConstruct
// instruction.
linear_algebra_instruction->SetOpcode(SpvOpCompositeConstruct);
linear_algebra_instruction->SetInOperand(0, {float_multiplication_ids[0]});
linear_algebra_instruction->SetInOperand(1, {float_multiplication_ids[1]});
for (uint32_t i = 2; i < float_multiplication_ids.size(); i++) {
linear_algebra_instruction->AddOperand(
{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[i]}});
}
}
void TransformationReplaceLinearAlgebraInstruction::ReplaceOpDot(
opt::IRContext* ir_context,
opt::Instruction* linear_algebra_instruction) const {
// Gets OpDot in operands.
auto vector_1 = ir_context->get_def_use_mgr()->GetDef(
linear_algebra_instruction->GetSingleWordInOperand(0));
auto vector_2 = ir_context->get_def_use_mgr()->GetDef(
linear_algebra_instruction->GetSingleWordInOperand(1));
uint32_t vectors_component_count = ir_context->get_type_mgr()
->GetType(vector_1->type_id())
->AsVector()
->element_count();
std::vector<uint32_t> float_multiplication_ids(vectors_component_count);
uint32_t fresh_id_index = 0;
for (uint32_t i = 0; i < vectors_component_count; i++) {
// Extracts |vector_1| component.
uint32_t vector_1_extract_id = message_.fresh_ids(fresh_id_index++);
fuzzerutil::UpdateModuleIdBound(ir_context, vector_1_extract_id);
linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
ir_context, SpvOpCompositeExtract,
linear_algebra_instruction->type_id(), vector_1_extract_id,
opt::Instruction::OperandList(
{{SPV_OPERAND_TYPE_ID, {vector_1->result_id()}},
{SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
// Extracts |vector_2| component.
uint32_t vector_2_extract_id = message_.fresh_ids(fresh_id_index++);
fuzzerutil::UpdateModuleIdBound(ir_context, vector_2_extract_id);
linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
ir_context, SpvOpCompositeExtract,
linear_algebra_instruction->type_id(), vector_2_extract_id,
opt::Instruction::OperandList(
{{SPV_OPERAND_TYPE_ID, {vector_2->result_id()}},
{SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
// Multiplies the pair of components.
float_multiplication_ids[i] = message_.fresh_ids(fresh_id_index++);
fuzzerutil::UpdateModuleIdBound(ir_context, float_multiplication_ids[i]);
linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
ir_context, SpvOpFMul, linear_algebra_instruction->type_id(),
float_multiplication_ids[i],
opt::Instruction::OperandList(
{{SPV_OPERAND_TYPE_ID, {vector_1_extract_id}},
{SPV_OPERAND_TYPE_ID, {vector_2_extract_id}}})));
}
// If the vector has 2 components, then there will be 2 float multiplication
// instructions.
if (vectors_component_count == 2) {
linear_algebra_instruction->SetOpcode(SpvOpFAdd);
linear_algebra_instruction->SetInOperand(0, {float_multiplication_ids[0]});
linear_algebra_instruction->SetInOperand(1, {float_multiplication_ids[1]});
} else {
// The first OpFAdd instruction has as operands the first two OpFMul
// instructions.
std::vector<uint32_t> float_add_ids;
uint32_t float_add_id = message_.fresh_ids(fresh_id_index++);
float_add_ids.push_back(float_add_id);
fuzzerutil::UpdateModuleIdBound(ir_context, float_add_id);
linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
ir_context, SpvOpFAdd, linear_algebra_instruction->type_id(),
float_add_id,
opt::Instruction::OperandList(
{{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[0]}},
{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[1]}}})));
// The remaining OpFAdd instructions has as operands an OpFMul and an OpFAdd
// instruction.
for (uint32_t i = 2; i < float_multiplication_ids.size() - 1; i++) {
float_add_id = message_.fresh_ids(fresh_id_index++);
fuzzerutil::UpdateModuleIdBound(ir_context, float_add_id);
float_add_ids.push_back(float_add_id);
linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
ir_context, SpvOpFAdd, linear_algebra_instruction->type_id(),
float_add_id,
opt::Instruction::OperandList(
{{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[i]}},
{SPV_OPERAND_TYPE_ID, {float_add_ids[i - 2]}}})));
}
// The last OpFAdd instruction is got by changing some of the OpDot
// instruction attributes.
linear_algebra_instruction->SetOpcode(SpvOpFAdd);
linear_algebra_instruction->SetInOperand(
0, {float_multiplication_ids[float_multiplication_ids.size() - 1]});
linear_algebra_instruction->SetInOperand(
1, {float_add_ids[float_add_ids.size() - 1]});
}
}
} // namespace fuzz
} // namespace spvtools

View File

@ -0,0 +1,67 @@
// Copyright (c) 2020 André Perez Maselco
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SOURCE_FUZZ_TRANSFORMATION_REPLACE_LINEAR_ALGEBRA_INSTRUCTION_H_
#define SOURCE_FUZZ_TRANSFORMATION_REPLACE_LINEAR_ALGEBRA_INSTRUCTION_H_
#include "source/fuzz/protobufs/spirvfuzz_protobufs.h"
#include "source/fuzz/transformation.h"
#include "source/fuzz/transformation_context.h"
#include "source/opt/ir_context.h"
namespace spvtools {
namespace fuzz {
class TransformationReplaceLinearAlgebraInstruction : public Transformation {
public:
explicit TransformationReplaceLinearAlgebraInstruction(
const protobufs::TransformationReplaceLinearAlgebraInstruction& message);
TransformationReplaceLinearAlgebraInstruction(
const std::vector<uint32_t>& fresh_ids,
const protobufs::InstructionDescriptor& instruction_descriptor);
// - |message_.fresh_ids| must be fresh ids needed to apply the
// transformation.
// - |message_.instruction_descriptor| must be a linear algebra instruction
bool IsApplicable(
opt::IRContext* ir_context,
const TransformationContext& transformation_context) const override;
// Replaces a linear algebra instruction.
void Apply(opt::IRContext* ir_context,
TransformationContext* transformation_context) const override;
protobufs::Transformation ToMessage() const override;
// Returns the number of ids needed to apply the transformation.
static uint32_t GetRequiredFreshIdCount(opt::IRContext* ir_context,
opt::Instruction* instruction);
private:
protobufs::TransformationReplaceLinearAlgebraInstruction message_;
// Replaces an OpVectorTimesScalar instruction.
void ReplaceOpVectorTimesScalar(opt::IRContext* ir_context,
opt::Instruction* instruction) const;
// Replaces an OpDot instruction.
void ReplaceOpDot(opt::IRContext* ir_context,
opt::Instruction* instruction) const;
};
} // namespace fuzz
} // namespace spvtools
#endif // SOURCE_FUZZ_TRANSFORMATION_REPLACE_LINEAR_ALGEBRA_INSTRUCTION_H_

View File

@ -650,6 +650,22 @@ bool spvOpcodeIsCommutativeBinaryOperator(SpvOp opcode) {
}
}
bool spvOpcodeIsLinearAlgebra(SpvOp opcode) {
switch (opcode) {
case SpvOpTranspose:
case SpvOpVectorTimesScalar:
case SpvOpMatrixTimesScalar:
case SpvOpVectorTimesMatrix:
case SpvOpMatrixTimesVector:
case SpvOpMatrixTimesMatrix:
case SpvOpOuterProduct:
case SpvOpDot:
return true;
default:
return false;
}
}
std::vector<uint32_t> spvOpcodeMemorySemanticsOperandIndices(SpvOp opcode) {
switch (opcode) {
case SpvOpMemoryBarrier:

View File

@ -134,6 +134,9 @@ bool spvOpcodeIsDebug(SpvOp opcode);
// where the order of the operands is irrelevant.
bool spvOpcodeIsCommutativeBinaryOperator(SpvOp opcode);
// Returns true for opcodes that represents linear algebra instructions.
bool spvOpcodeIsLinearAlgebra(SpvOp opcode);
// Returns a vector containing the indices of the memory semantics <id>
// operands for |opcode|.
std::vector<uint32_t> spvOpcodeMemorySemanticsOperandIndices(SpvOp opcode);

View File

@ -62,6 +62,7 @@ if (${SPIRV_BUILD_FUZZER})
transformation_replace_boolean_constant_with_constant_binary_test.cpp
transformation_replace_constant_with_uniform_test.cpp
transformation_replace_id_with_synonym_test.cpp
transformation_replace_linear_algebra_instruction_test.cpp
transformation_set_function_control_test.cpp
transformation_set_loop_control_test.cpp
transformation_set_memory_operands_mask_test.cpp

View File

@ -0,0 +1,391 @@
// Copyright (c) 2020 André Perez Maselco
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "source/fuzz/transformation_replace_linear_algebra_instruction.h"
#include "source/fuzz/instruction_descriptor.h"
#include "test/fuzz/fuzz_test_util.h"
namespace spvtools {
namespace fuzz {
namespace {
TEST(TransformationReplaceLinearAlgebraInstructionTest, IsApplicable) {
std::string shader = R"(
OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %22 "main"
OpExecutionMode %22 OriginUpperLeft
OpSource ESSL 310
OpName %22 "main"
%2 = OpTypeVoid
%3 = OpTypeFunction %2
%4 = OpTypeFloat 32
%5 = OpTypeVector %4 2
%6 = OpTypeVector %4 3
%7 = OpTypeVector %4 4
%8 = OpConstant %4 1
%9 = OpConstant %4 2
%10 = OpConstant %4 3
%11 = OpConstant %4 4
%12 = OpConstant %4 5
%13 = OpConstant %4 6
%14 = OpConstant %4 7
%15 = OpConstant %4 8
%16 = OpConstantComposite %5 %8 %9
%17 = OpConstantComposite %5 %10 %11
%18 = OpConstantComposite %6 %8 %9 %10
%19 = OpConstantComposite %6 %11 %12 %13
%20 = OpConstantComposite %7 %8 %9 %10 %11
%21 = OpConstantComposite %7 %12 %13 %14 %15
%22 = OpFunction %2 None %3
%23 = OpLabel
%24 = OpDot %4 %16 %17
%25 = OpDot %4 %18 %19
%26 = OpDot %4 %20 %21
%27 = OpVectorTimesScalar %5 %16 %8
%28 = OpVectorTimesScalar %6 %18 %9
%29 = OpVectorTimesScalar %7 %20 %10
%30 = OpCopyObject %4 %24
%31 = OpFAdd %4 %8 %9
%32 = OpFMul %4 %10 %11
OpReturn
OpFunctionEnd
)";
const auto env = SPV_ENV_UNIVERSAL_1_5;
const auto consumer = nullptr;
const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption);
ASSERT_TRUE(IsValid(env, context.get()));
FactManager fact_manager;
spvtools::ValidatorOptions validator_options;
TransformationContext transformation_context(&fact_manager,
validator_options);
// Tests linear algebra instructions.
auto instruction_descriptor = MakeInstructionDescriptor(24, SpvOpDot, 0);
auto transformation = TransformationReplaceLinearAlgebraInstruction(
{33, 34, 35, 36, 37, 38}, instruction_descriptor);
ASSERT_TRUE(
transformation.IsApplicable(context.get(), transformation_context));
instruction_descriptor =
MakeInstructionDescriptor(27, SpvOpVectorTimesScalar, 0);
transformation = TransformationReplaceLinearAlgebraInstruction(
{33, 34, 35, 36}, instruction_descriptor);
ASSERT_TRUE(
transformation.IsApplicable(context.get(), transformation_context));
// Tests non-linear algebra instructions.
instruction_descriptor = MakeInstructionDescriptor(30, SpvOpCopyObject, 0);
transformation = TransformationReplaceLinearAlgebraInstruction(
{33, 34, 35, 36, 37, 38}, instruction_descriptor);
ASSERT_FALSE(
transformation.IsApplicable(context.get(), transformation_context));
instruction_descriptor = MakeInstructionDescriptor(31, SpvOpFAdd, 0);
transformation = TransformationReplaceLinearAlgebraInstruction(
{33, 34, 35, 36, 37}, instruction_descriptor);
ASSERT_FALSE(
transformation.IsApplicable(context.get(), transformation_context));
instruction_descriptor = MakeInstructionDescriptor(32, SpvOpFMul, 0);
transformation = TransformationReplaceLinearAlgebraInstruction(
{33, 34, 35, 36}, instruction_descriptor);
ASSERT_FALSE(
transformation.IsApplicable(context.get(), transformation_context));
// Tests number of fresh ids is different than necessary.
instruction_descriptor = MakeInstructionDescriptor(25, SpvOpDot, 0);
transformation = TransformationReplaceLinearAlgebraInstruction(
{33, 34, 35, 36}, instruction_descriptor);
ASSERT_FALSE(
transformation.IsApplicable(context.get(), transformation_context));
instruction_descriptor =
MakeInstructionDescriptor(28, SpvOpVectorTimesScalar, 0);
transformation = TransformationReplaceLinearAlgebraInstruction(
{33, 34, 35, 36, 37, 38, 39}, instruction_descriptor);
ASSERT_FALSE(
transformation.IsApplicable(context.get(), transformation_context));
// Tests non-fresh ids.
instruction_descriptor = MakeInstructionDescriptor(26, SpvOpDot, 0);
transformation = TransformationReplaceLinearAlgebraInstruction(
{33, 34, 5, 36, 37, 8, 39, 40, 1, 42, 3, 44, 45, 46},
instruction_descriptor);
ASSERT_FALSE(
transformation.IsApplicable(context.get(), transformation_context));
instruction_descriptor =
MakeInstructionDescriptor(29, SpvOpVectorTimesScalar, 0);
transformation = TransformationReplaceLinearAlgebraInstruction(
{33, 34, 35, 36, 7, 38, 9, 40}, instruction_descriptor);
ASSERT_FALSE(
transformation.IsApplicable(context.get(), transformation_context));
}
TEST(TransformationReplaceLinearAlgebraInstructionTest,
ReplaceOpVectorTimesScalar) {
std::string reference_shader = R"(
OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %15 "main"
OpExecutionMode %15 OriginUpperLeft
OpSource ESSL 310
OpName %15 "main"
%2 = OpTypeVoid
%3 = OpTypeFunction %2
%4 = OpTypeFloat 32
%5 = OpTypeVector %4 2
%6 = OpTypeVector %4 3
%7 = OpTypeVector %4 4
%8 = OpConstant %4 1
%9 = OpConstant %4 2
%10 = OpConstant %4 3
%11 = OpConstant %4 4
%12 = OpConstantComposite %5 %8 %9
%13 = OpConstantComposite %6 %8 %9 %10
%14 = OpConstantComposite %7 %8 %9 %10 %11
%15 = OpFunction %2 None %3
%16 = OpLabel
%17 = OpVectorTimesScalar %5 %12 %8
%18 = OpVectorTimesScalar %6 %13 %9
%19 = OpVectorTimesScalar %7 %14 %10
OpReturn
OpFunctionEnd
)";
const auto env = SPV_ENV_UNIVERSAL_1_5;
const auto consumer = nullptr;
const auto context =
BuildModule(env, consumer, reference_shader, kFuzzAssembleOption);
ASSERT_TRUE(IsValid(env, context.get()));
FactManager fact_manager;
spvtools::ValidatorOptions validator_options;
TransformationContext transformation_context(&fact_manager,
validator_options);
auto instruction_descriptor =
MakeInstructionDescriptor(17, SpvOpVectorTimesScalar, 0);
auto transformation = TransformationReplaceLinearAlgebraInstruction(
{20, 21, 22, 23}, instruction_descriptor);
transformation.Apply(context.get(), &transformation_context);
instruction_descriptor =
MakeInstructionDescriptor(18, SpvOpVectorTimesScalar, 0);
transformation = TransformationReplaceLinearAlgebraInstruction(
{24, 25, 26, 27, 28, 29}, instruction_descriptor);
transformation.Apply(context.get(), &transformation_context);
instruction_descriptor =
MakeInstructionDescriptor(19, SpvOpVectorTimesScalar, 0);
transformation = TransformationReplaceLinearAlgebraInstruction(
{30, 31, 32, 33, 34, 35, 36, 37}, instruction_descriptor);
transformation.Apply(context.get(), &transformation_context);
std::string variant_shader = R"(
OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %15 "main"
OpExecutionMode %15 OriginUpperLeft
OpSource ESSL 310
OpName %15 "main"
%2 = OpTypeVoid
%3 = OpTypeFunction %2
%4 = OpTypeFloat 32
%5 = OpTypeVector %4 2
%6 = OpTypeVector %4 3
%7 = OpTypeVector %4 4
%8 = OpConstant %4 1
%9 = OpConstant %4 2
%10 = OpConstant %4 3
%11 = OpConstant %4 4
%12 = OpConstantComposite %5 %8 %9
%13 = OpConstantComposite %6 %8 %9 %10
%14 = OpConstantComposite %7 %8 %9 %10 %11
%15 = OpFunction %2 None %3
%16 = OpLabel
%20 = OpCompositeExtract %4 %12 0
%21 = OpFMul %4 %20 %8
%22 = OpCompositeExtract %4 %12 1
%23 = OpFMul %4 %22 %8
%17 = OpCompositeConstruct %5 %21 %23
%24 = OpCompositeExtract %4 %13 0
%25 = OpFMul %4 %24 %9
%26 = OpCompositeExtract %4 %13 1
%27 = OpFMul %4 %26 %9
%28 = OpCompositeExtract %4 %13 2
%29 = OpFMul %4 %28 %9
%18 = OpCompositeConstruct %6 %25 %27 %29
%30 = OpCompositeExtract %4 %14 0
%31 = OpFMul %4 %30 %10
%32 = OpCompositeExtract %4 %14 1
%33 = OpFMul %4 %32 %10
%34 = OpCompositeExtract %4 %14 2
%35 = OpFMul %4 %34 %10
%36 = OpCompositeExtract %4 %14 3
%37 = OpFMul %4 %36 %10
%19 = OpCompositeConstruct %7 %31 %33 %35 %37
OpReturn
OpFunctionEnd
)";
ASSERT_TRUE(IsValid(env, context.get()));
ASSERT_TRUE(IsEqual(env, variant_shader, context.get()));
}
TEST(TransformationReplaceLinearAlgebraInstructionTest, ReplaceOpDot) {
std::string reference_shader = R"(
OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %22 "main"
OpExecutionMode %22 OriginUpperLeft
OpSource ESSL 310
OpName %22 "main"
%2 = OpTypeVoid
%3 = OpTypeFunction %2
%4 = OpTypeFloat 32
%5 = OpTypeVector %4 2
%6 = OpTypeVector %4 3
%7 = OpTypeVector %4 4
%8 = OpConstant %4 1
%9 = OpConstant %4 2
%10 = OpConstant %4 3
%11 = OpConstant %4 4
%12 = OpConstant %4 5
%13 = OpConstant %4 6
%14 = OpConstant %4 7
%15 = OpConstant %4 8
%16 = OpConstantComposite %5 %8 %9
%17 = OpConstantComposite %5 %10 %11
%18 = OpConstantComposite %6 %8 %9 %10
%19 = OpConstantComposite %6 %11 %12 %13
%20 = OpConstantComposite %7 %8 %9 %10 %11
%21 = OpConstantComposite %7 %12 %13 %14 %15
%22 = OpFunction %2 None %3
%23 = OpLabel
%24 = OpDot %4 %16 %17
%25 = OpDot %4 %18 %19
%26 = OpDot %4 %20 %21
OpReturn
OpFunctionEnd
)";
const auto env = SPV_ENV_UNIVERSAL_1_5;
const auto consumer = nullptr;
const auto context =
BuildModule(env, consumer, reference_shader, kFuzzAssembleOption);
ASSERT_TRUE(IsValid(env, context.get()));
FactManager fact_manager;
spvtools::ValidatorOptions validator_options;
TransformationContext transformation_context(&fact_manager,
validator_options);
auto instruction_descriptor = MakeInstructionDescriptor(24, SpvOpDot, 0);
auto transformation = TransformationReplaceLinearAlgebraInstruction(
{27, 28, 29, 30, 31, 32}, instruction_descriptor);
transformation.Apply(context.get(), &transformation_context);
instruction_descriptor = MakeInstructionDescriptor(25, SpvOpDot, 0);
transformation = TransformationReplaceLinearAlgebraInstruction(
{33, 34, 35, 36, 37, 38, 39, 40, 41, 42}, instruction_descriptor);
transformation.Apply(context.get(), &transformation_context);
instruction_descriptor = MakeInstructionDescriptor(26, SpvOpDot, 0);
transformation = TransformationReplaceLinearAlgebraInstruction(
{43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56},
instruction_descriptor);
transformation.Apply(context.get(), &transformation_context);
std::string variant_shader = R"(
OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %22 "main"
OpExecutionMode %22 OriginUpperLeft
OpSource ESSL 310
OpName %22 "main"
%2 = OpTypeVoid
%3 = OpTypeFunction %2
%4 = OpTypeFloat 32
%5 = OpTypeVector %4 2
%6 = OpTypeVector %4 3
%7 = OpTypeVector %4 4
%8 = OpConstant %4 1
%9 = OpConstant %4 2
%10 = OpConstant %4 3
%11 = OpConstant %4 4
%12 = OpConstant %4 5
%13 = OpConstant %4 6
%14 = OpConstant %4 7
%15 = OpConstant %4 8
%16 = OpConstantComposite %5 %8 %9
%17 = OpConstantComposite %5 %10 %11
%18 = OpConstantComposite %6 %8 %9 %10
%19 = OpConstantComposite %6 %11 %12 %13
%20 = OpConstantComposite %7 %8 %9 %10 %11
%21 = OpConstantComposite %7 %12 %13 %14 %15
%22 = OpFunction %2 None %3
%23 = OpLabel
%27 = OpCompositeExtract %4 %16 0
%28 = OpCompositeExtract %4 %17 0
%29 = OpFMul %4 %27 %28
%30 = OpCompositeExtract %4 %16 1
%31 = OpCompositeExtract %4 %17 1
%32 = OpFMul %4 %30 %31
%24 = OpFAdd %4 %29 %32
%33 = OpCompositeExtract %4 %18 0
%34 = OpCompositeExtract %4 %19 0
%35 = OpFMul %4 %33 %34
%36 = OpCompositeExtract %4 %18 1
%37 = OpCompositeExtract %4 %19 1
%38 = OpFMul %4 %36 %37
%39 = OpCompositeExtract %4 %18 2
%40 = OpCompositeExtract %4 %19 2
%41 = OpFMul %4 %39 %40
%42 = OpFAdd %4 %35 %38
%25 = OpFAdd %4 %41 %42
%43 = OpCompositeExtract %4 %20 0
%44 = OpCompositeExtract %4 %21 0
%45 = OpFMul %4 %43 %44
%46 = OpCompositeExtract %4 %20 1
%47 = OpCompositeExtract %4 %21 1
%48 = OpFMul %4 %46 %47
%49 = OpCompositeExtract %4 %20 2
%50 = OpCompositeExtract %4 %21 2
%51 = OpFMul %4 %49 %50
%52 = OpCompositeExtract %4 %20 3
%53 = OpCompositeExtract %4 %21 3
%54 = OpFMul %4 %52 %53
%55 = OpFAdd %4 %45 %48
%56 = OpFAdd %4 %51 %55
%26 = OpFAdd %4 %54 %56
OpReturn
OpFunctionEnd
)";
ASSERT_TRUE(IsValid(env, context.get()));
ASSERT_TRUE(IsEqual(env, variant_shader, context.get()));
}
} // namespace
} // namespace fuzz
} // namespace spvtools