opt: Fix null deref in OpMatrixTimesVector and OpVectorTimesMatrix (#5199)

When some (not all) of the matrix columns are OpConstantNull
This commit is contained in:
Ben Clayton 2023-04-18 19:58:12 +01:00 committed by GitHub
parent d5f69dba55
commit bec566a32b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 80 additions and 34 deletions

View File

@ -398,6 +398,7 @@ ConstantFoldingRule FoldVectorTimesMatrix() {
if (float_type->width() == 32) {
for (uint32_t i = 0; i < resultVectorSize; ++i) {
float result_scalar = 0.0f;
if (!c2_components[i]->AsNullConstant()) {
const analysis::VectorConstant* c2_vec =
c2_components[i]->AsVectorConstant();
for (uint32_t j = 0; j < c2_vec->GetComponents().size(); ++j) {
@ -405,6 +406,7 @@ ConstantFoldingRule FoldVectorTimesMatrix() {
float c2_scalar = c2_vec->GetComponents()[j]->GetFloat();
result_scalar += c1_scalar * c2_scalar;
}
}
utils::FloatProxy<float> result(result_scalar);
std::vector<uint32_t> words = result.GetWords();
const analysis::Constant* new_elem =
@ -415,6 +417,7 @@ ConstantFoldingRule FoldVectorTimesMatrix() {
} else if (float_type->width() == 64) {
for (uint32_t i = 0; i < c2_components.size(); ++i) {
double result_scalar = 0.0;
if (!c2_components[i]->AsNullConstant()) {
const analysis::VectorConstant* c2_vec =
c2_components[i]->AsVectorConstant();
for (uint32_t j = 0; j < c2_vec->GetComponents().size(); ++j) {
@ -422,6 +425,7 @@ ConstantFoldingRule FoldVectorTimesMatrix() {
double c2_scalar = c2_vec->GetComponents()[j]->GetDouble();
result_scalar += c1_scalar * c2_scalar;
}
}
utils::FloatProxy<double> result(result_scalar);
std::vector<uint32_t> words = result.GetWords();
const analysis::Constant* new_elem =
@ -491,6 +495,7 @@ ConstantFoldingRule FoldMatrixTimesVector() {
for (uint32_t i = 0; i < resultVectorSize; ++i) {
float result_scalar = 0.0f;
for (uint32_t j = 0; j < c1_components.size(); ++j) {
if (!c1_components[j]->AsNullConstant()) {
float c1_scalar = c1_components[j]
->AsVectorConstant()
->GetComponents()[i]
@ -498,6 +503,7 @@ ConstantFoldingRule FoldMatrixTimesVector() {
float c2_scalar = c2_components[j]->GetFloat();
result_scalar += c1_scalar * c2_scalar;
}
}
utils::FloatProxy<float> result(result_scalar);
std::vector<uint32_t> words = result.GetWords();
const analysis::Constant* new_elem =
@ -509,6 +515,7 @@ ConstantFoldingRule FoldMatrixTimesVector() {
for (uint32_t i = 0; i < resultVectorSize; ++i) {
double result_scalar = 0.0;
for (uint32_t j = 0; j < c1_components.size(); ++j) {
if (!c1_components[j]->AsNullConstant()) {
double c1_scalar = c1_components[j]
->AsVectorConstant()
->GetComponents()[i]
@ -516,6 +523,7 @@ ConstantFoldingRule FoldMatrixTimesVector() {
double c2_scalar = c2_components[j]->GetDouble();
result_scalar += c1_scalar * c2_scalar;
}
}
utils::FloatProxy<double> result(result_scalar);
std::vector<uint32_t> words = result.GetWords();
const analysis::Constant* new_elem =

View File

@ -292,6 +292,7 @@ OpName %main "main"
%v4float_null = OpConstantNull %v4float
%mat4v4float_null = OpConstantComposite %mat4v4float %v4float_null %v4float_null %v4float_null %v4float_null
%mat4v4float_1_2_3_4 = OpConstantComposite %mat4v4float %v4float_1_2_3_4 %v4float_1_2_3_4 %v4float_1_2_3_4 %v4float_1_2_3_4
%mat4v4float_1_2_3_4_null = OpConstantComposite %mat4v4float %v4float_1_2_3_4 %v4float_null %v4float_1_2_3_4 %v4float_null
%107 = OpConstantComposite %v4double %double_0 %double_0 %double_0 %double_0
%v4double_0_0_0_0 = OpConstantComposite %v4double %double_0 %double_0 %double_0 %double_0
%v4double_0_0_0_1 = OpConstantComposite %v4double %double_0 %double_0 %double_0 %double_1
@ -302,6 +303,7 @@ OpName %main "main"
%v4double_null = OpConstantNull %v4double
%mat4v4double_null = OpConstantComposite %mat4v4double %v4double_null %v4double_null %v4double_null %v4double_null
%mat4v4double_1_2_3_4 = OpConstantComposite %mat4v4double %v4double_1_2_3_4 %v4double_1_2_3_4 %v4double_1_2_3_4 %v4double_1_2_3_4
%mat4v4double_1_2_3_4_null = OpConstantComposite %mat4v4double %v4double_1_2_3_4 %v4double_null %v4double_1_2_3_4 %v4double_null
%v4float_n1_2_1_3 = OpConstantComposite %v4float %float_n1 %float_2 %float_1 %float_3
%uint_0x3f800000 = OpConstant %uint 0x3f800000
%uint_0xbf800000 = OpConstant %uint 0xbf800000
@ -1049,7 +1051,16 @@ INSTANTIATE_TEST_SUITE_P(TestCase, DoubleVectorInstructionFoldingTest,
"OpReturn\n" +
"OpFunctionEnd",
2, {30.0,30.0,30.0,30.0}),
// Test case 4: OpMatrixTimesVector Zero Non-Zero {1.0, 2.0, 3.0, 4.0} {{0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}} {0.0, 0.0, 0.0, 0.0}
// Test case 4: OpVectorTimesMatrix Non-Zero Non-Zero {{1.0, 2.0, 3.0, 4.0}, Null, {1.0, 2.0, 3.0, 4.0}, Null} {1.0, 2.0, 3.0, 4.0} {30.0, 0.0, 30.0, 0.0}
InstructionFoldingCase<std::vector<double>>(
Header() +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpVectorTimesMatrix %v4double %v4double_1_2_3_4 %mat4v4double_1_2_3_4_null\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, {30.0,0.0,30.0,0.0}),
// Test case 5: OpMatrixTimesVector Zero Non-Zero {1.0, 2.0, 3.0, 4.0} {{0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}} {0.0, 0.0, 0.0, 0.0}
InstructionFoldingCase<std::vector<double>>(
Header() +
"%main = OpFunction %void None %void_func\n" +
@ -1058,7 +1069,7 @@ INSTANTIATE_TEST_SUITE_P(TestCase, DoubleVectorInstructionFoldingTest,
"OpReturn\n" +
"OpFunctionEnd",
2, {0.0,0.0,0.0,0.0}),
// Test case 5: OpMatrixTimesVector Non-Zero Zero {{1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}} {0.0, 0.0, 0.0, 0.0} {0.0, 0.0, 0.0, 0.0}
// Test case 6: OpMatrixTimesVector Non-Zero Zero {{1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}} {0.0, 0.0, 0.0, 0.0} {0.0, 0.0, 0.0, 0.0}
InstructionFoldingCase<std::vector<double>>(
Header() +
"%main = OpFunction %void None %void_func\n" +
@ -1067,7 +1078,7 @@ INSTANTIATE_TEST_SUITE_P(TestCase, DoubleVectorInstructionFoldingTest,
"OpReturn\n" +
"OpFunctionEnd",
2, {0.0,0.0,0.0,0.0}),
// Test case 6: OpMatrixTimesVector Non-Zero Non-Zero {1.0, 2.0, 3.0, 4.0} {{1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}} {10.0, 20.0, 30.0, 40.0}
// Test case 7: OpMatrixTimesVector Non-Zero Non-Zero {1.0, 2.0, 3.0, 4.0} {{1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}} {10.0, 20.0, 30.0, 40.0}
InstructionFoldingCase<std::vector<double>>(
Header() +
"%main = OpFunction %void None %void_func\n" +
@ -1075,7 +1086,16 @@ INSTANTIATE_TEST_SUITE_P(TestCase, DoubleVectorInstructionFoldingTest,
"%2 = OpMatrixTimesVector %v4double %mat4v4double_1_2_3_4 %v4double_1_2_3_4\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, {10.0,20.0,30.0,40.0})
2, {10.0,20.0,30.0,40.0}),
// Test case 8: OpMatrixTimesVector Non-Zero Non-Zero {1.0, 2.0, 3.0, 4.0} {{1.0, 2.0, 3.0, 4.0}, Null, {1.0, 2.0, 3.0, 4.0}, Null} {10.0, 20.0, 30.0, 40.0}
InstructionFoldingCase<std::vector<double>>(
Header() +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpMatrixTimesVector %v4double %mat4v4double_1_2_3_4_null %v4double_1_2_3_4\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, {4.0,8.0,12.0,16.0})
));
using FloatVectorInstructionFoldingTest =
@ -1154,7 +1174,16 @@ INSTANTIATE_TEST_SUITE_P(TestCase, FloatVectorInstructionFoldingTest,
"OpReturn\n" +
"OpFunctionEnd",
2, {0.0f,0.0f,0.0f,0.0f}),
// Test case 4: OpVectorTimesMatrix Zero Non-Zero {{1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}} {0.0, 0.0, 0.0, 0.0} {0.0, 0.0, 0.0, 0.0}
// Test case 4: OpVectorTimesMatrix Non-Zero Non-Zero {{1.0, 2.0, 3.0, 4.0}, Null, {1.0, 2.0, 3.0, 4.0}, Null} {1.0, 2.0, 3.0, 4.0} {30.0, 0.0, 30.0, 0.0}
InstructionFoldingCase<std::vector<float>>(
Header() +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpVectorTimesMatrix %v4float %v4float_1_2_3_4 %mat4v4float_1_2_3_4_null\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, {30.0,0.0,30.0,0.0}),
// Test case 5: OpVectorTimesMatrix Zero Non-Zero {{1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}} {0.0, 0.0, 0.0, 0.0} {0.0, 0.0, 0.0, 0.0}
InstructionFoldingCase<std::vector<float>>(
Header() +
"%main = OpFunction %void None %void_func\n" +
@ -1163,7 +1192,7 @@ INSTANTIATE_TEST_SUITE_P(TestCase, FloatVectorInstructionFoldingTest,
"OpReturn\n" +
"OpFunctionEnd",
2, {0.0f,0.0f,0.0f,0.0f}),
// Test case 5: OpVectorTimesMatrix Non-Zero Non-Zero {{1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}} {1.0, 2.0, 3.0, 4.0} {30.0, 30.0, 30.0, 30.0}
// Test case 6: OpVectorTimesMatrix Non-Zero Non-Zero {{1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}} {1.0, 2.0, 3.0, 4.0} {30.0, 30.0, 30.0, 30.0}
InstructionFoldingCase<std::vector<float>>(
Header() +
"%main = OpFunction %void None %void_func\n" +
@ -1172,7 +1201,7 @@ INSTANTIATE_TEST_SUITE_P(TestCase, FloatVectorInstructionFoldingTest,
"OpReturn\n" +
"OpFunctionEnd",
2, {30.0f,30.0f,30.0f,30.0f}),
// Test case 6: OpMatrixTimesVector Zero Non-Zero {1.0, 2.0, 3.0, 4.0} {{0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}} {0.0, 0.0, 0.0, 0.0}
// Test case 7: OpMatrixTimesVector Zero Non-Zero {1.0, 2.0, 3.0, 4.0} {{0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}} {0.0, 0.0, 0.0, 0.0}
InstructionFoldingCase<std::vector<float>>(
Header() +
"%main = OpFunction %void None %void_func\n" +
@ -1181,7 +1210,7 @@ INSTANTIATE_TEST_SUITE_P(TestCase, FloatVectorInstructionFoldingTest,
"OpReturn\n" +
"OpFunctionEnd",
2, {0.0f,0.0f,0.0f,0.0f}),
// Test case 7: OpMatrixTimesVector Non-Zero Zero {{1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}} {0.0, 0.0, 0.0, 0.0} {0.0, 0.0, 0.0, 0.0}
// Test case 8: OpMatrixTimesVector Non-Zero Zero {{1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}} {0.0, 0.0, 0.0, 0.0} {0.0, 0.0, 0.0, 0.0}
InstructionFoldingCase<std::vector<float>>(
Header() +
"%main = OpFunction %void None %void_func\n" +
@ -1190,7 +1219,7 @@ INSTANTIATE_TEST_SUITE_P(TestCase, FloatVectorInstructionFoldingTest,
"OpReturn\n" +
"OpFunctionEnd",
2, {0.0f,0.0f,0.0f,0.0f}),
// Test case 8: OpMatrixTimesVector Non-Zero Non-Zero {1.0, 2.0, 3.0, 4.0} {{1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}} {10.0, 20.0, 30.0, 40.0}
// Test case 9: OpMatrixTimesVector Non-Zero Non-Zero {1.0, 2.0, 3.0, 4.0} {{1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}} {10.0, 20.0, 30.0, 40.0}
InstructionFoldingCase<std::vector<float>>(
Header() +
"%main = OpFunction %void None %void_func\n" +
@ -1198,7 +1227,16 @@ INSTANTIATE_TEST_SUITE_P(TestCase, FloatVectorInstructionFoldingTest,
"%2 = OpMatrixTimesVector %v4float %mat4v4float_1_2_3_4 %v4float_1_2_3_4\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, {10.0f,20.0f,30.0f,40.0f})
2, {10.0f,20.0f,30.0f,40.0f}),
// Test case 10: OpMatrixTimesVector Non-Zero Non-Zero {1.0, 2.0, 3.0, 4.0} {{1.0, 2.0, 3.0, 4.0}, Null, {1.0, 2.0, 3.0, 4.0}, Null} {10.0, 20.0, 30.0, 40.0}
InstructionFoldingCase<std::vector<float>>(
Header() +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpMatrixTimesVector %v4float %mat4v4float_1_2_3_4_null %v4float_1_2_3_4\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, {4.0,8.0,12.0,16.0})
));
// clang-format on
using BooleanInstructionFoldingTest =