mirror of
https://github.com/KhronosGroup/SPIRV-Tools
synced 2024-10-19 03:20:14 +00:00
parent
2c5ed16ba9
commit
b54d950298
@ -296,6 +296,51 @@ ConstantFoldingRule FoldFPUnaryOp(UnaryScalarFoldingRule scalar_rule) {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns the result of folding the constants in |constants| according the
|
||||||
|
// |scalar_rule|. If |result_type| is a vector, then |scalar_rule| is applied
|
||||||
|
// per component.
|
||||||
|
const analysis::Constant* FoldFPBinaryOp(
|
||||||
|
BinaryScalarFoldingRule scalar_rule, uint32_t result_type_id,
|
||||||
|
const std::vector<const analysis::Constant*>& constants,
|
||||||
|
IRContext* context) {
|
||||||
|
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
|
||||||
|
analysis::TypeManager* type_mgr = context->get_type_mgr();
|
||||||
|
const analysis::Type* result_type = type_mgr->GetType(result_type_id);
|
||||||
|
const analysis::Vector* vector_type = result_type->AsVector();
|
||||||
|
|
||||||
|
if (constants[0] == nullptr || constants[1] == nullptr) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (vector_type != nullptr) {
|
||||||
|
std::vector<const analysis::Constant*> a_components;
|
||||||
|
std::vector<const analysis::Constant*> b_components;
|
||||||
|
std::vector<const analysis::Constant*> results_components;
|
||||||
|
|
||||||
|
a_components = constants[0]->GetVectorComponents(const_mgr);
|
||||||
|
b_components = constants[1]->GetVectorComponents(const_mgr);
|
||||||
|
|
||||||
|
// Fold each component of the vector.
|
||||||
|
for (uint32_t i = 0; i < a_components.size(); ++i) {
|
||||||
|
results_components.push_back(scalar_rule(vector_type->element_type(),
|
||||||
|
a_components[i], b_components[i],
|
||||||
|
const_mgr));
|
||||||
|
if (results_components[i] == nullptr) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build the constant object and return it.
|
||||||
|
std::vector<uint32_t> ids;
|
||||||
|
for (const analysis::Constant* member : results_components) {
|
||||||
|
ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id());
|
||||||
|
}
|
||||||
|
return const_mgr->GetConstant(vector_type, ids);
|
||||||
|
} else {
|
||||||
|
return scalar_rule(result_type, constants[0], constants[1], const_mgr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Returns a |ConstantFoldingRule| that folds floating point scalars using
|
// Returns a |ConstantFoldingRule| that folds floating point scalars using
|
||||||
// |scalar_rule| and vectors of floating point by applying |scalar_rule| to the
|
// |scalar_rule| and vectors of floating point by applying |scalar_rule| to the
|
||||||
// elements of the vector. The |ConstantFoldingRule| that is returned assumes
|
// elements of the vector. The |ConstantFoldingRule| that is returned assumes
|
||||||
@ -305,46 +350,10 @@ ConstantFoldingRule FoldFPBinaryOp(BinaryScalarFoldingRule scalar_rule) {
|
|||||||
return [scalar_rule](IRContext* context, Instruction* inst,
|
return [scalar_rule](IRContext* context, Instruction* inst,
|
||||||
const std::vector<const analysis::Constant*>& constants)
|
const std::vector<const analysis::Constant*>& constants)
|
||||||
-> const analysis::Constant* {
|
-> const analysis::Constant* {
|
||||||
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
|
|
||||||
analysis::TypeManager* type_mgr = context->get_type_mgr();
|
|
||||||
const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
|
|
||||||
const analysis::Vector* vector_type = result_type->AsVector();
|
|
||||||
|
|
||||||
if (!inst->IsFloatingPointFoldingAllowed()) {
|
if (!inst->IsFloatingPointFoldingAllowed()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
return FoldFPBinaryOp(scalar_rule, inst->type_id(), constants, context);
|
||||||
if (constants[0] == nullptr || constants[1] == nullptr) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (vector_type != nullptr) {
|
|
||||||
std::vector<const analysis::Constant*> a_components;
|
|
||||||
std::vector<const analysis::Constant*> b_components;
|
|
||||||
std::vector<const analysis::Constant*> results_components;
|
|
||||||
|
|
||||||
a_components = constants[0]->GetVectorComponents(const_mgr);
|
|
||||||
b_components = constants[1]->GetVectorComponents(const_mgr);
|
|
||||||
|
|
||||||
// Fold each component of the vector.
|
|
||||||
for (uint32_t i = 0; i < a_components.size(); ++i) {
|
|
||||||
results_components.push_back(scalar_rule(vector_type->element_type(),
|
|
||||||
a_components[i],
|
|
||||||
b_components[i], const_mgr));
|
|
||||||
if (results_components[i] == nullptr) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build the constant object and return it.
|
|
||||||
std::vector<uint32_t> ids;
|
|
||||||
for (const analysis::Constant* member : results_components) {
|
|
||||||
ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id());
|
|
||||||
}
|
|
||||||
return const_mgr->GetConstant(vector_type, ids);
|
|
||||||
} else {
|
|
||||||
return scalar_rule(result_type, constants[0], constants[1], const_mgr);
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -435,29 +444,33 @@ UnaryScalarFoldingRule FoldQuantizeToF16Scalar() {
|
|||||||
|
|
||||||
// This macro defines a |BinaryScalarFoldingRule| that applies |op|. The
|
// This macro defines a |BinaryScalarFoldingRule| that applies |op|. The
|
||||||
// operator |op| must work for both float and double, and use syntax "f1 op f2".
|
// operator |op| must work for both float and double, and use syntax "f1 op f2".
|
||||||
#define FOLD_FPARITH_OP(op) \
|
#define FOLD_FPARITH_OP(op) \
|
||||||
[](const analysis::Type* result_type, const analysis::Constant* a, \
|
[](const analysis::Type* result_type_in_macro, const analysis::Constant* a, \
|
||||||
const analysis::Constant* b, \
|
const analysis::Constant* b, \
|
||||||
analysis::ConstantManager* const_mgr_in_macro) \
|
analysis::ConstantManager* const_mgr_in_macro) \
|
||||||
-> const analysis::Constant* { \
|
-> const analysis::Constant* { \
|
||||||
assert(result_type != nullptr && a != nullptr && b != nullptr); \
|
assert(result_type_in_macro != nullptr && a != nullptr && b != nullptr); \
|
||||||
assert(result_type == a->type() && result_type == b->type()); \
|
assert(result_type_in_macro == a->type() && \
|
||||||
const analysis::Float* float_type_in_macro = result_type->AsFloat(); \
|
result_type_in_macro == b->type()); \
|
||||||
assert(float_type_in_macro != nullptr); \
|
const analysis::Float* float_type_in_macro = \
|
||||||
if (float_type_in_macro->width() == 32) { \
|
result_type_in_macro->AsFloat(); \
|
||||||
float fa = a->GetFloat(); \
|
assert(float_type_in_macro != nullptr); \
|
||||||
float fb = b->GetFloat(); \
|
if (float_type_in_macro->width() == 32) { \
|
||||||
utils::FloatProxy<float> result_in_macro(fa op fb); \
|
float fa = a->GetFloat(); \
|
||||||
std::vector<uint32_t> words_in_macro = result_in_macro.GetWords(); \
|
float fb = b->GetFloat(); \
|
||||||
return const_mgr_in_macro->GetConstant(result_type, words_in_macro); \
|
utils::FloatProxy<float> result_in_macro(fa op fb); \
|
||||||
} else if (float_type_in_macro->width() == 64) { \
|
std::vector<uint32_t> words_in_macro = result_in_macro.GetWords(); \
|
||||||
double fa = a->GetDouble(); \
|
return const_mgr_in_macro->GetConstant(result_type_in_macro, \
|
||||||
double fb = b->GetDouble(); \
|
words_in_macro); \
|
||||||
utils::FloatProxy<double> result_in_macro(fa op fb); \
|
} else if (float_type_in_macro->width() == 64) { \
|
||||||
std::vector<uint32_t> words_in_macro = result_in_macro.GetWords(); \
|
double fa = a->GetDouble(); \
|
||||||
return const_mgr_in_macro->GetConstant(result_type, words_in_macro); \
|
double fb = b->GetDouble(); \
|
||||||
} \
|
utils::FloatProxy<double> result_in_macro(fa op fb); \
|
||||||
return nullptr; \
|
std::vector<uint32_t> words_in_macro = result_in_macro.GetWords(); \
|
||||||
|
return const_mgr_in_macro->GetConstant(result_type_in_macro, \
|
||||||
|
words_in_macro); \
|
||||||
|
} \
|
||||||
|
return nullptr; \
|
||||||
}
|
}
|
||||||
|
|
||||||
// Define the folding rule for conversion between floating point and integer
|
// Define the folding rule for conversion between floating point and integer
|
||||||
@ -834,31 +847,49 @@ ConstantFoldingRule FoldFMix() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const analysis::Constant* one;
|
const analysis::Constant* one;
|
||||||
if (constants[1]->type()->AsFloat()->width() == 32) {
|
bool is_vector = false;
|
||||||
one = const_mgr->GetConstant(constants[1]->type(),
|
const analysis::Type* result_type = constants[1]->type();
|
||||||
|
const analysis::Type* base_type = result_type;
|
||||||
|
if (base_type->AsVector()) {
|
||||||
|
is_vector = true;
|
||||||
|
base_type = base_type->AsVector()->element_type();
|
||||||
|
}
|
||||||
|
assert(base_type->AsFloat() != nullptr &&
|
||||||
|
"FMix is suppose to act on floats or vectors of floats.");
|
||||||
|
|
||||||
|
if (base_type->AsFloat()->width() == 32) {
|
||||||
|
one = const_mgr->GetConstant(base_type,
|
||||||
utils::FloatProxy<float>(1.0f).GetWords());
|
utils::FloatProxy<float>(1.0f).GetWords());
|
||||||
} else {
|
} else {
|
||||||
one = const_mgr->GetConstant(constants[1]->type(),
|
one = const_mgr->GetConstant(base_type,
|
||||||
utils::FloatProxy<double>(1.0).GetWords());
|
utils::FloatProxy<double>(1.0).GetWords());
|
||||||
}
|
}
|
||||||
|
|
||||||
const analysis::Constant* temp1 =
|
if (is_vector) {
|
||||||
FOLD_FPARITH_OP(-)(constants[1]->type(), one, constants[3], const_mgr);
|
uint32_t one_id = const_mgr->GetDefiningInstruction(one)->result_id();
|
||||||
|
one =
|
||||||
|
const_mgr->GetConstant(result_type, std::vector<uint32_t>(4, one_id));
|
||||||
|
}
|
||||||
|
|
||||||
|
const analysis::Constant* temp1 = FoldFPBinaryOp(
|
||||||
|
FOLD_FPARITH_OP(-), inst->type_id(), {one, constants[3]}, context);
|
||||||
if (temp1 == nullptr) {
|
if (temp1 == nullptr) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
const analysis::Constant* temp2 = FOLD_FPARITH_OP(*)(
|
const analysis::Constant* temp2 = FoldFPBinaryOp(
|
||||||
constants[1]->type(), constants[1], temp1, const_mgr);
|
FOLD_FPARITH_OP(*), inst->type_id(), {constants[1], temp1}, context);
|
||||||
if (temp2 == nullptr) {
|
if (temp2 == nullptr) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
const analysis::Constant* temp3 = FOLD_FPARITH_OP(*)(
|
const analysis::Constant* temp3 =
|
||||||
constants[2]->type(), constants[2], constants[3], const_mgr);
|
FoldFPBinaryOp(FOLD_FPARITH_OP(*), inst->type_id(),
|
||||||
|
{constants[2], constants[3]}, context);
|
||||||
if (temp3 == nullptr) {
|
if (temp3 == nullptr) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
return FOLD_FPARITH_OP(+)(temp2->type(), temp2, temp3, const_mgr);
|
return FoldFPBinaryOp(FOLD_FPARITH_OP(+), inst->type_id(), {temp2, temp3},
|
||||||
|
context);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -222,6 +222,7 @@ OpName %main "main"
|
|||||||
%v2float_3_2 = OpConstantComposite %v2float %float_3 %float_2
|
%v2float_3_2 = OpConstantComposite %v2float %float_3 %float_2
|
||||||
%v2float_4_4 = OpConstantComposite %v2float %float_4 %float_4
|
%v2float_4_4 = OpConstantComposite %v2float %float_4 %float_4
|
||||||
%v2float_2_0p5 = OpConstantComposite %v2float %float_2 %float_0p5
|
%v2float_2_0p5 = OpConstantComposite %v2float %float_2 %float_0p5
|
||||||
|
%v2float_0p2_0p5 = OpConstantComposite %v2float %float_0p2 %float_0p5
|
||||||
%v2float_null = OpConstantNull %v2float
|
%v2float_null = OpConstantNull %v2float
|
||||||
%double_n1 = OpConstant %double -1
|
%double_n1 = OpConstant %double -1
|
||||||
%105 = OpConstant %double 0 ; Need a def with an numerical id to define id maps.
|
%105 = OpConstant %double 0 ; Need a def with an numerical id to define id maps.
|
||||||
@ -643,6 +644,58 @@ INSTANTIATE_TEST_SUITE_P(TestCase, IntVectorInstructionFoldingTest,
|
|||||||
));
|
));
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
|
||||||
|
using FloatVectorInstructionFoldingTest =
|
||||||
|
::testing::TestWithParam<InstructionFoldingCase<std::vector<float>>>;
|
||||||
|
|
||||||
|
TEST_P(FloatVectorInstructionFoldingTest, Case) {
|
||||||
|
const auto& tc = GetParam();
|
||||||
|
|
||||||
|
// Build module.
|
||||||
|
std::unique_ptr<IRContext> context =
|
||||||
|
BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.test_body,
|
||||||
|
SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
|
||||||
|
ASSERT_NE(nullptr, context);
|
||||||
|
|
||||||
|
// Fold the instruction to test.
|
||||||
|
analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
|
||||||
|
Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold);
|
||||||
|
SpvOp original_opcode = inst->opcode();
|
||||||
|
bool succeeded = context->get_instruction_folder().FoldInstruction(inst);
|
||||||
|
|
||||||
|
// Make sure the instruction folded as expected.
|
||||||
|
EXPECT_EQ(succeeded, inst == nullptr || inst->opcode() != original_opcode);
|
||||||
|
if (succeeded && inst != nullptr) {
|
||||||
|
EXPECT_EQ(inst->opcode(), SpvOpCopyObject);
|
||||||
|
inst = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0));
|
||||||
|
std::vector<SpvOp> opcodes = {SpvOpConstantComposite};
|
||||||
|
EXPECT_THAT(opcodes, Contains(inst->opcode()));
|
||||||
|
analysis::ConstantManager* const_mrg = context->get_constant_mgr();
|
||||||
|
const analysis::Constant* result = const_mrg->GetConstantFromInst(inst);
|
||||||
|
EXPECT_NE(result, nullptr);
|
||||||
|
if (result != nullptr) {
|
||||||
|
const std::vector<const analysis::Constant*>& componenets =
|
||||||
|
result->AsVectorConstant()->GetComponents();
|
||||||
|
EXPECT_EQ(componenets.size(), tc.expected_result.size());
|
||||||
|
for (size_t i = 0; i < componenets.size(); i++) {
|
||||||
|
EXPECT_EQ(tc.expected_result[i], componenets[i]->GetFloat());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
INSTANTIATE_TEST_SUITE_P(TestCase, FloatVectorInstructionFoldingTest,
|
||||||
|
::testing::Values(
|
||||||
|
// Test case 0: FMix {2.0, 2.0}, {2.0, 3.0} {0.2,0.5}
|
||||||
|
InstructionFoldingCase<std::vector<float>>(
|
||||||
|
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||||
|
"%main_lab = OpLabel\n" +
|
||||||
|
"%2 = OpExtInst %v2float %1 FMix %v2float_2_3 %v2float_0_0 %v2float_0p2_0p5\n" +
|
||||||
|
"OpReturn\n" +
|
||||||
|
"OpFunctionEnd",
|
||||||
|
2, {1.6f,1.5f})
|
||||||
|
));
|
||||||
|
// clang-format on
|
||||||
using BooleanInstructionFoldingTest =
|
using BooleanInstructionFoldingTest =
|
||||||
::testing::TestWithParam<InstructionFoldingCase<bool>>;
|
::testing::TestWithParam<InstructionFoldingCase<bool>>;
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user