Support constant-folding UConvert and SConvert (#2960)

This commit is contained in:
Jakub Kuderski 2019-10-16 16:29:55 -04:00 committed by GitHub
parent 8e89778531
commit e99b918221
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 471 additions and 17 deletions

View File

@ -56,6 +56,10 @@ uint32_t InstructionFolder::UnaryOperate(SpvOp opcode, uint32_t operand) const {
return ~operand; return ~operand;
case SpvOp::SpvOpLogicalNot: case SpvOp::SpvOpLogicalNot:
return !static_cast<bool>(operand); return !static_cast<bool>(operand);
case SpvOp::SpvOpUConvert:
return operand;
case SpvOp::SpvOpSConvert:
return operand;
default: default:
assert(false && assert(false &&
"Unsupported unary operation for OpSpecConstantOp instruction"); "Unsupported unary operation for OpSpecConstantOp instruction");
@ -596,6 +600,8 @@ bool InstructionFolder::IsFoldableOpcode(SpvOp opcode) const {
case SpvOp::SpvOpSMod: case SpvOp::SpvOpSMod:
case SpvOp::SpvOpSNegate: case SpvOp::SpvOpSNegate:
case SpvOp::SpvOpSRem: case SpvOp::SpvOpSRem:
case SpvOp::SpvOpSConvert:
case SpvOp::SpvOpUConvert:
case SpvOp::SpvOpUDiv: case SpvOp::SpvOpUDiv:
case SpvOp::SpvOpUGreaterThan: case SpvOp::SpvOpUGreaterThan:
case SpvOp::SpvOpUGreaterThanEqual: case SpvOp::SpvOpUGreaterThanEqual:

View File

@ -316,6 +316,59 @@ bool IsValidTypeForComponentWiseOperation(const analysis::Type* type) {
} }
return false; return false;
} }
// Encodes the integer |value| of in a word vector format appropriate for
// representing this value as a operands for a constant definition. Performs
// zero-extension/sign-extension/truncation when needed, based on the signess of
// the given target type.
//
// Note: type |type| argument must be either Integer or Bool.
utils::SmallVector<uint32_t, 2> EncodeIntegerAsWords(const analysis::Type& type,
uint32_t value) {
const uint32_t all_ones = ~0;
uint32_t bit_width = 0;
uint32_t pad_value = 0;
bool result_type_signed = false;
if (auto* int_ty = type.AsInteger()) {
bit_width = int_ty->width();
result_type_signed = int_ty->IsSigned();
if (result_type_signed && static_cast<int32_t>(value) < 0) {
pad_value = all_ones;
}
} else if (type.AsBool()) {
bit_width = 1;
} else {
assert(false && "type must be Integer or Bool");
}
assert(bit_width > 0);
uint32_t first_word = value;
const uint32_t bits_per_word = 32;
// Truncate first_word if the |type| has width less than uint32.
if (bit_width < bits_per_word) {
const uint32_t num_high_bits_to_mask = bits_per_word - bit_width;
const bool is_negative_after_truncation =
result_type_signed &&
utils::IsBitAtPositionSet(first_word, bit_width - 1);
if (is_negative_after_truncation) {
// Truncate and sign-extend |first_word|. No padding words will be
// added and |pad_value| can be left as-is.
first_word = utils::SetHighBits(first_word, num_high_bits_to_mask);
} else {
first_word = utils::ClearHighBits(first_word, num_high_bits_to_mask);
}
}
utils::SmallVector<uint32_t, 2> words = {first_word};
for (uint32_t current_bit = bits_per_word; current_bit < bit_width;
current_bit += bits_per_word) {
words.push_back(pad_value);
}
return words;
}
} // namespace } // namespace
Instruction* FoldSpecConstantOpAndCompositePass::DoComponentWiseOperation( Instruction* FoldSpecConstantOpAndCompositePass::DoComponentWiseOperation(
@ -345,10 +398,10 @@ Instruction* FoldSpecConstantOpAndCompositePass::DoComponentWiseOperation(
if (result_type->AsInteger() || result_type->AsBool()) { if (result_type->AsInteger() || result_type->AsBool()) {
// Scalar operation // Scalar operation
uint32_t result_val = const uint32_t result_val =
context()->get_instruction_folder().FoldScalars(spec_opcode, operands); context()->get_instruction_folder().FoldScalars(spec_opcode, operands);
auto result_const = auto result_const = context()->get_constant_mgr()->GetConstant(
context()->get_constant_mgr()->GetConstant(result_type, {result_val}); result_type, EncodeIntegerAsWords(*result_type, result_val));
return context()->get_constant_mgr()->BuildInstructionAndAddToModule( return context()->get_constant_mgr()->BuildInstructionAndAddToModule(
result_const, pos); result_const, pos);
} else if (result_type->AsVector()) { } else if (result_type->AsVector()) {
@ -360,9 +413,9 @@ Instruction* FoldSpecConstantOpAndCompositePass::DoComponentWiseOperation(
context()->get_instruction_folder().FoldVectors(spec_opcode, num_dims, context()->get_instruction_folder().FoldVectors(spec_opcode, num_dims,
operands); operands);
std::vector<const analysis::Constant*> result_vector_components; std::vector<const analysis::Constant*> result_vector_components;
for (uint32_t r : result_vec) { for (const uint32_t r : result_vec) {
if (auto rc = if (auto rc = context()->get_constant_mgr()->GetConstant(
context()->get_constant_mgr()->GetConstant(element_type, {r})) { element_type, EncodeIntegerAsWords(*element_type, r))) {
result_vector_components.push_back(rc); result_vector_components.push_back(rc);
if (!context()->get_constant_mgr()->BuildInstructionAndAddToModule( if (!context()->get_constant_mgr()->BuildInstructionAndAddToModule(
rc, pos)) { rc, pos)) {

View File

@ -15,8 +15,10 @@
#ifndef SOURCE_UTIL_BITUTILS_H_ #ifndef SOURCE_UTIL_BITUTILS_H_
#define SOURCE_UTIL_BITUTILS_H_ #define SOURCE_UTIL_BITUTILS_H_
#include <cassert>
#include <cstdint> #include <cstdint>
#include <cstring> #include <cstring>
#include <type_traits>
namespace spvtools { namespace spvtools {
namespace utils { namespace utils {
@ -31,6 +33,14 @@ Dest BitwiseCast(Src source) {
return dest; return dest;
} }
// Calculates the bit width of the integer type |T|.
template <typename T>
struct IntegerBitWidth {
static_assert(std::is_integral<T>::value, "Integer type required");
static const size_t kBitsPerByte = 8;
static const size_t get = sizeof(T) * kBitsPerByte;
};
// SetBits<T, First, Num> returns an integer of type <T> with bits set // SetBits<T, First, Num> returns an integer of type <T> with bits set
// for position <First> through <First + Num - 1>, counting from the least // for position <First> through <First + Num - 1>, counting from the least
// significant bit. In particular when Num == 0, no positions are set to 1. // significant bit. In particular when Num == 0, no positions are set to 1.
@ -38,7 +48,7 @@ Dest BitwiseCast(Src source) {
// a bit that will not fit in the underlying type is set. // a bit that will not fit in the underlying type is set.
template <typename T, size_t First = 0, size_t Num = 0> template <typename T, size_t First = 0, size_t Num = 0>
struct SetBits { struct SetBits {
static_assert(First < sizeof(T) * 8, static_assert(First < IntegerBitWidth<T>::get,
"Tried to set a bit that is shifted too far."); "Tried to set a bit that is shifted too far.");
const static T get = (T(1) << First) | SetBits<T, First + 1, Num - 1>::get; const static T get = (T(1) << First) | SetBits<T, First + 1, Num - 1>::get;
}; };
@ -49,6 +59,11 @@ struct SetBits<T, Last, 0> {
}; };
// This is all compile-time so we can put our tests right here. // This is all compile-time so we can put our tests right here.
static_assert(IntegerBitWidth<uint32_t>::get == 32, "IntegerBitWidth mismatch");
static_assert(IntegerBitWidth<int32_t>::get == 32, "IntegerBitWidth mismatch");
static_assert(IntegerBitWidth<uint64_t>::get == 64, "IntegerBitWidth mismatch");
static_assert(IntegerBitWidth<uint8_t>::get == 8, "IntegerBitWidth mismatch");
static_assert(SetBits<uint32_t, 0, 0>::get == uint32_t(0x00000000), static_assert(SetBits<uint32_t, 0, 0>::get == uint32_t(0x00000000),
"SetBits failed"); "SetBits failed");
static_assert(SetBits<uint32_t, 0, 1>::get == uint32_t(0x00000001), static_assert(SetBits<uint32_t, 0, 1>::get == uint32_t(0x00000001),
@ -90,6 +105,82 @@ size_t CountSetBits(T word) {
return count; return count;
} }
// Checks if the bit at the |position| is set to '1'.
// Bits zero-indexed starting at the least significant bit.
// |position| must be within the bit width of |T|.
template <typename T>
bool IsBitAtPositionSet(T word, size_t position) {
static_assert(std::is_integral<T>::value, "Integer type required");
static_assert(std::is_unsigned<T>::value, "Unsigned type required");
assert(position < IntegerBitWidth<T>::get &&
"position must be less than the bit width");
return word & T(T(1) << position);
}
// Returns a value obtained by setting a range of adjacent bits of |word| to
// |value|. Affected bits are within the range:
// [first_position, first_position + num_bits_to_mutate),
// assuming zero-based indexing starting at the least
// significant bit. Bits to mutate must be within the bit width of |T|.
template <typename T>
T MutateBits(T word, size_t first_position, size_t num_bits_to_mutate,
bool value) {
static_assert(std::is_integral<T>::value, "Integer type required");
static_assert(std::is_unsigned<T>::value, "Unsigned type required");
static const size_t word_bit_width = IntegerBitWidth<T>::get;
assert(first_position < word_bit_width &&
"Mutated bits must be within bit width");
assert(first_position + num_bits_to_mutate <= word_bit_width &&
"Mutated bits must be within bit width");
if (num_bits_to_mutate == 0) {
return word;
}
const T all_ones = ~T(0);
const size_t num_unaffected_low_bits = first_position;
const T unaffected_low_mask =
T(T(all_ones >> num_unaffected_low_bits) << num_unaffected_low_bits);
const size_t num_unaffected_high_bits =
word_bit_width - (first_position + num_bits_to_mutate);
const T unaffected_high_mask =
T(T(all_ones << num_unaffected_high_bits) >> num_unaffected_high_bits);
const T mutation_mask = unaffected_low_mask & unaffected_high_mask;
if (value) {
return word | mutation_mask;
}
return word & T(~mutation_mask);
}
// Returns a value obtained by setting the |num_bits_to_set| highest bits to
// '1'. |num_bits_to_set| must be not be greater than the bit width of |T|.
template <typename T>
T SetHighBits(T word, size_t num_bits_to_set) {
if (num_bits_to_set == 0) {
return word;
}
const size_t word_bit_width = IntegerBitWidth<T>::get;
assert(num_bits_to_set <= word_bit_width &&
"Can't set more bits than bit width");
return MutateBits(word, word_bit_width - num_bits_to_set, num_bits_to_set,
true);
}
// Returns a value obtained by setting the |num_bits_to_set| highest bits to
// '0'. |num_bits_to_set| must be not be greater than the bit width of |T|.
template <typename T>
T ClearHighBits(T word, size_t num_bits_to_set) {
if (num_bits_to_set == 0) {
return word;
}
const size_t word_bit_width = IntegerBitWidth<T>::get;
assert(num_bits_to_set <= word_bit_width &&
"Can't clear more bits than bit width");
return MutateBits(word, word_bit_width - num_bits_to_set, num_bits_to_set,
false);
}
} // namespace utils } // namespace utils
} // namespace spvtools } // namespace spvtools

View File

@ -112,8 +112,12 @@ std::vector<std::string> CommonTypesAndConstants() {
// clang-format off // clang-format off
// scalar types // scalar types
"%bool = OpTypeBool", "%bool = OpTypeBool",
"%ushort = OpTypeInt 16 0",
"%short = OpTypeInt 16 1",
"%uint = OpTypeInt 32 0", "%uint = OpTypeInt 32 0",
"%int = OpTypeInt 32 1", "%int = OpTypeInt 32 1",
"%ulong = OpTypeInt 64 0",
"%long = OpTypeInt 64 1",
"%float = OpTypeFloat 32", "%float = OpTypeFloat 32",
"%double = OpTypeFloat 64", "%double = OpTypeFloat 64",
// vector types // vector types
@ -122,6 +126,8 @@ std::vector<std::string> CommonTypesAndConstants() {
"%v2int = OpTypeVector %int 2", "%v2int = OpTypeVector %int 2",
"%v3int = OpTypeVector %int 3", "%v3int = OpTypeVector %int 3",
"%v4int = OpTypeVector %int 4", "%v4int = OpTypeVector %int 4",
"%v2long = OpTypeVector %long 2",
"%v2ulong = OpTypeVector %ulong 2",
"%v2float = OpTypeVector %float 2", "%v2float = OpTypeVector %float 2",
"%v2double = OpTypeVector %double 2", "%v2double = OpTypeVector %double 2",
// variable pointer types // variable pointer types
@ -145,6 +151,8 @@ std::vector<std::string> CommonTypesAndConstants() {
"%bool_null = OpConstantNull %bool", "%bool_null = OpConstantNull %bool",
"%signed_zero = OpConstant %int 0", "%signed_zero = OpConstant %int 0",
"%unsigned_zero = OpConstant %uint 0", "%unsigned_zero = OpConstant %uint 0",
"%long_zero = OpConstant %long 0",
"%ulong_zero = OpConstant %ulong 0",
"%signed_one = OpConstant %int 1", "%signed_one = OpConstant %int 1",
"%unsigned_one = OpConstant %uint 1", "%unsigned_one = OpConstant %uint 1",
"%signed_two = OpConstant %int 2", "%signed_two = OpConstant %int 2",
@ -153,6 +161,7 @@ std::vector<std::string> CommonTypesAndConstants() {
"%unsigned_three = OpConstant %uint 3", "%unsigned_three = OpConstant %uint 3",
"%signed_null = OpConstantNull %int", "%signed_null = OpConstantNull %int",
"%unsigned_null = OpConstantNull %uint", "%unsigned_null = OpConstantNull %uint",
"%signed_minus_one = OpConstant %int -1",
// vector constants: // vector constants:
"%bool_true_vec = OpConstantComposite %v2bool %bool_true %bool_true", "%bool_true_vec = OpConstantComposite %v2bool %bool_true %bool_true",
"%bool_false_vec = OpConstantComposite %v2bool %bool_false %bool_false", "%bool_false_vec = OpConstantComposite %v2bool %bool_false %bool_false",
@ -167,6 +176,7 @@ std::vector<std::string> CommonTypesAndConstants() {
"%unsigned_three_vec = OpConstantComposite %v2uint %unsigned_three %unsigned_three", "%unsigned_three_vec = OpConstantComposite %v2uint %unsigned_three %unsigned_three",
"%signed_null_vec = OpConstantNull %v2int", "%signed_null_vec = OpConstantNull %v2int",
"%unsigned_null_vec = OpConstantNull %v2uint", "%unsigned_null_vec = OpConstantNull %v2uint",
"%signed_minus_one_vec = OpConstantComposite %v2int %signed_minus_one %signed_minus_one",
"%v4int_0_1_2_3 = OpConstantComposite %v4int %signed_zero %signed_one %signed_two %signed_three", "%v4int_0_1_2_3 = OpConstantComposite %v4int %signed_zero %signed_one %signed_two %signed_three",
// clang-format on // clang-format on
}; };
@ -345,9 +355,9 @@ INSTANTIATE_TEST_SUITE_P(
// Tests for operations that resulting in different types. // Tests for operations that resulting in different types.
INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P(
Cast, FoldSpecConstantOpAndCompositePassTest, Cast, FoldSpecConstantOpAndCompositePassTest,
::testing::ValuesIn( ::testing::ValuesIn(std::vector<
std::vector<FoldSpecConstantOpAndCompositePassTestCase>({ FoldSpecConstantOpAndCompositePassTestCase>({
// clang-format off // clang-format off
// int -> bool scalar // int -> bool scalar
{ {
// original // original
@ -575,8 +585,108 @@ INSTANTIATE_TEST_SUITE_P(
"%spec_uint_from_null = OpConstantComposite %v2uint %unsigned_zero %unsigned_zero", "%spec_uint_from_null = OpConstantComposite %v2uint %unsigned_zero %unsigned_zero",
}, },
}, },
// clang-format on
}))); // UConvert scalar
{
// original
{
"%spec_uint_zero = OpSpecConstantOp %uint UConvert %bool_false",
"%spec_uint_one = OpSpecConstantOp %uint UConvert %bool_true",
"%spec_ulong_zero = OpSpecConstantOp %ulong UConvert %unsigned_zero",
"%spec_ulong_one = OpSpecConstantOp %ulong UConvert %unsigned_one",
"%spec_short_zero = OpSpecConstantOp %ushort UConvert %unsigned_zero",
"%spec_short_one = OpSpecConstantOp %ushort UConvert %unsigned_one",
"%uint_max = OpConstant %uint 4294967295",
"%spec_ushort_max = OpSpecConstantOp %ushort UConvert %uint_max",
"%uint_0xDDDDDDDD = OpConstant %uint 3722304989",
"%spec_ushort_0xDDDD = OpSpecConstantOp %ushort UConvert %uint_0xDDDDDDDD",
},
// expected
{
"%spec_uint_zero = OpConstant %uint 0",
"%spec_uint_one = OpConstant %uint 1",
"%spec_ulong_zero = OpConstant %ulong 0",
"%spec_ulong_one = OpConstant %ulong 1",
"%spec_short_zero = OpConstant %ushort 0",
"%spec_short_one = OpConstant %ushort 1",
"%uint_max = OpConstant %uint 4294967295",
"%spec_ushort_max = OpConstant %ushort 65535",
"%uint_0xDDDDDDDD = OpConstant %uint 3722304989",
"%spec_ushort_0xDDDD = OpConstant %ushort 56797",
},
},
// SConvert scalar
{
// original
{
"%spec_long_zero = OpSpecConstantOp %long SConvert %signed_zero",
"%spec_long_one = OpSpecConstantOp %long SConvert %signed_one",
"%spec_long_minus_one = OpSpecConstantOp %long SConvert %signed_minus_one",
"%spec_short_minus_one_trunc = OpSpecConstantOp %short SConvert %signed_minus_one",
"%int_2_to_17_minus_one = OpConstant %int 131071",
"%spec_short_minus_one_trunc2 = OpSpecConstantOp %short SConvert %int_2_to_17_minus_one",
},
// expected
{
"%spec_long_zero = OpConstant %long 0",
"%spec_long_one = OpConstant %long 1",
"%spec_long_minus_one = OpConstant %long -1",
"%spec_short_minus_one_trunc = OpConstant %short -1",
"%int_2_to_17_minus_one = OpConstant %int 131071",
"%spec_short_minus_one_trunc2 = OpConstant %short -1",
},
},
// UConvert vector
{
// original
{
"%spec_v2uint_zero = OpSpecConstantOp %v2uint UConvert %bool_false_vec",
"%spec_v2uint_one = OpSpecConstantOp %v2uint UConvert %bool_true_vec",
"%spec_v2ulong_zero = OpSpecConstantOp %v2ulong UConvert %unsigned_zero_vec",
"%spec_v2ulong_one = OpSpecConstantOp %v2ulong UConvert %unsigned_one_vec",
},
// expected
{
"%uint_0 = OpConstant %uint 0",
"%uint_0_0 = OpConstant %uint 0",
"%spec_v2uint_zero = OpConstantComposite %v2uint %unsigned_zero %unsigned_zero",
"%uint_1 = OpConstant %uint 1",
"%uint_1_0 = OpConstant %uint 1",
"%spec_v2uint_one = OpConstantComposite %v2uint %unsigned_one %unsigned_one",
"%ulong_0 = OpConstant %ulong 0",
"%ulong_0_0 = OpConstant %ulong 0",
"%spec_v2ulong_zero = OpConstantComposite %v2ulong %ulong_zero %ulong_zero",
"%ulong_1 = OpConstant %ulong 1",
"%ulong_1_0 = OpConstant %ulong 1",
"%spec_v2ulong_one = OpConstantComposite %v2ulong %ulong_1 %ulong_1",
},
},
// SConvert vector
{
// original
{
"%spec_v2long_zero = OpSpecConstantOp %v2long SConvert %signed_zero_vec",
"%spec_v2long_one = OpSpecConstantOp %v2long SConvert %signed_one_vec",
"%spec_v2long_minus_one = OpSpecConstantOp %v2long SConvert %signed_minus_one_vec",
},
// expected
{
"%long_0 = OpConstant %long 0",
"%long_0_0 = OpConstant %long 0",
"%spec_v2long_zero = OpConstantComposite %v2long %long_zero %long_zero",
"%long_1 = OpConstant %long 1",
"%long_1_0 = OpConstant %long 1",
"%spec_v2long_one = OpConstantComposite %v2long %long_1 %long_1",
"%long_n1 = OpConstant %long -1",
"%long_n1_0 = OpConstant %long -1",
"%spec_v2long_minus_one = OpConstantComposite %v2long %long_n1 %long_n1",
},
},
// clang-format on
})));
// Tests about boolean scalar logical operations and comparison operations with // Tests about boolean scalar logical operations and comparison operations with
// scalar int/uint type. // scalar int/uint type.
@ -851,7 +961,7 @@ INSTANTIATE_TEST_SUITE_P(
{ {
"%int_n1 = OpConstant %int -1", "%int_n1 = OpConstant %int -1",
"%int_n1_0 = OpConstant %int -1", "%int_n1_0 = OpConstant %int -1",
"%v2int_minus_1 = OpConstantComposite %v2int %int_n1 %int_n1", "%v2int_minus_1 = OpConstantComposite %v2int %signed_minus_one %signed_minus_one",
"%int_n2 = OpConstant %int -2", "%int_n2 = OpConstant %int -2",
"%int_n2_0 = OpConstant %int -2", "%int_n2_0 = OpConstant %int -2",
"%v2int_minus_2 = OpConstantComposite %v2int %int_n2 %int_n2", "%v2int_minus_2 = OpConstantComposite %v2int %int_n2 %int_n2",
@ -956,13 +1066,13 @@ INSTANTIATE_TEST_SUITE_P(
"%7_srem_3 = OpConstantComposite %v2int %signed_one %signed_one", "%7_srem_3 = OpConstantComposite %v2int %signed_one %signed_one",
"%int_n1 = OpConstant %int -1", "%int_n1 = OpConstant %int -1",
"%int_n1_0 = OpConstant %int -1", "%int_n1_0 = OpConstant %int -1",
"%minus_7_srem_3 = OpConstantComposite %v2int %int_n1 %int_n1", "%minus_7_srem_3 = OpConstantComposite %v2int %signed_minus_one %signed_minus_one",
"%int_1_1 = OpConstant %int 1", "%int_1_1 = OpConstant %int 1",
"%int_1_2 = OpConstant %int 1", "%int_1_2 = OpConstant %int 1",
"%7_srem_minus_3 = OpConstantComposite %v2int %signed_one %signed_one", "%7_srem_minus_3 = OpConstantComposite %v2int %signed_one %signed_one",
"%int_n1_1 = OpConstant %int -1", "%int_n1_1 = OpConstant %int -1",
"%int_n1_2 = OpConstant %int -1", "%int_n1_2 = OpConstant %int -1",
"%minus_7_srem_minus_3 = OpConstantComposite %v2int %int_n1 %int_n1", "%minus_7_srem_minus_3 = OpConstantComposite %v2int %signed_minus_one %signed_minus_one",
// smod // smod
"%int_1_3 = OpConstant %int 1", "%int_1_3 = OpConstant %int 1",
"%int_1_4 = OpConstant %int 1", "%int_1_4 = OpConstant %int 1",
@ -975,7 +1085,7 @@ INSTANTIATE_TEST_SUITE_P(
"%7_smod_minus_3 = OpConstantComposite %v2int %int_n2 %int_n2", "%7_smod_minus_3 = OpConstantComposite %v2int %int_n2 %int_n2",
"%int_n1_3 = OpConstant %int -1", "%int_n1_3 = OpConstant %int -1",
"%int_n1_4 = OpConstant %int -1", "%int_n1_4 = OpConstant %int -1",
"%minus_7_smod_minus_3 = OpConstantComposite %v2int %int_n1 %int_n1", "%minus_7_smod_minus_3 = OpConstantComposite %v2int %signed_minus_one %signed_minus_one",
// umod // umod
"%uint_1 = OpConstant %uint 1", "%uint_1 = OpConstant %uint 1",
"%uint_1_0 = OpConstant %uint 1", "%uint_1_0 = OpConstant %uint 1",
@ -1018,7 +1128,7 @@ INSTANTIATE_TEST_SUITE_P(
"%unsigned_right_shift_logical = OpConstantComposite %v2uint %unsigned_one %unsigned_one", "%unsigned_right_shift_logical = OpConstantComposite %v2uint %unsigned_one %unsigned_one",
"%int_n1 = OpConstant %int -1", "%int_n1 = OpConstant %int -1",
"%int_n1_0 = OpConstant %int -1", "%int_n1_0 = OpConstant %int -1",
"%signed_right_shift_arithmetic = OpConstantComposite %v2int %int_n1 %int_n1", "%signed_right_shift_arithmetic = OpConstantComposite %v2int %signed_minus_one %signed_minus_one",
}, },
}, },
// Skip folding if any vector operands or components of the operands // Skip folding if any vector operands or components of the operands

View File

@ -15,6 +15,7 @@
add_spvtools_unittest(TARGET utils add_spvtools_unittest(TARGET utils
SRCS ilist_test.cpp SRCS ilist_test.cpp
bit_vector_test.cpp bit_vector_test.cpp
bitutils_test.cpp
small_vector_test.cpp small_vector_test.cpp
LIBS SPIRV-Tools-opt LIBS SPIRV-Tools-opt
) )

193
test/util/bitutils_test.cpp Normal file
View File

@ -0,0 +1,193 @@
// Copyright (c) 2019 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "source/util/bitutils.h"
#include "gmock/gmock.h"
namespace spvtools {
namespace utils {
namespace {
using BitUtilsTest = ::testing::Test;
TEST(BitUtilsTest, MutateBitsWholeWord) {
const uint32_t zero_u32 = 0;
const uint32_t max_u32 = ~0;
EXPECT_EQ(MutateBits(zero_u32, 0, 0, false), zero_u32);
EXPECT_EQ(MutateBits(max_u32, 0, 0, false), max_u32);
EXPECT_EQ(MutateBits(zero_u32, 0, 32, false), zero_u32);
EXPECT_EQ(MutateBits(zero_u32, 0, 32, true), max_u32);
EXPECT_EQ(MutateBits(max_u32, 0, 32, true), max_u32);
EXPECT_EQ(MutateBits(max_u32, 0, 32, false), zero_u32);
}
TEST(BitUtilsTest, MutateBitsLow) {
const uint32_t zero_u32 = 0;
const uint32_t one_u32 = 1;
const uint32_t max_u32 = ~0;
EXPECT_EQ(MutateBits(zero_u32, 0, 1, false), zero_u32);
EXPECT_EQ(MutateBits(zero_u32, 0, 1, true), one_u32);
EXPECT_EQ(MutateBits(max_u32, 0, 1, true), max_u32);
EXPECT_EQ(MutateBits(one_u32, 0, 32, false), zero_u32);
EXPECT_EQ(MutateBits(one_u32, 0, 1, true), one_u32);
EXPECT_EQ(MutateBits(one_u32, 0, 1, false), zero_u32);
EXPECT_EQ(MutateBits(zero_u32, 0, 3, true), uint32_t(7));
EXPECT_EQ(MutateBits(uint32_t(7), 0, 2, false), uint32_t(4));
}
TEST(BitUtilsTest, MutateBitsHigh) {
const uint8_t zero_u8 = 0;
const uint8_t one_u8 = 1;
const uint8_t max_u8 = 255;
EXPECT_EQ(MutateBits(zero_u8, 7, 0, true), zero_u8);
EXPECT_EQ(MutateBits(zero_u8, 7, 1, true), uint8_t(128));
EXPECT_EQ(MutateBits(one_u8, 7, 1, true), uint8_t(129));
EXPECT_EQ(MutateBits(max_u8, 7, 1, true), max_u8);
EXPECT_EQ(MutateBits(max_u8, 7, 1, false), uint8_t(127));
EXPECT_EQ(MutateBits(max_u8, 6, 2, true), max_u8);
EXPECT_EQ(MutateBits(max_u8, 6, 2, false), uint8_t(63));
}
TEST(BitUtilsTest, MutateBitsUint8Mid) {
const uint8_t zero_u8 = 0;
const uint8_t max_u8 = 255;
EXPECT_EQ(MutateBits(zero_u8, 1, 2, true), uint8_t(6));
EXPECT_EQ(MutateBits(max_u8, 1, 2, true), max_u8);
EXPECT_EQ(MutateBits(max_u8, 1, 2, false), uint8_t(0xF9));
EXPECT_EQ(MutateBits(zero_u8, 2, 3, true), uint8_t(0x1C));
}
TEST(BitUtilsTest, MutateBitsUint64Mid) {
const uint64_t zero_u64 = 0;
const uint64_t max_u64 = ~zero_u64;
EXPECT_EQ(MutateBits(zero_u64, 1, 2, true), uint64_t(6));
EXPECT_EQ(MutateBits(max_u64, 1, 2, true), max_u64);
EXPECT_EQ(MutateBits(max_u64, 1, 2, false), uint64_t(0xFFFFFFFFFFFFFFF9));
EXPECT_EQ(MutateBits(zero_u64, 2, 3, true), uint64_t(0x000000000000001C));
EXPECT_EQ(MutateBits(zero_u64, 2, 35, true), uint64_t(0x0000001FFFFFFFFC));
EXPECT_EQ(MutateBits(zero_u64, 36, 4, true), uint64_t(0x000000F000000000));
EXPECT_EQ(MutateBits(max_u64, 36, 4, false), uint64_t(0xFFFFFF0FFFFFFFFF));
}
TEST(BitUtilsTest, SetHighBitsUint32) {
const uint32_t zero_u32 = 0;
const uint32_t one_u32 = 1;
const uint32_t max_u32 = ~zero_u32;
EXPECT_EQ(SetHighBits(zero_u32, 0), zero_u32);
EXPECT_EQ(SetHighBits(zero_u32, 1), 0x80000000);
EXPECT_EQ(SetHighBits(one_u32, 1), 0x80000001);
EXPECT_EQ(SetHighBits(one_u32, 2), 0xC0000001);
EXPECT_EQ(SetHighBits(zero_u32, 31), 0xFFFFFFFE);
EXPECT_EQ(SetHighBits(zero_u32, 32), max_u32);
EXPECT_EQ(SetHighBits(max_u32, 32), max_u32);
}
TEST(BitUtilsTest, ClearHighBitsUint32) {
const uint32_t zero_u32 = 0;
const uint32_t one_u32 = 1;
const uint32_t max_u32 = ~zero_u32;
EXPECT_EQ(ClearHighBits(zero_u32, 0), zero_u32);
EXPECT_EQ(ClearHighBits(zero_u32, 1), zero_u32);
EXPECT_EQ(ClearHighBits(one_u32, 1), one_u32);
EXPECT_EQ(ClearHighBits(one_u32, 31), one_u32);
EXPECT_EQ(ClearHighBits(one_u32, 32), zero_u32);
EXPECT_EQ(ClearHighBits(max_u32, 0), max_u32);
EXPECT_EQ(ClearHighBits(max_u32, 1), 0x7FFFFFFF);
EXPECT_EQ(ClearHighBits(max_u32, 2), 0x3FFFFFFF);
EXPECT_EQ(ClearHighBits(max_u32, 31), one_u32);
EXPECT_EQ(ClearHighBits(max_u32, 32), zero_u32);
}
TEST(BitUtilsTest, IsBitSetAtPositionZero) {
const uint32_t zero_u32 = 0;
for (size_t i = 0; i != 32; ++i) {
EXPECT_FALSE(IsBitAtPositionSet(zero_u32, i));
}
const uint8_t zero_u8 = 0;
for (size_t i = 0; i != 8; ++i) {
EXPECT_FALSE(IsBitAtPositionSet(zero_u8, i));
}
const uint64_t zero_u64 = 0;
for (size_t i = 0; i != 64; ++i) {
EXPECT_FALSE(IsBitAtPositionSet(zero_u64, i));
}
}
TEST(BitUtilsTest, IsBitSetAtPositionOne) {
const uint32_t one_u32 = 1;
for (size_t i = 0; i != 32; ++i) {
if (i == 0) {
EXPECT_TRUE(IsBitAtPositionSet(one_u32, i));
} else {
EXPECT_FALSE(IsBitAtPositionSet(one_u32, i));
}
}
const uint32_t two_to_17_u32 = 1 << 17;
for (size_t i = 0; i != 32; ++i) {
if (i == 17) {
EXPECT_TRUE(IsBitAtPositionSet(two_to_17_u32, i));
} else {
EXPECT_FALSE(IsBitAtPositionSet(two_to_17_u32, i));
}
}
const uint8_t two_to_4_u8 = 1 << 4;
for (size_t i = 0; i != 8; ++i) {
if (i == 4) {
EXPECT_TRUE(IsBitAtPositionSet(two_to_4_u8, i));
} else {
EXPECT_FALSE(IsBitAtPositionSet(two_to_4_u8, i));
}
}
const uint64_t two_to_55_u64 = uint64_t(1) << 55;
for (size_t i = 0; i != 64; ++i) {
if (i == 55) {
EXPECT_TRUE(IsBitAtPositionSet(two_to_55_u64, i));
} else {
EXPECT_FALSE(IsBitAtPositionSet(two_to_55_u64, i));
}
}
}
TEST(BitUtilsTest, IsBitSetAtPositionAll) {
const uint32_t max_u32 = ~0;
for (size_t i = 0; i != 32; ++i) {
EXPECT_TRUE(IsBitAtPositionSet(max_u32, i));
}
const uint32_t max_u8 = ~uint8_t(0);
for (size_t i = 0; i != 8; ++i) {
EXPECT_TRUE(IsBitAtPositionSet(max_u8, i));
}
const uint64_t max_u64 = ~uint64_t(0);
for (size_t i = 0; i != 64; ++i) {
EXPECT_TRUE(IsBitAtPositionSet(max_u64, i));
}
}
} // namespace
} // namespace utils
} // namespace spvtools