Re-factor generic constant folding code out of FoldSpecConstantOpAndCompositePass

There are no functional changes in this patch.  The generic folding
routines in FoldSpecConstantOpAndCompositePass are now inside opt/fold.{cpp,h}.
This code will be used by the upcoming constant propagation pass.  In
time, we'll add more expression folding and simplification into these
two files.
This commit is contained in:
Diego Novillo 2017-10-17 19:41:37 -04:00
parent 1a9061a2be
commit 332a1f1422
4 changed files with 303 additions and 236 deletions

View File

@ -23,13 +23,15 @@ add_library(SPIRV-Tools-opt
decoration_manager.h decoration_manager.h
def_use_manager.h def_use_manager.h
eliminate_dead_constant_pass.h eliminate_dead_constant_pass.h
eliminate_dead_functions_pass.h
flatten_decoration_pass.h flatten_decoration_pass.h
function.h fold.h
fold_spec_constant_op_and_composite_pass.h fold_spec_constant_op_and_composite_pass.h
freeze_spec_constant_value_pass.h freeze_spec_constant_value_pass.h
inline_pass.h function.h
inline_exhaustive_pass.h inline_exhaustive_pass.h
inline_opaque_pass.h inline_opaque_pass.h
inline_pass.h
insert_extract_elim.h insert_extract_elim.h
instruction.h instruction.h
ir_loader.h ir_loader.h
@ -38,20 +40,19 @@ add_library(SPIRV-Tools-opt
local_single_store_elim_pass.h local_single_store_elim_pass.h
local_ssa_elim_pass.h local_ssa_elim_pass.h
log.h log.h
mem_pass.h
module.h module.h
null_pass.h null_pass.h
reflect.h
mem_pass.h
pass.h
passes.h passes.h
pass.h
pass_manager.h pass_manager.h
eliminate_dead_functions_pass.h reflect.h
remove_duplicates_pass.h remove_duplicates_pass.h
set_spec_constant_default_value_pass.h set_spec_constant_default_value_pass.h
strength_reduction_pass.h strength_reduction_pass.h
strip_debug_info_pass.h strip_debug_info_pass.h
types.h
type_manager.h type_manager.h
types.h
unify_const_pass.h unify_const_pass.h
aggressive_dead_code_elim_pass.cpp aggressive_dead_code_elim_pass.cpp
@ -60,17 +61,19 @@ add_library(SPIRV-Tools-opt
build_module.cpp build_module.cpp
common_uniform_elim_pass.cpp common_uniform_elim_pass.cpp
compact_ids_pass.cpp compact_ids_pass.cpp
dead_branch_elim_pass.cpp
decoration_manager.cpp decoration_manager.cpp
def_use_manager.cpp def_use_manager.cpp
dead_branch_elim_pass.cpp
eliminate_dead_constant_pass.cpp eliminate_dead_constant_pass.cpp
eliminate_dead_functions_pass.cpp
flatten_decoration_pass.cpp flatten_decoration_pass.cpp
function.cpp fold.cpp
fold_spec_constant_op_and_composite_pass.cpp fold_spec_constant_op_and_composite_pass.cpp
freeze_spec_constant_value_pass.cpp freeze_spec_constant_value_pass.cpp
inline_pass.cpp function.cpp
inline_exhaustive_pass.cpp inline_exhaustive_pass.cpp
inline_opaque_pass.cpp inline_opaque_pass.cpp
inline_pass.cpp
insert_extract_elim.cpp insert_extract_elim.cpp
instruction.cpp instruction.cpp
ir_loader.cpp ir_loader.cpp
@ -78,18 +81,17 @@ add_library(SPIRV-Tools-opt
local_single_block_elim_pass.cpp local_single_block_elim_pass.cpp
local_single_store_elim_pass.cpp local_single_store_elim_pass.cpp
local_ssa_elim_pass.cpp local_ssa_elim_pass.cpp
module.cpp
eliminate_dead_functions_pass.cpp
remove_duplicates_pass.cpp
set_spec_constant_default_value_pass.cpp
optimizer.cpp
mem_pass.cpp mem_pass.cpp
module.cpp
optimizer.cpp
pass.cpp pass.cpp
pass_manager.cpp pass_manager.cpp
remove_duplicates_pass.cpp
set_spec_constant_default_value_pass.cpp
strength_reduction_pass.cpp strength_reduction_pass.cpp
strip_debug_info_pass.cpp strip_debug_info_pass.cpp
types.cpp
type_manager.cpp type_manager.cpp
types.cpp
unify_const_pass.cpp unify_const_pass.cpp
) )

244
source/opt/fold.cpp Normal file
View File

@ -0,0 +1,244 @@
// Copyright (c) 2017 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fold.h"
#include "def_use_manager.h"
#include <cassert>
#include <vector>
namespace spvtools {
namespace opt {
namespace {
// Returns the single-word result from performing the given unary operation on
// the operand value which is passed in as a 32-bit word.
uint32_t UnaryOperate(SpvOp opcode, uint32_t operand) {
switch (opcode) {
// Arthimetics
case SpvOp::SpvOpSNegate:
return -static_cast<int32_t>(operand);
case SpvOp::SpvOpNot:
return ~operand;
case SpvOp::SpvOpLogicalNot:
return !static_cast<bool>(operand);
default:
assert(false &&
"Unsupported unary operation for OpSpecConstantOp instruction");
return 0u;
}
}
// Returns the single-word result from performing the given binary operation on
// the operand values which are passed in as two 32-bit word.
uint32_t BinaryOperate(SpvOp opcode, uint32_t a, uint32_t b) {
switch (opcode) {
// Arthimetics
case SpvOp::SpvOpIAdd:
return a + b;
case SpvOp::SpvOpISub:
return a - b;
case SpvOp::SpvOpIMul:
return a * b;
case SpvOp::SpvOpUDiv:
assert(b != 0);
return a / b;
case SpvOp::SpvOpSDiv:
assert(b != 0u);
return (static_cast<int32_t>(a)) / (static_cast<int32_t>(b));
case SpvOp::SpvOpSRem: {
// The sign of non-zero result comes from the first operand: a. This is
// guaranteed by C++11 rules for integer division operator. The division
// result is rounded toward zero, so the result of '%' has the sign of
// the first operand.
assert(b != 0u);
return static_cast<int32_t>(a) % static_cast<int32_t>(b);
}
case SpvOp::SpvOpSMod: {
// The sign of non-zero result comes from the second operand: b
assert(b != 0u);
int32_t rem = BinaryOperate(SpvOp::SpvOpSRem, a, b);
int32_t b_prim = static_cast<int32_t>(b);
return (rem + b_prim) % b_prim;
}
case SpvOp::SpvOpUMod:
assert(b != 0u);
return (a % b);
// Shifting
case SpvOp::SpvOpShiftRightLogical: {
return a >> b;
}
case SpvOp::SpvOpShiftRightArithmetic:
return (static_cast<int32_t>(a)) >> b;
case SpvOp::SpvOpShiftLeftLogical:
return a << b;
// Bitwise operations
case SpvOp::SpvOpBitwiseOr:
return a | b;
case SpvOp::SpvOpBitwiseAnd:
return a & b;
case SpvOp::SpvOpBitwiseXor:
return a ^ b;
// Logical
case SpvOp::SpvOpLogicalEqual:
return (static_cast<bool>(a)) == (static_cast<bool>(b));
case SpvOp::SpvOpLogicalNotEqual:
return (static_cast<bool>(a)) != (static_cast<bool>(b));
case SpvOp::SpvOpLogicalOr:
return (static_cast<bool>(a)) || (static_cast<bool>(b));
case SpvOp::SpvOpLogicalAnd:
return (static_cast<bool>(a)) && (static_cast<bool>(b));
// Comparison
case SpvOp::SpvOpIEqual:
return a == b;
case SpvOp::SpvOpINotEqual:
return a != b;
case SpvOp::SpvOpULessThan:
return a < b;
case SpvOp::SpvOpSLessThan:
return (static_cast<int32_t>(a)) < (static_cast<int32_t>(b));
case SpvOp::SpvOpUGreaterThan:
return a > b;
case SpvOp::SpvOpSGreaterThan:
return (static_cast<int32_t>(a)) > (static_cast<int32_t>(b));
case SpvOp::SpvOpULessThanEqual:
return a <= b;
case SpvOp::SpvOpSLessThanEqual:
return (static_cast<int32_t>(a)) <= (static_cast<int32_t>(b));
case SpvOp::SpvOpUGreaterThanEqual:
return a >= b;
case SpvOp::SpvOpSGreaterThanEqual:
return (static_cast<int32_t>(a)) >= (static_cast<int32_t>(b));
default:
assert(false &&
"Unsupported binary operation for OpSpecConstantOp instruction");
return 0u;
}
}
// Returns the single-word result from performing the given ternary operation
// on the operand values which are passed in as three 32-bit word.
uint32_t TernaryOperate(SpvOp opcode, uint32_t a, uint32_t b, uint32_t c) {
switch (opcode) {
case SpvOp::SpvOpSelect:
return (static_cast<bool>(a)) ? b : c;
default:
assert(false &&
"Unsupported ternary operation for OpSpecConstantOp instruction");
return 0u;
}
}
// Returns the single-word result from performing the given operation on the
// operand words. This only works with 32-bit operations and uses boolean
// convention that 0u is false, and anything else is boolean true.
// TODO(qining): Support operands other than 32-bit wide.
uint32_t OperateWords(SpvOp opcode,
const std::vector<uint32_t>& operand_words) {
switch (operand_words.size()) {
case 1:
return UnaryOperate(opcode, operand_words.front());
case 2:
return BinaryOperate(opcode, operand_words.front(), operand_words.back());
case 3:
return TernaryOperate(opcode, operand_words[0], operand_words[1],
operand_words[2]);
default:
assert(false && "Invalid number of operands");
return 0;
}
}
} // namespace
// Returns the result of performing an operation on scalar constant operands.
// This function extracts the operand values as 32 bit words and returns the
// result in 32 bit word. Scalar constants with longer than 32-bit width are
// not accepted in this function.
uint32_t FoldScalars(SpvOp opcode,
const std::vector<analysis::Constant*>& operands) {
std::vector<uint32_t> operand_values_in_raw_words;
for (analysis::Constant* operand : operands) {
if (analysis::ScalarConstant* scalar = operand->AsScalarConstant()) {
const auto& scalar_words = scalar->words();
assert(scalar_words.size() == 1 &&
"Scalar constants with longer than 32-bit width are not allowed "
"in FoldScalars()");
operand_values_in_raw_words.push_back(scalar_words.front());
} else if (operand->AsNullConstant()) {
operand_values_in_raw_words.push_back(0u);
} else {
assert(false &&
"FoldScalars() only accepts ScalarConst or NullConst type of "
"constant");
}
}
return OperateWords(opcode, operand_values_in_raw_words);
}
// Returns the result of performing an operation over constant vectors. This
// function iterates through the given vector type constant operands and
// calculates the result for each element of the result vector to return.
// Vectors with longer than 32-bit scalar components are not accepted in this
// function.
std::vector<uint32_t> FoldVectors(
SpvOp opcode, uint32_t num_dims,
const std::vector<analysis::Constant*>& operands) {
std::vector<uint32_t> result;
for (uint32_t d = 0; d < num_dims; d++) {
std::vector<uint32_t> operand_values_for_one_dimension;
for (analysis::Constant* operand : operands) {
if (analysis::VectorConstant* vector_operand =
operand->AsVectorConstant()) {
// Extract the raw value of the scalar component constants
// in 32-bit words here. The reason of not using FoldScalars() here
// is that we do not create temporary null constants as components
// when the vector operand is a NullConstant because Constant creation
// may need extra checks for the validity and that is not manageed in
// here.
if (const analysis::ScalarConstant* scalar_component =
vector_operand->GetComponents().at(d)->AsScalarConstant()) {
const auto& scalar_words = scalar_component->words();
assert(
scalar_words.size() == 1 &&
"Vector components with longer than 32-bit width are not allowed "
"in FoldVectors()");
operand_values_for_one_dimension.push_back(scalar_words.front());
} else if (operand->AsNullConstant()) {
operand_values_for_one_dimension.push_back(0u);
} else {
assert(false &&
"VectorConst should only has ScalarConst or NullConst as "
"components");
}
} else if (operand->AsNullConstant()) {
operand_values_for_one_dimension.push_back(0u);
} else {
assert(false &&
"FoldVectors() only accepts VectorConst or NullConst type of "
"constant");
}
}
result.push_back(OperateWords(opcode, operand_values_for_one_dimension));
}
return result;
}
} // namespace opt
} // namespace spvtools

37
source/opt/fold.h Normal file
View File

@ -0,0 +1,37 @@
// Copyright (c) 2017 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef LIBSPIRV_UTIL_FOLD_H_
#define LIBSPIRV_UTIL_FOLD_H_
#include "def_use_manager.h"
#include "constants.h"
#include <cstdint>
#include <vector>
namespace spvtools {
namespace opt {
uint32_t FoldScalars(SpvOp opcode,
const std::vector<analysis::Constant*>& operands);
std::vector<uint32_t> FoldVectors(
SpvOp opcode, uint32_t num_dims,
const std::vector<analysis::Constant*>& operands);
} // namespace opt
} // namespace spvtools
#endif // LIBSPIRV_UTIL_FOLD_H_

View File

@ -20,227 +20,11 @@
#include "constants.h" #include "constants.h"
#include "make_unique.h" #include "make_unique.h"
#include "fold.h"
namespace spvtools { namespace spvtools {
namespace opt { namespace opt {
namespace {
// Returns the single-word result from performing the given unary operation on
// the operand value which is passed in as a 32-bit word.
uint32_t UnaryOperate(SpvOp opcode, uint32_t operand) {
switch (opcode) {
// Arthimetics
case SpvOp::SpvOpSNegate:
return -static_cast<int32_t>(operand);
case SpvOp::SpvOpNot:
return ~operand;
case SpvOp::SpvOpLogicalNot:
return !static_cast<bool>(operand);
default:
assert(false &&
"Unsupported unary operation for OpSpecConstantOp instruction");
return 0u;
}
}
// Returns the single-word result from performing the given binary operation on
// the operand values which are passed in as two 32-bit word.
uint32_t BinaryOperate(SpvOp opcode, uint32_t a, uint32_t b) {
switch (opcode) {
// Arthimetics
case SpvOp::SpvOpIAdd:
return a + b;
case SpvOp::SpvOpISub:
return a - b;
case SpvOp::SpvOpIMul:
return a * b;
case SpvOp::SpvOpUDiv:
assert(b != 0);
return a / b;
case SpvOp::SpvOpSDiv:
assert(b != 0u);
return (static_cast<int32_t>(a)) / (static_cast<int32_t>(b));
case SpvOp::SpvOpSRem: {
// The sign of non-zero result comes from the first operand: a. This is
// guaranteed by C++11 rules for integer division operator. The division
// result is rounded toward zero, so the result of '%' has the sign of
// the first operand.
assert(b != 0u);
return static_cast<int32_t>(a) % static_cast<int32_t>(b);
}
case SpvOp::SpvOpSMod: {
// The sign of non-zero result comes from the second operand: b
assert(b != 0u);
int32_t rem = BinaryOperate(SpvOp::SpvOpSRem, a, b);
int32_t b_prim = static_cast<int32_t>(b);
return (rem + b_prim) % b_prim;
}
case SpvOp::SpvOpUMod:
assert(b != 0u);
return (a % b);
// Shifting
case SpvOp::SpvOpShiftRightLogical: {
return a >> b;
}
case SpvOp::SpvOpShiftRightArithmetic:
return (static_cast<int32_t>(a)) >> b;
case SpvOp::SpvOpShiftLeftLogical:
return a << b;
// Bitwise operations
case SpvOp::SpvOpBitwiseOr:
return a | b;
case SpvOp::SpvOpBitwiseAnd:
return a & b;
case SpvOp::SpvOpBitwiseXor:
return a ^ b;
// Logical
case SpvOp::SpvOpLogicalEqual:
return (static_cast<bool>(a)) == (static_cast<bool>(b));
case SpvOp::SpvOpLogicalNotEqual:
return (static_cast<bool>(a)) != (static_cast<bool>(b));
case SpvOp::SpvOpLogicalOr:
return (static_cast<bool>(a)) || (static_cast<bool>(b));
case SpvOp::SpvOpLogicalAnd:
return (static_cast<bool>(a)) && (static_cast<bool>(b));
// Comparison
case SpvOp::SpvOpIEqual:
return a == b;
case SpvOp::SpvOpINotEqual:
return a != b;
case SpvOp::SpvOpULessThan:
return a < b;
case SpvOp::SpvOpSLessThan:
return (static_cast<int32_t>(a)) < (static_cast<int32_t>(b));
case SpvOp::SpvOpUGreaterThan:
return a > b;
case SpvOp::SpvOpSGreaterThan:
return (static_cast<int32_t>(a)) > (static_cast<int32_t>(b));
case SpvOp::SpvOpULessThanEqual:
return a <= b;
case SpvOp::SpvOpSLessThanEqual:
return (static_cast<int32_t>(a)) <= (static_cast<int32_t>(b));
case SpvOp::SpvOpUGreaterThanEqual:
return a >= b;
case SpvOp::SpvOpSGreaterThanEqual:
return (static_cast<int32_t>(a)) >= (static_cast<int32_t>(b));
default:
assert(false &&
"Unsupported binary operation for OpSpecConstantOp instruction");
return 0u;
}
}
// Returns the single-word result from performing the given ternary operation
// on the operand values which are passed in as three 32-bit word.
uint32_t TernaryOperate(SpvOp opcode, uint32_t a, uint32_t b, uint32_t c) {
switch (opcode) {
case SpvOp::SpvOpSelect:
return (static_cast<bool>(a)) ? b : c;
default:
assert(false &&
"Unsupported ternary operation for OpSpecConstantOp instruction");
return 0u;
}
}
// Returns the single-word result from performing the given operation on the
// operand words. This only works with 32-bit operations and uses boolean
// convention that 0u is false, and anything else is boolean true.
// TODO(qining): Support operands other than 32-bit wide.
uint32_t OperateWords(SpvOp opcode,
const std::vector<uint32_t>& operand_words) {
switch (operand_words.size()) {
case 1:
return UnaryOperate(opcode, operand_words.front());
case 2:
return BinaryOperate(opcode, operand_words.front(), operand_words.back());
case 3:
return TernaryOperate(opcode, operand_words[0], operand_words[1],
operand_words[2]);
default:
assert(false && "Invalid number of operands");
return 0;
}
}
// Returns the result of performing an operation on scalar constant operands.
// This function extracts the operand values as 32 bit words and returns the
// result in 32 bit word. Scalar constants with longer than 32-bit width are
// not accepted in this function.
uint32_t OperateScalars(SpvOp opcode,
const std::vector<analysis::Constant*>& operands) {
std::vector<uint32_t> operand_values_in_raw_words;
for (analysis::Constant* operand : operands) {
if (analysis::ScalarConstant* scalar = operand->AsScalarConstant()) {
const auto& scalar_words = scalar->words();
assert(scalar_words.size() == 1 &&
"Scalar constants with longer than 32-bit width are not allowed "
"in OperateScalars()");
operand_values_in_raw_words.push_back(scalar_words.front());
} else if (operand->AsNullConstant()) {
operand_values_in_raw_words.push_back(0u);
} else {
assert(false &&
"OperateScalars() only accepts ScalarConst or NullConst type of "
"constant");
}
}
return OperateWords(opcode, operand_values_in_raw_words);
}
// Returns the result of performing an operation over constant vectors. This
// function iterates through the given vector type constant operands and
// calculates the result for each element of the result vector to return.
// Vectors with longer than 32-bit scalar components are not accepted in this
// function.
std::vector<uint32_t> OperateVectors(
SpvOp opcode, uint32_t num_dims,
const std::vector<analysis::Constant*>& operands) {
std::vector<uint32_t> result;
for (uint32_t d = 0; d < num_dims; d++) {
std::vector<uint32_t> operand_values_for_one_dimension;
for (analysis::Constant* operand : operands) {
if (analysis::VectorConstant* vector_operand =
operand->AsVectorConstant()) {
// Extract the raw value of the scalar component constants
// in 32-bit words here. The reason of not using OperateScalars() here
// is that we do not create temporary null constants as components
// when the vector operand is a NullConstant because Constant creation
// may need extra checks for the validity and that is not manageed in
// here.
if (const analysis::ScalarConstant* scalar_component =
vector_operand->GetComponents().at(d)->AsScalarConstant()) {
const auto& scalar_words = scalar_component->words();
assert(
scalar_words.size() == 1 &&
"Vector components with longer than 32-bit width are not allowed "
"in OperateVectors()");
operand_values_for_one_dimension.push_back(scalar_words.front());
} else if (operand->AsNullConstant()) {
operand_values_for_one_dimension.push_back(0u);
} else {
assert(false &&
"VectorConst should only has ScalarConst or NullConst as "
"components");
}
} else if (operand->AsNullConstant()) {
operand_values_for_one_dimension.push_back(0u);
} else {
assert(false &&
"OperateVectors() only accepts VectorConst or NullConst type of "
"constant");
}
}
result.push_back(OperateWords(opcode, operand_values_for_one_dimension));
}
return result;
}
} // anonymous namespace
FoldSpecConstantOpAndCompositePass::FoldSpecConstantOpAndCompositePass() FoldSpecConstantOpAndCompositePass::FoldSpecConstantOpAndCompositePass()
: max_id_(0), : max_id_(0),
module_(nullptr), module_(nullptr),
@ -518,7 +302,7 @@ bool IsValidTypeForComponentWiseOperation(const analysis::Type* type) {
} }
return false; return false;
} }
} } // namespace
ir::Instruction* FoldSpecConstantOpAndCompositePass::DoComponentWiseOperation( ir::Instruction* FoldSpecConstantOpAndCompositePass::DoComponentWiseOperation(
ir::Module::inst_iterator* pos) { ir::Module::inst_iterator* pos) {
@ -546,7 +330,7 @@ ir::Instruction* FoldSpecConstantOpAndCompositePass::DoComponentWiseOperation(
if (result_type->AsInteger() || result_type->AsBool()) { if (result_type->AsInteger() || result_type->AsBool()) {
// Scalar operation // Scalar operation
uint32_t result_val = OperateScalars(spec_opcode, operands); uint32_t result_val = FoldScalars(spec_opcode, operands);
auto result_const = CreateConst(result_type, {result_val}); auto result_const = CreateConst(result_type, {result_val});
return BuildInstructionAndAddToModule(std::move(result_const), pos); return BuildInstructionAndAddToModule(std::move(result_const), pos);
} else if (result_type->AsVector()) { } else if (result_type->AsVector()) {
@ -555,7 +339,7 @@ ir::Instruction* FoldSpecConstantOpAndCompositePass::DoComponentWiseOperation(
result_type->AsVector()->element_type(); result_type->AsVector()->element_type();
uint32_t num_dims = result_type->AsVector()->element_count(); uint32_t num_dims = result_type->AsVector()->element_count();
std::vector<uint32_t> result_vec = std::vector<uint32_t> result_vec =
OperateVectors(spec_opcode, num_dims, operands); FoldVectors(spec_opcode, num_dims, operands);
std::vector<const analysis::Constant*> result_vector_components; std::vector<const analysis::Constant*> result_vector_components;
for (uint32_t r : result_vec) { for (uint32_t r : result_vec) {
if (auto rc = CreateConst(element_type, {r})) { if (auto rc = CreateConst(element_type, {r})) {