SPIRV-Tools/source/opt/fold.cpp
Steven Perron bc1ec9418b Add general folding infrastructure.
Create the folding engine that will

1) attempt to fold an instruction.
2) iterates on the folding so small folding rules can be easily combined.
3) insert new instructions when needed.

I've added the minimum number of rules needed to test the features above.
2018-02-02 12:24:11 -05:00

690 lines
22 KiB
C++

// Copyright (c) 2017 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fold.h"
#include "def_use_manager.h"
#include "folding_rules.h"
#include "ir_builder.h"
#include "ir_context.h"
#include <cassert>
#include <cstdint>
#include <vector>
namespace spvtools {
namespace opt {
namespace {
#ifndef INT32_MIN
#define INT32_MIN (-2147483648)
#endif
#ifndef INT32_MAX
#define INT32_MAX 2147483647
#endif
#ifndef UINT32_MAX
#define UINT32_MAX 0xffffffff /* 4294967295U */
#endif
// Returns the single-word result from performing the given unary operation on
// the operand value which is passed in as a 32-bit word.
uint32_t UnaryOperate(SpvOp opcode, uint32_t operand) {
switch (opcode) {
// Arthimetics
case SpvOp::SpvOpSNegate:
return -static_cast<int32_t>(operand);
case SpvOp::SpvOpNot:
return ~operand;
case SpvOp::SpvOpLogicalNot:
return !static_cast<bool>(operand);
default:
assert(false &&
"Unsupported unary operation for OpSpecConstantOp instruction");
return 0u;
}
}
// Returns the single-word result from performing the given binary operation on
// the operand values which are passed in as two 32-bit word.
uint32_t BinaryOperate(SpvOp opcode, uint32_t a, uint32_t b) {
switch (opcode) {
// Arthimetics
case SpvOp::SpvOpIAdd:
return a + b;
case SpvOp::SpvOpISub:
return a - b;
case SpvOp::SpvOpIMul:
return a * b;
case SpvOp::SpvOpUDiv:
assert(b != 0);
return a / b;
case SpvOp::SpvOpSDiv:
assert(b != 0u);
return (static_cast<int32_t>(a)) / (static_cast<int32_t>(b));
case SpvOp::SpvOpSRem: {
// The sign of non-zero result comes from the first operand: a. This is
// guaranteed by C++11 rules for integer division operator. The division
// result is rounded toward zero, so the result of '%' has the sign of
// the first operand.
assert(b != 0u);
return static_cast<int32_t>(a) % static_cast<int32_t>(b);
}
case SpvOp::SpvOpSMod: {
// The sign of non-zero result comes from the second operand: b
assert(b != 0u);
int32_t rem = BinaryOperate(SpvOp::SpvOpSRem, a, b);
int32_t b_prim = static_cast<int32_t>(b);
return (rem + b_prim) % b_prim;
}
case SpvOp::SpvOpUMod:
assert(b != 0u);
return (a % b);
// Shifting
case SpvOp::SpvOpShiftRightLogical: {
return a >> b;
}
case SpvOp::SpvOpShiftRightArithmetic:
return (static_cast<int32_t>(a)) >> b;
case SpvOp::SpvOpShiftLeftLogical:
return a << b;
// Bitwise operations
case SpvOp::SpvOpBitwiseOr:
return a | b;
case SpvOp::SpvOpBitwiseAnd:
return a & b;
case SpvOp::SpvOpBitwiseXor:
return a ^ b;
// Logical
case SpvOp::SpvOpLogicalEqual:
return (static_cast<bool>(a)) == (static_cast<bool>(b));
case SpvOp::SpvOpLogicalNotEqual:
return (static_cast<bool>(a)) != (static_cast<bool>(b));
case SpvOp::SpvOpLogicalOr:
return (static_cast<bool>(a)) || (static_cast<bool>(b));
case SpvOp::SpvOpLogicalAnd:
return (static_cast<bool>(a)) && (static_cast<bool>(b));
// Comparison
case SpvOp::SpvOpIEqual:
return a == b;
case SpvOp::SpvOpINotEqual:
return a != b;
case SpvOp::SpvOpULessThan:
return a < b;
case SpvOp::SpvOpSLessThan:
return (static_cast<int32_t>(a)) < (static_cast<int32_t>(b));
case SpvOp::SpvOpUGreaterThan:
return a > b;
case SpvOp::SpvOpSGreaterThan:
return (static_cast<int32_t>(a)) > (static_cast<int32_t>(b));
case SpvOp::SpvOpULessThanEqual:
return a <= b;
case SpvOp::SpvOpSLessThanEqual:
return (static_cast<int32_t>(a)) <= (static_cast<int32_t>(b));
case SpvOp::SpvOpUGreaterThanEqual:
return a >= b;
case SpvOp::SpvOpSGreaterThanEqual:
return (static_cast<int32_t>(a)) >= (static_cast<int32_t>(b));
default:
assert(false &&
"Unsupported binary operation for OpSpecConstantOp instruction");
return 0u;
}
}
// Returns the single-word result from performing the given ternary operation
// on the operand values which are passed in as three 32-bit word.
uint32_t TernaryOperate(SpvOp opcode, uint32_t a, uint32_t b, uint32_t c) {
switch (opcode) {
case SpvOp::SpvOpSelect:
return (static_cast<bool>(a)) ? b : c;
default:
assert(false &&
"Unsupported ternary operation for OpSpecConstantOp instruction");
return 0u;
}
}
// Returns the single-word result from performing the given operation on the
// operand words. This only works with 32-bit operations and uses boolean
// convention that 0u is false, and anything else is boolean true.
// TODO(qining): Support operands other than 32-bit wide.
uint32_t OperateWords(SpvOp opcode,
const std::vector<uint32_t>& operand_words) {
switch (operand_words.size()) {
case 1:
return UnaryOperate(opcode, operand_words.front());
case 2:
return BinaryOperate(opcode, operand_words.front(), operand_words.back());
case 3:
return TernaryOperate(opcode, operand_words[0], operand_words[1],
operand_words[2]);
default:
assert(false && "Invalid number of operands");
return 0;
}
}
bool FoldInstructionInternal(ir::Instruction* inst,
std::function<uint32_t(uint32_t)> id_map) {
ir::IRContext* context = inst->context();
ir::Instruction* folded_inst = FoldInstructionToConstant(inst, id_map);
if (folded_inst != nullptr) {
inst->SetOpcode(SpvOpCopyObject);
inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {folded_inst->result_id()}}});
return true;
}
SpvOp opcode = inst->opcode();
analysis::ConstantManager* const_manger = context->get_constant_mgr();
std::vector<const analysis::Constant*> constants;
for (uint32_t i = 0; i < inst->NumInOperands(); i++) {
const ir::Operand* operand = &inst->GetInOperand(i);
if (operand->type != SPV_OPERAND_TYPE_ID) {
constants.push_back(nullptr);
} else {
uint32_t id = id_map(operand->words[0]);
inst->SetInOperand(i, {id});
const analysis::Constant* constant =
const_manger->FindDeclaredConstant(id);
constants.push_back(constant);
}
}
static FoldingRules* rules = new FoldingRules();
for (FoldingRule rule : rules->GetRulesForOpcode(opcode)) {
if (rule(inst, constants)) {
return true;
}
}
return false;
}
} // namespace
// Returns the result of performing an operation on scalar constant operands.
// This function extracts the operand values as 32 bit words and returns the
// result in 32 bit word. Scalar constants with longer than 32-bit width are
// not accepted in this function.
uint32_t FoldScalars(SpvOp opcode,
const std::vector<const analysis::Constant*>& operands) {
assert(IsFoldableOpcode(opcode) &&
"Unhandled instruction opcode in FoldScalars");
std::vector<uint32_t> operand_values_in_raw_words;
for (const auto& operand : operands) {
if (const analysis::ScalarConstant* scalar = operand->AsScalarConstant()) {
const auto& scalar_words = scalar->words();
assert(scalar_words.size() == 1 &&
"Scalar constants with longer than 32-bit width are not allowed "
"in FoldScalars()");
operand_values_in_raw_words.push_back(scalar_words.front());
} else if (operand->AsNullConstant()) {
operand_values_in_raw_words.push_back(0u);
} else {
assert(false &&
"FoldScalars() only accepts ScalarConst or NullConst type of "
"constant");
}
}
return OperateWords(opcode, operand_values_in_raw_words);
}
// Returns true if |inst| is a binary operation that takes two integers as
// parameters and folds to a constant that can be represented as an unsigned
// 32-bit value when the ids have been replaced by |id_map|. If |inst| can be
// folded, the resulting value is returned in |*result|. Valid result types for
// the instruction are any integer (signed or unsigned) with 32-bits or less, or
// a boolean value.
bool FoldBinaryIntegerOpToConstant(ir::Instruction* inst,
std::function<uint32_t(uint32_t)> id_map,
uint32_t* result) {
SpvOp opcode = inst->opcode();
ir::IRContext* context = inst->context();
analysis::ConstantManager* const_manger = context->get_constant_mgr();
uint32_t ids[2];
const analysis::IntConstant* constants[2];
for (uint32_t i = 0; i < 2; i++) {
const ir::Operand* operand = &inst->GetInOperand(i);
if (operand->type != SPV_OPERAND_TYPE_ID) {
return false;
}
ids[i] = id_map(operand->words[0]);
const analysis::Constant* constant =
const_manger->FindDeclaredConstant(ids[i]);
constants[i] = (constant != nullptr ? constant->AsIntConstant() : nullptr);
}
switch (opcode) {
// Arthimetics
case SpvOp::SpvOpIMul:
for (uint32_t i = 0; i < 2; i++) {
if (constants[i] != nullptr && constants[i]->IsZero()) {
*result = 0;
return true;
}
}
break;
case SpvOp::SpvOpUDiv:
case SpvOp::SpvOpSDiv:
case SpvOp::SpvOpSRem:
case SpvOp::SpvOpSMod:
case SpvOp::SpvOpUMod:
// This changes undefined behaviour (ie divide by 0) into a 0.
for (uint32_t i = 0; i < 2; i++) {
if (constants[i] != nullptr && constants[i]->IsZero()) {
*result = 0;
return true;
}
}
break;
// Shifting
case SpvOp::SpvOpShiftRightLogical:
case SpvOp::SpvOpShiftLeftLogical:
if (constants[1] != nullptr) {
// When shifting by a value larger than the size of the result, the
// result is undefined. We are setting the undefined behaviour to a
// result of 0.
uint32_t shift_amount = constants[1]->GetU32BitValue();
if (shift_amount >= 32) {
*result = 0;
return true;
}
}
break;
// Bitwise operations
case SpvOp::SpvOpBitwiseOr:
for (uint32_t i = 0; i < 2; i++) {
if (constants[i] != nullptr) {
// TODO: Change the mask against a value based on the bit width of the
// instruction result type. This way we can handle say 16-bit values
// as well.
uint32_t mask = constants[i]->GetU32BitValue();
if (mask == 0xFFFFFFFF) {
*result = 0xFFFFFFFF;
return true;
}
}
}
break;
case SpvOp::SpvOpBitwiseAnd:
for (uint32_t i = 0; i < 2; i++) {
if (constants[i] != nullptr) {
if (constants[i]->IsZero()) {
*result = 0;
return true;
}
}
}
break;
// Comparison
case SpvOp::SpvOpULessThan:
if (constants[0] != nullptr &&
constants[0]->GetU32BitValue() == UINT32_MAX) {
*result = false;
return true;
}
if (constants[1] != nullptr && constants[1]->GetU32BitValue() == 0) {
*result = false;
return true;
}
break;
case SpvOp::SpvOpSLessThan:
if (constants[0] != nullptr &&
constants[0]->GetS32BitValue() == INT32_MAX) {
*result = false;
return true;
}
if (constants[1] != nullptr &&
constants[1]->GetS32BitValue() == INT32_MIN) {
*result = false;
return true;
}
break;
case SpvOp::SpvOpUGreaterThan:
if (constants[0] != nullptr && constants[0]->IsZero()) {
*result = false;
return true;
}
if (constants[1] != nullptr &&
constants[1]->GetU32BitValue() == UINT32_MAX) {
*result = false;
return true;
}
break;
case SpvOp::SpvOpSGreaterThan:
if (constants[0] != nullptr &&
constants[0]->GetS32BitValue() == INT32_MIN) {
*result = false;
return true;
}
if (constants[1] != nullptr &&
constants[1]->GetS32BitValue() == INT32_MAX) {
*result = false;
return true;
}
break;
case SpvOp::SpvOpULessThanEqual:
if (constants[0] != nullptr && constants[0]->IsZero()) {
*result = true;
return true;
}
if (constants[1] != nullptr &&
constants[1]->GetU32BitValue() == UINT32_MAX) {
*result = true;
return true;
}
break;
case SpvOp::SpvOpSLessThanEqual:
if (constants[0] != nullptr &&
constants[0]->GetS32BitValue() == INT32_MIN) {
*result = true;
return true;
}
if (constants[1] != nullptr &&
constants[1]->GetS32BitValue() == INT32_MAX) {
*result = true;
return true;
}
break;
case SpvOp::SpvOpUGreaterThanEqual:
if (constants[0] != nullptr &&
constants[0]->GetU32BitValue() == UINT32_MAX) {
*result = true;
return true;
}
if (constants[1] != nullptr && constants[1]->GetU32BitValue() == 0) {
*result = true;
return true;
}
break;
case SpvOp::SpvOpSGreaterThanEqual:
if (constants[0] != nullptr &&
constants[0]->GetS32BitValue() == INT32_MAX) {
*result = true;
return true;
}
if (constants[1] != nullptr &&
constants[1]->GetS32BitValue() == INT32_MIN) {
*result = true;
return true;
}
break;
default:
break;
}
return false;
}
// Returns true if |inst| is a binary operation on two boolean values, and folds
// to a constant boolean value when the ids have been replaced using |id_map|.
// If |inst| can be folded, the result value is returned in |*result|.
bool FoldBinaryBooleanOpToConstant(ir::Instruction* inst,
std::function<uint32_t(uint32_t)> id_map,
uint32_t* result) {
SpvOp opcode = inst->opcode();
ir::IRContext* context = inst->context();
analysis::ConstantManager* const_manger = context->get_constant_mgr();
uint32_t ids[2];
const analysis::BoolConstant* constants[2];
for (uint32_t i = 0; i < 2; i++) {
const ir::Operand* operand = &inst->GetInOperand(i);
if (operand->type != SPV_OPERAND_TYPE_ID) {
return false;
}
ids[i] = id_map(operand->words[0]);
const analysis::Constant* constant =
const_manger->FindDeclaredConstant(ids[i]);
constants[i] = (constant != nullptr ? constant->AsBoolConstant() : nullptr);
}
switch (opcode) {
// Logical
case SpvOp::SpvOpLogicalOr:
for (uint32_t i = 0; i < 2; i++) {
if (constants[i] != nullptr) {
if (constants[i]->value()) {
*result = true;
return true;
}
}
}
break;
case SpvOp::SpvOpLogicalAnd:
for (uint32_t i = 0; i < 2; i++) {
if (constants[i] != nullptr) {
if (!constants[i]->value()) {
*result = false;
return true;
}
}
}
break;
default:
break;
}
return false;
}
// Returns true if |inst| can be folded to an constant when the ids have been
// substituted using id_map. If it can, the value is returned in |result|. If
// not, |result| is unchanged. It is assumed that not all operands are
// constant. Those cases are handled by |FoldScalar|.
bool FoldIntegerOpToConstant(ir::Instruction* inst,
std::function<uint32_t(uint32_t)> id_map,
uint32_t* result) {
assert(IsFoldableOpcode(inst->opcode()) &&
"Unhandled instruction opcode in FoldScalars");
switch (inst->NumInOperands()) {
case 2:
return FoldBinaryIntegerOpToConstant(inst, id_map, result) ||
FoldBinaryBooleanOpToConstant(inst, id_map, result);
default:
return false;
}
}
std::vector<uint32_t> FoldVectors(
SpvOp opcode, uint32_t num_dims,
const std::vector<const analysis::Constant*>& operands) {
assert(IsFoldableOpcode(opcode) &&
"Unhandled instruction opcode in FoldVectors");
std::vector<uint32_t> result;
for (uint32_t d = 0; d < num_dims; d++) {
std::vector<uint32_t> operand_values_for_one_dimension;
for (const auto& operand : operands) {
if (const analysis::VectorConstant* vector_operand =
operand->AsVectorConstant()) {
// Extract the raw value of the scalar component constants
// in 32-bit words here. The reason of not using FoldScalars() here
// is that we do not create temporary null constants as components
// when the vector operand is a NullConstant because Constant creation
// may need extra checks for the validity and that is not manageed in
// here.
if (const analysis::ScalarConstant* scalar_component =
vector_operand->GetComponents().at(d)->AsScalarConstant()) {
const auto& scalar_words = scalar_component->words();
assert(
scalar_words.size() == 1 &&
"Vector components with longer than 32-bit width are not allowed "
"in FoldVectors()");
operand_values_for_one_dimension.push_back(scalar_words.front());
} else if (operand->AsNullConstant()) {
operand_values_for_one_dimension.push_back(0u);
} else {
assert(false &&
"VectorConst should only has ScalarConst or NullConst as "
"components");
}
} else if (operand->AsNullConstant()) {
operand_values_for_one_dimension.push_back(0u);
} else {
assert(false &&
"FoldVectors() only accepts VectorConst or NullConst type of "
"constant");
}
}
result.push_back(OperateWords(opcode, operand_values_for_one_dimension));
}
return result;
}
bool IsFoldableOpcode(SpvOp opcode) {
// NOTE: Extend to more opcodes as new cases are handled in the folder
// functions.
switch (opcode) {
case SpvOp::SpvOpBitwiseAnd:
case SpvOp::SpvOpBitwiseOr:
case SpvOp::SpvOpBitwiseXor:
case SpvOp::SpvOpIAdd:
case SpvOp::SpvOpIEqual:
case SpvOp::SpvOpIMul:
case SpvOp::SpvOpINotEqual:
case SpvOp::SpvOpISub:
case SpvOp::SpvOpLogicalAnd:
case SpvOp::SpvOpLogicalEqual:
case SpvOp::SpvOpLogicalNot:
case SpvOp::SpvOpLogicalNotEqual:
case SpvOp::SpvOpLogicalOr:
case SpvOp::SpvOpNot:
case SpvOp::SpvOpSDiv:
case SpvOp::SpvOpSelect:
case SpvOp::SpvOpSGreaterThan:
case SpvOp::SpvOpSGreaterThanEqual:
case SpvOp::SpvOpShiftLeftLogical:
case SpvOp::SpvOpShiftRightArithmetic:
case SpvOp::SpvOpShiftRightLogical:
case SpvOp::SpvOpSLessThan:
case SpvOp::SpvOpSLessThanEqual:
case SpvOp::SpvOpSMod:
case SpvOp::SpvOpSNegate:
case SpvOp::SpvOpSRem:
case SpvOp::SpvOpUDiv:
case SpvOp::SpvOpUGreaterThan:
case SpvOp::SpvOpUGreaterThanEqual:
case SpvOp::SpvOpULessThan:
case SpvOp::SpvOpULessThanEqual:
case SpvOp::SpvOpUMod:
return true;
default:
return false;
}
}
bool IsFoldableConstant(const analysis::Constant* cst) {
// Currently supported constants are 32-bit values or null constants.
if (const analysis::ScalarConstant* scalar = cst->AsScalarConstant())
return scalar->words().size() == 1;
else
return cst->AsNullConstant() != nullptr;
}
ir::Instruction* FoldInstructionToConstant(
ir::Instruction* inst, std::function<uint32_t(uint32_t)> id_map) {
if (!inst->IsFoldable()) {
return nullptr;
}
ir::IRContext* context = inst->context();
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
// Collect the values of the constant parameters.
std::vector<const analysis::Constant*> constants;
bool missing_constants = false;
inst->ForEachInId([&constants, &missing_constants, const_mgr,
&id_map](uint32_t* op_id) {
uint32_t id = id_map(*op_id);
const analysis::Constant* const_op = const_mgr->FindDeclaredConstant(id);
if (!const_op || !IsFoldableConstant(const_op)) {
constants.push_back(nullptr);
missing_constants = true;
return;
}
constants.push_back(const_op);
});
uint32_t result_val = 0;
bool successful = false;
// If all parameters are constant, fold the instruction to a constant.
if (!missing_constants) {
result_val = FoldScalars(inst->opcode(), constants);
successful = true;
}
if (!successful) {
successful = FoldIntegerOpToConstant(inst, id_map, &result_val);
}
if (successful) {
const analysis::Constant* result_const =
const_mgr->GetConstant(const_mgr->GetType(inst), {result_val});
return const_mgr->GetDefiningInstruction(result_const);
}
return nullptr;
}
bool IsFoldableType(ir::Instruction* type_inst) {
// Support 32-bit integers.
if (type_inst->opcode() == SpvOpTypeInt) {
return type_inst->GetSingleWordInOperand(0) == 32;
}
// Support booleans.
if (type_inst->opcode() == SpvOpTypeBool) {
return true;
}
// Nothing else yet.
return false;
}
ir::Instruction* FoldInstruction(ir::Instruction* inst,
std::function<uint32_t(uint32_t)> id_map) {
ir::IRContext* context = inst->context();
bool modified = false;
std::unique_ptr<ir::Instruction> folded_inst(inst->Clone(context));
while (FoldInstructionInternal(&*folded_inst, id_map)) {
modified = true;
}
if (modified) {
if (folded_inst->opcode() == SpvOpCopyObject) {
analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
return def_use_mgr->GetDef(folded_inst->GetSingleWordInOperand(0));
} else {
InstructionBuilder ir_builder(
context, inst,
ir::IRContext::kAnalysisDefUse |
ir::IRContext::kAnalysisInstrToBlockMapping);
folded_inst->SetResultId(context->TakeNextId());
return ir_builder.AddInstruction(std::move(folded_inst));
}
}
return nullptr;
}
} // namespace opt
} // namespace spvtools