Add folding rule for OpTranspose (#5241)

This commit is contained in:
Steven Perron 2023-06-01 09:09:08 -07:00 committed by GitHub
parent ec244c8598
commit 5ed21eb1e2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 145 additions and 0 deletions

View File

@ -341,6 +341,69 @@ ConstantFoldingRule FoldVectorTimesScalar() {
};
}
// Returns to the constant that results from tranposing |matrix|. The result
// will have type |result_type|, and |matrix| must exist in |context|. The
// result constant will also exist in |context|.
const analysis::Constant* TransposeMatrix(const analysis::Constant* matrix,
analysis::Matrix* result_type,
IRContext* context) {
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
if (matrix->AsNullConstant() != nullptr) {
return const_mgr->GetNullCompositeConstant(result_type);
}
const auto& columns = matrix->AsMatrixConstant()->GetComponents();
uint32_t number_of_rows = columns[0]->type()->AsVector()->element_count();
// Collect the ids of the elements in their new positions.
std::vector<std::vector<uint32_t>> result_elements(number_of_rows);
for (const analysis::Constant* column : columns) {
if (column->AsNullConstant()) {
column = const_mgr->GetNullCompositeConstant(column->type());
}
const auto& column_components = column->AsVectorConstant()->GetComponents();
for (uint32_t row = 0; row < number_of_rows; ++row) {
result_elements[row].push_back(
const_mgr->GetDefiningInstruction(column_components[row])
->result_id());
}
}
// Create the constant for each row in the result, and collect the ids.
std::vector<uint32_t> result_columns(number_of_rows);
for (uint32_t col = 0; col < number_of_rows; ++col) {
auto* element = const_mgr->GetConstant(result_type->element_type(),
result_elements[col]);
result_columns[col] =
const_mgr->GetDefiningInstruction(element)->result_id();
}
// Create the matrix constant from the row ids, and return it.
return const_mgr->GetConstant(result_type, result_columns);
}
const analysis::Constant* FoldTranspose(
IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == spv::Op::OpTranspose);
analysis::TypeManager* type_mgr = context->get_type_mgr();
if (!inst->IsFloatingPointFoldingAllowed()) {
if (HasFloatingPoint(type_mgr->GetType(inst->type_id()))) {
return nullptr;
}
}
const analysis::Constant* matrix = constants[0];
if (matrix == nullptr) {
return nullptr;
}
auto* result_type = type_mgr->GetType(inst->type_id());
return TransposeMatrix(matrix, result_type->AsMatrix(), context);
}
ConstantFoldingRule FoldVectorTimesMatrix() {
return [](IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>& constants)
@ -1566,6 +1629,7 @@ void ConstantFoldingRules::AddFoldingRules() {
rules_[spv::Op::OpVectorTimesScalar].push_back(FoldVectorTimesScalar());
rules_[spv::Op::OpVectorTimesMatrix].push_back(FoldVectorTimesMatrix());
rules_[spv::Op::OpMatrixTimesVector].push_back(FoldMatrixTimesVector());
rules_[spv::Op::OpTranspose].push_back(FoldTranspose);
rules_[spv::Op::OpFNegate].push_back(FoldFNegate());
rules_[spv::Op::OpQuantizeToF16].push_back(FoldQuantizeToF16());

View File

@ -156,6 +156,8 @@ OpName %main "main"
%v2half = OpTypeVector %half 2
%v2bool = OpTypeVector %bool 2
%m2x2int = OpTypeMatrix %v2int 2
%mat4v2float = OpTypeMatrix %v2float 4
%mat2v4float = OpTypeMatrix %v4float 2
%mat4v4float = OpTypeMatrix %v4float 4
%mat4v4double = OpTypeMatrix %v4double 4
%struct_v2int_int_int = OpTypeStruct %v2int %int %int
@ -290,6 +292,7 @@ OpName %main "main"
%v4float_1_1_1_1 = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1
%v4float_1_2_3_4 = OpConstantComposite %v4float %float_1 %float_2 %float_3 %float_4
%v4float_null = OpConstantNull %v4float
%mat2v4float_null = OpConstantNull %mat2v4float
%mat4v4float_null = OpConstantNull %mat4v4float
%mat4v4float_1_2_3_4 = OpConstantComposite %mat4v4float %v4float_1_2_3_4 %v4float_1_2_3_4 %v4float_1_2_3_4 %v4float_1_2_3_4
%mat4v4float_1_2_3_4_null = OpConstantComposite %mat4v4float %v4float_1_2_3_4 %v4float_null %v4float_1_2_3_4 %v4float_null
@ -1239,6 +1242,84 @@ INSTANTIATE_TEST_SUITE_P(TestCase, FloatVectorInstructionFoldingTest,
2, {4.0,8.0,12.0,16.0})
));
// clang-format on
using FloatMatrixInstructionFoldingTest = ::testing::TestWithParam<
InstructionFoldingCase<std::vector<std::vector<float>>>>;
TEST_P(FloatMatrixInstructionFoldingTest, Case) {
const auto& tc = GetParam();
// Build module.
std::unique_ptr<IRContext> context =
BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.test_body,
SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
ASSERT_NE(nullptr, context);
// Fold the instruction to test.
analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold);
bool succeeded = context->get_instruction_folder().FoldInstruction(inst);
// Make sure the instruction folded as expected.
EXPECT_TRUE(succeeded);
EXPECT_EQ(inst->opcode(), spv::Op::OpCopyObject);
if (inst->opcode() == spv::Op::OpCopyObject) {
inst = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0));
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
const analysis::Constant* result = const_mgr->GetConstantFromInst(inst);
EXPECT_NE(result, nullptr);
if (result != nullptr) {
std::vector<const analysis::Constant*> matrix =
result->AsMatrixConstant()->GetComponents();
EXPECT_EQ(matrix.size(), tc.expected_result.size());
for (size_t c = 0; c < matrix.size(); c++) {
if (matrix[c]->AsNullConstant() != nullptr) {
matrix[c] = const_mgr->GetNullCompositeConstant(matrix[c]->type());
}
const analysis::VectorConstant* column_const =
matrix[c]->AsVectorConstant();
ASSERT_NE(column_const, nullptr);
const std::vector<const analysis::Constant*>& column =
column_const->GetComponents();
EXPECT_EQ(column.size(), tc.expected_result[c].size());
for (size_t r = 0; r < column.size(); r++) {
EXPECT_EQ(tc.expected_result[c][r], column[r]->GetFloat());
}
}
}
}
}
// clang-format off
INSTANTIATE_TEST_SUITE_P(TestCase, FloatMatrixInstructionFoldingTest,
::testing::Values(
// Test case 0: OpTranspose square null matrix
InstructionFoldingCase<std::vector<std::vector<float>>>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpTranspose %mat4v4float %mat4v4float_null\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, {{0.0f, 0.0f, 0.0f, 0.0f},{0.0f, 0.0f, 0.0f, 0.0f},{0.0f, 0.0f, 0.0f, 0.0f},{0.0f, 0.0f, 0.0f, 0.0f}}),
// Test case 1: OpTranspose rectangular null matrix
InstructionFoldingCase<std::vector<std::vector<float>>>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpTranspose %mat4v2float %mat2v4float_null\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, {{0.0f, 0.0f},{0.0f, 0.0f},{0.0f, 0.0f},{0.0f, 0.0f}}),
InstructionFoldingCase<std::vector<std::vector<float>>>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpTranspose %mat4v4float %mat4v4float_1_2_3_4\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, {{1.0f, 1.0f, 1.0f, 1.0f},{2.0f, 2.0f, 2.0f, 2.0f},{3.0f, 3.0f, 3.0f, 3.0f},{4.0f, 4.0f, 4.0f, 4.0f}})
));
// clang-format on
using BooleanInstructionFoldingTest =
::testing::TestWithParam<InstructionFoldingCase<bool>>;