fixed SPIR-V matrix operations

Bug: skia:
Change-Id: I23be824cdd7d00ffd0c54516a168c07e77bb4f49
Reviewed-on: https://skia-review.googlesource.com/140182
Reviewed-by: Greg Daniel <egdaniel@google.com>
Commit-Queue: Ethan Nicholas <ethannicholas@google.com>
This commit is contained in:
Ethan Nicholas 2018-07-10 09:37:51 -04:00 committed by Skia Commit-Bot
parent ed1205ae20
commit 0df21136e3
2 changed files with 84 additions and 20 deletions

View File

@ -143,7 +143,7 @@ void SPIRVCodeGenerator::writeWord(int32_t word, OutputStream& out) {
}
static bool is_float(const Context& context, const Type& type) {
if (type.kind() == Type::kVector_Kind) {
if (type.columns() > 1) {
return is_float(context, type.componentType());
}
return type == *context.fFloat_Type || type == *context.fHalf_Type ||
@ -1822,38 +1822,67 @@ SpvId SPIRVCodeGenerator::foldToBool(SpvId id, const Type& operandType, SpvOp op
SpvId SPIRVCodeGenerator::writeMatrixComparison(const Type& operandType, SpvId lhs, SpvId rhs,
SpvOp_ floatOperator, SpvOp_ intOperator,
SpvOp_ vectorMergeOperator, SpvOp_ mergeOperator,
OutputStream& out) {
SpvOp_ compareOp = is_float(fContext, operandType) ? floatOperator : intOperator;
SkASSERT(operandType.kind() == Type::kMatrix_Kind);
SpvId rowType = this->getType(operandType.componentType().toCompound(fContext,
operandType.columns(),
1));
SpvId columnType = this->getType(operandType.componentType().toCompound(fContext,
operandType.rows(),
1));
SpvId bvecType = this->getType(fContext.fBool_Type->toCompound(fContext,
operandType.columns(),
operandType.rows(),
1));
SpvId boolType = this->getType(*fContext.fBool_Type);
SpvId result = 0;
for (int i = 0; i < operandType.rows(); i++) {
SpvId rowL = this->nextId();
this->writeInstruction(SpvOpCompositeExtract, rowType, rowL, lhs, 0, out);
SpvId rowR = this->nextId();
this->writeInstruction(SpvOpCompositeExtract, rowType, rowR, rhs, 0, out);
for (int i = 0; i < operandType.columns(); i++) {
SpvId columnL = this->nextId();
this->writeInstruction(SpvOpCompositeExtract, columnType, columnL, lhs, i, out);
SpvId columnR = this->nextId();
this->writeInstruction(SpvOpCompositeExtract, columnType, columnR, rhs, i, out);
SpvId compare = this->nextId();
this->writeInstruction(compareOp, bvecType, compare, rowL, rowR, out);
SpvId all = this->nextId();
this->writeInstruction(SpvOpAll, boolType, all, compare, out);
this->writeInstruction(compareOp, bvecType, compare, columnL, columnR, out);
SpvId merge = this->nextId();
this->writeInstruction(vectorMergeOperator, boolType, merge, compare, out);
if (result != 0) {
SpvId next = this->nextId();
this->writeInstruction(SpvOpLogicalAnd, boolType, next, result, all, out);
this->writeInstruction(mergeOperator, boolType, next, result, merge, out);
result = next;
}
else {
result = all;
result = merge;
}
}
return result;
}
SpvId SPIRVCodeGenerator::writeComponentwiseMatrixBinary(const Type& operandType, SpvId lhs,
SpvId rhs, SpvOp_ floatOperator,
SpvOp_ intOperator,
OutputStream& out) {
SpvOp_ op = is_float(fContext, operandType) ? floatOperator : intOperator;
SkASSERT(operandType.kind() == Type::kMatrix_Kind);
SpvId columnType = this->getType(operandType.componentType().toCompound(fContext,
operandType.rows(),
1));
SpvId columns[4];
for (int i = 0; i < operandType.columns(); i++) {
SpvId columnL = this->nextId();
this->writeInstruction(SpvOpCompositeExtract, columnType, columnL, lhs, i, out);
SpvId columnR = this->nextId();
this->writeInstruction(SpvOpCompositeExtract, columnType, columnR, rhs, i, out);
columns[i] = this->nextId();
this->writeInstruction(op, columnType, columns[i], columnL, columnR, out);
}
SpvId result = this->nextId();
this->writeOpCode(SpvOpCompositeConstruct, 3 + operandType.columns(), out);
this->writeWord(this->getType(operandType), out);
this->writeWord(result, out);
for (int i = 0; i < operandType.columns(); i++) {
this->writeWord(columns[i], out);
}
return result;
}
SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, OutputStream& out) {
// handle cases where we don't necessarily evaluate both LHS and RHS
switch (b.fOperator) {
@ -1964,7 +1993,7 @@ SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, Outpu
case Token::EQEQ: {
if (operandType->kind() == Type::kMatrix_Kind) {
return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFOrdEqual,
SpvOpIEqual, out);
SpvOpIEqual, SpvOpAll, SpvOpLogicalAnd, out);
}
SkASSERT(resultType == *fContext.fBool_Type);
const Type* tmpType;
@ -1983,7 +2012,7 @@ SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, Outpu
case Token::NEQ:
if (operandType->kind() == Type::kMatrix_Kind) {
return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFOrdNotEqual,
SpvOpINotEqual, out);
SpvOpINotEqual, SpvOpAny, SpvOpLogicalOr, out);
}
SkASSERT(resultType == *fContext.fBool_Type);
const Type* tmpType;
@ -2019,9 +2048,21 @@ SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, Outpu
SpvOpFOrdLessThanEqual, SpvOpSLessThanEqual,
SpvOpULessThanEqual, SpvOpUndef, out);
case Token::PLUS:
if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
b.fRight->fType.kind() == Type::kMatrix_Kind) {
SkASSERT(b.fLeft->fType == b.fRight->fType);
return this->writeComponentwiseMatrixBinary(b.fLeft->fType, lhs, rhs,
SpvOpFAdd, SpvOpIAdd, out);
}
return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFAdd,
SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
case Token::MINUS:
if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
b.fRight->fType.kind() == Type::kMatrix_Kind) {
SkASSERT(b.fLeft->fType == b.fRight->fType);
return this->writeComponentwiseMatrixBinary(b.fLeft->fType, lhs, rhs,
SpvOpFSub, SpvOpISub, out);
}
return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub,
SpvOpISub, SpvOpISub, SpvOpUndef, out);
case Token::STAR:
@ -2059,15 +2100,33 @@ SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, Outpu
return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
SpvOpBitwiseXor, SpvOpBitwiseXor, SpvOpUndef, out);
case Token::PLUSEQ: {
SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFAdd,
SpvId result;
if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
b.fRight->fType.kind() == Type::kMatrix_Kind) {
SkASSERT(b.fLeft->fType == b.fRight->fType);
result = this->writeComponentwiseMatrixBinary(b.fLeft->fType, lhs, rhs,
SpvOpFAdd, SpvOpIAdd, out);
}
else {
result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFAdd,
SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
}
SkASSERT(lvalue);
lvalue->store(result, out);
return result;
}
case Token::MINUSEQ: {
SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub,
SpvId result;
if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
b.fRight->fType.kind() == Type::kMatrix_Kind) {
SkASSERT(b.fLeft->fType == b.fRight->fType);
result = this->writeComponentwiseMatrixBinary(b.fLeft->fType, lhs, rhs,
SpvOpFSub, SpvOpISub, out);
}
else {
result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub,
SpvOpISub, SpvOpISub, SpvOpUndef, out);
}
SkASSERT(lvalue);
lvalue->store(result, out);
return result;

View File

@ -211,7 +211,12 @@ private:
SpvId foldToBool(SpvId id, const Type& operandType, SpvOp op, OutputStream& out);
SpvId writeMatrixComparison(const Type& operandType, SpvId lhs, SpvId rhs, SpvOp_ floatOperator,
SpvOp_ intOperator, OutputStream& out);
SpvOp_ intOperator, SpvOp_ vectorMergeOperator,
SpvOp_ mergeOperator, OutputStream& out);
SpvId writeComponentwiseMatrixBinary(const Type& operandType, SpvId lhs, SpvId rhs,
SpvOp_ floatOperator, SpvOp_ intOperator,
OutputStream& out);
SpvId writeBinaryOperation(const Type& resultType, const Type& operandType, SpvId lhs,
SpvId rhs, SpvOp_ ifFloat, SpvOp_ ifInt, SpvOp_ ifUInt,