diff --git a/source/fuzz/fuzzer_pass_replace_linear_algebra_instructions.cpp b/source/fuzz/fuzzer_pass_replace_linear_algebra_instructions.cpp index 1c7b8285b..7f9b84887 100644 --- a/source/fuzz/fuzzer_pass_replace_linear_algebra_instructions.cpp +++ b/source/fuzz/fuzzer_pass_replace_linear_algebra_instructions.cpp @@ -42,6 +42,7 @@ void FuzzerPassReplaceLinearAlgebraInstructions::Apply() { // addressed the following conditional can use the function // |spvOpcodeIsLinearAlgebra|. if (instruction->opcode() != SpvOpVectorTimesScalar && + instruction->opcode() != SpvOpMatrixTimesScalar && instruction->opcode() != SpvOpDot) { return; } diff --git a/source/fuzz/protobufs/spvtoolsfuzz.proto b/source/fuzz/protobufs/spvtoolsfuzz.proto index b8d07cf8d..7b8adc08d 100644 --- a/source/fuzz/protobufs/spvtoolsfuzz.proto +++ b/source/fuzz/protobufs/spvtoolsfuzz.proto @@ -1134,13 +1134,13 @@ message TransformationReplaceLinearAlgebraInstruction { // This transformation is only applicable if the described instruction has one of the following opcodes. // Supported: // OpVectorTimesScalar + // OpMatrixTimesScalar // OpDot // TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/3354): // Right now we only support certain operations. When this issue is addressed // the supporting comments can be removed. // To be supported in the future: // OpTranspose - // OpMatrixTimesScalar // OpVectorTimesMatrix // OpMatrixTimesVector // OpMatrixTimesMatrix diff --git a/source/fuzz/transformation_replace_linear_algebra_instruction.cpp b/source/fuzz/transformation_replace_linear_algebra_instruction.cpp index 14ed502ff..1c7d0c991 100644 --- a/source/fuzz/transformation_replace_linear_algebra_instruction.cpp +++ b/source/fuzz/transformation_replace_linear_algebra_instruction.cpp @@ -46,6 +46,7 @@ bool TransformationReplaceLinearAlgebraInstruction::IsApplicable( // the following conditional can use the function |spvOpcodeIsLinearAlgebra|. // It must be a supported linear algebra instruction. if (instruction->opcode() != SpvOpVectorTimesScalar && + instruction->opcode() != SpvOpMatrixTimesScalar && instruction->opcode() != SpvOpDot) { return false; } @@ -77,6 +78,9 @@ void TransformationReplaceLinearAlgebraInstruction::Apply( case SpvOpVectorTimesScalar: ReplaceOpVectorTimesScalar(ir_context, linear_algebra_instruction); break; + case SpvOpMatrixTimesScalar: + ReplaceOpMatrixTimesScalar(ir_context, linear_algebra_instruction); + break; case SpvOpDot: ReplaceOpDot(ir_context, linear_algebra_instruction); break; @@ -110,7 +114,21 @@ uint32_t TransformationReplaceLinearAlgebraInstruction::GetRequiredFreshIdCount( ->type_id()) ->AsVector() ->element_count(); - case SpvOpDot: { + case SpvOpMatrixTimesScalar: { + // For each matrix column, |1 + column.size| OpCompositeExtract, + // |column.size| OpFMul and 1 OpCompositeConstruct instructions will be + // inserted. + auto matrix_instruction = ir_context->get_def_use_mgr()->GetDef( + instruction->GetSingleWordInOperand(0)); + auto matrix_type = + ir_context->get_type_mgr()->GetType(matrix_instruction->type_id()); + return 2 * matrix_type->AsMatrix()->element_count() * + (1 + matrix_type->AsMatrix() + ->element_type() + ->AsVector() + ->element_count()); + } + case SpvOpDot: // For each pair of vector components, 2 OpCompositeExtract and 1 OpFMul // will be inserted. The first two OpFMul instructions will result the // first OpFAdd instruction to be inserted. For each remaining OpFMul, 1 @@ -124,7 +142,6 @@ uint32_t TransformationReplaceLinearAlgebraInstruction::GetRequiredFreshIdCount( ->AsVector() ->element_count() - 2; - } default: assert(false && "Unsupported linear algebra instruction."); return 0; @@ -179,6 +196,90 @@ void TransformationReplaceLinearAlgebraInstruction::ReplaceOpVectorTimesScalar( } } +void TransformationReplaceLinearAlgebraInstruction::ReplaceOpMatrixTimesScalar( + opt::IRContext* ir_context, + opt::Instruction* linear_algebra_instruction) const { + // Gets OpMatrixTimesScalar in operands. + auto matrix_instruction = ir_context->get_def_use_mgr()->GetDef( + linear_algebra_instruction->GetSingleWordInOperand(0)); + auto scalar_instruction = ir_context->get_def_use_mgr()->GetDef( + linear_algebra_instruction->GetSingleWordInOperand(1)); + + // Gets matrix information. + uint32_t matrix_column_count = ir_context->get_type_mgr() + ->GetType(matrix_instruction->type_id()) + ->AsMatrix() + ->element_count(); + auto matrix_column_type = ir_context->get_type_mgr() + ->GetType(matrix_instruction->type_id()) + ->AsMatrix() + ->element_type(); + uint32_t matrix_column_size = matrix_column_type->AsVector()->element_count(); + + std::vector composite_construct_ids(matrix_column_count); + uint32_t fresh_id_index = 0; + + for (uint32_t i = 0; i < matrix_column_count; i++) { + // Extracts |matrix| column. + uint32_t matrix_extract_id = message_.fresh_ids(fresh_id_index++); + fuzzerutil::UpdateModuleIdBound(ir_context, matrix_extract_id); + linear_algebra_instruction->InsertBefore(MakeUnique( + ir_context, SpvOpCompositeExtract, + ir_context->get_type_mgr()->GetId(matrix_column_type), + matrix_extract_id, + opt::Instruction::OperandList( + {{SPV_OPERAND_TYPE_ID, {matrix_instruction->result_id()}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}}))); + + std::vector float_multiplication_ids(matrix_column_size); + + for (uint32_t j = 0; j < matrix_column_size; j++) { + // Extracts |column| component. + uint32_t column_extract_id = message_.fresh_ids(fresh_id_index++); + fuzzerutil::UpdateModuleIdBound(ir_context, column_extract_id); + linear_algebra_instruction->InsertBefore(MakeUnique( + ir_context, SpvOpCompositeExtract, scalar_instruction->type_id(), + column_extract_id, + opt::Instruction::OperandList( + {{SPV_OPERAND_TYPE_ID, {matrix_extract_id}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {j}}}))); + + // Multiplies the |column| component with the |scalar|. + float_multiplication_ids[j] = message_.fresh_ids(fresh_id_index++); + fuzzerutil::UpdateModuleIdBound(ir_context, float_multiplication_ids[j]); + linear_algebra_instruction->InsertBefore(MakeUnique( + ir_context, SpvOpFMul, scalar_instruction->type_id(), + float_multiplication_ids[j], + opt::Instruction::OperandList( + {{SPV_OPERAND_TYPE_ID, {column_extract_id}}, + {SPV_OPERAND_TYPE_ID, {scalar_instruction->result_id()}}}))); + } + + // Constructs a new column multiplied by |scalar|. + opt::Instruction::OperandList composite_construct_in_operands; + for (uint32_t& float_multiplication_id : float_multiplication_ids) { + composite_construct_in_operands.push_back( + {SPV_OPERAND_TYPE_ID, {float_multiplication_id}}); + } + composite_construct_ids[i] = message_.fresh_ids(fresh_id_index++); + fuzzerutil::UpdateModuleIdBound(ir_context, composite_construct_ids[i]); + linear_algebra_instruction->InsertBefore(MakeUnique( + ir_context, SpvOpCompositeConstruct, + ir_context->get_type_mgr()->GetId(matrix_column_type), + composite_construct_ids[i], composite_construct_in_operands)); + } + + // The OpMatrixTimesScalar instruction is changed to an OpCompositeConstruct + // instruction. + linear_algebra_instruction->SetOpcode(SpvOpCompositeConstruct); + linear_algebra_instruction->SetInOperand(0, {composite_construct_ids[0]}); + linear_algebra_instruction->SetInOperand(1, {composite_construct_ids[1]}); + for (uint32_t i = 2; i < composite_construct_ids.size(); i++) { + linear_algebra_instruction->AddOperand( + {SPV_OPERAND_TYPE_ID, {composite_construct_ids[i]}}); + } +} + void TransformationReplaceLinearAlgebraInstruction::ReplaceOpDot( opt::IRContext* ir_context, opt::Instruction* linear_algebra_instruction) const { diff --git a/source/fuzz/transformation_replace_linear_algebra_instruction.h b/source/fuzz/transformation_replace_linear_algebra_instruction.h index de280e40f..45b12626d 100644 --- a/source/fuzz/transformation_replace_linear_algebra_instruction.h +++ b/source/fuzz/transformation_replace_linear_algebra_instruction.h @@ -56,6 +56,10 @@ class TransformationReplaceLinearAlgebraInstruction : public Transformation { void ReplaceOpVectorTimesScalar(opt::IRContext* ir_context, opt::Instruction* instruction) const; + // Replaces an OpMatrixTimesScalar instruction. + void ReplaceOpMatrixTimesScalar(opt::IRContext* ir_context, + opt::Instruction* instruction) const; + // Replaces an OpDot instruction. void ReplaceOpDot(opt::IRContext* ir_context, opt::Instruction* instruction) const; diff --git a/test/fuzz/transformation_replace_linear_algebra_instruction_test.cpp b/test/fuzz/transformation_replace_linear_algebra_instruction_test.cpp index 42906d067..c9a1aee13 100644 --- a/test/fuzz/transformation_replace_linear_algebra_instruction_test.cpp +++ b/test/fuzz/transformation_replace_linear_algebra_instruction_test.cpp @@ -250,6 +250,280 @@ TEST(TransformationReplaceLinearAlgebraInstructionTest, ASSERT_TRUE(IsEqual(env, variant_shader, context.get())); } +TEST(TransformationReplaceLinearAlgebraInstructionTest, + ReplaceOpMatrixTimesScalar) { + std::string reference_shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %54 "main" + OpExecutionMode %54 OriginUpperLeft + +; Types + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %4 = OpTypeFloat 32 + %5 = OpTypeVector %4 2 + %6 = OpTypeVector %4 3 + %7 = OpTypeVector %4 4 + %8 = OpTypeMatrix %5 2 + %9 = OpTypeMatrix %5 3 + %10 = OpTypeMatrix %5 4 + %11 = OpTypeMatrix %6 2 + %12 = OpTypeMatrix %6 3 + %13 = OpTypeMatrix %6 4 + %14 = OpTypeMatrix %7 2 + %15 = OpTypeMatrix %7 3 + %16 = OpTypeMatrix %7 4 + +; Constant scalars + %17 = OpConstant %4 1 + %18 = OpConstant %4 2 + %19 = OpConstant %4 3 + %20 = OpConstant %4 4 + %21 = OpConstant %4 5 + %22 = OpConstant %4 6 + %23 = OpConstant %4 7 + %24 = OpConstant %4 8 + %25 = OpConstant %4 9 + %26 = OpConstant %4 10 + %27 = OpConstant %4 11 + %28 = OpConstant %4 12 + %29 = OpConstant %4 13 + %30 = OpConstant %4 14 + %31 = OpConstant %4 15 + %32 = OpConstant %4 16 + +; Constant vectors + %33 = OpConstantComposite %5 %17 %18 + %34 = OpConstantComposite %5 %19 %20 + %35 = OpConstantComposite %5 %21 %22 + %36 = OpConstantComposite %5 %23 %24 + %37 = OpConstantComposite %6 %17 %18 %19 + %38 = OpConstantComposite %6 %20 %21 %22 + %39 = OpConstantComposite %6 %23 %24 %25 + %40 = OpConstantComposite %6 %26 %27 %28 + %41 = OpConstantComposite %7 %17 %18 %19 %20 + %42 = OpConstantComposite %7 %21 %22 %23 %24 + %43 = OpConstantComposite %7 %25 %26 %27 %28 + %44 = OpConstantComposite %7 %29 %30 %31 %32 + +; Constant matrices + %45 = OpConstantComposite %8 %33 %34 + %46 = OpConstantComposite %9 %33 %34 %35 + %47 = OpConstantComposite %10 %33 %34 %35 %36 + %48 = OpConstantComposite %11 %37 %38 + %49 = OpConstantComposite %12 %37 %38 %39 + %50 = OpConstantComposite %13 %37 %38 %39 %40 + %51 = OpConstantComposite %14 %41 %42 + %52 = OpConstantComposite %15 %41 %42 %43 + %53 = OpConstantComposite %16 %41 %42 %43 %44 + +; main function + %54 = OpFunction %2 None %3 + %55 = OpLabel + +; Multiplying 2-row matrices by scalar + %56 = OpMatrixTimesScalar %8 %45 %17 + %57 = OpMatrixTimesScalar %9 %46 %18 + %58 = OpMatrixTimesScalar %10 %47 %19 + +; Multiplying 3-row matrices by scalar + %59 = OpMatrixTimesScalar %11 %48 %21 + %60 = OpMatrixTimesScalar %12 %49 %22 + %61 = OpMatrixTimesScalar %13 %50 %23 + +; Multiplying 4-row matrices by scalar + %62 = OpMatrixTimesScalar %14 %51 %24 + %63 = OpMatrixTimesScalar %15 %52 %25 + %64 = OpMatrixTimesScalar %16 %53 %26 + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_5; + const auto consumer = nullptr; + const auto context = + BuildModule(env, consumer, reference_shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context.get())); + + FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); + + auto instruction_descriptor = + MakeInstructionDescriptor(56, SpvOpMatrixTimesScalar, 0); + auto transformation = TransformationReplaceLinearAlgebraInstruction( + {65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76}, instruction_descriptor); + transformation.Apply(context.get(), &transformation_context); + + instruction_descriptor = + MakeInstructionDescriptor(57, SpvOpMatrixTimesScalar, 0); + transformation = TransformationReplaceLinearAlgebraInstruction( + {77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94}, + instruction_descriptor); + transformation.Apply(context.get(), &transformation_context); + + instruction_descriptor = + MakeInstructionDescriptor(58, SpvOpMatrixTimesScalar, 0); + transformation = TransformationReplaceLinearAlgebraInstruction( + {95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118}, + instruction_descriptor); + transformation.Apply(context.get(), &transformation_context); + + std::string variant_shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %54 "main" + OpExecutionMode %54 OriginUpperLeft + +; Types + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %4 = OpTypeFloat 32 + %5 = OpTypeVector %4 2 + %6 = OpTypeVector %4 3 + %7 = OpTypeVector %4 4 + %8 = OpTypeMatrix %5 2 + %9 = OpTypeMatrix %5 3 + %10 = OpTypeMatrix %5 4 + %11 = OpTypeMatrix %6 2 + %12 = OpTypeMatrix %6 3 + %13 = OpTypeMatrix %6 4 + %14 = OpTypeMatrix %7 2 + %15 = OpTypeMatrix %7 3 + %16 = OpTypeMatrix %7 4 + +; Constant scalars + %17 = OpConstant %4 1 + %18 = OpConstant %4 2 + %19 = OpConstant %4 3 + %20 = OpConstant %4 4 + %21 = OpConstant %4 5 + %22 = OpConstant %4 6 + %23 = OpConstant %4 7 + %24 = OpConstant %4 8 + %25 = OpConstant %4 9 + %26 = OpConstant %4 10 + %27 = OpConstant %4 11 + %28 = OpConstant %4 12 + %29 = OpConstant %4 13 + %30 = OpConstant %4 14 + %31 = OpConstant %4 15 + %32 = OpConstant %4 16 + +; Constant vectors + %33 = OpConstantComposite %5 %17 %18 + %34 = OpConstantComposite %5 %19 %20 + %35 = OpConstantComposite %5 %21 %22 + %36 = OpConstantComposite %5 %23 %24 + %37 = OpConstantComposite %6 %17 %18 %19 + %38 = OpConstantComposite %6 %20 %21 %22 + %39 = OpConstantComposite %6 %23 %24 %25 + %40 = OpConstantComposite %6 %26 %27 %28 + %41 = OpConstantComposite %7 %17 %18 %19 %20 + %42 = OpConstantComposite %7 %21 %22 %23 %24 + %43 = OpConstantComposite %7 %25 %26 %27 %28 + %44 = OpConstantComposite %7 %29 %30 %31 %32 + +; Constant matrices + %45 = OpConstantComposite %8 %33 %34 + %46 = OpConstantComposite %9 %33 %34 %35 + %47 = OpConstantComposite %10 %33 %34 %35 %36 + %48 = OpConstantComposite %11 %37 %38 + %49 = OpConstantComposite %12 %37 %38 %39 + %50 = OpConstantComposite %13 %37 %38 %39 %40 + %51 = OpConstantComposite %14 %41 %42 + %52 = OpConstantComposite %15 %41 %42 %43 + %53 = OpConstantComposite %16 %41 %42 %43 %44 + +; main function + %54 = OpFunction %2 None %3 + %55 = OpLabel + +; Multiplying 2x2 matrix by scalar + %65 = OpCompositeExtract %5 %45 0 + %66 = OpCompositeExtract %4 %65 0 + %67 = OpFMul %4 %66 %17 + %68 = OpCompositeExtract %4 %65 1 + %69 = OpFMul %4 %68 %17 + %70 = OpCompositeConstruct %5 %67 %69 + %71 = OpCompositeExtract %5 %45 1 + %72 = OpCompositeExtract %4 %71 0 + %73 = OpFMul %4 %72 %17 + %74 = OpCompositeExtract %4 %71 1 + %75 = OpFMul %4 %74 %17 + %76 = OpCompositeConstruct %5 %73 %75 + %56 = OpCompositeConstruct %8 %70 %76 + +; Multiplying 2x3 matrix by scalar + %77 = OpCompositeExtract %5 %46 0 + %78 = OpCompositeExtract %4 %77 0 + %79 = OpFMul %4 %78 %18 + %80 = OpCompositeExtract %4 %77 1 + %81 = OpFMul %4 %80 %18 + %82 = OpCompositeConstruct %5 %79 %81 + %83 = OpCompositeExtract %5 %46 1 + %84 = OpCompositeExtract %4 %83 0 + %85 = OpFMul %4 %84 %18 + %86 = OpCompositeExtract %4 %83 1 + %87 = OpFMul %4 %86 %18 + %88 = OpCompositeConstruct %5 %85 %87 + %89 = OpCompositeExtract %5 %46 2 + %90 = OpCompositeExtract %4 %89 0 + %91 = OpFMul %4 %90 %18 + %92 = OpCompositeExtract %4 %89 1 + %93 = OpFMul %4 %92 %18 + %94 = OpCompositeConstruct %5 %91 %93 + %57 = OpCompositeConstruct %9 %82 %88 %94 + +; Multiplying 2x4 matrix by scalar + %95 = OpCompositeExtract %5 %47 0 + %96 = OpCompositeExtract %4 %95 0 + %97 = OpFMul %4 %96 %19 + %98 = OpCompositeExtract %4 %95 1 + %99 = OpFMul %4 %98 %19 + %100 = OpCompositeConstruct %5 %97 %99 + %101 = OpCompositeExtract %5 %47 1 + %102 = OpCompositeExtract %4 %101 0 + %103 = OpFMul %4 %102 %19 + %104 = OpCompositeExtract %4 %101 1 + %105 = OpFMul %4 %104 %19 + %106 = OpCompositeConstruct %5 %103 %105 + %107 = OpCompositeExtract %5 %47 2 + %108 = OpCompositeExtract %4 %107 0 + %109 = OpFMul %4 %108 %19 + %110 = OpCompositeExtract %4 %107 1 + %111 = OpFMul %4 %110 %19 + %112 = OpCompositeConstruct %5 %109 %111 + %113 = OpCompositeExtract %5 %47 3 + %114 = OpCompositeExtract %4 %113 0 + %115 = OpFMul %4 %114 %19 + %116 = OpCompositeExtract %4 %113 1 + %117 = OpFMul %4 %116 %19 + %118 = OpCompositeConstruct %5 %115 %117 + %58 = OpCompositeConstruct %10 %100 %106 %112 %118 + +; Multiplying 3-row matrices by scalar + %59 = OpMatrixTimesScalar %11 %48 %21 + %60 = OpMatrixTimesScalar %12 %49 %22 + %61 = OpMatrixTimesScalar %13 %50 %23 + +; Multiplying 4-row matrices by scalar + %62 = OpMatrixTimesScalar %14 %51 %24 + %63 = OpMatrixTimesScalar %15 %52 %25 + %64 = OpMatrixTimesScalar %16 %53 %26 + OpReturn + OpFunctionEnd + )"; + + ASSERT_TRUE(IsValid(env, context.get())); + ASSERT_TRUE(IsEqual(env, variant_shader, context.get())); +} + TEST(TransformationReplaceLinearAlgebraInstructionTest, ReplaceOpDot) { std::string reference_shader = R"( OpCapability Shader