SPIRV-Tools/source/opt/const_folding_rules.cpp
Steven Perron 208921efe8 Fix finding constant with particular type. (#1724)
With current implementation, the constant manager does not keep around
two constant with the same value but different types when the types
hash to the same value. So when you start looking for that constant you
will get a constant with the wrong type back.

I've made a few changes to the constant manager to fix this.  First off,
I have changed the map from constant to ids to be an std::multimap.
This way a single constant can be mapped to mutiple ids each
representing a different type.

Then when asking for an id of a constant, we can search all of the ids
associated with that constant in order to find the one with the correct
type.
2018-07-16 12:36:53 -04:00

846 lines
33 KiB
C++

// Copyright (c) 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "const_folding_rules.h"
#include "ir_context.h"
namespace spvtools {
namespace opt {
namespace {
const uint32_t kExtractCompositeIdInIdx = 0;
// Returns true if |type| is Float or a vector of Float.
bool HasFloatingPoint(const analysis::Type* type) {
if (type->AsFloat()) {
return true;
} else if (const analysis::Vector* vec_type = type->AsVector()) {
return vec_type->element_type()->AsFloat() != nullptr;
}
return false;
}
// Folds an OpcompositeExtract where input is a composite constant.
ConstantFoldingRule FoldExtractWithConstants() {
return [](IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>& constants)
-> const analysis::Constant* {
const analysis::Constant* c = constants[kExtractCompositeIdInIdx];
if (c == nullptr) {
return nullptr;
}
for (uint32_t i = 1; i < inst->NumInOperands(); ++i) {
uint32_t element_index = inst->GetSingleWordInOperand(i);
if (c->AsNullConstant()) {
// Return Null for the return type.
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
analysis::TypeManager* type_mgr = context->get_type_mgr();
return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), {});
}
auto cc = c->AsCompositeConstant();
assert(cc != nullptr);
auto components = cc->GetComponents();
c = components[element_index];
}
return c;
};
}
ConstantFoldingRule FoldVectorShuffleWithConstants() {
return [](IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>& constants)
-> const analysis::Constant* {
assert(inst->opcode() == SpvOpVectorShuffle);
const analysis::Constant* c1 = constants[0];
const analysis::Constant* c2 = constants[1];
if (c1 == nullptr || c2 == nullptr) {
return nullptr;
}
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
const analysis::Type* element_type = c1->type()->AsVector()->element_type();
std::vector<const analysis::Constant*> c1_components;
if (const analysis::VectorConstant* vec_const = c1->AsVectorConstant()) {
c1_components = vec_const->GetComponents();
} else {
assert(c1->AsNullConstant());
const analysis::Constant* element =
const_mgr->GetConstant(element_type, {});
c1_components.resize(c1->type()->AsVector()->element_count(), element);
}
std::vector<const analysis::Constant*> c2_components;
if (const analysis::VectorConstant* vec_const = c2->AsVectorConstant()) {
c2_components = vec_const->GetComponents();
} else {
assert(c2->AsNullConstant());
const analysis::Constant* element =
const_mgr->GetConstant(element_type, {});
c2_components.resize(c2->type()->AsVector()->element_count(), element);
}
std::vector<uint32_t> ids;
for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
uint32_t index = inst->GetSingleWordInOperand(i);
if (index < c1_components.size()) {
Instruction* member_inst =
const_mgr->GetDefiningInstruction(c1_components[index]);
ids.push_back(member_inst->result_id());
} else {
Instruction* member_inst = const_mgr->GetDefiningInstruction(
c2_components[index - c1_components.size()]);
ids.push_back(member_inst->result_id());
}
}
analysis::TypeManager* type_mgr = context->get_type_mgr();
return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), ids);
};
}
ConstantFoldingRule FoldVectorTimesScalar() {
return [](IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>& constants)
-> const analysis::Constant* {
assert(inst->opcode() == SpvOpVectorTimesScalar);
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
analysis::TypeManager* type_mgr = context->get_type_mgr();
if (!inst->IsFloatingPointFoldingAllowed()) {
if (HasFloatingPoint(type_mgr->GetType(inst->type_id()))) {
return nullptr;
}
}
const analysis::Constant* c1 = constants[0];
const analysis::Constant* c2 = constants[1];
if (c1 && c1->IsZero()) {
return c1;
}
if (c2 && c2->IsZero()) {
// Get or create the NullConstant for this type.
std::vector<uint32_t> ids;
return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), ids);
}
if (c1 == nullptr || c2 == nullptr) {
return nullptr;
}
// Check result type.
const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
const analysis::Vector* vector_type = result_type->AsVector();
assert(vector_type != nullptr);
const analysis::Type* element_type = vector_type->element_type();
assert(element_type != nullptr);
const analysis::Float* float_type = element_type->AsFloat();
assert(float_type != nullptr);
// Check types of c1 and c2.
assert(c1->type()->AsVector() == vector_type);
assert(c1->type()->AsVector()->element_type() == element_type &&
c2->type() == element_type);
// Get a float vector that is the result of vector-times-scalar.
std::vector<const analysis::Constant*> c1_components =
c1->GetVectorComponents(const_mgr);
std::vector<uint32_t> ids;
if (float_type->width() == 32) {
float scalar = c2->GetFloat();
for (uint32_t i = 0; i < c1_components.size(); ++i) {
utils::FloatProxy<float> result(c1_components[i]->GetFloat() * scalar);
std::vector<uint32_t> words = result.GetWords();
const analysis::Constant* new_elem =
const_mgr->GetConstant(float_type, words);
ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
}
return const_mgr->GetConstant(vector_type, ids);
} else if (float_type->width() == 64) {
double scalar = c2->GetDouble();
for (uint32_t i = 0; i < c1_components.size(); ++i) {
utils::FloatProxy<double> result(c1_components[i]->GetDouble() *
scalar);
std::vector<uint32_t> words = result.GetWords();
const analysis::Constant* new_elem =
const_mgr->GetConstant(float_type, words);
ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
}
return const_mgr->GetConstant(vector_type, ids);
}
return nullptr;
};
}
ConstantFoldingRule FoldCompositeWithConstants() {
// Folds an OpCompositeConstruct where all of the inputs are constants to a
// constant. A new constant is created if necessary.
return [](IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>& constants)
-> const analysis::Constant* {
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
analysis::TypeManager* type_mgr = context->get_type_mgr();
const analysis::Type* new_type = type_mgr->GetType(inst->type_id());
Instruction* type_inst =
context->get_def_use_mgr()->GetDef(inst->type_id());
std::vector<uint32_t> ids;
for (uint32_t i = 0; i < constants.size(); ++i) {
const analysis::Constant* element_const = constants[i];
if (element_const == nullptr) {
return nullptr;
}
uint32_t component_type_id = 0;
if (type_inst->opcode() == SpvOpTypeStruct) {
component_type_id = type_inst->GetSingleWordInOperand(i);
} else if (type_inst->opcode() == SpvOpTypeArray) {
component_type_id = type_inst->GetSingleWordInOperand(0);
}
uint32_t element_id =
const_mgr->FindDeclaredConstant(element_const, component_type_id);
if (element_id == 0) {
return nullptr;
}
ids.push_back(element_id);
}
return const_mgr->GetConstant(new_type, ids);
};
}
// The interface for a function that returns the result of applying a scalar
// floating-point binary operation on |a| and |b|. The type of the return value
// will be |type|. The input constants must also be of type |type|.
using UnaryScalarFoldingRule = std::function<const analysis::Constant*(
const analysis::Type* result_type, const analysis::Constant* a,
analysis::ConstantManager*)>;
// The interface for a function that returns the result of applying a scalar
// floating-point binary operation on |a| and |b|. The type of the return value
// will be |type|. The input constants must also be of type |type|.
using BinaryScalarFoldingRule = std::function<const analysis::Constant*(
const analysis::Type* result_type, const analysis::Constant* a,
const analysis::Constant* b, analysis::ConstantManager*)>;
// Returns a |ConstantFoldingRule| that folds unary floating point scalar ops
// using |scalar_rule| and unary float point vectors ops by applying
// |scalar_rule| to the elements of the vector. The |ConstantFoldingRule|
// that is returned assumes that |constants| contains 1 entry. If they are
// not |nullptr|, then their type is either |Float| or |Integer| or a |Vector|
// whose element type is |Float| or |Integer|.
ConstantFoldingRule FoldFPUnaryOp(UnaryScalarFoldingRule scalar_rule) {
return [scalar_rule](IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>& constants)
-> const analysis::Constant* {
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
analysis::TypeManager* type_mgr = context->get_type_mgr();
const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
const analysis::Vector* vector_type = result_type->AsVector();
if (!inst->IsFloatingPointFoldingAllowed()) {
return nullptr;
}
if (constants[0] == nullptr) {
return nullptr;
}
if (vector_type != nullptr) {
std::vector<const analysis::Constant*> a_components;
std::vector<const analysis::Constant*> results_components;
a_components = constants[0]->GetVectorComponents(const_mgr);
// Fold each component of the vector.
for (uint32_t i = 0; i < a_components.size(); ++i) {
results_components.push_back(scalar_rule(vector_type->element_type(),
a_components[i], const_mgr));
if (results_components[i] == nullptr) {
return nullptr;
}
}
// Build the constant object and return it.
std::vector<uint32_t> ids;
for (const analysis::Constant* member : results_components) {
ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id());
}
return const_mgr->GetConstant(vector_type, ids);
} else {
return scalar_rule(result_type, constants[0], const_mgr);
}
};
}
// Returns a |ConstantFoldingRule| that folds floating point scalars using
// |scalar_rule| and vectors of floating point by applying |scalar_rule| to the
// elements of the vector. The |ConstantFoldingRule| that is returned assumes
// that |constants| contains 2 entries. If they are not |nullptr|, then their
// type is either |Float| or a |Vector| whose element type is |Float|.
ConstantFoldingRule FoldFPBinaryOp(BinaryScalarFoldingRule scalar_rule) {
return [scalar_rule](IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>& constants)
-> const analysis::Constant* {
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
analysis::TypeManager* type_mgr = context->get_type_mgr();
const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
const analysis::Vector* vector_type = result_type->AsVector();
if (!inst->IsFloatingPointFoldingAllowed()) {
return nullptr;
}
if (constants[0] == nullptr || constants[1] == nullptr) {
return nullptr;
}
if (vector_type != nullptr) {
std::vector<const analysis::Constant*> a_components;
std::vector<const analysis::Constant*> b_components;
std::vector<const analysis::Constant*> results_components;
a_components = constants[0]->GetVectorComponents(const_mgr);
b_components = constants[1]->GetVectorComponents(const_mgr);
// Fold each component of the vector.
for (uint32_t i = 0; i < a_components.size(); ++i) {
results_components.push_back(scalar_rule(vector_type->element_type(),
a_components[i],
b_components[i], const_mgr));
if (results_components[i] == nullptr) {
return nullptr;
}
}
// Build the constant object and return it.
std::vector<uint32_t> ids;
for (const analysis::Constant* member : results_components) {
ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id());
}
return const_mgr->GetConstant(vector_type, ids);
} else {
return scalar_rule(result_type, constants[0], constants[1], const_mgr);
}
};
}
// This macro defines a |UnaryScalarFoldingRule| that performs float to
// integer conversion.
// TODO(greg-lunarg): Support for 64-bit integer types.
UnaryScalarFoldingRule FoldFToIOp() {
return [](const analysis::Type* result_type, const analysis::Constant* a,
analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
assert(result_type != nullptr && a != nullptr);
const analysis::Integer* integer_type = result_type->AsInteger();
const analysis::Float* float_type = a->type()->AsFloat();
assert(float_type != nullptr);
assert(integer_type != nullptr);
if (integer_type->width() != 32) return nullptr;
if (float_type->width() == 32) {
float fa = a->GetFloat();
uint32_t result = integer_type->IsSigned()
? static_cast<uint32_t>(static_cast<int32_t>(fa))
: static_cast<uint32_t>(fa);
std::vector<uint32_t> words = {result};
return const_mgr->GetConstant(result_type, words);
} else if (float_type->width() == 64) {
double fa = a->GetDouble();
uint32_t result = integer_type->IsSigned()
? static_cast<uint32_t>(static_cast<int32_t>(fa))
: static_cast<uint32_t>(fa);
std::vector<uint32_t> words = {result};
return const_mgr->GetConstant(result_type, words);
}
return nullptr;
};
}
// This function defines a |UnaryScalarFoldingRule| that performs integer to
// float conversion.
// TODO(greg-lunarg): Support for 64-bit integer types.
UnaryScalarFoldingRule FoldIToFOp() {
return [](const analysis::Type* result_type, const analysis::Constant* a,
analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
assert(result_type != nullptr && a != nullptr);
const analysis::Integer* integer_type = a->type()->AsInteger();
const analysis::Float* float_type = result_type->AsFloat();
assert(float_type != nullptr);
assert(integer_type != nullptr);
if (integer_type->width() != 32) return nullptr;
uint32_t ua = a->GetU32();
if (float_type->width() == 32) {
float result_val = integer_type->IsSigned()
? static_cast<float>(static_cast<int32_t>(ua))
: static_cast<float>(ua);
utils::FloatProxy<float> result(result_val);
std::vector<uint32_t> words = {result.data()};
return const_mgr->GetConstant(result_type, words);
} else if (float_type->width() == 64) {
double result_val = integer_type->IsSigned()
? static_cast<double>(static_cast<int32_t>(ua))
: static_cast<double>(ua);
utils::FloatProxy<double> result(result_val);
std::vector<uint32_t> words = result.GetWords();
return const_mgr->GetConstant(result_type, words);
}
return nullptr;
};
}
// This macro defines a |BinaryScalarFoldingRule| that applies |op|. The
// operator |op| must work for both float and double, and use syntax "f1 op f2".
#define FOLD_FPARITH_OP(op) \
[](const analysis::Type* result_type, const analysis::Constant* a, \
const analysis::Constant* b, \
analysis::ConstantManager* const_mgr_in_macro) \
-> const analysis::Constant* { \
assert(result_type != nullptr && a != nullptr && b != nullptr); \
assert(result_type == a->type() && result_type == b->type()); \
const analysis::Float* float_type_in_macro = result_type->AsFloat(); \
assert(float_type_in_macro != nullptr); \
if (float_type_in_macro->width() == 32) { \
float fa = a->GetFloat(); \
float fb = b->GetFloat(); \
utils::FloatProxy<float> result_in_macro(fa op fb); \
std::vector<uint32_t> words_in_macro = result_in_macro.GetWords(); \
return const_mgr_in_macro->GetConstant(result_type, words_in_macro); \
} else if (float_type_in_macro->width() == 64) { \
double fa = a->GetDouble(); \
double fb = b->GetDouble(); \
utils::FloatProxy<double> result_in_macro(fa op fb); \
std::vector<uint32_t> words_in_macro = result_in_macro.GetWords(); \
return const_mgr_in_macro->GetConstant(result_type, words_in_macro); \
} \
return nullptr; \
}
// Define the folding rule for conversion between floating point and integer
ConstantFoldingRule FoldFToI() { return FoldFPUnaryOp(FoldFToIOp()); }
ConstantFoldingRule FoldIToF() { return FoldFPUnaryOp(FoldIToFOp()); }
// Define the folding rules for subtraction, addition, multiplication, and
// division for floating point values.
ConstantFoldingRule FoldFSub() { return FoldFPBinaryOp(FOLD_FPARITH_OP(-)); }
ConstantFoldingRule FoldFAdd() { return FoldFPBinaryOp(FOLD_FPARITH_OP(+)); }
ConstantFoldingRule FoldFMul() { return FoldFPBinaryOp(FOLD_FPARITH_OP(*)); }
ConstantFoldingRule FoldFDiv() { return FoldFPBinaryOp(FOLD_FPARITH_OP(/)); }
bool CompareFloatingPoint(bool op_result, bool op_unordered,
bool need_ordered) {
if (need_ordered) {
// operands are ordered and Operand 1 is |op| Operand 2
return !op_unordered && op_result;
} else {
// operands are unordered or Operand 1 is |op| Operand 2
return op_unordered || op_result;
}
}
// This macro defines a |BinaryScalarFoldingRule| that applies |op|. The
// operator |op| must work for both float and double, and use syntax "f1 op f2".
#define FOLD_FPCMP_OP(op, ord) \
[](const analysis::Type* result_type, const analysis::Constant* a, \
const analysis::Constant* b, \
analysis::ConstantManager* const_mgr) -> const analysis::Constant* { \
assert(result_type != nullptr && a != nullptr && b != nullptr); \
assert(result_type->AsBool()); \
assert(a->type() == b->type()); \
const analysis::Float* float_type = a->type()->AsFloat(); \
assert(float_type != nullptr); \
if (float_type->width() == 32) { \
float fa = a->GetFloat(); \
float fb = b->GetFloat(); \
bool result = CompareFloatingPoint( \
fa op fb, std::isnan(fa) || std::isnan(fb), ord); \
std::vector<uint32_t> words = {uint32_t(result)}; \
return const_mgr->GetConstant(result_type, words); \
} else if (float_type->width() == 64) { \
double fa = a->GetDouble(); \
double fb = b->GetDouble(); \
bool result = CompareFloatingPoint( \
fa op fb, std::isnan(fa) || std::isnan(fb), ord); \
std::vector<uint32_t> words = {uint32_t(result)}; \
return const_mgr->GetConstant(result_type, words); \
} \
return nullptr; \
}
// Define the folding rules for ordered and unordered comparison for floating
// point values.
ConstantFoldingRule FoldFOrdEqual() {
return FoldFPBinaryOp(FOLD_FPCMP_OP(==, true));
}
ConstantFoldingRule FoldFUnordEqual() {
return FoldFPBinaryOp(FOLD_FPCMP_OP(==, false));
}
ConstantFoldingRule FoldFOrdNotEqual() {
return FoldFPBinaryOp(FOLD_FPCMP_OP(!=, true));
}
ConstantFoldingRule FoldFUnordNotEqual() {
return FoldFPBinaryOp(FOLD_FPCMP_OP(!=, false));
}
ConstantFoldingRule FoldFOrdLessThan() {
return FoldFPBinaryOp(FOLD_FPCMP_OP(<, true));
}
ConstantFoldingRule FoldFUnordLessThan() {
return FoldFPBinaryOp(FOLD_FPCMP_OP(<, false));
}
ConstantFoldingRule FoldFOrdGreaterThan() {
return FoldFPBinaryOp(FOLD_FPCMP_OP(>, true));
}
ConstantFoldingRule FoldFUnordGreaterThan() {
return FoldFPBinaryOp(FOLD_FPCMP_OP(>, false));
}
ConstantFoldingRule FoldFOrdLessThanEqual() {
return FoldFPBinaryOp(FOLD_FPCMP_OP(<=, true));
}
ConstantFoldingRule FoldFUnordLessThanEqual() {
return FoldFPBinaryOp(FOLD_FPCMP_OP(<=, false));
}
ConstantFoldingRule FoldFOrdGreaterThanEqual() {
return FoldFPBinaryOp(FOLD_FPCMP_OP(>=, true));
}
ConstantFoldingRule FoldFUnordGreaterThanEqual() {
return FoldFPBinaryOp(FOLD_FPCMP_OP(>=, false));
}
// Folds an OpDot where all of the inputs are constants to a
// constant. A new constant is created if necessary.
ConstantFoldingRule FoldOpDotWithConstants() {
return [](IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>& constants)
-> const analysis::Constant* {
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
analysis::TypeManager* type_mgr = context->get_type_mgr();
const analysis::Type* new_type = type_mgr->GetType(inst->type_id());
assert(new_type->AsFloat() && "OpDot should have a float return type.");
const analysis::Float* float_type = new_type->AsFloat();
if (!inst->IsFloatingPointFoldingAllowed()) {
return nullptr;
}
// If one of the operands is 0, then the result is 0.
bool has_zero_operand = false;
for (int i = 0; i < 2; ++i) {
if (constants[i]) {
if (constants[i]->AsNullConstant() ||
constants[i]->AsVectorConstant()->IsZero()) {
has_zero_operand = true;
break;
}
}
}
if (has_zero_operand) {
if (float_type->width() == 32) {
utils::FloatProxy<float> result(0.0f);
std::vector<uint32_t> words = result.GetWords();
return const_mgr->GetConstant(float_type, words);
}
if (float_type->width() == 64) {
utils::FloatProxy<double> result(0.0);
std::vector<uint32_t> words = result.GetWords();
return const_mgr->GetConstant(float_type, words);
}
return nullptr;
}
if (constants[0] == nullptr || constants[1] == nullptr) {
return nullptr;
}
std::vector<const analysis::Constant*> a_components;
std::vector<const analysis::Constant*> b_components;
a_components = constants[0]->GetVectorComponents(const_mgr);
b_components = constants[1]->GetVectorComponents(const_mgr);
utils::FloatProxy<double> result(0.0);
std::vector<uint32_t> words = result.GetWords();
const analysis::Constant* result_const =
const_mgr->GetConstant(float_type, words);
for (uint32_t i = 0; i < a_components.size(); ++i) {
if (a_components[i] == nullptr || b_components[i] == nullptr) {
return nullptr;
}
const analysis::Constant* component = FOLD_FPARITH_OP(*)(
new_type, a_components[i], b_components[i], const_mgr);
result_const =
FOLD_FPARITH_OP(+)(new_type, result_const, component, const_mgr);
}
return result_const;
};
}
// This function defines a |UnaryScalarFoldingRule| that subtracts the constant
// from zero.
UnaryScalarFoldingRule FoldFNegateOp() {
return [](const analysis::Type* result_type, const analysis::Constant* a,
analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
assert(result_type != nullptr && a != nullptr);
assert(result_type == a->type());
const analysis::Float* float_type = result_type->AsFloat();
assert(float_type != nullptr);
if (float_type->width() == 32) {
float fa = a->GetFloat();
utils::FloatProxy<float> result(-fa);
std::vector<uint32_t> words = result.GetWords();
return const_mgr->GetConstant(result_type, words);
} else if (float_type->width() == 64) {
double da = a->GetDouble();
utils::FloatProxy<double> result(-da);
std::vector<uint32_t> words = result.GetWords();
return const_mgr->GetConstant(result_type, words);
}
return nullptr;
};
}
ConstantFoldingRule FoldFNegate() { return FoldFPUnaryOp(FoldFNegateOp()); }
ConstantFoldingRule FoldFClampFeedingCompare(uint32_t cmp_opcode) {
return [cmp_opcode](IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>& constants)
-> const analysis::Constant* {
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
if (!inst->IsFloatingPointFoldingAllowed()) {
return nullptr;
}
uint32_t non_const_idx = (constants[0] ? 1 : 0);
uint32_t operand_id = inst->GetSingleWordInOperand(non_const_idx);
Instruction* operand_inst = def_use_mgr->GetDef(operand_id);
analysis::TypeManager* type_mgr = context->get_type_mgr();
const analysis::Type* operand_type =
type_mgr->GetType(operand_inst->type_id());
if (!operand_type->AsFloat()) {
return nullptr;
}
if (operand_type->AsFloat()->width() != 32 &&
operand_type->AsFloat()->width() != 64) {
return nullptr;
}
if (operand_inst->opcode() != SpvOpExtInst) {
return nullptr;
}
if (operand_inst->GetSingleWordInOperand(1) != GLSLstd450FClamp) {
return nullptr;
}
if (constants[1] == nullptr && constants[0] == nullptr) {
return nullptr;
}
uint32_t max_id = operand_inst->GetSingleWordInOperand(4);
const analysis::Constant* max_const =
const_mgr->FindDeclaredConstant(max_id);
uint32_t min_id = operand_inst->GetSingleWordInOperand(3);
const analysis::Constant* min_const =
const_mgr->FindDeclaredConstant(min_id);
bool found_result = false;
bool result = false;
switch (cmp_opcode) {
case SpvOpFOrdLessThan:
case SpvOpFUnordLessThan:
case SpvOpFOrdGreaterThanEqual:
case SpvOpFUnordGreaterThanEqual:
if (constants[0]) {
if (min_const) {
if (constants[0]->GetValueAsDouble() <
min_const->GetValueAsDouble()) {
found_result = true;
result = (cmp_opcode == SpvOpFOrdLessThan ||
cmp_opcode == SpvOpFUnordLessThan);
}
}
if (max_const) {
if (constants[0]->GetValueAsDouble() >=
max_const->GetValueAsDouble()) {
found_result = true;
result = !(cmp_opcode == SpvOpFOrdLessThan ||
cmp_opcode == SpvOpFUnordLessThan);
}
}
}
if (constants[1]) {
if (max_const) {
if (max_const->GetValueAsDouble() <
constants[1]->GetValueAsDouble()) {
found_result = true;
result = (cmp_opcode == SpvOpFOrdLessThan ||
cmp_opcode == SpvOpFUnordLessThan);
}
}
if (min_const) {
if (min_const->GetValueAsDouble() >=
constants[1]->GetValueAsDouble()) {
found_result = true;
result = !(cmp_opcode == SpvOpFOrdLessThan ||
cmp_opcode == SpvOpFUnordLessThan);
}
}
}
break;
case SpvOpFOrdGreaterThan:
case SpvOpFUnordGreaterThan:
case SpvOpFOrdLessThanEqual:
case SpvOpFUnordLessThanEqual:
if (constants[0]) {
if (min_const) {
if (constants[0]->GetValueAsDouble() <=
min_const->GetValueAsDouble()) {
found_result = true;
result = (cmp_opcode == SpvOpFOrdLessThanEqual ||
cmp_opcode == SpvOpFUnordLessThanEqual);
}
}
if (max_const) {
if (constants[0]->GetValueAsDouble() >
max_const->GetValueAsDouble()) {
found_result = true;
result = !(cmp_opcode == SpvOpFOrdLessThanEqual ||
cmp_opcode == SpvOpFUnordLessThanEqual);
}
}
}
if (constants[1]) {
if (max_const) {
if (max_const->GetValueAsDouble() <=
constants[1]->GetValueAsDouble()) {
found_result = true;
result = (cmp_opcode == SpvOpFOrdLessThanEqual ||
cmp_opcode == SpvOpFUnordLessThanEqual);
}
}
if (min_const) {
if (min_const->GetValueAsDouble() >
constants[1]->GetValueAsDouble()) {
found_result = true;
result = !(cmp_opcode == SpvOpFOrdLessThanEqual ||
cmp_opcode == SpvOpFUnordLessThanEqual);
}
}
}
break;
default:
return nullptr;
}
if (!found_result) {
return nullptr;
}
const analysis::Type* bool_type =
context->get_type_mgr()->GetType(inst->type_id());
const analysis::Constant* result_const =
const_mgr->GetConstant(bool_type, {static_cast<uint32_t>(result)});
assert(result_const);
return result_const;
};
}
} // namespace
ConstantFoldingRules::ConstantFoldingRules() {
// Add all folding rules to the list for the opcodes to which they apply.
// Note that the order in which rules are added to the list matters. If a rule
// applies to the instruction, the rest of the rules will not be attempted.
// Take that into consideration.
rules_[SpvOpCompositeConstruct].push_back(FoldCompositeWithConstants());
rules_[SpvOpCompositeExtract].push_back(FoldExtractWithConstants());
rules_[SpvOpConvertFToS].push_back(FoldFToI());
rules_[SpvOpConvertFToU].push_back(FoldFToI());
rules_[SpvOpConvertSToF].push_back(FoldIToF());
rules_[SpvOpConvertUToF].push_back(FoldIToF());
rules_[SpvOpDot].push_back(FoldOpDotWithConstants());
rules_[SpvOpFAdd].push_back(FoldFAdd());
rules_[SpvOpFDiv].push_back(FoldFDiv());
rules_[SpvOpFMul].push_back(FoldFMul());
rules_[SpvOpFSub].push_back(FoldFSub());
rules_[SpvOpFOrdEqual].push_back(FoldFOrdEqual());
rules_[SpvOpFUnordEqual].push_back(FoldFUnordEqual());
rules_[SpvOpFOrdNotEqual].push_back(FoldFOrdNotEqual());
rules_[SpvOpFUnordNotEqual].push_back(FoldFUnordNotEqual());
rules_[SpvOpFOrdLessThan].push_back(FoldFOrdLessThan());
rules_[SpvOpFOrdLessThan].push_back(
FoldFClampFeedingCompare(SpvOpFOrdLessThan));
rules_[SpvOpFUnordLessThan].push_back(FoldFUnordLessThan());
rules_[SpvOpFUnordLessThan].push_back(
FoldFClampFeedingCompare(SpvOpFUnordLessThan));
rules_[SpvOpFOrdGreaterThan].push_back(FoldFOrdGreaterThan());
rules_[SpvOpFOrdGreaterThan].push_back(
FoldFClampFeedingCompare(SpvOpFOrdGreaterThan));
rules_[SpvOpFUnordGreaterThan].push_back(FoldFUnordGreaterThan());
rules_[SpvOpFUnordGreaterThan].push_back(
FoldFClampFeedingCompare(SpvOpFUnordGreaterThan));
rules_[SpvOpFOrdLessThanEqual].push_back(FoldFOrdLessThanEqual());
rules_[SpvOpFOrdLessThanEqual].push_back(
FoldFClampFeedingCompare(SpvOpFOrdLessThanEqual));
rules_[SpvOpFUnordLessThanEqual].push_back(FoldFUnordLessThanEqual());
rules_[SpvOpFUnordLessThanEqual].push_back(
FoldFClampFeedingCompare(SpvOpFUnordLessThanEqual));
rules_[SpvOpFOrdGreaterThanEqual].push_back(FoldFOrdGreaterThanEqual());
rules_[SpvOpFOrdGreaterThanEqual].push_back(
FoldFClampFeedingCompare(SpvOpFOrdGreaterThanEqual));
rules_[SpvOpFUnordGreaterThanEqual].push_back(FoldFUnordGreaterThanEqual());
rules_[SpvOpFUnordGreaterThanEqual].push_back(
FoldFClampFeedingCompare(SpvOpFUnordGreaterThanEqual));
rules_[SpvOpVectorShuffle].push_back(FoldVectorShuffleWithConstants());
rules_[SpvOpVectorTimesScalar].push_back(FoldVectorTimesScalar());
rules_[SpvOpFNegate].push_back(FoldFNegate());
}
} // namespace opt
} // namespace spvtools