fixed SPIR-V matrix operations
Bug: skia: Change-Id: I23be824cdd7d00ffd0c54516a168c07e77bb4f49 Reviewed-on: https://skia-review.googlesource.com/140182 Reviewed-by: Greg Daniel <egdaniel@google.com> Commit-Queue: Ethan Nicholas <ethannicholas@google.com>
This commit is contained in:
parent
ed1205ae20
commit
0df21136e3
@ -143,7 +143,7 @@ void SPIRVCodeGenerator::writeWord(int32_t word, OutputStream& out) {
|
||||
}
|
||||
|
||||
static bool is_float(const Context& context, const Type& type) {
|
||||
if (type.kind() == Type::kVector_Kind) {
|
||||
if (type.columns() > 1) {
|
||||
return is_float(context, type.componentType());
|
||||
}
|
||||
return type == *context.fFloat_Type || type == *context.fHalf_Type ||
|
||||
@ -1822,38 +1822,67 @@ SpvId SPIRVCodeGenerator::foldToBool(SpvId id, const Type& operandType, SpvOp op
|
||||
|
||||
SpvId SPIRVCodeGenerator::writeMatrixComparison(const Type& operandType, SpvId lhs, SpvId rhs,
|
||||
SpvOp_ floatOperator, SpvOp_ intOperator,
|
||||
SpvOp_ vectorMergeOperator, SpvOp_ mergeOperator,
|
||||
OutputStream& out) {
|
||||
SpvOp_ compareOp = is_float(fContext, operandType) ? floatOperator : intOperator;
|
||||
SkASSERT(operandType.kind() == Type::kMatrix_Kind);
|
||||
SpvId rowType = this->getType(operandType.componentType().toCompound(fContext,
|
||||
operandType.columns(),
|
||||
1));
|
||||
SpvId columnType = this->getType(operandType.componentType().toCompound(fContext,
|
||||
operandType.rows(),
|
||||
1));
|
||||
SpvId bvecType = this->getType(fContext.fBool_Type->toCompound(fContext,
|
||||
operandType.columns(),
|
||||
operandType.rows(),
|
||||
1));
|
||||
SpvId boolType = this->getType(*fContext.fBool_Type);
|
||||
SpvId result = 0;
|
||||
for (int i = 0; i < operandType.rows(); i++) {
|
||||
SpvId rowL = this->nextId();
|
||||
this->writeInstruction(SpvOpCompositeExtract, rowType, rowL, lhs, 0, out);
|
||||
SpvId rowR = this->nextId();
|
||||
this->writeInstruction(SpvOpCompositeExtract, rowType, rowR, rhs, 0, out);
|
||||
for (int i = 0; i < operandType.columns(); i++) {
|
||||
SpvId columnL = this->nextId();
|
||||
this->writeInstruction(SpvOpCompositeExtract, columnType, columnL, lhs, i, out);
|
||||
SpvId columnR = this->nextId();
|
||||
this->writeInstruction(SpvOpCompositeExtract, columnType, columnR, rhs, i, out);
|
||||
SpvId compare = this->nextId();
|
||||
this->writeInstruction(compareOp, bvecType, compare, rowL, rowR, out);
|
||||
SpvId all = this->nextId();
|
||||
this->writeInstruction(SpvOpAll, boolType, all, compare, out);
|
||||
this->writeInstruction(compareOp, bvecType, compare, columnL, columnR, out);
|
||||
SpvId merge = this->nextId();
|
||||
this->writeInstruction(vectorMergeOperator, boolType, merge, compare, out);
|
||||
if (result != 0) {
|
||||
SpvId next = this->nextId();
|
||||
this->writeInstruction(SpvOpLogicalAnd, boolType, next, result, all, out);
|
||||
this->writeInstruction(mergeOperator, boolType, next, result, merge, out);
|
||||
result = next;
|
||||
}
|
||||
else {
|
||||
result = all;
|
||||
result = merge;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
SpvId SPIRVCodeGenerator::writeComponentwiseMatrixBinary(const Type& operandType, SpvId lhs,
|
||||
SpvId rhs, SpvOp_ floatOperator,
|
||||
SpvOp_ intOperator,
|
||||
OutputStream& out) {
|
||||
SpvOp_ op = is_float(fContext, operandType) ? floatOperator : intOperator;
|
||||
SkASSERT(operandType.kind() == Type::kMatrix_Kind);
|
||||
SpvId columnType = this->getType(operandType.componentType().toCompound(fContext,
|
||||
operandType.rows(),
|
||||
1));
|
||||
SpvId columns[4];
|
||||
for (int i = 0; i < operandType.columns(); i++) {
|
||||
SpvId columnL = this->nextId();
|
||||
this->writeInstruction(SpvOpCompositeExtract, columnType, columnL, lhs, i, out);
|
||||
SpvId columnR = this->nextId();
|
||||
this->writeInstruction(SpvOpCompositeExtract, columnType, columnR, rhs, i, out);
|
||||
columns[i] = this->nextId();
|
||||
this->writeInstruction(op, columnType, columns[i], columnL, columnR, out);
|
||||
}
|
||||
SpvId result = this->nextId();
|
||||
this->writeOpCode(SpvOpCompositeConstruct, 3 + operandType.columns(), out);
|
||||
this->writeWord(this->getType(operandType), out);
|
||||
this->writeWord(result, out);
|
||||
for (int i = 0; i < operandType.columns(); i++) {
|
||||
this->writeWord(columns[i], out);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, OutputStream& out) {
|
||||
// handle cases where we don't necessarily evaluate both LHS and RHS
|
||||
switch (b.fOperator) {
|
||||
@ -1964,7 +1993,7 @@ SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, Outpu
|
||||
case Token::EQEQ: {
|
||||
if (operandType->kind() == Type::kMatrix_Kind) {
|
||||
return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFOrdEqual,
|
||||
SpvOpIEqual, out);
|
||||
SpvOpIEqual, SpvOpAll, SpvOpLogicalAnd, out);
|
||||
}
|
||||
SkASSERT(resultType == *fContext.fBool_Type);
|
||||
const Type* tmpType;
|
||||
@ -1983,7 +2012,7 @@ SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, Outpu
|
||||
case Token::NEQ:
|
||||
if (operandType->kind() == Type::kMatrix_Kind) {
|
||||
return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFOrdNotEqual,
|
||||
SpvOpINotEqual, out);
|
||||
SpvOpINotEqual, SpvOpAny, SpvOpLogicalOr, out);
|
||||
}
|
||||
SkASSERT(resultType == *fContext.fBool_Type);
|
||||
const Type* tmpType;
|
||||
@ -2019,9 +2048,21 @@ SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, Outpu
|
||||
SpvOpFOrdLessThanEqual, SpvOpSLessThanEqual,
|
||||
SpvOpULessThanEqual, SpvOpUndef, out);
|
||||
case Token::PLUS:
|
||||
if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
|
||||
b.fRight->fType.kind() == Type::kMatrix_Kind) {
|
||||
SkASSERT(b.fLeft->fType == b.fRight->fType);
|
||||
return this->writeComponentwiseMatrixBinary(b.fLeft->fType, lhs, rhs,
|
||||
SpvOpFAdd, SpvOpIAdd, out);
|
||||
}
|
||||
return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFAdd,
|
||||
SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
|
||||
case Token::MINUS:
|
||||
if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
|
||||
b.fRight->fType.kind() == Type::kMatrix_Kind) {
|
||||
SkASSERT(b.fLeft->fType == b.fRight->fType);
|
||||
return this->writeComponentwiseMatrixBinary(b.fLeft->fType, lhs, rhs,
|
||||
SpvOpFSub, SpvOpISub, out);
|
||||
}
|
||||
return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub,
|
||||
SpvOpISub, SpvOpISub, SpvOpUndef, out);
|
||||
case Token::STAR:
|
||||
@ -2059,15 +2100,33 @@ SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, Outpu
|
||||
return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
|
||||
SpvOpBitwiseXor, SpvOpBitwiseXor, SpvOpUndef, out);
|
||||
case Token::PLUSEQ: {
|
||||
SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFAdd,
|
||||
SpvId result;
|
||||
if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
|
||||
b.fRight->fType.kind() == Type::kMatrix_Kind) {
|
||||
SkASSERT(b.fLeft->fType == b.fRight->fType);
|
||||
result = this->writeComponentwiseMatrixBinary(b.fLeft->fType, lhs, rhs,
|
||||
SpvOpFAdd, SpvOpIAdd, out);
|
||||
}
|
||||
else {
|
||||
result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFAdd,
|
||||
SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
|
||||
}
|
||||
SkASSERT(lvalue);
|
||||
lvalue->store(result, out);
|
||||
return result;
|
||||
}
|
||||
case Token::MINUSEQ: {
|
||||
SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub,
|
||||
SpvId result;
|
||||
if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
|
||||
b.fRight->fType.kind() == Type::kMatrix_Kind) {
|
||||
SkASSERT(b.fLeft->fType == b.fRight->fType);
|
||||
result = this->writeComponentwiseMatrixBinary(b.fLeft->fType, lhs, rhs,
|
||||
SpvOpFSub, SpvOpISub, out);
|
||||
}
|
||||
else {
|
||||
result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub,
|
||||
SpvOpISub, SpvOpISub, SpvOpUndef, out);
|
||||
}
|
||||
SkASSERT(lvalue);
|
||||
lvalue->store(result, out);
|
||||
return result;
|
||||
|
@ -211,7 +211,12 @@ private:
|
||||
SpvId foldToBool(SpvId id, const Type& operandType, SpvOp op, OutputStream& out);
|
||||
|
||||
SpvId writeMatrixComparison(const Type& operandType, SpvId lhs, SpvId rhs, SpvOp_ floatOperator,
|
||||
SpvOp_ intOperator, OutputStream& out);
|
||||
SpvOp_ intOperator, SpvOp_ vectorMergeOperator,
|
||||
SpvOp_ mergeOperator, OutputStream& out);
|
||||
|
||||
SpvId writeComponentwiseMatrixBinary(const Type& operandType, SpvId lhs, SpvId rhs,
|
||||
SpvOp_ floatOperator, SpvOp_ intOperator,
|
||||
OutputStream& out);
|
||||
|
||||
SpvId writeBinaryOperation(const Type& resultType, const Type& operandType, SpvId lhs,
|
||||
SpvId rhs, SpvOp_ ifFloat, SpvOp_ ifInt, SpvOp_ ifUInt,
|
||||
|
Loading…
Reference in New Issue
Block a user