mirror of
https://github.com/KhronosGroup/SPIRV-Tools
synced 2025-01-17 19:40:06 +00:00
Merge pull request #885 from dnovillo/const-prop
Re-factor generic constant folding code out of fold spec constants pass
This commit is contained in:
commit
4101cf4687
@ -23,13 +23,15 @@ add_library(SPIRV-Tools-opt
|
||||
decoration_manager.h
|
||||
def_use_manager.h
|
||||
eliminate_dead_constant_pass.h
|
||||
eliminate_dead_functions_pass.h
|
||||
flatten_decoration_pass.h
|
||||
function.h
|
||||
fold.h
|
||||
fold_spec_constant_op_and_composite_pass.h
|
||||
freeze_spec_constant_value_pass.h
|
||||
inline_pass.h
|
||||
function.h
|
||||
inline_exhaustive_pass.h
|
||||
inline_opaque_pass.h
|
||||
inline_pass.h
|
||||
insert_extract_elim.h
|
||||
instruction.h
|
||||
ir_loader.h
|
||||
@ -38,20 +40,19 @@ add_library(SPIRV-Tools-opt
|
||||
local_single_store_elim_pass.h
|
||||
local_ssa_elim_pass.h
|
||||
log.h
|
||||
mem_pass.h
|
||||
module.h
|
||||
null_pass.h
|
||||
reflect.h
|
||||
mem_pass.h
|
||||
pass.h
|
||||
passes.h
|
||||
pass.h
|
||||
pass_manager.h
|
||||
eliminate_dead_functions_pass.h
|
||||
reflect.h
|
||||
remove_duplicates_pass.h
|
||||
set_spec_constant_default_value_pass.h
|
||||
strength_reduction_pass.h
|
||||
strip_debug_info_pass.h
|
||||
types.h
|
||||
type_manager.h
|
||||
types.h
|
||||
unify_const_pass.h
|
||||
|
||||
aggressive_dead_code_elim_pass.cpp
|
||||
@ -60,17 +61,19 @@ add_library(SPIRV-Tools-opt
|
||||
build_module.cpp
|
||||
common_uniform_elim_pass.cpp
|
||||
compact_ids_pass.cpp
|
||||
dead_branch_elim_pass.cpp
|
||||
decoration_manager.cpp
|
||||
def_use_manager.cpp
|
||||
dead_branch_elim_pass.cpp
|
||||
eliminate_dead_constant_pass.cpp
|
||||
eliminate_dead_functions_pass.cpp
|
||||
flatten_decoration_pass.cpp
|
||||
function.cpp
|
||||
fold.cpp
|
||||
fold_spec_constant_op_and_composite_pass.cpp
|
||||
freeze_spec_constant_value_pass.cpp
|
||||
inline_pass.cpp
|
||||
function.cpp
|
||||
inline_exhaustive_pass.cpp
|
||||
inline_opaque_pass.cpp
|
||||
inline_pass.cpp
|
||||
insert_extract_elim.cpp
|
||||
instruction.cpp
|
||||
ir_loader.cpp
|
||||
@ -78,18 +81,17 @@ add_library(SPIRV-Tools-opt
|
||||
local_single_block_elim_pass.cpp
|
||||
local_single_store_elim_pass.cpp
|
||||
local_ssa_elim_pass.cpp
|
||||
module.cpp
|
||||
eliminate_dead_functions_pass.cpp
|
||||
remove_duplicates_pass.cpp
|
||||
set_spec_constant_default_value_pass.cpp
|
||||
optimizer.cpp
|
||||
mem_pass.cpp
|
||||
module.cpp
|
||||
optimizer.cpp
|
||||
pass.cpp
|
||||
pass_manager.cpp
|
||||
remove_duplicates_pass.cpp
|
||||
set_spec_constant_default_value_pass.cpp
|
||||
strength_reduction_pass.cpp
|
||||
strip_debug_info_pass.cpp
|
||||
types.cpp
|
||||
type_manager.cpp
|
||||
types.cpp
|
||||
unify_const_pass.cpp
|
||||
)
|
||||
|
||||
|
244
source/opt/fold.cpp
Normal file
244
source/opt/fold.cpp
Normal file
@ -0,0 +1,244 @@
|
||||
// Copyright (c) 2017 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fold.h"
|
||||
#include "def_use_manager.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <vector>
|
||||
|
||||
namespace spvtools {
|
||||
namespace opt {
|
||||
|
||||
namespace {
|
||||
|
||||
// Returns the single-word result from performing the given unary operation on
|
||||
// the operand value which is passed in as a 32-bit word.
|
||||
uint32_t UnaryOperate(SpvOp opcode, uint32_t operand) {
|
||||
switch (opcode) {
|
||||
// Arthimetics
|
||||
case SpvOp::SpvOpSNegate:
|
||||
return -static_cast<int32_t>(operand);
|
||||
case SpvOp::SpvOpNot:
|
||||
return ~operand;
|
||||
case SpvOp::SpvOpLogicalNot:
|
||||
return !static_cast<bool>(operand);
|
||||
default:
|
||||
assert(false &&
|
||||
"Unsupported unary operation for OpSpecConstantOp instruction");
|
||||
return 0u;
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the single-word result from performing the given binary operation on
|
||||
// the operand values which are passed in as two 32-bit word.
|
||||
uint32_t BinaryOperate(SpvOp opcode, uint32_t a, uint32_t b) {
|
||||
switch (opcode) {
|
||||
// Arthimetics
|
||||
case SpvOp::SpvOpIAdd:
|
||||
return a + b;
|
||||
case SpvOp::SpvOpISub:
|
||||
return a - b;
|
||||
case SpvOp::SpvOpIMul:
|
||||
return a * b;
|
||||
case SpvOp::SpvOpUDiv:
|
||||
assert(b != 0);
|
||||
return a / b;
|
||||
case SpvOp::SpvOpSDiv:
|
||||
assert(b != 0u);
|
||||
return (static_cast<int32_t>(a)) / (static_cast<int32_t>(b));
|
||||
case SpvOp::SpvOpSRem: {
|
||||
// The sign of non-zero result comes from the first operand: a. This is
|
||||
// guaranteed by C++11 rules for integer division operator. The division
|
||||
// result is rounded toward zero, so the result of '%' has the sign of
|
||||
// the first operand.
|
||||
assert(b != 0u);
|
||||
return static_cast<int32_t>(a) % static_cast<int32_t>(b);
|
||||
}
|
||||
case SpvOp::SpvOpSMod: {
|
||||
// The sign of non-zero result comes from the second operand: b
|
||||
assert(b != 0u);
|
||||
int32_t rem = BinaryOperate(SpvOp::SpvOpSRem, a, b);
|
||||
int32_t b_prim = static_cast<int32_t>(b);
|
||||
return (rem + b_prim) % b_prim;
|
||||
}
|
||||
case SpvOp::SpvOpUMod:
|
||||
assert(b != 0u);
|
||||
return (a % b);
|
||||
|
||||
// Shifting
|
||||
case SpvOp::SpvOpShiftRightLogical: {
|
||||
return a >> b;
|
||||
}
|
||||
case SpvOp::SpvOpShiftRightArithmetic:
|
||||
return (static_cast<int32_t>(a)) >> b;
|
||||
case SpvOp::SpvOpShiftLeftLogical:
|
||||
return a << b;
|
||||
|
||||
// Bitwise operations
|
||||
case SpvOp::SpvOpBitwiseOr:
|
||||
return a | b;
|
||||
case SpvOp::SpvOpBitwiseAnd:
|
||||
return a & b;
|
||||
case SpvOp::SpvOpBitwiseXor:
|
||||
return a ^ b;
|
||||
|
||||
// Logical
|
||||
case SpvOp::SpvOpLogicalEqual:
|
||||
return (static_cast<bool>(a)) == (static_cast<bool>(b));
|
||||
case SpvOp::SpvOpLogicalNotEqual:
|
||||
return (static_cast<bool>(a)) != (static_cast<bool>(b));
|
||||
case SpvOp::SpvOpLogicalOr:
|
||||
return (static_cast<bool>(a)) || (static_cast<bool>(b));
|
||||
case SpvOp::SpvOpLogicalAnd:
|
||||
return (static_cast<bool>(a)) && (static_cast<bool>(b));
|
||||
|
||||
// Comparison
|
||||
case SpvOp::SpvOpIEqual:
|
||||
return a == b;
|
||||
case SpvOp::SpvOpINotEqual:
|
||||
return a != b;
|
||||
case SpvOp::SpvOpULessThan:
|
||||
return a < b;
|
||||
case SpvOp::SpvOpSLessThan:
|
||||
return (static_cast<int32_t>(a)) < (static_cast<int32_t>(b));
|
||||
case SpvOp::SpvOpUGreaterThan:
|
||||
return a > b;
|
||||
case SpvOp::SpvOpSGreaterThan:
|
||||
return (static_cast<int32_t>(a)) > (static_cast<int32_t>(b));
|
||||
case SpvOp::SpvOpULessThanEqual:
|
||||
return a <= b;
|
||||
case SpvOp::SpvOpSLessThanEqual:
|
||||
return (static_cast<int32_t>(a)) <= (static_cast<int32_t>(b));
|
||||
case SpvOp::SpvOpUGreaterThanEqual:
|
||||
return a >= b;
|
||||
case SpvOp::SpvOpSGreaterThanEqual:
|
||||
return (static_cast<int32_t>(a)) >= (static_cast<int32_t>(b));
|
||||
default:
|
||||
assert(false &&
|
||||
"Unsupported binary operation for OpSpecConstantOp instruction");
|
||||
return 0u;
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the single-word result from performing the given ternary operation
|
||||
// on the operand values which are passed in as three 32-bit word.
|
||||
uint32_t TernaryOperate(SpvOp opcode, uint32_t a, uint32_t b, uint32_t c) {
|
||||
switch (opcode) {
|
||||
case SpvOp::SpvOpSelect:
|
||||
return (static_cast<bool>(a)) ? b : c;
|
||||
default:
|
||||
assert(false &&
|
||||
"Unsupported ternary operation for OpSpecConstantOp instruction");
|
||||
return 0u;
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the single-word result from performing the given operation on the
|
||||
// operand words. This only works with 32-bit operations and uses boolean
|
||||
// convention that 0u is false, and anything else is boolean true.
|
||||
// TODO(qining): Support operands other than 32-bit wide.
|
||||
uint32_t OperateWords(SpvOp opcode,
|
||||
const std::vector<uint32_t>& operand_words) {
|
||||
switch (operand_words.size()) {
|
||||
case 1:
|
||||
return UnaryOperate(opcode, operand_words.front());
|
||||
case 2:
|
||||
return BinaryOperate(opcode, operand_words.front(), operand_words.back());
|
||||
case 3:
|
||||
return TernaryOperate(opcode, operand_words[0], operand_words[1],
|
||||
operand_words[2]);
|
||||
default:
|
||||
assert(false && "Invalid number of operands");
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Returns the result of performing an operation on scalar constant operands.
|
||||
// This function extracts the operand values as 32 bit words and returns the
|
||||
// result in 32 bit word. Scalar constants with longer than 32-bit width are
|
||||
// not accepted in this function.
|
||||
uint32_t FoldScalars(SpvOp opcode,
|
||||
const std::vector<analysis::Constant*>& operands) {
|
||||
std::vector<uint32_t> operand_values_in_raw_words;
|
||||
for (analysis::Constant* operand : operands) {
|
||||
if (analysis::ScalarConstant* scalar = operand->AsScalarConstant()) {
|
||||
const auto& scalar_words = scalar->words();
|
||||
assert(scalar_words.size() == 1 &&
|
||||
"Scalar constants with longer than 32-bit width are not allowed "
|
||||
"in FoldScalars()");
|
||||
operand_values_in_raw_words.push_back(scalar_words.front());
|
||||
} else if (operand->AsNullConstant()) {
|
||||
operand_values_in_raw_words.push_back(0u);
|
||||
} else {
|
||||
assert(false &&
|
||||
"FoldScalars() only accepts ScalarConst or NullConst type of "
|
||||
"constant");
|
||||
}
|
||||
}
|
||||
return OperateWords(opcode, operand_values_in_raw_words);
|
||||
}
|
||||
|
||||
// Returns the result of performing an operation over constant vectors. This
|
||||
// function iterates through the given vector type constant operands and
|
||||
// calculates the result for each element of the result vector to return.
|
||||
// Vectors with longer than 32-bit scalar components are not accepted in this
|
||||
// function.
|
||||
std::vector<uint32_t> FoldVectors(
|
||||
SpvOp opcode, uint32_t num_dims,
|
||||
const std::vector<analysis::Constant*>& operands) {
|
||||
std::vector<uint32_t> result;
|
||||
for (uint32_t d = 0; d < num_dims; d++) {
|
||||
std::vector<uint32_t> operand_values_for_one_dimension;
|
||||
for (analysis::Constant* operand : operands) {
|
||||
if (analysis::VectorConstant* vector_operand =
|
||||
operand->AsVectorConstant()) {
|
||||
// Extract the raw value of the scalar component constants
|
||||
// in 32-bit words here. The reason of not using FoldScalars() here
|
||||
// is that we do not create temporary null constants as components
|
||||
// when the vector operand is a NullConstant because Constant creation
|
||||
// may need extra checks for the validity and that is not manageed in
|
||||
// here.
|
||||
if (const analysis::ScalarConstant* scalar_component =
|
||||
vector_operand->GetComponents().at(d)->AsScalarConstant()) {
|
||||
const auto& scalar_words = scalar_component->words();
|
||||
assert(
|
||||
scalar_words.size() == 1 &&
|
||||
"Vector components with longer than 32-bit width are not allowed "
|
||||
"in FoldVectors()");
|
||||
operand_values_for_one_dimension.push_back(scalar_words.front());
|
||||
} else if (operand->AsNullConstant()) {
|
||||
operand_values_for_one_dimension.push_back(0u);
|
||||
} else {
|
||||
assert(false &&
|
||||
"VectorConst should only has ScalarConst or NullConst as "
|
||||
"components");
|
||||
}
|
||||
} else if (operand->AsNullConstant()) {
|
||||
operand_values_for_one_dimension.push_back(0u);
|
||||
} else {
|
||||
assert(false &&
|
||||
"FoldVectors() only accepts VectorConst or NullConst type of "
|
||||
"constant");
|
||||
}
|
||||
}
|
||||
result.push_back(OperateWords(opcode, operand_values_for_one_dimension));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace opt
|
||||
} // namespace spvtools
|
37
source/opt/fold.h
Normal file
37
source/opt/fold.h
Normal file
@ -0,0 +1,37 @@
|
||||
// Copyright (c) 2017 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef LIBSPIRV_UTIL_FOLD_H_
|
||||
#define LIBSPIRV_UTIL_FOLD_H_
|
||||
|
||||
#include "def_use_manager.h"
|
||||
#include "constants.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
|
||||
namespace spvtools {
|
||||
namespace opt {
|
||||
|
||||
uint32_t FoldScalars(SpvOp opcode,
|
||||
const std::vector<analysis::Constant*>& operands);
|
||||
|
||||
std::vector<uint32_t> FoldVectors(
|
||||
SpvOp opcode, uint32_t num_dims,
|
||||
const std::vector<analysis::Constant*>& operands);
|
||||
|
||||
} // namespace opt
|
||||
} // namespace spvtools
|
||||
|
||||
#endif // LIBSPIRV_UTIL_FOLD_H_
|
@ -20,227 +20,11 @@
|
||||
|
||||
#include "constants.h"
|
||||
#include "make_unique.h"
|
||||
#include "fold.h"
|
||||
|
||||
namespace spvtools {
|
||||
namespace opt {
|
||||
|
||||
namespace {
|
||||
// Returns the single-word result from performing the given unary operation on
|
||||
// the operand value which is passed in as a 32-bit word.
|
||||
uint32_t UnaryOperate(SpvOp opcode, uint32_t operand) {
|
||||
switch (opcode) {
|
||||
// Arthimetics
|
||||
case SpvOp::SpvOpSNegate:
|
||||
return -static_cast<int32_t>(operand);
|
||||
case SpvOp::SpvOpNot:
|
||||
return ~operand;
|
||||
case SpvOp::SpvOpLogicalNot:
|
||||
return !static_cast<bool>(operand);
|
||||
default:
|
||||
assert(false &&
|
||||
"Unsupported unary operation for OpSpecConstantOp instruction");
|
||||
return 0u;
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the single-word result from performing the given binary operation on
|
||||
// the operand values which are passed in as two 32-bit word.
|
||||
uint32_t BinaryOperate(SpvOp opcode, uint32_t a, uint32_t b) {
|
||||
switch (opcode) {
|
||||
// Arthimetics
|
||||
case SpvOp::SpvOpIAdd:
|
||||
return a + b;
|
||||
case SpvOp::SpvOpISub:
|
||||
return a - b;
|
||||
case SpvOp::SpvOpIMul:
|
||||
return a * b;
|
||||
case SpvOp::SpvOpUDiv:
|
||||
assert(b != 0);
|
||||
return a / b;
|
||||
case SpvOp::SpvOpSDiv:
|
||||
assert(b != 0u);
|
||||
return (static_cast<int32_t>(a)) / (static_cast<int32_t>(b));
|
||||
case SpvOp::SpvOpSRem: {
|
||||
// The sign of non-zero result comes from the first operand: a. This is
|
||||
// guaranteed by C++11 rules for integer division operator. The division
|
||||
// result is rounded toward zero, so the result of '%' has the sign of
|
||||
// the first operand.
|
||||
assert(b != 0u);
|
||||
return static_cast<int32_t>(a) % static_cast<int32_t>(b);
|
||||
}
|
||||
case SpvOp::SpvOpSMod: {
|
||||
// The sign of non-zero result comes from the second operand: b
|
||||
assert(b != 0u);
|
||||
int32_t rem = BinaryOperate(SpvOp::SpvOpSRem, a, b);
|
||||
int32_t b_prim = static_cast<int32_t>(b);
|
||||
return (rem + b_prim) % b_prim;
|
||||
}
|
||||
case SpvOp::SpvOpUMod:
|
||||
assert(b != 0u);
|
||||
return (a % b);
|
||||
|
||||
// Shifting
|
||||
case SpvOp::SpvOpShiftRightLogical: {
|
||||
return a >> b;
|
||||
}
|
||||
case SpvOp::SpvOpShiftRightArithmetic:
|
||||
return (static_cast<int32_t>(a)) >> b;
|
||||
case SpvOp::SpvOpShiftLeftLogical:
|
||||
return a << b;
|
||||
|
||||
// Bitwise operations
|
||||
case SpvOp::SpvOpBitwiseOr:
|
||||
return a | b;
|
||||
case SpvOp::SpvOpBitwiseAnd:
|
||||
return a & b;
|
||||
case SpvOp::SpvOpBitwiseXor:
|
||||
return a ^ b;
|
||||
|
||||
// Logical
|
||||
case SpvOp::SpvOpLogicalEqual:
|
||||
return (static_cast<bool>(a)) == (static_cast<bool>(b));
|
||||
case SpvOp::SpvOpLogicalNotEqual:
|
||||
return (static_cast<bool>(a)) != (static_cast<bool>(b));
|
||||
case SpvOp::SpvOpLogicalOr:
|
||||
return (static_cast<bool>(a)) || (static_cast<bool>(b));
|
||||
case SpvOp::SpvOpLogicalAnd:
|
||||
return (static_cast<bool>(a)) && (static_cast<bool>(b));
|
||||
|
||||
// Comparison
|
||||
case SpvOp::SpvOpIEqual:
|
||||
return a == b;
|
||||
case SpvOp::SpvOpINotEqual:
|
||||
return a != b;
|
||||
case SpvOp::SpvOpULessThan:
|
||||
return a < b;
|
||||
case SpvOp::SpvOpSLessThan:
|
||||
return (static_cast<int32_t>(a)) < (static_cast<int32_t>(b));
|
||||
case SpvOp::SpvOpUGreaterThan:
|
||||
return a > b;
|
||||
case SpvOp::SpvOpSGreaterThan:
|
||||
return (static_cast<int32_t>(a)) > (static_cast<int32_t>(b));
|
||||
case SpvOp::SpvOpULessThanEqual:
|
||||
return a <= b;
|
||||
case SpvOp::SpvOpSLessThanEqual:
|
||||
return (static_cast<int32_t>(a)) <= (static_cast<int32_t>(b));
|
||||
case SpvOp::SpvOpUGreaterThanEqual:
|
||||
return a >= b;
|
||||
case SpvOp::SpvOpSGreaterThanEqual:
|
||||
return (static_cast<int32_t>(a)) >= (static_cast<int32_t>(b));
|
||||
default:
|
||||
assert(false &&
|
||||
"Unsupported binary operation for OpSpecConstantOp instruction");
|
||||
return 0u;
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the single-word result from performing the given ternary operation
|
||||
// on the operand values which are passed in as three 32-bit word.
|
||||
uint32_t TernaryOperate(SpvOp opcode, uint32_t a, uint32_t b, uint32_t c) {
|
||||
switch (opcode) {
|
||||
case SpvOp::SpvOpSelect:
|
||||
return (static_cast<bool>(a)) ? b : c;
|
||||
default:
|
||||
assert(false &&
|
||||
"Unsupported ternary operation for OpSpecConstantOp instruction");
|
||||
return 0u;
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the single-word result from performing the given operation on the
|
||||
// operand words. This only works with 32-bit operations and uses boolean
|
||||
// convention that 0u is false, and anything else is boolean true.
|
||||
// TODO(qining): Support operands other than 32-bit wide.
|
||||
uint32_t OperateWords(SpvOp opcode,
|
||||
const std::vector<uint32_t>& operand_words) {
|
||||
switch (operand_words.size()) {
|
||||
case 1:
|
||||
return UnaryOperate(opcode, operand_words.front());
|
||||
case 2:
|
||||
return BinaryOperate(opcode, operand_words.front(), operand_words.back());
|
||||
case 3:
|
||||
return TernaryOperate(opcode, operand_words[0], operand_words[1],
|
||||
operand_words[2]);
|
||||
default:
|
||||
assert(false && "Invalid number of operands");
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the result of performing an operation on scalar constant operands.
|
||||
// This function extracts the operand values as 32 bit words and returns the
|
||||
// result in 32 bit word. Scalar constants with longer than 32-bit width are
|
||||
// not accepted in this function.
|
||||
uint32_t OperateScalars(SpvOp opcode,
|
||||
const std::vector<analysis::Constant*>& operands) {
|
||||
std::vector<uint32_t> operand_values_in_raw_words;
|
||||
for (analysis::Constant* operand : operands) {
|
||||
if (analysis::ScalarConstant* scalar = operand->AsScalarConstant()) {
|
||||
const auto& scalar_words = scalar->words();
|
||||
assert(scalar_words.size() == 1 &&
|
||||
"Scalar constants with longer than 32-bit width are not allowed "
|
||||
"in OperateScalars()");
|
||||
operand_values_in_raw_words.push_back(scalar_words.front());
|
||||
} else if (operand->AsNullConstant()) {
|
||||
operand_values_in_raw_words.push_back(0u);
|
||||
} else {
|
||||
assert(false &&
|
||||
"OperateScalars() only accepts ScalarConst or NullConst type of "
|
||||
"constant");
|
||||
}
|
||||
}
|
||||
return OperateWords(opcode, operand_values_in_raw_words);
|
||||
}
|
||||
|
||||
// Returns the result of performing an operation over constant vectors. This
|
||||
// function iterates through the given vector type constant operands and
|
||||
// calculates the result for each element of the result vector to return.
|
||||
// Vectors with longer than 32-bit scalar components are not accepted in this
|
||||
// function.
|
||||
std::vector<uint32_t> OperateVectors(
|
||||
SpvOp opcode, uint32_t num_dims,
|
||||
const std::vector<analysis::Constant*>& operands) {
|
||||
std::vector<uint32_t> result;
|
||||
for (uint32_t d = 0; d < num_dims; d++) {
|
||||
std::vector<uint32_t> operand_values_for_one_dimension;
|
||||
for (analysis::Constant* operand : operands) {
|
||||
if (analysis::VectorConstant* vector_operand =
|
||||
operand->AsVectorConstant()) {
|
||||
// Extract the raw value of the scalar component constants
|
||||
// in 32-bit words here. The reason of not using OperateScalars() here
|
||||
// is that we do not create temporary null constants as components
|
||||
// when the vector operand is a NullConstant because Constant creation
|
||||
// may need extra checks for the validity and that is not manageed in
|
||||
// here.
|
||||
if (const analysis::ScalarConstant* scalar_component =
|
||||
vector_operand->GetComponents().at(d)->AsScalarConstant()) {
|
||||
const auto& scalar_words = scalar_component->words();
|
||||
assert(
|
||||
scalar_words.size() == 1 &&
|
||||
"Vector components with longer than 32-bit width are not allowed "
|
||||
"in OperateVectors()");
|
||||
operand_values_for_one_dimension.push_back(scalar_words.front());
|
||||
} else if (operand->AsNullConstant()) {
|
||||
operand_values_for_one_dimension.push_back(0u);
|
||||
} else {
|
||||
assert(false &&
|
||||
"VectorConst should only has ScalarConst or NullConst as "
|
||||
"components");
|
||||
}
|
||||
} else if (operand->AsNullConstant()) {
|
||||
operand_values_for_one_dimension.push_back(0u);
|
||||
} else {
|
||||
assert(false &&
|
||||
"OperateVectors() only accepts VectorConst or NullConst type of "
|
||||
"constant");
|
||||
}
|
||||
}
|
||||
result.push_back(OperateWords(opcode, operand_values_for_one_dimension));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
FoldSpecConstantOpAndCompositePass::FoldSpecConstantOpAndCompositePass()
|
||||
: max_id_(0),
|
||||
module_(nullptr),
|
||||
@ -518,7 +302,7 @@ bool IsValidTypeForComponentWiseOperation(const analysis::Type* type) {
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
ir::Instruction* FoldSpecConstantOpAndCompositePass::DoComponentWiseOperation(
|
||||
ir::Module::inst_iterator* pos) {
|
||||
@ -546,7 +330,7 @@ ir::Instruction* FoldSpecConstantOpAndCompositePass::DoComponentWiseOperation(
|
||||
|
||||
if (result_type->AsInteger() || result_type->AsBool()) {
|
||||
// Scalar operation
|
||||
uint32_t result_val = OperateScalars(spec_opcode, operands);
|
||||
uint32_t result_val = FoldScalars(spec_opcode, operands);
|
||||
auto result_const = CreateConst(result_type, {result_val});
|
||||
return BuildInstructionAndAddToModule(std::move(result_const), pos);
|
||||
} else if (result_type->AsVector()) {
|
||||
@ -555,7 +339,7 @@ ir::Instruction* FoldSpecConstantOpAndCompositePass::DoComponentWiseOperation(
|
||||
result_type->AsVector()->element_type();
|
||||
uint32_t num_dims = result_type->AsVector()->element_count();
|
||||
std::vector<uint32_t> result_vec =
|
||||
OperateVectors(spec_opcode, num_dims, operands);
|
||||
FoldVectors(spec_opcode, num_dims, operands);
|
||||
std::vector<const analysis::Constant*> result_vector_components;
|
||||
for (uint32_t r : result_vec) {
|
||||
if (auto rc = CreateConst(element_type, {r})) {
|
||||
|
Loading…
Reference in New Issue
Block a user