SPIRV-Tools/source/opt/scalar_replacement_pass.cpp
Alan Baker 867451f49e Add scalar replacement
Adds a scalar replacement pass. The pass considers all function scope
variables of composite type. If there are accesses to individual
elements (and it is legal) the pass replaces the variable with a
variable for each composite element and updates all the uses.

Added the pass to -O
Added NumUses and NumUsers to DefUseManager
Added some helper methods for the inst to block mapping in context
Added some helper methods for specific constant types

No longer generate duplicate pointer types.

* Now searches for an existing pointer of the appropriate type instead
of failing validation
* Fixed spec constant extracts
* Addressed changes for review
* Changed RunSinglePassAndMatch to be able to run validation
 * current users do not enable it

Added handling of acceptable decorations.

* Decorations are also transfered where appropriate

Refactored extension checking into FeatureManager

* Context now owns a feature manager
 * consciously NOT an analysis
 * added some test
* fixed some minor issues related to decorates
* added some decorate related tests for scalar replacement
2017-12-11 10:51:13 -05:00

681 lines
23 KiB
C++

// Copyright (c) 2017 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "scalar_replacement_pass.h"
#include "enum_string_mapping.h"
#include "extensions.h"
#include "make_unique.h"
#include "reflect.h"
#include "types.h"
#include <queue>
namespace spvtools {
namespace opt {
// Heuristic aggregate element limit.
const uint32_t MAX_NUM_ELEMENTS = 100u;
Pass::Status ScalarReplacementPass::Process(ir::IRContext* c) {
InitializeProcessing(c);
Status status = Status::SuccessWithoutChange;
for (auto& f : *get_module()) {
Status functionStatus = ProcessFunction(&f);
if (functionStatus == Status::Failure)
return functionStatus;
else if (functionStatus == Status::SuccessWithChange)
status = functionStatus;
}
return status;
}
Pass::Status ScalarReplacementPass::ProcessFunction(ir::Function* function) {
std::queue<ir::Instruction*> worklist;
ir::BasicBlock& entry = *function->begin();
for (auto iter = entry.begin(); iter != entry.end(); ++iter) {
// Function storage class OpVariables must appear as the first instructions
// of the entry block.
if (iter->opcode() != SpvOpVariable) break;
ir::Instruction* varInst = &*iter;
if (CanReplaceVariable(varInst)) {
worklist.push(varInst);
}
}
Status status = Status::SuccessWithoutChange;
while (!worklist.empty()) {
ir::Instruction* varInst = worklist.front();
worklist.pop();
if (!ReplaceVariable(varInst, &worklist))
return Status::Failure;
else
status = Status::SuccessWithChange;
}
return status;
}
bool ScalarReplacementPass::ReplaceVariable(
ir::Instruction* inst, std::queue<ir::Instruction*>* worklist) {
std::vector<ir::Instruction*> replacements;
CreateReplacementVariables(inst, &replacements);
bool ok = true;
std::vector<ir::Instruction*> dead;
dead.push_back(inst);
get_def_use_mgr()->ForEachUser(
inst, [this, &ok, &replacements, &dead](ir::Instruction* user) {
if (!ir::IsAnnotationInst(user->opcode())) {
switch (user->opcode()) {
case SpvOpLoad:
ReplaceWholeLoad(user, replacements);
dead.push_back(user);
break;
case SpvOpStore:
ReplaceWholeStore(user, replacements);
dead.push_back(user);
break;
case SpvOpAccessChain:
case SpvOpInBoundsAccessChain:
ok &= ReplaceAccessChain(user, replacements);
dead.push_back(user);
break;
case SpvOpName:
case SpvOpMemberName:
break;
default:
assert(false && "Unexpected opcode");
break;
}
}
});
// There was an illegal access.
if (!ok) return false;
// Clean up some dead code.
while (!dead.empty()) {
ir::Instruction* toKill = dead.back();
dead.pop_back();
context()->KillInst(toKill);
}
// Attempt to further scalarize.
for (auto var : replacements) {
if (get_def_use_mgr()->NumUsers(var) == 0) {
context()->KillInst(var);
} else if (CanReplaceVariable(var)) {
worklist->push(var);
}
}
return ok;
}
void ScalarReplacementPass::ReplaceWholeLoad(
ir::Instruction* load, const std::vector<ir::Instruction*>& replacements) {
// Replaces the load of the entire composite with a load from each replacement
// variable followed by a composite construction.
ir::BasicBlock* block = context()->get_instr_block(load);
std::vector<ir::Instruction*> loads;
loads.reserve(replacements.size());
ir::BasicBlock::iterator where(load);
for (auto var : replacements) {
// Create a load of each replacement variable.
ir::Instruction* type = GetStorageType(var);
uint32_t loadId = TakeNextId();
std::unique_ptr<ir::Instruction> newLoad(
new ir::Instruction(context(), SpvOpLoad, type->result_id(), loadId,
std::initializer_list<ir::Operand>{
{SPV_OPERAND_TYPE_ID, {var->result_id()}}}));
// Copy memory access attributes which start at index 1. Index 0 is the
// pointer to load.
for (uint32_t i = 1; i < load->NumInOperands(); ++i) {
ir::Operand copy(load->GetInOperand(i));
newLoad->AddOperand(std::move(copy));
}
where = where.InsertBefore(std::move(newLoad));
get_def_use_mgr()->AnalyzeInstDefUse(&*where);
context()->set_instr_block(&*where, block);
loads.push_back(&*where);
}
// Construct a new composite.
uint32_t compositeId = TakeNextId();
where = load;
std::unique_ptr<ir::Instruction> compositeConstruct(new ir::Instruction(
context(), SpvOpCompositeConstruct, load->type_id(), compositeId, {}));
for (auto l : loads) {
ir::Operand op(SPV_OPERAND_TYPE_ID,
std::initializer_list<uint32_t>{l->result_id()});
compositeConstruct->AddOperand(std::move(op));
}
where = where.InsertBefore(std::move(compositeConstruct));
get_def_use_mgr()->AnalyzeInstDefUse(&*where);
context()->set_instr_block(&*where, block);
context()->ReplaceAllUsesWith(load->result_id(), compositeId);
}
void ScalarReplacementPass::ReplaceWholeStore(
ir::Instruction* store, const std::vector<ir::Instruction*>& replacements) {
// Replaces a store to the whole composite with a series of extract and stores
// to each element.
uint32_t storeInput = store->GetSingleWordInOperand(1u);
ir::BasicBlock* block = context()->get_instr_block(store);
ir::BasicBlock::iterator where(store);
uint32_t elementIndex = 0;
for (auto var : replacements) {
// Create the extract.
ir::Instruction* type = GetStorageType(var);
uint32_t extractId = TakeNextId();
std::unique_ptr<ir::Instruction> extract(new ir::Instruction(
context(), SpvOpCompositeExtract, type->result_id(), extractId,
std::initializer_list<ir::Operand>{
{SPV_OPERAND_TYPE_ID, {storeInput}},
{SPV_OPERAND_TYPE_LITERAL_INTEGER, {elementIndex++}}}));
auto iter = where.InsertBefore(std::move(extract));
get_def_use_mgr()->AnalyzeInstDefUse(&*iter);
context()->set_instr_block(&*iter, block);
// Create the store.
std::unique_ptr<ir::Instruction> newStore(
new ir::Instruction(context(), SpvOpStore, 0, 0,
std::initializer_list<ir::Operand>{
{SPV_OPERAND_TYPE_ID, {var->result_id()}},
{SPV_OPERAND_TYPE_ID, {extractId}}}));
// Copy memory access attributes which start at index 2. Index 0 is the
// pointer and index 1 is the data.
for (uint32_t i = 2; i < store->NumInOperands(); ++i) {
ir::Operand copy(store->GetInOperand(i));
newStore->AddOperand(std::move(copy));
}
iter = where.InsertBefore(std::move(newStore));
get_def_use_mgr()->AnalyzeInstDefUse(&*iter);
context()->set_instr_block(&*iter, block);
}
}
bool ScalarReplacementPass::ReplaceAccessChain(
ir::Instruction* chain, const std::vector<ir::Instruction*>& replacements) {
// Replaces the access chain with either another access chain (with one fewer
// indexes) or a direct use of the replacement variable.
uint32_t indexId = chain->GetSingleWordInOperand(1u);
const ir::Instruction* index = get_def_use_mgr()->GetDef(indexId);
size_t indexValue = GetConstantInteger(index);
if (indexValue > replacements.size()) {
// Out of bounds access, this is illegal IR.
return false;
} else {
const ir::Instruction* var = replacements[indexValue];
if (chain->NumInOperands() > 2) {
// Replace input access chain with another access chain.
ir::BasicBlock::iterator chainIter(chain);
uint32_t replacementId = TakeNextId();
std::unique_ptr<ir::Instruction> replacementChain(new ir::Instruction(
context(), chain->opcode(), chain->type_id(), replacementId,
std::initializer_list<ir::Operand>{
{SPV_OPERAND_TYPE_ID, {var->result_id()}}}));
// Add the remaining indexes.
for (uint32_t i = 2; i < chain->NumInOperands(); ++i) {
ir::Operand copy(chain->GetInOperand(i));
replacementChain->AddOperand(std::move(copy));
}
auto iter = chainIter.InsertBefore(std::move(replacementChain));
get_def_use_mgr()->AnalyzeInstDefUse(&*iter);
context()->set_instr_block(&*iter, context()->get_instr_block(chain));
context()->ReplaceAllUsesWith(chain->result_id(), replacementId);
} else {
// Replace with a use of the variable.
context()->ReplaceAllUsesWith(chain->result_id(), var->result_id());
}
}
return true;
}
void ScalarReplacementPass::CreateReplacementVariables(
ir::Instruction* inst, std::vector<ir::Instruction*>* replacements) {
ir::Instruction* type = GetStorageType(inst);
uint32_t elem = 0;
switch (type->opcode()) {
case SpvOpTypeStruct:
type->ForEachInOperand([this, inst, &elem, replacements](uint32_t* id) {
CreateVariable(*id, inst, elem++, replacements);
});
break;
case SpvOpTypeArray:
for (uint32_t i = 0; i != GetArrayLength(type); ++i) {
CreateVariable(type->GetSingleWordInOperand(0u), inst, i, replacements);
}
break;
case SpvOpTypeMatrix:
case SpvOpTypeVector:
for (uint32_t i = 0; i != GetNumElements(type); ++i) {
CreateVariable(type->GetSingleWordInOperand(0u), inst, i, replacements);
}
break;
default:
assert(false && "Unexpected type.");
break;
}
TransferAnnotations(inst, replacements);
}
void ScalarReplacementPass::TransferAnnotations(
const ir::Instruction* source,
std::vector<ir::Instruction*>* replacements) {
// Only transfer invariant and restrict decorations on the variable. There are
// no type or member decorations that are necessary to transfer.
for (auto inst :
get_decoration_mgr()->GetDecorationsFor(source->result_id(), false)) {
assert(inst->opcode() == SpvOpDecorate);
uint32_t decoration = inst->GetSingleWordInOperand(1u);
if (decoration == SpvDecorationInvariant ||
decoration == SpvDecorationRestrict) {
for (auto var : *replacements) {
std::unique_ptr<ir::Instruction> annotation(new ir::Instruction(
context(), SpvOpDecorate, 0, 0,
std::initializer_list<ir::Operand>{
{SPV_OPERAND_TYPE_ID, {var->result_id()}},
{SPV_OPERAND_TYPE_DECORATION, {decoration}}}));
for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
ir::Operand copy(inst->GetInOperand(i));
annotation->AddOperand(std::move(copy));
}
context()->AddAnnotationInst(std::move(annotation));
get_def_use_mgr()->AnalyzeInstUse(&*--context()->annotation_end());
}
}
}
}
void ScalarReplacementPass::CreateVariable(
uint32_t typeId, ir::Instruction* varInst, uint32_t index,
std::vector<ir::Instruction*>* replacements) {
uint32_t ptrId = GetOrCreatePointerType(typeId);
uint32_t id = TakeNextId();
std::unique_ptr<ir::Instruction> variable(new ir::Instruction(
context(), SpvOpVariable, ptrId, id,
std::initializer_list<ir::Operand>{
{SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}}}));
ir::BasicBlock* block = context()->get_instr_block(varInst);
block->begin().InsertBefore(std::move(variable));
ir::Instruction* inst = &*block->begin();
// If varInst was initialized, make sure to initialize its replacement.
GetOrCreateInitialValue(varInst, index, inst);
get_def_use_mgr()->AnalyzeInstDefUse(inst);
context()->set_instr_block(inst, block);
replacements->push_back(inst);
}
uint32_t ScalarReplacementPass::GetOrCreatePointerType(uint32_t id) {
auto iter = pointee_to_pointer_.find(id);
if (iter != pointee_to_pointer_.end()) return iter->second;
// TODO(alanbaker): Make the type manager useful and then replace this code.
uint32_t ptrId = 0;
for (auto global : context()->types_values()) {
if (global.opcode() == SpvOpTypePointer &&
global.GetSingleWordInOperand(0u) == SpvStorageClassFunction &&
global.GetSingleWordInOperand(1u) == id) {
if (!context()->get_feature_mgr()->HasExtension(
libspirv::Extension::kSPV_KHR_variable_pointers) ||
get_decoration_mgr()->GetDecorationsFor(id, false).empty()) {
// If variable pointers is enabled, only reuse a decoration-less
// pointer of the correct type
ptrId = global.result_id();
break;
}
}
}
if (ptrId != 0) {
pointee_to_pointer_[id] = ptrId;
return ptrId;
}
ptrId = TakeNextId();
context()->AddType(MakeUnique<ir::Instruction>(
context(), SpvOpTypePointer, 0, ptrId,
std::initializer_list<ir::Operand>{
{SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}},
{SPV_OPERAND_TYPE_ID, {id}}}));
ir::Instruction* ptr = &*--context()->types_values_end();
get_def_use_mgr()->AnalyzeInstDefUse(ptr);
pointee_to_pointer_[id] = ptrId;
return ptrId;
}
void ScalarReplacementPass::GetOrCreateInitialValue(ir::Instruction* source,
uint32_t index,
ir::Instruction* newVar) {
assert(source->opcode() == SpvOpVariable);
if (source->NumInOperands() < 2) return;
uint32_t initId = source->GetSingleWordInOperand(1u);
uint32_t storageId = GetStorageType(newVar)->result_id();
ir::Instruction* init = get_def_use_mgr()->GetDef(initId);
uint32_t newInitId = 0;
// TODO(dnovillo): Refactor this with constant propagation.
if (init->opcode() == SpvOpConstantNull) {
// Initialize to appropriate NULL.
auto iter = type_to_null_.find(storageId);
if (iter == type_to_null_.end()) {
newInitId = TakeNextId();
type_to_null_[storageId] = newInitId;
context()->AddGlobalValue(MakeUnique<ir::Instruction>(
context(), SpvOpConstantNull, storageId, newInitId,
std::initializer_list<ir::Operand>{}));
ir::Instruction* newNull = &*--context()->types_values_end();
get_def_use_mgr()->AnalyzeInstDefUse(newNull);
} else {
newInitId = iter->second;
}
} else if (ir::IsSpecConstantInst(init->opcode())) {
// Create a new constant extract.
newInitId = TakeNextId();
context()->AddGlobalValue(MakeUnique<ir::Instruction>(
context(), SpvOpSpecConstantOp, storageId, newInitId,
std::initializer_list<ir::Operand>{
{SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER, {SpvOpCompositeExtract}},
{SPV_OPERAND_TYPE_ID, {init->result_id()}},
{SPV_OPERAND_TYPE_LITERAL_INTEGER, {index}}}));
ir::Instruction* newSpecConst = &*--context()->types_values_end();
get_def_use_mgr()->AnalyzeInstDefUse(newSpecConst);
} else if (init->opcode() == SpvOpConstantComposite) {
// Get the appropriate index constant.
newInitId = init->GetSingleWordInOperand(index);
ir::Instruction* element = get_def_use_mgr()->GetDef(newInitId);
if (element->opcode() == SpvOpUndef) {
// Undef is not a valid initializer for a variable.
newInitId = 0;
}
} else {
assert(false);
}
if (newInitId != 0) {
newVar->AddOperand({SPV_OPERAND_TYPE_ID, {newInitId}});
}
}
size_t ScalarReplacementPass::GetIntegerLiteral(const ir::Operand& op) const {
assert(op.words.size() <= 2);
size_t len = 0;
for (uint32_t i = 0; i != op.words.size(); ++i) {
len |= (op.words[i] << (32 * i));
}
return len;
}
size_t ScalarReplacementPass::GetConstantInteger(
const ir::Instruction* constant) const {
assert(get_def_use_mgr()->GetDef(constant->type_id())->opcode() ==
SpvOpTypeInt);
assert(constant->opcode() == SpvOpConstant ||
constant->opcode() == SpvOpConstantNull);
if (constant->opcode() == SpvOpConstantNull) {
return 0;
}
const ir::Operand& op = constant->GetInOperand(0u);
return GetIntegerLiteral(op);
}
size_t ScalarReplacementPass::GetArrayLength(
const ir::Instruction* arrayType) const {
assert(arrayType->opcode() == SpvOpTypeArray);
const ir::Instruction* length =
get_def_use_mgr()->GetDef(arrayType->GetSingleWordInOperand(1u));
return GetConstantInteger(length);
}
size_t ScalarReplacementPass::GetNumElements(
const ir::Instruction* type) const {
assert(type->opcode() == SpvOpTypeVector ||
type->opcode() == SpvOpTypeMatrix);
const ir::Operand& op = type->GetInOperand(1u);
assert(op.words.size() <= 2);
size_t len = 0;
for (uint32_t i = 0; i != op.words.size(); ++i) {
len |= (op.words[i] << (32 * i));
}
return len;
}
ir::Instruction* ScalarReplacementPass::GetStorageType(
const ir::Instruction* inst) const {
assert(inst->opcode() == SpvOpVariable);
uint32_t ptrTypeId = inst->type_id();
uint32_t typeId =
get_def_use_mgr()->GetDef(ptrTypeId)->GetSingleWordInOperand(1u);
return get_def_use_mgr()->GetDef(typeId);
}
bool ScalarReplacementPass::CanReplaceVariable(
const ir::Instruction* varInst) const {
assert(varInst->opcode() == SpvOpVariable);
// Can only replace function scope variables.
if (varInst->GetSingleWordInOperand(0u) != SpvStorageClassFunction)
return false;
if (!CheckTypeAnnotations(get_def_use_mgr()->GetDef(varInst->type_id())))
return false;
const ir::Instruction* typeInst = GetStorageType(varInst);
return CheckType(typeInst) && CheckAnnotations(varInst) && CheckUses(varInst);
}
bool ScalarReplacementPass::CheckType(const ir::Instruction* typeInst) const {
if (!CheckTypeAnnotations(typeInst)) return false;
switch (typeInst->opcode()) {
case SpvOpTypeStruct:
// Don't bother with empty structs or very large structs.
if (typeInst->NumInOperands() == 0 ||
typeInst->NumInOperands() > MAX_NUM_ELEMENTS)
return false;
return true;
case SpvOpTypeArray:
if (GetArrayLength(typeInst) > MAX_NUM_ELEMENTS) return false;
return true;
// TODO(alanbaker): Develop some heuristics for when this should be
// re-enabled.
//// Specifically including matrix and vector in an attempt to reduce the
//// number of vector registers required.
// case SpvOpTypeMatrix:
// case SpvOpTypeVector:
// if (GetNumElements(typeInst) > MAX_NUM_ELEMENTS) return false;
// return true;
case SpvOpTypeRuntimeArray:
default:
return false;
}
}
bool ScalarReplacementPass::CheckTypeAnnotations(
const ir::Instruction* typeInst) const {
for (auto inst :
get_decoration_mgr()->GetDecorationsFor(typeInst->result_id(), false)) {
uint32_t decoration;
if (inst->opcode() == SpvOpDecorate) {
decoration = inst->GetSingleWordInOperand(1u);
} else {
assert(inst->opcode() == SpvOpMemberDecorate);
decoration = inst->GetSingleWordInOperand(2u);
}
switch (decoration) {
case SpvDecorationRowMajor:
case SpvDecorationColMajor:
case SpvDecorationArrayStride:
case SpvDecorationMatrixStride:
case SpvDecorationCPacked:
case SpvDecorationInvariant:
case SpvDecorationRestrict:
case SpvDecorationOffset:
case SpvDecorationAlignment:
case SpvDecorationAlignmentId:
case SpvDecorationMaxByteOffset:
break;
default:
return false;
}
}
return true;
}
bool ScalarReplacementPass::CheckAnnotations(
const ir::Instruction* varInst) const {
for (auto inst :
get_decoration_mgr()->GetDecorationsFor(varInst->result_id(), false)) {
assert(inst->opcode() == SpvOpDecorate);
uint32_t decoration = inst->GetSingleWordInOperand(1u);
switch (decoration) {
case SpvDecorationInvariant:
case SpvDecorationRestrict:
case SpvDecorationAlignment:
case SpvDecorationAlignmentId:
case SpvDecorationMaxByteOffset:
break;
default:
return false;
}
}
return true;
}
bool ScalarReplacementPass::CheckUses(const ir::Instruction* inst) const {
VariableStats stats = {0, 0};
bool ok = CheckUses(inst, &stats);
// TODO(alanbaker): Extend this to some meaningful heuristics about when
// SRoA is valuable.
if (stats.num_partial_accesses == 0) ok = false;
return ok;
}
bool ScalarReplacementPass::CheckUses(const ir::Instruction* inst,
VariableStats* stats) const {
bool ok = true;
get_def_use_mgr()->ForEachUse(
inst, [this, stats, &ok](const ir::Instruction* user, uint32_t index) {
// Annotations are check as a group separately.
if (!ir::IsAnnotationInst(user->opcode())) {
switch (user->opcode()) {
case SpvOpAccessChain:
case SpvOpInBoundsAccessChain:
if (index == 2u) {
uint32_t id = user->GetSingleWordOperand(3u);
const ir::Instruction* opInst = get_def_use_mgr()->GetDef(id);
if (!ir::IsCompileTimeConstantInst(opInst->opcode())) {
ok = false;
} else {
if (!CheckUsesRelaxed(user)) ok = false;
}
stats->num_partial_accesses++;
} else {
ok = false;
}
break;
case SpvOpLoad:
if (!CheckLoad(user, index)) ok = false;
stats->num_full_accesses++;
break;
case SpvOpStore:
if (!CheckStore(user, index)) ok = false;
stats->num_full_accesses++;
break;
case SpvOpName:
case SpvOpMemberName:
break;
default:
ok = false;
break;
}
}
});
return ok;
}
bool ScalarReplacementPass::CheckUsesRelaxed(
const ir::Instruction* inst) const {
bool ok = true;
get_def_use_mgr()->ForEachUse(
inst, [this, &ok](const ir::Instruction* user, uint32_t index) {
switch (user->opcode()) {
case SpvOpAccessChain:
case SpvOpInBoundsAccessChain:
if (index != 2u) {
ok = false;
} else {
if (!CheckUsesRelaxed(user)) ok = false;
}
break;
case SpvOpLoad:
if (!CheckLoad(user, index)) ok = false;
break;
case SpvOpStore:
if (!CheckStore(user, index)) ok = false;
break;
default:
ok = false;
break;
}
});
return ok;
}
bool ScalarReplacementPass::CheckLoad(const ir::Instruction* inst,
uint32_t index) const {
if (index != 2u) return false;
if (inst->NumInOperands() >= 2 &&
inst->GetSingleWordInOperand(1u) & SpvMemoryAccessVolatileMask)
return false;
return true;
}
bool ScalarReplacementPass::CheckStore(const ir::Instruction* inst,
uint32_t index) const {
if (index != 0u) return false;
if (inst->NumInOperands() >= 3 &&
inst->GetSingleWordInOperand(2u) & SpvMemoryAccessVolatileMask)
return false;
return true;
}
} // namespace opt
} // namespace spvtools