// Copyright (c) 2020 André Perez Maselco // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "source/fuzz/transformation_replace_linear_algebra_instruction.h" #include "source/fuzz/fuzzer_util.h" #include "source/fuzz/instruction_descriptor.h" namespace spvtools { namespace fuzz { TransformationReplaceLinearAlgebraInstruction:: TransformationReplaceLinearAlgebraInstruction( const spvtools::fuzz::protobufs:: TransformationReplaceLinearAlgebraInstruction& message) : message_(message) {} TransformationReplaceLinearAlgebraInstruction:: TransformationReplaceLinearAlgebraInstruction( const std::vector& fresh_ids, const protobufs::InstructionDescriptor& instruction_descriptor) { for (auto fresh_id : fresh_ids) { message_.add_fresh_ids(fresh_id); } *message_.mutable_instruction_descriptor() = instruction_descriptor; } bool TransformationReplaceLinearAlgebraInstruction::IsApplicable( opt::IRContext* ir_context, const TransformationContext& /*unused*/) const { auto instruction = FindInstruction(message_.instruction_descriptor(), ir_context); // TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/3354): // Right now we only support certain operations. When this issue is addressed // the following conditional can use the function |spvOpcodeIsLinearAlgebra|. // It must be a supported linear algebra instruction. if (instruction->opcode() != SpvOpVectorTimesScalar && instruction->opcode() != SpvOpMatrixTimesScalar && instruction->opcode() != SpvOpVectorTimesMatrix && instruction->opcode() != SpvOpMatrixTimesVector && instruction->opcode() != SpvOpMatrixTimesMatrix && instruction->opcode() != SpvOpDot) { return false; } // |message_.fresh_ids.size| must be the exact number of fresh ids needed to // apply the transformation. if (static_cast(message_.fresh_ids().size()) != GetRequiredFreshIdCount(ir_context, instruction)) { return false; } // All ids in |message_.fresh_ids| must be fresh. for (uint32_t fresh_id : message_.fresh_ids()) { if (!fuzzerutil::IsFreshId(ir_context, fresh_id)) { return false; } } return true; } void TransformationReplaceLinearAlgebraInstruction::Apply( opt::IRContext* ir_context, TransformationContext* /*unused*/) const { auto linear_algebra_instruction = FindInstruction(message_.instruction_descriptor(), ir_context); switch (linear_algebra_instruction->opcode()) { case SpvOpVectorTimesScalar: ReplaceOpVectorTimesScalar(ir_context, linear_algebra_instruction); break; case SpvOpMatrixTimesScalar: ReplaceOpMatrixTimesScalar(ir_context, linear_algebra_instruction); break; case SpvOpVectorTimesMatrix: ReplaceOpVectorTimesMatrix(ir_context, linear_algebra_instruction); break; case SpvOpMatrixTimesVector: ReplaceOpMatrixTimesVector(ir_context, linear_algebra_instruction); break; case SpvOpMatrixTimesMatrix: ReplaceOpMatrixTimesMatrix(ir_context, linear_algebra_instruction); break; case SpvOpDot: ReplaceOpDot(ir_context, linear_algebra_instruction); break; default: assert(false && "Should be unreachable."); break; } ir_context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone); } protobufs::Transformation TransformationReplaceLinearAlgebraInstruction::ToMessage() const { protobufs::Transformation result; *result.mutable_replace_linear_algebra_instruction() = message_; return result; } uint32_t TransformationReplaceLinearAlgebraInstruction::GetRequiredFreshIdCount( opt::IRContext* ir_context, opt::Instruction* instruction) { // TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/3354): // Right now we only support certain operations. switch (instruction->opcode()) { case SpvOpVectorTimesScalar: // For each vector component, 1 OpCompositeExtract and 1 OpFMul will be // inserted. return 2 * ir_context->get_type_mgr() ->GetType(ir_context->get_def_use_mgr() ->GetDef(instruction->GetSingleWordInOperand(0)) ->type_id()) ->AsVector() ->element_count(); case SpvOpMatrixTimesScalar: { // For each matrix column, |1 + column.size| OpCompositeExtract, // |column.size| OpFMul and 1 OpCompositeConstruct instructions will be // inserted. auto matrix_instruction = ir_context->get_def_use_mgr()->GetDef( instruction->GetSingleWordInOperand(0)); auto matrix_type = ir_context->get_type_mgr()->GetType(matrix_instruction->type_id()); return 2 * matrix_type->AsMatrix()->element_count() * (1 + matrix_type->AsMatrix() ->element_type() ->AsVector() ->element_count()); } case SpvOpVectorTimesMatrix: { // For each vector component, 1 OpCompositeExtract instruction will be // inserted. For each matrix column, |1 + vector_component_count| // OpCompositeExtract, |vector_component_count| OpFMul and // |vector_component_count - 1| OpFAdd instructions will be inserted. auto vector_instruction = ir_context->get_def_use_mgr()->GetDef( instruction->GetSingleWordInOperand(0)); auto matrix_instruction = ir_context->get_def_use_mgr()->GetDef( instruction->GetSingleWordInOperand(1)); uint32_t vector_component_count = ir_context->get_type_mgr() ->GetType(vector_instruction->type_id()) ->AsVector() ->element_count(); uint32_t matrix_column_count = ir_context->get_type_mgr() ->GetType(matrix_instruction->type_id()) ->AsMatrix() ->element_count(); return vector_component_count * (3 * matrix_column_count + 1); } case SpvOpMatrixTimesVector: { // For each matrix column, |1 + matrix_row_count| OpCompositeExtract // will be inserted. For each matrix row, |matrix_column_count| OpFMul and // |matrix_column_count - 1| OpFAdd instructions will be inserted. For // each vector component, 1 OpCompositeExtract instruction will be // inserted. auto matrix_instruction = ir_context->get_def_use_mgr()->GetDef( instruction->GetSingleWordInOperand(0)); uint32_t matrix_column_count = ir_context->get_type_mgr() ->GetType(matrix_instruction->type_id()) ->AsMatrix() ->element_count(); uint32_t matrix_row_count = ir_context->get_type_mgr() ->GetType(matrix_instruction->type_id()) ->AsMatrix() ->element_type() ->AsVector() ->element_count(); return 3 * matrix_column_count * matrix_row_count + 2 * matrix_column_count - matrix_row_count; } case SpvOpMatrixTimesMatrix: { // For each matrix 2 column, 1 OpCompositeExtract, 1 OpCompositeConstruct, // |3 * matrix_1_row_count * matrix_1_column_count| OpCompositeExtract, // |matrix_1_row_count * matrix_1_column_count| OpFMul, // |matrix_1_row_count * (matrix_1_column_count - 1)| OpFAdd instructions // will be inserted. auto matrix_1_instruction = ir_context->get_def_use_mgr()->GetDef( instruction->GetSingleWordInOperand(0)); uint32_t matrix_1_column_count = ir_context->get_type_mgr() ->GetType(matrix_1_instruction->type_id()) ->AsMatrix() ->element_count(); uint32_t matrix_1_row_count = ir_context->get_type_mgr() ->GetType(matrix_1_instruction->type_id()) ->AsMatrix() ->element_type() ->AsVector() ->element_count(); auto matrix_2_instruction = ir_context->get_def_use_mgr()->GetDef( instruction->GetSingleWordInOperand(1)); uint32_t matrix_2_column_count = ir_context->get_type_mgr() ->GetType(matrix_2_instruction->type_id()) ->AsMatrix() ->element_count(); return matrix_2_column_count * (2 + matrix_1_row_count * (5 * matrix_1_column_count - 1)); } case SpvOpDot: // For each pair of vector components, 2 OpCompositeExtract and 1 OpFMul // will be inserted. The first two OpFMul instructions will result the // first OpFAdd instruction to be inserted. For each remaining OpFMul, 1 // OpFAdd will be inserted. The last OpFAdd instruction is got by changing // the OpDot instruction. return 4 * ir_context->get_type_mgr() ->GetType( ir_context->get_def_use_mgr() ->GetDef(instruction->GetSingleWordInOperand(0)) ->type_id()) ->AsVector() ->element_count() - 2; default: assert(false && "Unsupported linear algebra instruction."); return 0; } } void TransformationReplaceLinearAlgebraInstruction::ReplaceOpVectorTimesScalar( opt::IRContext* ir_context, opt::Instruction* linear_algebra_instruction) const { // Gets OpVectorTimesScalar in operands. auto vector = ir_context->get_def_use_mgr()->GetDef( linear_algebra_instruction->GetSingleWordInOperand(0)); auto scalar = ir_context->get_def_use_mgr()->GetDef( linear_algebra_instruction->GetSingleWordInOperand(1)); uint32_t vector_component_count = ir_context->get_type_mgr() ->GetType(vector->type_id()) ->AsVector() ->element_count(); std::vector float_multiplication_ids(vector_component_count); uint32_t fresh_id_index = 0; for (uint32_t i = 0; i < vector_component_count; i++) { // Extracts |vector| component. uint32_t vector_extract_id = message_.fresh_ids(fresh_id_index++); fuzzerutil::UpdateModuleIdBound(ir_context, vector_extract_id); linear_algebra_instruction->InsertBefore(MakeUnique( ir_context, SpvOpCompositeExtract, scalar->type_id(), vector_extract_id, opt::Instruction::OperandList( {{SPV_OPERAND_TYPE_ID, {vector->result_id()}}, {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}}))); // Multiplies the |vector| component with the |scalar|. uint32_t float_multiplication_id = message_.fresh_ids(fresh_id_index++); float_multiplication_ids[i] = float_multiplication_id; fuzzerutil::UpdateModuleIdBound(ir_context, float_multiplication_id); linear_algebra_instruction->InsertBefore(MakeUnique( ir_context, SpvOpFMul, scalar->type_id(), float_multiplication_id, opt::Instruction::OperandList( {{SPV_OPERAND_TYPE_ID, {vector_extract_id}}, {SPV_OPERAND_TYPE_ID, {scalar->result_id()}}}))); } // The OpVectorTimesScalar instruction is changed to an OpCompositeConstruct // instruction. linear_algebra_instruction->SetOpcode(SpvOpCompositeConstruct); linear_algebra_instruction->SetInOperand(0, {float_multiplication_ids[0]}); linear_algebra_instruction->SetInOperand(1, {float_multiplication_ids[1]}); for (uint32_t i = 2; i < float_multiplication_ids.size(); i++) { linear_algebra_instruction->AddOperand( {SPV_OPERAND_TYPE_ID, {float_multiplication_ids[i]}}); } } void TransformationReplaceLinearAlgebraInstruction::ReplaceOpMatrixTimesScalar( opt::IRContext* ir_context, opt::Instruction* linear_algebra_instruction) const { // Gets OpMatrixTimesScalar in operands. auto matrix_instruction = ir_context->get_def_use_mgr()->GetDef( linear_algebra_instruction->GetSingleWordInOperand(0)); auto scalar_instruction = ir_context->get_def_use_mgr()->GetDef( linear_algebra_instruction->GetSingleWordInOperand(1)); // Gets matrix information. uint32_t matrix_column_count = ir_context->get_type_mgr() ->GetType(matrix_instruction->type_id()) ->AsMatrix() ->element_count(); auto matrix_column_type = ir_context->get_type_mgr() ->GetType(matrix_instruction->type_id()) ->AsMatrix() ->element_type(); uint32_t matrix_column_size = matrix_column_type->AsVector()->element_count(); std::vector composite_construct_ids(matrix_column_count); uint32_t fresh_id_index = 0; for (uint32_t i = 0; i < matrix_column_count; i++) { // Extracts |matrix| column. uint32_t matrix_extract_id = message_.fresh_ids(fresh_id_index++); fuzzerutil::UpdateModuleIdBound(ir_context, matrix_extract_id); linear_algebra_instruction->InsertBefore(MakeUnique( ir_context, SpvOpCompositeExtract, ir_context->get_type_mgr()->GetId(matrix_column_type), matrix_extract_id, opt::Instruction::OperandList( {{SPV_OPERAND_TYPE_ID, {matrix_instruction->result_id()}}, {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}}))); std::vector float_multiplication_ids(matrix_column_size); for (uint32_t j = 0; j < matrix_column_size; j++) { // Extracts |column| component. uint32_t column_extract_id = message_.fresh_ids(fresh_id_index++); fuzzerutil::UpdateModuleIdBound(ir_context, column_extract_id); linear_algebra_instruction->InsertBefore(MakeUnique( ir_context, SpvOpCompositeExtract, scalar_instruction->type_id(), column_extract_id, opt::Instruction::OperandList( {{SPV_OPERAND_TYPE_ID, {matrix_extract_id}}, {SPV_OPERAND_TYPE_LITERAL_INTEGER, {j}}}))); // Multiplies the |column| component with the |scalar|. float_multiplication_ids[j] = message_.fresh_ids(fresh_id_index++); fuzzerutil::UpdateModuleIdBound(ir_context, float_multiplication_ids[j]); linear_algebra_instruction->InsertBefore(MakeUnique( ir_context, SpvOpFMul, scalar_instruction->type_id(), float_multiplication_ids[j], opt::Instruction::OperandList( {{SPV_OPERAND_TYPE_ID, {column_extract_id}}, {SPV_OPERAND_TYPE_ID, {scalar_instruction->result_id()}}}))); } // Constructs a new column multiplied by |scalar|. opt::Instruction::OperandList composite_construct_in_operands; for (uint32_t& float_multiplication_id : float_multiplication_ids) { composite_construct_in_operands.push_back( {SPV_OPERAND_TYPE_ID, {float_multiplication_id}}); } composite_construct_ids[i] = message_.fresh_ids(fresh_id_index++); fuzzerutil::UpdateModuleIdBound(ir_context, composite_construct_ids[i]); linear_algebra_instruction->InsertBefore(MakeUnique( ir_context, SpvOpCompositeConstruct, ir_context->get_type_mgr()->GetId(matrix_column_type), composite_construct_ids[i], composite_construct_in_operands)); } // The OpMatrixTimesScalar instruction is changed to an OpCompositeConstruct // instruction. linear_algebra_instruction->SetOpcode(SpvOpCompositeConstruct); linear_algebra_instruction->SetInOperand(0, {composite_construct_ids[0]}); linear_algebra_instruction->SetInOperand(1, {composite_construct_ids[1]}); for (uint32_t i = 2; i < composite_construct_ids.size(); i++) { linear_algebra_instruction->AddOperand( {SPV_OPERAND_TYPE_ID, {composite_construct_ids[i]}}); } } void TransformationReplaceLinearAlgebraInstruction::ReplaceOpVectorTimesMatrix( opt::IRContext* ir_context, opt::Instruction* linear_algebra_instruction) const { // Gets vector information. auto vector_instruction = ir_context->get_def_use_mgr()->GetDef( linear_algebra_instruction->GetSingleWordInOperand(0)); uint32_t vector_component_count = ir_context->get_type_mgr() ->GetType(vector_instruction->type_id()) ->AsVector() ->element_count(); auto vector_component_type = ir_context->get_type_mgr() ->GetType(vector_instruction->type_id()) ->AsVector() ->element_type(); // Extracts vector components. uint32_t fresh_id_index = 0; std::vector vector_component_ids(vector_component_count); for (uint32_t i = 0; i < vector_component_count; i++) { vector_component_ids[i] = message_.fresh_ids(fresh_id_index++); linear_algebra_instruction->InsertBefore(MakeUnique( ir_context, SpvOpCompositeExtract, ir_context->get_type_mgr()->GetId(vector_component_type), vector_component_ids[i], opt::Instruction::OperandList( {{SPV_OPERAND_TYPE_ID, {vector_instruction->result_id()}}, {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}}))); } // Gets matrix information. auto matrix_instruction = ir_context->get_def_use_mgr()->GetDef( linear_algebra_instruction->GetSingleWordInOperand(1)); uint32_t matrix_column_count = ir_context->get_type_mgr() ->GetType(matrix_instruction->type_id()) ->AsMatrix() ->element_count(); auto matrix_column_type = ir_context->get_type_mgr() ->GetType(matrix_instruction->type_id()) ->AsMatrix() ->element_type(); std::vector result_component_ids(matrix_column_count); for (uint32_t i = 0; i < matrix_column_count; i++) { // Extracts matrix column. uint32_t matrix_extract_id = message_.fresh_ids(fresh_id_index++); linear_algebra_instruction->InsertBefore(MakeUnique( ir_context, SpvOpCompositeExtract, ir_context->get_type_mgr()->GetId(matrix_column_type), matrix_extract_id, opt::Instruction::OperandList( {{SPV_OPERAND_TYPE_ID, {matrix_instruction->result_id()}}, {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}}))); std::vector float_multiplication_ids(vector_component_count); for (uint32_t j = 0; j < vector_component_count; j++) { // Extracts column component. uint32_t column_extract_id = message_.fresh_ids(fresh_id_index++); linear_algebra_instruction->InsertBefore(MakeUnique( ir_context, SpvOpCompositeExtract, ir_context->get_type_mgr()->GetId(vector_component_type), column_extract_id, opt::Instruction::OperandList( {{SPV_OPERAND_TYPE_ID, {matrix_extract_id}}, {SPV_OPERAND_TYPE_LITERAL_INTEGER, {j}}}))); // Multiplies corresponding vector and column components. float_multiplication_ids[j] = message_.fresh_ids(fresh_id_index++); linear_algebra_instruction->InsertBefore(MakeUnique( ir_context, SpvOpFMul, ir_context->get_type_mgr()->GetId(vector_component_type), float_multiplication_ids[j], opt::Instruction::OperandList( {{SPV_OPERAND_TYPE_ID, {vector_component_ids[j]}}, {SPV_OPERAND_TYPE_ID, {column_extract_id}}}))); } // Adds the multiplication results. std::vector float_add_ids; uint32_t float_add_id = message_.fresh_ids(fresh_id_index++); float_add_ids.push_back(float_add_id); linear_algebra_instruction->InsertBefore(MakeUnique( ir_context, SpvOpFAdd, ir_context->get_type_mgr()->GetId(vector_component_type), float_add_id, opt::Instruction::OperandList( {{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[0]}}, {SPV_OPERAND_TYPE_ID, {float_multiplication_ids[1]}}}))); for (uint32_t j = 2; j < float_multiplication_ids.size(); j++) { float_add_id = message_.fresh_ids(fresh_id_index++); float_add_ids.push_back(float_add_id); linear_algebra_instruction->InsertBefore(MakeUnique( ir_context, SpvOpFAdd, ir_context->get_type_mgr()->GetId(vector_component_type), float_add_id, opt::Instruction::OperandList( {{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[j]}}, {SPV_OPERAND_TYPE_ID, {float_add_ids[j - 2]}}}))); } result_component_ids[i] = float_add_ids.back(); } // The OpVectorTimesMatrix instruction is changed to an OpCompositeConstruct // instruction. linear_algebra_instruction->SetOpcode(SpvOpCompositeConstruct); linear_algebra_instruction->SetInOperand(0, {result_component_ids[0]}); linear_algebra_instruction->SetInOperand(1, {result_component_ids[1]}); for (uint32_t i = 2; i < result_component_ids.size(); i++) { linear_algebra_instruction->AddOperand( {SPV_OPERAND_TYPE_ID, {result_component_ids[i]}}); } fuzzerutil::UpdateModuleIdBound( ir_context, message_.fresh_ids(message_.fresh_ids().size() - 1)); } void TransformationReplaceLinearAlgebraInstruction::ReplaceOpMatrixTimesVector( opt::IRContext* ir_context, opt::Instruction* linear_algebra_instruction) const { // Gets matrix information. auto matrix_instruction = ir_context->get_def_use_mgr()->GetDef( linear_algebra_instruction->GetSingleWordInOperand(0)); uint32_t matrix_column_count = ir_context->get_type_mgr() ->GetType(matrix_instruction->type_id()) ->AsMatrix() ->element_count(); auto matrix_column_type = ir_context->get_type_mgr() ->GetType(matrix_instruction->type_id()) ->AsMatrix() ->element_type(); uint32_t matrix_row_count = matrix_column_type->AsVector()->element_count(); // Extracts matrix columns. uint32_t fresh_id_index = 0; std::vector matrix_column_ids(matrix_column_count); for (uint32_t i = 0; i < matrix_column_count; i++) { matrix_column_ids[i] = message_.fresh_ids(fresh_id_index++); linear_algebra_instruction->InsertBefore(MakeUnique( ir_context, SpvOpCompositeExtract, ir_context->get_type_mgr()->GetId(matrix_column_type), matrix_column_ids[i], opt::Instruction::OperandList( {{SPV_OPERAND_TYPE_ID, {matrix_instruction->result_id()}}, {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}}))); } // Gets vector information. auto vector_instruction = ir_context->get_def_use_mgr()->GetDef( linear_algebra_instruction->GetSingleWordInOperand(1)); auto vector_component_type = ir_context->get_type_mgr() ->GetType(vector_instruction->type_id()) ->AsVector() ->element_type(); // Extracts vector components. std::vector vector_component_ids(matrix_column_count); for (uint32_t i = 0; i < matrix_column_count; i++) { vector_component_ids[i] = message_.fresh_ids(fresh_id_index++); linear_algebra_instruction->InsertBefore(MakeUnique( ir_context, SpvOpCompositeExtract, ir_context->get_type_mgr()->GetId(vector_component_type), vector_component_ids[i], opt::Instruction::OperandList( {{SPV_OPERAND_TYPE_ID, {vector_instruction->result_id()}}, {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}}))); } std::vector result_component_ids(matrix_row_count); for (uint32_t i = 0; i < matrix_row_count; i++) { std::vector float_multiplication_ids(matrix_column_count); for (uint32_t j = 0; j < matrix_column_count; j++) { // Extracts column component. uint32_t column_extract_id = message_.fresh_ids(fresh_id_index++); linear_algebra_instruction->InsertBefore(MakeUnique( ir_context, SpvOpCompositeExtract, ir_context->get_type_mgr()->GetId(vector_component_type), column_extract_id, opt::Instruction::OperandList( {{SPV_OPERAND_TYPE_ID, {matrix_column_ids[j]}}, {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}}))); // Multiplies corresponding vector and column components. float_multiplication_ids[j] = message_.fresh_ids(fresh_id_index++); linear_algebra_instruction->InsertBefore(MakeUnique( ir_context, SpvOpFMul, ir_context->get_type_mgr()->GetId(vector_component_type), float_multiplication_ids[j], opt::Instruction::OperandList( {{SPV_OPERAND_TYPE_ID, {column_extract_id}}, {SPV_OPERAND_TYPE_ID, {vector_component_ids[j]}}}))); } // Adds the multiplication results. std::vector float_add_ids; uint32_t float_add_id = message_.fresh_ids(fresh_id_index++); float_add_ids.push_back(float_add_id); linear_algebra_instruction->InsertBefore(MakeUnique( ir_context, SpvOpFAdd, ir_context->get_type_mgr()->GetId(vector_component_type), float_add_id, opt::Instruction::OperandList( {{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[0]}}, {SPV_OPERAND_TYPE_ID, {float_multiplication_ids[1]}}}))); for (uint32_t j = 2; j < float_multiplication_ids.size(); j++) { float_add_id = message_.fresh_ids(fresh_id_index++); float_add_ids.push_back(float_add_id); linear_algebra_instruction->InsertBefore(MakeUnique( ir_context, SpvOpFAdd, ir_context->get_type_mgr()->GetId(vector_component_type), float_add_id, opt::Instruction::OperandList( {{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[j]}}, {SPV_OPERAND_TYPE_ID, {float_add_ids[j - 2]}}}))); } result_component_ids[i] = float_add_ids.back(); } // The OpMatrixTimesVector instruction is changed to an OpCompositeConstruct // instruction. linear_algebra_instruction->SetOpcode(SpvOpCompositeConstruct); linear_algebra_instruction->SetInOperand(0, {result_component_ids[0]}); linear_algebra_instruction->SetInOperand(1, {result_component_ids[1]}); for (uint32_t i = 2; i < result_component_ids.size(); i++) { linear_algebra_instruction->AddOperand( {SPV_OPERAND_TYPE_ID, {result_component_ids[i]}}); } fuzzerutil::UpdateModuleIdBound( ir_context, message_.fresh_ids(message_.fresh_ids().size() - 1)); } void TransformationReplaceLinearAlgebraInstruction::ReplaceOpMatrixTimesMatrix( opt::IRContext* ir_context, opt::Instruction* linear_algebra_instruction) const { // Gets matrix 1 information. auto matrix_1_instruction = ir_context->get_def_use_mgr()->GetDef( linear_algebra_instruction->GetSingleWordInOperand(0)); uint32_t matrix_1_column_count = ir_context->get_type_mgr() ->GetType(matrix_1_instruction->type_id()) ->AsMatrix() ->element_count(); auto matrix_1_column_type = ir_context->get_type_mgr() ->GetType(matrix_1_instruction->type_id()) ->AsMatrix() ->element_type(); auto matrix_1_column_component_type = matrix_1_column_type->AsVector()->element_type(); uint32_t matrix_1_row_count = matrix_1_column_type->AsVector()->element_count(); // Gets matrix 2 information. auto matrix_2_instruction = ir_context->get_def_use_mgr()->GetDef( linear_algebra_instruction->GetSingleWordInOperand(1)); uint32_t matrix_2_column_count = ir_context->get_type_mgr() ->GetType(matrix_2_instruction->type_id()) ->AsMatrix() ->element_count(); auto matrix_2_column_type = ir_context->get_type_mgr() ->GetType(matrix_2_instruction->type_id()) ->AsMatrix() ->element_type(); uint32_t fresh_id_index = 0; std::vector result_column_ids(matrix_2_column_count); for (uint32_t i = 0; i < matrix_2_column_count; i++) { // Extracts matrix 2 column. uint32_t matrix_2_column_id = message_.fresh_ids(fresh_id_index++); linear_algebra_instruction->InsertBefore(MakeUnique( ir_context, SpvOpCompositeExtract, ir_context->get_type_mgr()->GetId(matrix_2_column_type), matrix_2_column_id, opt::Instruction::OperandList( {{SPV_OPERAND_TYPE_ID, {matrix_2_instruction->result_id()}}, {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}}))); std::vector column_component_ids(matrix_1_row_count); for (uint32_t j = 0; j < matrix_1_row_count; j++) { std::vector float_multiplication_ids(matrix_1_column_count); for (uint32_t k = 0; k < matrix_1_column_count; k++) { // Extracts matrix 1 column. uint32_t matrix_1_column_id = message_.fresh_ids(fresh_id_index++); linear_algebra_instruction->InsertBefore(MakeUnique( ir_context, SpvOpCompositeExtract, ir_context->get_type_mgr()->GetId(matrix_1_column_type), matrix_1_column_id, opt::Instruction::OperandList( {{SPV_OPERAND_TYPE_ID, {matrix_1_instruction->result_id()}}, {SPV_OPERAND_TYPE_LITERAL_INTEGER, {k}}}))); // Extracts matrix 1 column component. uint32_t matrix_1_column_component_id = message_.fresh_ids(fresh_id_index++); linear_algebra_instruction->InsertBefore(MakeUnique( ir_context, SpvOpCompositeExtract, ir_context->get_type_mgr()->GetId(matrix_1_column_component_type), matrix_1_column_component_id, opt::Instruction::OperandList( {{SPV_OPERAND_TYPE_ID, {matrix_1_column_id}}, {SPV_OPERAND_TYPE_LITERAL_INTEGER, {j}}}))); // Extracts matrix 2 column component. uint32_t matrix_2_column_component_id = message_.fresh_ids(fresh_id_index++); linear_algebra_instruction->InsertBefore(MakeUnique( ir_context, SpvOpCompositeExtract, ir_context->get_type_mgr()->GetId(matrix_1_column_component_type), matrix_2_column_component_id, opt::Instruction::OperandList( {{SPV_OPERAND_TYPE_ID, {matrix_2_column_id}}, {SPV_OPERAND_TYPE_LITERAL_INTEGER, {k}}}))); // Multiplies corresponding matrix 1 and matrix 2 column components. float_multiplication_ids[k] = message_.fresh_ids(fresh_id_index++); linear_algebra_instruction->InsertBefore(MakeUnique( ir_context, SpvOpFMul, ir_context->get_type_mgr()->GetId(matrix_1_column_component_type), float_multiplication_ids[k], opt::Instruction::OperandList( {{SPV_OPERAND_TYPE_ID, {matrix_1_column_component_id}}, {SPV_OPERAND_TYPE_ID, {matrix_2_column_component_id}}}))); } // Adds the multiplication results. std::vector float_add_ids; uint32_t float_add_id = message_.fresh_ids(fresh_id_index++); float_add_ids.push_back(float_add_id); linear_algebra_instruction->InsertBefore(MakeUnique( ir_context, SpvOpFAdd, ir_context->get_type_mgr()->GetId(matrix_1_column_component_type), float_add_id, opt::Instruction::OperandList( {{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[0]}}, {SPV_OPERAND_TYPE_ID, {float_multiplication_ids[1]}}}))); for (uint32_t k = 2; k < float_multiplication_ids.size(); k++) { float_add_id = message_.fresh_ids(fresh_id_index++); float_add_ids.push_back(float_add_id); linear_algebra_instruction->InsertBefore(MakeUnique( ir_context, SpvOpFAdd, ir_context->get_type_mgr()->GetId(matrix_1_column_component_type), float_add_id, opt::Instruction::OperandList( {{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[k]}}, {SPV_OPERAND_TYPE_ID, {float_add_ids[k - 2]}}}))); } column_component_ids[j] = float_add_ids.back(); } // Inserts the resulting matrix column. opt::Instruction::OperandList in_operands; for (auto& column_component_id : column_component_ids) { in_operands.push_back({SPV_OPERAND_TYPE_ID, {column_component_id}}); } result_column_ids[i] = message_.fresh_ids(fresh_id_index++); linear_algebra_instruction->InsertBefore(MakeUnique( ir_context, SpvOpCompositeConstruct, ir_context->get_type_mgr()->GetId(matrix_1_column_type), result_column_ids[i], opt::Instruction::OperandList(in_operands))); } // The OpMatrixTimesMatrix instruction is changed to an OpCompositeConstruct // instruction. linear_algebra_instruction->SetOpcode(SpvOpCompositeConstruct); linear_algebra_instruction->SetInOperand(0, {result_column_ids[0]}); linear_algebra_instruction->SetInOperand(1, {result_column_ids[1]}); for (uint32_t i = 2; i < result_column_ids.size(); i++) { linear_algebra_instruction->AddOperand( {SPV_OPERAND_TYPE_ID, {result_column_ids[i]}}); } fuzzerutil::UpdateModuleIdBound( ir_context, message_.fresh_ids(message_.fresh_ids().size() - 1)); } void TransformationReplaceLinearAlgebraInstruction::ReplaceOpDot( opt::IRContext* ir_context, opt::Instruction* linear_algebra_instruction) const { // Gets OpDot in operands. auto vector_1 = ir_context->get_def_use_mgr()->GetDef( linear_algebra_instruction->GetSingleWordInOperand(0)); auto vector_2 = ir_context->get_def_use_mgr()->GetDef( linear_algebra_instruction->GetSingleWordInOperand(1)); uint32_t vectors_component_count = ir_context->get_type_mgr() ->GetType(vector_1->type_id()) ->AsVector() ->element_count(); std::vector float_multiplication_ids(vectors_component_count); uint32_t fresh_id_index = 0; for (uint32_t i = 0; i < vectors_component_count; i++) { // Extracts |vector_1| component. uint32_t vector_1_extract_id = message_.fresh_ids(fresh_id_index++); fuzzerutil::UpdateModuleIdBound(ir_context, vector_1_extract_id); linear_algebra_instruction->InsertBefore(MakeUnique( ir_context, SpvOpCompositeExtract, linear_algebra_instruction->type_id(), vector_1_extract_id, opt::Instruction::OperandList( {{SPV_OPERAND_TYPE_ID, {vector_1->result_id()}}, {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}}))); // Extracts |vector_2| component. uint32_t vector_2_extract_id = message_.fresh_ids(fresh_id_index++); fuzzerutil::UpdateModuleIdBound(ir_context, vector_2_extract_id); linear_algebra_instruction->InsertBefore(MakeUnique( ir_context, SpvOpCompositeExtract, linear_algebra_instruction->type_id(), vector_2_extract_id, opt::Instruction::OperandList( {{SPV_OPERAND_TYPE_ID, {vector_2->result_id()}}, {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}}))); // Multiplies the pair of components. float_multiplication_ids[i] = message_.fresh_ids(fresh_id_index++); fuzzerutil::UpdateModuleIdBound(ir_context, float_multiplication_ids[i]); linear_algebra_instruction->InsertBefore(MakeUnique( ir_context, SpvOpFMul, linear_algebra_instruction->type_id(), float_multiplication_ids[i], opt::Instruction::OperandList( {{SPV_OPERAND_TYPE_ID, {vector_1_extract_id}}, {SPV_OPERAND_TYPE_ID, {vector_2_extract_id}}}))); } // If the vector has 2 components, then there will be 2 float multiplication // instructions. if (vectors_component_count == 2) { linear_algebra_instruction->SetOpcode(SpvOpFAdd); linear_algebra_instruction->SetInOperand(0, {float_multiplication_ids[0]}); linear_algebra_instruction->SetInOperand(1, {float_multiplication_ids[1]}); } else { // The first OpFAdd instruction has as operands the first two OpFMul // instructions. std::vector float_add_ids; uint32_t float_add_id = message_.fresh_ids(fresh_id_index++); float_add_ids.push_back(float_add_id); fuzzerutil::UpdateModuleIdBound(ir_context, float_add_id); linear_algebra_instruction->InsertBefore(MakeUnique( ir_context, SpvOpFAdd, linear_algebra_instruction->type_id(), float_add_id, opt::Instruction::OperandList( {{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[0]}}, {SPV_OPERAND_TYPE_ID, {float_multiplication_ids[1]}}}))); // The remaining OpFAdd instructions has as operands an OpFMul and an OpFAdd // instruction. for (uint32_t i = 2; i < float_multiplication_ids.size() - 1; i++) { float_add_id = message_.fresh_ids(fresh_id_index++); fuzzerutil::UpdateModuleIdBound(ir_context, float_add_id); float_add_ids.push_back(float_add_id); linear_algebra_instruction->InsertBefore(MakeUnique( ir_context, SpvOpFAdd, linear_algebra_instruction->type_id(), float_add_id, opt::Instruction::OperandList( {{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[i]}}, {SPV_OPERAND_TYPE_ID, {float_add_ids[i - 2]}}}))); } // The last OpFAdd instruction is got by changing some of the OpDot // instruction attributes. linear_algebra_instruction->SetOpcode(SpvOpFAdd); linear_algebra_instruction->SetInOperand( 0, {float_multiplication_ids[float_multiplication_ids.size() - 1]}); linear_algebra_instruction->SetInOperand( 1, {float_add_ids[float_add_ids.size() - 1]}); } } } // namespace fuzz } // namespace spvtools