mirror of
https://github.com/KhronosGroup/SPIRV-Tools
synced 2025-01-07 07:20:05 +00:00
3b5960174f
Fixes #1952 * Prevent scalarization of arrays that are sized by a specialization constant
818 lines
28 KiB
C++
818 lines
28 KiB
C++
// Copyright (c) 2017 Google Inc.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#include "source/opt/scalar_replacement_pass.h"
|
|
|
|
#include <algorithm>
|
|
#include <queue>
|
|
#include <tuple>
|
|
#include <utility>
|
|
|
|
#include "source/enum_string_mapping.h"
|
|
#include "source/extensions.h"
|
|
#include "source/opt/reflect.h"
|
|
#include "source/opt/types.h"
|
|
#include "source/util/make_unique.h"
|
|
|
|
namespace spvtools {
|
|
namespace opt {
|
|
|
|
Pass::Status ScalarReplacementPass::Process() {
|
|
Status status = Status::SuccessWithoutChange;
|
|
for (auto& f : *get_module()) {
|
|
Status functionStatus = ProcessFunction(&f);
|
|
if (functionStatus == Status::Failure)
|
|
return functionStatus;
|
|
else if (functionStatus == Status::SuccessWithChange)
|
|
status = functionStatus;
|
|
}
|
|
|
|
return status;
|
|
}
|
|
|
|
Pass::Status ScalarReplacementPass::ProcessFunction(Function* function) {
|
|
std::queue<Instruction*> worklist;
|
|
BasicBlock& entry = *function->begin();
|
|
for (auto iter = entry.begin(); iter != entry.end(); ++iter) {
|
|
// Function storage class OpVariables must appear as the first instructions
|
|
// of the entry block.
|
|
if (iter->opcode() != SpvOpVariable) break;
|
|
|
|
Instruction* varInst = &*iter;
|
|
if (CanReplaceVariable(varInst)) {
|
|
worklist.push(varInst);
|
|
}
|
|
}
|
|
|
|
Status status = Status::SuccessWithoutChange;
|
|
while (!worklist.empty()) {
|
|
Instruction* varInst = worklist.front();
|
|
worklist.pop();
|
|
|
|
if (!ReplaceVariable(varInst, &worklist))
|
|
return Status::Failure;
|
|
else
|
|
status = Status::SuccessWithChange;
|
|
}
|
|
|
|
return status;
|
|
}
|
|
|
|
bool ScalarReplacementPass::ReplaceVariable(
|
|
Instruction* inst, std::queue<Instruction*>* worklist) {
|
|
std::vector<Instruction*> replacements;
|
|
CreateReplacementVariables(inst, &replacements);
|
|
|
|
std::vector<Instruction*> dead;
|
|
dead.push_back(inst);
|
|
if (!get_def_use_mgr()->WhileEachUser(
|
|
inst, [this, &replacements, &dead](Instruction* user) {
|
|
if (!IsAnnotationInst(user->opcode())) {
|
|
switch (user->opcode()) {
|
|
case SpvOpLoad:
|
|
ReplaceWholeLoad(user, replacements);
|
|
dead.push_back(user);
|
|
break;
|
|
case SpvOpStore:
|
|
ReplaceWholeStore(user, replacements);
|
|
dead.push_back(user);
|
|
break;
|
|
case SpvOpAccessChain:
|
|
case SpvOpInBoundsAccessChain:
|
|
if (!ReplaceAccessChain(user, replacements)) return false;
|
|
dead.push_back(user);
|
|
break;
|
|
case SpvOpName:
|
|
case SpvOpMemberName:
|
|
break;
|
|
default:
|
|
assert(false && "Unexpected opcode");
|
|
break;
|
|
}
|
|
}
|
|
return true;
|
|
}))
|
|
return false;
|
|
|
|
// Clean up some dead code.
|
|
while (!dead.empty()) {
|
|
Instruction* toKill = dead.back();
|
|
dead.pop_back();
|
|
context()->KillInst(toKill);
|
|
}
|
|
|
|
// Attempt to further scalarize.
|
|
for (auto var : replacements) {
|
|
if (var->opcode() == SpvOpVariable) {
|
|
if (get_def_use_mgr()->NumUsers(var) == 0) {
|
|
context()->KillInst(var);
|
|
} else if (CanReplaceVariable(var)) {
|
|
worklist->push(var);
|
|
}
|
|
}
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
void ScalarReplacementPass::ReplaceWholeLoad(
|
|
Instruction* load, const std::vector<Instruction*>& replacements) {
|
|
// Replaces the load of the entire composite with a load from each replacement
|
|
// variable followed by a composite construction.
|
|
BasicBlock* block = context()->get_instr_block(load);
|
|
std::vector<Instruction*> loads;
|
|
loads.reserve(replacements.size());
|
|
BasicBlock::iterator where(load);
|
|
for (auto var : replacements) {
|
|
// Create a load of each replacement variable.
|
|
if (var->opcode() != SpvOpVariable) {
|
|
loads.push_back(var);
|
|
continue;
|
|
}
|
|
|
|
Instruction* type = GetStorageType(var);
|
|
uint32_t loadId = TakeNextId();
|
|
std::unique_ptr<Instruction> newLoad(
|
|
new Instruction(context(), SpvOpLoad, type->result_id(), loadId,
|
|
std::initializer_list<Operand>{
|
|
{SPV_OPERAND_TYPE_ID, {var->result_id()}}}));
|
|
// Copy memory access attributes which start at index 1. Index 0 is the
|
|
// pointer to load.
|
|
for (uint32_t i = 1; i < load->NumInOperands(); ++i) {
|
|
Operand copy(load->GetInOperand(i));
|
|
newLoad->AddOperand(std::move(copy));
|
|
}
|
|
where = where.InsertBefore(std::move(newLoad));
|
|
get_def_use_mgr()->AnalyzeInstDefUse(&*where);
|
|
context()->set_instr_block(&*where, block);
|
|
loads.push_back(&*where);
|
|
}
|
|
|
|
// Construct a new composite.
|
|
uint32_t compositeId = TakeNextId();
|
|
where = load;
|
|
std::unique_ptr<Instruction> compositeConstruct(new Instruction(
|
|
context(), SpvOpCompositeConstruct, load->type_id(), compositeId, {}));
|
|
for (auto l : loads) {
|
|
Operand op(SPV_OPERAND_TYPE_ID,
|
|
std::initializer_list<uint32_t>{l->result_id()});
|
|
compositeConstruct->AddOperand(std::move(op));
|
|
}
|
|
where = where.InsertBefore(std::move(compositeConstruct));
|
|
get_def_use_mgr()->AnalyzeInstDefUse(&*where);
|
|
context()->set_instr_block(&*where, block);
|
|
context()->ReplaceAllUsesWith(load->result_id(), compositeId);
|
|
}
|
|
|
|
void ScalarReplacementPass::ReplaceWholeStore(
|
|
Instruction* store, const std::vector<Instruction*>& replacements) {
|
|
// Replaces a store to the whole composite with a series of extract and stores
|
|
// to each element.
|
|
uint32_t storeInput = store->GetSingleWordInOperand(1u);
|
|
BasicBlock* block = context()->get_instr_block(store);
|
|
BasicBlock::iterator where(store);
|
|
uint32_t elementIndex = 0;
|
|
for (auto var : replacements) {
|
|
// Create the extract.
|
|
if (var->opcode() != SpvOpVariable) {
|
|
elementIndex++;
|
|
continue;
|
|
}
|
|
|
|
Instruction* type = GetStorageType(var);
|
|
uint32_t extractId = TakeNextId();
|
|
std::unique_ptr<Instruction> extract(new Instruction(
|
|
context(), SpvOpCompositeExtract, type->result_id(), extractId,
|
|
std::initializer_list<Operand>{
|
|
{SPV_OPERAND_TYPE_ID, {storeInput}},
|
|
{SPV_OPERAND_TYPE_LITERAL_INTEGER, {elementIndex++}}}));
|
|
auto iter = where.InsertBefore(std::move(extract));
|
|
get_def_use_mgr()->AnalyzeInstDefUse(&*iter);
|
|
context()->set_instr_block(&*iter, block);
|
|
|
|
// Create the store.
|
|
std::unique_ptr<Instruction> newStore(
|
|
new Instruction(context(), SpvOpStore, 0, 0,
|
|
std::initializer_list<Operand>{
|
|
{SPV_OPERAND_TYPE_ID, {var->result_id()}},
|
|
{SPV_OPERAND_TYPE_ID, {extractId}}}));
|
|
// Copy memory access attributes which start at index 2. Index 0 is the
|
|
// pointer and index 1 is the data.
|
|
for (uint32_t i = 2; i < store->NumInOperands(); ++i) {
|
|
Operand copy(store->GetInOperand(i));
|
|
newStore->AddOperand(std::move(copy));
|
|
}
|
|
iter = where.InsertBefore(std::move(newStore));
|
|
get_def_use_mgr()->AnalyzeInstDefUse(&*iter);
|
|
context()->set_instr_block(&*iter, block);
|
|
}
|
|
}
|
|
|
|
bool ScalarReplacementPass::ReplaceAccessChain(
|
|
Instruction* chain, const std::vector<Instruction*>& replacements) {
|
|
// Replaces the access chain with either another access chain (with one fewer
|
|
// indexes) or a direct use of the replacement variable.
|
|
uint32_t indexId = chain->GetSingleWordInOperand(1u);
|
|
const Instruction* index = get_def_use_mgr()->GetDef(indexId);
|
|
size_t indexValue = GetConstantInteger(index);
|
|
if (indexValue > replacements.size()) {
|
|
// Out of bounds access, this is illegal IR.
|
|
return false;
|
|
} else {
|
|
const Instruction* var = replacements[indexValue];
|
|
if (chain->NumInOperands() > 2) {
|
|
// Replace input access chain with another access chain.
|
|
BasicBlock::iterator chainIter(chain);
|
|
uint32_t replacementId = TakeNextId();
|
|
std::unique_ptr<Instruction> replacementChain(new Instruction(
|
|
context(), chain->opcode(), chain->type_id(), replacementId,
|
|
std::initializer_list<Operand>{
|
|
{SPV_OPERAND_TYPE_ID, {var->result_id()}}}));
|
|
// Add the remaining indexes.
|
|
for (uint32_t i = 2; i < chain->NumInOperands(); ++i) {
|
|
Operand copy(chain->GetInOperand(i));
|
|
replacementChain->AddOperand(std::move(copy));
|
|
}
|
|
auto iter = chainIter.InsertBefore(std::move(replacementChain));
|
|
get_def_use_mgr()->AnalyzeInstDefUse(&*iter);
|
|
context()->set_instr_block(&*iter, context()->get_instr_block(chain));
|
|
context()->ReplaceAllUsesWith(chain->result_id(), replacementId);
|
|
} else {
|
|
// Replace with a use of the variable.
|
|
context()->ReplaceAllUsesWith(chain->result_id(), var->result_id());
|
|
}
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
void ScalarReplacementPass::CreateReplacementVariables(
|
|
Instruction* inst, std::vector<Instruction*>* replacements) {
|
|
Instruction* type = GetStorageType(inst);
|
|
|
|
std::unique_ptr<std::unordered_set<uint64_t>> components_used =
|
|
GetUsedComponents(inst);
|
|
|
|
uint32_t elem = 0;
|
|
switch (type->opcode()) {
|
|
case SpvOpTypeStruct:
|
|
type->ForEachInOperand(
|
|
[this, inst, &elem, replacements, &components_used](uint32_t* id) {
|
|
if (!components_used || components_used->count(elem)) {
|
|
CreateVariable(*id, inst, elem, replacements);
|
|
} else {
|
|
replacements->push_back(CreateNullConstant(*id));
|
|
}
|
|
elem++;
|
|
});
|
|
break;
|
|
case SpvOpTypeArray:
|
|
for (uint32_t i = 0; i != GetArrayLength(type); ++i) {
|
|
if (!components_used || components_used->count(i)) {
|
|
CreateVariable(type->GetSingleWordInOperand(0u), inst, i,
|
|
replacements);
|
|
} else {
|
|
replacements->push_back(
|
|
CreateNullConstant(type->GetSingleWordInOperand(0u)));
|
|
}
|
|
}
|
|
break;
|
|
|
|
case SpvOpTypeMatrix:
|
|
case SpvOpTypeVector:
|
|
for (uint32_t i = 0; i != GetNumElements(type); ++i) {
|
|
CreateVariable(type->GetSingleWordInOperand(0u), inst, i, replacements);
|
|
}
|
|
break;
|
|
|
|
default:
|
|
assert(false && "Unexpected type.");
|
|
break;
|
|
}
|
|
|
|
TransferAnnotations(inst, replacements);
|
|
}
|
|
|
|
void ScalarReplacementPass::TransferAnnotations(
|
|
const Instruction* source, std::vector<Instruction*>* replacements) {
|
|
// Only transfer invariant and restrict decorations on the variable. There are
|
|
// no type or member decorations that are necessary to transfer.
|
|
for (auto inst :
|
|
get_decoration_mgr()->GetDecorationsFor(source->result_id(), false)) {
|
|
assert(inst->opcode() == SpvOpDecorate);
|
|
uint32_t decoration = inst->GetSingleWordInOperand(1u);
|
|
if (decoration == SpvDecorationInvariant ||
|
|
decoration == SpvDecorationRestrict) {
|
|
for (auto var : *replacements) {
|
|
std::unique_ptr<Instruction> annotation(
|
|
new Instruction(context(), SpvOpDecorate, 0, 0,
|
|
std::initializer_list<Operand>{
|
|
{SPV_OPERAND_TYPE_ID, {var->result_id()}},
|
|
{SPV_OPERAND_TYPE_DECORATION, {decoration}}}));
|
|
for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
|
|
Operand copy(inst->GetInOperand(i));
|
|
annotation->AddOperand(std::move(copy));
|
|
}
|
|
context()->AddAnnotationInst(std::move(annotation));
|
|
get_def_use_mgr()->AnalyzeInstUse(&*--context()->annotation_end());
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
void ScalarReplacementPass::CreateVariable(
|
|
uint32_t typeId, Instruction* varInst, uint32_t index,
|
|
std::vector<Instruction*>* replacements) {
|
|
uint32_t ptrId = GetOrCreatePointerType(typeId);
|
|
uint32_t id = TakeNextId();
|
|
std::unique_ptr<Instruction> variable(new Instruction(
|
|
context(), SpvOpVariable, ptrId, id,
|
|
std::initializer_list<Operand>{
|
|
{SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}}}));
|
|
|
|
BasicBlock* block = context()->get_instr_block(varInst);
|
|
block->begin().InsertBefore(std::move(variable));
|
|
Instruction* inst = &*block->begin();
|
|
|
|
// If varInst was initialized, make sure to initialize its replacement.
|
|
GetOrCreateInitialValue(varInst, index, inst);
|
|
get_def_use_mgr()->AnalyzeInstDefUse(inst);
|
|
context()->set_instr_block(inst, block);
|
|
|
|
replacements->push_back(inst);
|
|
}
|
|
|
|
uint32_t ScalarReplacementPass::GetOrCreatePointerType(uint32_t id) {
|
|
auto iter = pointee_to_pointer_.find(id);
|
|
if (iter != pointee_to_pointer_.end()) return iter->second;
|
|
|
|
analysis::Type* pointeeTy;
|
|
std::unique_ptr<analysis::Pointer> pointerTy;
|
|
std::tie(pointeeTy, pointerTy) =
|
|
context()->get_type_mgr()->GetTypeAndPointerType(id,
|
|
SpvStorageClassFunction);
|
|
uint32_t ptrId = 0;
|
|
if (id == context()->get_type_mgr()->GetId(pointeeTy)) {
|
|
// Non-ambiguous type, just ask the type manager for an id.
|
|
ptrId = context()->get_type_mgr()->GetTypeInstruction(pointerTy.get());
|
|
pointee_to_pointer_[id] = ptrId;
|
|
return ptrId;
|
|
}
|
|
|
|
// Ambiguous type. We must perform a linear search to try and find the right
|
|
// type.
|
|
for (auto global : context()->types_values()) {
|
|
if (global.opcode() == SpvOpTypePointer &&
|
|
global.GetSingleWordInOperand(0u) == SpvStorageClassFunction &&
|
|
global.GetSingleWordInOperand(1u) == id) {
|
|
if (get_decoration_mgr()->GetDecorationsFor(id, false).empty()) {
|
|
// Only reuse a decoration-less pointer of the correct type.
|
|
ptrId = global.result_id();
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
if (ptrId != 0) {
|
|
pointee_to_pointer_[id] = ptrId;
|
|
return ptrId;
|
|
}
|
|
|
|
ptrId = TakeNextId();
|
|
context()->AddType(MakeUnique<Instruction>(
|
|
context(), SpvOpTypePointer, 0, ptrId,
|
|
std::initializer_list<Operand>{
|
|
{SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}},
|
|
{SPV_OPERAND_TYPE_ID, {id}}}));
|
|
Instruction* ptr = &*--context()->types_values_end();
|
|
get_def_use_mgr()->AnalyzeInstDefUse(ptr);
|
|
pointee_to_pointer_[id] = ptrId;
|
|
// Register with the type manager if necessary.
|
|
context()->get_type_mgr()->RegisterType(ptrId, *pointerTy);
|
|
|
|
return ptrId;
|
|
}
|
|
|
|
void ScalarReplacementPass::GetOrCreateInitialValue(Instruction* source,
|
|
uint32_t index,
|
|
Instruction* newVar) {
|
|
assert(source->opcode() == SpvOpVariable);
|
|
if (source->NumInOperands() < 2) return;
|
|
|
|
uint32_t initId = source->GetSingleWordInOperand(1u);
|
|
uint32_t storageId = GetStorageType(newVar)->result_id();
|
|
Instruction* init = get_def_use_mgr()->GetDef(initId);
|
|
uint32_t newInitId = 0;
|
|
// TODO(dnovillo): Refactor this with constant propagation.
|
|
if (init->opcode() == SpvOpConstantNull) {
|
|
// Initialize to appropriate NULL.
|
|
auto iter = type_to_null_.find(storageId);
|
|
if (iter == type_to_null_.end()) {
|
|
newInitId = TakeNextId();
|
|
type_to_null_[storageId] = newInitId;
|
|
context()->AddGlobalValue(
|
|
MakeUnique<Instruction>(context(), SpvOpConstantNull, storageId,
|
|
newInitId, std::initializer_list<Operand>{}));
|
|
Instruction* newNull = &*--context()->types_values_end();
|
|
get_def_use_mgr()->AnalyzeInstDefUse(newNull);
|
|
} else {
|
|
newInitId = iter->second;
|
|
}
|
|
} else if (IsSpecConstantInst(init->opcode())) {
|
|
// Create a new constant extract.
|
|
newInitId = TakeNextId();
|
|
context()->AddGlobalValue(MakeUnique<Instruction>(
|
|
context(), SpvOpSpecConstantOp, storageId, newInitId,
|
|
std::initializer_list<Operand>{
|
|
{SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER, {SpvOpCompositeExtract}},
|
|
{SPV_OPERAND_TYPE_ID, {init->result_id()}},
|
|
{SPV_OPERAND_TYPE_LITERAL_INTEGER, {index}}}));
|
|
Instruction* newSpecConst = &*--context()->types_values_end();
|
|
get_def_use_mgr()->AnalyzeInstDefUse(newSpecConst);
|
|
} else if (init->opcode() == SpvOpConstantComposite) {
|
|
// Get the appropriate index constant.
|
|
newInitId = init->GetSingleWordInOperand(index);
|
|
Instruction* element = get_def_use_mgr()->GetDef(newInitId);
|
|
if (element->opcode() == SpvOpUndef) {
|
|
// Undef is not a valid initializer for a variable.
|
|
newInitId = 0;
|
|
}
|
|
} else {
|
|
assert(false);
|
|
}
|
|
|
|
if (newInitId != 0) {
|
|
newVar->AddOperand({SPV_OPERAND_TYPE_ID, {newInitId}});
|
|
}
|
|
}
|
|
|
|
size_t ScalarReplacementPass::GetIntegerLiteral(const Operand& op) const {
|
|
assert(op.words.size() <= 2);
|
|
size_t len = 0;
|
|
for (uint32_t i = 0; i != op.words.size(); ++i) {
|
|
len |= (op.words[i] << (32 * i));
|
|
}
|
|
return len;
|
|
}
|
|
|
|
size_t ScalarReplacementPass::GetConstantInteger(
|
|
const Instruction* constant) const {
|
|
assert(get_def_use_mgr()->GetDef(constant->type_id())->opcode() ==
|
|
SpvOpTypeInt);
|
|
assert(constant->opcode() == SpvOpConstant ||
|
|
constant->opcode() == SpvOpConstantNull);
|
|
if (constant->opcode() == SpvOpConstantNull) {
|
|
return 0;
|
|
}
|
|
|
|
const Operand& op = constant->GetInOperand(0u);
|
|
return GetIntegerLiteral(op);
|
|
}
|
|
|
|
size_t ScalarReplacementPass::GetArrayLength(
|
|
const Instruction* arrayType) const {
|
|
assert(arrayType->opcode() == SpvOpTypeArray);
|
|
const Instruction* length =
|
|
get_def_use_mgr()->GetDef(arrayType->GetSingleWordInOperand(1u));
|
|
return GetConstantInteger(length);
|
|
}
|
|
|
|
size_t ScalarReplacementPass::GetNumElements(const Instruction* type) const {
|
|
assert(type->opcode() == SpvOpTypeVector ||
|
|
type->opcode() == SpvOpTypeMatrix);
|
|
const Operand& op = type->GetInOperand(1u);
|
|
assert(op.words.size() <= 2);
|
|
size_t len = 0;
|
|
for (uint32_t i = 0; i != op.words.size(); ++i) {
|
|
len |= (op.words[i] << (32 * i));
|
|
}
|
|
return len;
|
|
}
|
|
|
|
bool ScalarReplacementPass::IsSpecConstant(uint32_t id) const {
|
|
const Instruction* inst = get_def_use_mgr()->GetDef(id);
|
|
assert(inst);
|
|
return spvOpcodeIsSpecConstant(inst->opcode());
|
|
}
|
|
|
|
Instruction* ScalarReplacementPass::GetStorageType(
|
|
const Instruction* inst) const {
|
|
assert(inst->opcode() == SpvOpVariable);
|
|
|
|
uint32_t ptrTypeId = inst->type_id();
|
|
uint32_t typeId =
|
|
get_def_use_mgr()->GetDef(ptrTypeId)->GetSingleWordInOperand(1u);
|
|
return get_def_use_mgr()->GetDef(typeId);
|
|
}
|
|
|
|
bool ScalarReplacementPass::CanReplaceVariable(
|
|
const Instruction* varInst) const {
|
|
assert(varInst->opcode() == SpvOpVariable);
|
|
|
|
// Can only replace function scope variables.
|
|
if (varInst->GetSingleWordInOperand(0u) != SpvStorageClassFunction)
|
|
return false;
|
|
|
|
if (!CheckTypeAnnotations(get_def_use_mgr()->GetDef(varInst->type_id())))
|
|
return false;
|
|
|
|
const Instruction* typeInst = GetStorageType(varInst);
|
|
return CheckType(typeInst) && CheckAnnotations(varInst) && CheckUses(varInst);
|
|
}
|
|
|
|
bool ScalarReplacementPass::CheckType(const Instruction* typeInst) const {
|
|
if (!CheckTypeAnnotations(typeInst)) return false;
|
|
|
|
switch (typeInst->opcode()) {
|
|
case SpvOpTypeStruct:
|
|
// Don't bother with empty structs or very large structs.
|
|
if (typeInst->NumInOperands() == 0 ||
|
|
IsLargerThanSizeLimit(typeInst->NumInOperands()))
|
|
return false;
|
|
return true;
|
|
case SpvOpTypeArray:
|
|
if (IsSpecConstant(typeInst->GetSingleWordInOperand(1u))) {
|
|
return false;
|
|
}
|
|
if (IsLargerThanSizeLimit(GetArrayLength(typeInst))) {
|
|
return false;
|
|
}
|
|
return true;
|
|
// TODO(alanbaker): Develop some heuristics for when this should be
|
|
// re-enabled.
|
|
//// Specifically including matrix and vector in an attempt to reduce the
|
|
//// number of vector registers required.
|
|
// case SpvOpTypeMatrix:
|
|
// case SpvOpTypeVector:
|
|
// if (IsLargerThanSizeLimit(GetNumElements(typeInst))) return false;
|
|
// return true;
|
|
|
|
case SpvOpTypeRuntimeArray:
|
|
default:
|
|
return false;
|
|
}
|
|
}
|
|
|
|
bool ScalarReplacementPass::CheckTypeAnnotations(
|
|
const Instruction* typeInst) const {
|
|
for (auto inst :
|
|
get_decoration_mgr()->GetDecorationsFor(typeInst->result_id(), false)) {
|
|
uint32_t decoration;
|
|
if (inst->opcode() == SpvOpDecorate) {
|
|
decoration = inst->GetSingleWordInOperand(1u);
|
|
} else {
|
|
assert(inst->opcode() == SpvOpMemberDecorate);
|
|
decoration = inst->GetSingleWordInOperand(2u);
|
|
}
|
|
|
|
switch (decoration) {
|
|
case SpvDecorationRowMajor:
|
|
case SpvDecorationColMajor:
|
|
case SpvDecorationArrayStride:
|
|
case SpvDecorationMatrixStride:
|
|
case SpvDecorationCPacked:
|
|
case SpvDecorationInvariant:
|
|
case SpvDecorationRestrict:
|
|
case SpvDecorationOffset:
|
|
case SpvDecorationAlignment:
|
|
case SpvDecorationAlignmentId:
|
|
case SpvDecorationMaxByteOffset:
|
|
break;
|
|
default:
|
|
return false;
|
|
}
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
bool ScalarReplacementPass::CheckAnnotations(const Instruction* varInst) const {
|
|
for (auto inst :
|
|
get_decoration_mgr()->GetDecorationsFor(varInst->result_id(), false)) {
|
|
assert(inst->opcode() == SpvOpDecorate);
|
|
uint32_t decoration = inst->GetSingleWordInOperand(1u);
|
|
switch (decoration) {
|
|
case SpvDecorationInvariant:
|
|
case SpvDecorationRestrict:
|
|
case SpvDecorationAlignment:
|
|
case SpvDecorationAlignmentId:
|
|
case SpvDecorationMaxByteOffset:
|
|
break;
|
|
default:
|
|
return false;
|
|
}
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
bool ScalarReplacementPass::CheckUses(const Instruction* inst) const {
|
|
VariableStats stats = {0, 0};
|
|
bool ok = CheckUses(inst, &stats);
|
|
|
|
// TODO(alanbaker/greg-lunarg): Add some meaningful heuristics about when
|
|
// SRoA is costly, such as when the structure has many (unaccessed?)
|
|
// members.
|
|
|
|
return ok;
|
|
}
|
|
|
|
bool ScalarReplacementPass::CheckUses(const Instruction* inst,
|
|
VariableStats* stats) const {
|
|
bool ok = true;
|
|
get_def_use_mgr()->ForEachUse(
|
|
inst, [this, stats, &ok](const Instruction* user, uint32_t index) {
|
|
// Annotations are check as a group separately.
|
|
if (!IsAnnotationInst(user->opcode())) {
|
|
switch (user->opcode()) {
|
|
case SpvOpAccessChain:
|
|
case SpvOpInBoundsAccessChain:
|
|
if (index == 2u) {
|
|
uint32_t id = user->GetSingleWordOperand(3u);
|
|
const Instruction* opInst = get_def_use_mgr()->GetDef(id);
|
|
if (!IsCompileTimeConstantInst(opInst->opcode())) {
|
|
ok = false;
|
|
} else {
|
|
if (!CheckUsesRelaxed(user)) ok = false;
|
|
}
|
|
stats->num_partial_accesses++;
|
|
} else {
|
|
ok = false;
|
|
}
|
|
break;
|
|
case SpvOpLoad:
|
|
if (!CheckLoad(user, index)) ok = false;
|
|
stats->num_full_accesses++;
|
|
break;
|
|
case SpvOpStore:
|
|
if (!CheckStore(user, index)) ok = false;
|
|
stats->num_full_accesses++;
|
|
break;
|
|
case SpvOpName:
|
|
case SpvOpMemberName:
|
|
break;
|
|
default:
|
|
ok = false;
|
|
break;
|
|
}
|
|
}
|
|
});
|
|
|
|
return ok;
|
|
}
|
|
|
|
bool ScalarReplacementPass::CheckUsesRelaxed(const Instruction* inst) const {
|
|
bool ok = true;
|
|
get_def_use_mgr()->ForEachUse(
|
|
inst, [this, &ok](const Instruction* user, uint32_t index) {
|
|
switch (user->opcode()) {
|
|
case SpvOpAccessChain:
|
|
case SpvOpInBoundsAccessChain:
|
|
if (index != 2u) {
|
|
ok = false;
|
|
} else {
|
|
if (!CheckUsesRelaxed(user)) ok = false;
|
|
}
|
|
break;
|
|
case SpvOpLoad:
|
|
if (!CheckLoad(user, index)) ok = false;
|
|
break;
|
|
case SpvOpStore:
|
|
if (!CheckStore(user, index)) ok = false;
|
|
break;
|
|
default:
|
|
ok = false;
|
|
break;
|
|
}
|
|
});
|
|
|
|
return ok;
|
|
}
|
|
|
|
bool ScalarReplacementPass::CheckLoad(const Instruction* inst,
|
|
uint32_t index) const {
|
|
if (index != 2u) return false;
|
|
if (inst->NumInOperands() >= 2 &&
|
|
inst->GetSingleWordInOperand(1u) & SpvMemoryAccessVolatileMask)
|
|
return false;
|
|
return true;
|
|
}
|
|
|
|
bool ScalarReplacementPass::CheckStore(const Instruction* inst,
|
|
uint32_t index) const {
|
|
if (index != 0u) return false;
|
|
if (inst->NumInOperands() >= 3 &&
|
|
inst->GetSingleWordInOperand(2u) & SpvMemoryAccessVolatileMask)
|
|
return false;
|
|
return true;
|
|
}
|
|
bool ScalarReplacementPass::IsLargerThanSizeLimit(size_t length) const {
|
|
if (max_num_elements_ == 0) {
|
|
return false;
|
|
}
|
|
return length > max_num_elements_;
|
|
}
|
|
|
|
std::unique_ptr<std::unordered_set<uint64_t>>
|
|
ScalarReplacementPass::GetUsedComponents(Instruction* inst) {
|
|
std::unique_ptr<std::unordered_set<uint64_t>> result(
|
|
new std::unordered_set<uint64_t>());
|
|
|
|
analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
|
|
|
|
def_use_mgr->WhileEachUser(inst, [&result, def_use_mgr,
|
|
this](Instruction* use) {
|
|
switch (use->opcode()) {
|
|
case SpvOpLoad: {
|
|
// Look for extract from the load.
|
|
std::vector<uint32_t> t;
|
|
if (def_use_mgr->WhileEachUser(use, [&t](Instruction* use2) {
|
|
if (use2->opcode() != SpvOpCompositeExtract) {
|
|
return false;
|
|
}
|
|
t.push_back(use2->GetSingleWordInOperand(1));
|
|
return true;
|
|
})) {
|
|
result->insert(t.begin(), t.end());
|
|
return true;
|
|
} else {
|
|
result.reset(nullptr);
|
|
return false;
|
|
}
|
|
}
|
|
case SpvOpStore:
|
|
// No components are used. Things are just stored to.
|
|
return true;
|
|
case SpvOpAccessChain:
|
|
case SpvOpInBoundsAccessChain: {
|
|
// Add the first index it if is a constant.
|
|
// TODO: Could be improved by checking if the address is used in a load.
|
|
analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
|
|
uint32_t index_id = use->GetSingleWordInOperand(1);
|
|
const analysis::Constant* index_const =
|
|
const_mgr->FindDeclaredConstant(index_id);
|
|
if (index_const) {
|
|
const analysis::Integer* index_type =
|
|
index_const->type()->AsInteger();
|
|
assert(index_type);
|
|
if (index_type->width() == 32) {
|
|
result->insert(index_const->GetU32());
|
|
return true;
|
|
} else if (index_type->width() == 64) {
|
|
result->insert(index_const->GetU64());
|
|
return true;
|
|
}
|
|
result.reset(nullptr);
|
|
return false;
|
|
} else {
|
|
// Could be any element. Assuming all are used.
|
|
result.reset(nullptr);
|
|
return false;
|
|
}
|
|
}
|
|
case SpvOpCopyObject: {
|
|
// Follow the copy to see which components are used.
|
|
auto t = GetUsedComponents(use);
|
|
if (!t) {
|
|
result.reset(nullptr);
|
|
return false;
|
|
}
|
|
result->insert(t->begin(), t->end());
|
|
return true;
|
|
}
|
|
default:
|
|
// We do not know what is happening. Have to assume the worst.
|
|
result.reset(nullptr);
|
|
return false;
|
|
}
|
|
});
|
|
|
|
return result;
|
|
}
|
|
|
|
Instruction* ScalarReplacementPass::CreateNullConstant(uint32_t type_id) {
|
|
analysis::TypeManager* type_mgr = context()->get_type_mgr();
|
|
analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
|
|
|
|
const analysis::Type* type = type_mgr->GetType(type_id);
|
|
const analysis::Constant* null_const = const_mgr->GetConstant(type, {});
|
|
Instruction* null_inst =
|
|
const_mgr->GetDefiningInstruction(null_const, type_id);
|
|
context()->UpdateDefUse(null_inst);
|
|
return null_inst;
|
|
}
|
|
|
|
} // namespace opt
|
|
} // namespace spvtools
|