diff --git a/include/util/hex_float.h b/include/util/hex_float.h index a3bdc317c..6f4b7c0f9 100644 --- a/include/util/hex_float.h +++ b/include/util/hex_float.h @@ -39,19 +39,45 @@ namespace spvutils { +class Float16 { + public: + Float16(uint16_t v) : val(v) {} + Float16() = default; + static bool isNan(const Float16 val) { + return ((val.val & 0x7C00) == 0x7C00) && ((val.val & 0x3FF) != 0); + } + Float16(const Float16& other) { val = other.val; } + uint16_t get_value() const { return val; } + + private: + uint16_t val; +}; + +// To specialize this type, you must override uint_type to define +// an unsigned integer that can fit your floating point type. +// You must also add a isNan function that returns true if +// a value is Nan. template struct FloatProxyTraits { - typedef void uint_type; + using uint_type = void; }; template <> struct FloatProxyTraits { - typedef uint32_t uint_type; + using uint_type = uint32_t; + static bool isNan(float f) { return std::isnan(f); } }; template <> struct FloatProxyTraits { - typedef uint64_t uint_type; + using uint_type = uint64_t; + static bool isNan(double f) { return std::isnan(f); } +}; + +template <> +struct FloatProxyTraits { + using uint_type = uint16_t; + static bool isNan(Float16 f) { return Float16::isNan(f); } }; // Since copying a floating point number (especially if it is NaN) @@ -86,7 +112,7 @@ class FloatProxy { uint_type data() const { return data_; } // Returns true if the value represents any type of NaN. - bool isNan() { return std::isnan(getAsFloat()); } + bool isNan() { return FloatProxyTraits::isNan(getAsFloat()); } private: uint_type data_; @@ -111,9 +137,13 @@ std::istream& operator>>(std::istream& is, FloatProxy& value) { template struct HexFloatTraits { // Integer type that can store this hex-float. - typedef void uint_type; + using uint_type = void; // Signed integer type that can store this hex-float. - typedef void int_type; + using int_type = void; + // The numerical type that this HexFloat represents. + using underlying_type = void; + // The type needed to construct the underlying type. + using native_type = void; // The number of bits that are actually relevant in the uint_type. // This allows us to deal with, for example, 24-bit values in a 32-bit // integer. @@ -131,8 +161,10 @@ struct HexFloatTraits { // 1 sign bit, 8 exponent bits, 23 fractional bits. template <> struct HexFloatTraits> { - typedef uint32_t uint_type; - typedef int32_t int_type; + using uint_type = uint32_t; + using int_type = int32_t; + using underlying_type = FloatProxy; + using native_type = float; static const uint_type num_used_bits = 32; static const uint_type num_exponent_bits = 8; static const uint_type num_fraction_bits = 23; @@ -143,14 +175,38 @@ struct HexFloatTraits> { // 1 sign bit, 11 exponent bits, 52 fractional bits. template <> struct HexFloatTraits> { - typedef uint64_t uint_type; - typedef int64_t int_type; + using uint_type = uint64_t; + using int_type = int64_t; + using underlying_type = FloatProxy; + using native_type = double; static const uint_type num_used_bits = 64; static const uint_type num_exponent_bits = 11; static const uint_type num_fraction_bits = 52; static const uint_type exponent_bias = 1023; }; +// Traits for IEEE half. +// 1 sign bit, 5 exponent bits, 10 fractional bits. +template <> +struct HexFloatTraits> { + using uint_type = uint16_t; + using int_type = int16_t; + using underlying_type = uint16_t; + using native_type = uint16_t; + static const uint_type num_used_bits = 16; + static const uint_type num_exponent_bits = 5; + static const uint_type num_fraction_bits = 10; + static const uint_type exponent_bias = 15; +}; + +enum class round_direction { + kToZero, + kToNearestEven, + kToPositiveInfinity, + kToNegativeInfinity, + max = kToNegativeInfinity +}; + // Template class that houses a floating pointer number. // It exposes a number of constants based on the provided traits to // assist in interpreting the bits of the value. @@ -159,6 +215,8 @@ class HexFloat { public: using uint_type = typename Traits::uint_type; using int_type = typename Traits::int_type; + using underlying_type = typename Traits::underlying_type; + using native_type = typename Traits::native_type; explicit HexFloat(T f) : value_(f) {} @@ -190,10 +248,15 @@ class HexFloat { spvutils::SetBits::get; - // The topmost bit in the fraction. (The first non-implicit bit). + // The topmost bit in the nibble-aligned fraction. static const uint_type fraction_top_bit = uint_type(1) << (num_fraction_bits + num_overflow_bits - 1); + // The least significant bit in the exponent, which is also the bit + // immediately to the left of the significand. + static const uint_type first_exponent_bit = uint_type(1) + << (num_fraction_bits); + // The mask for the encoded fraction. It does not include the // implicit bit. static const uint_type fraction_encode_mask = @@ -213,12 +276,334 @@ class HexFloat { static const uint32_t fraction_right_shift = (sizeof(uint_type) * 8) - num_fraction_bits; + // The maximum representable unbiased exponent. + static const int_type max_exponent = + (exponent_mask >> num_fraction_bits) - exponent_bias; + // The minimum representable exponent for normalized numbers. + static const int_type min_exponent = -static_cast(exponent_bias); + + // Returns the bits associated with the value. + uint_type getBits() const { return spvutils::BitwiseCast(value_); } + + // Returns the bits associated with the value, without the leading sign bit. + uint_type getUnsignedBits() const { + return spvutils::BitwiseCast(value_) & ~sign_mask; + } + + // Returns the bits associated with the exponent, shifted to start at the + // lsb of the type. + const uint_type getExponentBits() const { + return (getBits() & exponent_mask) >> num_fraction_bits; + } + + // Returns the exponent in unbiased form. This is the exponent in the + // human-friendly form. + const int_type getUnbiasedExponent() const { + return (static_cast(getExponentBits()) - exponent_bias); + } + + // Returns just the significand bits from the value. + const uint_type getSignificandBits() const { + return getBits() & fraction_encode_mask; + } + + // If the number was normalized, returns the unbiased exponent. + // If the number was denormal, normalize the exponent first. + const int_type getUnbiasedNormalizedExponent() const { + if ((getBits() & ~sign_mask) == 0) { // special case if everything is 0 + return 0; + } + int_type exp = getUnbiasedExponent(); + if (exp == min_exponent) { // We are in denorm land. + uint_type significand_bits = getSignificandBits(); + while ((significand_bits & (first_exponent_bit >> 1)) == 0) { + significand_bits <<= 1; + exp -= 1; + } + significand_bits &= fraction_encode_mask; + } + return exp; + } + + // Returns the signficand after it has been normalized. + const uint_type getNormalizedSignificand() const { + int_type unbiased_exponent = getUnbiasedNormalizedExponent(); + uint_type significand = getSignificandBits(); + for (int_type i = unbiased_exponent; i <= min_exponent; ++i) { + significand <<= 1; + } + significand &= fraction_encode_mask; + return significand; + } + + // Returns true if this number represents a negative value. + bool isNegative() const { return (getBits() & sign_mask) != 0; } + + // Sets this HexFloat from the individual components. + // Note this assumes EVERY significand is normalized, and has an implicit + // leading one. This means that the only way that this method will set 0, + // is if you set a number so denormalized that it underflows. + // Do not use this method with raw bits extracted from a subnormal number, + // since subnormals do not have an implicit leading 1 in the significand. + // The significand is also expected to be in the + // lowest-most num_fraction_bits of the uint_type. + // The exponent is expected to be unbiased, meaning an exponent of + // 0 actually means 0. + // If underflow_round_up is set, then on underflow, if a number is non-0 + // and would underflow, we round up to the smallest denorm. + void setFromSignUnbiasedExponentAndNormalizedSignificand( + bool negative, int_type exponent, uint_type significand, + bool round_denorm_up) { + bool significand_is_zero = significand == 0; + + if (exponent <= min_exponent) { + // If this was denormalized, then we have to shift the bit on, meaning + // the significand is not zero. + significand_is_zero = false; + significand |= first_exponent_bit; + significand >>= 1; + } + + while (exponent < min_exponent) { + significand >>= 1; + ++exponent; + } + + if (exponent == min_exponent) { + if (significand == 0 && !significand_is_zero && round_denorm_up) { + significand = 0x1; + } + } + + uint_type value = 0; + if (negative) { + value |= sign_mask; + } + exponent += exponent_bias; + assert(exponent >= 0); + + // put it all together + exponent = (exponent << exponent_left_shift) & exponent_mask; + significand &= fraction_encode_mask; + value |= exponent | significand; + value_ = BitwiseCast(value); + } + + // Increments the significand of this number by the given amount. + // If this would spill the significand into the implicit bit, + // carry is set to true and the significand is shifted to fit into + // the correct location, otherwise carry is set to false. + // All significands and to_increment are assumed to be within the bounds + // for a valid significand. + static uint_type incrementSignificand(uint_type significand, + uint_type to_increment, bool* carry) { + significand += to_increment; + *carry = false; + if (significand & first_exponent_bit) { + *carry = true; + // The implicit 1-bit will have carried, so we should zero-out the + // top bit and shift back. + significand &= ~first_exponent_bit; + significand >>= 1; + } + return significand; + } + + // These exist because MSVC throws warnings on negative right-shifts + // even if they are not going to be executed. Eg: + // constant_number < 0? 0: constant_number + // These convert the negative left-shifts into right shifts. + + template + struct negatable_left_shift { + static uint_type val(uint_type val) { return val >> -N; } + }; + + template + struct negatable_left_shift= 0>::type> { + static uint_type val(uint_type val) { return val << N; } + }; + + template + struct negatable_right_shift { + static uint_type val(uint_type val) { return val << -N; } + }; + + template + struct negatable_right_shift= 0>::type> { + static uint_type val(uint_type val) { return val >> N; } + }; + + // Returns the significand, rounded to fit in a significand in + // other_T. This is shifted so that the most significant + // bit of the rounded number lines up with the most significant bit + // of the returned significand. + template + typename other_T::uint_type getRoundedNormalizedSignificand( + round_direction dir, bool* carry_bit) { + using other_uint_type = typename other_T::uint_type; + static const int_type num_throwaway_bits = + static_cast(num_fraction_bits) - + static_cast(other_T::num_fraction_bits); + + static const uint_type last_significant_bit = + (num_throwaway_bits < 0) + ? 0 + : negatable_left_shift::val(1u); + static const uint_type first_rounded_bit = + (num_throwaway_bits < 1) + ? 0 + : negatable_left_shift::val(1u); + + static const uint_type throwaway_mask_bits = + num_throwaway_bits > 0 ? num_throwaway_bits : 0; + static const uint_type throwaway_mask = + spvutils::SetBits::get; + + *carry_bit = false; + other_uint_type out_val = 0; + uint_type significand = getNormalizedSignificand(); + // If we are up-casting, then we just have to shift to the right location. + if (num_throwaway_bits <= 0) { + out_val = significand; + uint_type shift_amount = -num_throwaway_bits; + out_val <<= shift_amount; + return out_val; + } + + // If every non-representable bit is 0, then we don't have any casting to + // do. + if ((significand & throwaway_mask) == 0) { + return static_cast( + negatable_right_shift::val(significand)); + } + + bool round_away_from_zero = false; + // We actually have to narrow the significand here, so we have to follow the + // rounding rules. + switch (dir) { + case round_direction::kToZero: + break; + case round_direction::kToPositiveInfinity: + round_away_from_zero = !isNegative(); + break; + case round_direction::kToNegativeInfinity: + round_away_from_zero = isNegative(); + break; + case round_direction::kToNearestEven: + // Have to round down, round bit is 0 + if ((first_rounded_bit & significand) == 0) { + break; + } + if (((significand & throwaway_mask) & ~first_rounded_bit) != 0) { + // If any subsequent bit of the rounded portion is non-0 then we round + // up. + round_away_from_zero = true; + break; + } + // We are exactly half-way between 2 numbers, pick even. + if ((significand & last_significant_bit) != 0) { + // 1 for our last bit, round up. + round_away_from_zero = true; + break; + } + break; + } + + if (round_away_from_zero) { + return static_cast( + negatable_right_shift::val(incrementSignificand( + significand, last_significant_bit, carry_bit))); + } else { + return static_cast( + negatable_right_shift::val(significand)); + } + // We really shouldn't get here. + assert(false && "We should not have ended up here"); + return 0; + } + + // Casts this value to another HexFloat. If the cast is widening, + // then round_dir is ignored. If the cast is narrowing, then + // the result is rounded in the direction specified. + // This number will retain Nan and Inf values. + // It will also saturate to Inf if the number overflows, and + // underflow to (0 or min depending on rounding) if the number underflows. + template + void castTo(other_T& other, round_direction round_dir) { + other = other_T(static_cast(0)); + bool negate = isNegative(); + if (getUnsignedBits() == 0) { + if (negate) { + other.set_value(-other.value()); + } + return; + } + uint_type significand = getSignificandBits(); + bool carried = false; + typename other_T::uint_type rounded_significand = + getRoundedNormalizedSignificand(round_dir, &carried); + + int_type exponent = getUnbiasedExponent(); + if (exponent == min_exponent) { + // If we are denormal, normalize the exponent, so that we can encode + // easily. + exponent += 1; + for (uint_type check_bit = first_exponent_bit >> 1; check_bit != 0; + check_bit >>= 1) { + exponent -= 1; + if (check_bit & significand) break; + } + } + + bool is_nan = + (getBits() & exponent_mask) == exponent_mask && significand != 0; + bool is_inf = + !is_nan && + ((exponent + carried) > static_cast(other_T::exponent_bias) || + (significand == 0 && (getBits() & exponent_mask) == exponent_mask)); + + // If we are Nan or Inf we should pass that through. + if (is_inf) { + other.set_value(BitwiseCast( + static_cast( + (negate ? other_T::sign_mask : 0) | other_T::exponent_mask))); + return; + } + if (is_nan) { + typename other_T::uint_type shifted_significand; + shifted_significand = static_cast( + negatable_left_shift::val(significand)); + + // We are some sort of Nan. We try to keep the bit-pattern of the Nan + // as close as possible. If we had to shift off bits so we are 0, then we + // just set the last bit. + other.set_value(BitwiseCast( + static_cast( + (negate ? other_T::sign_mask : 0) | other_T::exponent_mask | + (shifted_significand == 0 ? 0x1 : shifted_significand)))); + return; + } + + bool round_underflow_up = + isNegative() ? round_dir == round_direction::kToNegativeInfinity + : round_dir == round_direction::kToPositiveInfinity; + + // setFromSignUnbiasedExponentAndNormalizedSignificand will + // zero out any underflowing value (but retain the sign). + other.setFromSignUnbiasedExponentAndNormalizedSignificand( + negate, exponent, rounded_significand, round_underflow_up); + return; + } + private: T value_; static_assert(num_used_bits == Traits::num_exponent_bits + Traits::num_fraction_bits + 1, "The number of bits do not fit"); + static_assert(sizeof(T) == sizeof(uint_type), "The type sizes do not match"); }; // Returns 4 bits represented by the hex character. diff --git a/test/HexFloat.cpp b/test/HexFloat.cpp index 3fdea66a5..8f3a36d3f 100644 --- a/test/HexFloat.cpp +++ b/test/HexFloat.cpp @@ -528,5 +528,470 @@ INSTANTIATE_TEST_CASE_P( }))); +// double is used so that unbiased_exponent can be used with the output +// of ldexp directly. +int32_t unbiased_exponent(double f) { + return spvutils::HexFloat>( + static_cast(f)).getUnbiasedNormalizedExponent(); +} + +int16_t unbiased_half_exponent(uint16_t f) { + return spvutils::HexFloat>(f) + .getUnbiasedNormalizedExponent(); +} + +TEST(HexFloatOperationTest, UnbiasedExponent) { + // Float cases + EXPECT_EQ(0, unbiased_exponent(ldexp(1.0f, 0))); + EXPECT_EQ(-32, unbiased_exponent(ldexp(1.0f, -32))); + EXPECT_EQ(42, unbiased_exponent(ldexp(1.0f, 42))); + EXPECT_EQ(125, unbiased_exponent(ldexp(1.0f, 125))); + // Saturates to 128 + EXPECT_EQ(128, unbiased_exponent(ldexp(1.0f, 256))); + + EXPECT_EQ(-100, unbiased_exponent(ldexp(1.0f, -100))); + EXPECT_EQ(-127, unbiased_exponent(ldexp(1.0f, -127))); // First denorm + EXPECT_EQ(-128, unbiased_exponent(ldexp(1.0f, -128))); + EXPECT_EQ(-129, unbiased_exponent(ldexp(1.0f, -129))); + EXPECT_EQ(-140, unbiased_exponent(ldexp(1.0f, -140))); + // Smallest representable number + EXPECT_EQ(-126 - 23, unbiased_exponent(ldexp(1.0f, -126 - 23))); + // Should get rounded to 0 first. + EXPECT_EQ(0, unbiased_exponent(ldexp(1.0f, -127 - 23))); + + // Float16 cases + // The exponent is represented in the bits 0x7C00 + // The offset is -15 + EXPECT_EQ(0, unbiased_half_exponent(0x3C00)); + EXPECT_EQ(3, unbiased_half_exponent(0x4800)); + EXPECT_EQ(-1, unbiased_half_exponent(0x3800)); + EXPECT_EQ(-14, unbiased_half_exponent(0x0400)); + EXPECT_EQ(16, unbiased_half_exponent(0x7C00)); + EXPECT_EQ(10, unbiased_half_exponent(0x6400)); + + // Smallest representable number + EXPECT_EQ(-24, unbiased_half_exponent(0x0001)); +} + +// Creates a float that is the sum of 1/(2 ^ fractions[i]) for i in factions +float float_fractions(const std::vector& fractions) { + float f = 0; + for(int32_t i: fractions) { + f += ldexp(1.0f, -i); + } + return f; +} + +// Returns the normalized significand of a HexFloat> +// that was created by calling float_fractions with the input fractions, +// raised to the power of exp. +uint32_t normalized_significand(const std::vector& fractions, uint32_t exp) { + return spvutils::HexFloat>( + static_cast(ldexp(float_fractions(fractions), exp))) + .getNormalizedSignificand(); +} + +// Sets the bits from MSB to LSB of the significand part of a float. +// For example 0 would set the bit 23 (counting from LSB to MSB), +// and 1 would set the 22nd bit. +uint32_t bits_set(const std::vector& bits) { + const uint32_t top_bit = 1u << 22u; + uint32_t val= 0; + for(uint32_t i: bits) { + val |= top_bit >> i; + } + return val; +} + +// The same as bits_set but for a Float16 value instead of 32-bit floating +// point. +uint16_t half_bits_set(const std::vector& bits) { + const uint32_t top_bit = 1u << 9u; + uint32_t val= 0; + for(uint32_t i: bits) { + val |= top_bit >> i; + } + return val; +} + +TEST(HexFloatOperationTest, NormalizedSignificand) { + // For normalized numbers (the following) it should be a simple matter + // of getting rid of the top implicit bit + EXPECT_EQ(bits_set({}), normalized_significand({0}, 0)); + EXPECT_EQ(bits_set({0}), normalized_significand({0, 1}, 0)); + EXPECT_EQ(bits_set({0, 1}), normalized_significand({0, 1, 2}, 0)); + EXPECT_EQ(bits_set({1}), normalized_significand({0, 2}, 0)); + EXPECT_EQ(bits_set({1}), normalized_significand({0, 2}, 32)); + EXPECT_EQ(bits_set({1}), normalized_significand({0, 2}, 126)); + + // For denormalized numbers we expect the normalized significand to + // shift as if it were normalized. This means, in practice that the + // top_most set bit will be cut off. Looks very similar to above (on purpose) + EXPECT_EQ(bits_set({}), normalized_significand({0}, -127)); + EXPECT_EQ(bits_set({3}), normalized_significand({0, 4}, -128)); + EXPECT_EQ(bits_set({3}), normalized_significand({0, 4}, -127)); + EXPECT_EQ(bits_set({}), normalized_significand({22}, -127)); + EXPECT_EQ(bits_set({0}), normalized_significand({21, 22}, -127)); +} + +// Returns the 32-bit floating point value created by +// calling setFromSignUnbiasedExponentAndNormalizedSignificand +// on a HexFloat> +float set_from_sign(bool negative, int32_t unbiased_exponent, + uint32_t significand, bool round_denorm_up) { + spvutils::HexFloat> f(0.f); + f.setFromSignUnbiasedExponentAndNormalizedSignificand( + negative, unbiased_exponent, significand, round_denorm_up); + return f.value().getAsFloat(); +} + +TEST(HexFloatOperationTests, + SetFromSignUnbiasedExponentAndNormalizedSignificand) { + + EXPECT_EQ(1.f, set_from_sign(false, 0, 0, false)); + + // Tests insertion of various denormalized numbers with and without round up. + EXPECT_EQ(static_cast(ldexp(1.f, -149)), set_from_sign(false, -149, 0, false)); + EXPECT_EQ(static_cast(ldexp(1.f, -149)), set_from_sign(false, -149, 0, true)); + EXPECT_EQ(0.f, set_from_sign(false, -150, 1, false)); + EXPECT_EQ(static_cast(ldexp(1.f, -149)), set_from_sign(false, -150, 1, true)); + + EXPECT_EQ(ldexp(1.0f, -127), set_from_sign(false, -127, 0, false)); + EXPECT_EQ(ldexp(1.0f, -128), set_from_sign(false, -128, 0, false)); + EXPECT_EQ(float_fractions({0, 1, 2, 5}), + set_from_sign(false, 0, bits_set({0, 1, 4}), false)); + EXPECT_EQ(ldexp(float_fractions({0, 1, 2, 5}), -32), + set_from_sign(false, -32, bits_set({0, 1, 4}), false)); + EXPECT_EQ(ldexp(float_fractions({0, 1, 2, 5}), -128), + set_from_sign(false, -128, bits_set({0, 1, 4}), false)); + + // The negative cases from above. + EXPECT_EQ(-1.f, set_from_sign(true, 0, 0, false)); + EXPECT_EQ(-ldexp(1.0, -127), set_from_sign(true, -127, 0, false)); + EXPECT_EQ(-ldexp(1.0, -128), set_from_sign(true, -128, 0, false)); + EXPECT_EQ(-float_fractions({0, 1, 2, 5}), + set_from_sign(true, 0, bits_set({0, 1, 4}), false)); + EXPECT_EQ(-ldexp(float_fractions({0, 1, 2, 5}), -32), + set_from_sign(true, -32, bits_set({0, 1, 4}), false)); + EXPECT_EQ(-ldexp(float_fractions({0, 1, 2, 5}), -128), + set_from_sign(true, -128, bits_set({0, 1, 4}), false)); +} + +TEST(HexFloatOperationTests, NonRounding) { + // Rounding from 32-bit hex-float to 32-bit hex-float should be trivial, + // except in the denorm case which is a bit more complex. + using HF = spvutils::HexFloat>; + bool carry_bit = false; + + spvutils::round_direction rounding[] = { + spvutils::round_direction::kToZero, + spvutils::round_direction::kToNearestEven, + spvutils::round_direction::kToPositiveInfinity, + spvutils::round_direction::kToNegativeInfinity}; + + // Everything fits, so this should be straight-forward + for (spvutils::round_direction round : rounding) { + EXPECT_EQ(bits_set({}), HF(0.f).getRoundedNormalizedSignificand( + round, &carry_bit)); + EXPECT_FALSE(carry_bit); + + EXPECT_EQ(bits_set({0}), + HF(float_fractions({0, 1})) + .getRoundedNormalizedSignificand(round, &carry_bit)); + EXPECT_FALSE(carry_bit); + + EXPECT_EQ(bits_set({1, 3}), + HF(float_fractions({0, 2, 4})) + .getRoundedNormalizedSignificand(round, &carry_bit)); + EXPECT_FALSE(carry_bit); + + EXPECT_EQ( + bits_set({0, 1, 4}), + HF(static_cast(-ldexp(float_fractions({0, 1, 2, 5}), -128))) + .getRoundedNormalizedSignificand(round, &carry_bit)); + EXPECT_FALSE(carry_bit); + + EXPECT_EQ( + bits_set({0, 1, 4, 22}), + HF(static_cast(float_fractions({0, 1, 2, 5, 23}))) + .getRoundedNormalizedSignificand(round, &carry_bit)); + EXPECT_FALSE(carry_bit); + } +} + +using RD = spvutils::round_direction; +struct RoundSignificandCase { + float source_float; + std::pair expected_results; + spvutils::round_direction round; +}; + +using HexFloatRoundTest = + ::testing::TestWithParam; + +TEST_P(HexFloatRoundTest, RoundDownToFP16) { + using HF = spvutils::HexFloat>; + using HF16 = spvutils::HexFloat>; + + HF input_value(GetParam().source_float); + bool carry_bit = false; + EXPECT_EQ(GetParam().expected_results.first, + input_value.getRoundedNormalizedSignificand( + GetParam().round, &carry_bit)); + EXPECT_EQ(carry_bit, GetParam().expected_results.second); +} + +// clang-format off +INSTANTIATE_TEST_CASE_P(F32ToF16, HexFloatRoundTest, + ::testing::ValuesIn(std::vector( + { + {float_fractions({0}), std::make_pair(half_bits_set({}), false), RD::kToZero}, + {float_fractions({0}), std::make_pair(half_bits_set({}), false), RD::kToNearestEven}, + {float_fractions({0}), std::make_pair(half_bits_set({}), false), RD::kToPositiveInfinity}, + {float_fractions({0}), std::make_pair(half_bits_set({}), false), RD::kToNegativeInfinity}, + {float_fractions({0, 1}), std::make_pair(half_bits_set({0}), false), RD::kToZero}, + + {float_fractions({0, 1, 11}), std::make_pair(half_bits_set({0}), false), RD::kToZero}, + {float_fractions({0, 1, 11}), std::make_pair(half_bits_set({0, 9}), false), RD::kToPositiveInfinity}, + {float_fractions({0, 1, 11}), std::make_pair(half_bits_set({0}), false), RD::kToNegativeInfinity}, + {float_fractions({0, 1, 11}), std::make_pair(half_bits_set({0}), false), RD::kToNearestEven}, + + {float_fractions({0, 1, 10, 11}), std::make_pair(half_bits_set({0, 9}), false), RD::kToZero}, + {float_fractions({0, 1, 10, 11}), std::make_pair(half_bits_set({0, 8}), false), RD::kToPositiveInfinity}, + {float_fractions({0, 1, 10, 11}), std::make_pair(half_bits_set({0, 9}), false), RD::kToNegativeInfinity}, + {float_fractions({0, 1, 10, 11}), std::make_pair(half_bits_set({0, 8}), false), RD::kToNearestEven}, + + {float_fractions({0, 1, 11, 12}), std::make_pair(half_bits_set({0}), false), RD::kToZero}, + {float_fractions({0, 1, 11, 12}), std::make_pair(half_bits_set({0, 9}), false), RD::kToPositiveInfinity}, + {float_fractions({0, 1, 11, 12}), std::make_pair(half_bits_set({0}), false), RD::kToNegativeInfinity}, + {float_fractions({0, 1, 11, 12}), std::make_pair(half_bits_set({0, 9}), false), RD::kToNearestEven}, + + {-float_fractions({0, 1, 11, 12}), std::make_pair(half_bits_set({0}), false), RD::kToZero}, + {-float_fractions({0, 1, 11, 12}), std::make_pair(half_bits_set({0}), false), RD::kToPositiveInfinity}, + {-float_fractions({0, 1, 11, 12}), std::make_pair(half_bits_set({0, 9}), false), RD::kToNegativeInfinity}, + {-float_fractions({0, 1, 11, 12}), std::make_pair(half_bits_set({0, 9}), false), RD::kToNearestEven}, + + {float_fractions({0, 1, 11, 22}), std::make_pair(half_bits_set({0}), false), RD::kToZero}, + {float_fractions({0, 1, 11, 22}), std::make_pair(half_bits_set({0, 9}), false), RD::kToPositiveInfinity}, + {float_fractions({0, 1, 11, 22}), std::make_pair(half_bits_set({0}), false), RD::kToNegativeInfinity}, + {float_fractions({0, 1, 11, 22}), std::make_pair(half_bits_set({0, 9}), false), RD::kToNearestEven}, + + // Carries + {float_fractions({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}), std::make_pair(half_bits_set({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}), false), RD::kToZero}, + {float_fractions({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}), std::make_pair(half_bits_set({}), true), RD::kToPositiveInfinity}, + {float_fractions({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}), std::make_pair(half_bits_set({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}), false), RD::kToNegativeInfinity}, + {float_fractions({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}), std::make_pair(half_bits_set({}), true), RD::kToNearestEven}, + + // Cases where original number was denorm. Note: this should have no effect + // the number is pre-normalized. + {static_cast(ldexp(float_fractions({0, 1, 11, 13}), -128)), std::make_pair(half_bits_set({0}), false), RD::kToZero}, + {static_cast(ldexp(float_fractions({0, 1, 11, 13}), -129)), std::make_pair(half_bits_set({0, 9}), false), RD::kToPositiveInfinity}, + {static_cast(ldexp(float_fractions({0, 1, 11, 13}), -131)), std::make_pair(half_bits_set({0}), false), RD::kToNegativeInfinity}, + {static_cast(ldexp(float_fractions({0, 1, 11, 13}), -130)), std::make_pair(half_bits_set({0, 9}), false), RD::kToNearestEven}, + }))); +// clang-format on + +struct UpCastSignificandCase { + uint16_t source_half; + uint32_t expected_result; +}; + +using HexFloatRoundUpSignificandTest = + ::testing::TestWithParam; +TEST_P(HexFloatRoundUpSignificandTest, Widening) { + using HF = spvutils::HexFloat>; + using HF16 = spvutils::HexFloat>; + bool carry_bit = false; + + spvutils::round_direction rounding[] = { + spvutils::round_direction::kToZero, + spvutils::round_direction::kToNearestEven, + spvutils::round_direction::kToPositiveInfinity, + spvutils::round_direction::kToNegativeInfinity}; + + // Everything fits, so everything should just be bit-shifts. + for (spvutils::round_direction round : rounding) { + carry_bit = false; + HF16 input_value(GetParam().source_half); + EXPECT_EQ( + GetParam().expected_result, + input_value.getRoundedNormalizedSignificand(round, &carry_bit)) + << std::hex << "0x" + << input_value.getRoundedNormalizedSignificand(round, &carry_bit) + << " 0x" << GetParam().expected_result; + EXPECT_FALSE(carry_bit); + } +} + +INSTANTIATE_TEST_CASE_P(F16toF32, HexFloatRoundUpSignificandTest, + // 0xFC00 of the source 16-bit hex value cover the sign and the exponent. + // They are ignored for this test. + ::testing::ValuesIn(std::vector( + { + {0x3F00, 0x600000}, + {0x0F00, 0x600000}, + {0x0F01, 0x602000}, + {0x0FFF, 0x7FE000}, + }))); + +struct DownCastTest { + float source_float; + uint16_t expected_half; + std::vector directions; +}; + +std::string get_round_text(spvutils::round_direction direction) { +#define CASE(round_direction) \ + case round_direction: \ + return #round_direction + + switch (direction) { + CASE(spvutils::round_direction::kToZero); + CASE(spvutils::round_direction::kToPositiveInfinity); + CASE(spvutils::round_direction::kToNegativeInfinity); + CASE(spvutils::round_direction::kToNearestEven); + } +#undef CASE + return ""; +} + +using HexFloatFP32To16Tests = ::testing::TestWithParam; + +TEST_P(HexFloatFP32To16Tests, NarrowingCasts) { + using HF = spvutils::HexFloat>; + using HF16 = spvutils::HexFloat>; + HF f(GetParam().source_float); + for (auto round : GetParam().directions) { + HF16 half(0); + f.castTo(half, round); + EXPECT_EQ(GetParam().expected_half, half.value().getAsFloat().get_value()) + << get_round_text(round) << " " << std::hex + << spvutils::BitwiseCast(GetParam().source_float) + << " cast to: " << half.value().getAsFloat().get_value(); + } +} + +const uint16_t positive_infinity = 0x7C00; +const uint16_t negative_infinity = 0xFC00; + +INSTANTIATE_TEST_CASE_P(F32ToF16, HexFloatFP32To16Tests, + ::testing::ValuesIn(std::vector( + { + // Exactly representable as half. + {0.f, 0x0, {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, RD::kToNearestEven}}, + {-0.f, 0x8000, {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, RD::kToNearestEven}}, + {1.0f, 0x3C00, {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, RD::kToNearestEven}}, + {-1.0f, 0xBC00, {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, RD::kToNearestEven}}, + + {float_fractions({0, 1, 10}) , 0x3E01, {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, RD::kToNearestEven}}, + {-float_fractions({0, 1, 10}) , 0xBE01, {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, RD::kToNearestEven}}, + {static_cast(ldexp(float_fractions({0, 1, 10}), 3)), 0x4A01, {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, RD::kToNearestEven}}, + {static_cast(-ldexp(float_fractions({0, 1, 10}), 3)), 0xCA01, {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, RD::kToNearestEven}}, + + + // Underflow + {static_cast(ldexp(1.0f, -25)), 0x0, {RD::kToZero, RD::kToNegativeInfinity, RD::kToNearestEven}}, + {static_cast(ldexp(1.0f, -25)), 0x1, {RD::kToPositiveInfinity}}, + {static_cast(-ldexp(1.0f, -25)), 0x8000, {RD::kToZero, RD::kToPositiveInfinity, RD::kToNearestEven}}, + {static_cast(-ldexp(1.0f, -25)), 0x8001, {RD::kToNegativeInfinity}}, + {static_cast(ldexp(1.0f, -24)), 0x1, {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, RD::kToNearestEven}}, + + // Overflow + {static_cast(ldexp(1.0f, 16)), positive_infinity, {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, RD::kToNearestEven}}, + {static_cast(ldexp(1.0f, 18)), positive_infinity, {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, RD::kToNearestEven}}, + {static_cast(ldexp(1.3f, 16)), positive_infinity, {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, RD::kToNearestEven}}, + {static_cast(-ldexp(1.0f, 16)), negative_infinity, {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, RD::kToNearestEven}}, + {static_cast(-ldexp(1.0f, 18)), negative_infinity, {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, RD::kToNearestEven}}, + {static_cast(-ldexp(1.3f, 16)), negative_infinity, {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, RD::kToNearestEven}}, + + // Transfer of Infinities + {std::numeric_limits::infinity(), positive_infinity, {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, RD::kToNearestEven}}, + {-std::numeric_limits::infinity(), negative_infinity, {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, RD::kToNearestEven}}, + + // Nans are below because we cannot test for equality. + }))); + +struct UpCastCase{ + uint16_t source_half; + float expected_float; +}; + +using HexFloatFP16To32Tests = ::testing::TestWithParam; +TEST_P(HexFloatFP16To32Tests, WideningCasts) { + using HF = spvutils::HexFloat>; + using HF16 = spvutils::HexFloat>; + HF16 f(GetParam().source_half); + + spvutils::round_direction rounding[] = { + spvutils::round_direction::kToZero, + spvutils::round_direction::kToNearestEven, + spvutils::round_direction::kToPositiveInfinity, + spvutils::round_direction::kToNegativeInfinity}; + + // Everything fits, so everything should just be bit-shifts. + for (spvutils::round_direction round : rounding) { + HF flt(0.f); + f.castTo(flt, round); + EXPECT_EQ(GetParam().expected_float, flt.value().getAsFloat()) + << get_round_text(round) << " " << std::hex + << spvutils::BitwiseCast(GetParam().source_half) + << " cast to: " << flt.value().getAsFloat(); + } +} + +INSTANTIATE_TEST_CASE_P(F16ToF32, HexFloatFP16To32Tests, + ::testing::ValuesIn(std::vector( + { + {0x0000, 0.f}, + {0x8000, -0.f}, + {0x3C00, 1.0f}, + {0xBC00, -1.0f}, + {0x3F00, float_fractions({0, 1, 2})}, + {0xBF00, -float_fractions({0, 1, 2})}, + {0x3F01, float_fractions({0, 1, 2, 10})}, + {0xBF01, -float_fractions({0, 1, 2, 10})}, + + // denorm + {0x0001, static_cast(ldexp(1.0, -24))}, + {0x0002, static_cast(ldexp(1.0, -23))}, + {0x8001, static_cast(-ldexp(1.0, -24))}, + {0x8011, static_cast(-ldexp(1.0, -20) + -ldexp(1.0, -24))}, + + // inf + {0x7C00, std::numeric_limits::infinity()}, + {0xFC00, -std::numeric_limits::infinity()}, + }))); + +TEST(HexFloatOperationTests, NanTests) { + using HF = spvutils::HexFloat>; + using HF16 = spvutils::HexFloat>; + spvutils::round_direction rounding[] = { + spvutils::round_direction::kToZero, + spvutils::round_direction::kToNearestEven, + spvutils::round_direction::kToPositiveInfinity, + spvutils::round_direction::kToNegativeInfinity}; + + // Everything fits, so everything should just be bit-shifts. + for (spvutils::round_direction round : rounding) { + HF16 f16(0); + HF f(0.f); + HF(std::numeric_limits::quiet_NaN()).castTo(f16, round); + EXPECT_TRUE(f16.value().isNan()); + HF(std::numeric_limits::signaling_NaN()).castTo(f16, round); + EXPECT_TRUE(f16.value().isNan()); + + HF16(0x7C01).castTo(f, round); + EXPECT_TRUE(f.value().isNan()); + HF16(0x7C11).castTo(f, round); + EXPECT_TRUE(f.value().isNan()); + HF16(0xFC01).castTo(f, round); + EXPECT_TRUE(f.value().isNan()); + HF16(0x7C10).castTo(f, round); + EXPECT_TRUE(f.value().isNan()); + HF16(0xFF00).castTo(f, round); + EXPECT_TRUE(f.value().isNan()); + } +} + // TODO(awoloszyn): Add fp16 tests and HexFloatTraits. }