mirror of
https://github.com/KhronosGroup/SPIRV-Tools
synced 2024-10-18 11:10:05 +00:00
Add folding rule for OpTranspose (#5241)
This commit is contained in:
parent
ec244c8598
commit
5ed21eb1e2
@ -341,6 +341,69 @@ ConstantFoldingRule FoldVectorTimesScalar() {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns to the constant that results from tranposing |matrix|. The result
|
||||||
|
// will have type |result_type|, and |matrix| must exist in |context|. The
|
||||||
|
// result constant will also exist in |context|.
|
||||||
|
const analysis::Constant* TransposeMatrix(const analysis::Constant* matrix,
|
||||||
|
analysis::Matrix* result_type,
|
||||||
|
IRContext* context) {
|
||||||
|
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
|
||||||
|
if (matrix->AsNullConstant() != nullptr) {
|
||||||
|
return const_mgr->GetNullCompositeConstant(result_type);
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto& columns = matrix->AsMatrixConstant()->GetComponents();
|
||||||
|
uint32_t number_of_rows = columns[0]->type()->AsVector()->element_count();
|
||||||
|
|
||||||
|
// Collect the ids of the elements in their new positions.
|
||||||
|
std::vector<std::vector<uint32_t>> result_elements(number_of_rows);
|
||||||
|
for (const analysis::Constant* column : columns) {
|
||||||
|
if (column->AsNullConstant()) {
|
||||||
|
column = const_mgr->GetNullCompositeConstant(column->type());
|
||||||
|
}
|
||||||
|
const auto& column_components = column->AsVectorConstant()->GetComponents();
|
||||||
|
|
||||||
|
for (uint32_t row = 0; row < number_of_rows; ++row) {
|
||||||
|
result_elements[row].push_back(
|
||||||
|
const_mgr->GetDefiningInstruction(column_components[row])
|
||||||
|
->result_id());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the constant for each row in the result, and collect the ids.
|
||||||
|
std::vector<uint32_t> result_columns(number_of_rows);
|
||||||
|
for (uint32_t col = 0; col < number_of_rows; ++col) {
|
||||||
|
auto* element = const_mgr->GetConstant(result_type->element_type(),
|
||||||
|
result_elements[col]);
|
||||||
|
result_columns[col] =
|
||||||
|
const_mgr->GetDefiningInstruction(element)->result_id();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the matrix constant from the row ids, and return it.
|
||||||
|
return const_mgr->GetConstant(result_type, result_columns);
|
||||||
|
}
|
||||||
|
|
||||||
|
const analysis::Constant* FoldTranspose(
|
||||||
|
IRContext* context, Instruction* inst,
|
||||||
|
const std::vector<const analysis::Constant*>& constants) {
|
||||||
|
assert(inst->opcode() == spv::Op::OpTranspose);
|
||||||
|
|
||||||
|
analysis::TypeManager* type_mgr = context->get_type_mgr();
|
||||||
|
if (!inst->IsFloatingPointFoldingAllowed()) {
|
||||||
|
if (HasFloatingPoint(type_mgr->GetType(inst->type_id()))) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const analysis::Constant* matrix = constants[0];
|
||||||
|
if (matrix == nullptr) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto* result_type = type_mgr->GetType(inst->type_id());
|
||||||
|
return TransposeMatrix(matrix, result_type->AsMatrix(), context);
|
||||||
|
}
|
||||||
|
|
||||||
ConstantFoldingRule FoldVectorTimesMatrix() {
|
ConstantFoldingRule FoldVectorTimesMatrix() {
|
||||||
return [](IRContext* context, Instruction* inst,
|
return [](IRContext* context, Instruction* inst,
|
||||||
const std::vector<const analysis::Constant*>& constants)
|
const std::vector<const analysis::Constant*>& constants)
|
||||||
@ -1566,6 +1629,7 @@ void ConstantFoldingRules::AddFoldingRules() {
|
|||||||
rules_[spv::Op::OpVectorTimesScalar].push_back(FoldVectorTimesScalar());
|
rules_[spv::Op::OpVectorTimesScalar].push_back(FoldVectorTimesScalar());
|
||||||
rules_[spv::Op::OpVectorTimesMatrix].push_back(FoldVectorTimesMatrix());
|
rules_[spv::Op::OpVectorTimesMatrix].push_back(FoldVectorTimesMatrix());
|
||||||
rules_[spv::Op::OpMatrixTimesVector].push_back(FoldMatrixTimesVector());
|
rules_[spv::Op::OpMatrixTimesVector].push_back(FoldMatrixTimesVector());
|
||||||
|
rules_[spv::Op::OpTranspose].push_back(FoldTranspose);
|
||||||
|
|
||||||
rules_[spv::Op::OpFNegate].push_back(FoldFNegate());
|
rules_[spv::Op::OpFNegate].push_back(FoldFNegate());
|
||||||
rules_[spv::Op::OpQuantizeToF16].push_back(FoldQuantizeToF16());
|
rules_[spv::Op::OpQuantizeToF16].push_back(FoldQuantizeToF16());
|
||||||
|
@ -156,6 +156,8 @@ OpName %main "main"
|
|||||||
%v2half = OpTypeVector %half 2
|
%v2half = OpTypeVector %half 2
|
||||||
%v2bool = OpTypeVector %bool 2
|
%v2bool = OpTypeVector %bool 2
|
||||||
%m2x2int = OpTypeMatrix %v2int 2
|
%m2x2int = OpTypeMatrix %v2int 2
|
||||||
|
%mat4v2float = OpTypeMatrix %v2float 4
|
||||||
|
%mat2v4float = OpTypeMatrix %v4float 2
|
||||||
%mat4v4float = OpTypeMatrix %v4float 4
|
%mat4v4float = OpTypeMatrix %v4float 4
|
||||||
%mat4v4double = OpTypeMatrix %v4double 4
|
%mat4v4double = OpTypeMatrix %v4double 4
|
||||||
%struct_v2int_int_int = OpTypeStruct %v2int %int %int
|
%struct_v2int_int_int = OpTypeStruct %v2int %int %int
|
||||||
@ -290,6 +292,7 @@ OpName %main "main"
|
|||||||
%v4float_1_1_1_1 = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1
|
%v4float_1_1_1_1 = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1
|
||||||
%v4float_1_2_3_4 = OpConstantComposite %v4float %float_1 %float_2 %float_3 %float_4
|
%v4float_1_2_3_4 = OpConstantComposite %v4float %float_1 %float_2 %float_3 %float_4
|
||||||
%v4float_null = OpConstantNull %v4float
|
%v4float_null = OpConstantNull %v4float
|
||||||
|
%mat2v4float_null = OpConstantNull %mat2v4float
|
||||||
%mat4v4float_null = OpConstantNull %mat4v4float
|
%mat4v4float_null = OpConstantNull %mat4v4float
|
||||||
%mat4v4float_1_2_3_4 = OpConstantComposite %mat4v4float %v4float_1_2_3_4 %v4float_1_2_3_4 %v4float_1_2_3_4 %v4float_1_2_3_4
|
%mat4v4float_1_2_3_4 = OpConstantComposite %mat4v4float %v4float_1_2_3_4 %v4float_1_2_3_4 %v4float_1_2_3_4 %v4float_1_2_3_4
|
||||||
%mat4v4float_1_2_3_4_null = OpConstantComposite %mat4v4float %v4float_1_2_3_4 %v4float_null %v4float_1_2_3_4 %v4float_null
|
%mat4v4float_1_2_3_4_null = OpConstantComposite %mat4v4float %v4float_1_2_3_4 %v4float_null %v4float_1_2_3_4 %v4float_null
|
||||||
@ -1239,6 +1242,84 @@ INSTANTIATE_TEST_SUITE_P(TestCase, FloatVectorInstructionFoldingTest,
|
|||||||
2, {4.0,8.0,12.0,16.0})
|
2, {4.0,8.0,12.0,16.0})
|
||||||
));
|
));
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
|
||||||
|
using FloatMatrixInstructionFoldingTest = ::testing::TestWithParam<
|
||||||
|
InstructionFoldingCase<std::vector<std::vector<float>>>>;
|
||||||
|
|
||||||
|
TEST_P(FloatMatrixInstructionFoldingTest, Case) {
|
||||||
|
const auto& tc = GetParam();
|
||||||
|
|
||||||
|
// Build module.
|
||||||
|
std::unique_ptr<IRContext> context =
|
||||||
|
BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.test_body,
|
||||||
|
SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
|
||||||
|
ASSERT_NE(nullptr, context);
|
||||||
|
|
||||||
|
// Fold the instruction to test.
|
||||||
|
analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
|
||||||
|
Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold);
|
||||||
|
bool succeeded = context->get_instruction_folder().FoldInstruction(inst);
|
||||||
|
|
||||||
|
// Make sure the instruction folded as expected.
|
||||||
|
EXPECT_TRUE(succeeded);
|
||||||
|
EXPECT_EQ(inst->opcode(), spv::Op::OpCopyObject);
|
||||||
|
|
||||||
|
if (inst->opcode() == spv::Op::OpCopyObject) {
|
||||||
|
inst = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0));
|
||||||
|
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
|
||||||
|
const analysis::Constant* result = const_mgr->GetConstantFromInst(inst);
|
||||||
|
EXPECT_NE(result, nullptr);
|
||||||
|
if (result != nullptr) {
|
||||||
|
std::vector<const analysis::Constant*> matrix =
|
||||||
|
result->AsMatrixConstant()->GetComponents();
|
||||||
|
EXPECT_EQ(matrix.size(), tc.expected_result.size());
|
||||||
|
for (size_t c = 0; c < matrix.size(); c++) {
|
||||||
|
if (matrix[c]->AsNullConstant() != nullptr) {
|
||||||
|
matrix[c] = const_mgr->GetNullCompositeConstant(matrix[c]->type());
|
||||||
|
}
|
||||||
|
const analysis::VectorConstant* column_const =
|
||||||
|
matrix[c]->AsVectorConstant();
|
||||||
|
ASSERT_NE(column_const, nullptr);
|
||||||
|
const std::vector<const analysis::Constant*>& column =
|
||||||
|
column_const->GetComponents();
|
||||||
|
EXPECT_EQ(column.size(), tc.expected_result[c].size());
|
||||||
|
for (size_t r = 0; r < column.size(); r++) {
|
||||||
|
EXPECT_EQ(tc.expected_result[c][r], column[r]->GetFloat());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
INSTANTIATE_TEST_SUITE_P(TestCase, FloatMatrixInstructionFoldingTest,
|
||||||
|
::testing::Values(
|
||||||
|
// Test case 0: OpTranspose square null matrix
|
||||||
|
InstructionFoldingCase<std::vector<std::vector<float>>>(
|
||||||
|
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||||
|
"%main_lab = OpLabel\n" +
|
||||||
|
"%2 = OpTranspose %mat4v4float %mat4v4float_null\n" +
|
||||||
|
"OpReturn\n" +
|
||||||
|
"OpFunctionEnd",
|
||||||
|
2, {{0.0f, 0.0f, 0.0f, 0.0f},{0.0f, 0.0f, 0.0f, 0.0f},{0.0f, 0.0f, 0.0f, 0.0f},{0.0f, 0.0f, 0.0f, 0.0f}}),
|
||||||
|
// Test case 1: OpTranspose rectangular null matrix
|
||||||
|
InstructionFoldingCase<std::vector<std::vector<float>>>(
|
||||||
|
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||||
|
"%main_lab = OpLabel\n" +
|
||||||
|
"%2 = OpTranspose %mat4v2float %mat2v4float_null\n" +
|
||||||
|
"OpReturn\n" +
|
||||||
|
"OpFunctionEnd",
|
||||||
|
2, {{0.0f, 0.0f},{0.0f, 0.0f},{0.0f, 0.0f},{0.0f, 0.0f}}),
|
||||||
|
InstructionFoldingCase<std::vector<std::vector<float>>>(
|
||||||
|
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||||
|
"%main_lab = OpLabel\n" +
|
||||||
|
"%2 = OpTranspose %mat4v4float %mat4v4float_1_2_3_4\n" +
|
||||||
|
"OpReturn\n" +
|
||||||
|
"OpFunctionEnd",
|
||||||
|
2, {{1.0f, 1.0f, 1.0f, 1.0f},{2.0f, 2.0f, 2.0f, 2.0f},{3.0f, 3.0f, 3.0f, 3.0f},{4.0f, 4.0f, 4.0f, 4.0f}})
|
||||||
|
));
|
||||||
|
// clang-format on
|
||||||
|
|
||||||
using BooleanInstructionFoldingTest =
|
using BooleanInstructionFoldingTest =
|
||||||
::testing::TestWithParam<InstructionFoldingCase<bool>>;
|
::testing::TestWithParam<InstructionFoldingCase<bool>>;
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user