Support loading col/row-major matrices from ByteAddressBuffer.

This commit is contained in:
Hans-Kristian Arntzen 2017-10-26 16:35:18 +02:00
parent 43d178f780
commit 551424ce43
3 changed files with 158 additions and 40 deletions

View File

@ -638,7 +638,8 @@ struct SPIRAccessChain : IVariant
int32_t static_index;
uint32_t loaded_from = 0;
bool need_transpose = false;
uint32_t matrix_stride = 0;
bool row_major_matrix = false;
bool immutable = false;
};

View File

@ -4226,8 +4226,8 @@ string CompilerGLSL::access_chain(uint32_t base, const uint32_t *indices, uint32
{
if (flattened_buffer_blocks.count(base))
{
uint32_t matrix_stride;
bool need_transpose;
uint32_t matrix_stride = 0;
bool need_transpose = false;
flattened_access_chain_offset(expression_type(base), indices, count, 0, 16, &need_transpose, &matrix_stride);
if (out_need_transpose)
@ -4462,10 +4462,11 @@ std::pair<std::string, uint32_t> CompilerGLSL::flattened_access_chain_offset(con
assert(type->basetype == SPIRType::Struct);
uint32_t type_id = 0;
uint32_t matrix_stride = 0;
std::string expr;
bool row_major_matrix_needs_conversion = false;
// Inherit matrix information in case we are access chaining a vector which might have come from a row major layout.
bool row_major_matrix_needs_conversion = need_transpose ? *need_transpose : false;
uint32_t matrix_stride = out_matrix_stride ? *out_matrix_stride : 0;
for (uint32_t i = 0; i < count; i++)
{
@ -4535,11 +4536,29 @@ std::pair<std::string, uint32_t> CompilerGLSL::flattened_access_chain_offset(con
// Matrix -> Vector
else if (type->columns > 1)
{
if (ids[index].get_type() != TypeConstant)
SPIRV_CROSS_THROW("Cannot flatten dynamic matrix indexing!");
auto *constant = maybe_get<SPIRConstant>(index);
if (constant)
{
index = get<SPIRConstant>(index).scalar();
offset += index * (row_major_matrix_needs_conversion ? (type->width / 8) : matrix_stride);
}
else
{
uint32_t indexing_stride = row_major_matrix_needs_conversion ? (type->width / 8) : matrix_stride;
// Dynamic array access.
if (indexing_stride % word_stride)
{
SPIRV_CROSS_THROW(
"Matrix stride for dynamic indexing must be divisible by the size of a 4-component vector. "
"Likely culprit here is a row-major matrix being accessed dynamically. "
"This cannot be flattened. Try using std140 layout instead.");
}
index = get<SPIRConstant>(index).scalar();
offset += index * (row_major_matrix_needs_conversion ? type->width / 8 : matrix_stride);
expr += to_enclosed_expression(index);
expr += " * ";
expr += convert_to_string(indexing_stride / word_stride);
expr += " + ";
}
uint32_t parent_type = type->parent_type;
type = &get<SPIRType>(type->parent_type);
@ -4548,11 +4567,29 @@ std::pair<std::string, uint32_t> CompilerGLSL::flattened_access_chain_offset(con
// Vector -> Scalar
else if (type->vecsize > 1)
{
if (ids[index].get_type() != TypeConstant)
SPIRV_CROSS_THROW("Cannot flatten dynamic vector indexing!");
auto *constant = maybe_get<SPIRConstant>(index);
if (constant)
{
index = get<SPIRConstant>(index).scalar();
offset += index * (row_major_matrix_needs_conversion ? matrix_stride : (type->width / 8));
}
else
{
uint32_t indexing_stride = row_major_matrix_needs_conversion ? matrix_stride : (type->width / 8);
index = get<SPIRConstant>(index).scalar();
offset += index * (row_major_matrix_needs_conversion ? matrix_stride : type->width / 8);
// Dynamic array access.
if (indexing_stride % word_stride)
{
SPIRV_CROSS_THROW(
"Stride for dynamic vector indexing must be divisible by the size of a 4-component vector. "
"This cannot be flattened in legacy targets.");
}
expr += to_enclosed_expression(index);
expr += " * ";
expr += convert_to_string(indexing_stride / word_stride);
expr += " + ";
}
uint32_t parent_type = type->parent_type;
type = &get<SPIRType>(type->parent_type);

View File

@ -2192,36 +2192,106 @@ string CompilerHLSL::read_access_chain(const SPIRAccessChain &chain)
target_type.vecsize = type.vecsize;
target_type.columns = type.columns;
// FIXME: Transposition?
if (type.columns != 1)
SPIRV_CROSS_THROW("Reading matrices from ByteAddressBuffer not yet supported.");
if (type.basetype == SPIRType::Struct)
SPIRV_CROSS_THROW("Reading structs from ByteAddressBuffer not yet supported.");
if (type.width != 32)
SPIRV_CROSS_THROW("Reading types other than 32-bit from ByteAddressBuffer not yet supported.");
const char *load_op = nullptr;
switch (type.vecsize)
string load_expr;
// Load a vector or scalar.
if (type.columns == 1 && !chain.row_major_matrix)
{
case 1:
load_op = "Load";
break;
case 2:
load_op = "Load2";
break;
case 3:
load_op = "Load3";
break;
case 4:
load_op = "Load4";
break;
default:
SPIRV_CROSS_THROW("Unknown vector size.");
const char *load_op = nullptr;
switch (type.vecsize)
{
case 1:
load_op = "Load";
break;
case 2:
load_op = "Load2";
break;
case 3:
load_op = "Load3";
break;
case 4:
load_op = "Load4";
break;
default:
SPIRV_CROSS_THROW("Unknown vector size.");
}
load_expr = join(chain.base, ".", load_op, "(", chain.dynamic_index, chain.static_index, ")");
}
else if (type.columns == 1)
{
// Strided load since we are loading a column from a row-major matrix.
load_expr = type_to_glsl(target_type);
load_expr += "(";
for (uint32_t r = 0; r < type.vecsize; r++)
{
load_expr += join(chain.base, ".Load(", chain.dynamic_index, chain.static_index + r * chain.matrix_stride, ")");
if (r + 1 < type.vecsize)
load_expr += ", ";
}
load_expr += ")";
}
else if (!chain.row_major_matrix)
{
// Load a matrix, column-major, the easy case.
const char *load_op = nullptr;
switch (type.vecsize)
{
case 1:
load_op = "Load";
break;
case 2:
load_op = "Load2";
break;
case 3:
load_op = "Load3";
break;
case 4:
load_op = "Load4";
break;
default:
SPIRV_CROSS_THROW("Unknown vector size.");
}
// Note, this loading style in HLSL is *actually* row-major, but we always treat matrices as transposed in this backend,
// so row-major is technically column-major ...
load_expr = type_to_glsl(target_type);
load_expr += "(";
for (uint32_t c = 0; c < type.columns; c++)
{
load_expr += join(chain.base, ".", load_op, "(", chain.dynamic_index, chain.static_index + c * chain.matrix_stride, ")");
if (c + 1 < type.columns)
load_expr += ", ";
}
load_expr += ")";
}
else
{
// Pick out elements one by one ... Hopefully compilers are smart enough to recognize this pattern
// considering HLSL is "row-major decl", but "column-major" memory layout (basically implicit transpose model, ugh) ...
load_expr = type_to_glsl(target_type);
load_expr += "(";
for (uint32_t r = 0; r < type.vecsize; r++)
{
for (uint32_t c = 0; c < type.columns; c++)
{
load_expr += join(chain.base, ".Load(", chain.dynamic_index,
chain.static_index + r * (type.width / 8) + c * chain.matrix_stride, ")");
if ((r + 1 < type.vecsize) || (c + 1 < type.columns))
load_expr += ", ";
}
}
load_expr += ")";
}
auto load_expr = join(chain.base, ".", load_op, "(", chain.dynamic_index, chain.static_index, ")");
auto bitcast_op = bitcast_glsl_op(type, target_type);
if (!bitcast_op.empty())
load_expr = join(bitcast_op, "(", load_expr, ")");
@ -2244,7 +2314,7 @@ void CompilerHLSL::emit_load(const Instruction &instruction)
bool forward = should_forward(ptr) && forced_temporaries.find(id) == end(forced_temporaries);
auto &e = emit_op(result_type, id, load_expr, forward, true);
e.need_transpose = false; // TODO: Forward this somehow.
e.need_transpose = false;
register_read(id, ptr, forward);
}
else
@ -2355,19 +2425,29 @@ void CompilerHLSL::emit_access_chain(const Instruction &instruction)
}
uint32_t matrix_stride = 0;
bool need_transpose = false;
bool row_major_matrix = false;
// Inherit matrix information.
if (chain)
{
matrix_stride = chain->matrix_stride;
row_major_matrix = chain->row_major_matrix;
}
auto offsets =
flattened_access_chain_offset(*basetype, &ops[3 + to_plain_buffer_length],
length - 3 - to_plain_buffer_length, 0, 1, &need_transpose, &matrix_stride);
length - 3 - to_plain_buffer_length, 0, 1, &row_major_matrix, &matrix_stride);
auto &e = set<SPIRAccessChain>(ops[1], ops[0], type.storage, base, offsets.first, offsets.second);
e.row_major_matrix = row_major_matrix;
e.matrix_stride = matrix_stride;
e.immutable = should_forward(ops[2]);
if (chain)
{
e.dynamic_index += chain->dynamic_index;
e.static_index += chain->static_index;
}
e.immutable = should_forward(ops[2]);
}
else
{