From a3d2f2469239564d98e730adb4b699bb3f7ec791 Mon Sep 17 00:00:00 2001 From: John Stiles Date: Mon, 15 Nov 2021 08:43:55 -0500 Subject: [PATCH] Simplify literal creation in SPIR-V. We had several ways of creating one, zero, true, false, and arbitrary literals throughout the code. These have all been simplified to `writeLiteral(double, Type)`. Change-Id: I8094f142d24b5068e1baf7db8a6148a22e864682 Reviewed-on: https://skia-review.googlesource.com/c/skia/+/471377 Auto-Submit: John Stiles Reviewed-by: Brian Osman Commit-Queue: John Stiles --- src/sksl/codegen/SkSLSPIRVCodeGenerator.cpp | 78 +++++++-------------- 1 file changed, 25 insertions(+), 53 deletions(-) diff --git a/src/sksl/codegen/SkSLSPIRVCodeGenerator.cpp b/src/sksl/codegen/SkSLSPIRVCodeGenerator.cpp index c40bab4a69..2af44bac5d 100644 --- a/src/sksl/codegen/SkSLSPIRVCodeGenerator.cpp +++ b/src/sksl/codegen/SkSLSPIRVCodeGenerator.cpp @@ -621,9 +621,7 @@ SpvId SPIRVCodeGenerator::getType(const Type& rawType, const MemoryLayout& layou } if (type->columns() > 0) { SpvId typeId = this->getType(type->componentType(), layout); - Literal countLiteral(/*line=*/-1, type->columns(), - fContext.fTypes.fInt.get()); - SpvId countId = this->writeLiteral(countLiteral); + SpvId countId = this->writeLiteral(type->columns(), *fContext.fTypes.fInt); this->writeInstruction(SpvOpTypeArray, result, typeId, countId, fConstantBuffer); this->writeInstruction(SpvOpDecorate, result, SpvDecorationArrayStride, @@ -1075,11 +1073,9 @@ SpvId SPIRVCodeGenerator::writeSpecialIntrinsic(const FunctionCall& c, SpecialIn } else { SkASSERT(arguments.size() == 2); if (fProgram.fConfig->fSettings.fSharpenTextures) { - Literal lodBias(/*line=*/-1, /*value=*/-0.5, fContext.fTypes.fFloat.get()); + SpvId lodBias = this->writeLiteral(-0.5, *fContext.fTypes.fFloat); this->writeInstruction(op, type, result, sampler, uv, - SpvImageOperandsBiasMask, - this->writeLiteral(lodBias), - out); + SpvImageOperandsBiasMask, lodBias, out); } else { this->writeInstruction(op, type, result, sampler, uv, out); @@ -1358,10 +1354,8 @@ SpvId SPIRVCodeGenerator::castScalarToFloat(SpvId inputId, const Type& inputType SpvId result = this->nextId(&outputType); if (inputType.isBoolean()) { // Use OpSelect to convert the boolean argument to a literal 1.0 or 0.0. - Literal one(/*line=*/-1, /*value=*/1, fContext.fTypes.fFloat.get()); - const SpvId oneID = this->writeLiteral(one); - Literal zero(/*line=*/-1, /*value=*/0, fContext.fTypes.fFloat.get()); - const SpvId zeroID = this->writeLiteral(zero); + const SpvId oneID = this->writeLiteral(1.0, *fContext.fTypes.fFloat); + const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fFloat); this->writeInstruction(SpvOpSelect, this->getType(outputType), result, inputId, oneID, zeroID, out); } else if (inputType.isSigned()) { @@ -1394,10 +1388,8 @@ SpvId SPIRVCodeGenerator::castScalarToSignedInt(SpvId inputId, const Type& input SpvId result = this->nextId(&outputType); if (inputType.isBoolean()) { // Use OpSelect to convert the boolean argument to a literal 1 or 0. - Literal one(/*line=*/-1, /*value=*/1, fContext.fTypes.fInt.get()); - const SpvId oneID = this->writeLiteral(one); - Literal zero(/*line=*/-1, /*value=*/0, fContext.fTypes.fInt.get()); - const SpvId zeroID = this->writeLiteral(zero); + const SpvId oneID = this->writeLiteral(1.0, *fContext.fTypes.fInt); + const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fInt); this->writeInstruction(SpvOpSelect, this->getType(outputType), result, inputId, oneID, zeroID, out); } else if (inputType.isFloat()) { @@ -1431,10 +1423,8 @@ SpvId SPIRVCodeGenerator::castScalarToUnsignedInt(SpvId inputId, const Type& inp SpvId result = this->nextId(&outputType); if (inputType.isBoolean()) { // Use OpSelect to convert the boolean argument to a literal 1u or 0u. - Literal one(/*line=*/-1, /*value=*/1, fContext.fTypes.fUInt.get()); - const SpvId oneID = this->writeLiteral(one); - Literal zero(/*line=*/-1, /*value=*/0, fContext.fTypes.fUInt.get()); - const SpvId zeroID = this->writeLiteral(zero); + const SpvId oneID = this->writeLiteral(1.0, *fContext.fTypes.fUInt); + const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fUInt); this->writeInstruction(SpvOpSelect, this->getType(outputType), result, inputId, oneID, zeroID, out); } else if (inputType.isFloat()) { @@ -1468,20 +1458,17 @@ SpvId SPIRVCodeGenerator::castScalarToBoolean(SpvId inputId, const Type& inputTy SpvId result = this->nextId(nullptr); if (inputType.isSigned()) { // Synthesize a boolean result by comparing the input against a signed zero literal. - Literal zero(/*line=*/-1, /*value=*/0, fContext.fTypes.fInt.get()); - const SpvId zeroID = this->writeLiteral(zero); + const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fInt); this->writeInstruction(SpvOpINotEqual, this->getType(outputType), result, inputId, zeroID, out); } else if (inputType.isUnsigned()) { // Synthesize a boolean result by comparing the input against an unsigned zero literal. - Literal zero(/*line=*/-1, /*value=*/0, fContext.fTypes.fUInt.get()); - const SpvId zeroID = this->writeLiteral(zero); + const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fUInt); this->writeInstruction(SpvOpINotEqual, this->getType(outputType), result, inputId, zeroID, out); } else if (inputType.isFloat()) { // Synthesize a boolean result by comparing the input against a floating-point zero literal. - Literal zero(/*line=*/-1, /*value=*/0, fContext.fTypes.fFloat.get()); - const SpvId zeroID = this->writeLiteral(zero); + const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fFloat); this->writeInstruction(SpvOpFUnordNotEqual, this->getType(outputType), result, inputId, zeroID, out); } else { @@ -1493,8 +1480,7 @@ SpvId SPIRVCodeGenerator::castScalarToBoolean(SpvId inputId, const Type& inputTy void SPIRVCodeGenerator::writeUniformScaleMatrix(SpvId id, SpvId diagonal, const Type& type, OutputStream& out) { - Literal zero(/*line=*/-1, /*value=*/0, fContext.fTypes.fFloat.get()); - SpvId zeroId = this->writeLiteral(zero); + SpvId zeroId = this->writeLiteral(0.0, *fContext.fTypes.fFloat); std::vector columnIds; columnIds.reserve(type.columns()); for (int column = 0; column < type.columns(); column++) { @@ -1532,10 +1518,8 @@ SpvId SPIRVCodeGenerator::writeMatrixCopy(SpvId src, const Type& srcType, const dstType.rows(), 1)); SkASSERT(dstType.componentType().isFloat()); - Literal zero(/*line=*/-1, /*value=*/0.0, &dstType.componentType()); - const SpvId zeroId = this->writeLiteral(zero); - Literal one(/*line=*/-1, /*value=*/1.0, &dstType.componentType()); - const SpvId oneId = this->writeLiteral(one); + const SpvId zeroId = this->writeLiteral(0.0, dstType.componentType()); + const SpvId oneId = this->writeLiteral(1.0, dstType.componentType()); SpvId columns[4]; for (int i = 0; i < dstType.columns(); i++) { @@ -1889,8 +1873,7 @@ std::vector SPIRVCodeGenerator::getAccessChain(const Expression& expr, Ou case Expression::Kind::kFieldAccess: { const FieldAccess& fieldExpr = expr.as(); chain = this->getAccessChain(*fieldExpr.base(), out); - Literal index(/*line=*/-1, fieldExpr.fieldIndex(), fContext.fTypes.fInt.get()); - chain.push_back(this->writeLiteral(index)); + chain.push_back(this->writeLiteral(fieldExpr.fieldIndex(), *fContext.fTypes.fInt)); break; } default: { @@ -2038,11 +2021,9 @@ std::unique_ptr SPIRVCodeGenerator::getLValue(const const Variable& var = *expr.as().variable(); int uniformIdx = this->findUniformFieldIndex(var); if (uniformIdx >= 0) { - Literal uniformIdxLiteral{/*line=*/-1, (double)uniformIdx, - fContext.fTypes.fInt.get()}; SpvId memberId = this->nextId(nullptr); SpvId typeId = this->getPointerType(type, SpvStorageClassUniform); - SpvId uniformIdxId = this->writeLiteral(uniformIdxLiteral); + SpvId uniformIdxId = this->writeLiteral((double)uniformIdx, *fContext.fTypes.fInt); this->writeInstruction(SpvOpAccessChain, typeId, memberId, fUniformBufferId, uniformIdxId, out); return std::make_unique(*this, memberId, @@ -2082,8 +2063,7 @@ std::unique_ptr SPIRVCodeGenerator::getLValue(const if (swizzle.components().size() == 1) { SpvId member = this->nextId(nullptr); SpvId typeId = this->getPointerType(type, get_storage_class(*swizzle.base())); - Literal index(/*line=*/-1, swizzle.components()[0], fContext.fTypes.fInt.get()); - SpvId indexId = this->writeLiteral(index); + SpvId indexId = this->writeLiteral(swizzle.components()[0], *fContext.fTypes.fInt); this->writeInstruction(SpvOpAccessChain, typeId, member, base, indexId, out); return std::make_unique(*this, member, @@ -2320,14 +2300,9 @@ SpvId SPIRVCodeGenerator::writeComponentwiseMatrixBinary(const Type& operandType return this->writeComposite(columns, operandType, out); } -static std::unique_ptr create_literal_1(const Context& context, const Type& type) { - SkASSERT(type.isInteger() || type.isFloat()); - return Literal::Make(/*line=*/-1, /*value=*/1.0, &type); -} - SpvId SPIRVCodeGenerator::writeReciprocal(const Type& type, SpvId value, OutputStream& out) { SkASSERT(type.isFloat()); - SpvId one = this->writeLiteral({/*line=*/-1, /*value=*/1, &type}); + SpvId one = this->writeLiteral(1.0, type); SpvId reciprocal = this->nextId(&type); this->writeInstruction(SpvOpFDiv, this->getType(type), reciprocal, one, value, out); return reciprocal; @@ -2732,8 +2707,7 @@ SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, Outpu // This converts `expr / 2` into `expr * 0.5` // This improves codegen, especially for certain types of divides (e.g. vector/scalar). op = Operator(Token::Kind::TK_STAR); - Literal reciprocal{right->fLine, 1.0f / rhsValue, &right->type()}; - rhs = this->writeExpression(reciprocal, out); + rhs = this->writeLiteral(1.0 / rhsValue, right->type()); } else { // Write the right-hand side expression normally. rhs = this->writeExpression(*right, out); @@ -2749,8 +2723,7 @@ SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, Outpu SpvId SPIRVCodeGenerator::writeLogicalAnd(const Expression& left, const Expression& right, OutputStream& out) { - Literal falseLiteral(/*line=*/-1, /*value=*/false, fContext.fTypes.fBool.get()); - SpvId falseConstant = this->writeLiteral(falseLiteral); + SpvId falseConstant = this->writeLiteral(0.0, *fContext.fTypes.fBool); SpvId lhs = this->writeExpression(left, out); SpvId rhsLabel = this->nextId(nullptr); SpvId end = this->nextId(nullptr); @@ -2770,8 +2743,7 @@ SpvId SPIRVCodeGenerator::writeLogicalAnd(const Expression& left, const Expressi SpvId SPIRVCodeGenerator::writeLogicalOr(const Expression& left, const Expression& right, OutputStream& out) { - Literal trueLiteral(/*line=*/-1, /*value=*/true, fContext.fTypes.fBool.get()); - SpvId trueConstant = this->writeLiteral(trueLiteral); + SpvId trueConstant = this->writeLiteral(1.0, *fContext.fTypes.fBool); SpvId lhs = this->writeExpression(left, out); SpvId rhsLabel = this->nextId(nullptr); SpvId end = this->nextId(nullptr); @@ -2845,7 +2817,7 @@ SpvId SPIRVCodeGenerator::writePrefixExpression(const PrefixExpression& p, Outpu return this->writeExpression(*p.operand(), out); case Token::Kind::TK_PLUSPLUS: { std::unique_ptr lv = this->getLValue(*p.operand(), out); - SpvId one = this->writeExpression(*create_literal_1(fContext, type), out); + SpvId one = this->writeLiteral(1.0, type); SpvId result = this->writeBinaryOperation(type, type, lv->load(out), one, SpvOpFAdd, SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out); @@ -2854,7 +2826,7 @@ SpvId SPIRVCodeGenerator::writePrefixExpression(const PrefixExpression& p, Outpu } case Token::Kind::TK_MINUSMINUS: { std::unique_ptr lv = this->getLValue(*p.operand(), out); - SpvId one = this->writeExpression(*create_literal_1(fContext, type), out); + SpvId one = this->writeLiteral(1.0, type); SpvId result = this->writeBinaryOperation(type, type, lv->load(out), one, SpvOpFSub, SpvOpISub, SpvOpISub, SpvOpUndef, out); lv->store(result, out); @@ -2883,7 +2855,7 @@ SpvId SPIRVCodeGenerator::writePostfixExpression(const PostfixExpression& p, Out const Type& type = p.type(); std::unique_ptr lv = this->getLValue(*p.operand(), out); SpvId result = lv->load(out); - SpvId one = this->writeExpression(*create_literal_1(fContext, type), out); + SpvId one = this->writeLiteral(1.0, type); switch (p.getOperator().kind()) { case Token::Kind::TK_PLUSPLUS: { SpvId temp = this->writeBinaryOperation(type, type, result, one, SpvOpFAdd,