mirror of
https://github.com/KhronosGroup/SPIRV-Tools
synced 2024-10-19 03:20:14 +00:00
parent
2c5ed16ba9
commit
b54d950298
@ -296,24 +296,18 @@ ConstantFoldingRule FoldFPUnaryOp(UnaryScalarFoldingRule scalar_rule) {
|
||||
};
|
||||
}
|
||||
|
||||
// Returns a |ConstantFoldingRule| that folds floating point scalars using
|
||||
// |scalar_rule| and vectors of floating point by applying |scalar_rule| to the
|
||||
// elements of the vector. The |ConstantFoldingRule| that is returned assumes
|
||||
// that |constants| contains 2 entries. If they are not |nullptr|, then their
|
||||
// type is either |Float| or a |Vector| whose element type is |Float|.
|
||||
ConstantFoldingRule FoldFPBinaryOp(BinaryScalarFoldingRule scalar_rule) {
|
||||
return [scalar_rule](IRContext* context, Instruction* inst,
|
||||
const std::vector<const analysis::Constant*>& constants)
|
||||
-> const analysis::Constant* {
|
||||
// Returns the result of folding the constants in |constants| according the
|
||||
// |scalar_rule|. If |result_type| is a vector, then |scalar_rule| is applied
|
||||
// per component.
|
||||
const analysis::Constant* FoldFPBinaryOp(
|
||||
BinaryScalarFoldingRule scalar_rule, uint32_t result_type_id,
|
||||
const std::vector<const analysis::Constant*>& constants,
|
||||
IRContext* context) {
|
||||
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
|
||||
analysis::TypeManager* type_mgr = context->get_type_mgr();
|
||||
const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
|
||||
const analysis::Type* result_type = type_mgr->GetType(result_type_id);
|
||||
const analysis::Vector* vector_type = result_type->AsVector();
|
||||
|
||||
if (!inst->IsFloatingPointFoldingAllowed()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (constants[0] == nullptr || constants[1] == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
@ -329,8 +323,8 @@ ConstantFoldingRule FoldFPBinaryOp(BinaryScalarFoldingRule scalar_rule) {
|
||||
// Fold each component of the vector.
|
||||
for (uint32_t i = 0; i < a_components.size(); ++i) {
|
||||
results_components.push_back(scalar_rule(vector_type->element_type(),
|
||||
a_components[i],
|
||||
b_components[i], const_mgr));
|
||||
a_components[i], b_components[i],
|
||||
const_mgr));
|
||||
if (results_components[i] == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
@ -345,6 +339,21 @@ ConstantFoldingRule FoldFPBinaryOp(BinaryScalarFoldingRule scalar_rule) {
|
||||
} else {
|
||||
return scalar_rule(result_type, constants[0], constants[1], const_mgr);
|
||||
}
|
||||
}
|
||||
|
||||
// Returns a |ConstantFoldingRule| that folds floating point scalars using
|
||||
// |scalar_rule| and vectors of floating point by applying |scalar_rule| to the
|
||||
// elements of the vector. The |ConstantFoldingRule| that is returned assumes
|
||||
// that |constants| contains 2 entries. If they are not |nullptr|, then their
|
||||
// type is either |Float| or a |Vector| whose element type is |Float|.
|
||||
ConstantFoldingRule FoldFPBinaryOp(BinaryScalarFoldingRule scalar_rule) {
|
||||
return [scalar_rule](IRContext* context, Instruction* inst,
|
||||
const std::vector<const analysis::Constant*>& constants)
|
||||
-> const analysis::Constant* {
|
||||
if (!inst->IsFloatingPointFoldingAllowed()) {
|
||||
return nullptr;
|
||||
}
|
||||
return FoldFPBinaryOp(scalar_rule, inst->type_id(), constants, context);
|
||||
};
|
||||
}
|
||||
|
||||
@ -436,26 +445,30 @@ UnaryScalarFoldingRule FoldQuantizeToF16Scalar() {
|
||||
// This macro defines a |BinaryScalarFoldingRule| that applies |op|. The
|
||||
// operator |op| must work for both float and double, and use syntax "f1 op f2".
|
||||
#define FOLD_FPARITH_OP(op) \
|
||||
[](const analysis::Type* result_type, const analysis::Constant* a, \
|
||||
[](const analysis::Type* result_type_in_macro, const analysis::Constant* a, \
|
||||
const analysis::Constant* b, \
|
||||
analysis::ConstantManager* const_mgr_in_macro) \
|
||||
-> const analysis::Constant* { \
|
||||
assert(result_type != nullptr && a != nullptr && b != nullptr); \
|
||||
assert(result_type == a->type() && result_type == b->type()); \
|
||||
const analysis::Float* float_type_in_macro = result_type->AsFloat(); \
|
||||
assert(result_type_in_macro != nullptr && a != nullptr && b != nullptr); \
|
||||
assert(result_type_in_macro == a->type() && \
|
||||
result_type_in_macro == b->type()); \
|
||||
const analysis::Float* float_type_in_macro = \
|
||||
result_type_in_macro->AsFloat(); \
|
||||
assert(float_type_in_macro != nullptr); \
|
||||
if (float_type_in_macro->width() == 32) { \
|
||||
float fa = a->GetFloat(); \
|
||||
float fb = b->GetFloat(); \
|
||||
utils::FloatProxy<float> result_in_macro(fa op fb); \
|
||||
std::vector<uint32_t> words_in_macro = result_in_macro.GetWords(); \
|
||||
return const_mgr_in_macro->GetConstant(result_type, words_in_macro); \
|
||||
return const_mgr_in_macro->GetConstant(result_type_in_macro, \
|
||||
words_in_macro); \
|
||||
} else if (float_type_in_macro->width() == 64) { \
|
||||
double fa = a->GetDouble(); \
|
||||
double fb = b->GetDouble(); \
|
||||
utils::FloatProxy<double> result_in_macro(fa op fb); \
|
||||
std::vector<uint32_t> words_in_macro = result_in_macro.GetWords(); \
|
||||
return const_mgr_in_macro->GetConstant(result_type, words_in_macro); \
|
||||
return const_mgr_in_macro->GetConstant(result_type_in_macro, \
|
||||
words_in_macro); \
|
||||
} \
|
||||
return nullptr; \
|
||||
}
|
||||
@ -834,31 +847,49 @@ ConstantFoldingRule FoldFMix() {
|
||||
}
|
||||
|
||||
const analysis::Constant* one;
|
||||
if (constants[1]->type()->AsFloat()->width() == 32) {
|
||||
one = const_mgr->GetConstant(constants[1]->type(),
|
||||
bool is_vector = false;
|
||||
const analysis::Type* result_type = constants[1]->type();
|
||||
const analysis::Type* base_type = result_type;
|
||||
if (base_type->AsVector()) {
|
||||
is_vector = true;
|
||||
base_type = base_type->AsVector()->element_type();
|
||||
}
|
||||
assert(base_type->AsFloat() != nullptr &&
|
||||
"FMix is suppose to act on floats or vectors of floats.");
|
||||
|
||||
if (base_type->AsFloat()->width() == 32) {
|
||||
one = const_mgr->GetConstant(base_type,
|
||||
utils::FloatProxy<float>(1.0f).GetWords());
|
||||
} else {
|
||||
one = const_mgr->GetConstant(constants[1]->type(),
|
||||
one = const_mgr->GetConstant(base_type,
|
||||
utils::FloatProxy<double>(1.0).GetWords());
|
||||
}
|
||||
|
||||
const analysis::Constant* temp1 =
|
||||
FOLD_FPARITH_OP(-)(constants[1]->type(), one, constants[3], const_mgr);
|
||||
if (is_vector) {
|
||||
uint32_t one_id = const_mgr->GetDefiningInstruction(one)->result_id();
|
||||
one =
|
||||
const_mgr->GetConstant(result_type, std::vector<uint32_t>(4, one_id));
|
||||
}
|
||||
|
||||
const analysis::Constant* temp1 = FoldFPBinaryOp(
|
||||
FOLD_FPARITH_OP(-), inst->type_id(), {one, constants[3]}, context);
|
||||
if (temp1 == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const analysis::Constant* temp2 = FOLD_FPARITH_OP(*)(
|
||||
constants[1]->type(), constants[1], temp1, const_mgr);
|
||||
const analysis::Constant* temp2 = FoldFPBinaryOp(
|
||||
FOLD_FPARITH_OP(*), inst->type_id(), {constants[1], temp1}, context);
|
||||
if (temp2 == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
const analysis::Constant* temp3 = FOLD_FPARITH_OP(*)(
|
||||
constants[2]->type(), constants[2], constants[3], const_mgr);
|
||||
const analysis::Constant* temp3 =
|
||||
FoldFPBinaryOp(FOLD_FPARITH_OP(*), inst->type_id(),
|
||||
{constants[2], constants[3]}, context);
|
||||
if (temp3 == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
return FOLD_FPARITH_OP(+)(temp2->type(), temp2, temp3, const_mgr);
|
||||
return FoldFPBinaryOp(FOLD_FPARITH_OP(+), inst->type_id(), {temp2, temp3},
|
||||
context);
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -222,6 +222,7 @@ OpName %main "main"
|
||||
%v2float_3_2 = OpConstantComposite %v2float %float_3 %float_2
|
||||
%v2float_4_4 = OpConstantComposite %v2float %float_4 %float_4
|
||||
%v2float_2_0p5 = OpConstantComposite %v2float %float_2 %float_0p5
|
||||
%v2float_0p2_0p5 = OpConstantComposite %v2float %float_0p2 %float_0p5
|
||||
%v2float_null = OpConstantNull %v2float
|
||||
%double_n1 = OpConstant %double -1
|
||||
%105 = OpConstant %double 0 ; Need a def with an numerical id to define id maps.
|
||||
@ -643,6 +644,58 @@ INSTANTIATE_TEST_SUITE_P(TestCase, IntVectorInstructionFoldingTest,
|
||||
));
|
||||
// clang-format on
|
||||
|
||||
using FloatVectorInstructionFoldingTest =
|
||||
::testing::TestWithParam<InstructionFoldingCase<std::vector<float>>>;
|
||||
|
||||
TEST_P(FloatVectorInstructionFoldingTest, Case) {
|
||||
const auto& tc = GetParam();
|
||||
|
||||
// Build module.
|
||||
std::unique_ptr<IRContext> context =
|
||||
BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.test_body,
|
||||
SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
|
||||
ASSERT_NE(nullptr, context);
|
||||
|
||||
// Fold the instruction to test.
|
||||
analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
|
||||
Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold);
|
||||
SpvOp original_opcode = inst->opcode();
|
||||
bool succeeded = context->get_instruction_folder().FoldInstruction(inst);
|
||||
|
||||
// Make sure the instruction folded as expected.
|
||||
EXPECT_EQ(succeeded, inst == nullptr || inst->opcode() != original_opcode);
|
||||
if (succeeded && inst != nullptr) {
|
||||
EXPECT_EQ(inst->opcode(), SpvOpCopyObject);
|
||||
inst = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0));
|
||||
std::vector<SpvOp> opcodes = {SpvOpConstantComposite};
|
||||
EXPECT_THAT(opcodes, Contains(inst->opcode()));
|
||||
analysis::ConstantManager* const_mrg = context->get_constant_mgr();
|
||||
const analysis::Constant* result = const_mrg->GetConstantFromInst(inst);
|
||||
EXPECT_NE(result, nullptr);
|
||||
if (result != nullptr) {
|
||||
const std::vector<const analysis::Constant*>& componenets =
|
||||
result->AsVectorConstant()->GetComponents();
|
||||
EXPECT_EQ(componenets.size(), tc.expected_result.size());
|
||||
for (size_t i = 0; i < componenets.size(); i++) {
|
||||
EXPECT_EQ(tc.expected_result[i], componenets[i]->GetFloat());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
INSTANTIATE_TEST_SUITE_P(TestCase, FloatVectorInstructionFoldingTest,
|
||||
::testing::Values(
|
||||
// Test case 0: FMix {2.0, 2.0}, {2.0, 3.0} {0.2,0.5}
|
||||
InstructionFoldingCase<std::vector<float>>(
|
||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%2 = OpExtInst %v2float %1 FMix %v2float_2_3 %v2float_0_0 %v2float_0p2_0p5\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
2, {1.6f,1.5f})
|
||||
));
|
||||
// clang-format on
|
||||
using BooleanInstructionFoldingTest =
|
||||
::testing::TestWithParam<InstructionFoldingCase<bool>>;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user