Honor component type in Metal matrix helper functions.

Right now, Metal forces types to full precision. The matrix helper
functions previously baked in that assumption by hard-coding "floatX".
Now, they honor the component type; if this->typeName() started
returning "half", our helper functions would be named with "halfX". This
would allow half-precision and full-precision helpers to coexist.

Change-Id: I1679e6e76d2cf3c27fd69c42a92fb24bff6b69ec
Bug: skia:12339
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/439396
Commit-Queue: Brian Osman <brianosman@google.com>
Reviewed-by: Brian Osman <brianosman@google.com>
Auto-Submit: John Stiles <johnstiles@google.com>
This commit is contained in:
John Stiles 2021-08-13 17:53:49 -04:00 committed by SkCQ
parent e56d31fd56
commit c18ee4e55a
4 changed files with 40 additions and 27 deletions

View File

@ -825,10 +825,12 @@ void MetalCodeGenerator::assembleMatrixFromMatrix(const Type& sourceMatrix, int
SkASSERT(rows <= 4);
SkASSERT(columns <= 4);
const char* columnSeparator = "";
std::string matrixType = this->typeName(sourceMatrix.componentType());
const char* separator = "";
for (int c = 0; c < columns; ++c) {
fExtraFunctions.printf("%sfloat%d(", columnSeparator, rows);
columnSeparator = "), ";
fExtraFunctions.printf("%s%s%d(", separator, matrixType.c_str(), rows);
separator = "), ";
// Determine how many values to take from the source matrix for this row.
int swizzleLength = 0;
@ -861,14 +863,18 @@ void MetalCodeGenerator::assembleMatrixFromMatrix(const Type& sourceMatrix, int
// `x1`, etc. An error is written if the expression list don't contain exactly C*R scalars.
void MetalCodeGenerator::assembleMatrixFromExpressions(const AnyConstructor& ctor,
int columns, int rows) {
SkASSERT(rows <= 4);
SkASSERT(columns <= 4);
std::string matrixType = this->typeName(ctor.type().componentType());
size_t argIndex = 0;
int argPosition = 0;
auto args = ctor.argumentSpan();
const char* rowSeparator = "";
const char* separator = "";
for (int r = 0; r < rows; ++r) {
fExtraFunctions.printf("%sfloat%d(", rowSeparator, rows);
rowSeparator = "), ";
fExtraFunctions.printf("%s%s%d(", separator, matrixType.c_str(), rows);
separator = "), ";
const char* columnSeparator = "";
for (int c = 0; c < columns; ++c) {
@ -924,14 +930,15 @@ void MetalCodeGenerator::assembleMatrixFromExpressions(const AnyConstructor& cto
// constructor for any given permutation of input argument types. Returns the name of the
// generated constructor method.
String MetalCodeGenerator::getMatrixConstructHelper(const AnyConstructor& c) {
const Type& matrix = c.type();
int columns = matrix.columns();
int rows = matrix.rows();
const Type& type = c.type();
int columns = type.columns();
int rows = type.rows();
auto args = c.argumentSpan();
String typeName = this->typeName(type);
// Create the helper-method name and use it as our lookup key.
String name;
name.appendf("float%dx%d_from", columns, rows);
name.appendf("%s_from", typeName.c_str());
for (const std::unique_ptr<Expression>& expr : args) {
name.appendf("_%s", this->typeName(expr->type()).c_str());
}
@ -945,7 +952,7 @@ String MetalCodeGenerator::getMatrixConstructHelper(const AnyConstructor& c) {
// Unlike GLSL, Metal requires that matrices are initialized with exactly R vectors of C
// components apiece. (In Metal 2.0, you can also supply R*C scalars, but you still cannot
// supply a mixture of scalars and vectors.)
fExtraFunctions.printf("float%dx%d %s(", columns, rows, name.c_str());
fExtraFunctions.printf("%s %s(", typeName.c_str(), name.c_str());
size_t argIndex = 0;
const char* argSeparator = "";
@ -955,7 +962,7 @@ String MetalCodeGenerator::getMatrixConstructHelper(const AnyConstructor& c) {
argSeparator = ", ";
}
fExtraFunctions.printf(") {\n return float%dx%d(", columns, rows);
fExtraFunctions.printf(") {\n return %s(", typeName.c_str());
if (args.size() == 1 && args.front()->type().isMatrix()) {
this->assembleMatrixFromMatrix(args.front()->type(), rows, columns);
@ -1042,18 +1049,24 @@ void MetalCodeGenerator::writeConstructorCompound(const ConstructorCompound& c,
}
}
void MetalCodeGenerator::writeVectorFromMat2x2ConstructorHelper() {
static constexpr char kCode[] =
R"(float4 float4_from_float2x2(float2x2 x) {
return float4(x[0].xy, x[1].xy);
}
)";
String MetalCodeGenerator::getVectorFromMat2x2ConstructorHelper(const Type& matrixType) {
SkASSERT(matrixType.isMatrix());
SkASSERT(matrixType.rows() == 2);
SkASSERT(matrixType.columns() == 2);
String name = "matrixCompMult";
if (fHelpers.find("float4_from_float2x2") == fHelpers.end()) {
fHelpers.insert("float4_from_float2x2");
fExtraFunctions.writeText(kCode);
String baseType = this->typeName(matrixType.componentType());
String name = String::printf("%s4_from_%s2x2", baseType.c_str(), baseType.c_str());
if (fHelpers.find(name) == fHelpers.end()) {
fHelpers.insert(name);
fExtraFunctions.printf(R"(
%s4 %s(%s2x2 x) {
return %s4(x[0].xy, x[1].xy);
}
)", baseType.c_str(), name.c_str(), baseType.c_str(), baseType.c_str());
}
return name;
}
void MetalCodeGenerator::writeConstructorCompoundVector(const ConstructorCompound& c,
@ -1065,10 +1078,8 @@ void MetalCodeGenerator::writeConstructorCompoundVector(const ConstructorCompoun
if (c.type().columns() == 4 && c.argumentSpan().size() == 1) {
const Expression& expr = *c.argumentSpan().front();
if (expr.type().isMatrix()) {
SkASSERT(expr.type().rows() == 2);
SkASSERT(expr.type().columns() == 2);
this->writeVectorFromMat2x2ConstructorHelper();
this->write("float4_from_float2x2(");
this->write(this->getVectorFromMat2x2ConstructorHelper(expr.type()));
this->write("(");
this->writeExpression(expr, Precedence::kSequence);
this->write(")");
return;

View File

@ -178,7 +178,7 @@ protected:
void writeMatrixEqualityHelpers(const Type& left, const Type& right);
void writeVectorFromMat2x2ConstructorHelper();
String getVectorFromMat2x2ConstructorHelper(const Type& matrixType);
void writeArrayEqualityHelpers(const Type& type);

View File

@ -11,6 +11,7 @@ struct Inputs {
struct Outputs {
float4 sk_FragColor [[color(0)]];
};
float4 float4_from_float2x2(float2x2 x) {
return float4(x[0].xy, x[1].xy);
}

View File

@ -6,6 +6,7 @@ struct Inputs {
struct Outputs {
float4 sk_FragColor [[color(0)]];
};
float4 float4_from_float2x2(float2x2 x) {
return float4(x[0].xy, x[1].xy);
}