[sksl][wgsl] Support WGSL in SkSLMemoryLayout

- SkSLMemoryLayout now handles WGSL uniform and storage address space
  memory alignment and size calculations.
- MemoryLayout::IsSupported is now an instance method that checks for
  type support based on the specified language standard.
- Unit tests have been added for WGSL memory layout types.

Bug: skia:13092
Change-Id: I2fe161998a0b551ca7c3c1d118b6a8cc1ae95a5b
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/545398
Reviewed-by: John Stiles <johnstiles@google.com>
Commit-Queue: Arman Uguray <armansito@google.com>
This commit is contained in:
Arman Uguray 2022-05-31 20:56:16 -07:00 committed by SkCQ
parent 81737585ff
commit 02031e67d0
5 changed files with 688 additions and 49 deletions

View File

@ -16,32 +16,54 @@ namespace SkSL {
class MemoryLayout {
public:
enum Standard {
k140_Standard,
k430_Standard,
kMetal_Standard
enum class Standard {
// GLSL std140 layout as described in OpenGL Spec v4.5, 7.6.2.2.
k140,
// GLSL std430 layout. This layout is like std140 but with optimizations. This layout can
// ONLY be used with shader storage blocks.
k430,
// MSL memory layout.
kMetal,
// WebGPU Shading Language buffer layout constraints for the uniform address space.
kWGSLUniform,
// WebGPU Shading Language buffer layout constraints for the storage address space.
kWGSLStorage,
};
MemoryLayout(Standard std)
: fStd(std) {}
static size_t vector_alignment(size_t componentSize, int columns) {
return componentSize * (columns + columns % 2);
bool isWGSL() const { return fStd == Standard::kWGSLUniform || fStd == Standard::kWGSLStorage; }
bool isMetal() const { return fStd == Standard::kMetal; }
/**
* WGSL and std140 require various types of variables (structs, arrays, and matrices) in the
* uniform address space to be rounded up to the nearest multiple of 16. This function performs
* the rounding depending on the given `type` and the current memory layout standard.
*
* (For WGSL, see https://www.w3.org/TR/WGSL/#address-space-layout-constraints).
*/
size_t roundUpIfNeeded(size_t raw, Type::TypeKind type) const {
if (fStd == Standard::k140) {
return roundUp16(raw);
}
// WGSL uniform matrix layout is simply the alignment of the matrix columns and
// doesn't have a 16-byte multiple alignment constraint.
if (fStd == Standard::kWGSLUniform && type != Type::TypeKind::kMatrix) {
return roundUp16(raw);
}
return raw;
}
/**
* Rounds up to the nearest multiple of 16 if in std140, otherwise returns the parameter
* unchanged (std140 requires various things to be rounded up to the nearest multiple of 16,
* std430 does not).
* Rounds up the integer `n` to the smallest multiple of 16 greater than `n`.
*/
size_t roundUpIfNeeded(size_t raw) const {
switch (fStd) {
case k140_Standard: return (raw + 15) & ~15;
case k430_Standard: return raw;
case kMetal_Standard: return raw;
}
SkUNREACHABLE;
}
size_t roundUp16(size_t n) const { return (n + 15) & ~15; }
/**
* Returns a type's required alignment when used as a standalone variable.
@ -52,12 +74,14 @@ public:
case Type::TypeKind::kScalar:
return this->size(type);
case Type::TypeKind::kVector:
return vector_alignment(this->size(type.componentType()), type.columns());
return GetVectorAlignment(this->size(type.componentType()), type.columns());
case Type::TypeKind::kMatrix:
return this->roundUpIfNeeded(vector_alignment(this->size(type.componentType()),
type.rows()));
return this->roundUpIfNeeded(
GetVectorAlignment(this->size(type.componentType()), type.rows()),
type.typeKind());
case Type::TypeKind::kArray:
return this->roundUpIfNeeded(this->alignment(type.componentType()));
return this->roundUpIfNeeded(this->alignment(type.componentType()),
type.typeKind());
case Type::TypeKind::kStruct: {
size_t result = 0;
for (const auto& f : type.fields()) {
@ -66,10 +90,10 @@ public:
result = alignment;
}
}
return this->roundUpIfNeeded(result);
return this->roundUpIfNeeded(result, type.typeKind());
}
default:
SK_ABORT("cannot determine size of type %s", type.displayName().c_str());
SK_ABORT("cannot determine alignment of type %s", type.displayName().c_str());
}
}
@ -79,17 +103,15 @@ public:
*/
size_t stride(const Type& type) const {
switch (type.typeKind()) {
case Type::TypeKind::kMatrix: {
size_t base = vector_alignment(this->size(type.componentType()), type.rows());
return this->roundUpIfNeeded(base);
}
case Type::TypeKind::kMatrix:
return this->alignment(type);
case Type::TypeKind::kArray: {
int stride = this->size(type.componentType());
if (stride > 0) {
int align = this->alignment(type.componentType());
stride += align - 1;
stride -= stride % align;
stride = this->roundUpIfNeeded(stride);
stride = this->roundUpIfNeeded(stride, type.typeKind());
}
return stride;
}
@ -99,20 +121,24 @@ public:
}
/**
* Returns the size of a type in bytes.
* Returns the size of a type in bytes. Returns 0 if the given type is not supported.
*/
size_t size(const Type& type) const {
switch (type.typeKind()) {
case Type::TypeKind::kScalar:
if (type.isBoolean()) {
if (this->isWGSL()) {
return 0;
}
return 1;
}
if (fStd == kMetal_Standard && !type.highPrecision() && type.isNumber()) {
if ((this->isMetal() || this->isWGSL()) && !type.highPrecision() &&
type.isNumber()) {
return 2;
}
return 4;
case Type::TypeKind::kVector:
if (fStd == kMetal_Standard && type.columns() == 3) {
if (this->isMetal() && type.columns() == 3) {
return 4 * this->size(type.componentType());
}
return type.columns() * this->size(type.componentType());
@ -142,26 +168,34 @@ public:
/**
* Not all types are compatible with memory layout.
*/
static size_t LayoutIsSupported(const Type& type) {
size_t isSupported(const Type& type) const {
switch (type.typeKind()) {
case Type::TypeKind::kScalar:
// bool and short are not host-shareable in WGSL.
return !this->isWGSL() ||
(!type.isBoolean() && (type.isFloat() || type.highPrecision()));
case Type::TypeKind::kVector:
case Type::TypeKind::kMatrix:
return true;
case Type::TypeKind::kArray:
return LayoutIsSupported(type.componentType());
return this->isSupported(type.componentType());
case Type::TypeKind::kStruct:
return std::all_of(
type.fields().begin(), type.fields().end(),
[](const Type::Field& f) { return LayoutIsSupported(*f.fType); });
type.fields().begin(), type.fields().end(), [this](const Type::Field& f) {
return this->isSupported(*f.fType);
});
default:
return false;
}
}
private:
static size_t GetVectorAlignment(size_t componentSize, int columns) {
return componentSize * (columns + columns % 2);
}
const Standard fStd;
};

View File

@ -2086,12 +2086,12 @@ void MetalCodeGenerator::writeInterfaceBlock(const InterfaceBlock& intf) {
void MetalCodeGenerator::writeFields(const std::vector<Type::Field>& fields, Position parentPos,
const InterfaceBlock* parentIntf) {
MemoryLayout memoryLayout(MemoryLayout::kMetal_Standard);
MemoryLayout memoryLayout(MemoryLayout::Standard::kMetal);
int currentOffset = 0;
for (const Type::Field& field : fields) {
int fieldOffset = field.fModifiers.fLayout.fOffset;
const Type* fieldType = field.fType;
if (!MemoryLayout::LayoutIsSupported(*fieldType)) {
if (!memoryLayout.isSupported(*fieldType)) {
fContext.fErrors->error(parentPos, "type '" + std::string(fieldType->name()) +
"' is not permitted here");
return;

View File

@ -967,7 +967,7 @@ SpvId SPIRVCodeGenerator::writeStruct(const Type& type, const MemoryLayout& memo
size_t offset = 0;
for (int32_t i = 0; i < (int32_t) type.fields().size(); i++) {
const Type::Field& field = type.fields()[i];
if (!MemoryLayout::LayoutIsSupported(*field.fType)) {
if (!memoryLayout.isSupported(*field.fType)) {
fContext.fErrors->error(type.fPosition, "type '" + field.fType->displayName() +
"' is not permitted here");
return resultId;
@ -1070,7 +1070,7 @@ SpvId SPIRVCodeGenerator::getType(const Type& rawType, const MemoryLayout& layou
fConstantBuffer);
}
case Type::TypeKind::kArray: {
if (!MemoryLayout::LayoutIsSupported(*type)) {
if (!layout.isSupported(*type)) {
fContext.fErrors->error(type->fPosition, "type '" + type->displayName() +
"' is not permitted here");
return NA;
@ -3325,7 +3325,7 @@ void SPIRVCodeGenerator::writeFieldLayout(const Layout& layout, SpvId target, in
MemoryLayout SPIRVCodeGenerator::memoryLayoutForVariable(const Variable& v) const {
bool pushConstant = ((v.modifiers().fLayout.fFlags & Layout::kPushConstant_Flag) != 0);
return pushConstant ? MemoryLayout(MemoryLayout::k430_Standard) : fDefaultLayout;
return pushConstant ? MemoryLayout(MemoryLayout::Standard::k430) : fDefaultLayout;
}
SpvId SPIRVCodeGenerator::writeInterfaceBlock(const InterfaceBlock& intf, bool appendRTFlip) {
@ -3333,7 +3333,7 @@ SpvId SPIRVCodeGenerator::writeInterfaceBlock(const InterfaceBlock& intf, bool a
SpvId result = this->nextId(nullptr);
const Variable& intfVar = intf.variable();
const Type& type = intfVar.type();
if (!MemoryLayout::LayoutIsSupported(type)) {
if (!memoryLayout.isSupported(type)) {
fContext.fErrors->error(type.fPosition, "type '" + type.displayName() +
"' is not permitted here");
return this->nextId(nullptr);

View File

@ -102,11 +102,9 @@ public:
virtual void store(SpvId value, OutputStream& out) = 0;
};
SPIRVCodeGenerator(const Context* context,
const Program* program,
OutputStream* out)
SPIRVCodeGenerator(const Context* context, const Program* program, OutputStream* out)
: INHERITED(context, program, out)
, fDefaultLayout(MemoryLayout::k140_Standard)
, fDefaultLayout(MemoryLayout::Standard::k140)
, fCapabilities(0)
, fIdCount(1)
, fCurrentBlock(0)

View File

@ -26,7 +26,7 @@ DEF_TEST(SkSLMemoryLayout140Test, r) {
SkSL::ShaderCaps caps;
SkSL::Mangler mangler;
SkSL::Context context(errors, caps, mangler);
SkSL::MemoryLayout layout(SkSL::MemoryLayout::k140_Standard);
SkSL::MemoryLayout layout(SkSL::MemoryLayout::Standard::k140);
// basic types
REPORTER_ASSERT(r, 4 == layout.size(*context.fTypes.fFloat));
@ -122,7 +122,7 @@ DEF_TEST(SkSLMemoryLayout430Test, r) {
SkSL::ShaderCaps caps;
SkSL::Mangler mangler;
SkSL::Context context(errors, caps, mangler);
SkSL::MemoryLayout layout(SkSL::MemoryLayout::k430_Standard);
SkSL::MemoryLayout layout(SkSL::MemoryLayout::Standard::k430);
// basic types
REPORTER_ASSERT(r, 4 == layout.size(*context.fTypes.fFloat));
@ -212,3 +212,610 @@ DEF_TEST(SkSLMemoryLayout430Test, r) {
REPORTER_ASSERT(r, 16 == layout.alignment(*array2));
REPORTER_ASSERT(r, 16 == layout.stride(*array2));
}
DEF_TEST(SkSLMemoryLayoutWGSLUniformTest, r) {
SkSL::TestingOnly_AbortErrorReporter errors;
SkSL::ShaderCaps caps;
SkSL::Mangler mangler;
SkSL::Context context(errors, caps, mangler);
SkSL::MemoryLayout layout(SkSL::MemoryLayout::Standard::kWGSLUniform);
// The values here are taken from https://www.w3.org/TR/WGSL/#alignment-and-size, table titled
// "Alignment and size for host-shareable types".
// scalars (i32, u32, f32, f16)
REPORTER_ASSERT(r, 4 == layout.size(*context.fTypes.fInt));
REPORTER_ASSERT(r, 4 == layout.size(*context.fTypes.fUInt));
REPORTER_ASSERT(r, 4 == layout.size(*context.fTypes.fFloat));
REPORTER_ASSERT(r, 2 == layout.size(*context.fTypes.fHalf));
REPORTER_ASSERT(r, 4 == layout.alignment(*context.fTypes.fInt));
REPORTER_ASSERT(r, 4 == layout.alignment(*context.fTypes.fUInt));
REPORTER_ASSERT(r, 4 == layout.alignment(*context.fTypes.fFloat));
REPORTER_ASSERT(r, 2 == layout.alignment(*context.fTypes.fHalf));
// vec2<T>, T: i32, u32, f32, f16
REPORTER_ASSERT(r, 8 == layout.size(*context.fTypes.fInt2));
REPORTER_ASSERT(r, 8 == layout.size(*context.fTypes.fUInt2));
REPORTER_ASSERT(r, 8 == layout.size(*context.fTypes.fFloat2));
REPORTER_ASSERT(r, 4 == layout.size(*context.fTypes.fHalf2));
REPORTER_ASSERT(r, 8 == layout.alignment(*context.fTypes.fInt2));
REPORTER_ASSERT(r, 8 == layout.alignment(*context.fTypes.fUInt2));
REPORTER_ASSERT(r, 8 == layout.alignment(*context.fTypes.fFloat2));
REPORTER_ASSERT(r, 4 == layout.alignment(*context.fTypes.fHalf2));
// vec3<T>, T: i32, u32, f32, f16
REPORTER_ASSERT(r, 12 == layout.size(*context.fTypes.fInt3));
REPORTER_ASSERT(r, 12 == layout.size(*context.fTypes.fUInt3));
REPORTER_ASSERT(r, 12 == layout.size(*context.fTypes.fFloat3));
REPORTER_ASSERT(r, 6 == layout.size(*context.fTypes.fHalf3));
REPORTER_ASSERT(r, 16 == layout.alignment(*context.fTypes.fInt3));
REPORTER_ASSERT(r, 16 == layout.alignment(*context.fTypes.fUInt3));
REPORTER_ASSERT(r, 16 == layout.alignment(*context.fTypes.fFloat3));
REPORTER_ASSERT(r, 8 == layout.alignment(*context.fTypes.fHalf3));
// vec4<T>, T: i32, u32, f32, f16
REPORTER_ASSERT(r, 16 == layout.size(*context.fTypes.fInt4));
REPORTER_ASSERT(r, 16 == layout.size(*context.fTypes.fUInt4));
REPORTER_ASSERT(r, 16 == layout.size(*context.fTypes.fFloat4));
REPORTER_ASSERT(r, 8 == layout.size(*context.fTypes.fHalf4));
REPORTER_ASSERT(r, 16 == layout.alignment(*context.fTypes.fInt4));
REPORTER_ASSERT(r, 16 == layout.alignment(*context.fTypes.fUInt4));
REPORTER_ASSERT(r, 16 == layout.alignment(*context.fTypes.fFloat4));
REPORTER_ASSERT(r, 8 == layout.alignment(*context.fTypes.fHalf4));
// mat2x2<f32>, mat2x2<f16>
REPORTER_ASSERT(r, 16 == layout.size(*context.fTypes.fFloat2x2));
REPORTER_ASSERT(r, 8 == layout.size(*context.fTypes.fHalf2x2));
REPORTER_ASSERT(r, 8 == layout.alignment(*context.fTypes.fFloat2x2));
REPORTER_ASSERT(r, 4 == layout.alignment(*context.fTypes.fHalf2x2));
REPORTER_ASSERT(r, 8 == layout.stride(*context.fTypes.fFloat2x2));
REPORTER_ASSERT(r, 4 == layout.stride(*context.fTypes.fHalf2x2));
// mat3x2<f32>, mat3x2<f16>
REPORTER_ASSERT(r, 24 == layout.size(*context.fTypes.fFloat3x2));
REPORTER_ASSERT(r, 12 == layout.size(*context.fTypes.fHalf3x2));
REPORTER_ASSERT(r, 8 == layout.alignment(*context.fTypes.fFloat3x2));
REPORTER_ASSERT(r, 4 == layout.alignment(*context.fTypes.fHalf3x2));
REPORTER_ASSERT(r, 8 == layout.stride(*context.fTypes.fFloat3x2));
REPORTER_ASSERT(r, 4 == layout.stride(*context.fTypes.fHalf3x2));
// mat4x2<f32>, mat4x2<f16>
REPORTER_ASSERT(r, 32 == layout.size(*context.fTypes.fFloat4x2));
REPORTER_ASSERT(r, 16 == layout.size(*context.fTypes.fHalf4x2));
REPORTER_ASSERT(r, 8 == layout.alignment(*context.fTypes.fFloat4x2));
REPORTER_ASSERT(r, 4 == layout.alignment(*context.fTypes.fHalf4x2));
REPORTER_ASSERT(r, 8 == layout.stride(*context.fTypes.fFloat4x2));
REPORTER_ASSERT(r, 4 == layout.stride(*context.fTypes.fHalf4x2));
// mat2x3<f32>, mat2x3<f16>
REPORTER_ASSERT(r, 32 == layout.size(*context.fTypes.fFloat2x3));
REPORTER_ASSERT(r, 16 == layout.size(*context.fTypes.fHalf2x3));
REPORTER_ASSERT(r, 16 == layout.alignment(*context.fTypes.fFloat2x3));
REPORTER_ASSERT(r, 8 == layout.alignment(*context.fTypes.fHalf2x3));
REPORTER_ASSERT(r, 16 == layout.stride(*context.fTypes.fFloat2x3));
REPORTER_ASSERT(r, 8 == layout.stride(*context.fTypes.fHalf2x3));
// mat3x3<f32>, mat3x3<f16>
REPORTER_ASSERT(r, 48 == layout.size(*context.fTypes.fFloat3x3));
REPORTER_ASSERT(r, 24 == layout.size(*context.fTypes.fHalf3x3));
REPORTER_ASSERT(r, 16 == layout.alignment(*context.fTypes.fFloat3x3));
REPORTER_ASSERT(r, 8 == layout.alignment(*context.fTypes.fHalf3x3));
REPORTER_ASSERT(r, 16 == layout.stride(*context.fTypes.fFloat3x3));
REPORTER_ASSERT(r, 8 == layout.stride(*context.fTypes.fHalf3x3));
// mat4x3<f32>, mat4x3<f16>
REPORTER_ASSERT(r, 64 == layout.size(*context.fTypes.fFloat4x3));
REPORTER_ASSERT(r, 32 == layout.size(*context.fTypes.fHalf4x3));
REPORTER_ASSERT(r, 16 == layout.alignment(*context.fTypes.fFloat4x3));
REPORTER_ASSERT(r, 8 == layout.alignment(*context.fTypes.fHalf4x3));
REPORTER_ASSERT(r, 16 == layout.stride(*context.fTypes.fFloat4x3));
REPORTER_ASSERT(r, 8 == layout.stride(*context.fTypes.fHalf4x3));
// mat2x4<f32>, mat2x4<f16>
REPORTER_ASSERT(r, 32 == layout.size(*context.fTypes.fFloat2x4));
REPORTER_ASSERT(r, 16 == layout.size(*context.fTypes.fHalf2x4));
REPORTER_ASSERT(r, 16 == layout.alignment(*context.fTypes.fFloat2x4));
REPORTER_ASSERT(r, 8 == layout.alignment(*context.fTypes.fHalf2x4));
REPORTER_ASSERT(r, 16 == layout.stride(*context.fTypes.fFloat2x4));
REPORTER_ASSERT(r, 8 == layout.stride(*context.fTypes.fHalf2x4));
// mat3x4<f32>, mat3x4<f16>
REPORTER_ASSERT(r, 48 == layout.size(*context.fTypes.fFloat3x4));
REPORTER_ASSERT(r, 24 == layout.size(*context.fTypes.fHalf3x4));
REPORTER_ASSERT(r, 16 == layout.alignment(*context.fTypes.fFloat3x4));
REPORTER_ASSERT(r, 8 == layout.alignment(*context.fTypes.fHalf3x4));
REPORTER_ASSERT(r, 16 == layout.stride(*context.fTypes.fFloat3x4));
REPORTER_ASSERT(r, 8 == layout.stride(*context.fTypes.fHalf3x4));
// mat4x4<f32>, mat4x4<f16>
REPORTER_ASSERT(r, 64 == layout.size(*context.fTypes.fFloat4x4));
REPORTER_ASSERT(r, 32 == layout.size(*context.fTypes.fHalf4x4));
REPORTER_ASSERT(r, 16 == layout.alignment(*context.fTypes.fFloat4x4));
REPORTER_ASSERT(r, 8 == layout.alignment(*context.fTypes.fHalf4x4));
REPORTER_ASSERT(r, 16 == layout.stride(*context.fTypes.fFloat4x4));
REPORTER_ASSERT(r, 8 == layout.stride(*context.fTypes.fHalf4x4));
// bool is not a host-shareable type and returns 0 for WGSL.
REPORTER_ASSERT(r, 0 == layout.size(*context.fTypes.fBool));
REPORTER_ASSERT(r, 0 == layout.size(*context.fTypes.fBool2));
REPORTER_ASSERT(r, 0 == layout.size(*context.fTypes.fBool3));
REPORTER_ASSERT(r, 0 == layout.size(*context.fTypes.fBool4));
// Arrays
// array<f32, 4>
{
auto array = SkSL::Type::MakeArrayType("float[4]", *context.fTypes.fFloat, 4);
REPORTER_ASSERT(r, 64 == layout.size(*array));
REPORTER_ASSERT(r, 16 == layout.alignment(*array));
REPORTER_ASSERT(r, 16 == layout.stride(*array));
}
// array<f16, 4>
{
auto array = SkSL::Type::MakeArrayType("half[4]", *context.fTypes.fHalf, 4);
REPORTER_ASSERT(r, 64 == layout.size(*array));
REPORTER_ASSERT(r, 16 == layout.alignment(*array));
REPORTER_ASSERT(r, 16 == layout.stride(*array));
}
// array<vec2<f32>, 4>
{
auto array = SkSL::Type::MakeArrayType("float2[4]", *context.fTypes.fFloat2, 4);
REPORTER_ASSERT(r, 64 == layout.size(*array));
REPORTER_ASSERT(r, 16 == layout.alignment(*array));
REPORTER_ASSERT(r, 16 == layout.stride(*array));
}
// array<vec3<f32>, 4>
{
auto array = SkSL::Type::MakeArrayType("float3[4]", *context.fTypes.fFloat3, 4);
REPORTER_ASSERT(r, 64 == layout.size(*array));
REPORTER_ASSERT(r, 16 == layout.alignment(*array));
REPORTER_ASSERT(r, 16 == layout.stride(*array));
}
// array<vec4<f32>, 4>
{
auto array = SkSL::Type::MakeArrayType("float4[4]", *context.fTypes.fFloat4, 4);
REPORTER_ASSERT(r, 64 == layout.size(*array));
REPORTER_ASSERT(r, 16 == layout.alignment(*array));
REPORTER_ASSERT(r, 16 == layout.stride(*array));
}
// array<mat3x3<f32>, 4>
{
auto array = SkSL::Type::MakeArrayType("mat3[4]", *context.fTypes.fFloat3x3, 4);
REPORTER_ASSERT(r, 192 == layout.size(*array));
REPORTER_ASSERT(r, 16 == layout.alignment(*array));
REPORTER_ASSERT(r, 48 == layout.stride(*array));
}
// Structs A and B from example in https://www.w3.org/TR/WGSL/#structure-member-layout, with
// offsets adjusted for uniform address space constraints.
//
// struct A { // align(roundUp(16, 8)) size(roundUp(16, 24))
// u: f32, // offset(0) align(4) size(4)
// v: f32, // offset(4) align(4) size(4)
// w: vec2<f32>, // offset(8) align(8) size(8)
// x: f32 // offset(16) align(4) size(4)
// // padding // offset(20) size(12)
// }
std::vector<SkSL::Type::Field> fields;
fields.emplace_back(SkSL::Position(),
SkSL::Modifiers(),
std::string_view("u"),
context.fTypes.fFloat.get());
fields.emplace_back(SkSL::Position(),
SkSL::Modifiers(),
std::string_view("v"),
context.fTypes.fFloat.get());
fields.emplace_back(SkSL::Position(),
SkSL::Modifiers(),
std::string_view("v"),
context.fTypes.fFloat2.get());
fields.emplace_back(SkSL::Position(),
SkSL::Modifiers(),
std::string_view("w"),
context.fTypes.fFloat.get());
std::unique_ptr<SkSL::Type> structA =
SkSL::Type::MakeStructType(SkSL::Position(), std::string_view("A"), std::move(fields));
REPORTER_ASSERT(r, 32 == layout.size(*structA));
REPORTER_ASSERT(r, 16 == layout.alignment(*structA));
fields = {};
// struct B { // align(16) size(208)
// a: vec2<f32>, // offset(0) align(8) size(8)
// // padding // offset(8) size(8)
// b: vec3<f32>, // offset(16) align(16) size(12)
// c: f32, // offset(28) align(4) size(4)
// d: f32, // offset(32) align(4) size(4)
// // padding // offset(36) size(12)
// e: A, // offset(48) align(16) size(32)
// f: vec3<f32>, // offset(80) align(16) size(12)
// // padding // offset(92) size(4)
// g: array<A, 3>, // offset(96) align(16) size(96)
// h: i32 // offset(192) align(4) size(4)
// // padding // offset(196) size(12)
// }
fields.emplace_back(SkSL::Position(),
SkSL::Modifiers(),
std::string_view("a"),
context.fTypes.fFloat2.get());
fields.emplace_back(SkSL::Position(),
SkSL::Modifiers(),
std::string_view("b"),
context.fTypes.fFloat3.get());
fields.emplace_back(SkSL::Position(),
SkSL::Modifiers(),
std::string_view("c"),
context.fTypes.fFloat.get());
fields.emplace_back(SkSL::Position(),
SkSL::Modifiers(),
std::string_view("d"),
context.fTypes.fFloat.get());
fields.emplace_back(SkSL::Position(), SkSL::Modifiers(), std::string_view("e"), structA.get());
fields.emplace_back(SkSL::Position(),
SkSL::Modifiers(),
std::string_view("f"),
context.fTypes.fFloat3.get());
auto array = SkSL::Type::MakeArrayType("A[3]", *structA, 3);
fields.emplace_back(SkSL::Position(), SkSL::Modifiers(), std::string_view("g"), array.get());
fields.emplace_back(
SkSL::Position(), SkSL::Modifiers(), std::string_view("h"), context.fTypes.fInt.get());
std::unique_ptr<SkSL::Type> structB =
SkSL::Type::MakeStructType(SkSL::Position(), std::string_view("B"), std::move(fields));
REPORTER_ASSERT(r, 208 == layout.size(*structB));
REPORTER_ASSERT(r, 16 == layout.alignment(*structB));
}
DEF_TEST(SkSLMemoryLayoutWGSLStorageTest, r) {
SkSL::TestingOnly_AbortErrorReporter errors;
SkSL::ShaderCaps caps;
SkSL::Mangler mangler;
SkSL::Context context(errors, caps, mangler);
SkSL::MemoryLayout layout(SkSL::MemoryLayout::Standard::kWGSLStorage);
// The values here are taken from https://www.w3.org/TR/WGSL/#alignment-and-size, table titled
// "Alignment and size for host-shareable types".
// scalars (i32, u32, f32, f16)
REPORTER_ASSERT(r, 4 == layout.size(*context.fTypes.fInt));
REPORTER_ASSERT(r, 4 == layout.size(*context.fTypes.fUInt));
REPORTER_ASSERT(r, 4 == layout.size(*context.fTypes.fFloat));
REPORTER_ASSERT(r, 2 == layout.size(*context.fTypes.fHalf));
REPORTER_ASSERT(r, 4 == layout.alignment(*context.fTypes.fInt));
REPORTER_ASSERT(r, 4 == layout.alignment(*context.fTypes.fUInt));
REPORTER_ASSERT(r, 4 == layout.alignment(*context.fTypes.fFloat));
REPORTER_ASSERT(r, 2 == layout.alignment(*context.fTypes.fHalf));
// vec2<T>, T: i32, u32, f32, f16
REPORTER_ASSERT(r, 8 == layout.size(*context.fTypes.fInt2));
REPORTER_ASSERT(r, 8 == layout.size(*context.fTypes.fUInt2));
REPORTER_ASSERT(r, 8 == layout.size(*context.fTypes.fFloat2));
REPORTER_ASSERT(r, 4 == layout.size(*context.fTypes.fHalf2));
REPORTER_ASSERT(r, 8 == layout.alignment(*context.fTypes.fInt2));
REPORTER_ASSERT(r, 8 == layout.alignment(*context.fTypes.fUInt2));
REPORTER_ASSERT(r, 8 == layout.alignment(*context.fTypes.fFloat2));
REPORTER_ASSERT(r, 4 == layout.alignment(*context.fTypes.fHalf2));
// vec3<T>, T: i32, u32, f32, f16
REPORTER_ASSERT(r, 12 == layout.size(*context.fTypes.fInt3));
REPORTER_ASSERT(r, 12 == layout.size(*context.fTypes.fUInt3));
REPORTER_ASSERT(r, 12 == layout.size(*context.fTypes.fFloat3));
REPORTER_ASSERT(r, 6 == layout.size(*context.fTypes.fHalf3));
REPORTER_ASSERT(r, 16 == layout.alignment(*context.fTypes.fInt3));
REPORTER_ASSERT(r, 16 == layout.alignment(*context.fTypes.fUInt3));
REPORTER_ASSERT(r, 16 == layout.alignment(*context.fTypes.fFloat3));
REPORTER_ASSERT(r, 8 == layout.alignment(*context.fTypes.fHalf3));
// vec4<T>, T: i32, u32, f32, f16
REPORTER_ASSERT(r, 16 == layout.size(*context.fTypes.fInt4));
REPORTER_ASSERT(r, 16 == layout.size(*context.fTypes.fUInt4));
REPORTER_ASSERT(r, 16 == layout.size(*context.fTypes.fFloat4));
REPORTER_ASSERT(r, 8 == layout.size(*context.fTypes.fHalf4));
REPORTER_ASSERT(r, 16 == layout.alignment(*context.fTypes.fInt4));
REPORTER_ASSERT(r, 16 == layout.alignment(*context.fTypes.fUInt4));
REPORTER_ASSERT(r, 16 == layout.alignment(*context.fTypes.fFloat4));
REPORTER_ASSERT(r, 8 == layout.alignment(*context.fTypes.fHalf4));
// mat2x2<f32>, mat2x2<f16>
REPORTER_ASSERT(r, 16 == layout.size(*context.fTypes.fFloat2x2));
REPORTER_ASSERT(r, 8 == layout.size(*context.fTypes.fHalf2x2));
REPORTER_ASSERT(r, 8 == layout.alignment(*context.fTypes.fFloat2x2));
REPORTER_ASSERT(r, 4 == layout.alignment(*context.fTypes.fHalf2x2));
REPORTER_ASSERT(r, 8 == layout.stride(*context.fTypes.fFloat2x2));
REPORTER_ASSERT(r, 4 == layout.stride(*context.fTypes.fHalf2x2));
// mat3x2<f32>, mat3x2<f16>
REPORTER_ASSERT(r, 24 == layout.size(*context.fTypes.fFloat3x2));
REPORTER_ASSERT(r, 12 == layout.size(*context.fTypes.fHalf3x2));
REPORTER_ASSERT(r, 8 == layout.alignment(*context.fTypes.fFloat3x2));
REPORTER_ASSERT(r, 4 == layout.alignment(*context.fTypes.fHalf3x2));
REPORTER_ASSERT(r, 8 == layout.stride(*context.fTypes.fFloat3x2));
REPORTER_ASSERT(r, 4 == layout.stride(*context.fTypes.fHalf3x2));
// mat4x2<f32>, mat4x2<f16>
REPORTER_ASSERT(r, 32 == layout.size(*context.fTypes.fFloat4x2));
REPORTER_ASSERT(r, 16 == layout.size(*context.fTypes.fHalf4x2));
REPORTER_ASSERT(r, 8 == layout.alignment(*context.fTypes.fFloat4x2));
REPORTER_ASSERT(r, 4 == layout.alignment(*context.fTypes.fHalf4x2));
REPORTER_ASSERT(r, 8 == layout.stride(*context.fTypes.fFloat4x2));
REPORTER_ASSERT(r, 4 == layout.stride(*context.fTypes.fHalf4x2));
// mat2x3<f32>, mat2x3<f16>
REPORTER_ASSERT(r, 32 == layout.size(*context.fTypes.fFloat2x3));
REPORTER_ASSERT(r, 16 == layout.size(*context.fTypes.fHalf2x3));
REPORTER_ASSERT(r, 16 == layout.alignment(*context.fTypes.fFloat2x3));
REPORTER_ASSERT(r, 8 == layout.alignment(*context.fTypes.fHalf2x3));
REPORTER_ASSERT(r, 16 == layout.stride(*context.fTypes.fFloat2x3));
REPORTER_ASSERT(r, 8 == layout.stride(*context.fTypes.fHalf2x3));
// mat3x3<f32>, mat3x3<f16>
REPORTER_ASSERT(r, 48 == layout.size(*context.fTypes.fFloat3x3));
REPORTER_ASSERT(r, 24 == layout.size(*context.fTypes.fHalf3x3));
REPORTER_ASSERT(r, 16 == layout.alignment(*context.fTypes.fFloat3x3));
REPORTER_ASSERT(r, 8 == layout.alignment(*context.fTypes.fHalf3x3));
REPORTER_ASSERT(r, 16 == layout.stride(*context.fTypes.fFloat3x3));
REPORTER_ASSERT(r, 8 == layout.stride(*context.fTypes.fHalf3x3));
// mat4x3<f32>, mat4x3<f16>
REPORTER_ASSERT(r, 64 == layout.size(*context.fTypes.fFloat4x3));
REPORTER_ASSERT(r, 32 == layout.size(*context.fTypes.fHalf4x3));
REPORTER_ASSERT(r, 16 == layout.alignment(*context.fTypes.fFloat4x3));
REPORTER_ASSERT(r, 8 == layout.alignment(*context.fTypes.fHalf4x3));
REPORTER_ASSERT(r, 16 == layout.stride(*context.fTypes.fFloat4x3));
REPORTER_ASSERT(r, 8 == layout.stride(*context.fTypes.fHalf4x3));
// mat2x4<f32>, mat2x4<f16>
REPORTER_ASSERT(r, 32 == layout.size(*context.fTypes.fFloat2x4));
REPORTER_ASSERT(r, 16 == layout.size(*context.fTypes.fHalf2x4));
REPORTER_ASSERT(r, 16 == layout.alignment(*context.fTypes.fFloat2x4));
REPORTER_ASSERT(r, 8 == layout.alignment(*context.fTypes.fHalf2x4));
REPORTER_ASSERT(r, 16 == layout.stride(*context.fTypes.fFloat2x4));
REPORTER_ASSERT(r, 8 == layout.stride(*context.fTypes.fHalf2x4));
// mat3x4<f32>, mat3x4<f16>
REPORTER_ASSERT(r, 48 == layout.size(*context.fTypes.fFloat3x4));
REPORTER_ASSERT(r, 24 == layout.size(*context.fTypes.fHalf3x4));
REPORTER_ASSERT(r, 16 == layout.alignment(*context.fTypes.fFloat3x4));
REPORTER_ASSERT(r, 8 == layout.alignment(*context.fTypes.fHalf3x4));
REPORTER_ASSERT(r, 16 == layout.stride(*context.fTypes.fFloat3x4));
REPORTER_ASSERT(r, 8 == layout.stride(*context.fTypes.fHalf3x4));
// mat4x4<f32>, mat4x4<f16>
REPORTER_ASSERT(r, 64 == layout.size(*context.fTypes.fFloat4x4));
REPORTER_ASSERT(r, 32 == layout.size(*context.fTypes.fHalf4x4));
REPORTER_ASSERT(r, 16 == layout.alignment(*context.fTypes.fFloat4x4));
REPORTER_ASSERT(r, 8 == layout.alignment(*context.fTypes.fHalf4x4));
REPORTER_ASSERT(r, 16 == layout.stride(*context.fTypes.fFloat4x4));
REPORTER_ASSERT(r, 8 == layout.stride(*context.fTypes.fHalf4x4));
// bool is not a host-shareable type and returns 0 for WGSL.
REPORTER_ASSERT(r, 0 == layout.size(*context.fTypes.fBool));
REPORTER_ASSERT(r, 0 == layout.size(*context.fTypes.fBool2));
REPORTER_ASSERT(r, 0 == layout.size(*context.fTypes.fBool3));
REPORTER_ASSERT(r, 0 == layout.size(*context.fTypes.fBool4));
// Arrays
// array<f32, 4>
{
auto array = SkSL::Type::MakeArrayType("float[4]", *context.fTypes.fFloat, 4);
REPORTER_ASSERT(r, 16 == layout.size(*array));
REPORTER_ASSERT(r, 4 == layout.alignment(*array));
REPORTER_ASSERT(r, 4 == layout.stride(*array));
}
// array<f16, 4>
{
auto array = SkSL::Type::MakeArrayType("half[4]", *context.fTypes.fHalf, 4);
REPORTER_ASSERT(r, 8 == layout.size(*array));
REPORTER_ASSERT(r, 2 == layout.alignment(*array));
REPORTER_ASSERT(r, 2 == layout.stride(*array));
}
// array<vec2<f32>, 4>
{
auto array = SkSL::Type::MakeArrayType("float2[4]", *context.fTypes.fFloat2, 4);
REPORTER_ASSERT(r, 32 == layout.size(*array));
REPORTER_ASSERT(r, 8 == layout.alignment(*array));
REPORTER_ASSERT(r, 8 == layout.stride(*array));
}
// array<vec3<f32>, 4>
{
auto array = SkSL::Type::MakeArrayType("float3[4]", *context.fTypes.fFloat3, 4);
REPORTER_ASSERT(r, 64 == layout.size(*array));
REPORTER_ASSERT(r, 16 == layout.alignment(*array));
REPORTER_ASSERT(r, 16 == layout.stride(*array));
}
// array<vec4<f32>, 4>
{
auto array = SkSL::Type::MakeArrayType("float4[4]", *context.fTypes.fFloat4, 4);
REPORTER_ASSERT(r, 64 == layout.size(*array));
REPORTER_ASSERT(r, 16 == layout.alignment(*array));
REPORTER_ASSERT(r, 16 == layout.stride(*array));
}
// array<mat3x3<f32>, 4>
{
auto array = SkSL::Type::MakeArrayType("mat3[4]", *context.fTypes.fFloat3x3, 4);
REPORTER_ASSERT(r, 192 == layout.size(*array));
REPORTER_ASSERT(r, 16 == layout.alignment(*array));
REPORTER_ASSERT(r, 48 == layout.stride(*array));
}
// Structs A and B from example in https://www.w3.org/TR/WGSL/#structure-member-layout
//
// struct A { // align(8) size(24)
// u: f32, // offset(0) align(4) size(4)
// v: f32, // offset(4) align(4) size(4)
// w: vec2<f32>, // offset(8) align(8) size(8)
// x: f32 // offset(16) align(4) size(4)
// // padding // offset(20) size(4)
// }
std::vector<SkSL::Type::Field> fields;
fields.emplace_back(SkSL::Position(),
SkSL::Modifiers(),
std::string_view("u"),
context.fTypes.fFloat.get());
fields.emplace_back(SkSL::Position(),
SkSL::Modifiers(),
std::string_view("v"),
context.fTypes.fFloat.get());
fields.emplace_back(SkSL::Position(),
SkSL::Modifiers(),
std::string_view("v"),
context.fTypes.fFloat2.get());
fields.emplace_back(SkSL::Position(),
SkSL::Modifiers(),
std::string_view("w"),
context.fTypes.fFloat.get());
std::unique_ptr<SkSL::Type> structA =
SkSL::Type::MakeStructType(SkSL::Position(), std::string_view("A"), std::move(fields));
REPORTER_ASSERT(r, 24 == layout.size(*structA));
REPORTER_ASSERT(r, 8 == layout.alignment(*structA));
fields = {};
// struct B { // align(16) size(160)
// a: vec2<f32>, // offset(0) align(8) size(8)
// // padding // offset(8) size(8)
// b: vec3<f32>, // offset(16) align(16) size(12)
// c: f32, // offset(28) align(4) size(4)
// d: f32, // offset(32) align(4) size(4)
// // padding // offset(36) size(4)
// e: A, // offset(40) align(8) size(24)
// f: vec3<f32>, // offset(64) align(16) size(12)
// // padding // offset(76) size(4)
// g: array<A, 3>, // offset(80) align(16) size(72)
// h: i32 // offset(152) align(4) size(4)
// // padding // offset(156) size(4)
// }
fields.emplace_back(SkSL::Position(),
SkSL::Modifiers(),
std::string_view("a"),
context.fTypes.fFloat2.get());
fields.emplace_back(SkSL::Position(),
SkSL::Modifiers(),
std::string_view("b"),
context.fTypes.fFloat3.get());
fields.emplace_back(SkSL::Position(),
SkSL::Modifiers(),
std::string_view("c"),
context.fTypes.fFloat.get());
fields.emplace_back(SkSL::Position(),
SkSL::Modifiers(),
std::string_view("d"),
context.fTypes.fFloat.get());
fields.emplace_back(SkSL::Position(), SkSL::Modifiers(), std::string_view("e"), structA.get());
fields.emplace_back(SkSL::Position(),
SkSL::Modifiers(),
std::string_view("f"),
context.fTypes.fFloat3.get());
auto array = SkSL::Type::MakeArrayType("A[3]", *structA, 3);
fields.emplace_back(SkSL::Position(), SkSL::Modifiers(), std::string_view("g"), array.get());
fields.emplace_back(
SkSL::Position(), SkSL::Modifiers(), std::string_view("h"), context.fTypes.fInt.get());
std::unique_ptr<SkSL::Type> structB =
SkSL::Type::MakeStructType(SkSL::Position(), std::string_view("B"), std::move(fields));
REPORTER_ASSERT(r, 160 == layout.size(*structB));
REPORTER_ASSERT(r, 16 == layout.alignment(*structB));
}
DEF_TEST(SkSLMemoryLayoutWGSLUnsupportedTypesTest, r) {
SkSL::TestingOnly_AbortErrorReporter errors;
SkSL::ShaderCaps caps;
SkSL::Mangler mangler;
SkSL::Context context(errors, caps, mangler);
auto testArray = SkSL::Type::MakeArrayType("bool[3]", *context.fTypes.fBool, 3);
std::vector<SkSL::Type::Field> fields;
fields.emplace_back(
SkSL::Position(), SkSL::Modifiers(), std::string_view("foo"), testArray.get());
auto testStruct = SkSL::Type::MakeStructType(
SkSL::Position(), std::string_view("Test"), std::move(fields));
SkSL::MemoryLayout layout(SkSL::MemoryLayout::Standard::kWGSLUniform);
REPORTER_ASSERT(r, !layout.isSupported(*context.fTypes.fBool));
REPORTER_ASSERT(r, !layout.isSupported(*context.fTypes.fBool2));
REPORTER_ASSERT(r, !layout.isSupported(*context.fTypes.fBool3));
REPORTER_ASSERT(r, !layout.isSupported(*context.fTypes.fBool4));
REPORTER_ASSERT(r, !layout.isSupported(*context.fTypes.fShort));
REPORTER_ASSERT(r, !layout.isSupported(*context.fTypes.fShort2));
REPORTER_ASSERT(r, !layout.isSupported(*context.fTypes.fShort3));
REPORTER_ASSERT(r, !layout.isSupported(*context.fTypes.fShort4));
REPORTER_ASSERT(r, !layout.isSupported(*testArray));
REPORTER_ASSERT(r, !layout.isSupported(*testStruct));
}
DEF_TEST(SkSLMemoryLayoutWGSLSupportedTypesTest, r) {
SkSL::TestingOnly_AbortErrorReporter errors;
SkSL::ShaderCaps caps;
SkSL::Mangler mangler;
SkSL::Context context(errors, caps, mangler);
auto testArray = SkSL::Type::MakeArrayType("float[3]", *context.fTypes.fFloat, 3);
std::vector<SkSL::Type::Field> fields;
fields.emplace_back(
SkSL::Position(), SkSL::Modifiers(), std::string_view("foo"), testArray.get());
auto testStruct = SkSL::Type::MakeStructType(
SkSL::Position(), std::string_view("Test"), std::move(fields));
SkSL::MemoryLayout layout(SkSL::MemoryLayout::Standard::kWGSLUniform);
// scalars (i32, u32, f32, f16)
REPORTER_ASSERT(r, layout.isSupported(*context.fTypes.fInt));
REPORTER_ASSERT(r, layout.isSupported(*context.fTypes.fUInt));
REPORTER_ASSERT(r, layout.isSupported(*context.fTypes.fFloat));
REPORTER_ASSERT(r, layout.isSupported(*context.fTypes.fHalf));
// vec2<T>, T: i32, u32, f32, f16
REPORTER_ASSERT(r, layout.isSupported(*context.fTypes.fInt2));
REPORTER_ASSERT(r, layout.isSupported(*context.fTypes.fUInt2));
REPORTER_ASSERT(r, layout.isSupported(*context.fTypes.fFloat2));
REPORTER_ASSERT(r, layout.isSupported(*context.fTypes.fHalf2));
// vec3<T>, T: i32, u32, f32, f16
REPORTER_ASSERT(r, layout.isSupported(*context.fTypes.fInt3));
REPORTER_ASSERT(r, layout.isSupported(*context.fTypes.fUInt3));
REPORTER_ASSERT(r, layout.isSupported(*context.fTypes.fFloat3));
REPORTER_ASSERT(r, layout.isSupported(*context.fTypes.fHalf3));
// vec4<T>, T: i32, u32, f32, f16
REPORTER_ASSERT(r, layout.isSupported(*context.fTypes.fInt4));
REPORTER_ASSERT(r, layout.isSupported(*context.fTypes.fUInt4));
REPORTER_ASSERT(r, layout.isSupported(*context.fTypes.fFloat4));
REPORTER_ASSERT(r, layout.isSupported(*context.fTypes.fHalf4));
// mat2x2<f32>, mat2x2<f16>
REPORTER_ASSERT(r, layout.isSupported(*context.fTypes.fFloat2x2));
REPORTER_ASSERT(r, layout.isSupported(*context.fTypes.fHalf2x2));
// mat3x2<f32>, mat3x2<f16>
REPORTER_ASSERT(r, layout.isSupported(*context.fTypes.fFloat3x2));
REPORTER_ASSERT(r, layout.isSupported(*context.fTypes.fHalf3x2));
// mat4x2<f32>, mat4x2<f16>
REPORTER_ASSERT(r, layout.isSupported(*context.fTypes.fFloat4x2));
REPORTER_ASSERT(r, layout.isSupported(*context.fTypes.fHalf4x2));
// mat2x3<f32>, mat2x3<f16>
REPORTER_ASSERT(r, layout.isSupported(*context.fTypes.fFloat2x3));
REPORTER_ASSERT(r, layout.isSupported(*context.fTypes.fHalf2x3));
// mat3x3<f32>, mat3x3<f16>
REPORTER_ASSERT(r, layout.isSupported(*context.fTypes.fFloat3x3));
REPORTER_ASSERT(r, layout.isSupported(*context.fTypes.fHalf3x3));
// mat4x3<f32>, mat4x3<f16>
REPORTER_ASSERT(r, layout.isSupported(*context.fTypes.fFloat4x3));
REPORTER_ASSERT(r, layout.isSupported(*context.fTypes.fHalf4x3));
// mat2x4<f32>, mat2x4<f16>
REPORTER_ASSERT(r, layout.isSupported(*context.fTypes.fFloat2x4));
REPORTER_ASSERT(r, layout.isSupported(*context.fTypes.fHalf2x4));
// mat3x4<f32>, mat3x4<f16>
REPORTER_ASSERT(r, layout.isSupported(*context.fTypes.fFloat3x4));
REPORTER_ASSERT(r, layout.isSupported(*context.fTypes.fHalf3x4));
// mat4x4<f32>, mat4x4<f16>
REPORTER_ASSERT(r, layout.isSupported(*context.fTypes.fFloat4x4));
REPORTER_ASSERT(r, layout.isSupported(*context.fTypes.fHalf4x4));
// arrays and structs
REPORTER_ASSERT(r, layout.isSupported(*testArray));
REPORTER_ASSERT(r, layout.isSupported(*testStruct));
}