spirv-fuzz: Fix memory management in the fact manager (#3313)

Fixes a bug where, while recursively adding id equation facts, a
reference to a set of id equations could be used after it had been
freed (due to equivalence classes of equations being merged).
This commit is contained in:
Alastair Donaldson 2020-04-27 14:24:11 +01:00 committed by GitHub
parent d158ffe540
commit b74199a22d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -430,6 +430,9 @@ class FactManager::DataSynonymAndIdEquationFacts {
uint32_t maximum_equivalence_class_size);
private:
using OperationSet =
std::unordered_set<Operation, OperationHash, OperationEquals>;
// Adds the synonym |dd1| = |dd2| to the set of managed facts, and recurses
// into sub-components of the data descriptors, if they are composites, to
// record that their components are pairwise-synonymous.
@ -448,6 +451,8 @@ class FactManager::DataSynonymAndIdEquationFacts {
opt::IRContext* context, const protobufs::DataDescriptor& dd1,
const protobufs::DataDescriptor& dd2) const;
OperationSet GetEquations(const protobufs::DataDescriptor* lhs) const;
// Requires that |lhs_dd| and every element of |rhs_dds| is present in the
// |synonymous_| equivalence relation, but is not necessarily its own
// representative. Records the fact that the equation
@ -480,9 +485,7 @@ class FactManager::DataSynonymAndIdEquationFacts {
// All data descriptors occurring in equations are required to be present in
// the |synonymous_| equivalence relation, and to be their own representatives
// in that relation.
std::unordered_map<
const protobufs::DataDescriptor*,
std::unordered_set<Operation, OperationHash, OperationEquals>>
std::unordered_map<const protobufs::DataDescriptor*, OperationSet>
id_equations_;
};
@ -520,6 +523,16 @@ void FactManager::DataSynonymAndIdEquationFacts::AddFact(
rhs_dd_ptrs, context);
}
FactManager::DataSynonymAndIdEquationFacts::OperationSet
FactManager::DataSynonymAndIdEquationFacts::GetEquations(
const protobufs::DataDescriptor* lhs) const {
auto existing = id_equations_.find(lhs);
if (existing == id_equations_.end()) {
return OperationSet();
}
return existing->second;
}
void FactManager::DataSynonymAndIdEquationFacts::AddEquationFactRecursive(
const protobufs::DataDescriptor& lhs_dd, SpvOp opcode,
const std::vector<const protobufs::DataDescriptor*>& rhs_dds,
@ -538,9 +551,7 @@ void FactManager::DataSynonymAndIdEquationFacts::AddEquationFactRecursive(
if (id_equations_.count(lhs_dd_representative) == 0) {
// We have not seen an equation with this LHS before, so associate the LHS
// with an initially empty set.
id_equations_.insert(
{lhs_dd_representative,
std::unordered_set<Operation, OperationHash, OperationEquals>()});
id_equations_.insert({lhs_dd_representative, OperationSet()});
}
{
@ -562,44 +573,29 @@ void FactManager::DataSynonymAndIdEquationFacts::AddEquationFactRecursive(
switch (opcode) {
case SpvOpIAdd: {
// Equation form: "a = b + c"
{
auto existing_first_operand_equations = id_equations_.find(rhs_dds[0]);
if (existing_first_operand_equations != id_equations_.end()) {
for (auto equation : existing_first_operand_equations->second) {
if (equation.opcode == SpvOpISub) {
// Equation form: "a = (d - e) + c"
if (synonymous_.IsEquivalent(*equation.operands[1],
*rhs_dds[1])) {
// Equation form: "a = (d - c) + c"
// We can thus infer "a = d"
AddDataSynonymFactRecursive(lhs_dd, *equation.operands[0],
context);
}
if (synonymous_.IsEquivalent(*equation.operands[0],
*rhs_dds[1])) {
// Equation form: "a = (c - e) + c"
// We can thus infer "a = -e"
AddEquationFactRecursive(lhs_dd, SpvOpSNegate,
{equation.operands[1]}, context);
}
}
for (auto equation : GetEquations(rhs_dds[0])) {
if (equation.opcode == SpvOpISub) {
// Equation form: "a = (d - e) + c"
if (synonymous_.IsEquivalent(*equation.operands[1], *rhs_dds[1])) {
// Equation form: "a = (d - c) + c"
// We can thus infer "a = d"
AddDataSynonymFactRecursive(lhs_dd, *equation.operands[0], context);
}
if (synonymous_.IsEquivalent(*equation.operands[0], *rhs_dds[1])) {
// Equation form: "a = (c - e) + c"
// We can thus infer "a = -e"
AddEquationFactRecursive(lhs_dd, SpvOpSNegate,
{equation.operands[1]}, context);
}
}
}
{
auto existing_second_operand_equations = id_equations_.find(rhs_dds[1]);
if (existing_second_operand_equations != id_equations_.end()) {
for (auto equation : existing_second_operand_equations->second) {
if (equation.opcode == SpvOpISub) {
// Equation form: "a = b + (d - e)"
if (synonymous_.IsEquivalent(*equation.operands[1],
*rhs_dds[0])) {
// Equation form: "a = b + (d - b)"
// We can thus infer "a = d"
AddDataSynonymFactRecursive(lhs_dd, *equation.operands[0],
context);
}
}
for (auto equation : GetEquations(rhs_dds[1])) {
if (equation.opcode == SpvOpISub) {
// Equation form: "a = b + (d - e)"
if (synonymous_.IsEquivalent(*equation.operands[1], *rhs_dds[0])) {
// Equation form: "a = b + (d - b)"
// We can thus infer "a = d"
AddDataSynonymFactRecursive(lhs_dd, *equation.operands[0], context);
}
}
}
@ -607,73 +603,54 @@ void FactManager::DataSynonymAndIdEquationFacts::AddEquationFactRecursive(
}
case SpvOpISub: {
// Equation form: "a = b - c"
{
auto existing_first_operand_equations = id_equations_.find(rhs_dds[0]);
if (existing_first_operand_equations != id_equations_.end()) {
for (auto equation : existing_first_operand_equations->second) {
if (equation.opcode == SpvOpIAdd) {
// Equation form: "a = (d + e) - c"
if (synonymous_.IsEquivalent(*equation.operands[0],
*rhs_dds[1])) {
// Equation form: "a = (c + e) - c"
// We can thus infer "a = e"
AddDataSynonymFactRecursive(lhs_dd, *equation.operands[1],
context);
}
if (synonymous_.IsEquivalent(*equation.operands[1],
*rhs_dds[1])) {
// Equation form: "a = (d + c) - c"
// We can thus infer "a = d"
AddDataSynonymFactRecursive(lhs_dd, *equation.operands[0],
context);
}
}
for (auto equation : GetEquations(rhs_dds[0])) {
if (equation.opcode == SpvOpIAdd) {
// Equation form: "a = (d + e) - c"
if (synonymous_.IsEquivalent(*equation.operands[0], *rhs_dds[1])) {
// Equation form: "a = (c + e) - c"
// We can thus infer "a = e"
AddDataSynonymFactRecursive(lhs_dd, *equation.operands[1], context);
}
if (synonymous_.IsEquivalent(*equation.operands[1], *rhs_dds[1])) {
// Equation form: "a = (d + c) - c"
// We can thus infer "a = d"
AddDataSynonymFactRecursive(lhs_dd, *equation.operands[0], context);
}
}
if (equation.opcode == SpvOpISub) {
// Equation form: "a = (d - e) - c"
if (synonymous_.IsEquivalent(*equation.operands[0],
*rhs_dds[1])) {
// Equation form: "a = (c - e) - c"
// We can thus infer "a = -e"
AddEquationFactRecursive(lhs_dd, SpvOpSNegate,
{equation.operands[1]}, context);
}
}
if (equation.opcode == SpvOpISub) {
// Equation form: "a = (d - e) - c"
if (synonymous_.IsEquivalent(*equation.operands[0], *rhs_dds[1])) {
// Equation form: "a = (c - e) - c"
// We can thus infer "a = -e"
AddEquationFactRecursive(lhs_dd, SpvOpSNegate,
{equation.operands[1]}, context);
}
}
}
{
auto existing_second_operand_equations = id_equations_.find(rhs_dds[1]);
if (existing_second_operand_equations != id_equations_.end()) {
for (auto equation : existing_second_operand_equations->second) {
if (equation.opcode == SpvOpIAdd) {
// Equation form: "a = b - (d + e)"
if (synonymous_.IsEquivalent(*equation.operands[0],
*rhs_dds[0])) {
// Equation form: "a = b - (b + e)"
// We can thus infer "a = -e"
AddEquationFactRecursive(lhs_dd, SpvOpSNegate,
{equation.operands[1]}, context);
}
if (synonymous_.IsEquivalent(*equation.operands[1],
*rhs_dds[0])) {
// Equation form: "a = b - (d + b)"
// We can thus infer "a = -d"
AddEquationFactRecursive(lhs_dd, SpvOpSNegate,
{equation.operands[0]}, context);
}
}
if (equation.opcode == SpvOpISub) {
// Equation form: "a = b - (d - e)"
if (synonymous_.IsEquivalent(*equation.operands[0],
*rhs_dds[0])) {
// Equation form: "a = b - (b - e)"
// We can thus infer "a = e"
AddDataSynonymFactRecursive(lhs_dd, *equation.operands[1],
context);
}
}
for (auto equation : GetEquations(rhs_dds[1])) {
if (equation.opcode == SpvOpIAdd) {
// Equation form: "a = b - (d + e)"
if (synonymous_.IsEquivalent(*equation.operands[0], *rhs_dds[0])) {
// Equation form: "a = b - (b + e)"
// We can thus infer "a = -e"
AddEquationFactRecursive(lhs_dd, SpvOpSNegate,
{equation.operands[1]}, context);
}
if (synonymous_.IsEquivalent(*equation.operands[1], *rhs_dds[0])) {
// Equation form: "a = b - (d + b)"
// We can thus infer "a = -d"
AddEquationFactRecursive(lhs_dd, SpvOpSNegate,
{equation.operands[0]}, context);
}
}
if (equation.opcode == SpvOpISub) {
// Equation form: "a = b - (d - e)"
if (synonymous_.IsEquivalent(*equation.operands[0], *rhs_dds[0])) {
// Equation form: "a = b - (b - e)"
// We can thus infer "a = e"
AddDataSynonymFactRecursive(lhs_dd, *equation.operands[1], context);
}
}
}
@ -682,14 +659,11 @@ void FactManager::DataSynonymAndIdEquationFacts::AddEquationFactRecursive(
case SpvOpLogicalNot:
case SpvOpSNegate: {
// Equation form: "a = !b" or "a = -b"
auto existing_equations = id_equations_.find(rhs_dds[0]);
if (existing_equations != id_equations_.end()) {
for (auto equation : existing_equations->second) {
if (equation.opcode == opcode) {
// Equation form: "a = !!b" or "a = -(-b)"
// We can thus infer "a = b"
AddDataSynonymFactRecursive(lhs_dd, *equation.operands[0], context);
}
for (auto equation : GetEquations(rhs_dds[0])) {
if (equation.opcode == opcode) {
// Equation form: "a = !!b" or "a = -(-b)"
// We can thus infer "a = b"
AddDataSynonymFactRecursive(lhs_dd, *equation.operands[0], context);
}
}
break;
@ -1116,9 +1090,7 @@ void FactManager::DataSynonymAndIdEquationFacts::MakeEquivalent(
// equations about |still_representative|; create an empty set of equations
// if this is the case.
if (!id_equations_.count(still_representative)) {
id_equations_.insert(
{still_representative,
std::unordered_set<Operation, OperationHash, OperationEquals>()});
id_equations_.insert({still_representative, OperationSet()});
}
auto still_representative_id_equations =
id_equations_.find(still_representative);