Check if const is zero before getting components. (#5217)

* Check if const is zero before getting components.

Two folding rules try to cast a constant to a MatrixConstant before
checking if it is a Null constant. This leads to the null pointer being
dereferneced. The solution is to move the check for zero earlier.

Fixes https://github.com/microsoft/DirectXShaderCompiler/issues/5063
This commit is contained in:
Steven Perron 2023-05-25 09:07:22 -04:00 committed by GitHub
parent 2358001827
commit af27ece750
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 14 additions and 14 deletions

View File

@ -376,13 +376,7 @@ ConstantFoldingRule FoldVectorTimesMatrix() {
assert(c1->type()->AsVector()->element_type() == element_type &&
c2->type()->AsMatrix()->element_type() == vector_type);
// Get a float vector that is the result of vector-times-matrix.
std::vector<const analysis::Constant*> c1_components =
c1->GetVectorComponents(const_mgr);
std::vector<const analysis::Constant*> c2_components =
c2->AsMatrixConstant()->GetComponents();
uint32_t resultVectorSize = result_type->AsVector()->element_count();
std::vector<uint32_t> ids;
if ((c1 && c1->IsZero()) || (c2 && c2->IsZero())) {
@ -395,6 +389,12 @@ ConstantFoldingRule FoldVectorTimesMatrix() {
return const_mgr->GetConstant(vector_type, ids);
}
// Get a float vector that is the result of vector-times-matrix.
std::vector<const analysis::Constant*> c1_components =
c1->GetVectorComponents(const_mgr);
std::vector<const analysis::Constant*> c2_components =
c2->AsMatrixConstant()->GetComponents();
if (float_type->width() == 32) {
for (uint32_t i = 0; i < resultVectorSize; ++i) {
float result_scalar = 0.0f;
@ -472,13 +472,7 @@ ConstantFoldingRule FoldMatrixTimesVector() {
assert(c1->type()->AsMatrix()->element_type() == vector_type);
assert(c2->type()->AsVector()->element_type() == element_type);
// Get a float vector that is the result of matrix-times-vector.
std::vector<const analysis::Constant*> c1_components =
c1->AsMatrixConstant()->GetComponents();
std::vector<const analysis::Constant*> c2_components =
c2->GetVectorComponents(const_mgr);
uint32_t resultVectorSize = result_type->AsVector()->element_count();
std::vector<uint32_t> ids;
if ((c1 && c1->IsZero()) || (c2 && c2->IsZero())) {
@ -491,6 +485,12 @@ ConstantFoldingRule FoldMatrixTimesVector() {
return const_mgr->GetConstant(vector_type, ids);
}
// Get a float vector that is the result of matrix-times-vector.
std::vector<const analysis::Constant*> c1_components =
c1->AsMatrixConstant()->GetComponents();
std::vector<const analysis::Constant*> c2_components =
c2->GetVectorComponents(const_mgr);
if (float_type->width() == 32) {
for (uint32_t i = 0; i < resultVectorSize; ++i) {
float result_scalar = 0.0f;

View File

@ -290,7 +290,7 @@ OpName %main "main"
%v4float_1_1_1_1 = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1
%v4float_1_2_3_4 = OpConstantComposite %v4float %float_1 %float_2 %float_3 %float_4
%v4float_null = OpConstantNull %v4float
%mat4v4float_null = OpConstantComposite %mat4v4float %v4float_null %v4float_null %v4float_null %v4float_null
%mat4v4float_null = OpConstantNull %mat4v4float
%mat4v4float_1_2_3_4 = OpConstantComposite %mat4v4float %v4float_1_2_3_4 %v4float_1_2_3_4 %v4float_1_2_3_4 %v4float_1_2_3_4
%mat4v4float_1_2_3_4_null = OpConstantComposite %mat4v4float %v4float_1_2_3_4 %v4float_null %v4float_1_2_3_4 %v4float_null
%107 = OpConstantComposite %v4double %double_0 %double_0 %double_0 %double_0
@ -301,7 +301,7 @@ OpName %main "main"
%v4double_1_2_3_4 = OpConstantComposite %v4double %double_1 %double_2 %double_3 %double_4
%v4double_1_1_1_0p5 = OpConstantComposite %v4double %double_1 %double_1 %double_1 %double_0p5
%v4double_null = OpConstantNull %v4double
%mat4v4double_null = OpConstantComposite %mat4v4double %v4double_null %v4double_null %v4double_null %v4double_null
%mat4v4double_null = OpConstantNull %mat4v4double
%mat4v4double_1_2_3_4 = OpConstantComposite %mat4v4double %v4double_1_2_3_4 %v4double_1_2_3_4 %v4double_1_2_3_4 %v4double_1_2_3_4
%mat4v4double_1_2_3_4_null = OpConstantComposite %mat4v4double %v4double_1_2_3_4 %v4double_null %v4double_1_2_3_4 %v4double_null
%v4float_n1_2_1_3 = OpConstantComposite %v4float %float_n1 %float_2 %float_1 %float_3