reworked SPIR-V binary operations and added support for VectorTimesScalar

Bug: skia:
Change-Id: I03b8a1ed3cf78060c5b9a5ede8d0371998116744
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/208677
Reviewed-by: Greg Daniel <egdaniel@google.com>
Commit-Queue: Ethan Nicholas <ethannicholas@google.com>
This commit is contained in:
Ethan Nicholas 2019-04-17 12:22:21 -04:00 committed by Skia Commit-Bot
parent 7bb47f2a2e
commit 49465b41d1
3 changed files with 110 additions and 194 deletions

View File

@ -2012,48 +2012,43 @@ SpvId SPIRVCodeGenerator::writeComponentwiseMatrixBinary(const Type& operandType
return result;
}
SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, OutputStream& out) {
// handle cases where we don't necessarily evaluate both LHS and RHS
switch (b.fOperator) {
case Token::EQ: {
SpvId rhs = this->writeExpression(*b.fRight, out);
this->getLValue(*b.fLeft, out)->store(rhs, out);
return rhs;
}
case Token::LOGICALAND:
return this->writeLogicalAnd(b, out);
case Token::LOGICALOR:
return this->writeLogicalOr(b, out);
default:
break;
std::unique_ptr<Expression> create_literal_1(const Context& context, const Type& type) {
if (type.isInteger()) {
return std::unique_ptr<Expression>(new IntLiteral(-1, 1, &type));
}
// "normal" operators
const Type& resultType = b.fType;
std::unique_ptr<LValue> lvalue;
SpvId lhs;
if (is_assignment(b.fOperator)) {
lvalue = this->getLValue(*b.fLeft, out);
lhs = lvalue->load(out);
else if (type.isFloat()) {
return std::unique_ptr<Expression>(new FloatLiteral(-1, 1.0, &type));
} else {
lvalue = nullptr;
lhs = this->writeExpression(*b.fLeft, out);
}
SpvId rhs = this->writeExpression(*b.fRight, out);
if (b.fOperator == Token::COMMA) {
return rhs;
ABORT("math is unsupported on type '%s'", type.name().c_str());
}
}
SpvId SPIRVCodeGenerator::writeBinaryExpression(const Type& leftType, SpvId lhs, Token::Kind op,
const Type& rightType, SpvId rhs,
const Type& resultType, OutputStream& out) {
Type tmp("<invalid>");
// overall type we are operating on: float2, int, uint4...
const Type* operandType;
// IR allows mismatched types in expressions (e.g. float2 * float), but they need special
// handling in SPIR-V
if (this->getActualType(b.fLeft->fType) != this->getActualType(b.fRight->fType)) {
if (b.fLeft->fType.kind() == Type::kVector_Kind &&
b.fRight->fType.isNumber()) {
if (this->getActualType(leftType) != this->getActualType(rightType)) {
if (leftType.kind() == Type::kVector_Kind && rightType.isNumber()) {
if (op == Token::SLASH) {
SpvId one = this->writeExpression(*create_literal_1(fContext, rightType), out);
SpvId inverse = this->nextId();
this->writeInstruction(SpvOpFDiv, this->getType(rightType), inverse, one, rhs, out);
rhs = inverse;
op = Token::STAR;
}
if (op == Token::STAR) {
SpvId result = this->nextId();
this->writeInstruction(SpvOpVectorTimesScalar, this->getType(resultType),
result, lhs, rhs, out);
return result;
}
// promote number to vector
SpvId vec = this->nextId();
const Type& vecType = b.fLeft->fType;
const Type& vecType = leftType;
this->writeOpCode(SpvOpCompositeConstruct, 3 + vecType.columns(), out);
this->writeWord(this->getType(vecType), out);
this->writeWord(vec, out);
@ -2061,12 +2056,17 @@ SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, Outpu
this->writeWord(rhs, out);
}
rhs = vec;
operandType = &b.fLeft->fType;
} else if (b.fRight->fType.kind() == Type::kVector_Kind &&
b.fLeft->fType.isNumber()) {
operandType = &leftType;
} else if (rightType.kind() == Type::kVector_Kind && leftType.isNumber()) {
if (op == Token::STAR) {
SpvId result = this->nextId();
this->writeInstruction(SpvOpVectorTimesScalar, this->getType(resultType),
result, rhs, lhs, out);
return result;
}
// promote number to vector
SpvId vec = this->nextId();
const Type& vecType = b.fRight->fType;
const Type& vecType = rightType;
this->writeOpCode(SpvOpCompositeConstruct, 3 + vecType.columns(), out);
this->writeWord(this->getType(vecType), out);
this->writeWord(vec, out);
@ -2074,52 +2074,41 @@ SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, Outpu
this->writeWord(lhs, out);
}
lhs = vec;
SkASSERT(!lvalue);
operandType = &b.fRight->fType;
} else if (b.fLeft->fType.kind() == Type::kMatrix_Kind) {
SpvOp_ op;
if (b.fRight->fType.kind() == Type::kMatrix_Kind) {
op = SpvOpMatrixTimesMatrix;
} else if (b.fRight->fType.kind() == Type::kVector_Kind) {
op = SpvOpMatrixTimesVector;
operandType = &rightType;
} else if (leftType.kind() == Type::kMatrix_Kind) {
SpvOp_ spvop;
if (rightType.kind() == Type::kMatrix_Kind) {
spvop = SpvOpMatrixTimesMatrix;
} else if (rightType.kind() == Type::kVector_Kind) {
spvop = SpvOpMatrixTimesVector;
} else {
SkASSERT(b.fRight->fType.kind() == Type::kScalar_Kind);
op = SpvOpMatrixTimesScalar;
SkASSERT(rightType.kind() == Type::kScalar_Kind);
spvop = SpvOpMatrixTimesScalar;
}
SpvId result = this->nextId();
this->writeInstruction(op, this->getType(b.fType), result, lhs, rhs, out);
if (b.fOperator == Token::STAREQ) {
lvalue->store(result, out);
} else {
SkASSERT(b.fOperator == Token::STAR);
}
this->writeInstruction(spvop, this->getType(resultType), result, lhs, rhs, out);
return result;
} else if (b.fRight->fType.kind() == Type::kMatrix_Kind) {
} else if (rightType.kind() == Type::kMatrix_Kind) {
SpvId result = this->nextId();
if (b.fLeft->fType.kind() == Type::kVector_Kind) {
this->writeInstruction(SpvOpVectorTimesMatrix, this->getType(b.fType), result,
if (leftType.kind() == Type::kVector_Kind) {
this->writeInstruction(SpvOpVectorTimesMatrix, this->getType(resultType), result,
lhs, rhs, out);
} else {
SkASSERT(b.fLeft->fType.kind() == Type::kScalar_Kind);
this->writeInstruction(SpvOpMatrixTimesScalar, this->getType(b.fType), result, rhs,
lhs, out);
}
if (b.fOperator == Token::STAREQ) {
lvalue->store(result, out);
} else {
SkASSERT(b.fOperator == Token::STAR);
SkASSERT(leftType.kind() == Type::kScalar_Kind);
this->writeInstruction(SpvOpMatrixTimesScalar, this->getType(resultType), result,
rhs, lhs, out);
}
return result;
} else {
ABORT("unsupported binary expression: %s (%s, %s)", b.description().c_str(),
b.fLeft->fType.description().c_str(), b.fRight->fType.description().c_str());
SkASSERT(false);
return -1;
}
} else {
tmp = this->getActualType(b.fLeft->fType);
tmp = this->getActualType(leftType);
operandType = &tmp;
SkASSERT(*operandType == this->getActualType(b.fRight->fType));
SkASSERT(*operandType == this->getActualType(rightType));
}
switch (b.fOperator) {
switch (op) {
case Token::EQEQ: {
if (operandType->kind() == Type::kMatrix_Kind) {
return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFOrdEqual,
@ -2178,26 +2167,26 @@ SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, Outpu
SpvOpFOrdLessThanEqual, SpvOpSLessThanEqual,
SpvOpULessThanEqual, SpvOpUndef, out);
case Token::PLUS:
if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
b.fRight->fType.kind() == Type::kMatrix_Kind) {
SkASSERT(b.fLeft->fType == b.fRight->fType);
return this->writeComponentwiseMatrixBinary(b.fLeft->fType, lhs, rhs,
if (leftType.kind() == Type::kMatrix_Kind &&
rightType.kind() == Type::kMatrix_Kind) {
SkASSERT(leftType == rightType);
return this->writeComponentwiseMatrixBinary(leftType, lhs, rhs,
SpvOpFAdd, SpvOpIAdd, out);
}
return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFAdd,
SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
case Token::MINUS:
if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
b.fRight->fType.kind() == Type::kMatrix_Kind) {
SkASSERT(b.fLeft->fType == b.fRight->fType);
return this->writeComponentwiseMatrixBinary(b.fLeft->fType, lhs, rhs,
if (leftType.kind() == Type::kMatrix_Kind &&
rightType.kind() == Type::kMatrix_Kind) {
SkASSERT(leftType == rightType);
return this->writeComponentwiseMatrixBinary(leftType, lhs, rhs,
SpvOpFSub, SpvOpISub, out);
}
return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub,
SpvOpISub, SpvOpISub, SpvOpUndef, out);
case Token::STAR:
if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
b.fRight->fType.kind() == Type::kMatrix_Kind) {
if (leftType.kind() == Type::kMatrix_Kind &&
rightType.kind() == Type::kMatrix_Kind) {
// matrix multiply
SpvId result = this->nextId();
this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result,
@ -2229,114 +2218,48 @@ SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, Outpu
case Token::BITWISEXOR:
return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
SpvOpBitwiseXor, SpvOpBitwiseXor, SpvOpUndef, out);
case Token::PLUSEQ: {
SpvId result;
if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
b.fRight->fType.kind() == Type::kMatrix_Kind) {
SkASSERT(b.fLeft->fType == b.fRight->fType);
result = this->writeComponentwiseMatrixBinary(b.fLeft->fType, lhs, rhs,
SpvOpFAdd, SpvOpIAdd, out);
}
else {
result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFAdd,
SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
}
SkASSERT(lvalue);
lvalue->store(result, out);
return result;
}
case Token::MINUSEQ: {
SpvId result;
if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
b.fRight->fType.kind() == Type::kMatrix_Kind) {
SkASSERT(b.fLeft->fType == b.fRight->fType);
result = this->writeComponentwiseMatrixBinary(b.fLeft->fType, lhs, rhs,
SpvOpFSub, SpvOpISub, out);
}
else {
result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub,
SpvOpISub, SpvOpISub, SpvOpUndef, out);
}
SkASSERT(lvalue);
lvalue->store(result, out);
return result;
}
case Token::STAREQ: {
if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
b.fRight->fType.kind() == Type::kMatrix_Kind) {
// matrix multiply
SpvId result = this->nextId();
this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result,
lhs, rhs, out);
SkASSERT(lvalue);
lvalue->store(result, out);
return result;
}
SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMul,
SpvOpIMul, SpvOpIMul, SpvOpUndef, out);
SkASSERT(lvalue);
lvalue->store(result, out);
return result;
}
case Token::SLASHEQ: {
SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFDiv,
SpvOpSDiv, SpvOpUDiv, SpvOpUndef, out);
SkASSERT(lvalue);
lvalue->store(result, out);
return result;
}
case Token::PERCENTEQ: {
SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMod,
SpvOpSMod, SpvOpUMod, SpvOpUndef, out);
SkASSERT(lvalue);
lvalue->store(result, out);
return result;
}
case Token::SHLEQ: {
SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
SpvOpUndef, SpvOpShiftLeftLogical,
SpvOpShiftLeftLogical, SpvOpUndef, out);
SkASSERT(lvalue);
lvalue->store(result, out);
return result;
}
case Token::SHREQ: {
SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
SpvOpUndef, SpvOpShiftRightArithmetic,
SpvOpShiftRightLogical, SpvOpUndef, out);
SkASSERT(lvalue);
lvalue->store(result, out);
return result;
}
case Token::BITWISEANDEQ: {
SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
SpvOpUndef, SpvOpBitwiseAnd, SpvOpBitwiseAnd,
SpvOpUndef, out);
SkASSERT(lvalue);
lvalue->store(result, out);
return result;
}
case Token::BITWISEOREQ: {
SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
SpvOpUndef, SpvOpBitwiseOr, SpvOpBitwiseOr,
SpvOpUndef, out);
SkASSERT(lvalue);
lvalue->store(result, out);
return result;
}
case Token::BITWISEXOREQ: {
SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
SpvOpUndef, SpvOpBitwiseXor, SpvOpBitwiseXor,
SpvOpUndef, out);
SkASSERT(lvalue);
lvalue->store(result, out);
return result;
}
case Token::COMMA:
return rhs;
default:
ABORT("unsupported binary expression: %s", b.description().c_str());
SkASSERT(false);
return -1;
}
}
SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, OutputStream& out) {
// handle cases where we don't necessarily evaluate both LHS and RHS
switch (b.fOperator) {
case Token::EQ: {
SpvId rhs = this->writeExpression(*b.fRight, out);
this->getLValue(*b.fLeft, out)->store(rhs, out);
return rhs;
}
case Token::LOGICALAND:
return this->writeLogicalAnd(b, out);
case Token::LOGICALOR:
return this->writeLogicalOr(b, out);
default:
break;
}
std::unique_ptr<LValue> lvalue;
SpvId lhs;
if (is_assignment(b.fOperator)) {
lvalue = this->getLValue(*b.fLeft, out);
lhs = lvalue->load(out);
} else {
lvalue = nullptr;
lhs = this->writeExpression(*b.fLeft, out);
}
SpvId rhs = this->writeExpression(*b.fRight, out);
SpvId result = this->writeBinaryExpression(b.fLeft->fType, lhs, remove_assignment(b.fOperator),
b.fRight->fType, rhs, b.fType, out);
if (lvalue) {
lvalue->store(result, out);
}
return result;
}
SpvId SPIRVCodeGenerator::writeLogicalAnd(const BinaryExpression& a, OutputStream& out) {
SkASSERT(a.fOperator == Token::LOGICALAND);
BoolLiteral falseLiteral(fContext, -1, false);
@ -2413,17 +2336,6 @@ SpvId SPIRVCodeGenerator::writeTernaryExpression(const TernaryExpression& t, Out
return result;
}
std::unique_ptr<Expression> create_literal_1(const Context& context, const Type& type) {
if (type.isInteger()) {
return std::unique_ptr<Expression>(new IntLiteral(-1, 1, &type));
}
else if (type.isFloat()) {
return std::unique_ptr<Expression>(new FloatLiteral(-1, 1.0, &type));
} else {
ABORT("math is unsupported on type '%s'", type.name().c_str());
}
}
SpvId SPIRVCodeGenerator::writePrefixExpression(const PrefixExpression& p, OutputStream& out) {
if (p.fOperator == Token::MINUS) {
SpvId result = this->nextId();

View File

@ -274,6 +274,10 @@ private:
SpvId writeBinaryOperation(const BinaryExpression& expr, SpvOp_ ifFloat, SpvOp_ ifInt,
SpvOp_ ifUInt, OutputStream& out);
SpvId writeBinaryExpression(const Type& leftType, SpvId lhs, Token::Kind op,
const Type& rightType, SpvId rhs, const Type& resultType,
OutputStream& out);
SpvId writeBinaryExpression(const BinaryExpression& b, OutputStream& out);
SpvId writeTernaryExpression(const TernaryExpression& t, OutputStream& out);

View File

@ -67,7 +67,7 @@ Token::Kind remove_assignment(Token::Kind op) {
case Token::LOGICALOREQ: return Token::LOGICALOR;
case Token::LOGICALXOREQ: return Token::LOGICALXOR;
case Token::LOGICALANDEQ: return Token::LOGICALAND;
default: return Token::INVALID;
default: return op;
}
}