mirror of
https://github.com/KhronosGroup/SPIRV-Tools
synced 2024-10-18 11:10:05 +00:00
Implement constant folding for many transcendentals (#3166)
* Implement constant folding for many transcendentals This change adds support for folding of sin/cos/tan/asin/acos/atan, exp/log/exp2/log2, sqrt, atan2 and pow. The mechanism allows to use any C function to implement folding in the future; for now I limited the actual additions to the most commonly used intrinsics in the shaders. Unary folder had to be tweaked to work with extended instructions - for extended instructions, constants.size() == 2 and constants[0] == nullptr. This adjustment is similar to the one binary folder already performs. Fixes #1390. * Fix Android build On old versions of Android NDK, we don't get std::exp2/std::log2 because of partial C++11 support. We do get ::exp2, but not ::log2 so we need to emulate that.
This commit is contained in:
parent
7a2d408dea
commit
0265a9d4de
@ -265,7 +265,10 @@ ConstantFoldingRule FoldFPUnaryOp(UnaryScalarFoldingRule scalar_rule) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (constants[0] == nullptr) {
|
||||
const analysis::Constant* arg =
|
||||
(inst->opcode() == SpvOpExtInst) ? constants[1] : constants[0];
|
||||
|
||||
if (arg == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@ -273,7 +276,7 @@ ConstantFoldingRule FoldFPUnaryOp(UnaryScalarFoldingRule scalar_rule) {
|
||||
std::vector<const analysis::Constant*> a_components;
|
||||
std::vector<const analysis::Constant*> results_components;
|
||||
|
||||
a_components = constants[0]->GetVectorComponents(const_mgr);
|
||||
a_components = arg->GetVectorComponents(const_mgr);
|
||||
|
||||
// Fold each component of the vector.
|
||||
for (uint32_t i = 0; i < a_components.size(); ++i) {
|
||||
@ -291,7 +294,7 @@ ConstantFoldingRule FoldFPUnaryOp(UnaryScalarFoldingRule scalar_rule) {
|
||||
}
|
||||
return const_mgr->GetConstant(vector_type, ids);
|
||||
} else {
|
||||
return scalar_rule(result_type, constants[0], const_mgr);
|
||||
return scalar_rule(result_type, arg, const_mgr);
|
||||
}
|
||||
};
|
||||
}
|
||||
@ -1070,6 +1073,60 @@ const analysis::Constant* FoldClamp3(
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
UnaryScalarFoldingRule FoldFTranscendentalUnary(double (*fp)(double)) {
|
||||
return
|
||||
[fp](const analysis::Type* result_type, const analysis::Constant* a,
|
||||
analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
|
||||
assert(result_type != nullptr && a != nullptr);
|
||||
const analysis::Float* float_type = a->type()->AsFloat();
|
||||
assert(float_type != nullptr);
|
||||
assert(float_type == result_type->AsFloat());
|
||||
if (float_type->width() == 32) {
|
||||
float fa = a->GetFloat();
|
||||
float res = static_cast<float>(fp(fa));
|
||||
utils::FloatProxy<float> result(res);
|
||||
std::vector<uint32_t> words = result.GetWords();
|
||||
return const_mgr->GetConstant(result_type, words);
|
||||
} else if (float_type->width() == 64) {
|
||||
double fa = a->GetDouble();
|
||||
double res = fp(fa);
|
||||
utils::FloatProxy<double> result(res);
|
||||
std::vector<uint32_t> words = result.GetWords();
|
||||
return const_mgr->GetConstant(result_type, words);
|
||||
}
|
||||
return nullptr;
|
||||
};
|
||||
}
|
||||
|
||||
BinaryScalarFoldingRule FoldFTranscendentalBinary(double (*fp)(double,
|
||||
double)) {
|
||||
return
|
||||
[fp](const analysis::Type* result_type, const analysis::Constant* a,
|
||||
const analysis::Constant* b,
|
||||
analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
|
||||
assert(result_type != nullptr && a != nullptr);
|
||||
const analysis::Float* float_type = a->type()->AsFloat();
|
||||
assert(float_type != nullptr);
|
||||
assert(float_type == result_type->AsFloat());
|
||||
assert(float_type == b->type()->AsFloat());
|
||||
if (float_type->width() == 32) {
|
||||
float fa = a->GetFloat();
|
||||
float fb = b->GetFloat();
|
||||
float res = static_cast<float>(fp(fa, fb));
|
||||
utils::FloatProxy<float> result(res);
|
||||
std::vector<uint32_t> words = result.GetWords();
|
||||
return const_mgr->GetConstant(result_type, words);
|
||||
} else if (float_type->width() == 64) {
|
||||
double fa = a->GetDouble();
|
||||
double fb = b->GetDouble();
|
||||
double res = fp(fa, fb);
|
||||
utils::FloatProxy<double> result(res);
|
||||
std::vector<uint32_t> words = result.GetWords();
|
||||
return const_mgr->GetConstant(result_type, words);
|
||||
}
|
||||
return nullptr;
|
||||
};
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void ConstantFoldingRules::AddFoldingRules() {
|
||||
@ -1175,6 +1232,45 @@ void ConstantFoldingRules::AddFoldingRules() {
|
||||
FoldClamp2);
|
||||
ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FClamp}].push_back(
|
||||
FoldClamp3);
|
||||
ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Sin}].push_back(
|
||||
FoldFPUnaryOp(FoldFTranscendentalUnary(std::sin)));
|
||||
ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Cos}].push_back(
|
||||
FoldFPUnaryOp(FoldFTranscendentalUnary(std::cos)));
|
||||
ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Tan}].push_back(
|
||||
FoldFPUnaryOp(FoldFTranscendentalUnary(std::tan)));
|
||||
ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Asin}].push_back(
|
||||
FoldFPUnaryOp(FoldFTranscendentalUnary(std::asin)));
|
||||
ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Acos}].push_back(
|
||||
FoldFPUnaryOp(FoldFTranscendentalUnary(std::acos)));
|
||||
ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Atan}].push_back(
|
||||
FoldFPUnaryOp(FoldFTranscendentalUnary(std::atan)));
|
||||
ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Exp}].push_back(
|
||||
FoldFPUnaryOp(FoldFTranscendentalUnary(std::exp)));
|
||||
ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Log}].push_back(
|
||||
FoldFPUnaryOp(FoldFTranscendentalUnary(std::log)));
|
||||
|
||||
#ifdef __ANDROID__
|
||||
// Android NDK r15c tageting ABI 15 doesn't have full support for C++11
|
||||
// (no std::exp2/log2). ::exp2 is available from C99 but ::log2 isn't
|
||||
// available up until ABI 18 so we use a shim
|
||||
auto log2_shim = [](double v) -> double { return log(v) / log(2.0); };
|
||||
ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Exp2}].push_back(
|
||||
FoldFPUnaryOp(FoldFTranscendentalUnary(::exp2)));
|
||||
ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Log2}].push_back(
|
||||
FoldFPUnaryOp(FoldFTranscendentalUnary(log2_shim)));
|
||||
#else
|
||||
ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Exp2}].push_back(
|
||||
FoldFPUnaryOp(FoldFTranscendentalUnary(std::exp2)));
|
||||
ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Log2}].push_back(
|
||||
FoldFPUnaryOp(FoldFTranscendentalUnary(std::log2)));
|
||||
#endif
|
||||
|
||||
ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Sqrt}].push_back(
|
||||
FoldFPUnaryOp(FoldFTranscendentalUnary(std::sqrt)));
|
||||
ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Atan2}].push_back(
|
||||
FoldFPBinaryOp(FoldFTranscendentalBinary(std::atan2)));
|
||||
ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Pow}].push_back(
|
||||
FoldFPBinaryOp(FoldFTranscendentalBinary(std::pow)));
|
||||
}
|
||||
}
|
||||
} // namespace opt
|
||||
|
@ -1693,7 +1693,7 @@ INSTANTIATE_TEST_SUITE_P(FloatConstantFoldingTest, FloatInstructionFoldingTest,
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
2, 0.2f),
|
||||
// Test case 21: FMax 1.0 4.0
|
||||
// Test case 23: FMax 1.0 4.0
|
||||
InstructionFoldingCase<float>(
|
||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
@ -1701,7 +1701,7 @@ INSTANTIATE_TEST_SUITE_P(FloatConstantFoldingTest, FloatInstructionFoldingTest,
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
2, 4.0f),
|
||||
// Test case 22: FMax 1.0 0.2
|
||||
// Test case 24: FMax 1.0 0.2
|
||||
InstructionFoldingCase<float>(
|
||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
@ -1709,7 +1709,7 @@ INSTANTIATE_TEST_SUITE_P(FloatConstantFoldingTest, FloatInstructionFoldingTest,
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
2, 1.0f),
|
||||
// Test case 23: FClamp 1.0 0.2 4.0
|
||||
// Test case 25: FClamp 1.0 0.2 4.0
|
||||
InstructionFoldingCase<float>(
|
||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
@ -1717,7 +1717,7 @@ INSTANTIATE_TEST_SUITE_P(FloatConstantFoldingTest, FloatInstructionFoldingTest,
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
2, 1.0f),
|
||||
// Test case 24: FClamp 0.2 2.0 4.0
|
||||
// Test case 26: FClamp 0.2 2.0 4.0
|
||||
InstructionFoldingCase<float>(
|
||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
@ -1725,7 +1725,7 @@ INSTANTIATE_TEST_SUITE_P(FloatConstantFoldingTest, FloatInstructionFoldingTest,
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
2, 2.0f),
|
||||
// Test case 25: FClamp 2049.0 2.0 4.0
|
||||
// Test case 27: FClamp 2049.0 2.0 4.0
|
||||
InstructionFoldingCase<float>(
|
||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
@ -1733,7 +1733,7 @@ INSTANTIATE_TEST_SUITE_P(FloatConstantFoldingTest, FloatInstructionFoldingTest,
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
2, 4.0f),
|
||||
// Test case 26: FClamp 1.0 2.0 x
|
||||
// Test case 28: FClamp 1.0 2.0 x
|
||||
InstructionFoldingCase<float>(
|
||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
@ -1742,7 +1742,7 @@ INSTANTIATE_TEST_SUITE_P(FloatConstantFoldingTest, FloatInstructionFoldingTest,
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
2, 2.0),
|
||||
// Test case 27: FClamp 1.0 x 0.5
|
||||
// Test case 29: FClamp 1.0 x 0.5
|
||||
InstructionFoldingCase<float>(
|
||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
@ -1750,7 +1750,111 @@ INSTANTIATE_TEST_SUITE_P(FloatConstantFoldingTest, FloatInstructionFoldingTest,
|
||||
"%2 = OpExtInst %float %1 FClamp %float_1 %undef %float_0p5\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
2, 0.5)
|
||||
2, 0.5),
|
||||
// Test case 30: Sin 0.0
|
||||
InstructionFoldingCase<float>(
|
||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%2 = OpExtInst %float %1 Sin %float_0\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
2, 0.0),
|
||||
// Test case 31: Cos 0.0
|
||||
InstructionFoldingCase<float>(
|
||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%2 = OpExtInst %float %1 Cos %float_0\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
2, 1.0),
|
||||
// Test case 32: Tan 0.0
|
||||
InstructionFoldingCase<float>(
|
||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%2 = OpExtInst %float %1 Tan %float_0\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
2, 0.0),
|
||||
// Test case 33: Asin 0.0
|
||||
InstructionFoldingCase<float>(
|
||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%2 = OpExtInst %float %1 Asin %float_0\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
2, 0.0),
|
||||
// Test case 34: Acos 1.0
|
||||
InstructionFoldingCase<float>(
|
||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%2 = OpExtInst %float %1 Acos %float_1\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
2, 0.0),
|
||||
// Test case 35: Atan 0.0
|
||||
InstructionFoldingCase<float>(
|
||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%2 = OpExtInst %float %1 Atan %float_0\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
2, 0.0),
|
||||
// Test case 36: Exp 0.0
|
||||
InstructionFoldingCase<float>(
|
||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%2 = OpExtInst %float %1 Exp %float_0\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
2, 1.0),
|
||||
// Test case 37: Log 1.0
|
||||
InstructionFoldingCase<float>(
|
||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%2 = OpExtInst %float %1 Log %float_1\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
2, 0.0),
|
||||
// Test case 38: Exp2 2.0
|
||||
InstructionFoldingCase<float>(
|
||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%2 = OpExtInst %float %1 Exp2 %float_2\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
2, 4.0),
|
||||
// Test case 39: Log2 4.0
|
||||
InstructionFoldingCase<float>(
|
||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%2 = OpExtInst %float %1 Log2 %float_4\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
2, 2.0),
|
||||
// Test case 40: Sqrt 4.0
|
||||
InstructionFoldingCase<float>(
|
||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%2 = OpExtInst %float %1 Sqrt %float_4\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
2, 2.0),
|
||||
// Test case 41: Atan2 0.0 1.0
|
||||
InstructionFoldingCase<float>(
|
||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%2 = OpExtInst %float %1 Atan2 %float_0 %float_1\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
2, 0.0),
|
||||
// Test case 42: Pow 2.0 3.0
|
||||
InstructionFoldingCase<float>(
|
||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%2 = OpExtInst %float %1 Pow %float_2 %float_3\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
2, 8.0)
|
||||
));
|
||||
// clang-format on
|
||||
|
||||
@ -1967,7 +2071,25 @@ INSTANTIATE_TEST_SUITE_P(DoubleConstantFoldingTest, DoubleInstructionFoldingTest
|
||||
"%2 = OpExtInst %double %1 FClamp %double_1 %undef %double_0p5\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
2, 0.5)
|
||||
2, 0.5),
|
||||
// Test case 21: Sqrt 4.0
|
||||
InstructionFoldingCase<double>(
|
||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%undef = OpUndef %double\n" +
|
||||
"%2 = OpExtInst %double %1 Sqrt %double_4\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
2, 2.0),
|
||||
// Test case 22: Pow 2.0 3.0
|
||||
InstructionFoldingCase<double>(
|
||||
Header() + "%main = OpFunction %void None %void_func\n" +
|
||||
"%main_lab = OpLabel\n" +
|
||||
"%undef = OpUndef %double\n" +
|
||||
"%2 = OpExtInst %double %1 Pow %double_2 %double_3\n" +
|
||||
"OpReturn\n" +
|
||||
"OpFunctionEnd",
|
||||
2, 8.0)
|
||||
));
|
||||
// clang-format on
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user