Simplify literal creation in SPIR-V.

We had several ways of creating one, zero, true, false, and arbitrary
literals throughout the code. These have all been simplified to
`writeLiteral(double, Type)`.

Change-Id: I8094f142d24b5068e1baf7db8a6148a22e864682
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/471377
Auto-Submit: John Stiles <johnstiles@google.com>
Reviewed-by: Brian Osman <brianosman@google.com>
Commit-Queue: John Stiles <johnstiles@google.com>
This commit is contained in:
John Stiles 2021-11-15 08:43:55 -05:00 committed by SkCQ
parent 7fab38d97a
commit a3d2f24692

View File

@ -621,9 +621,7 @@ SpvId SPIRVCodeGenerator::getType(const Type& rawType, const MemoryLayout& layou
}
if (type->columns() > 0) {
SpvId typeId = this->getType(type->componentType(), layout);
Literal countLiteral(/*line=*/-1, type->columns(),
fContext.fTypes.fInt.get());
SpvId countId = this->writeLiteral(countLiteral);
SpvId countId = this->writeLiteral(type->columns(), *fContext.fTypes.fInt);
this->writeInstruction(SpvOpTypeArray, result, typeId, countId,
fConstantBuffer);
this->writeInstruction(SpvOpDecorate, result, SpvDecorationArrayStride,
@ -1075,11 +1073,9 @@ SpvId SPIRVCodeGenerator::writeSpecialIntrinsic(const FunctionCall& c, SpecialIn
} else {
SkASSERT(arguments.size() == 2);
if (fProgram.fConfig->fSettings.fSharpenTextures) {
Literal lodBias(/*line=*/-1, /*value=*/-0.5, fContext.fTypes.fFloat.get());
SpvId lodBias = this->writeLiteral(-0.5, *fContext.fTypes.fFloat);
this->writeInstruction(op, type, result, sampler, uv,
SpvImageOperandsBiasMask,
this->writeLiteral(lodBias),
out);
SpvImageOperandsBiasMask, lodBias, out);
} else {
this->writeInstruction(op, type, result, sampler, uv,
out);
@ -1358,10 +1354,8 @@ SpvId SPIRVCodeGenerator::castScalarToFloat(SpvId inputId, const Type& inputType
SpvId result = this->nextId(&outputType);
if (inputType.isBoolean()) {
// Use OpSelect to convert the boolean argument to a literal 1.0 or 0.0.
Literal one(/*line=*/-1, /*value=*/1, fContext.fTypes.fFloat.get());
const SpvId oneID = this->writeLiteral(one);
Literal zero(/*line=*/-1, /*value=*/0, fContext.fTypes.fFloat.get());
const SpvId zeroID = this->writeLiteral(zero);
const SpvId oneID = this->writeLiteral(1.0, *fContext.fTypes.fFloat);
const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fFloat);
this->writeInstruction(SpvOpSelect, this->getType(outputType), result,
inputId, oneID, zeroID, out);
} else if (inputType.isSigned()) {
@ -1394,10 +1388,8 @@ SpvId SPIRVCodeGenerator::castScalarToSignedInt(SpvId inputId, const Type& input
SpvId result = this->nextId(&outputType);
if (inputType.isBoolean()) {
// Use OpSelect to convert the boolean argument to a literal 1 or 0.
Literal one(/*line=*/-1, /*value=*/1, fContext.fTypes.fInt.get());
const SpvId oneID = this->writeLiteral(one);
Literal zero(/*line=*/-1, /*value=*/0, fContext.fTypes.fInt.get());
const SpvId zeroID = this->writeLiteral(zero);
const SpvId oneID = this->writeLiteral(1.0, *fContext.fTypes.fInt);
const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fInt);
this->writeInstruction(SpvOpSelect, this->getType(outputType), result,
inputId, oneID, zeroID, out);
} else if (inputType.isFloat()) {
@ -1431,10 +1423,8 @@ SpvId SPIRVCodeGenerator::castScalarToUnsignedInt(SpvId inputId, const Type& inp
SpvId result = this->nextId(&outputType);
if (inputType.isBoolean()) {
// Use OpSelect to convert the boolean argument to a literal 1u or 0u.
Literal one(/*line=*/-1, /*value=*/1, fContext.fTypes.fUInt.get());
const SpvId oneID = this->writeLiteral(one);
Literal zero(/*line=*/-1, /*value=*/0, fContext.fTypes.fUInt.get());
const SpvId zeroID = this->writeLiteral(zero);
const SpvId oneID = this->writeLiteral(1.0, *fContext.fTypes.fUInt);
const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fUInt);
this->writeInstruction(SpvOpSelect, this->getType(outputType), result,
inputId, oneID, zeroID, out);
} else if (inputType.isFloat()) {
@ -1468,20 +1458,17 @@ SpvId SPIRVCodeGenerator::castScalarToBoolean(SpvId inputId, const Type& inputTy
SpvId result = this->nextId(nullptr);
if (inputType.isSigned()) {
// Synthesize a boolean result by comparing the input against a signed zero literal.
Literal zero(/*line=*/-1, /*value=*/0, fContext.fTypes.fInt.get());
const SpvId zeroID = this->writeLiteral(zero);
const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fInt);
this->writeInstruction(SpvOpINotEqual, this->getType(outputType), result,
inputId, zeroID, out);
} else if (inputType.isUnsigned()) {
// Synthesize a boolean result by comparing the input against an unsigned zero literal.
Literal zero(/*line=*/-1, /*value=*/0, fContext.fTypes.fUInt.get());
const SpvId zeroID = this->writeLiteral(zero);
const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fUInt);
this->writeInstruction(SpvOpINotEqual, this->getType(outputType), result,
inputId, zeroID, out);
} else if (inputType.isFloat()) {
// Synthesize a boolean result by comparing the input against a floating-point zero literal.
Literal zero(/*line=*/-1, /*value=*/0, fContext.fTypes.fFloat.get());
const SpvId zeroID = this->writeLiteral(zero);
const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fFloat);
this->writeInstruction(SpvOpFUnordNotEqual, this->getType(outputType), result,
inputId, zeroID, out);
} else {
@ -1493,8 +1480,7 @@ SpvId SPIRVCodeGenerator::castScalarToBoolean(SpvId inputId, const Type& inputTy
void SPIRVCodeGenerator::writeUniformScaleMatrix(SpvId id, SpvId diagonal, const Type& type,
OutputStream& out) {
Literal zero(/*line=*/-1, /*value=*/0, fContext.fTypes.fFloat.get());
SpvId zeroId = this->writeLiteral(zero);
SpvId zeroId = this->writeLiteral(0.0, *fContext.fTypes.fFloat);
std::vector<SpvId> columnIds;
columnIds.reserve(type.columns());
for (int column = 0; column < type.columns(); column++) {
@ -1532,10 +1518,8 @@ SpvId SPIRVCodeGenerator::writeMatrixCopy(SpvId src, const Type& srcType, const
dstType.rows(),
1));
SkASSERT(dstType.componentType().isFloat());
Literal zero(/*line=*/-1, /*value=*/0.0, &dstType.componentType());
const SpvId zeroId = this->writeLiteral(zero);
Literal one(/*line=*/-1, /*value=*/1.0, &dstType.componentType());
const SpvId oneId = this->writeLiteral(one);
const SpvId zeroId = this->writeLiteral(0.0, dstType.componentType());
const SpvId oneId = this->writeLiteral(1.0, dstType.componentType());
SpvId columns[4];
for (int i = 0; i < dstType.columns(); i++) {
@ -1889,8 +1873,7 @@ std::vector<SpvId> SPIRVCodeGenerator::getAccessChain(const Expression& expr, Ou
case Expression::Kind::kFieldAccess: {
const FieldAccess& fieldExpr = expr.as<FieldAccess>();
chain = this->getAccessChain(*fieldExpr.base(), out);
Literal index(/*line=*/-1, fieldExpr.fieldIndex(), fContext.fTypes.fInt.get());
chain.push_back(this->writeLiteral(index));
chain.push_back(this->writeLiteral(fieldExpr.fieldIndex(), *fContext.fTypes.fInt));
break;
}
default: {
@ -2038,11 +2021,9 @@ std::unique_ptr<SPIRVCodeGenerator::LValue> SPIRVCodeGenerator::getLValue(const
const Variable& var = *expr.as<VariableReference>().variable();
int uniformIdx = this->findUniformFieldIndex(var);
if (uniformIdx >= 0) {
Literal uniformIdxLiteral{/*line=*/-1, (double)uniformIdx,
fContext.fTypes.fInt.get()};
SpvId memberId = this->nextId(nullptr);
SpvId typeId = this->getPointerType(type, SpvStorageClassUniform);
SpvId uniformIdxId = this->writeLiteral(uniformIdxLiteral);
SpvId uniformIdxId = this->writeLiteral((double)uniformIdx, *fContext.fTypes.fInt);
this->writeInstruction(SpvOpAccessChain, typeId, memberId, fUniformBufferId,
uniformIdxId, out);
return std::make_unique<PointerLValue>(*this, memberId,
@ -2082,8 +2063,7 @@ std::unique_ptr<SPIRVCodeGenerator::LValue> SPIRVCodeGenerator::getLValue(const
if (swizzle.components().size() == 1) {
SpvId member = this->nextId(nullptr);
SpvId typeId = this->getPointerType(type, get_storage_class(*swizzle.base()));
Literal index(/*line=*/-1, swizzle.components()[0], fContext.fTypes.fInt.get());
SpvId indexId = this->writeLiteral(index);
SpvId indexId = this->writeLiteral(swizzle.components()[0], *fContext.fTypes.fInt);
this->writeInstruction(SpvOpAccessChain, typeId, member, base, indexId, out);
return std::make_unique<PointerLValue>(*this,
member,
@ -2320,14 +2300,9 @@ SpvId SPIRVCodeGenerator::writeComponentwiseMatrixBinary(const Type& operandType
return this->writeComposite(columns, operandType, out);
}
static std::unique_ptr<Expression> create_literal_1(const Context& context, const Type& type) {
SkASSERT(type.isInteger() || type.isFloat());
return Literal::Make(/*line=*/-1, /*value=*/1.0, &type);
}
SpvId SPIRVCodeGenerator::writeReciprocal(const Type& type, SpvId value, OutputStream& out) {
SkASSERT(type.isFloat());
SpvId one = this->writeLiteral({/*line=*/-1, /*value=*/1, &type});
SpvId one = this->writeLiteral(1.0, type);
SpvId reciprocal = this->nextId(&type);
this->writeInstruction(SpvOpFDiv, this->getType(type), reciprocal, one, value, out);
return reciprocal;
@ -2732,8 +2707,7 @@ SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, Outpu
// This converts `expr / 2` into `expr * 0.5`
// This improves codegen, especially for certain types of divides (e.g. vector/scalar).
op = Operator(Token::Kind::TK_STAR);
Literal reciprocal{right->fLine, 1.0f / rhsValue, &right->type()};
rhs = this->writeExpression(reciprocal, out);
rhs = this->writeLiteral(1.0 / rhsValue, right->type());
} else {
// Write the right-hand side expression normally.
rhs = this->writeExpression(*right, out);
@ -2749,8 +2723,7 @@ SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, Outpu
SpvId SPIRVCodeGenerator::writeLogicalAnd(const Expression& left, const Expression& right,
OutputStream& out) {
Literal falseLiteral(/*line=*/-1, /*value=*/false, fContext.fTypes.fBool.get());
SpvId falseConstant = this->writeLiteral(falseLiteral);
SpvId falseConstant = this->writeLiteral(0.0, *fContext.fTypes.fBool);
SpvId lhs = this->writeExpression(left, out);
SpvId rhsLabel = this->nextId(nullptr);
SpvId end = this->nextId(nullptr);
@ -2770,8 +2743,7 @@ SpvId SPIRVCodeGenerator::writeLogicalAnd(const Expression& left, const Expressi
SpvId SPIRVCodeGenerator::writeLogicalOr(const Expression& left, const Expression& right,
OutputStream& out) {
Literal trueLiteral(/*line=*/-1, /*value=*/true, fContext.fTypes.fBool.get());
SpvId trueConstant = this->writeLiteral(trueLiteral);
SpvId trueConstant = this->writeLiteral(1.0, *fContext.fTypes.fBool);
SpvId lhs = this->writeExpression(left, out);
SpvId rhsLabel = this->nextId(nullptr);
SpvId end = this->nextId(nullptr);
@ -2845,7 +2817,7 @@ SpvId SPIRVCodeGenerator::writePrefixExpression(const PrefixExpression& p, Outpu
return this->writeExpression(*p.operand(), out);
case Token::Kind::TK_PLUSPLUS: {
std::unique_ptr<LValue> lv = this->getLValue(*p.operand(), out);
SpvId one = this->writeExpression(*create_literal_1(fContext, type), out);
SpvId one = this->writeLiteral(1.0, type);
SpvId result = this->writeBinaryOperation(type, type, lv->load(out), one,
SpvOpFAdd, SpvOpIAdd, SpvOpIAdd, SpvOpUndef,
out);
@ -2854,7 +2826,7 @@ SpvId SPIRVCodeGenerator::writePrefixExpression(const PrefixExpression& p, Outpu
}
case Token::Kind::TK_MINUSMINUS: {
std::unique_ptr<LValue> lv = this->getLValue(*p.operand(), out);
SpvId one = this->writeExpression(*create_literal_1(fContext, type), out);
SpvId one = this->writeLiteral(1.0, type);
SpvId result = this->writeBinaryOperation(type, type, lv->load(out), one, SpvOpFSub,
SpvOpISub, SpvOpISub, SpvOpUndef, out);
lv->store(result, out);
@ -2883,7 +2855,7 @@ SpvId SPIRVCodeGenerator::writePostfixExpression(const PostfixExpression& p, Out
const Type& type = p.type();
std::unique_ptr<LValue> lv = this->getLValue(*p.operand(), out);
SpvId result = lv->load(out);
SpvId one = this->writeExpression(*create_literal_1(fContext, type), out);
SpvId one = this->writeLiteral(1.0, type);
switch (p.getOperator().kind()) {
case Token::Kind::TK_PLUSPLUS: {
SpvId temp = this->writeBinaryOperation(type, type, result, one, SpvOpFAdd,