Fix finding constant with particular type. (#1724)

With current implementation, the constant manager does not keep around
two constant with the same value but different types when the types
hash to the same value. So when you start looking for that constant you
will get a constant with the wrong type back.

I've made a few changes to the constant manager to fix this.  First off,
I have changed the map from constant to ids to be an std::multimap.
This way a single constant can be mapped to mutiple ids each
representing a different type.

Then when asking for an id of a constant, we can search all of the ids
associated with that constant in order to find the one with the correct
type.
This commit is contained in:
Steven Perron 2018-07-16 12:36:53 -04:00 committed by Lei Zhang
parent 95b4d47e34
commit 208921efe8
6 changed files with 128 additions and 62 deletions

View File

@ -197,13 +197,25 @@ ConstantFoldingRule FoldCompositeWithConstants() {
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
analysis::TypeManager* type_mgr = context->get_type_mgr();
const analysis::Type* new_type = type_mgr->GetType(inst->type_id());
Instruction* type_inst =
context->get_def_use_mgr()->GetDef(inst->type_id());
std::vector<uint32_t> ids;
for (const analysis::Constant* element_const : constants) {
for (uint32_t i = 0; i < constants.size(); ++i) {
const analysis::Constant* element_const = constants[i];
if (element_const == nullptr) {
return nullptr;
}
uint32_t element_id = const_mgr->FindDeclaredConstant(element_const);
uint32_t component_type_id = 0;
if (type_inst->opcode() == SpvOpTypeStruct) {
component_type_id = type_inst->GetSingleWordInOperand(i);
} else if (type_inst->opcode() == SpvOpTypeArray) {
component_type_id = type_inst->GetSingleWordInOperand(0);
}
uint32_t element_id =
const_mgr->FindDeclaredConstant(element_const, component_type_id);
if (element_id == 0) {
return nullptr;
}

View File

@ -131,6 +131,24 @@ std::vector<const Constant*> ConstantManager::GetOperandConstants(
return constants;
}
uint32_t ConstantManager::FindDeclaredConstant(const Constant* c,
uint32_t type_id) const {
c = FindConstant(c);
if (c == nullptr) {
return 0;
}
for (auto range = const_val_to_id_.equal_range(c);
range.first != range.second; ++range.first) {
Instruction* const_def =
context()->get_def_use_mgr()->GetDef(range.first->second);
if (type_id == 0 || const_def->type_id() == type_id) {
return range.first->second;
}
}
return 0;
}
std::vector<const Constant*> ConstantManager::GetConstantsFromIds(
const std::vector<uint32_t>& ids) const {
std::vector<const Constant*> constants;
@ -163,7 +181,7 @@ Instruction* ConstantManager::GetDefiningInstruction(
const Constant* c, uint32_t type_id, Module::inst_iterator* pos) {
assert(type_id == 0 ||
context()->get_type_mgr()->GetType(type_id) == c->type());
uint32_t decl_id = FindDeclaredConstant(c);
uint32_t decl_id = FindDeclaredConstant(c, type_id);
if (decl_id == 0) {
auto iter = context()->types_values_end();
if (pos == nullptr) pos = &iter;
@ -295,8 +313,17 @@ std::unique_ptr<Instruction> ConstantManager::CreateInstruction(
std::unique_ptr<Instruction> ConstantManager::CreateCompositeInstruction(
uint32_t result_id, const CompositeConstant* cc, uint32_t type_id) const {
std::vector<Operand> operands;
Instruction* type_inst = context()->get_def_use_mgr()->GetDef(type_id);
uint32_t component_index = 0;
for (const Constant* component_const : cc->GetComponents()) {
uint32_t id = FindDeclaredConstant(component_const);
uint32_t component_type_id = 0;
if (type_inst && type_inst->opcode() == SpvOpTypeStruct) {
component_type_id = type_inst->GetSingleWordInOperand(component_index);
} else if (type_inst && type_inst->opcode() == SpvOpTypeArray) {
component_type_id = type_inst->GetSingleWordInOperand(0);
}
uint32_t id = FindDeclaredConstant(component_const, component_type_id);
if (id == 0) {
// Cannot get the id of the component constant, while all components
// should have been added to the module prior to the composite constant.
@ -305,6 +332,7 @@ std::unique_ptr<Instruction> ConstantManager::CreateCompositeInstruction(
}
operands.emplace_back(spv_operand_type_t::SPV_OPERAND_TYPE_ID,
std::initializer_list<uint32_t>{id});
component_index++;
}
uint32_t type =
(type_id == 0) ? context()->get_type_mgr()->GetId(cc->type()) : type_id;

View File

@ -561,10 +561,7 @@ class ConstantManager {
// A helper function to get the id of a collected constant with the pointer
// to the Constant instance. Returns 0 in case the constant is not found.
uint32_t FindDeclaredConstant(const Constant* c) const {
auto iter = const_val_to_id_.find(c);
return (iter != const_val_to_id_.end()) ? iter->second : 0;
}
uint32_t FindDeclaredConstant(const Constant* c, uint32_t type_id) const;
// Returns the canonical constant that has the same structure and value as the
// given Constant |cst|. If none is found, it returns nullptr.
@ -616,8 +613,9 @@ class ConstantManager {
// Records a new mapping between |inst| and |const_value|. This updates the
// two mappings |id_to_const_val_| and |const_val_to_id_|.
void MapConstantToInst(const Constant* const_value, Instruction* inst) {
const_val_to_id_[const_value] = inst->result_id();
id_to_const_val_[inst->result_id()] = const_value;
if (id_to_const_val_.insert({inst->result_id(), const_value}).second) {
const_val_to_id_.insert({const_value, inst->result_id()});
}
}
private:
@ -676,7 +674,7 @@ class ConstantManager {
// result id in the module. This is a mirror map of |id_to_const_val_|. All
// Normal Constants that defining instructions in the module should have
// their Constant and their result id registered here.
std::unordered_map<const Constant*, uint32_t> const_val_to_id_;
std::multimap<const Constant*, uint32_t> const_val_to_id_;
// The constant pool. All created constants are registered here.
std::unordered_set<const Constant*, ConstantHash, ConstantEqual> const_pool_;

View File

@ -51,6 +51,35 @@ TEST_F(ConstantManagerTest, GetDefiningInstruction) {
EXPECT_EQ(const_inst_2->type_id(), 2);
}
TEST_F(ConstantManagerTest, GetDefiningInstruction2) {
const std::string text = R"(
%int = OpTypeInt 32 0
%1 = OpTypeStruct %int
%2 = OpTypeStruct %int
%3 = OpConstantNull %1
%4 = OpConstantNull %2
)";
std::unique_ptr<IRContext> context =
BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text,
SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
ASSERT_NE(context, nullptr);
Type* struct_type_1 = context->get_type_mgr()->GetType(1);
NullConstant struct_const_1(struct_type_1->AsStruct());
Instruction* const_inst_1 =
context->get_constant_mgr()->GetDefiningInstruction(&struct_const_1, 1);
EXPECT_EQ(const_inst_1->type_id(), 1);
EXPECT_EQ(const_inst_1->result_id(), 3);
Type* struct_type_2 = context->get_type_mgr()->GetType(2);
NullConstant struct_const_2(struct_type_2->AsStruct());
Instruction* const_inst_2 =
context->get_constant_mgr()->GetDefiningInstruction(&struct_const_2, 2);
EXPECT_EQ(const_inst_2->type_id(), 2);
EXPECT_EQ(const_inst_2->result_id(), 4);
}
} // namespace
} // namespace analysis
} // namespace opt

View File

@ -442,13 +442,13 @@ INSTANTIATE_TEST_CASE_P(
{
"%true = OpConstantTrue %bool",
"%true_0 = OpConstantTrue %bool",
"%spec_bool_t_vec = OpConstantComposite %v2bool %true_0 %true_0",
"%spec_bool_t_vec = OpConstantComposite %v2bool %bool_true %bool_true",
"%false = OpConstantFalse %bool",
"%false_0 = OpConstantFalse %bool",
"%spec_bool_f_vec = OpConstantComposite %v2bool %false_0 %false_0",
"%spec_bool_f_vec = OpConstantComposite %v2bool %bool_false %bool_false",
"%false_1 = OpConstantFalse %bool",
"%false_2 = OpConstantFalse %bool",
"%spec_bool_from_null = OpConstantComposite %v2bool %false_2 %false_2",
"%spec_bool_from_null = OpConstantComposite %v2bool %bool_false %bool_false",
},
},
@ -464,13 +464,13 @@ INSTANTIATE_TEST_CASE_P(
{
"%true = OpConstantTrue %bool",
"%true_0 = OpConstantTrue %bool",
"%spec_bool_t_vec = OpConstantComposite %v2bool %true_0 %true_0",
"%spec_bool_t_vec = OpConstantComposite %v2bool %bool_true %bool_true",
"%false = OpConstantFalse %bool",
"%false_0 = OpConstantFalse %bool",
"%spec_bool_f_vec = OpConstantComposite %v2bool %false_0 %false_0",
"%spec_bool_f_vec = OpConstantComposite %v2bool %bool_false %bool_false",
"%false_1 = OpConstantFalse %bool",
"%false_2 = OpConstantFalse %bool",
"%spec_bool_from_null = OpConstantComposite %v2bool %false_2 %false_2",
"%spec_bool_from_null = OpConstantComposite %v2bool %bool_false %bool_false",
},
},
@ -486,13 +486,13 @@ INSTANTIATE_TEST_CASE_P(
{
"%int_1 = OpConstant %int 1",
"%int_1_0 = OpConstant %int 1",
"%spec_int_one_vec = OpConstantComposite %v2int %int_1_0 %int_1_0",
"%spec_int_one_vec = OpConstantComposite %v2int %signed_one %signed_one",
"%int_0 = OpConstant %int 0",
"%int_0_0 = OpConstant %int 0",
"%spec_int_zero_vec = OpConstantComposite %v2int %int_0_0 %int_0_0",
"%spec_int_zero_vec = OpConstantComposite %v2int %signed_zero %signed_zero",
"%int_0_1 = OpConstant %int 0",
"%int_0_2 = OpConstant %int 0",
"%spec_int_from_null = OpConstantComposite %v2int %int_0_2 %int_0_2",
"%spec_int_from_null = OpConstantComposite %v2int %signed_zero %signed_zero",
},
},
@ -508,13 +508,13 @@ INSTANTIATE_TEST_CASE_P(
{
"%int_1 = OpConstant %int 1",
"%int_1_0 = OpConstant %int 1",
"%spec_int_one_vec = OpConstantComposite %v2int %int_1_0 %int_1_0",
"%spec_int_one_vec = OpConstantComposite %v2int %signed_one %signed_one",
"%int_0 = OpConstant %int 0",
"%int_0_0 = OpConstant %int 0",
"%spec_int_zero_vec = OpConstantComposite %v2int %int_0_0 %int_0_0",
"%spec_int_zero_vec = OpConstantComposite %v2int %signed_zero %signed_zero",
"%int_0_1 = OpConstant %int 0",
"%int_0_2 = OpConstant %int 0",
"%spec_int_from_null = OpConstantComposite %v2int %int_0_2 %int_0_2",
"%spec_int_from_null = OpConstantComposite %v2int %signed_zero %signed_zero",
},
},
@ -530,13 +530,13 @@ INSTANTIATE_TEST_CASE_P(
{
"%uint_1 = OpConstant %uint 1",
"%uint_1_0 = OpConstant %uint 1",
"%spec_uint_one_vec = OpConstantComposite %v2uint %uint_1_0 %uint_1_0",
"%spec_uint_one_vec = OpConstantComposite %v2uint %unsigned_one %unsigned_one",
"%uint_0 = OpConstant %uint 0",
"%uint_0_0 = OpConstant %uint 0",
"%spec_uint_zero_vec = OpConstantComposite %v2uint %uint_0_0 %uint_0_0",
"%spec_uint_zero_vec = OpConstantComposite %v2uint %unsigned_zero %unsigned_zero",
"%uint_0_1 = OpConstant %uint 0",
"%uint_0_2 = OpConstant %uint 0",
"%spec_uint_from_null = OpConstantComposite %v2uint %uint_0_2 %uint_0_2",
"%spec_uint_from_null = OpConstantComposite %v2uint %unsigned_zero %unsigned_zero",
},
},
@ -552,13 +552,13 @@ INSTANTIATE_TEST_CASE_P(
{
"%uint_1 = OpConstant %uint 1",
"%uint_1_0 = OpConstant %uint 1",
"%spec_uint_one_vec = OpConstantComposite %v2uint %uint_1_0 %uint_1_0",
"%spec_uint_one_vec = OpConstantComposite %v2uint %unsigned_one %unsigned_one",
"%uint_0 = OpConstant %uint 0",
"%uint_0_0 = OpConstant %uint 0",
"%spec_uint_zero_vec = OpConstantComposite %v2uint %uint_0_0 %uint_0_0",
"%spec_uint_zero_vec = OpConstantComposite %v2uint %unsigned_zero %unsigned_zero",
"%uint_0_1 = OpConstant %uint 0",
"%uint_0_2 = OpConstant %uint 0",
"%spec_uint_from_null = OpConstantComposite %v2uint %uint_0_2 %uint_0_2",
"%spec_uint_from_null = OpConstantComposite %v2uint %unsigned_zero %unsigned_zero",
},
},
// clang-format on
@ -837,13 +837,13 @@ INSTANTIATE_TEST_CASE_P(
{
"%int_n1 = OpConstant %int -1",
"%int_n1_0 = OpConstant %int -1",
"%v2int_minus_1 = OpConstantComposite %v2int %int_n1_0 %int_n1_0",
"%v2int_minus_1 = OpConstantComposite %v2int %int_n1 %int_n1",
"%int_n2 = OpConstant %int -2",
"%int_n2_0 = OpConstant %int -2",
"%v2int_minus_2 = OpConstantComposite %v2int %int_n2_0 %int_n2_0",
"%v2int_minus_2 = OpConstantComposite %v2int %int_n2 %int_n2",
"%int_0 = OpConstant %int 0",
"%int_0_0 = OpConstant %int 0",
"%v2int_neg_null = OpConstantComposite %v2int %int_0_0 %int_0_0",
"%v2int_neg_null = OpConstantComposite %v2int %signed_zero %signed_zero",
},
},
// vector integer (including null vetors) add, sub, div, mul
@ -866,35 +866,35 @@ INSTANTIATE_TEST_CASE_P(
{
"%int_5 = OpConstant %int 5",
"%int_5_0 = OpConstant %int 5",
"%spec_v2int_iadd = OpConstantComposite %v2int %int_5_0 %int_5_0",
"%spec_v2int_iadd = OpConstantComposite %v2int %int_5 %int_5",
"%int_n4 = OpConstant %int -4",
"%int_n4_0 = OpConstant %int -4",
"%spec_v2int_isub = OpConstantComposite %v2int %int_n4_0 %int_n4_0",
"%spec_v2int_isub = OpConstantComposite %v2int %int_n4 %int_n4",
"%int_n2 = OpConstant %int -2",
"%int_n2_0 = OpConstant %int -2",
"%spec_v2int_sdiv = OpConstantComposite %v2int %int_n2_0 %int_n2_0",
"%spec_v2int_sdiv = OpConstantComposite %v2int %int_n2 %int_n2",
"%int_n6 = OpConstant %int -6",
"%int_n6_0 = OpConstant %int -6",
"%spec_v2int_imul = OpConstantComposite %v2int %int_n6_0 %int_n6_0",
"%spec_v2int_imul = OpConstantComposite %v2int %int_n6 %int_n6",
"%int_n6_1 = OpConstant %int -6",
"%int_n6_2 = OpConstant %int -6",
"%spec_v2int_iadd_null = OpConstantComposite %v2int %int_n6_2 %int_n6_2",
"%spec_v2int_iadd_null = OpConstantComposite %v2int %int_n6 %int_n6",
"%uint_5 = OpConstant %uint 5",
"%uint_5_0 = OpConstant %uint 5",
"%spec_v2uint_iadd = OpConstantComposite %v2uint %uint_5_0 %uint_5_0",
"%spec_v2uint_iadd = OpConstantComposite %v2uint %uint_5 %uint_5",
"%uint_4294967292 = OpConstant %uint 4294967292",
"%uint_4294967292_0 = OpConstant %uint 4294967292",
"%spec_v2uint_isub = OpConstantComposite %v2uint %uint_4294967292_0 %uint_4294967292_0",
"%spec_v2uint_isub = OpConstantComposite %v2uint %uint_4294967292 %uint_4294967292",
"%uint_1431655764 = OpConstant %uint 1431655764",
"%uint_1431655764_0 = OpConstant %uint 1431655764",
"%spec_v2uint_udiv = OpConstantComposite %v2uint %uint_1431655764_0 %uint_1431655764_0",
"%spec_v2uint_udiv = OpConstantComposite %v2uint %uint_1431655764 %uint_1431655764",
"%uint_2863311528 = OpConstant %uint 2863311528",
"%uint_2863311528_0 = OpConstant %uint 2863311528",
"%spec_v2uint_imul = OpConstantComposite %v2uint %uint_2863311528_0 %uint_2863311528_0",
"%spec_v2uint_imul = OpConstantComposite %v2uint %uint_2863311528 %uint_2863311528",
"%uint_2863311528_1 = OpConstant %uint 2863311528",
"%uint_2863311528_2 = OpConstant %uint 2863311528",
"%spec_v2uint_isub_null = OpConstantComposite %v2uint %uint_2863311528_2 %uint_2863311528_2",
"%spec_v2uint_isub_null = OpConstantComposite %v2uint %uint_2863311528 %uint_2863311528",
},
},
// vector integer rem, mod
@ -939,33 +939,33 @@ INSTANTIATE_TEST_CASE_P(
// srem
"%int_1 = OpConstant %int 1",
"%int_1_0 = OpConstant %int 1",
"%7_srem_3 = OpConstantComposite %v2int %int_1_0 %int_1_0",
"%7_srem_3 = OpConstantComposite %v2int %signed_one %signed_one",
"%int_n1 = OpConstant %int -1",
"%int_n1_0 = OpConstant %int -1",
"%minus_7_srem_3 = OpConstantComposite %v2int %int_n1_0 %int_n1_0",
"%minus_7_srem_3 = OpConstantComposite %v2int %int_n1 %int_n1",
"%int_1_1 = OpConstant %int 1",
"%int_1_2 = OpConstant %int 1",
"%7_srem_minus_3 = OpConstantComposite %v2int %int_1_2 %int_1_2",
"%7_srem_minus_3 = OpConstantComposite %v2int %signed_one %signed_one",
"%int_n1_1 = OpConstant %int -1",
"%int_n1_2 = OpConstant %int -1",
"%minus_7_srem_minus_3 = OpConstantComposite %v2int %int_n1_2 %int_n1_2",
"%minus_7_srem_minus_3 = OpConstantComposite %v2int %int_n1 %int_n1",
// smod
"%int_1_3 = OpConstant %int 1",
"%int_1_4 = OpConstant %int 1",
"%7_smod_3 = OpConstantComposite %v2int %int_1_4 %int_1_4",
"%7_smod_3 = OpConstantComposite %v2int %signed_one %signed_one",
"%int_2 = OpConstant %int 2",
"%int_2_0 = OpConstant %int 2",
"%minus_7_smod_3 = OpConstantComposite %v2int %int_2_0 %int_2_0",
"%minus_7_smod_3 = OpConstantComposite %v2int %signed_two %signed_two",
"%int_n2 = OpConstant %int -2",
"%int_n2_0 = OpConstant %int -2",
"%7_smod_minus_3 = OpConstantComposite %v2int %int_n2_0 %int_n2_0",
"%7_smod_minus_3 = OpConstantComposite %v2int %int_n2 %int_n2",
"%int_n1_3 = OpConstant %int -1",
"%int_n1_4 = OpConstant %int -1",
"%minus_7_smod_minus_3 = OpConstantComposite %v2int %int_n1_4 %int_n1_4",
"%minus_7_smod_minus_3 = OpConstantComposite %v2int %int_n1 %int_n1",
// umod
"%uint_1 = OpConstant %uint 1",
"%uint_1_0 = OpConstant %uint 1",
"%7_umod_3 = OpConstantComposite %v2uint %uint_1_0 %uint_1_0",
"%7_umod_3 = OpConstantComposite %v2uint %unsigned_one %unsigned_one",
},
},
// vector integer bitwise, shift
@ -986,25 +986,25 @@ INSTANTIATE_TEST_CASE_P(
{
"%int_2 = OpConstant %int 2",
"%int_2_0 = OpConstant %int 2",
"%xor_1_3 = OpConstantComposite %v2int %int_2_0 %int_2_0",
"%xor_1_3 = OpConstantComposite %v2int %signed_two %signed_two",
"%int_0 = OpConstant %int 0",
"%int_0_0 = OpConstant %int 0",
"%and_1_2 = OpConstantComposite %v2int %int_0_0 %int_0_0",
"%and_1_2 = OpConstantComposite %v2int %signed_zero %signed_zero",
"%int_3 = OpConstant %int 3",
"%int_3_0 = OpConstant %int 3",
"%or_1_2 = OpConstantComposite %v2int %int_3_0 %int_3_0",
"%or_1_2 = OpConstantComposite %v2int %signed_three %signed_three",
"%unsigned_31 = OpConstant %uint 31",
"%v2unsigned_31 = OpConstantComposite %v2uint %unsigned_31 %unsigned_31",
"%uint_2147483648 = OpConstant %uint 2147483648",
"%uint_2147483648_0 = OpConstant %uint 2147483648",
"%unsigned_left_shift_max = OpConstantComposite %v2uint %uint_2147483648_0 %uint_2147483648_0",
"%unsigned_left_shift_max = OpConstantComposite %v2uint %uint_2147483648 %uint_2147483648",
"%uint_1 = OpConstant %uint 1",
"%uint_1_0 = OpConstant %uint 1",
"%unsigned_right_shift_logical = OpConstantComposite %v2uint %uint_1_0 %uint_1_0",
"%unsigned_right_shift_logical = OpConstantComposite %v2uint %unsigned_one %unsigned_one",
"%int_n1 = OpConstant %int -1",
"%int_n1_0 = OpConstant %int -1",
"%signed_right_shift_arithmetic = OpConstantComposite %v2int %int_n1_0 %int_n1_0",
"%signed_right_shift_arithmetic = OpConstantComposite %v2int %int_n1 %int_n1",
},
},
// Skip folding if any vector operands or components of the operands
@ -1256,13 +1256,13 @@ INSTANTIATE_TEST_CASE_P(
// expected
{
"%60 = OpConstantNull %int",
"%a = OpConstantComposite %v2int %60 %60",
"%a = OpConstantComposite %v2int %signed_null %signed_null",
"%62 = OpConstantNull %int",
"%b = OpConstantComposite %v2int %signed_zero %signed_one",
"%64 = OpConstantNull %int",
"%c = OpConstantComposite %v2int %signed_three %64",
"%c = OpConstantComposite %v2int %signed_three %signed_null",
"%66 = OpConstantNull %int",
"%d = OpConstantComposite %v2int %66 %66",
"%d = OpConstantComposite %v2int %signed_null %signed_null",
}
},
// skip if any of the components of the vector operands do not have
@ -1378,7 +1378,7 @@ INSTANTIATE_TEST_CASE_P(
"%used_vec_a = OpConstantComposite %v2int %spec_int_18 %spec_int_19",
"%int_10201 = OpConstant %int 10201",
"%int_1 = OpConstant %int 1",
"%used_vec_b = OpConstantComposite %v2int %int_10201 %int_1",
"%used_vec_b = OpConstantComposite %v2int %int_10201 %signed_one",
"%spec_int_21 = OpConstant %int 10201",
"%array = OpConstantComposite %type_arr_int_4 %spec_int_20 %spec_int_20 %spec_int_21 %spec_int_21",
"%spec_int_22 = OpSpecConstant %int 123",

View File

@ -198,7 +198,6 @@ OpName %main "main"
%104 = OpConstant %float 0 ; Need a def with an numerical id to define id maps.
%float_null = OpConstantNull %float
%float_0 = OpConstant %float 0
%float_half = OpConstant %float 0.5
%float_1 = OpConstant %float 1
%float_2 = OpConstant %float 2
%float_3 = OpConstant %float 3