Implement constant folding for many transcendentals (#3166)

* Implement constant folding for many transcendentals

This change adds support for folding of sin/cos/tan/asin/acos/atan,
exp/log/exp2/log2, sqrt, atan2 and pow.

The mechanism allows to use any C function to implement folding in the
future; for now I limited the actual additions to the most commonly used
intrinsics in the shaders.

Unary folder had to be tweaked to work with extended instructions - for
extended instructions, constants.size() == 2 and constants[0] ==
nullptr. This adjustment is similar to the one binary folder already
performs.

Fixes #1390.

* Fix Android build

On old versions of Android NDK, we don't get std::exp2/std::log2
because of partial C++11 support.

We do get ::exp2, but not ::log2 so we need to emulate that.
This commit is contained in:
Arseny Kapoulkine 2020-02-03 06:20:47 -08:00 committed by GitHub
parent 7a2d408dea
commit 0265a9d4de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 230 additions and 12 deletions

View File

@ -265,7 +265,10 @@ ConstantFoldingRule FoldFPUnaryOp(UnaryScalarFoldingRule scalar_rule) {
return nullptr;
}
if (constants[0] == nullptr) {
const analysis::Constant* arg =
(inst->opcode() == SpvOpExtInst) ? constants[1] : constants[0];
if (arg == nullptr) {
return nullptr;
}
@ -273,7 +276,7 @@ ConstantFoldingRule FoldFPUnaryOp(UnaryScalarFoldingRule scalar_rule) {
std::vector<const analysis::Constant*> a_components;
std::vector<const analysis::Constant*> results_components;
a_components = constants[0]->GetVectorComponents(const_mgr);
a_components = arg->GetVectorComponents(const_mgr);
// Fold each component of the vector.
for (uint32_t i = 0; i < a_components.size(); ++i) {
@ -291,7 +294,7 @@ ConstantFoldingRule FoldFPUnaryOp(UnaryScalarFoldingRule scalar_rule) {
}
return const_mgr->GetConstant(vector_type, ids);
} else {
return scalar_rule(result_type, constants[0], const_mgr);
return scalar_rule(result_type, arg, const_mgr);
}
};
}
@ -1070,6 +1073,60 @@ const analysis::Constant* FoldClamp3(
return nullptr;
}
UnaryScalarFoldingRule FoldFTranscendentalUnary(double (*fp)(double)) {
return
[fp](const analysis::Type* result_type, const analysis::Constant* a,
analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
assert(result_type != nullptr && a != nullptr);
const analysis::Float* float_type = a->type()->AsFloat();
assert(float_type != nullptr);
assert(float_type == result_type->AsFloat());
if (float_type->width() == 32) {
float fa = a->GetFloat();
float res = static_cast<float>(fp(fa));
utils::FloatProxy<float> result(res);
std::vector<uint32_t> words = result.GetWords();
return const_mgr->GetConstant(result_type, words);
} else if (float_type->width() == 64) {
double fa = a->GetDouble();
double res = fp(fa);
utils::FloatProxy<double> result(res);
std::vector<uint32_t> words = result.GetWords();
return const_mgr->GetConstant(result_type, words);
}
return nullptr;
};
}
BinaryScalarFoldingRule FoldFTranscendentalBinary(double (*fp)(double,
double)) {
return
[fp](const analysis::Type* result_type, const analysis::Constant* a,
const analysis::Constant* b,
analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
assert(result_type != nullptr && a != nullptr);
const analysis::Float* float_type = a->type()->AsFloat();
assert(float_type != nullptr);
assert(float_type == result_type->AsFloat());
assert(float_type == b->type()->AsFloat());
if (float_type->width() == 32) {
float fa = a->GetFloat();
float fb = b->GetFloat();
float res = static_cast<float>(fp(fa, fb));
utils::FloatProxy<float> result(res);
std::vector<uint32_t> words = result.GetWords();
return const_mgr->GetConstant(result_type, words);
} else if (float_type->width() == 64) {
double fa = a->GetDouble();
double fb = b->GetDouble();
double res = fp(fa, fb);
utils::FloatProxy<double> result(res);
std::vector<uint32_t> words = result.GetWords();
return const_mgr->GetConstant(result_type, words);
}
return nullptr;
};
}
} // namespace
void ConstantFoldingRules::AddFoldingRules() {
@ -1175,6 +1232,45 @@ void ConstantFoldingRules::AddFoldingRules() {
FoldClamp2);
ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FClamp}].push_back(
FoldClamp3);
ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Sin}].push_back(
FoldFPUnaryOp(FoldFTranscendentalUnary(std::sin)));
ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Cos}].push_back(
FoldFPUnaryOp(FoldFTranscendentalUnary(std::cos)));
ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Tan}].push_back(
FoldFPUnaryOp(FoldFTranscendentalUnary(std::tan)));
ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Asin}].push_back(
FoldFPUnaryOp(FoldFTranscendentalUnary(std::asin)));
ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Acos}].push_back(
FoldFPUnaryOp(FoldFTranscendentalUnary(std::acos)));
ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Atan}].push_back(
FoldFPUnaryOp(FoldFTranscendentalUnary(std::atan)));
ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Exp}].push_back(
FoldFPUnaryOp(FoldFTranscendentalUnary(std::exp)));
ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Log}].push_back(
FoldFPUnaryOp(FoldFTranscendentalUnary(std::log)));
#ifdef __ANDROID__
// Android NDK r15c tageting ABI 15 doesn't have full support for C++11
// (no std::exp2/log2). ::exp2 is available from C99 but ::log2 isn't
// available up until ABI 18 so we use a shim
auto log2_shim = [](double v) -> double { return log(v) / log(2.0); };
ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Exp2}].push_back(
FoldFPUnaryOp(FoldFTranscendentalUnary(::exp2)));
ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Log2}].push_back(
FoldFPUnaryOp(FoldFTranscendentalUnary(log2_shim)));
#else
ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Exp2}].push_back(
FoldFPUnaryOp(FoldFTranscendentalUnary(std::exp2)));
ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Log2}].push_back(
FoldFPUnaryOp(FoldFTranscendentalUnary(std::log2)));
#endif
ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Sqrt}].push_back(
FoldFPUnaryOp(FoldFTranscendentalUnary(std::sqrt)));
ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Atan2}].push_back(
FoldFPBinaryOp(FoldFTranscendentalBinary(std::atan2)));
ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Pow}].push_back(
FoldFPBinaryOp(FoldFTranscendentalBinary(std::pow)));
}
}
} // namespace opt

View File

@ -1693,7 +1693,7 @@ INSTANTIATE_TEST_SUITE_P(FloatConstantFoldingTest, FloatInstructionFoldingTest,
"OpReturn\n" +
"OpFunctionEnd",
2, 0.2f),
// Test case 21: FMax 1.0 4.0
// Test case 23: FMax 1.0 4.0
InstructionFoldingCase<float>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
@ -1701,7 +1701,7 @@ INSTANTIATE_TEST_SUITE_P(FloatConstantFoldingTest, FloatInstructionFoldingTest,
"OpReturn\n" +
"OpFunctionEnd",
2, 4.0f),
// Test case 22: FMax 1.0 0.2
// Test case 24: FMax 1.0 0.2
InstructionFoldingCase<float>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
@ -1709,7 +1709,7 @@ INSTANTIATE_TEST_SUITE_P(FloatConstantFoldingTest, FloatInstructionFoldingTest,
"OpReturn\n" +
"OpFunctionEnd",
2, 1.0f),
// Test case 23: FClamp 1.0 0.2 4.0
// Test case 25: FClamp 1.0 0.2 4.0
InstructionFoldingCase<float>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
@ -1717,7 +1717,7 @@ INSTANTIATE_TEST_SUITE_P(FloatConstantFoldingTest, FloatInstructionFoldingTest,
"OpReturn\n" +
"OpFunctionEnd",
2, 1.0f),
// Test case 24: FClamp 0.2 2.0 4.0
// Test case 26: FClamp 0.2 2.0 4.0
InstructionFoldingCase<float>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
@ -1725,7 +1725,7 @@ INSTANTIATE_TEST_SUITE_P(FloatConstantFoldingTest, FloatInstructionFoldingTest,
"OpReturn\n" +
"OpFunctionEnd",
2, 2.0f),
// Test case 25: FClamp 2049.0 2.0 4.0
// Test case 27: FClamp 2049.0 2.0 4.0
InstructionFoldingCase<float>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
@ -1733,7 +1733,7 @@ INSTANTIATE_TEST_SUITE_P(FloatConstantFoldingTest, FloatInstructionFoldingTest,
"OpReturn\n" +
"OpFunctionEnd",
2, 4.0f),
// Test case 26: FClamp 1.0 2.0 x
// Test case 28: FClamp 1.0 2.0 x
InstructionFoldingCase<float>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
@ -1742,7 +1742,7 @@ INSTANTIATE_TEST_SUITE_P(FloatConstantFoldingTest, FloatInstructionFoldingTest,
"OpReturn\n" +
"OpFunctionEnd",
2, 2.0),
// Test case 27: FClamp 1.0 x 0.5
// Test case 29: FClamp 1.0 x 0.5
InstructionFoldingCase<float>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
@ -1750,7 +1750,111 @@ INSTANTIATE_TEST_SUITE_P(FloatConstantFoldingTest, FloatInstructionFoldingTest,
"%2 = OpExtInst %float %1 FClamp %float_1 %undef %float_0p5\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0.5)
2, 0.5),
// Test case 30: Sin 0.0
InstructionFoldingCase<float>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpExtInst %float %1 Sin %float_0\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0.0),
// Test case 31: Cos 0.0
InstructionFoldingCase<float>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpExtInst %float %1 Cos %float_0\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 1.0),
// Test case 32: Tan 0.0
InstructionFoldingCase<float>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpExtInst %float %1 Tan %float_0\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0.0),
// Test case 33: Asin 0.0
InstructionFoldingCase<float>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpExtInst %float %1 Asin %float_0\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0.0),
// Test case 34: Acos 1.0
InstructionFoldingCase<float>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpExtInst %float %1 Acos %float_1\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0.0),
// Test case 35: Atan 0.0
InstructionFoldingCase<float>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpExtInst %float %1 Atan %float_0\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0.0),
// Test case 36: Exp 0.0
InstructionFoldingCase<float>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpExtInst %float %1 Exp %float_0\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 1.0),
// Test case 37: Log 1.0
InstructionFoldingCase<float>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpExtInst %float %1 Log %float_1\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0.0),
// Test case 38: Exp2 2.0
InstructionFoldingCase<float>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpExtInst %float %1 Exp2 %float_2\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 4.0),
// Test case 39: Log2 4.0
InstructionFoldingCase<float>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpExtInst %float %1 Log2 %float_4\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 2.0),
// Test case 40: Sqrt 4.0
InstructionFoldingCase<float>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpExtInst %float %1 Sqrt %float_4\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 2.0),
// Test case 41: Atan2 0.0 1.0
InstructionFoldingCase<float>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpExtInst %float %1 Atan2 %float_0 %float_1\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0.0),
// Test case 42: Pow 2.0 3.0
InstructionFoldingCase<float>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpExtInst %float %1 Pow %float_2 %float_3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 8.0)
));
// clang-format on
@ -1967,7 +2071,25 @@ INSTANTIATE_TEST_SUITE_P(DoubleConstantFoldingTest, DoubleInstructionFoldingTest
"%2 = OpExtInst %double %1 FClamp %double_1 %undef %double_0p5\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0.5)
2, 0.5),
// Test case 21: Sqrt 4.0
InstructionFoldingCase<double>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%undef = OpUndef %double\n" +
"%2 = OpExtInst %double %1 Sqrt %double_4\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 2.0),
// Test case 22: Pow 2.0 3.0
InstructionFoldingCase<double>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%undef = OpUndef %double\n" +
"%2 = OpExtInst %double %1 Pow %double_2 %double_3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 8.0)
));
// clang-format on