SPIRV-Tools/source/opt/scalar_analysis.cpp
dan sinclair c7da51a085
Cleanup extraneous namespace qualifies in source/opt. (#1716)
This CL follows up on the opt namespacing CLs by removing the
unnecessary opt:: and opt::analysis:: namespace prefixes.
2018-07-12 15:14:43 -04:00

989 lines
33 KiB
C++

// Copyright (c) 2018 Google LLC.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "opt/scalar_analysis.h"
#include <algorithm>
#include <functional>
#include <string>
#include <utility>
#include "opt/ir_context.h"
// Transforms a given scalar operation instruction into a DAG representation.
//
// 1. Take an instruction and traverse its operands until we reach a
// constant node or an instruction which we do not know how to compute the
// value, such as a load.
//
// 2. Create a new node for each instruction traversed and build the nodes for
// the in operands of that instruction as well.
//
// 3. Add the operand nodes as children of the first and hash the node. Use the
// hash to see if the node is already in the cache. We ensure the children are
// always in sorted order so that two nodes with the same children but inserted
// in a different order have the same hash and so that the overloaded operator==
// will return true. If the node is already in the cache return the cached
// version instead.
//
// 4. The created DAG can then be simplified by
// ScalarAnalysis::SimplifyExpression, implemented in
// scalar_analysis_simplification.cpp. See that file for further information on
// the simplification process.
//
namespace spvtools {
namespace opt {
uint32_t SENode::NumberOfNodes = 0;
ScalarEvolutionAnalysis::ScalarEvolutionAnalysis(IRContext* context)
: context_(context), pretend_equal_{} {
// Create and cached the CantComputeNode.
cached_cant_compute_ =
GetCachedOrAdd(std::unique_ptr<SECantCompute>(new SECantCompute(this)));
}
SENode* ScalarEvolutionAnalysis::CreateNegation(SENode* operand) {
// If operand is can't compute then the whole graph is can't compute.
if (operand->IsCantCompute()) return CreateCantComputeNode();
if (operand->GetType() == SENode::Constant) {
return CreateConstant(-operand->AsSEConstantNode()->FoldToSingleValue());
}
std::unique_ptr<SENode> negation_node{new SENegative(this)};
negation_node->AddChild(operand);
return GetCachedOrAdd(std::move(negation_node));
}
SENode* ScalarEvolutionAnalysis::CreateConstant(int64_t integer) {
return GetCachedOrAdd(
std::unique_ptr<SENode>(new SEConstantNode(this, integer)));
}
SENode* ScalarEvolutionAnalysis::CreateRecurrentExpression(
const Loop* loop, SENode* offset, SENode* coefficient) {
assert(loop && "Recurrent add expressions must have a valid loop.");
// If operands are can't compute then the whole graph is can't compute.
if (offset->IsCantCompute() || coefficient->IsCantCompute())
return CreateCantComputeNode();
const Loop* loop_to_use = nullptr;
if (pretend_equal_[loop]) {
loop_to_use = pretend_equal_[loop];
} else {
loop_to_use = loop;
}
std::unique_ptr<SERecurrentNode> phi_node{
new SERecurrentNode(this, loop_to_use)};
phi_node->AddOffset(offset);
phi_node->AddCoefficient(coefficient);
return GetCachedOrAdd(std::move(phi_node));
}
SENode* ScalarEvolutionAnalysis::AnalyzeMultiplyOp(
const Instruction* multiply) {
assert(multiply->opcode() == SpvOp::SpvOpIMul &&
"Multiply node did not come from a multiply instruction");
analysis::DefUseManager* def_use = context_->get_def_use_mgr();
SENode* op1 =
AnalyzeInstruction(def_use->GetDef(multiply->GetSingleWordInOperand(0)));
SENode* op2 =
AnalyzeInstruction(def_use->GetDef(multiply->GetSingleWordInOperand(1)));
return CreateMultiplyNode(op1, op2);
}
SENode* ScalarEvolutionAnalysis::CreateMultiplyNode(SENode* operand_1,
SENode* operand_2) {
// If operands are can't compute then the whole graph is can't compute.
if (operand_1->IsCantCompute() || operand_2->IsCantCompute())
return CreateCantComputeNode();
if (operand_1->GetType() == SENode::Constant &&
operand_2->GetType() == SENode::Constant) {
return CreateConstant(operand_1->AsSEConstantNode()->FoldToSingleValue() *
operand_2->AsSEConstantNode()->FoldToSingleValue());
}
std::unique_ptr<SENode> multiply_node{new SEMultiplyNode(this)};
multiply_node->AddChild(operand_1);
multiply_node->AddChild(operand_2);
return GetCachedOrAdd(std::move(multiply_node));
}
SENode* ScalarEvolutionAnalysis::CreateSubtraction(SENode* operand_1,
SENode* operand_2) {
// Fold if both operands are constant.
if (operand_1->GetType() == SENode::Constant &&
operand_2->GetType() == SENode::Constant) {
return CreateConstant(operand_1->AsSEConstantNode()->FoldToSingleValue() -
operand_2->AsSEConstantNode()->FoldToSingleValue());
}
return CreateAddNode(operand_1, CreateNegation(operand_2));
}
SENode* ScalarEvolutionAnalysis::CreateAddNode(SENode* operand_1,
SENode* operand_2) {
// Fold if both operands are constant and the |simplify| flag is true.
if (operand_1->GetType() == SENode::Constant &&
operand_2->GetType() == SENode::Constant) {
return CreateConstant(operand_1->AsSEConstantNode()->FoldToSingleValue() +
operand_2->AsSEConstantNode()->FoldToSingleValue());
}
// If operands are can't compute then the whole graph is can't compute.
if (operand_1->IsCantCompute() || operand_2->IsCantCompute())
return CreateCantComputeNode();
std::unique_ptr<SENode> add_node{new SEAddNode(this)};
add_node->AddChild(operand_1);
add_node->AddChild(operand_2);
return GetCachedOrAdd(std::move(add_node));
}
SENode* ScalarEvolutionAnalysis::AnalyzeInstruction(const Instruction* inst) {
auto itr = recurrent_node_map_.find(inst);
if (itr != recurrent_node_map_.end()) return itr->second;
SENode* output = nullptr;
switch (inst->opcode()) {
case SpvOp::SpvOpPhi: {
output = AnalyzePhiInstruction(inst);
break;
}
case SpvOp::SpvOpConstant:
case SpvOp::SpvOpConstantNull: {
output = AnalyzeConstant(inst);
break;
}
case SpvOp::SpvOpISub:
case SpvOp::SpvOpIAdd: {
output = AnalyzeAddOp(inst);
break;
}
case SpvOp::SpvOpIMul: {
output = AnalyzeMultiplyOp(inst);
break;
}
default: {
output = CreateValueUnknownNode(inst);
break;
}
}
return output;
}
SENode* ScalarEvolutionAnalysis::AnalyzeConstant(const Instruction* inst) {
if (inst->opcode() == SpvOp::SpvOpConstantNull) return CreateConstant(0);
assert(inst->opcode() == SpvOp::SpvOpConstant);
assert(inst->NumInOperands() == 1);
int64_t value = 0;
// Look up the instruction in the constant manager.
const analysis::Constant* constant =
context_->get_constant_mgr()->FindDeclaredConstant(inst->result_id());
if (!constant) return CreateCantComputeNode();
const analysis::IntConstant* int_constant = constant->AsIntConstant();
// Exit out if it is a 64 bit integer.
if (!int_constant || int_constant->words().size() != 1)
return CreateCantComputeNode();
if (int_constant->type()->AsInteger()->IsSigned()) {
value = int_constant->GetS32BitValue();
} else {
value = int_constant->GetU32BitValue();
}
return CreateConstant(value);
}
// Handles both addition and subtraction. If the |sub| flag is set then the
// addition will be op1+(-op2) otherwise op1+op2.
SENode* ScalarEvolutionAnalysis::AnalyzeAddOp(const Instruction* inst) {
assert((inst->opcode() == SpvOp::SpvOpIAdd ||
inst->opcode() == SpvOp::SpvOpISub) &&
"Add node must be created from a OpIAdd or OpISub instruction");
analysis::DefUseManager* def_use = context_->get_def_use_mgr();
SENode* op1 =
AnalyzeInstruction(def_use->GetDef(inst->GetSingleWordInOperand(0)));
SENode* op2 =
AnalyzeInstruction(def_use->GetDef(inst->GetSingleWordInOperand(1)));
// To handle subtraction we wrap the second operand in a unary negation node.
if (inst->opcode() == SpvOp::SpvOpISub) {
op2 = CreateNegation(op2);
}
return CreateAddNode(op1, op2);
}
SENode* ScalarEvolutionAnalysis::AnalyzePhiInstruction(const Instruction* phi) {
// The phi should only have two incoming value pairs.
if (phi->NumInOperands() != 4) {
return CreateCantComputeNode();
}
analysis::DefUseManager* def_use = context_->get_def_use_mgr();
// Get the basic block this instruction belongs to.
BasicBlock* basic_block =
context_->get_instr_block(const_cast<Instruction*>(phi));
// And then the function that the basic blocks belongs to.
Function* function = basic_block->GetParent();
// Use the function to get the loop descriptor.
LoopDescriptor* loop_descriptor = context_->GetLoopDescriptor(function);
// We only handle phis in loops at the moment.
if (!loop_descriptor) return CreateCantComputeNode();
// Get the innermost loop which this block belongs to.
Loop* loop = (*loop_descriptor)[basic_block->id()];
// If the loop doesn't exist or doesn't have a preheader or latch block, exit
// out.
if (!loop || !loop->GetLatchBlock() || !loop->GetPreHeaderBlock() ||
loop->GetHeaderBlock() != basic_block)
return recurrent_node_map_[phi] = CreateCantComputeNode();
const Loop* loop_to_use = nullptr;
if (pretend_equal_[loop]) {
loop_to_use = pretend_equal_[loop];
} else {
loop_to_use = loop;
}
std::unique_ptr<SERecurrentNode> phi_node{
new SERecurrentNode(this, loop_to_use)};
// We add the node to this map to allow it to be returned before the node is
// fully built. This is needed as the subsequent call to AnalyzeInstruction
// could lead back to this |phi| instruction so we return the pointer
// immediately in AnalyzeInstruction to break the recursion.
recurrent_node_map_[phi] = phi_node.get();
// Traverse the operands of the instruction an create new nodes for each one.
for (uint32_t i = 0; i < phi->NumInOperands(); i += 2) {
uint32_t value_id = phi->GetSingleWordInOperand(i);
uint32_t incoming_label_id = phi->GetSingleWordInOperand(i + 1);
Instruction* value_inst = def_use->GetDef(value_id);
SENode* value_node = AnalyzeInstruction(value_inst);
// If any operand is CantCompute then the whole graph is CantCompute.
if (value_node->IsCantCompute())
return recurrent_node_map_[phi] = CreateCantComputeNode();
// If the value is coming from the preheader block then the value is the
// initial value of the phi.
if (incoming_label_id == loop->GetPreHeaderBlock()->id()) {
phi_node->AddOffset(value_node);
} else if (incoming_label_id == loop->GetLatchBlock()->id()) {
// Assumed to be in the form of step + phi.
if (value_node->GetType() != SENode::Add)
return recurrent_node_map_[phi] = CreateCantComputeNode();
SENode* step_node = nullptr;
SENode* phi_operand = nullptr;
SENode* operand_1 = value_node->GetChild(0);
SENode* operand_2 = value_node->GetChild(1);
// Find which node is the step term.
if (!operand_1->AsSERecurrentNode())
step_node = operand_1;
else if (!operand_2->AsSERecurrentNode())
step_node = operand_2;
// Find which node is the recurrent expression.
if (operand_1->AsSERecurrentNode())
phi_operand = operand_1;
else if (operand_2->AsSERecurrentNode())
phi_operand = operand_2;
// If it is not in the form step + phi exit out.
if (!(step_node && phi_operand))
return recurrent_node_map_[phi] = CreateCantComputeNode();
// If the phi operand is not the same phi node exit out.
if (phi_operand != phi_node.get())
return recurrent_node_map_[phi] = CreateCantComputeNode();
if (!IsLoopInvariant(loop, step_node))
return recurrent_node_map_[phi] = CreateCantComputeNode();
phi_node->AddCoefficient(step_node);
}
}
// Once the node is fully built we update the map with the version from the
// cache (if it has already been added to the cache).
return recurrent_node_map_[phi] = GetCachedOrAdd(std::move(phi_node));
}
SENode* ScalarEvolutionAnalysis::CreateValueUnknownNode(
const Instruction* inst) {
std::unique_ptr<SEValueUnknown> load_node{
new SEValueUnknown(this, inst->result_id())};
return GetCachedOrAdd(std::move(load_node));
}
SENode* ScalarEvolutionAnalysis::CreateCantComputeNode() {
return cached_cant_compute_;
}
// Add the created node into the cache of nodes. If it already exists return it.
SENode* ScalarEvolutionAnalysis::GetCachedOrAdd(
std::unique_ptr<SENode> prospective_node) {
auto itr = node_cache_.find(prospective_node);
if (itr != node_cache_.end()) {
return (*itr).get();
}
SENode* raw_ptr_to_node = prospective_node.get();
node_cache_.insert(std::move(prospective_node));
return raw_ptr_to_node;
}
bool ScalarEvolutionAnalysis::IsLoopInvariant(const Loop* loop,
const SENode* node) const {
for (auto itr = node->graph_cbegin(); itr != node->graph_cend(); ++itr) {
if (const SERecurrentNode* rec = itr->AsSERecurrentNode()) {
const BasicBlock* header = rec->GetLoop()->GetHeaderBlock();
// If the loop which the recurrent expression belongs to is either |loop
// or a nested loop inside |loop| then we assume it is variant.
if (loop->IsInsideLoop(header)) {
return false;
}
} else if (const SEValueUnknown* unknown = itr->AsSEValueUnknown()) {
// If the instruction is inside the loop we conservatively assume it is
// loop variant.
if (loop->IsInsideLoop(unknown->ResultId())) return false;
}
}
return true;
}
SENode* ScalarEvolutionAnalysis::GetCoefficientFromRecurrentTerm(
SENode* node, const Loop* loop) {
// Traverse the DAG to find the recurrent expression belonging to |loop|.
for (auto itr = node->graph_begin(); itr != node->graph_end(); ++itr) {
SERecurrentNode* rec = itr->AsSERecurrentNode();
if (rec && rec->GetLoop() == loop) {
return rec->GetCoefficient();
}
}
return CreateConstant(0);
}
SENode* ScalarEvolutionAnalysis::UpdateChildNode(SENode* parent,
SENode* old_child,
SENode* new_child) {
// Only handles add.
if (parent->GetType() != SENode::Add) return parent;
std::vector<SENode*> new_children;
for (SENode* child : *parent) {
if (child == old_child) {
new_children.push_back(new_child);
} else {
new_children.push_back(child);
}
}
std::unique_ptr<SENode> add_node{new SEAddNode(this)};
for (SENode* child : new_children) {
add_node->AddChild(child);
}
return SimplifyExpression(GetCachedOrAdd(std::move(add_node)));
}
// Rebuild the |node| eliminating, if it exists, the recurrent term which
// belongs to the |loop|.
SENode* ScalarEvolutionAnalysis::BuildGraphWithoutRecurrentTerm(
SENode* node, const Loop* loop) {
// If the node is already a recurrent expression belonging to loop then just
// return the offset.
SERecurrentNode* recurrent = node->AsSERecurrentNode();
if (recurrent) {
if (recurrent->GetLoop() == loop) {
return recurrent->GetOffset();
} else {
return node;
}
}
std::vector<SENode*> new_children;
// Otherwise find the recurrent node in the children of this node.
for (auto itr : *node) {
recurrent = itr->AsSERecurrentNode();
if (recurrent && recurrent->GetLoop() == loop) {
new_children.push_back(recurrent->GetOffset());
} else {
new_children.push_back(itr);
}
}
std::unique_ptr<SENode> add_node{new SEAddNode(this)};
for (SENode* child : new_children) {
add_node->AddChild(child);
}
return SimplifyExpression(GetCachedOrAdd(std::move(add_node)));
}
// Return the recurrent term belonging to |loop| if it appears in the graph
// starting at |node| or null if it doesn't.
SERecurrentNode* ScalarEvolutionAnalysis::GetRecurrentTerm(SENode* node,
const Loop* loop) {
for (auto itr = node->graph_begin(); itr != node->graph_end(); ++itr) {
SERecurrentNode* rec = itr->AsSERecurrentNode();
if (rec && rec->GetLoop() == loop) {
return rec;
}
}
return nullptr;
}
std::string SENode::AsString() const {
switch (GetType()) {
case Constant:
return "Constant";
case RecurrentAddExpr:
return "RecurrentAddExpr";
case Add:
return "Add";
case Negative:
return "Negative";
case Multiply:
return "Multiply";
case ValueUnknown:
return "Value Unknown";
case CanNotCompute:
return "Can not compute";
}
return "NULL";
}
bool SENode::operator==(const SENode& other) const {
if (GetType() != other.GetType()) return false;
if (other.GetChildren().size() != children_.size()) return false;
const SERecurrentNode* this_as_recurrent = AsSERecurrentNode();
// Check the children are the same, for SERecurrentNodes we need to check the
// offset and coefficient manually as the child vector is sorted by ids so the
// offset/coefficient information is lost.
if (!this_as_recurrent) {
for (size_t index = 0; index < children_.size(); ++index) {
if (other.GetChildren()[index] != children_[index]) return false;
}
} else {
const SERecurrentNode* other_as_recurrent = other.AsSERecurrentNode();
// We've already checked the types are the same, this should not fail if
// this->AsSERecurrentNode() succeeded.
assert(other_as_recurrent);
if (this_as_recurrent->GetCoefficient() !=
other_as_recurrent->GetCoefficient())
return false;
if (this_as_recurrent->GetOffset() != other_as_recurrent->GetOffset())
return false;
if (this_as_recurrent->GetLoop() != other_as_recurrent->GetLoop())
return false;
}
// If we're dealing with a value unknown node check both nodes were created by
// the same instruction.
if (GetType() == SENode::ValueUnknown) {
if (AsSEValueUnknown()->ResultId() !=
other.AsSEValueUnknown()->ResultId()) {
return false;
}
}
if (AsSEConstantNode()) {
if (AsSEConstantNode()->FoldToSingleValue() !=
other.AsSEConstantNode()->FoldToSingleValue())
return false;
}
return true;
}
bool SENode::operator!=(const SENode& other) const { return !(*this == other); }
namespace {
// Helper functions to insert 32/64 bit values into the 32 bit hash string. This
// allows us to add pointers to the string by reinterpreting the pointers as
// uintptr_t. PushToString will deduce the type, call sizeof on it and use
// that size to call into the correct PushToStringImpl functor depending on
// whether it is 32 or 64 bit.
template <typename T, size_t size_of_t>
struct PushToStringImpl;
template <typename T>
struct PushToStringImpl<T, 8> {
void operator()(T id, std::u32string* str) {
str->push_back(static_cast<uint32_t>(id >> 32));
str->push_back(static_cast<uint32_t>(id));
}
};
template <typename T>
struct PushToStringImpl<T, 4> {
void operator()(T id, std::u32string* str) {
str->push_back(static_cast<uint32_t>(id));
}
};
template <typename T>
static void PushToString(T id, std::u32string* str) {
PushToStringImpl<T, sizeof(T)>{}(id, str);
}
} // namespace
// Implements the hashing of SENodes.
size_t SENodeHash::operator()(const SENode* node) const {
// Concatinate the terms into a string which we can hash.
std::u32string hash_string{};
// Hashing the type as a string is safer than hashing the enum as the enum is
// very likely to collide with constants.
for (char ch : node->AsString()) {
hash_string.push_back(static_cast<char32_t>(ch));
}
// We just ignore the literal value unless it is a constant.
if (node->GetType() == SENode::Constant)
PushToString(node->AsSEConstantNode()->FoldToSingleValue(), &hash_string);
const SERecurrentNode* recurrent = node->AsSERecurrentNode();
// If we're dealing with a recurrent expression hash the loop as well so that
// nested inductions like i=0,i++ and j=0,j++ correspond to different nodes.
if (recurrent) {
PushToString(reinterpret_cast<uintptr_t>(recurrent->GetLoop()),
&hash_string);
// Recurrent expressions can't be hashed using the normal method as the
// order of coefficient and offset matters to the hash.
PushToString(reinterpret_cast<uintptr_t>(recurrent->GetCoefficient()),
&hash_string);
PushToString(reinterpret_cast<uintptr_t>(recurrent->GetOffset()),
&hash_string);
return std::hash<std::u32string>{}(hash_string);
}
// Hash the result id of the original instruction which created this node if
// it is a value unknown node.
if (node->GetType() == SENode::ValueUnknown) {
PushToString(node->AsSEValueUnknown()->ResultId(), &hash_string);
}
// Hash the pointers of the child nodes, each SENode has a unique pointer
// associated with it.
const std::vector<SENode*>& children = node->GetChildren();
for (const SENode* child : children) {
PushToString(reinterpret_cast<uintptr_t>(child), &hash_string);
}
return std::hash<std::u32string>{}(hash_string);
}
// This overload is the actual overload used by the node_cache_ set.
size_t SENodeHash::operator()(const std::unique_ptr<SENode>& node) const {
return this->operator()(node.get());
}
void SENode::DumpDot(std::ostream& out, bool recurse) const {
size_t unique_id = std::hash<const SENode*>{}(this);
out << unique_id << " [label=\"" << AsString() << " ";
if (GetType() == SENode::Constant) {
out << "\nwith value: " << this->AsSEConstantNode()->FoldToSingleValue();
}
out << "\"]\n";
for (const SENode* child : children_) {
size_t child_unique_id = std::hash<const SENode*>{}(child);
out << unique_id << " -> " << child_unique_id << " \n";
if (recurse) child->DumpDot(out, true);
}
}
namespace {
class IsGreaterThanZero {
public:
explicit IsGreaterThanZero(IRContext* context) : context_(context) {}
// Determine if the value of |node| is always strictly greater than zero if
// |or_equal_zero| is false or greater or equal to zero if |or_equal_zero| is
// true. It returns true is the evaluation was able to conclude something, in
// which case the result is stored in |result|.
// The algorithm work by going through all the nodes and determine the
// sign of each of them.
bool Eval(const SENode* node, bool or_equal_zero, bool* result) {
*result = false;
switch (Visit(node)) {
case Signedness::kPositiveOrNegative: {
return false;
}
case Signedness::kStrictlyNegative: {
*result = false;
break;
}
case Signedness::kNegative: {
if (!or_equal_zero) {
return false;
}
*result = false;
break;
}
case Signedness::kStrictlyPositive: {
*result = true;
break;
}
case Signedness::kPositive: {
if (!or_equal_zero) {
return false;
}
*result = true;
break;
}
}
return true;
}
private:
enum class Signedness {
kPositiveOrNegative, // Yield a value positive or negative.
kStrictlyNegative, // Yield a value strictly less than 0.
kNegative, // Yield a value less or equal to 0.
kStrictlyPositive, // Yield a value strictly greater than 0.
kPositive // Yield a value greater or equal to 0.
};
// Combine the signedness according to arithmetic rules of a given operator.
using Combiner = std::function<Signedness(Signedness, Signedness)>;
// Returns a functor to interpret the signedness of 2 expressions as if they
// were added.
Combiner GetAddCombiner() const {
return [](Signedness lhs, Signedness rhs) {
switch (lhs) {
case Signedness::kPositiveOrNegative:
break;
case Signedness::kStrictlyNegative:
if (rhs == Signedness::kStrictlyNegative ||
rhs == Signedness::kNegative)
return lhs;
break;
case Signedness::kNegative: {
if (rhs == Signedness::kStrictlyNegative)
return Signedness::kStrictlyNegative;
if (rhs == Signedness::kNegative) return Signedness::kNegative;
break;
}
case Signedness::kStrictlyPositive: {
if (rhs == Signedness::kStrictlyPositive ||
rhs == Signedness::kPositive) {
return Signedness::kStrictlyPositive;
}
break;
}
case Signedness::kPositive: {
if (rhs == Signedness::kStrictlyPositive)
return Signedness::kStrictlyPositive;
if (rhs == Signedness::kPositive) return Signedness::kPositive;
break;
}
}
return Signedness::kPositiveOrNegative;
};
}
// Returns a functor to interpret the signedness of 2 expressions as if they
// were multiplied.
Combiner GetMulCombiner() const {
return [](Signedness lhs, Signedness rhs) {
switch (lhs) {
case Signedness::kPositiveOrNegative:
break;
case Signedness::kStrictlyNegative: {
switch (rhs) {
case Signedness::kPositiveOrNegative: {
break;
}
case Signedness::kStrictlyNegative: {
return Signedness::kStrictlyPositive;
}
case Signedness::kNegative: {
return Signedness::kPositive;
}
case Signedness::kStrictlyPositive: {
return Signedness::kStrictlyNegative;
}
case Signedness::kPositive: {
return Signedness::kNegative;
}
}
break;
}
case Signedness::kNegative: {
switch (rhs) {
case Signedness::kPositiveOrNegative: {
break;
}
case Signedness::kStrictlyNegative:
case Signedness::kNegative: {
return Signedness::kPositive;
}
case Signedness::kStrictlyPositive:
case Signedness::kPositive: {
return Signedness::kNegative;
}
}
break;
}
case Signedness::kStrictlyPositive: {
return rhs;
}
case Signedness::kPositive: {
switch (rhs) {
case Signedness::kPositiveOrNegative: {
break;
}
case Signedness::kStrictlyNegative:
case Signedness::kNegative: {
return Signedness::kNegative;
}
case Signedness::kStrictlyPositive:
case Signedness::kPositive: {
return Signedness::kPositive;
}
}
break;
}
}
return Signedness::kPositiveOrNegative;
};
}
Signedness Visit(const SENode* node) {
switch (node->GetType()) {
case SENode::Constant:
return Visit(node->AsSEConstantNode());
break;
case SENode::RecurrentAddExpr:
return Visit(node->AsSERecurrentNode());
break;
case SENode::Negative:
return Visit(node->AsSENegative());
break;
case SENode::CanNotCompute:
return Visit(node->AsSECantCompute());
break;
case SENode::ValueUnknown:
return Visit(node->AsSEValueUnknown());
break;
case SENode::Add:
return VisitExpr(node, GetAddCombiner());
break;
case SENode::Multiply:
return VisitExpr(node, GetMulCombiner());
break;
}
return Signedness::kPositiveOrNegative;
}
// Returns the signedness of a constant |node|.
Signedness Visit(const SEConstantNode* node) {
if (0 == node->FoldToSingleValue()) return Signedness::kPositive;
if (0 < node->FoldToSingleValue()) return Signedness::kStrictlyPositive;
if (0 > node->FoldToSingleValue()) return Signedness::kStrictlyNegative;
return Signedness::kPositiveOrNegative;
}
// Returns the signedness of an unknown |node| based on its type.
Signedness Visit(const SEValueUnknown* node) {
Instruction* insn = context_->get_def_use_mgr()->GetDef(node->ResultId());
analysis::Type* type = context_->get_type_mgr()->GetType(insn->type_id());
assert(type && "Can't retrieve a type for the instruction");
analysis::Integer* int_type = type->AsInteger();
assert(type && "Can't retrieve an integer type for the instruction");
return int_type->IsSigned() ? Signedness::kPositiveOrNegative
: Signedness::kPositive;
}
// Returns the signedness of a recurring expression.
Signedness Visit(const SERecurrentNode* node) {
Signedness coeff_sign = Visit(node->GetCoefficient());
// SERecurrentNode represent an affine expression in the range [0,
// loop_bound], so the result cannot be strictly positive or negative.
switch (coeff_sign) {
default:
break;
case Signedness::kStrictlyNegative:
coeff_sign = Signedness::kNegative;
break;
case Signedness::kStrictlyPositive:
coeff_sign = Signedness::kPositive;
break;
}
return GetAddCombiner()(coeff_sign, Visit(node->GetOffset()));
}
// Returns the signedness of a negation |node|.
Signedness Visit(const SENegative* node) {
switch (Visit(*node->begin())) {
case Signedness::kPositiveOrNegative: {
return Signedness::kPositiveOrNegative;
}
case Signedness::kStrictlyNegative: {
return Signedness::kStrictlyPositive;
}
case Signedness::kNegative: {
return Signedness::kPositive;
}
case Signedness::kStrictlyPositive: {
return Signedness::kStrictlyNegative;
}
case Signedness::kPositive: {
return Signedness::kNegative;
}
}
return Signedness::kPositiveOrNegative;
}
Signedness Visit(const SECantCompute*) {
return Signedness::kPositiveOrNegative;
}
// Returns the signedness of a binary expression by using the combiner
// |reduce|.
Signedness VisitExpr(
const SENode* node,
std::function<Signedness(Signedness, Signedness)> reduce) {
Signedness result = Visit(*node->begin());
for (const SENode* operand : make_range(++node->begin(), node->end())) {
if (result == Signedness::kPositiveOrNegative) {
return Signedness::kPositiveOrNegative;
}
result = reduce(result, Visit(operand));
}
return result;
}
IRContext* context_;
};
} // namespace
bool ScalarEvolutionAnalysis::IsAlwaysGreaterThanZero(SENode* node,
bool* is_gt_zero) const {
return IsGreaterThanZero(context_).Eval(node, false, is_gt_zero);
}
bool ScalarEvolutionAnalysis::IsAlwaysGreaterOrEqualToZero(
SENode* node, bool* is_ge_zero) const {
return IsGreaterThanZero(context_).Eval(node, true, is_ge_zero);
}
namespace {
// Remove |node| from the |mul| chain (of the form A * ... * |node| * ... * Z),
// if |node| is not in the chain, returns the original chain.
static SENode* RemoveOneNodeFromMultiplyChain(SEMultiplyNode* mul,
const SENode* node) {
SENode* lhs = mul->GetChildren()[0];
SENode* rhs = mul->GetChildren()[1];
if (lhs == node) {
return rhs;
}
if (rhs == node) {
return lhs;
}
if (lhs->AsSEMultiplyNode()) {
SENode* res = RemoveOneNodeFromMultiplyChain(lhs->AsSEMultiplyNode(), node);
if (res != lhs)
return mul->GetParentAnalysis()->CreateMultiplyNode(res, rhs);
}
if (rhs->AsSEMultiplyNode()) {
SENode* res = RemoveOneNodeFromMultiplyChain(rhs->AsSEMultiplyNode(), node);
if (res != rhs)
return mul->GetParentAnalysis()->CreateMultiplyNode(res, rhs);
}
return mul;
}
} // namespace
std::pair<SExpression, int64_t> SExpression::operator/(
SExpression rhs_wrapper) const {
SENode* lhs = node_;
SENode* rhs = rhs_wrapper.node_;
// Check for division by 0.
if (rhs->AsSEConstantNode() &&
!rhs->AsSEConstantNode()->FoldToSingleValue()) {
return {scev_->CreateCantComputeNode(), 0};
}
// Trivial case.
if (lhs->AsSEConstantNode() && rhs->AsSEConstantNode()) {
int64_t lhs_value = lhs->AsSEConstantNode()->FoldToSingleValue();
int64_t rhs_value = rhs->AsSEConstantNode()->FoldToSingleValue();
return {scev_->CreateConstant(lhs_value / rhs_value),
lhs_value % rhs_value};
}
// look for a "c U / U" pattern.
if (lhs->AsSEMultiplyNode()) {
assert(lhs->GetChildren().size() == 2 &&
"More than 2 operand for a multiply node.");
SENode* res = RemoveOneNodeFromMultiplyChain(lhs->AsSEMultiplyNode(), rhs);
if (res != lhs) {
return {res, 0};
}
}
return {scev_->CreateCantComputeNode(), 0};
}
} // namespace opt
} // namespace spvtools