Implement the OpMatrixTimesScalar linear algebra case (#3450)

This PR implements the OpMatrixTimesScalar case for the
replace linear algebra instruction transformation.
This commit is contained in:
André Perez 2020-06-26 11:54:33 -03:00 committed by GitHub
parent efaae24d00
commit c3680adbd5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 383 additions and 3 deletions

View File

@ -42,6 +42,7 @@ void FuzzerPassReplaceLinearAlgebraInstructions::Apply() {
// addressed the following conditional can use the function
// |spvOpcodeIsLinearAlgebra|.
if (instruction->opcode() != SpvOpVectorTimesScalar &&
instruction->opcode() != SpvOpMatrixTimesScalar &&
instruction->opcode() != SpvOpDot) {
return;
}

View File

@ -1134,13 +1134,13 @@ message TransformationReplaceLinearAlgebraInstruction {
// This transformation is only applicable if the described instruction has one of the following opcodes.
// Supported:
// OpVectorTimesScalar
// OpMatrixTimesScalar
// OpDot
// TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/3354):
// Right now we only support certain operations. When this issue is addressed
// the supporting comments can be removed.
// To be supported in the future:
// OpTranspose
// OpMatrixTimesScalar
// OpVectorTimesMatrix
// OpMatrixTimesVector
// OpMatrixTimesMatrix

View File

@ -46,6 +46,7 @@ bool TransformationReplaceLinearAlgebraInstruction::IsApplicable(
// the following conditional can use the function |spvOpcodeIsLinearAlgebra|.
// It must be a supported linear algebra instruction.
if (instruction->opcode() != SpvOpVectorTimesScalar &&
instruction->opcode() != SpvOpMatrixTimesScalar &&
instruction->opcode() != SpvOpDot) {
return false;
}
@ -77,6 +78,9 @@ void TransformationReplaceLinearAlgebraInstruction::Apply(
case SpvOpVectorTimesScalar:
ReplaceOpVectorTimesScalar(ir_context, linear_algebra_instruction);
break;
case SpvOpMatrixTimesScalar:
ReplaceOpMatrixTimesScalar(ir_context, linear_algebra_instruction);
break;
case SpvOpDot:
ReplaceOpDot(ir_context, linear_algebra_instruction);
break;
@ -110,7 +114,21 @@ uint32_t TransformationReplaceLinearAlgebraInstruction::GetRequiredFreshIdCount(
->type_id())
->AsVector()
->element_count();
case SpvOpDot: {
case SpvOpMatrixTimesScalar: {
// For each matrix column, |1 + column.size| OpCompositeExtract,
// |column.size| OpFMul and 1 OpCompositeConstruct instructions will be
// inserted.
auto matrix_instruction = ir_context->get_def_use_mgr()->GetDef(
instruction->GetSingleWordInOperand(0));
auto matrix_type =
ir_context->get_type_mgr()->GetType(matrix_instruction->type_id());
return 2 * matrix_type->AsMatrix()->element_count() *
(1 + matrix_type->AsMatrix()
->element_type()
->AsVector()
->element_count());
}
case SpvOpDot:
// For each pair of vector components, 2 OpCompositeExtract and 1 OpFMul
// will be inserted. The first two OpFMul instructions will result the
// first OpFAdd instruction to be inserted. For each remaining OpFMul, 1
@ -124,7 +142,6 @@ uint32_t TransformationReplaceLinearAlgebraInstruction::GetRequiredFreshIdCount(
->AsVector()
->element_count() -
2;
}
default:
assert(false && "Unsupported linear algebra instruction.");
return 0;
@ -179,6 +196,90 @@ void TransformationReplaceLinearAlgebraInstruction::ReplaceOpVectorTimesScalar(
}
}
void TransformationReplaceLinearAlgebraInstruction::ReplaceOpMatrixTimesScalar(
opt::IRContext* ir_context,
opt::Instruction* linear_algebra_instruction) const {
// Gets OpMatrixTimesScalar in operands.
auto matrix_instruction = ir_context->get_def_use_mgr()->GetDef(
linear_algebra_instruction->GetSingleWordInOperand(0));
auto scalar_instruction = ir_context->get_def_use_mgr()->GetDef(
linear_algebra_instruction->GetSingleWordInOperand(1));
// Gets matrix information.
uint32_t matrix_column_count = ir_context->get_type_mgr()
->GetType(matrix_instruction->type_id())
->AsMatrix()
->element_count();
auto matrix_column_type = ir_context->get_type_mgr()
->GetType(matrix_instruction->type_id())
->AsMatrix()
->element_type();
uint32_t matrix_column_size = matrix_column_type->AsVector()->element_count();
std::vector<uint32_t> composite_construct_ids(matrix_column_count);
uint32_t fresh_id_index = 0;
for (uint32_t i = 0; i < matrix_column_count; i++) {
// Extracts |matrix| column.
uint32_t matrix_extract_id = message_.fresh_ids(fresh_id_index++);
fuzzerutil::UpdateModuleIdBound(ir_context, matrix_extract_id);
linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
ir_context, SpvOpCompositeExtract,
ir_context->get_type_mgr()->GetId(matrix_column_type),
matrix_extract_id,
opt::Instruction::OperandList(
{{SPV_OPERAND_TYPE_ID, {matrix_instruction->result_id()}},
{SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
std::vector<uint32_t> float_multiplication_ids(matrix_column_size);
for (uint32_t j = 0; j < matrix_column_size; j++) {
// Extracts |column| component.
uint32_t column_extract_id = message_.fresh_ids(fresh_id_index++);
fuzzerutil::UpdateModuleIdBound(ir_context, column_extract_id);
linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
ir_context, SpvOpCompositeExtract, scalar_instruction->type_id(),
column_extract_id,
opt::Instruction::OperandList(
{{SPV_OPERAND_TYPE_ID, {matrix_extract_id}},
{SPV_OPERAND_TYPE_LITERAL_INTEGER, {j}}})));
// Multiplies the |column| component with the |scalar|.
float_multiplication_ids[j] = message_.fresh_ids(fresh_id_index++);
fuzzerutil::UpdateModuleIdBound(ir_context, float_multiplication_ids[j]);
linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
ir_context, SpvOpFMul, scalar_instruction->type_id(),
float_multiplication_ids[j],
opt::Instruction::OperandList(
{{SPV_OPERAND_TYPE_ID, {column_extract_id}},
{SPV_OPERAND_TYPE_ID, {scalar_instruction->result_id()}}})));
}
// Constructs a new column multiplied by |scalar|.
opt::Instruction::OperandList composite_construct_in_operands;
for (uint32_t& float_multiplication_id : float_multiplication_ids) {
composite_construct_in_operands.push_back(
{SPV_OPERAND_TYPE_ID, {float_multiplication_id}});
}
composite_construct_ids[i] = message_.fresh_ids(fresh_id_index++);
fuzzerutil::UpdateModuleIdBound(ir_context, composite_construct_ids[i]);
linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
ir_context, SpvOpCompositeConstruct,
ir_context->get_type_mgr()->GetId(matrix_column_type),
composite_construct_ids[i], composite_construct_in_operands));
}
// The OpMatrixTimesScalar instruction is changed to an OpCompositeConstruct
// instruction.
linear_algebra_instruction->SetOpcode(SpvOpCompositeConstruct);
linear_algebra_instruction->SetInOperand(0, {composite_construct_ids[0]});
linear_algebra_instruction->SetInOperand(1, {composite_construct_ids[1]});
for (uint32_t i = 2; i < composite_construct_ids.size(); i++) {
linear_algebra_instruction->AddOperand(
{SPV_OPERAND_TYPE_ID, {composite_construct_ids[i]}});
}
}
void TransformationReplaceLinearAlgebraInstruction::ReplaceOpDot(
opt::IRContext* ir_context,
opt::Instruction* linear_algebra_instruction) const {

View File

@ -56,6 +56,10 @@ class TransformationReplaceLinearAlgebraInstruction : public Transformation {
void ReplaceOpVectorTimesScalar(opt::IRContext* ir_context,
opt::Instruction* instruction) const;
// Replaces an OpMatrixTimesScalar instruction.
void ReplaceOpMatrixTimesScalar(opt::IRContext* ir_context,
opt::Instruction* instruction) const;
// Replaces an OpDot instruction.
void ReplaceOpDot(opt::IRContext* ir_context,
opt::Instruction* instruction) const;

View File

@ -250,6 +250,280 @@ TEST(TransformationReplaceLinearAlgebraInstructionTest,
ASSERT_TRUE(IsEqual(env, variant_shader, context.get()));
}
TEST(TransformationReplaceLinearAlgebraInstructionTest,
ReplaceOpMatrixTimesScalar) {
std::string reference_shader = R"(
OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %54 "main"
OpExecutionMode %54 OriginUpperLeft
; Types
%2 = OpTypeVoid
%3 = OpTypeFunction %2
%4 = OpTypeFloat 32
%5 = OpTypeVector %4 2
%6 = OpTypeVector %4 3
%7 = OpTypeVector %4 4
%8 = OpTypeMatrix %5 2
%9 = OpTypeMatrix %5 3
%10 = OpTypeMatrix %5 4
%11 = OpTypeMatrix %6 2
%12 = OpTypeMatrix %6 3
%13 = OpTypeMatrix %6 4
%14 = OpTypeMatrix %7 2
%15 = OpTypeMatrix %7 3
%16 = OpTypeMatrix %7 4
; Constant scalars
%17 = OpConstant %4 1
%18 = OpConstant %4 2
%19 = OpConstant %4 3
%20 = OpConstant %4 4
%21 = OpConstant %4 5
%22 = OpConstant %4 6
%23 = OpConstant %4 7
%24 = OpConstant %4 8
%25 = OpConstant %4 9
%26 = OpConstant %4 10
%27 = OpConstant %4 11
%28 = OpConstant %4 12
%29 = OpConstant %4 13
%30 = OpConstant %4 14
%31 = OpConstant %4 15
%32 = OpConstant %4 16
; Constant vectors
%33 = OpConstantComposite %5 %17 %18
%34 = OpConstantComposite %5 %19 %20
%35 = OpConstantComposite %5 %21 %22
%36 = OpConstantComposite %5 %23 %24
%37 = OpConstantComposite %6 %17 %18 %19
%38 = OpConstantComposite %6 %20 %21 %22
%39 = OpConstantComposite %6 %23 %24 %25
%40 = OpConstantComposite %6 %26 %27 %28
%41 = OpConstantComposite %7 %17 %18 %19 %20
%42 = OpConstantComposite %7 %21 %22 %23 %24
%43 = OpConstantComposite %7 %25 %26 %27 %28
%44 = OpConstantComposite %7 %29 %30 %31 %32
; Constant matrices
%45 = OpConstantComposite %8 %33 %34
%46 = OpConstantComposite %9 %33 %34 %35
%47 = OpConstantComposite %10 %33 %34 %35 %36
%48 = OpConstantComposite %11 %37 %38
%49 = OpConstantComposite %12 %37 %38 %39
%50 = OpConstantComposite %13 %37 %38 %39 %40
%51 = OpConstantComposite %14 %41 %42
%52 = OpConstantComposite %15 %41 %42 %43
%53 = OpConstantComposite %16 %41 %42 %43 %44
; main function
%54 = OpFunction %2 None %3
%55 = OpLabel
; Multiplying 2-row matrices by scalar
%56 = OpMatrixTimesScalar %8 %45 %17
%57 = OpMatrixTimesScalar %9 %46 %18
%58 = OpMatrixTimesScalar %10 %47 %19
; Multiplying 3-row matrices by scalar
%59 = OpMatrixTimesScalar %11 %48 %21
%60 = OpMatrixTimesScalar %12 %49 %22
%61 = OpMatrixTimesScalar %13 %50 %23
; Multiplying 4-row matrices by scalar
%62 = OpMatrixTimesScalar %14 %51 %24
%63 = OpMatrixTimesScalar %15 %52 %25
%64 = OpMatrixTimesScalar %16 %53 %26
OpReturn
OpFunctionEnd
)";
const auto env = SPV_ENV_UNIVERSAL_1_5;
const auto consumer = nullptr;
const auto context =
BuildModule(env, consumer, reference_shader, kFuzzAssembleOption);
ASSERT_TRUE(IsValid(env, context.get()));
FactManager fact_manager;
spvtools::ValidatorOptions validator_options;
TransformationContext transformation_context(&fact_manager,
validator_options);
auto instruction_descriptor =
MakeInstructionDescriptor(56, SpvOpMatrixTimesScalar, 0);
auto transformation = TransformationReplaceLinearAlgebraInstruction(
{65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76}, instruction_descriptor);
transformation.Apply(context.get(), &transformation_context);
instruction_descriptor =
MakeInstructionDescriptor(57, SpvOpMatrixTimesScalar, 0);
transformation = TransformationReplaceLinearAlgebraInstruction(
{77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94},
instruction_descriptor);
transformation.Apply(context.get(), &transformation_context);
instruction_descriptor =
MakeInstructionDescriptor(58, SpvOpMatrixTimesScalar, 0);
transformation = TransformationReplaceLinearAlgebraInstruction(
{95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106,
107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118},
instruction_descriptor);
transformation.Apply(context.get(), &transformation_context);
std::string variant_shader = R"(
OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %54 "main"
OpExecutionMode %54 OriginUpperLeft
; Types
%2 = OpTypeVoid
%3 = OpTypeFunction %2
%4 = OpTypeFloat 32
%5 = OpTypeVector %4 2
%6 = OpTypeVector %4 3
%7 = OpTypeVector %4 4
%8 = OpTypeMatrix %5 2
%9 = OpTypeMatrix %5 3
%10 = OpTypeMatrix %5 4
%11 = OpTypeMatrix %6 2
%12 = OpTypeMatrix %6 3
%13 = OpTypeMatrix %6 4
%14 = OpTypeMatrix %7 2
%15 = OpTypeMatrix %7 3
%16 = OpTypeMatrix %7 4
; Constant scalars
%17 = OpConstant %4 1
%18 = OpConstant %4 2
%19 = OpConstant %4 3
%20 = OpConstant %4 4
%21 = OpConstant %4 5
%22 = OpConstant %4 6
%23 = OpConstant %4 7
%24 = OpConstant %4 8
%25 = OpConstant %4 9
%26 = OpConstant %4 10
%27 = OpConstant %4 11
%28 = OpConstant %4 12
%29 = OpConstant %4 13
%30 = OpConstant %4 14
%31 = OpConstant %4 15
%32 = OpConstant %4 16
; Constant vectors
%33 = OpConstantComposite %5 %17 %18
%34 = OpConstantComposite %5 %19 %20
%35 = OpConstantComposite %5 %21 %22
%36 = OpConstantComposite %5 %23 %24
%37 = OpConstantComposite %6 %17 %18 %19
%38 = OpConstantComposite %6 %20 %21 %22
%39 = OpConstantComposite %6 %23 %24 %25
%40 = OpConstantComposite %6 %26 %27 %28
%41 = OpConstantComposite %7 %17 %18 %19 %20
%42 = OpConstantComposite %7 %21 %22 %23 %24
%43 = OpConstantComposite %7 %25 %26 %27 %28
%44 = OpConstantComposite %7 %29 %30 %31 %32
; Constant matrices
%45 = OpConstantComposite %8 %33 %34
%46 = OpConstantComposite %9 %33 %34 %35
%47 = OpConstantComposite %10 %33 %34 %35 %36
%48 = OpConstantComposite %11 %37 %38
%49 = OpConstantComposite %12 %37 %38 %39
%50 = OpConstantComposite %13 %37 %38 %39 %40
%51 = OpConstantComposite %14 %41 %42
%52 = OpConstantComposite %15 %41 %42 %43
%53 = OpConstantComposite %16 %41 %42 %43 %44
; main function
%54 = OpFunction %2 None %3
%55 = OpLabel
; Multiplying 2x2 matrix by scalar
%65 = OpCompositeExtract %5 %45 0
%66 = OpCompositeExtract %4 %65 0
%67 = OpFMul %4 %66 %17
%68 = OpCompositeExtract %4 %65 1
%69 = OpFMul %4 %68 %17
%70 = OpCompositeConstruct %5 %67 %69
%71 = OpCompositeExtract %5 %45 1
%72 = OpCompositeExtract %4 %71 0
%73 = OpFMul %4 %72 %17
%74 = OpCompositeExtract %4 %71 1
%75 = OpFMul %4 %74 %17
%76 = OpCompositeConstruct %5 %73 %75
%56 = OpCompositeConstruct %8 %70 %76
; Multiplying 2x3 matrix by scalar
%77 = OpCompositeExtract %5 %46 0
%78 = OpCompositeExtract %4 %77 0
%79 = OpFMul %4 %78 %18
%80 = OpCompositeExtract %4 %77 1
%81 = OpFMul %4 %80 %18
%82 = OpCompositeConstruct %5 %79 %81
%83 = OpCompositeExtract %5 %46 1
%84 = OpCompositeExtract %4 %83 0
%85 = OpFMul %4 %84 %18
%86 = OpCompositeExtract %4 %83 1
%87 = OpFMul %4 %86 %18
%88 = OpCompositeConstruct %5 %85 %87
%89 = OpCompositeExtract %5 %46 2
%90 = OpCompositeExtract %4 %89 0
%91 = OpFMul %4 %90 %18
%92 = OpCompositeExtract %4 %89 1
%93 = OpFMul %4 %92 %18
%94 = OpCompositeConstruct %5 %91 %93
%57 = OpCompositeConstruct %9 %82 %88 %94
; Multiplying 2x4 matrix by scalar
%95 = OpCompositeExtract %5 %47 0
%96 = OpCompositeExtract %4 %95 0
%97 = OpFMul %4 %96 %19
%98 = OpCompositeExtract %4 %95 1
%99 = OpFMul %4 %98 %19
%100 = OpCompositeConstruct %5 %97 %99
%101 = OpCompositeExtract %5 %47 1
%102 = OpCompositeExtract %4 %101 0
%103 = OpFMul %4 %102 %19
%104 = OpCompositeExtract %4 %101 1
%105 = OpFMul %4 %104 %19
%106 = OpCompositeConstruct %5 %103 %105
%107 = OpCompositeExtract %5 %47 2
%108 = OpCompositeExtract %4 %107 0
%109 = OpFMul %4 %108 %19
%110 = OpCompositeExtract %4 %107 1
%111 = OpFMul %4 %110 %19
%112 = OpCompositeConstruct %5 %109 %111
%113 = OpCompositeExtract %5 %47 3
%114 = OpCompositeExtract %4 %113 0
%115 = OpFMul %4 %114 %19
%116 = OpCompositeExtract %4 %113 1
%117 = OpFMul %4 %116 %19
%118 = OpCompositeConstruct %5 %115 %117
%58 = OpCompositeConstruct %10 %100 %106 %112 %118
; Multiplying 3-row matrices by scalar
%59 = OpMatrixTimesScalar %11 %48 %21
%60 = OpMatrixTimesScalar %12 %49 %22
%61 = OpMatrixTimesScalar %13 %50 %23
; Multiplying 4-row matrices by scalar
%62 = OpMatrixTimesScalar %14 %51 %24
%63 = OpMatrixTimesScalar %15 %52 %25
%64 = OpMatrixTimesScalar %16 %53 %26
OpReturn
OpFunctionEnd
)";
ASSERT_TRUE(IsValid(env, context.get()));
ASSERT_TRUE(IsEqual(env, variant_shader, context.get()));
}
TEST(TransformationReplaceLinearAlgebraInstructionTest, ReplaceOpDot) {
std::string reference_shader = R"(
OpCapability Shader