Added Float16, and HexFloat conversions

This commit is contained in:
Andrew Woloszyn 2015-12-03 16:30:21 -05:00
parent 4b6a98fe16
commit 4e5bc928c0
2 changed files with 861 additions and 11 deletions

View File

@ -39,19 +39,45 @@
namespace spvutils {
class Float16 {
public:
Float16(uint16_t v) : val(v) {}
Float16() = default;
static bool isNan(const Float16 val) {
return ((val.val & 0x7C00) == 0x7C00) && ((val.val & 0x3FF) != 0);
}
Float16(const Float16& other) { val = other.val; }
uint16_t get_value() const { return val; }
private:
uint16_t val;
};
// To specialize this type, you must override uint_type to define
// an unsigned integer that can fit your floating point type.
// You must also add a isNan function that returns true if
// a value is Nan.
template <typename T>
struct FloatProxyTraits {
typedef void uint_type;
using uint_type = void;
};
template <>
struct FloatProxyTraits<float> {
typedef uint32_t uint_type;
using uint_type = uint32_t;
static bool isNan(float f) { return std::isnan(f); }
};
template <>
struct FloatProxyTraits<double> {
typedef uint64_t uint_type;
using uint_type = uint64_t;
static bool isNan(double f) { return std::isnan(f); }
};
template <>
struct FloatProxyTraits<Float16> {
using uint_type = uint16_t;
static bool isNan(Float16 f) { return Float16::isNan(f); }
};
// Since copying a floating point number (especially if it is NaN)
@ -86,7 +112,7 @@ class FloatProxy {
uint_type data() const { return data_; }
// Returns true if the value represents any type of NaN.
bool isNan() { return std::isnan(getAsFloat()); }
bool isNan() { return FloatProxyTraits<T>::isNan(getAsFloat()); }
private:
uint_type data_;
@ -111,9 +137,13 @@ std::istream& operator>>(std::istream& is, FloatProxy<T>& value) {
template <typename T>
struct HexFloatTraits {
// Integer type that can store this hex-float.
typedef void uint_type;
using uint_type = void;
// Signed integer type that can store this hex-float.
typedef void int_type;
using int_type = void;
// The numerical type that this HexFloat represents.
using underlying_type = void;
// The type needed to construct the underlying type.
using native_type = void;
// The number of bits that are actually relevant in the uint_type.
// This allows us to deal with, for example, 24-bit values in a 32-bit
// integer.
@ -131,8 +161,10 @@ struct HexFloatTraits {
// 1 sign bit, 8 exponent bits, 23 fractional bits.
template <>
struct HexFloatTraits<FloatProxy<float>> {
typedef uint32_t uint_type;
typedef int32_t int_type;
using uint_type = uint32_t;
using int_type = int32_t;
using underlying_type = FloatProxy<float>;
using native_type = float;
static const uint_type num_used_bits = 32;
static const uint_type num_exponent_bits = 8;
static const uint_type num_fraction_bits = 23;
@ -143,14 +175,38 @@ struct HexFloatTraits<FloatProxy<float>> {
// 1 sign bit, 11 exponent bits, 52 fractional bits.
template <>
struct HexFloatTraits<FloatProxy<double>> {
typedef uint64_t uint_type;
typedef int64_t int_type;
using uint_type = uint64_t;
using int_type = int64_t;
using underlying_type = FloatProxy<double>;
using native_type = double;
static const uint_type num_used_bits = 64;
static const uint_type num_exponent_bits = 11;
static const uint_type num_fraction_bits = 52;
static const uint_type exponent_bias = 1023;
};
// Traits for IEEE half.
// 1 sign bit, 5 exponent bits, 10 fractional bits.
template <>
struct HexFloatTraits<FloatProxy<Float16>> {
using uint_type = uint16_t;
using int_type = int16_t;
using underlying_type = uint16_t;
using native_type = uint16_t;
static const uint_type num_used_bits = 16;
static const uint_type num_exponent_bits = 5;
static const uint_type num_fraction_bits = 10;
static const uint_type exponent_bias = 15;
};
enum class round_direction {
kToZero,
kToNearestEven,
kToPositiveInfinity,
kToNegativeInfinity,
max = kToNegativeInfinity
};
// Template class that houses a floating pointer number.
// It exposes a number of constants based on the provided traits to
// assist in interpreting the bits of the value.
@ -159,6 +215,8 @@ class HexFloat {
public:
using uint_type = typename Traits::uint_type;
using int_type = typename Traits::int_type;
using underlying_type = typename Traits::underlying_type;
using native_type = typename Traits::native_type;
explicit HexFloat(T f) : value_(f) {}
@ -190,10 +248,15 @@ class HexFloat {
spvutils::SetBits<uint_type, 0,
num_fraction_bits + num_overflow_bits>::get;
// The topmost bit in the fraction. (The first non-implicit bit).
// The topmost bit in the nibble-aligned fraction.
static const uint_type fraction_top_bit =
uint_type(1) << (num_fraction_bits + num_overflow_bits - 1);
// The least significant bit in the exponent, which is also the bit
// immediately to the left of the significand.
static const uint_type first_exponent_bit = uint_type(1)
<< (num_fraction_bits);
// The mask for the encoded fraction. It does not include the
// implicit bit.
static const uint_type fraction_encode_mask =
@ -213,12 +276,334 @@ class HexFloat {
static const uint32_t fraction_right_shift =
(sizeof(uint_type) * 8) - num_fraction_bits;
// The maximum representable unbiased exponent.
static const int_type max_exponent =
(exponent_mask >> num_fraction_bits) - exponent_bias;
// The minimum representable exponent for normalized numbers.
static const int_type min_exponent = -static_cast<int_type>(exponent_bias);
// Returns the bits associated with the value.
uint_type getBits() const { return spvutils::BitwiseCast<uint_type>(value_); }
// Returns the bits associated with the value, without the leading sign bit.
uint_type getUnsignedBits() const {
return spvutils::BitwiseCast<uint_type>(value_) & ~sign_mask;
}
// Returns the bits associated with the exponent, shifted to start at the
// lsb of the type.
const uint_type getExponentBits() const {
return (getBits() & exponent_mask) >> num_fraction_bits;
}
// Returns the exponent in unbiased form. This is the exponent in the
// human-friendly form.
const int_type getUnbiasedExponent() const {
return (static_cast<int_type>(getExponentBits()) - exponent_bias);
}
// Returns just the significand bits from the value.
const uint_type getSignificandBits() const {
return getBits() & fraction_encode_mask;
}
// If the number was normalized, returns the unbiased exponent.
// If the number was denormal, normalize the exponent first.
const int_type getUnbiasedNormalizedExponent() const {
if ((getBits() & ~sign_mask) == 0) { // special case if everything is 0
return 0;
}
int_type exp = getUnbiasedExponent();
if (exp == min_exponent) { // We are in denorm land.
uint_type significand_bits = getSignificandBits();
while ((significand_bits & (first_exponent_bit >> 1)) == 0) {
significand_bits <<= 1;
exp -= 1;
}
significand_bits &= fraction_encode_mask;
}
return exp;
}
// Returns the signficand after it has been normalized.
const uint_type getNormalizedSignificand() const {
int_type unbiased_exponent = getUnbiasedNormalizedExponent();
uint_type significand = getSignificandBits();
for (int_type i = unbiased_exponent; i <= min_exponent; ++i) {
significand <<= 1;
}
significand &= fraction_encode_mask;
return significand;
}
// Returns true if this number represents a negative value.
bool isNegative() const { return (getBits() & sign_mask) != 0; }
// Sets this HexFloat from the individual components.
// Note this assumes EVERY significand is normalized, and has an implicit
// leading one. This means that the only way that this method will set 0,
// is if you set a number so denormalized that it underflows.
// Do not use this method with raw bits extracted from a subnormal number,
// since subnormals do not have an implicit leading 1 in the significand.
// The significand is also expected to be in the
// lowest-most num_fraction_bits of the uint_type.
// The exponent is expected to be unbiased, meaning an exponent of
// 0 actually means 0.
// If underflow_round_up is set, then on underflow, if a number is non-0
// and would underflow, we round up to the smallest denorm.
void setFromSignUnbiasedExponentAndNormalizedSignificand(
bool negative, int_type exponent, uint_type significand,
bool round_denorm_up) {
bool significand_is_zero = significand == 0;
if (exponent <= min_exponent) {
// If this was denormalized, then we have to shift the bit on, meaning
// the significand is not zero.
significand_is_zero = false;
significand |= first_exponent_bit;
significand >>= 1;
}
while (exponent < min_exponent) {
significand >>= 1;
++exponent;
}
if (exponent == min_exponent) {
if (significand == 0 && !significand_is_zero && round_denorm_up) {
significand = 0x1;
}
}
uint_type value = 0;
if (negative) {
value |= sign_mask;
}
exponent += exponent_bias;
assert(exponent >= 0);
// put it all together
exponent = (exponent << exponent_left_shift) & exponent_mask;
significand &= fraction_encode_mask;
value |= exponent | significand;
value_ = BitwiseCast<T>(value);
}
// Increments the significand of this number by the given amount.
// If this would spill the significand into the implicit bit,
// carry is set to true and the significand is shifted to fit into
// the correct location, otherwise carry is set to false.
// All significands and to_increment are assumed to be within the bounds
// for a valid significand.
static uint_type incrementSignificand(uint_type significand,
uint_type to_increment, bool* carry) {
significand += to_increment;
*carry = false;
if (significand & first_exponent_bit) {
*carry = true;
// The implicit 1-bit will have carried, so we should zero-out the
// top bit and shift back.
significand &= ~first_exponent_bit;
significand >>= 1;
}
return significand;
}
// These exist because MSVC throws warnings on negative right-shifts
// even if they are not going to be executed. Eg:
// constant_number < 0? 0: constant_number
// These convert the negative left-shifts into right shifts.
template <int_type N, typename enable = void>
struct negatable_left_shift {
static uint_type val(uint_type val) { return val >> -N; }
};
template <int_type N>
struct negatable_left_shift<N, typename std::enable_if<N >= 0>::type> {
static uint_type val(uint_type val) { return val << N; }
};
template <int_type N, typename enable = void>
struct negatable_right_shift {
static uint_type val(uint_type val) { return val << -N; }
};
template <int_type N>
struct negatable_right_shift<N, typename std::enable_if<N >= 0>::type> {
static uint_type val(uint_type val) { return val >> N; }
};
// Returns the significand, rounded to fit in a significand in
// other_T. This is shifted so that the most significant
// bit of the rounded number lines up with the most significant bit
// of the returned significand.
template <typename other_T>
typename other_T::uint_type getRoundedNormalizedSignificand(
round_direction dir, bool* carry_bit) {
using other_uint_type = typename other_T::uint_type;
static const int_type num_throwaway_bits =
static_cast<int_type>(num_fraction_bits) -
static_cast<int_type>(other_T::num_fraction_bits);
static const uint_type last_significant_bit =
(num_throwaway_bits < 0)
? 0
: negatable_left_shift<num_throwaway_bits>::val(1u);
static const uint_type first_rounded_bit =
(num_throwaway_bits < 1)
? 0
: negatable_left_shift<num_throwaway_bits - 1>::val(1u);
static const uint_type throwaway_mask_bits =
num_throwaway_bits > 0 ? num_throwaway_bits : 0;
static const uint_type throwaway_mask =
spvutils::SetBits<uint_type, 0, throwaway_mask_bits>::get;
*carry_bit = false;
other_uint_type out_val = 0;
uint_type significand = getNormalizedSignificand();
// If we are up-casting, then we just have to shift to the right location.
if (num_throwaway_bits <= 0) {
out_val = significand;
uint_type shift_amount = -num_throwaway_bits;
out_val <<= shift_amount;
return out_val;
}
// If every non-representable bit is 0, then we don't have any casting to
// do.
if ((significand & throwaway_mask) == 0) {
return static_cast<other_uint_type>(
negatable_right_shift<num_throwaway_bits>::val(significand));
}
bool round_away_from_zero = false;
// We actually have to narrow the significand here, so we have to follow the
// rounding rules.
switch (dir) {
case round_direction::kToZero:
break;
case round_direction::kToPositiveInfinity:
round_away_from_zero = !isNegative();
break;
case round_direction::kToNegativeInfinity:
round_away_from_zero = isNegative();
break;
case round_direction::kToNearestEven:
// Have to round down, round bit is 0
if ((first_rounded_bit & significand) == 0) {
break;
}
if (((significand & throwaway_mask) & ~first_rounded_bit) != 0) {
// If any subsequent bit of the rounded portion is non-0 then we round
// up.
round_away_from_zero = true;
break;
}
// We are exactly half-way between 2 numbers, pick even.
if ((significand & last_significant_bit) != 0) {
// 1 for our last bit, round up.
round_away_from_zero = true;
break;
}
break;
}
if (round_away_from_zero) {
return static_cast<other_uint_type>(
negatable_right_shift<num_throwaway_bits>::val(incrementSignificand(
significand, last_significant_bit, carry_bit)));
} else {
return static_cast<other_uint_type>(
negatable_right_shift<num_throwaway_bits>::val(significand));
}
// We really shouldn't get here.
assert(false && "We should not have ended up here");
return 0;
}
// Casts this value to another HexFloat. If the cast is widening,
// then round_dir is ignored. If the cast is narrowing, then
// the result is rounded in the direction specified.
// This number will retain Nan and Inf values.
// It will also saturate to Inf if the number overflows, and
// underflow to (0 or min depending on rounding) if the number underflows.
template <typename other_T>
void castTo(other_T& other, round_direction round_dir) {
other = other_T(static_cast<typename other_T::native_type>(0));
bool negate = isNegative();
if (getUnsignedBits() == 0) {
if (negate) {
other.set_value(-other.value());
}
return;
}
uint_type significand = getSignificandBits();
bool carried = false;
typename other_T::uint_type rounded_significand =
getRoundedNormalizedSignificand<other_T>(round_dir, &carried);
int_type exponent = getUnbiasedExponent();
if (exponent == min_exponent) {
// If we are denormal, normalize the exponent, so that we can encode
// easily.
exponent += 1;
for (uint_type check_bit = first_exponent_bit >> 1; check_bit != 0;
check_bit >>= 1) {
exponent -= 1;
if (check_bit & significand) break;
}
}
bool is_nan =
(getBits() & exponent_mask) == exponent_mask && significand != 0;
bool is_inf =
!is_nan &&
((exponent + carried) > static_cast<int_type>(other_T::exponent_bias) ||
(significand == 0 && (getBits() & exponent_mask) == exponent_mask));
// If we are Nan or Inf we should pass that through.
if (is_inf) {
other.set_value(BitwiseCast<typename other_T::underlying_type>(
static_cast<typename other_T::uint_type>(
(negate ? other_T::sign_mask : 0) | other_T::exponent_mask)));
return;
}
if (is_nan) {
typename other_T::uint_type shifted_significand;
shifted_significand = static_cast<typename other_T::uint_type>(
negatable_left_shift<other_T::num_fraction_bits -
num_fraction_bits>::val(significand));
// We are some sort of Nan. We try to keep the bit-pattern of the Nan
// as close as possible. If we had to shift off bits so we are 0, then we
// just set the last bit.
other.set_value(BitwiseCast<typename other_T::underlying_type>(
static_cast<typename other_T::uint_type>(
(negate ? other_T::sign_mask : 0) | other_T::exponent_mask |
(shifted_significand == 0 ? 0x1 : shifted_significand))));
return;
}
bool round_underflow_up =
isNegative() ? round_dir == round_direction::kToNegativeInfinity
: round_dir == round_direction::kToPositiveInfinity;
// setFromSignUnbiasedExponentAndNormalizedSignificand will
// zero out any underflowing value (but retain the sign).
other.setFromSignUnbiasedExponentAndNormalizedSignificand(
negate, exponent, rounded_significand, round_underflow_up);
return;
}
private:
T value_;
static_assert(num_used_bits ==
Traits::num_exponent_bits + Traits::num_fraction_bits + 1,
"The number of bits do not fit");
static_assert(sizeof(T) == sizeof(uint_type), "The type sizes do not match");
};
// Returns 4 bits represented by the hex character.

View File

@ -528,5 +528,470 @@ INSTANTIATE_TEST_CASE_P(
})));
// double is used so that unbiased_exponent can be used with the output
// of ldexp directly.
int32_t unbiased_exponent(double f) {
return spvutils::HexFloat<spvutils::FloatProxy<float>>(
static_cast<float>(f)).getUnbiasedNormalizedExponent();
}
int16_t unbiased_half_exponent(uint16_t f) {
return spvutils::HexFloat<spvutils::FloatProxy<spvutils::Float16>>(f)
.getUnbiasedNormalizedExponent();
}
TEST(HexFloatOperationTest, UnbiasedExponent) {
// Float cases
EXPECT_EQ(0, unbiased_exponent(ldexp(1.0f, 0)));
EXPECT_EQ(-32, unbiased_exponent(ldexp(1.0f, -32)));
EXPECT_EQ(42, unbiased_exponent(ldexp(1.0f, 42)));
EXPECT_EQ(125, unbiased_exponent(ldexp(1.0f, 125)));
// Saturates to 128
EXPECT_EQ(128, unbiased_exponent(ldexp(1.0f, 256)));
EXPECT_EQ(-100, unbiased_exponent(ldexp(1.0f, -100)));
EXPECT_EQ(-127, unbiased_exponent(ldexp(1.0f, -127))); // First denorm
EXPECT_EQ(-128, unbiased_exponent(ldexp(1.0f, -128)));
EXPECT_EQ(-129, unbiased_exponent(ldexp(1.0f, -129)));
EXPECT_EQ(-140, unbiased_exponent(ldexp(1.0f, -140)));
// Smallest representable number
EXPECT_EQ(-126 - 23, unbiased_exponent(ldexp(1.0f, -126 - 23)));
// Should get rounded to 0 first.
EXPECT_EQ(0, unbiased_exponent(ldexp(1.0f, -127 - 23)));
// Float16 cases
// The exponent is represented in the bits 0x7C00
// The offset is -15
EXPECT_EQ(0, unbiased_half_exponent(0x3C00));
EXPECT_EQ(3, unbiased_half_exponent(0x4800));
EXPECT_EQ(-1, unbiased_half_exponent(0x3800));
EXPECT_EQ(-14, unbiased_half_exponent(0x0400));
EXPECT_EQ(16, unbiased_half_exponent(0x7C00));
EXPECT_EQ(10, unbiased_half_exponent(0x6400));
// Smallest representable number
EXPECT_EQ(-24, unbiased_half_exponent(0x0001));
}
// Creates a float that is the sum of 1/(2 ^ fractions[i]) for i in factions
float float_fractions(const std::vector<uint32_t>& fractions) {
float f = 0;
for(int32_t i: fractions) {
f += ldexp(1.0f, -i);
}
return f;
}
// Returns the normalized significand of a HexFloat<FloatProxy<float>>
// that was created by calling float_fractions with the input fractions,
// raised to the power of exp.
uint32_t normalized_significand(const std::vector<uint32_t>& fractions, uint32_t exp) {
return spvutils::HexFloat<spvutils::FloatProxy<float>>(
static_cast<float>(ldexp(float_fractions(fractions), exp)))
.getNormalizedSignificand();
}
// Sets the bits from MSB to LSB of the significand part of a float.
// For example 0 would set the bit 23 (counting from LSB to MSB),
// and 1 would set the 22nd bit.
uint32_t bits_set(const std::vector<uint32_t>& bits) {
const uint32_t top_bit = 1u << 22u;
uint32_t val= 0;
for(uint32_t i: bits) {
val |= top_bit >> i;
}
return val;
}
// The same as bits_set but for a Float16 value instead of 32-bit floating
// point.
uint16_t half_bits_set(const std::vector<uint32_t>& bits) {
const uint32_t top_bit = 1u << 9u;
uint32_t val= 0;
for(uint32_t i: bits) {
val |= top_bit >> i;
}
return val;
}
TEST(HexFloatOperationTest, NormalizedSignificand) {
// For normalized numbers (the following) it should be a simple matter
// of getting rid of the top implicit bit
EXPECT_EQ(bits_set({}), normalized_significand({0}, 0));
EXPECT_EQ(bits_set({0}), normalized_significand({0, 1}, 0));
EXPECT_EQ(bits_set({0, 1}), normalized_significand({0, 1, 2}, 0));
EXPECT_EQ(bits_set({1}), normalized_significand({0, 2}, 0));
EXPECT_EQ(bits_set({1}), normalized_significand({0, 2}, 32));
EXPECT_EQ(bits_set({1}), normalized_significand({0, 2}, 126));
// For denormalized numbers we expect the normalized significand to
// shift as if it were normalized. This means, in practice that the
// top_most set bit will be cut off. Looks very similar to above (on purpose)
EXPECT_EQ(bits_set({}), normalized_significand({0}, -127));
EXPECT_EQ(bits_set({3}), normalized_significand({0, 4}, -128));
EXPECT_EQ(bits_set({3}), normalized_significand({0, 4}, -127));
EXPECT_EQ(bits_set({}), normalized_significand({22}, -127));
EXPECT_EQ(bits_set({0}), normalized_significand({21, 22}, -127));
}
// Returns the 32-bit floating point value created by
// calling setFromSignUnbiasedExponentAndNormalizedSignificand
// on a HexFloat<FloatProxy<float>>
float set_from_sign(bool negative, int32_t unbiased_exponent,
uint32_t significand, bool round_denorm_up) {
spvutils::HexFloat<spvutils::FloatProxy<float>> f(0.f);
f.setFromSignUnbiasedExponentAndNormalizedSignificand(
negative, unbiased_exponent, significand, round_denorm_up);
return f.value().getAsFloat();
}
TEST(HexFloatOperationTests,
SetFromSignUnbiasedExponentAndNormalizedSignificand) {
EXPECT_EQ(1.f, set_from_sign(false, 0, 0, false));
// Tests insertion of various denormalized numbers with and without round up.
EXPECT_EQ(static_cast<float>(ldexp(1.f, -149)), set_from_sign(false, -149, 0, false));
EXPECT_EQ(static_cast<float>(ldexp(1.f, -149)), set_from_sign(false, -149, 0, true));
EXPECT_EQ(0.f, set_from_sign(false, -150, 1, false));
EXPECT_EQ(static_cast<float>(ldexp(1.f, -149)), set_from_sign(false, -150, 1, true));
EXPECT_EQ(ldexp(1.0f, -127), set_from_sign(false, -127, 0, false));
EXPECT_EQ(ldexp(1.0f, -128), set_from_sign(false, -128, 0, false));
EXPECT_EQ(float_fractions({0, 1, 2, 5}),
set_from_sign(false, 0, bits_set({0, 1, 4}), false));
EXPECT_EQ(ldexp(float_fractions({0, 1, 2, 5}), -32),
set_from_sign(false, -32, bits_set({0, 1, 4}), false));
EXPECT_EQ(ldexp(float_fractions({0, 1, 2, 5}), -128),
set_from_sign(false, -128, bits_set({0, 1, 4}), false));
// The negative cases from above.
EXPECT_EQ(-1.f, set_from_sign(true, 0, 0, false));
EXPECT_EQ(-ldexp(1.0, -127), set_from_sign(true, -127, 0, false));
EXPECT_EQ(-ldexp(1.0, -128), set_from_sign(true, -128, 0, false));
EXPECT_EQ(-float_fractions({0, 1, 2, 5}),
set_from_sign(true, 0, bits_set({0, 1, 4}), false));
EXPECT_EQ(-ldexp(float_fractions({0, 1, 2, 5}), -32),
set_from_sign(true, -32, bits_set({0, 1, 4}), false));
EXPECT_EQ(-ldexp(float_fractions({0, 1, 2, 5}), -128),
set_from_sign(true, -128, bits_set({0, 1, 4}), false));
}
TEST(HexFloatOperationTests, NonRounding) {
// Rounding from 32-bit hex-float to 32-bit hex-float should be trivial,
// except in the denorm case which is a bit more complex.
using HF = spvutils::HexFloat<spvutils::FloatProxy<float>>;
bool carry_bit = false;
spvutils::round_direction rounding[] = {
spvutils::round_direction::kToZero,
spvutils::round_direction::kToNearestEven,
spvutils::round_direction::kToPositiveInfinity,
spvutils::round_direction::kToNegativeInfinity};
// Everything fits, so this should be straight-forward
for (spvutils::round_direction round : rounding) {
EXPECT_EQ(bits_set({}), HF(0.f).getRoundedNormalizedSignificand<HF>(
round, &carry_bit));
EXPECT_FALSE(carry_bit);
EXPECT_EQ(bits_set({0}),
HF(float_fractions({0, 1}))
.getRoundedNormalizedSignificand<HF>(round, &carry_bit));
EXPECT_FALSE(carry_bit);
EXPECT_EQ(bits_set({1, 3}),
HF(float_fractions({0, 2, 4}))
.getRoundedNormalizedSignificand<HF>(round, &carry_bit));
EXPECT_FALSE(carry_bit);
EXPECT_EQ(
bits_set({0, 1, 4}),
HF(static_cast<float>(-ldexp(float_fractions({0, 1, 2, 5}), -128)))
.getRoundedNormalizedSignificand<HF>(round, &carry_bit));
EXPECT_FALSE(carry_bit);
EXPECT_EQ(
bits_set({0, 1, 4, 22}),
HF(static_cast<float>(float_fractions({0, 1, 2, 5, 23})))
.getRoundedNormalizedSignificand<HF>(round, &carry_bit));
EXPECT_FALSE(carry_bit);
}
}
using RD = spvutils::round_direction;
struct RoundSignificandCase {
float source_float;
std::pair<int16_t, bool> expected_results;
spvutils::round_direction round;
};
using HexFloatRoundTest =
::testing::TestWithParam<RoundSignificandCase>;
TEST_P(HexFloatRoundTest, RoundDownToFP16) {
using HF = spvutils::HexFloat<spvutils::FloatProxy<float>>;
using HF16 = spvutils::HexFloat<spvutils::FloatProxy<spvutils::Float16>>;
HF input_value(GetParam().source_float);
bool carry_bit = false;
EXPECT_EQ(GetParam().expected_results.first,
input_value.getRoundedNormalizedSignificand<HF16>(
GetParam().round, &carry_bit));
EXPECT_EQ(carry_bit, GetParam().expected_results.second);
}
// clang-format off
INSTANTIATE_TEST_CASE_P(F32ToF16, HexFloatRoundTest,
::testing::ValuesIn(std::vector<RoundSignificandCase>(
{
{float_fractions({0}), std::make_pair(half_bits_set({}), false), RD::kToZero},
{float_fractions({0}), std::make_pair(half_bits_set({}), false), RD::kToNearestEven},
{float_fractions({0}), std::make_pair(half_bits_set({}), false), RD::kToPositiveInfinity},
{float_fractions({0}), std::make_pair(half_bits_set({}), false), RD::kToNegativeInfinity},
{float_fractions({0, 1}), std::make_pair(half_bits_set({0}), false), RD::kToZero},
{float_fractions({0, 1, 11}), std::make_pair(half_bits_set({0}), false), RD::kToZero},
{float_fractions({0, 1, 11}), std::make_pair(half_bits_set({0, 9}), false), RD::kToPositiveInfinity},
{float_fractions({0, 1, 11}), std::make_pair(half_bits_set({0}), false), RD::kToNegativeInfinity},
{float_fractions({0, 1, 11}), std::make_pair(half_bits_set({0}), false), RD::kToNearestEven},
{float_fractions({0, 1, 10, 11}), std::make_pair(half_bits_set({0, 9}), false), RD::kToZero},
{float_fractions({0, 1, 10, 11}), std::make_pair(half_bits_set({0, 8}), false), RD::kToPositiveInfinity},
{float_fractions({0, 1, 10, 11}), std::make_pair(half_bits_set({0, 9}), false), RD::kToNegativeInfinity},
{float_fractions({0, 1, 10, 11}), std::make_pair(half_bits_set({0, 8}), false), RD::kToNearestEven},
{float_fractions({0, 1, 11, 12}), std::make_pair(half_bits_set({0}), false), RD::kToZero},
{float_fractions({0, 1, 11, 12}), std::make_pair(half_bits_set({0, 9}), false), RD::kToPositiveInfinity},
{float_fractions({0, 1, 11, 12}), std::make_pair(half_bits_set({0}), false), RD::kToNegativeInfinity},
{float_fractions({0, 1, 11, 12}), std::make_pair(half_bits_set({0, 9}), false), RD::kToNearestEven},
{-float_fractions({0, 1, 11, 12}), std::make_pair(half_bits_set({0}), false), RD::kToZero},
{-float_fractions({0, 1, 11, 12}), std::make_pair(half_bits_set({0}), false), RD::kToPositiveInfinity},
{-float_fractions({0, 1, 11, 12}), std::make_pair(half_bits_set({0, 9}), false), RD::kToNegativeInfinity},
{-float_fractions({0, 1, 11, 12}), std::make_pair(half_bits_set({0, 9}), false), RD::kToNearestEven},
{float_fractions({0, 1, 11, 22}), std::make_pair(half_bits_set({0}), false), RD::kToZero},
{float_fractions({0, 1, 11, 22}), std::make_pair(half_bits_set({0, 9}), false), RD::kToPositiveInfinity},
{float_fractions({0, 1, 11, 22}), std::make_pair(half_bits_set({0}), false), RD::kToNegativeInfinity},
{float_fractions({0, 1, 11, 22}), std::make_pair(half_bits_set({0, 9}), false), RD::kToNearestEven},
// Carries
{float_fractions({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}), std::make_pair(half_bits_set({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}), false), RD::kToZero},
{float_fractions({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}), std::make_pair(half_bits_set({}), true), RD::kToPositiveInfinity},
{float_fractions({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}), std::make_pair(half_bits_set({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}), false), RD::kToNegativeInfinity},
{float_fractions({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}), std::make_pair(half_bits_set({}), true), RD::kToNearestEven},
// Cases where original number was denorm. Note: this should have no effect
// the number is pre-normalized.
{static_cast<float>(ldexp(float_fractions({0, 1, 11, 13}), -128)), std::make_pair(half_bits_set({0}), false), RD::kToZero},
{static_cast<float>(ldexp(float_fractions({0, 1, 11, 13}), -129)), std::make_pair(half_bits_set({0, 9}), false), RD::kToPositiveInfinity},
{static_cast<float>(ldexp(float_fractions({0, 1, 11, 13}), -131)), std::make_pair(half_bits_set({0}), false), RD::kToNegativeInfinity},
{static_cast<float>(ldexp(float_fractions({0, 1, 11, 13}), -130)), std::make_pair(half_bits_set({0, 9}), false), RD::kToNearestEven},
})));
// clang-format on
struct UpCastSignificandCase {
uint16_t source_half;
uint32_t expected_result;
};
using HexFloatRoundUpSignificandTest =
::testing::TestWithParam<UpCastSignificandCase>;
TEST_P(HexFloatRoundUpSignificandTest, Widening) {
using HF = spvutils::HexFloat<spvutils::FloatProxy<float>>;
using HF16 = spvutils::HexFloat<spvutils::FloatProxy<spvutils::Float16>>;
bool carry_bit = false;
spvutils::round_direction rounding[] = {
spvutils::round_direction::kToZero,
spvutils::round_direction::kToNearestEven,
spvutils::round_direction::kToPositiveInfinity,
spvutils::round_direction::kToNegativeInfinity};
// Everything fits, so everything should just be bit-shifts.
for (spvutils::round_direction round : rounding) {
carry_bit = false;
HF16 input_value(GetParam().source_half);
EXPECT_EQ(
GetParam().expected_result,
input_value.getRoundedNormalizedSignificand<HF>(round, &carry_bit))
<< std::hex << "0x"
<< input_value.getRoundedNormalizedSignificand<HF>(round, &carry_bit)
<< " 0x" << GetParam().expected_result;
EXPECT_FALSE(carry_bit);
}
}
INSTANTIATE_TEST_CASE_P(F16toF32, HexFloatRoundUpSignificandTest,
// 0xFC00 of the source 16-bit hex value cover the sign and the exponent.
// They are ignored for this test.
::testing::ValuesIn(std::vector<UpCastSignificandCase>(
{
{0x3F00, 0x600000},
{0x0F00, 0x600000},
{0x0F01, 0x602000},
{0x0FFF, 0x7FE000},
})));
struct DownCastTest {
float source_float;
uint16_t expected_half;
std::vector<spvutils::round_direction> directions;
};
std::string get_round_text(spvutils::round_direction direction) {
#define CASE(round_direction) \
case round_direction: \
return #round_direction
switch (direction) {
CASE(spvutils::round_direction::kToZero);
CASE(spvutils::round_direction::kToPositiveInfinity);
CASE(spvutils::round_direction::kToNegativeInfinity);
CASE(spvutils::round_direction::kToNearestEven);
}
#undef CASE
return "";
}
using HexFloatFP32To16Tests = ::testing::TestWithParam<DownCastTest>;
TEST_P(HexFloatFP32To16Tests, NarrowingCasts) {
using HF = spvutils::HexFloat<spvutils::FloatProxy<float>>;
using HF16 = spvutils::HexFloat<spvutils::FloatProxy<spvutils::Float16>>;
HF f(GetParam().source_float);
for (auto round : GetParam().directions) {
HF16 half(0);
f.castTo(half, round);
EXPECT_EQ(GetParam().expected_half, half.value().getAsFloat().get_value())
<< get_round_text(round) << " " << std::hex
<< spvutils::BitwiseCast<uint32_t>(GetParam().source_float)
<< " cast to: " << half.value().getAsFloat().get_value();
}
}
const uint16_t positive_infinity = 0x7C00;
const uint16_t negative_infinity = 0xFC00;
INSTANTIATE_TEST_CASE_P(F32ToF16, HexFloatFP32To16Tests,
::testing::ValuesIn(std::vector<DownCastTest>(
{
// Exactly representable as half.
{0.f, 0x0, {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, RD::kToNearestEven}},
{-0.f, 0x8000, {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, RD::kToNearestEven}},
{1.0f, 0x3C00, {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, RD::kToNearestEven}},
{-1.0f, 0xBC00, {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, RD::kToNearestEven}},
{float_fractions({0, 1, 10}) , 0x3E01, {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, RD::kToNearestEven}},
{-float_fractions({0, 1, 10}) , 0xBE01, {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, RD::kToNearestEven}},
{static_cast<float>(ldexp(float_fractions({0, 1, 10}), 3)), 0x4A01, {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, RD::kToNearestEven}},
{static_cast<float>(-ldexp(float_fractions({0, 1, 10}), 3)), 0xCA01, {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, RD::kToNearestEven}},
// Underflow
{static_cast<float>(ldexp(1.0f, -25)), 0x0, {RD::kToZero, RD::kToNegativeInfinity, RD::kToNearestEven}},
{static_cast<float>(ldexp(1.0f, -25)), 0x1, {RD::kToPositiveInfinity}},
{static_cast<float>(-ldexp(1.0f, -25)), 0x8000, {RD::kToZero, RD::kToPositiveInfinity, RD::kToNearestEven}},
{static_cast<float>(-ldexp(1.0f, -25)), 0x8001, {RD::kToNegativeInfinity}},
{static_cast<float>(ldexp(1.0f, -24)), 0x1, {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, RD::kToNearestEven}},
// Overflow
{static_cast<float>(ldexp(1.0f, 16)), positive_infinity, {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, RD::kToNearestEven}},
{static_cast<float>(ldexp(1.0f, 18)), positive_infinity, {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, RD::kToNearestEven}},
{static_cast<float>(ldexp(1.3f, 16)), positive_infinity, {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, RD::kToNearestEven}},
{static_cast<float>(-ldexp(1.0f, 16)), negative_infinity, {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, RD::kToNearestEven}},
{static_cast<float>(-ldexp(1.0f, 18)), negative_infinity, {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, RD::kToNearestEven}},
{static_cast<float>(-ldexp(1.3f, 16)), negative_infinity, {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, RD::kToNearestEven}},
// Transfer of Infinities
{std::numeric_limits<float>::infinity(), positive_infinity, {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, RD::kToNearestEven}},
{-std::numeric_limits<float>::infinity(), negative_infinity, {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, RD::kToNearestEven}},
// Nans are below because we cannot test for equality.
})));
struct UpCastCase{
uint16_t source_half;
float expected_float;
};
using HexFloatFP16To32Tests = ::testing::TestWithParam<UpCastCase>;
TEST_P(HexFloatFP16To32Tests, WideningCasts) {
using HF = spvutils::HexFloat<spvutils::FloatProxy<float>>;
using HF16 = spvutils::HexFloat<spvutils::FloatProxy<spvutils::Float16>>;
HF16 f(GetParam().source_half);
spvutils::round_direction rounding[] = {
spvutils::round_direction::kToZero,
spvutils::round_direction::kToNearestEven,
spvutils::round_direction::kToPositiveInfinity,
spvutils::round_direction::kToNegativeInfinity};
// Everything fits, so everything should just be bit-shifts.
for (spvutils::round_direction round : rounding) {
HF flt(0.f);
f.castTo(flt, round);
EXPECT_EQ(GetParam().expected_float, flt.value().getAsFloat())
<< get_round_text(round) << " " << std::hex
<< spvutils::BitwiseCast<uint16_t>(GetParam().source_half)
<< " cast to: " << flt.value().getAsFloat();
}
}
INSTANTIATE_TEST_CASE_P(F16ToF32, HexFloatFP16To32Tests,
::testing::ValuesIn(std::vector<UpCastCase>(
{
{0x0000, 0.f},
{0x8000, -0.f},
{0x3C00, 1.0f},
{0xBC00, -1.0f},
{0x3F00, float_fractions({0, 1, 2})},
{0xBF00, -float_fractions({0, 1, 2})},
{0x3F01, float_fractions({0, 1, 2, 10})},
{0xBF01, -float_fractions({0, 1, 2, 10})},
// denorm
{0x0001, static_cast<float>(ldexp(1.0, -24))},
{0x0002, static_cast<float>(ldexp(1.0, -23))},
{0x8001, static_cast<float>(-ldexp(1.0, -24))},
{0x8011, static_cast<float>(-ldexp(1.0, -20) + -ldexp(1.0, -24))},
// inf
{0x7C00, std::numeric_limits<float>::infinity()},
{0xFC00, -std::numeric_limits<float>::infinity()},
})));
TEST(HexFloatOperationTests, NanTests) {
using HF = spvutils::HexFloat<spvutils::FloatProxy<float>>;
using HF16 = spvutils::HexFloat<spvutils::FloatProxy<spvutils::Float16>>;
spvutils::round_direction rounding[] = {
spvutils::round_direction::kToZero,
spvutils::round_direction::kToNearestEven,
spvutils::round_direction::kToPositiveInfinity,
spvutils::round_direction::kToNegativeInfinity};
// Everything fits, so everything should just be bit-shifts.
for (spvutils::round_direction round : rounding) {
HF16 f16(0);
HF f(0.f);
HF(std::numeric_limits<float>::quiet_NaN()).castTo(f16, round);
EXPECT_TRUE(f16.value().isNan());
HF(std::numeric_limits<float>::signaling_NaN()).castTo(f16, round);
EXPECT_TRUE(f16.value().isNan());
HF16(0x7C01).castTo(f, round);
EXPECT_TRUE(f.value().isNan());
HF16(0x7C11).castTo(f, round);
EXPECT_TRUE(f.value().isNan());
HF16(0xFC01).castTo(f, round);
EXPECT_TRUE(f.value().isNan());
HF16(0x7C10).castTo(f, round);
EXPECT_TRUE(f.value().isNan());
HF16(0xFF00).castTo(f, round);
EXPECT_TRUE(f.value().isNan());
}
}
// TODO(awoloszyn): Add fp16 tests and HexFloatTraits.
}