spirv-diff: Refactor instruction grouping and matching (#4760)

In preparation for supporting OpTypeForwardPointer, which adds more
usages like this.  This change refactors common code used to group
instructions and match the groups.
This commit is contained in:
Shahbaz Youssefi 2022-03-24 14:04:48 -04:00 committed by GitHub
parent 90728d2dff
commit 48c8363f0a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -40,8 +40,8 @@ using FunctionInstMap = std::map<uint32_t, InstructionList>;
// A list of ids with some similar property, for example functions with the same
// name.
using IdGroup = std::vector<uint32_t>;
// A map of function names to function ids with the same name. This is an
// ordered map so different implementations produce identical results.
// A map of names to ids with the same name. This is an ordered map so
// different implementations produce identical results.
using IdGroupMapByName = std::map<std::string, IdGroup>;
using IdGroupMapByTypeId = std::map<uint32_t, IdGroup>;
@ -268,7 +268,7 @@ class Differ {
// Helper functions that match ids between src and dst
void PoolPotentialIds(
opt::IteratorRange<opt::Module::const_inst_iterator> section,
std::vector<uint32_t>& ids,
std::vector<uint32_t>& ids, bool is_src,
std::function<bool(const opt::Instruction&)> filter,
std::function<uint32_t(const opt::Instruction&)> get_id);
void MatchIds(
@ -292,6 +292,42 @@ class Differ {
opt::IteratorRange<opt::Module::const_inst_iterator> src_insts,
opt::IteratorRange<opt::Module::const_inst_iterator> dst_insts);
// Get various properties from an id. These Helper functions are passed to
// `GroupIds` and `GroupIdsAndMatch` below (as the `get_group` argument).
uint32_t GroupIdsHelperGetTypeId(const IdInstructions& id_to, uint32_t id);
// Given a list of ids, groups them based on some value. The `get_group`
// function extracts a piece of information corresponding to each id, and the
// ids are bucketed based on that (and output in `groups`). This is useful to
// attempt to match ids between src and dst only when said property is
// identical.
template <typename T>
void GroupIds(const IdGroup& ids, bool is_src, std::map<T, IdGroup>* groups,
T (Differ::*get_group)(const IdInstructions&, uint32_t));
// Calls GroupIds to bucket ids in src and dst based on a property returned by
// `get_group`. This function then calls `match_group` for each bucket (i.e.
// "group") with identical values for said property.
//
// For example, say src and dst ids have the following properties
// correspondingly:
//
// - src ids' properties: {id0: A, id1: A, id2: B, id3: C, id4: B}
// - dst ids' properties: {id0': B, id1': C, id2': B, id3': D, id4': B}
//
// Then `match_group` is called 2 times:
//
// - Once with: ([id2, id4], [id0', id2', id4']) corresponding to B
// - Once with: ([id3], [id2']) corresponding to C
//
// Ids corresponding to A and D cannot match based on this property.
template <typename T>
void GroupIdsAndMatch(
const IdGroup& src_ids, const IdGroup& dst_ids, T invalid_group_key,
T (Differ::*get_group)(const IdInstructions&, uint32_t),
std::function<void(const IdGroup& src_group, const IdGroup& dst_group)>
match_group);
// Helper functions that determine if two instructions match
bool DoIdsMatch(uint32_t src_id, uint32_t dst_id);
bool DoesOperandMatch(const opt::Operand& src_operand,
@ -335,14 +371,6 @@ class Differ {
FunctionInstMap* function_insts);
void GetFunctionHeaderInstructions(const opt::Module* module,
FunctionInstMap* function_insts);
void GroupIdsByName(const IdGroup& functions, bool is_src,
IdGroupMapByName* groups);
void GroupIdsByTypeId(const IdGroup& functions, bool is_src,
IdGroupMapByTypeId* groups);
template <typename T>
void GroupIds(const IdGroup& functions, bool is_src,
std::map<T, IdGroup>* groups,
std::function<T(const IdInstructions, uint32_t)> get_group);
void BestEffortMatchFunctions(const IdGroup& src_func_ids,
const IdGroup& dst_func_ids,
const FunctionInstMap& src_func_insts,
@ -374,14 +402,17 @@ class Differ {
uint32_t GetConstantUint(const IdInstructions& id_to, uint32_t constant_id);
SpvExecutionModel GetExecutionModel(const opt::Module* module,
uint32_t entry_point_id);
// Get the OpName associated with an id
std::string GetName(const IdInstructions& id_to, uint32_t id, bool* has_name);
std::string GetFunctionName(const IdInstructions& id_to, uint32_t id);
// Get the OpName associated with an id, with argument types stripped for
// functions. Some tools don't encode function argument types in the OpName
// string, and this improves diff between SPIR-V from those tools and others.
std::string GetSanitizedName(const IdInstructions& id_to, uint32_t id);
uint32_t GetVarTypeId(const IdInstructions& id_to, uint32_t var_id,
SpvStorageClass* storage_class);
bool GetDecorationValue(const IdInstructions& id_to, uint32_t id,
SpvDecoration decoration, uint32_t* decoration_value);
bool IsIntType(const IdInstructions& id_to, uint32_t type_id);
// bool IsUintType(const IdInstructions& id_to, uint32_t type_id);
bool IsFloatType(const IdInstructions& id_to, uint32_t type_id);
bool IsConstantUint(const IdInstructions& id_to, uint32_t id);
bool IsVariable(const IdInstructions& id_to, uint32_t pointer_id);
@ -548,18 +579,27 @@ void IdInstructions::MapIdsToInfos(
void Differ::PoolPotentialIds(
opt::IteratorRange<opt::Module::const_inst_iterator> section,
std::vector<uint32_t>& ids,
std::vector<uint32_t>& ids, bool is_src,
std::function<bool(const opt::Instruction&)> filter,
std::function<uint32_t(const opt::Instruction&)> get_id) {
for (const opt::Instruction& inst : section) {
if (!filter(inst)) {
continue;
}
uint32_t result_id = get_id(inst);
assert(result_id != 0);
assert(std::find(ids.begin(), ids.end(), result_id) == ids.end());
// Don't include ids that are already matched, for example through
// OpTypeForwardPointer.
const bool is_matched = is_src ? id_map_.IsSrcMapped(result_id)
: id_map_.IsDstMapped(result_id);
if (is_matched) {
continue;
}
ids.push_back(result_id);
}
}
@ -748,6 +788,62 @@ void Differ::MatchDebugAndAnnotationInstructions(
}
}
uint32_t Differ::GroupIdsHelperGetTypeId(const IdInstructions& id_to,
uint32_t id) {
return GetInst(id_to, id)->type_id();
}
template <typename T>
void Differ::GroupIds(const IdGroup& ids, bool is_src,
std::map<T, IdGroup>* groups,
T (Differ::*get_group)(const IdInstructions&, uint32_t)) {
assert(groups->empty());
const IdInstructions& id_to = is_src ? src_id_to_ : dst_id_to_;
for (const uint32_t id : ids) {
// Don't include ids that are already matched, for example through
// OpEntryPoint.
const bool is_matched =
is_src ? id_map_.IsSrcMapped(id) : id_map_.IsDstMapped(id);
if (is_matched) {
continue;
}
T group = (this->*get_group)(id_to, id);
(*groups)[group].push_back(id);
}
}
template <typename T>
void Differ::GroupIdsAndMatch(
const IdGroup& src_ids, const IdGroup& dst_ids, T invalid_group_key,
T (Differ::*get_group)(const IdInstructions&, uint32_t),
std::function<void(const IdGroup& src_group, const IdGroup& dst_group)>
match_group) {
// Group the ids based on a key (get_group)
std::map<T, IdGroup> src_groups;
std::map<T, IdGroup> dst_groups;
GroupIds<T>(src_ids, true, &src_groups, get_group);
GroupIds<T>(dst_ids, false, &dst_groups, get_group);
// Iterate over the groups, and match those with identical keys
for (const auto& iter : src_groups) {
const T& key = iter.first;
const IdGroup& src_group = iter.second;
if (key == invalid_group_key) {
continue;
}
const IdGroup& dst_group = dst_groups[key];
// Let the caller match the groups as appropriate.
match_group(src_group, dst_group);
}
}
bool Differ::DoIdsMatch(uint32_t src_id, uint32_t dst_id) {
assert(dst_id != 0);
return id_map_.MappedDstId(src_id) == dst_id;
@ -1319,28 +1415,6 @@ void Differ::GetFunctionHeaderInstructions(const opt::Module* module,
}
}
template <typename T>
void Differ::GroupIds(
const IdGroup& functions, bool is_src, std::map<T, IdGroup>* groups,
std::function<T(const IdInstructions, uint32_t)> get_group) {
assert(groups->empty());
const IdInstructions& id_to = is_src ? src_id_to_ : dst_id_to_;
for (const uint32_t func_id : functions) {
// Don't include functions that are already matched, for example through
// OpEntryPoint.
const bool is_matched =
is_src ? id_map_.IsSrcMapped(func_id) : id_map_.IsDstMapped(func_id);
if (is_matched) {
continue;
}
T group = get_group(id_to, func_id);
(*groups)[group].push_back(func_id);
}
}
void Differ::BestEffortMatchFunctions(const IdGroup& src_func_ids,
const IdGroup& dst_func_ids,
const FunctionInstMap& src_func_insts,
@ -1361,7 +1435,7 @@ void Differ::BestEffortMatchFunctions(const IdGroup& src_func_ids,
if (id_map_.IsSrcMapped(src_func_id)) {
continue;
}
const std::string src_name = GetFunctionName(src_id_to_, src_func_id);
const std::string src_name = GetSanitizedName(src_id_to_, src_func_id);
for (const uint32_t dst_func_id : dst_func_ids) {
if (id_map_.IsDstMapped(dst_func_id)) {
@ -1369,7 +1443,7 @@ void Differ::BestEffortMatchFunctions(const IdGroup& src_func_ids,
}
// Don't match functions that are named, but the names are different.
const std::string dst_name = GetFunctionName(dst_id_to_, dst_func_id);
const std::string dst_name = GetSanitizedName(dst_id_to_, dst_func_id);
if (src_name != "" && dst_name != "" && src_name != dst_name) {
continue;
}
@ -1406,22 +1480,6 @@ void Differ::BestEffortMatchFunctions(const IdGroup& src_func_ids,
}
}
void Differ::GroupIdsByName(const IdGroup& functions, bool is_src,
IdGroupMapByName* groups) {
GroupIds<std::string>(functions, is_src, groups,
[this](const IdInstructions& id_to, uint32_t func_id) {
return GetFunctionName(id_to, func_id);
});
}
void Differ::GroupIdsByTypeId(const IdGroup& functions, bool is_src,
IdGroupMapByTypeId* groups) {
GroupIds<uint32_t>(functions, is_src, groups,
[this](const IdInstructions& id_to, uint32_t func_id) {
return GetInst(id_to, func_id)->type_id();
});
}
void Differ::MatchFunctionParamIds(const opt::Function* src_func,
const opt::Function* dst_func) {
IdGroup src_params;
@ -1437,52 +1495,33 @@ void Differ::MatchFunctionParamIds(const opt::Function* src_func,
},
false);
IdGroupMapByName src_param_groups;
IdGroupMapByName dst_param_groups;
GroupIdsAndMatch<std::string>(
src_params, dst_params, "", &Differ::GetSanitizedName,
[this](const IdGroup& src_group, const IdGroup& dst_group) {
GroupIdsByName(src_params, true, &src_param_groups);
GroupIdsByName(dst_params, false, &dst_param_groups);
// Match parameters with identical names.
for (const auto& src_param_group : src_param_groups) {
const std::string& name = src_param_group.first;
const IdGroup& src_group = src_param_group.second;
if (name == "") {
continue;
}
const IdGroup& dst_group = dst_param_groups[name];
// There shouldn't be two parameters with the same name, so the ids should
// match. There is nothing restricting the SPIR-V however to have two
// parameters with the same name, so be resilient against that.
if (src_group.size() == 1 && dst_group.size() == 1) {
id_map_.MapIds(src_group[0], dst_group[0]);
}
}
// There shouldn't be two parameters with the same name, so the ids
// should match. There is nothing restricting the SPIR-V however to have
// two parameters with the same name, so be resilient against that.
if (src_group.size() == 1 && dst_group.size() == 1) {
id_map_.MapIds(src_group[0], dst_group[0]);
}
});
// Then match the parameters by their type. If there are multiple of them,
// match them by their order.
IdGroupMapByTypeId src_param_groups_by_type_id;
IdGroupMapByTypeId dst_param_groups_by_type_id;
GroupIdsAndMatch<uint32_t>(
src_params, dst_params, 0, &Differ::GroupIdsHelperGetTypeId,
[this](const IdGroup& src_group_by_type_id,
const IdGroup& dst_group_by_type_id) {
GroupIdsByTypeId(src_params, true, &src_param_groups_by_type_id);
GroupIdsByTypeId(dst_params, false, &dst_param_groups_by_type_id);
const size_t shared_param_count =
std::min(src_group_by_type_id.size(), dst_group_by_type_id.size());
for (const auto& src_param_group_by_type_id : src_param_groups_by_type_id) {
const uint32_t type_id = src_param_group_by_type_id.first;
const IdGroup& src_group_by_type_id = src_param_group_by_type_id.second;
const IdGroup& dst_group_by_type_id = dst_param_groups_by_type_id[type_id];
const size_t shared_param_count =
std::min(src_group_by_type_id.size(), dst_group_by_type_id.size());
for (size_t param_index = 0; param_index < shared_param_count;
++param_index) {
id_map_.MapIds(src_group_by_type_id[0], dst_group_by_type_id[0]);
}
}
for (size_t param_index = 0; param_index < shared_param_count;
++param_index) {
id_map_.MapIds(src_group_by_type_id[0], dst_group_by_type_id[0]);
}
});
}
float Differ::MatchFunctionBodies(const InstructionList& src_body,
@ -1626,7 +1665,7 @@ std::string Differ::GetName(const IdInstructions& id_to, uint32_t id,
return "";
}
std::string Differ::GetFunctionName(const IdInstructions& id_to, uint32_t id) {
std::string Differ::GetSanitizedName(const IdInstructions& id_to, uint32_t id) {
bool has_name = false;
std::string name = GetName(id_to, id, &has_name);
@ -1634,7 +1673,7 @@ std::string Differ::GetFunctionName(const IdInstructions& id_to, uint32_t id) {
return "";
}
// Remove args from the name
// Remove args from the name, in case this is a function name
return name.substr(0, name.find('('));
}
@ -1672,19 +1711,8 @@ bool Differ::GetDecorationValue(const IdInstructions& id_to, uint32_t id,
bool Differ::IsIntType(const IdInstructions& id_to, uint32_t type_id) {
return IsOp(id_to, type_id, SpvOpTypeInt);
#if 0
const opt::Instruction *type_inst = GetInst(id_to, type_id);
return type_inst->opcode() == SpvOpTypeInt && type_inst->GetInOperand(1).words[0] != 0;
#endif
}
#if 0
bool Differ::IsUintType(const IdInstructions& id_to, uint32_t type_id) {
const opt::Instruction *type_inst = GetInst(id_to, type_id);
return type_inst->opcode() == SpvOpTypeInt && type_inst->GetInOperand(1).words[0] == 0;
}
#endif
bool Differ::IsFloatType(const IdInstructions& id_to, uint32_t type_id) {
return IsOp(id_to, type_id, SpvOpTypeFloat);
}
@ -1853,9 +1881,9 @@ void Differ::MatchExtInstImportIds() {
};
auto accept_all = [](const opt::Instruction&) { return true; };
PoolPotentialIds(src_->ext_inst_imports(), potential_id_map.src_ids,
PoolPotentialIds(src_->ext_inst_imports(), potential_id_map.src_ids, true,
accept_all, get_result_id);
PoolPotentialIds(dst_->ext_inst_imports(), potential_id_map.dst_ids,
PoolPotentialIds(dst_->ext_inst_imports(), potential_id_map.dst_ids, false,
accept_all, get_result_id);
// Then match the ids.
@ -1949,9 +1977,9 @@ void Differ::MatchTypeIds() {
return spvOpcodeGeneratesType(inst.opcode());
};
PoolPotentialIds(src_->types_values(), potential_id_map.src_ids,
PoolPotentialIds(src_->types_values(), potential_id_map.src_ids, true,
accept_type_ops, get_result_id);
PoolPotentialIds(dst_->types_values(), potential_id_map.dst_ids,
PoolPotentialIds(dst_->types_values(), potential_id_map.dst_ids, false,
accept_type_ops, get_result_id);
// Then match the ids. Start with exact matches, then match the leftover with
@ -2036,9 +2064,9 @@ void Differ::MatchConstants() {
return spvOpcodeIsConstant(inst.opcode());
};
PoolPotentialIds(src_->types_values(), potential_id_map.src_ids,
PoolPotentialIds(src_->types_values(), potential_id_map.src_ids, true,
accept_type_ops, get_result_id);
PoolPotentialIds(dst_->types_values(), potential_id_map.dst_ids,
PoolPotentialIds(dst_->types_values(), potential_id_map.dst_ids, false,
accept_type_ops, get_result_id);
// Then match the ids. Constants are matched exactly, except for float types
@ -2115,9 +2143,9 @@ void Differ::MatchVariableIds() {
return inst.opcode() == SpvOpVariable;
};
PoolPotentialIds(src_->types_values(), potential_id_map.src_ids,
PoolPotentialIds(src_->types_values(), potential_id_map.src_ids, true,
accept_type_ops, get_result_id);
PoolPotentialIds(dst_->types_values(), potential_id_map.dst_ids,
PoolPotentialIds(dst_->types_values(), potential_id_map.dst_ids, false,
accept_type_ops, get_result_id);
// Then match the ids. Start with exact matches, then match the leftover with
@ -2148,49 +2176,31 @@ void Differ::MatchFunctions() {
}
// Base the matching of functions on debug info when available.
IdGroupMapByName src_func_groups;
IdGroupMapByName dst_func_groups;
GroupIdsAndMatch<std::string>(
src_func_ids, dst_func_ids, "", &Differ::GetSanitizedName,
[this](const IdGroup& src_group, const IdGroup& dst_group) {
GroupIdsByName(src_func_ids, true, &src_func_groups);
GroupIdsByName(dst_func_ids, false, &dst_func_groups);
// If there is a single function with this name in src and dst, it's a
// definite match.
if (src_group.size() == 1 && dst_group.size() == 1) {
id_map_.MapIds(src_group[0], dst_group[0]);
return;
}
// Match functions with identical names.
for (const auto& src_func_group : src_func_groups) {
const std::string& name = src_func_group.first;
const IdGroup& src_group = src_func_group.second;
// If there are multiple functions with the same name, group them by
// type, and match only if the types match (and are unique).
GroupIdsAndMatch<uint32_t>(src_group, dst_group, 0,
&Differ::GroupIdsHelperGetTypeId,
[this](const IdGroup& src_group_by_type_id,
const IdGroup& dst_group_by_type_id) {
if (name == "") {
continue;
}
const IdGroup& dst_group = dst_func_groups[name];
// If there is a single function with this name in src and dst, it's a
// definite match.
if (src_group.size() == 1 && dst_group.size() == 1) {
id_map_.MapIds(src_group[0], dst_group[0]);
continue;
}
// If there are multiple functions with the same name, group them by type,
// and match only if the types match (and are unique).
IdGroupMapByTypeId src_func_groups_by_type_id;
IdGroupMapByTypeId dst_func_groups_by_type_id;
GroupIdsByTypeId(src_group, true, &src_func_groups_by_type_id);
GroupIdsByTypeId(dst_group, false, &dst_func_groups_by_type_id);
for (const auto& src_func_group_by_type_id : src_func_groups_by_type_id) {
const uint32_t type_id = src_func_group_by_type_id.first;
const IdGroup& src_group_by_type_id = src_func_group_by_type_id.second;
const IdGroup& dst_group_by_type_id = dst_func_groups_by_type_id[type_id];
if (src_group_by_type_id.size() == 1 &&
dst_group_by_type_id.size() == 1) {
id_map_.MapIds(src_group_by_type_id[0], dst_group_by_type_id[0]);
}
}
}
if (src_group_by_type_id.size() == 1 &&
dst_group_by_type_id.size() == 1) {
id_map_.MapIds(src_group_by_type_id[0],
dst_group_by_type_id[0]);
}
});
});
// Any functions that are left are pooled together and matched as if unnamed,
// with the only exception that two functions with mismatching names are not
@ -2224,20 +2234,14 @@ void Differ::MatchFunctions() {
}
// Best effort match functions with matching type.
IdGroupMapByTypeId src_func_groups_by_type_id;
IdGroupMapByTypeId dst_func_groups_by_type_id;
GroupIdsAndMatch<uint32_t>(
src_func_ids, dst_func_ids, 0, &Differ::GroupIdsHelperGetTypeId,
[this](const IdGroup& src_group_by_type_id,
const IdGroup& dst_group_by_type_id) {
GroupIdsByTypeId(src_func_ids, true, &src_func_groups_by_type_id);
GroupIdsByTypeId(dst_func_ids, false, &dst_func_groups_by_type_id);
for (const auto& src_func_group_by_type_id : src_func_groups_by_type_id) {
const uint32_t type_id = src_func_group_by_type_id.first;
const IdGroup& src_group_by_type_id = src_func_group_by_type_id.second;
const IdGroup& dst_group_by_type_id = dst_func_groups_by_type_id[type_id];
BestEffortMatchFunctions(src_group_by_type_id, dst_group_by_type_id,
src_func_insts_, dst_func_insts_);
}
BestEffortMatchFunctions(src_group_by_type_id, dst_group_by_type_id,
src_func_insts_, dst_func_insts_);
});
// Any function that's left, best effort match them.
BestEffortMatchFunctions(src_func_ids, dst_func_ids, src_func_insts_,