From 332a1f142243c9a4cc4ba8ee24449f1a182c7889 Mon Sep 17 00:00:00 2001 From: Diego Novillo Date: Tue, 17 Oct 2017 19:41:37 -0400 Subject: [PATCH] Re-factor generic constant folding code out of FoldSpecConstantOpAndCompositePass There are no functional changes in this patch. The generic folding routines in FoldSpecConstantOpAndCompositePass are now inside opt/fold.{cpp,h}. This code will be used by the upcoming constant propagation pass. In time, we'll add more expression folding and simplification into these two files. --- source/opt/CMakeLists.txt | 34 +-- source/opt/fold.cpp | 244 ++++++++++++++++++ source/opt/fold.h | 37 +++ ...ld_spec_constant_op_and_composite_pass.cpp | 224 +--------------- 4 files changed, 303 insertions(+), 236 deletions(-) create mode 100644 source/opt/fold.cpp create mode 100644 source/opt/fold.h diff --git a/source/opt/CMakeLists.txt b/source/opt/CMakeLists.txt index 7a2eb8540..33d1e1380 100644 --- a/source/opt/CMakeLists.txt +++ b/source/opt/CMakeLists.txt @@ -23,13 +23,15 @@ add_library(SPIRV-Tools-opt decoration_manager.h def_use_manager.h eliminate_dead_constant_pass.h + eliminate_dead_functions_pass.h flatten_decoration_pass.h - function.h + fold.h fold_spec_constant_op_and_composite_pass.h freeze_spec_constant_value_pass.h - inline_pass.h + function.h inline_exhaustive_pass.h inline_opaque_pass.h + inline_pass.h insert_extract_elim.h instruction.h ir_loader.h @@ -38,20 +40,19 @@ add_library(SPIRV-Tools-opt local_single_store_elim_pass.h local_ssa_elim_pass.h log.h + mem_pass.h module.h null_pass.h - reflect.h - mem_pass.h - pass.h passes.h + pass.h pass_manager.h - eliminate_dead_functions_pass.h + reflect.h remove_duplicates_pass.h set_spec_constant_default_value_pass.h strength_reduction_pass.h strip_debug_info_pass.h - types.h type_manager.h + types.h unify_const_pass.h aggressive_dead_code_elim_pass.cpp @@ -60,17 +61,19 @@ add_library(SPIRV-Tools-opt build_module.cpp common_uniform_elim_pass.cpp compact_ids_pass.cpp + dead_branch_elim_pass.cpp decoration_manager.cpp def_use_manager.cpp - dead_branch_elim_pass.cpp eliminate_dead_constant_pass.cpp + eliminate_dead_functions_pass.cpp flatten_decoration_pass.cpp - function.cpp + fold.cpp fold_spec_constant_op_and_composite_pass.cpp freeze_spec_constant_value_pass.cpp - inline_pass.cpp + function.cpp inline_exhaustive_pass.cpp inline_opaque_pass.cpp + inline_pass.cpp insert_extract_elim.cpp instruction.cpp ir_loader.cpp @@ -78,18 +81,17 @@ add_library(SPIRV-Tools-opt local_single_block_elim_pass.cpp local_single_store_elim_pass.cpp local_ssa_elim_pass.cpp - module.cpp - eliminate_dead_functions_pass.cpp - remove_duplicates_pass.cpp - set_spec_constant_default_value_pass.cpp - optimizer.cpp mem_pass.cpp + module.cpp + optimizer.cpp pass.cpp pass_manager.cpp + remove_duplicates_pass.cpp + set_spec_constant_default_value_pass.cpp strength_reduction_pass.cpp strip_debug_info_pass.cpp - types.cpp type_manager.cpp + types.cpp unify_const_pass.cpp ) diff --git a/source/opt/fold.cpp b/source/opt/fold.cpp new file mode 100644 index 000000000..005cb76ed --- /dev/null +++ b/source/opt/fold.cpp @@ -0,0 +1,244 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fold.h" +#include "def_use_manager.h" + +#include +#include + +namespace spvtools { +namespace opt { + +namespace { + +// Returns the single-word result from performing the given unary operation on +// the operand value which is passed in as a 32-bit word. +uint32_t UnaryOperate(SpvOp opcode, uint32_t operand) { + switch (opcode) { + // Arthimetics + case SpvOp::SpvOpSNegate: + return -static_cast(operand); + case SpvOp::SpvOpNot: + return ~operand; + case SpvOp::SpvOpLogicalNot: + return !static_cast(operand); + default: + assert(false && + "Unsupported unary operation for OpSpecConstantOp instruction"); + return 0u; + } +} + +// Returns the single-word result from performing the given binary operation on +// the operand values which are passed in as two 32-bit word. +uint32_t BinaryOperate(SpvOp opcode, uint32_t a, uint32_t b) { + switch (opcode) { + // Arthimetics + case SpvOp::SpvOpIAdd: + return a + b; + case SpvOp::SpvOpISub: + return a - b; + case SpvOp::SpvOpIMul: + return a * b; + case SpvOp::SpvOpUDiv: + assert(b != 0); + return a / b; + case SpvOp::SpvOpSDiv: + assert(b != 0u); + return (static_cast(a)) / (static_cast(b)); + case SpvOp::SpvOpSRem: { + // The sign of non-zero result comes from the first operand: a. This is + // guaranteed by C++11 rules for integer division operator. The division + // result is rounded toward zero, so the result of '%' has the sign of + // the first operand. + assert(b != 0u); + return static_cast(a) % static_cast(b); + } + case SpvOp::SpvOpSMod: { + // The sign of non-zero result comes from the second operand: b + assert(b != 0u); + int32_t rem = BinaryOperate(SpvOp::SpvOpSRem, a, b); + int32_t b_prim = static_cast(b); + return (rem + b_prim) % b_prim; + } + case SpvOp::SpvOpUMod: + assert(b != 0u); + return (a % b); + + // Shifting + case SpvOp::SpvOpShiftRightLogical: { + return a >> b; + } + case SpvOp::SpvOpShiftRightArithmetic: + return (static_cast(a)) >> b; + case SpvOp::SpvOpShiftLeftLogical: + return a << b; + + // Bitwise operations + case SpvOp::SpvOpBitwiseOr: + return a | b; + case SpvOp::SpvOpBitwiseAnd: + return a & b; + case SpvOp::SpvOpBitwiseXor: + return a ^ b; + + // Logical + case SpvOp::SpvOpLogicalEqual: + return (static_cast(a)) == (static_cast(b)); + case SpvOp::SpvOpLogicalNotEqual: + return (static_cast(a)) != (static_cast(b)); + case SpvOp::SpvOpLogicalOr: + return (static_cast(a)) || (static_cast(b)); + case SpvOp::SpvOpLogicalAnd: + return (static_cast(a)) && (static_cast(b)); + + // Comparison + case SpvOp::SpvOpIEqual: + return a == b; + case SpvOp::SpvOpINotEqual: + return a != b; + case SpvOp::SpvOpULessThan: + return a < b; + case SpvOp::SpvOpSLessThan: + return (static_cast(a)) < (static_cast(b)); + case SpvOp::SpvOpUGreaterThan: + return a > b; + case SpvOp::SpvOpSGreaterThan: + return (static_cast(a)) > (static_cast(b)); + case SpvOp::SpvOpULessThanEqual: + return a <= b; + case SpvOp::SpvOpSLessThanEqual: + return (static_cast(a)) <= (static_cast(b)); + case SpvOp::SpvOpUGreaterThanEqual: + return a >= b; + case SpvOp::SpvOpSGreaterThanEqual: + return (static_cast(a)) >= (static_cast(b)); + default: + assert(false && + "Unsupported binary operation for OpSpecConstantOp instruction"); + return 0u; + } +} + +// Returns the single-word result from performing the given ternary operation +// on the operand values which are passed in as three 32-bit word. +uint32_t TernaryOperate(SpvOp opcode, uint32_t a, uint32_t b, uint32_t c) { + switch (opcode) { + case SpvOp::SpvOpSelect: + return (static_cast(a)) ? b : c; + default: + assert(false && + "Unsupported ternary operation for OpSpecConstantOp instruction"); + return 0u; + } +} + +// Returns the single-word result from performing the given operation on the +// operand words. This only works with 32-bit operations and uses boolean +// convention that 0u is false, and anything else is boolean true. +// TODO(qining): Support operands other than 32-bit wide. +uint32_t OperateWords(SpvOp opcode, + const std::vector& operand_words) { + switch (operand_words.size()) { + case 1: + return UnaryOperate(opcode, operand_words.front()); + case 2: + return BinaryOperate(opcode, operand_words.front(), operand_words.back()); + case 3: + return TernaryOperate(opcode, operand_words[0], operand_words[1], + operand_words[2]); + default: + assert(false && "Invalid number of operands"); + return 0; + } +} + +} // namespace + +// Returns the result of performing an operation on scalar constant operands. +// This function extracts the operand values as 32 bit words and returns the +// result in 32 bit word. Scalar constants with longer than 32-bit width are +// not accepted in this function. +uint32_t FoldScalars(SpvOp opcode, + const std::vector& operands) { + std::vector operand_values_in_raw_words; + for (analysis::Constant* operand : operands) { + if (analysis::ScalarConstant* scalar = operand->AsScalarConstant()) { + const auto& scalar_words = scalar->words(); + assert(scalar_words.size() == 1 && + "Scalar constants with longer than 32-bit width are not allowed " + "in FoldScalars()"); + operand_values_in_raw_words.push_back(scalar_words.front()); + } else if (operand->AsNullConstant()) { + operand_values_in_raw_words.push_back(0u); + } else { + assert(false && + "FoldScalars() only accepts ScalarConst or NullConst type of " + "constant"); + } + } + return OperateWords(opcode, operand_values_in_raw_words); +} + +// Returns the result of performing an operation over constant vectors. This +// function iterates through the given vector type constant operands and +// calculates the result for each element of the result vector to return. +// Vectors with longer than 32-bit scalar components are not accepted in this +// function. +std::vector FoldVectors( + SpvOp opcode, uint32_t num_dims, + const std::vector& operands) { + std::vector result; + for (uint32_t d = 0; d < num_dims; d++) { + std::vector operand_values_for_one_dimension; + for (analysis::Constant* operand : operands) { + if (analysis::VectorConstant* vector_operand = + operand->AsVectorConstant()) { + // Extract the raw value of the scalar component constants + // in 32-bit words here. The reason of not using FoldScalars() here + // is that we do not create temporary null constants as components + // when the vector operand is a NullConstant because Constant creation + // may need extra checks for the validity and that is not manageed in + // here. + if (const analysis::ScalarConstant* scalar_component = + vector_operand->GetComponents().at(d)->AsScalarConstant()) { + const auto& scalar_words = scalar_component->words(); + assert( + scalar_words.size() == 1 && + "Vector components with longer than 32-bit width are not allowed " + "in FoldVectors()"); + operand_values_for_one_dimension.push_back(scalar_words.front()); + } else if (operand->AsNullConstant()) { + operand_values_for_one_dimension.push_back(0u); + } else { + assert(false && + "VectorConst should only has ScalarConst or NullConst as " + "components"); + } + } else if (operand->AsNullConstant()) { + operand_values_for_one_dimension.push_back(0u); + } else { + assert(false && + "FoldVectors() only accepts VectorConst or NullConst type of " + "constant"); + } + } + result.push_back(OperateWords(opcode, operand_values_for_one_dimension)); + } + return result; +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/fold.h b/source/opt/fold.h new file mode 100644 index 000000000..c6e61d68a --- /dev/null +++ b/source/opt/fold.h @@ -0,0 +1,37 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIBSPIRV_UTIL_FOLD_H_ +#define LIBSPIRV_UTIL_FOLD_H_ + +#include "def_use_manager.h" +#include "constants.h" + +#include +#include + +namespace spvtools { +namespace opt { + +uint32_t FoldScalars(SpvOp opcode, + const std::vector& operands); + +std::vector FoldVectors( + SpvOp opcode, uint32_t num_dims, + const std::vector& operands); + +} // namespace opt +} // namespace spvtools + +#endif // LIBSPIRV_UTIL_FOLD_H_ diff --git a/source/opt/fold_spec_constant_op_and_composite_pass.cpp b/source/opt/fold_spec_constant_op_and_composite_pass.cpp index 84248146a..9d4d79075 100644 --- a/source/opt/fold_spec_constant_op_and_composite_pass.cpp +++ b/source/opt/fold_spec_constant_op_and_composite_pass.cpp @@ -20,227 +20,11 @@ #include "constants.h" #include "make_unique.h" +#include "fold.h" namespace spvtools { namespace opt { -namespace { -// Returns the single-word result from performing the given unary operation on -// the operand value which is passed in as a 32-bit word. -uint32_t UnaryOperate(SpvOp opcode, uint32_t operand) { - switch (opcode) { - // Arthimetics - case SpvOp::SpvOpSNegate: - return -static_cast(operand); - case SpvOp::SpvOpNot: - return ~operand; - case SpvOp::SpvOpLogicalNot: - return !static_cast(operand); - default: - assert(false && - "Unsupported unary operation for OpSpecConstantOp instruction"); - return 0u; - } -} - -// Returns the single-word result from performing the given binary operation on -// the operand values which are passed in as two 32-bit word. -uint32_t BinaryOperate(SpvOp opcode, uint32_t a, uint32_t b) { - switch (opcode) { - // Arthimetics - case SpvOp::SpvOpIAdd: - return a + b; - case SpvOp::SpvOpISub: - return a - b; - case SpvOp::SpvOpIMul: - return a * b; - case SpvOp::SpvOpUDiv: - assert(b != 0); - return a / b; - case SpvOp::SpvOpSDiv: - assert(b != 0u); - return (static_cast(a)) / (static_cast(b)); - case SpvOp::SpvOpSRem: { - // The sign of non-zero result comes from the first operand: a. This is - // guaranteed by C++11 rules for integer division operator. The division - // result is rounded toward zero, so the result of '%' has the sign of - // the first operand. - assert(b != 0u); - return static_cast(a) % static_cast(b); - } - case SpvOp::SpvOpSMod: { - // The sign of non-zero result comes from the second operand: b - assert(b != 0u); - int32_t rem = BinaryOperate(SpvOp::SpvOpSRem, a, b); - int32_t b_prim = static_cast(b); - return (rem + b_prim) % b_prim; - } - case SpvOp::SpvOpUMod: - assert(b != 0u); - return (a % b); - - // Shifting - case SpvOp::SpvOpShiftRightLogical: { - return a >> b; - } - case SpvOp::SpvOpShiftRightArithmetic: - return (static_cast(a)) >> b; - case SpvOp::SpvOpShiftLeftLogical: - return a << b; - - // Bitwise operations - case SpvOp::SpvOpBitwiseOr: - return a | b; - case SpvOp::SpvOpBitwiseAnd: - return a & b; - case SpvOp::SpvOpBitwiseXor: - return a ^ b; - - // Logical - case SpvOp::SpvOpLogicalEqual: - return (static_cast(a)) == (static_cast(b)); - case SpvOp::SpvOpLogicalNotEqual: - return (static_cast(a)) != (static_cast(b)); - case SpvOp::SpvOpLogicalOr: - return (static_cast(a)) || (static_cast(b)); - case SpvOp::SpvOpLogicalAnd: - return (static_cast(a)) && (static_cast(b)); - - // Comparison - case SpvOp::SpvOpIEqual: - return a == b; - case SpvOp::SpvOpINotEqual: - return a != b; - case SpvOp::SpvOpULessThan: - return a < b; - case SpvOp::SpvOpSLessThan: - return (static_cast(a)) < (static_cast(b)); - case SpvOp::SpvOpUGreaterThan: - return a > b; - case SpvOp::SpvOpSGreaterThan: - return (static_cast(a)) > (static_cast(b)); - case SpvOp::SpvOpULessThanEqual: - return a <= b; - case SpvOp::SpvOpSLessThanEqual: - return (static_cast(a)) <= (static_cast(b)); - case SpvOp::SpvOpUGreaterThanEqual: - return a >= b; - case SpvOp::SpvOpSGreaterThanEqual: - return (static_cast(a)) >= (static_cast(b)); - default: - assert(false && - "Unsupported binary operation for OpSpecConstantOp instruction"); - return 0u; - } -} - -// Returns the single-word result from performing the given ternary operation -// on the operand values which are passed in as three 32-bit word. -uint32_t TernaryOperate(SpvOp opcode, uint32_t a, uint32_t b, uint32_t c) { - switch (opcode) { - case SpvOp::SpvOpSelect: - return (static_cast(a)) ? b : c; - default: - assert(false && - "Unsupported ternary operation for OpSpecConstantOp instruction"); - return 0u; - } -} - -// Returns the single-word result from performing the given operation on the -// operand words. This only works with 32-bit operations and uses boolean -// convention that 0u is false, and anything else is boolean true. -// TODO(qining): Support operands other than 32-bit wide. -uint32_t OperateWords(SpvOp opcode, - const std::vector& operand_words) { - switch (operand_words.size()) { - case 1: - return UnaryOperate(opcode, operand_words.front()); - case 2: - return BinaryOperate(opcode, operand_words.front(), operand_words.back()); - case 3: - return TernaryOperate(opcode, operand_words[0], operand_words[1], - operand_words[2]); - default: - assert(false && "Invalid number of operands"); - return 0; - } -} - -// Returns the result of performing an operation on scalar constant operands. -// This function extracts the operand values as 32 bit words and returns the -// result in 32 bit word. Scalar constants with longer than 32-bit width are -// not accepted in this function. -uint32_t OperateScalars(SpvOp opcode, - const std::vector& operands) { - std::vector operand_values_in_raw_words; - for (analysis::Constant* operand : operands) { - if (analysis::ScalarConstant* scalar = operand->AsScalarConstant()) { - const auto& scalar_words = scalar->words(); - assert(scalar_words.size() == 1 && - "Scalar constants with longer than 32-bit width are not allowed " - "in OperateScalars()"); - operand_values_in_raw_words.push_back(scalar_words.front()); - } else if (operand->AsNullConstant()) { - operand_values_in_raw_words.push_back(0u); - } else { - assert(false && - "OperateScalars() only accepts ScalarConst or NullConst type of " - "constant"); - } - } - return OperateWords(opcode, operand_values_in_raw_words); -} - -// Returns the result of performing an operation over constant vectors. This -// function iterates through the given vector type constant operands and -// calculates the result for each element of the result vector to return. -// Vectors with longer than 32-bit scalar components are not accepted in this -// function. -std::vector OperateVectors( - SpvOp opcode, uint32_t num_dims, - const std::vector& operands) { - std::vector result; - for (uint32_t d = 0; d < num_dims; d++) { - std::vector operand_values_for_one_dimension; - for (analysis::Constant* operand : operands) { - if (analysis::VectorConstant* vector_operand = - operand->AsVectorConstant()) { - // Extract the raw value of the scalar component constants - // in 32-bit words here. The reason of not using OperateScalars() here - // is that we do not create temporary null constants as components - // when the vector operand is a NullConstant because Constant creation - // may need extra checks for the validity and that is not manageed in - // here. - if (const analysis::ScalarConstant* scalar_component = - vector_operand->GetComponents().at(d)->AsScalarConstant()) { - const auto& scalar_words = scalar_component->words(); - assert( - scalar_words.size() == 1 && - "Vector components with longer than 32-bit width are not allowed " - "in OperateVectors()"); - operand_values_for_one_dimension.push_back(scalar_words.front()); - } else if (operand->AsNullConstant()) { - operand_values_for_one_dimension.push_back(0u); - } else { - assert(false && - "VectorConst should only has ScalarConst or NullConst as " - "components"); - } - } else if (operand->AsNullConstant()) { - operand_values_for_one_dimension.push_back(0u); - } else { - assert(false && - "OperateVectors() only accepts VectorConst or NullConst type of " - "constant"); - } - } - result.push_back(OperateWords(opcode, operand_values_for_one_dimension)); - } - return result; -} -} // anonymous namespace - FoldSpecConstantOpAndCompositePass::FoldSpecConstantOpAndCompositePass() : max_id_(0), module_(nullptr), @@ -518,7 +302,7 @@ bool IsValidTypeForComponentWiseOperation(const analysis::Type* type) { } return false; } -} +} // namespace ir::Instruction* FoldSpecConstantOpAndCompositePass::DoComponentWiseOperation( ir::Module::inst_iterator* pos) { @@ -546,7 +330,7 @@ ir::Instruction* FoldSpecConstantOpAndCompositePass::DoComponentWiseOperation( if (result_type->AsInteger() || result_type->AsBool()) { // Scalar operation - uint32_t result_val = OperateScalars(spec_opcode, operands); + uint32_t result_val = FoldScalars(spec_opcode, operands); auto result_const = CreateConst(result_type, {result_val}); return BuildInstructionAndAddToModule(std::move(result_const), pos); } else if (result_type->AsVector()) { @@ -555,7 +339,7 @@ ir::Instruction* FoldSpecConstantOpAndCompositePass::DoComponentWiseOperation( result_type->AsVector()->element_type(); uint32_t num_dims = result_type->AsVector()->element_count(); std::vector result_vec = - OperateVectors(spec_opcode, num_dims, operands); + FoldVectors(spec_opcode, num_dims, operands); std::vector result_vector_components; for (uint32_t r : result_vec) { if (auto rc = CreateConst(element_type, {r})) {