diff --git a/src/sksl/ir/SkSLFunctionCall.cpp b/src/sksl/ir/SkSLFunctionCall.cpp index 0e7e8edf15..9ec5baf43c 100644 --- a/src/sksl/ir/SkSLFunctionCall.cpp +++ b/src/sksl/ir/SkSLFunctionCall.cpp @@ -20,8 +20,12 @@ #include "include/sksl/DSLCore.h" #include "src/core/SkMatrixInvert.h" +#include + namespace SkSL { +using IntrinsicArguments = std::array; + static bool has_compile_time_constant_arguments(const ExpressionArray& arguments) { for (const std::unique_ptr& arg : arguments) { const Expression* expr = ConstantFolder::GetConstantValueForVariable(*arg); @@ -104,19 +108,11 @@ static std::unique_ptr coalesce_n_way_vector(const Expression* arg0, int offset = arg0->fOffset; - arg0 = ConstantFolder::GetConstantValueForVariable(*arg0); - SkASSERT(arg0); - const Type& vecType = arg0->type().isVector() ? arg0->type() : (arg1 && arg1->type().isVector()) ? arg1->type() : arg0->type(); - SkASSERT(arg0->type().componentType() == vecType.componentType()); - - if (arg1) { - arg1 = ConstantFolder::GetConstantValueForVariable(*arg1); - SkASSERT(arg1); - SkASSERT(arg1->type().componentType() == vecType.componentType()); - } + SkASSERT( arg0->type().componentType() == vecType.componentType()); + SkASSERT(!arg1 || arg1->type().componentType() == vecType.componentType()); double value = startingState; int arg0Index = 0; @@ -149,42 +145,45 @@ static std::unique_ptr coalesce_n_way_vector(const Expression* arg0, } template -static std::unique_ptr coalesce_vector(const ExpressionArray& arguments, +static std::unique_ptr coalesce_vector(const IntrinsicArguments& arguments, double startingState, const Type& returnType, CoalesceFn coalesce, FinalizeFn finalize) { - SkASSERT(arguments.size() == 1); + SkASSERT(arguments[0]); + SkASSERT(!arguments[1]); type_check_expression(*arguments[0]); - return coalesce_n_way_vector(arguments[0].get(), /*arg1=*/nullptr, + return coalesce_n_way_vector(arguments[0], /*arg1=*/nullptr, (double)startingState, returnType, coalesce, finalize); } template -static std::unique_ptr coalesce_pairwise_vectors(const ExpressionArray& arguments, +static std::unique_ptr coalesce_pairwise_vectors(const IntrinsicArguments& arguments, double startingState, const Type& returnType, CoalesceFn coalesce, FinalizeFn finalize) { - SkASSERT(arguments.size() == 2); + SkASSERT(arguments[0]); + SkASSERT(arguments[1]); + SkASSERT(!arguments[2]); type_check_expression(*arguments[0]); type_check_expression(*arguments[1]); - return coalesce_n_way_vector(arguments[0].get(), arguments[1].get(), + return coalesce_n_way_vector(arguments[0], arguments[1], (double)startingState, returnType, coalesce, finalize); } using CompareFn = bool (*)(double, double); static std::unique_ptr optimize_comparison(const Context& context, - const ExpressionArray& arguments, + const IntrinsicArguments& arguments, CompareFn compare) { - SkASSERT(arguments.size() == 2); - const Expression* left = ConstantFolder::GetConstantValueForVariable(*arguments[0]); - const Expression* right = ConstantFolder::GetConstantValueForVariable(*arguments[1]); + const Expression* left = arguments[0]; + const Expression* right = arguments[1]; SkASSERT(left); SkASSERT(right); + SkASSERT(!arguments[2]); const Type& type = left->type(); SkASSERT(type.isVector()); @@ -225,18 +224,6 @@ static std::unique_ptr evaluate_n_way_intrinsic(const Context& conte // If an argument is null, zero is passed to the evaluation function. If the arguments are a mix // of scalars and compounds, scalars are interpreted as a compound containing the same value for // every component. - arg0 = ConstantFolder::GetConstantValueForVariable(*arg0); - SkASSERT(arg0); - - if (arg1) { - arg1 = ConstantFolder::GetConstantValueForVariable(*arg1); - SkASSERT(arg1); - } - - if (arg2) { - arg2 = ConstantFolder::GetConstantValueForVariable(*arg2); - SkASSERT(arg2); - } int slots = returnType.slotCount(); ExpressionArray array; @@ -279,21 +266,23 @@ static std::unique_ptr evaluate_n_way_intrinsic(const Context& conte template static std::unique_ptr evaluate_intrinsic(const Context& context, - const ExpressionArray& arguments, + const IntrinsicArguments& arguments, const Type& returnType, EvaluateFn eval) { - SkASSERT(arguments.size() == 1); + SkASSERT(arguments[0]); + SkASSERT(!arguments[1]); type_check_expression(*arguments[0]); - return evaluate_n_way_intrinsic(context, arguments[0].get(), /*arg1=*/nullptr, /*arg2=*/nullptr, + return evaluate_n_way_intrinsic(context, arguments[0], /*arg1=*/nullptr, /*arg2=*/nullptr, returnType, eval); } static std::unique_ptr evaluate_intrinsic_numeric(const Context& context, - const ExpressionArray& arguments, + const IntrinsicArguments& arguments, const Type& returnType, EvaluateFn eval) { - SkASSERT(arguments.size() == 1); + SkASSERT(arguments[0]); + SkASSERT(!arguments[1]); const Type& type = arguments[0]->type().componentType(); if (type.isFloat()) { @@ -308,10 +297,12 @@ static std::unique_ptr evaluate_intrinsic_numeric(const Context& con } static std::unique_ptr evaluate_pairwise_intrinsic(const Context& context, - const ExpressionArray& arguments, + const IntrinsicArguments& arguments, const Type& returnType, EvaluateFn eval) { - SkASSERT(arguments.size() == 2); + SkASSERT(arguments[0]); + SkASSERT(arguments[1]); + SkASSERT(!arguments[2]); const Type& type = arguments[0]->type().componentType(); if (type.isFloat()) { @@ -325,15 +316,17 @@ static std::unique_ptr evaluate_pairwise_intrinsic(const Context& co return nullptr; } - return evaluate_n_way_intrinsic(context, arguments[0].get(), arguments[1].get(), - /*arg2=*/nullptr, returnType, eval); + return evaluate_n_way_intrinsic(context, arguments[0], arguments[1], /*arg2=*/nullptr, + returnType, eval); } static std::unique_ptr evaluate_3_way_intrinsic(const Context& context, - const ExpressionArray& arguments, + const IntrinsicArguments& arguments, const Type& returnType, EvaluateFn eval) { - SkASSERT(arguments.size() == 3); + SkASSERT(arguments[0]); + SkASSERT(arguments[1]); + SkASSERT(arguments[2]); const Type& type = arguments[0]->type().componentType(); if (type.isFloat()) { @@ -349,8 +342,8 @@ static std::unique_ptr evaluate_3_way_intrinsic(const Context& conte return nullptr; } - return evaluate_n_way_intrinsic(context, arguments[0].get(), arguments[1].get(), - arguments[2].get(), returnType, eval); + return evaluate_n_way_intrinsic(context, arguments[0], arguments[1], arguments[2], + returnType, eval); } template @@ -458,14 +451,14 @@ static void extract_matrix(const Expression* expr, float mat[16]) { static std::unique_ptr optimize_intrinsic_call(const Context& context, IntrinsicKind intrinsic, - const ExpressionArray& arguments, + const ExpressionArray& argArray, const Type& returnType) { - // Helper function for accessing a matrix argument by column and row. - const Expression* matrix = nullptr; - auto M = [&](int c, int r) -> float { - int index = (matrix->type().rows() * c) + r; - return matrix->getConstantSubexpression(index)->as().value(); - }; + // Replace constant variables with their literal values. + IntrinsicArguments arguments = {}; + SkASSERT(argArray.size() <= arguments.size()); + for (int index = 0; index < argArray.count(); ++index) { + arguments[index] = ConstantFolder::GetConstantValueForVariable(*argArray[index]); + } using namespace SkSL::dsl; switch (intrinsic) { @@ -501,7 +494,7 @@ static std::unique_ptr optimize_intrinsic_call(const Context& contex return evaluate_intrinsic(context, arguments, returnType, Intrinsics::evaluate_acos); case k_atan_IntrinsicKind: - if (arguments.size() == 1) { + if (argArray.size() == 1) { return evaluate_intrinsic(context, arguments, returnType, Intrinsics::evaluate_atan); } else { @@ -588,9 +581,8 @@ static std::unique_ptr optimize_intrinsic_call(const Context& contex SkDEBUGFAILF("unsupported type %s", numericType.description().c_str()); return nullptr; } - return evaluate_n_way_intrinsic(context, arguments[0].get(), arguments[1].get(), - arguments[2].get(), returnType, - Intrinsics::evaluate_mix); + return evaluate_n_way_intrinsic(context, arguments[0], arguments[1], arguments[2], + returnType, Intrinsics::evaluate_mix); } else { return evaluate_3_way_intrinsic(context, arguments, returnType, Intrinsics::evaluate_mix); @@ -683,24 +675,27 @@ static std::unique_ptr optimize_intrinsic_call(const Context& contex return evaluate_pairwise_intrinsic(context, arguments, returnType, Intrinsics::evaluate_matrixCompMult); case k_transpose_IntrinsicKind: { - matrix = ConstantFolder::GetConstantValueForVariable(*arguments[0]); + auto M = [&](int c, int r) -> float { + int index = (arguments[0]->type().rows() * c) + r; + return arguments[0]->getConstantSubexpression(index)->as().value(); + }; + ExpressionArray array; array.reserve_back(returnType.slotCount()); for (int c = 0; c < returnType.columns(); ++c) { for (int r = 0; r < returnType.rows(); ++r) { - array.push_back(FloatLiteral::Make(matrix->fOffset, M(r, c), + array.push_back(FloatLiteral::Make(arguments[0]->fOffset, M(r, c), &returnType.componentType())); } } - return ConstructorCompound::Make(context, matrix->fOffset, returnType, + return ConstructorCompound::Make(context, arguments[0]->fOffset, returnType, std::move(array)); } case k_determinant_IntrinsicKind: { - matrix = ConstantFolder::GetConstantValueForVariable(*arguments[0]); float m[16]; - extract_matrix(matrix, m); + extract_matrix(arguments[0], m); float determinant; - switch (matrix->type().slotCount()) { + switch (arguments[0]->type().slotCount()) { case 4: determinant = SkInvert2x2Matrix(m, /*outMatrix=*/nullptr); break; @@ -711,28 +706,27 @@ static std::unique_ptr optimize_intrinsic_call(const Context& contex determinant = SkInvert4x4Matrix(m, /*outMatrix=*/nullptr); break; default: - SkDEBUGFAILF("unsupported type %s", matrix->type().description().c_str()); + SkDEBUGFAILF("unsupported type %s", arguments[0]->type().description().c_str()); return nullptr; } - return FloatLiteral::Make(matrix->fOffset, determinant, &returnType); + return FloatLiteral::Make(arguments[0]->fOffset, determinant, &returnType); } case k_inverse_IntrinsicKind: { - matrix = ConstantFolder::GetConstantValueForVariable(*arguments[0]); float m[16]; - extract_matrix(matrix, m); - switch (matrix->type().slotCount()) { + extract_matrix(arguments[0], m); + switch (arguments[0]->type().slotCount()) { case 4: if (SkInvert2x2Matrix(m, m) == 0.0f) { return nullptr; } - return DSLType::Construct(&matrix->type(), + return DSLType::Construct(&arguments[0]->type(), m[0], m[1], m[2], m[3]).release(); case 9: if (SkInvert3x3Matrix(m, m) == 0.0f) { return nullptr; } - return DSLType::Construct(&matrix->type(), + return DSLType::Construct(&arguments[0]->type(), m[0], m[1], m[2], m[3], m[4], m[5], m[6], m[7], m[8]).release(); @@ -741,13 +735,13 @@ static std::unique_ptr optimize_intrinsic_call(const Context& contex if (SkInvert4x4Matrix(m, m) == 0.0f) { return nullptr; } - return DSLType::Construct(&matrix->type(), + return DSLType::Construct(&arguments[0]->type(), m[0], m[1], m[2], m[3], m[4], m[5], m[6], m[7], m[8], m[9], m[10], m[11], m[12], m[13], m[14], m[15]).release(); } - SkDEBUGFAILF("unsupported type %s", matrix->type().description().c_str()); + SkDEBUGFAILF("unsupported type %s", arguments[0]->type().description().c_str()); return nullptr; } // 8.7 : Vector Relational Functions