reworked SPIR-V binary operations and added support for VectorTimesScalar
Bug: skia: Change-Id: I03b8a1ed3cf78060c5b9a5ede8d0371998116744 Reviewed-on: https://skia-review.googlesource.com/c/skia/+/208677 Reviewed-by: Greg Daniel <egdaniel@google.com> Commit-Queue: Ethan Nicholas <ethannicholas@google.com>
This commit is contained in:
parent
7bb47f2a2e
commit
49465b41d1
@ -2012,48 +2012,43 @@ SpvId SPIRVCodeGenerator::writeComponentwiseMatrixBinary(const Type& operandType
|
||||
return result;
|
||||
}
|
||||
|
||||
SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, OutputStream& out) {
|
||||
// handle cases where we don't necessarily evaluate both LHS and RHS
|
||||
switch (b.fOperator) {
|
||||
case Token::EQ: {
|
||||
SpvId rhs = this->writeExpression(*b.fRight, out);
|
||||
this->getLValue(*b.fLeft, out)->store(rhs, out);
|
||||
return rhs;
|
||||
}
|
||||
case Token::LOGICALAND:
|
||||
return this->writeLogicalAnd(b, out);
|
||||
case Token::LOGICALOR:
|
||||
return this->writeLogicalOr(b, out);
|
||||
default:
|
||||
break;
|
||||
std::unique_ptr<Expression> create_literal_1(const Context& context, const Type& type) {
|
||||
if (type.isInteger()) {
|
||||
return std::unique_ptr<Expression>(new IntLiteral(-1, 1, &type));
|
||||
}
|
||||
|
||||
// "normal" operators
|
||||
const Type& resultType = b.fType;
|
||||
std::unique_ptr<LValue> lvalue;
|
||||
SpvId lhs;
|
||||
if (is_assignment(b.fOperator)) {
|
||||
lvalue = this->getLValue(*b.fLeft, out);
|
||||
lhs = lvalue->load(out);
|
||||
else if (type.isFloat()) {
|
||||
return std::unique_ptr<Expression>(new FloatLiteral(-1, 1.0, &type));
|
||||
} else {
|
||||
lvalue = nullptr;
|
||||
lhs = this->writeExpression(*b.fLeft, out);
|
||||
}
|
||||
SpvId rhs = this->writeExpression(*b.fRight, out);
|
||||
if (b.fOperator == Token::COMMA) {
|
||||
return rhs;
|
||||
ABORT("math is unsupported on type '%s'", type.name().c_str());
|
||||
}
|
||||
}
|
||||
|
||||
SpvId SPIRVCodeGenerator::writeBinaryExpression(const Type& leftType, SpvId lhs, Token::Kind op,
|
||||
const Type& rightType, SpvId rhs,
|
||||
const Type& resultType, OutputStream& out) {
|
||||
Type tmp("<invalid>");
|
||||
// overall type we are operating on: float2, int, uint4...
|
||||
const Type* operandType;
|
||||
// IR allows mismatched types in expressions (e.g. float2 * float), but they need special
|
||||
// handling in SPIR-V
|
||||
if (this->getActualType(b.fLeft->fType) != this->getActualType(b.fRight->fType)) {
|
||||
if (b.fLeft->fType.kind() == Type::kVector_Kind &&
|
||||
b.fRight->fType.isNumber()) {
|
||||
if (this->getActualType(leftType) != this->getActualType(rightType)) {
|
||||
if (leftType.kind() == Type::kVector_Kind && rightType.isNumber()) {
|
||||
if (op == Token::SLASH) {
|
||||
SpvId one = this->writeExpression(*create_literal_1(fContext, rightType), out);
|
||||
SpvId inverse = this->nextId();
|
||||
this->writeInstruction(SpvOpFDiv, this->getType(rightType), inverse, one, rhs, out);
|
||||
rhs = inverse;
|
||||
op = Token::STAR;
|
||||
}
|
||||
if (op == Token::STAR) {
|
||||
SpvId result = this->nextId();
|
||||
this->writeInstruction(SpvOpVectorTimesScalar, this->getType(resultType),
|
||||
result, lhs, rhs, out);
|
||||
return result;
|
||||
}
|
||||
// promote number to vector
|
||||
SpvId vec = this->nextId();
|
||||
const Type& vecType = b.fLeft->fType;
|
||||
const Type& vecType = leftType;
|
||||
this->writeOpCode(SpvOpCompositeConstruct, 3 + vecType.columns(), out);
|
||||
this->writeWord(this->getType(vecType), out);
|
||||
this->writeWord(vec, out);
|
||||
@ -2061,12 +2056,17 @@ SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, Outpu
|
||||
this->writeWord(rhs, out);
|
||||
}
|
||||
rhs = vec;
|
||||
operandType = &b.fLeft->fType;
|
||||
} else if (b.fRight->fType.kind() == Type::kVector_Kind &&
|
||||
b.fLeft->fType.isNumber()) {
|
||||
operandType = &leftType;
|
||||
} else if (rightType.kind() == Type::kVector_Kind && leftType.isNumber()) {
|
||||
if (op == Token::STAR) {
|
||||
SpvId result = this->nextId();
|
||||
this->writeInstruction(SpvOpVectorTimesScalar, this->getType(resultType),
|
||||
result, rhs, lhs, out);
|
||||
return result;
|
||||
}
|
||||
// promote number to vector
|
||||
SpvId vec = this->nextId();
|
||||
const Type& vecType = b.fRight->fType;
|
||||
const Type& vecType = rightType;
|
||||
this->writeOpCode(SpvOpCompositeConstruct, 3 + vecType.columns(), out);
|
||||
this->writeWord(this->getType(vecType), out);
|
||||
this->writeWord(vec, out);
|
||||
@ -2074,52 +2074,41 @@ SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, Outpu
|
||||
this->writeWord(lhs, out);
|
||||
}
|
||||
lhs = vec;
|
||||
SkASSERT(!lvalue);
|
||||
operandType = &b.fRight->fType;
|
||||
} else if (b.fLeft->fType.kind() == Type::kMatrix_Kind) {
|
||||
SpvOp_ op;
|
||||
if (b.fRight->fType.kind() == Type::kMatrix_Kind) {
|
||||
op = SpvOpMatrixTimesMatrix;
|
||||
} else if (b.fRight->fType.kind() == Type::kVector_Kind) {
|
||||
op = SpvOpMatrixTimesVector;
|
||||
operandType = &rightType;
|
||||
} else if (leftType.kind() == Type::kMatrix_Kind) {
|
||||
SpvOp_ spvop;
|
||||
if (rightType.kind() == Type::kMatrix_Kind) {
|
||||
spvop = SpvOpMatrixTimesMatrix;
|
||||
} else if (rightType.kind() == Type::kVector_Kind) {
|
||||
spvop = SpvOpMatrixTimesVector;
|
||||
} else {
|
||||
SkASSERT(b.fRight->fType.kind() == Type::kScalar_Kind);
|
||||
op = SpvOpMatrixTimesScalar;
|
||||
SkASSERT(rightType.kind() == Type::kScalar_Kind);
|
||||
spvop = SpvOpMatrixTimesScalar;
|
||||
}
|
||||
SpvId result = this->nextId();
|
||||
this->writeInstruction(op, this->getType(b.fType), result, lhs, rhs, out);
|
||||
if (b.fOperator == Token::STAREQ) {
|
||||
lvalue->store(result, out);
|
||||
} else {
|
||||
SkASSERT(b.fOperator == Token::STAR);
|
||||
}
|
||||
this->writeInstruction(spvop, this->getType(resultType), result, lhs, rhs, out);
|
||||
return result;
|
||||
} else if (b.fRight->fType.kind() == Type::kMatrix_Kind) {
|
||||
} else if (rightType.kind() == Type::kMatrix_Kind) {
|
||||
SpvId result = this->nextId();
|
||||
if (b.fLeft->fType.kind() == Type::kVector_Kind) {
|
||||
this->writeInstruction(SpvOpVectorTimesMatrix, this->getType(b.fType), result,
|
||||
if (leftType.kind() == Type::kVector_Kind) {
|
||||
this->writeInstruction(SpvOpVectorTimesMatrix, this->getType(resultType), result,
|
||||
lhs, rhs, out);
|
||||
} else {
|
||||
SkASSERT(b.fLeft->fType.kind() == Type::kScalar_Kind);
|
||||
this->writeInstruction(SpvOpMatrixTimesScalar, this->getType(b.fType), result, rhs,
|
||||
lhs, out);
|
||||
}
|
||||
if (b.fOperator == Token::STAREQ) {
|
||||
lvalue->store(result, out);
|
||||
} else {
|
||||
SkASSERT(b.fOperator == Token::STAR);
|
||||
SkASSERT(leftType.kind() == Type::kScalar_Kind);
|
||||
this->writeInstruction(SpvOpMatrixTimesScalar, this->getType(resultType), result,
|
||||
rhs, lhs, out);
|
||||
}
|
||||
return result;
|
||||
} else {
|
||||
ABORT("unsupported binary expression: %s (%s, %s)", b.description().c_str(),
|
||||
b.fLeft->fType.description().c_str(), b.fRight->fType.description().c_str());
|
||||
SkASSERT(false);
|
||||
return -1;
|
||||
}
|
||||
} else {
|
||||
tmp = this->getActualType(b.fLeft->fType);
|
||||
tmp = this->getActualType(leftType);
|
||||
operandType = &tmp;
|
||||
SkASSERT(*operandType == this->getActualType(b.fRight->fType));
|
||||
SkASSERT(*operandType == this->getActualType(rightType));
|
||||
}
|
||||
switch (b.fOperator) {
|
||||
switch (op) {
|
||||
case Token::EQEQ: {
|
||||
if (operandType->kind() == Type::kMatrix_Kind) {
|
||||
return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFOrdEqual,
|
||||
@ -2178,26 +2167,26 @@ SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, Outpu
|
||||
SpvOpFOrdLessThanEqual, SpvOpSLessThanEqual,
|
||||
SpvOpULessThanEqual, SpvOpUndef, out);
|
||||
case Token::PLUS:
|
||||
if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
|
||||
b.fRight->fType.kind() == Type::kMatrix_Kind) {
|
||||
SkASSERT(b.fLeft->fType == b.fRight->fType);
|
||||
return this->writeComponentwiseMatrixBinary(b.fLeft->fType, lhs, rhs,
|
||||
if (leftType.kind() == Type::kMatrix_Kind &&
|
||||
rightType.kind() == Type::kMatrix_Kind) {
|
||||
SkASSERT(leftType == rightType);
|
||||
return this->writeComponentwiseMatrixBinary(leftType, lhs, rhs,
|
||||
SpvOpFAdd, SpvOpIAdd, out);
|
||||
}
|
||||
return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFAdd,
|
||||
SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
|
||||
case Token::MINUS:
|
||||
if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
|
||||
b.fRight->fType.kind() == Type::kMatrix_Kind) {
|
||||
SkASSERT(b.fLeft->fType == b.fRight->fType);
|
||||
return this->writeComponentwiseMatrixBinary(b.fLeft->fType, lhs, rhs,
|
||||
if (leftType.kind() == Type::kMatrix_Kind &&
|
||||
rightType.kind() == Type::kMatrix_Kind) {
|
||||
SkASSERT(leftType == rightType);
|
||||
return this->writeComponentwiseMatrixBinary(leftType, lhs, rhs,
|
||||
SpvOpFSub, SpvOpISub, out);
|
||||
}
|
||||
return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub,
|
||||
SpvOpISub, SpvOpISub, SpvOpUndef, out);
|
||||
case Token::STAR:
|
||||
if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
|
||||
b.fRight->fType.kind() == Type::kMatrix_Kind) {
|
||||
if (leftType.kind() == Type::kMatrix_Kind &&
|
||||
rightType.kind() == Type::kMatrix_Kind) {
|
||||
// matrix multiply
|
||||
SpvId result = this->nextId();
|
||||
this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result,
|
||||
@ -2229,114 +2218,48 @@ SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, Outpu
|
||||
case Token::BITWISEXOR:
|
||||
return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
|
||||
SpvOpBitwiseXor, SpvOpBitwiseXor, SpvOpUndef, out);
|
||||
case Token::PLUSEQ: {
|
||||
SpvId result;
|
||||
if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
|
||||
b.fRight->fType.kind() == Type::kMatrix_Kind) {
|
||||
SkASSERT(b.fLeft->fType == b.fRight->fType);
|
||||
result = this->writeComponentwiseMatrixBinary(b.fLeft->fType, lhs, rhs,
|
||||
SpvOpFAdd, SpvOpIAdd, out);
|
||||
}
|
||||
else {
|
||||
result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFAdd,
|
||||
SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
|
||||
}
|
||||
SkASSERT(lvalue);
|
||||
lvalue->store(result, out);
|
||||
return result;
|
||||
}
|
||||
case Token::MINUSEQ: {
|
||||
SpvId result;
|
||||
if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
|
||||
b.fRight->fType.kind() == Type::kMatrix_Kind) {
|
||||
SkASSERT(b.fLeft->fType == b.fRight->fType);
|
||||
result = this->writeComponentwiseMatrixBinary(b.fLeft->fType, lhs, rhs,
|
||||
SpvOpFSub, SpvOpISub, out);
|
||||
}
|
||||
else {
|
||||
result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub,
|
||||
SpvOpISub, SpvOpISub, SpvOpUndef, out);
|
||||
}
|
||||
SkASSERT(lvalue);
|
||||
lvalue->store(result, out);
|
||||
return result;
|
||||
}
|
||||
case Token::STAREQ: {
|
||||
if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
|
||||
b.fRight->fType.kind() == Type::kMatrix_Kind) {
|
||||
// matrix multiply
|
||||
SpvId result = this->nextId();
|
||||
this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result,
|
||||
lhs, rhs, out);
|
||||
SkASSERT(lvalue);
|
||||
lvalue->store(result, out);
|
||||
return result;
|
||||
}
|
||||
SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMul,
|
||||
SpvOpIMul, SpvOpIMul, SpvOpUndef, out);
|
||||
SkASSERT(lvalue);
|
||||
lvalue->store(result, out);
|
||||
return result;
|
||||
}
|
||||
case Token::SLASHEQ: {
|
||||
SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFDiv,
|
||||
SpvOpSDiv, SpvOpUDiv, SpvOpUndef, out);
|
||||
SkASSERT(lvalue);
|
||||
lvalue->store(result, out);
|
||||
return result;
|
||||
}
|
||||
case Token::PERCENTEQ: {
|
||||
SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMod,
|
||||
SpvOpSMod, SpvOpUMod, SpvOpUndef, out);
|
||||
SkASSERT(lvalue);
|
||||
lvalue->store(result, out);
|
||||
return result;
|
||||
}
|
||||
case Token::SHLEQ: {
|
||||
SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
|
||||
SpvOpUndef, SpvOpShiftLeftLogical,
|
||||
SpvOpShiftLeftLogical, SpvOpUndef, out);
|
||||
SkASSERT(lvalue);
|
||||
lvalue->store(result, out);
|
||||
return result;
|
||||
}
|
||||
case Token::SHREQ: {
|
||||
SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
|
||||
SpvOpUndef, SpvOpShiftRightArithmetic,
|
||||
SpvOpShiftRightLogical, SpvOpUndef, out);
|
||||
SkASSERT(lvalue);
|
||||
lvalue->store(result, out);
|
||||
return result;
|
||||
}
|
||||
case Token::BITWISEANDEQ: {
|
||||
SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
|
||||
SpvOpUndef, SpvOpBitwiseAnd, SpvOpBitwiseAnd,
|
||||
SpvOpUndef, out);
|
||||
SkASSERT(lvalue);
|
||||
lvalue->store(result, out);
|
||||
return result;
|
||||
}
|
||||
case Token::BITWISEOREQ: {
|
||||
SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
|
||||
SpvOpUndef, SpvOpBitwiseOr, SpvOpBitwiseOr,
|
||||
SpvOpUndef, out);
|
||||
SkASSERT(lvalue);
|
||||
lvalue->store(result, out);
|
||||
return result;
|
||||
}
|
||||
case Token::BITWISEXOREQ: {
|
||||
SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
|
||||
SpvOpUndef, SpvOpBitwiseXor, SpvOpBitwiseXor,
|
||||
SpvOpUndef, out);
|
||||
SkASSERT(lvalue);
|
||||
lvalue->store(result, out);
|
||||
return result;
|
||||
}
|
||||
case Token::COMMA:
|
||||
return rhs;
|
||||
default:
|
||||
ABORT("unsupported binary expression: %s", b.description().c_str());
|
||||
SkASSERT(false);
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, OutputStream& out) {
|
||||
// handle cases where we don't necessarily evaluate both LHS and RHS
|
||||
switch (b.fOperator) {
|
||||
case Token::EQ: {
|
||||
SpvId rhs = this->writeExpression(*b.fRight, out);
|
||||
this->getLValue(*b.fLeft, out)->store(rhs, out);
|
||||
return rhs;
|
||||
}
|
||||
case Token::LOGICALAND:
|
||||
return this->writeLogicalAnd(b, out);
|
||||
case Token::LOGICALOR:
|
||||
return this->writeLogicalOr(b, out);
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
std::unique_ptr<LValue> lvalue;
|
||||
SpvId lhs;
|
||||
if (is_assignment(b.fOperator)) {
|
||||
lvalue = this->getLValue(*b.fLeft, out);
|
||||
lhs = lvalue->load(out);
|
||||
} else {
|
||||
lvalue = nullptr;
|
||||
lhs = this->writeExpression(*b.fLeft, out);
|
||||
}
|
||||
SpvId rhs = this->writeExpression(*b.fRight, out);
|
||||
SpvId result = this->writeBinaryExpression(b.fLeft->fType, lhs, remove_assignment(b.fOperator),
|
||||
b.fRight->fType, rhs, b.fType, out);
|
||||
if (lvalue) {
|
||||
lvalue->store(result, out);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
SpvId SPIRVCodeGenerator::writeLogicalAnd(const BinaryExpression& a, OutputStream& out) {
|
||||
SkASSERT(a.fOperator == Token::LOGICALAND);
|
||||
BoolLiteral falseLiteral(fContext, -1, false);
|
||||
@ -2413,17 +2336,6 @@ SpvId SPIRVCodeGenerator::writeTernaryExpression(const TernaryExpression& t, Out
|
||||
return result;
|
||||
}
|
||||
|
||||
std::unique_ptr<Expression> create_literal_1(const Context& context, const Type& type) {
|
||||
if (type.isInteger()) {
|
||||
return std::unique_ptr<Expression>(new IntLiteral(-1, 1, &type));
|
||||
}
|
||||
else if (type.isFloat()) {
|
||||
return std::unique_ptr<Expression>(new FloatLiteral(-1, 1.0, &type));
|
||||
} else {
|
||||
ABORT("math is unsupported on type '%s'", type.name().c_str());
|
||||
}
|
||||
}
|
||||
|
||||
SpvId SPIRVCodeGenerator::writePrefixExpression(const PrefixExpression& p, OutputStream& out) {
|
||||
if (p.fOperator == Token::MINUS) {
|
||||
SpvId result = this->nextId();
|
||||
|
@ -274,6 +274,10 @@ private:
|
||||
SpvId writeBinaryOperation(const BinaryExpression& expr, SpvOp_ ifFloat, SpvOp_ ifInt,
|
||||
SpvOp_ ifUInt, OutputStream& out);
|
||||
|
||||
SpvId writeBinaryExpression(const Type& leftType, SpvId lhs, Token::Kind op,
|
||||
const Type& rightType, SpvId rhs, const Type& resultType,
|
||||
OutputStream& out);
|
||||
|
||||
SpvId writeBinaryExpression(const BinaryExpression& b, OutputStream& out);
|
||||
|
||||
SpvId writeTernaryExpression(const TernaryExpression& t, OutputStream& out);
|
||||
|
@ -67,7 +67,7 @@ Token::Kind remove_assignment(Token::Kind op) {
|
||||
case Token::LOGICALOREQ: return Token::LOGICALOR;
|
||||
case Token::LOGICALXOREQ: return Token::LOGICALXOR;
|
||||
case Token::LOGICALANDEQ: return Token::LOGICALAND;
|
||||
default: return Token::INVALID;
|
||||
default: return op;
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user