From 5cb47d6a884e7e3665ba8d3852558b0fd42bdf11 Mon Sep 17 00:00:00 2001 From: Mike Klein Date: Fri, 10 Jul 2020 15:46:46 -0500 Subject: [PATCH] refactor skvx::if_then_else() First move if_then_else() specializations inline using a quasi-constexpr-if approach, letting them operate on any types of the right vector and lane size. We can't use constexpr-if per se because this header is sometimes used in C++14 contexts. Then, add AVX specialization for 8x32-bit types. SkVM's interpreter uses if_then_else() on three i32x16, and these changes allow that to vectorize ideally, as two vblendvps instructions. Change-Id: I8355c47975c736c1fbc32b1f8ceddb772978d271 Reviewed-on: https://skia-review.googlesource.com/c/skia/+/302080 Auto-Submit: Mike Klein Commit-Queue: Brian Osman Reviewed-by: Brian Osman --- include/private/SkVx.h | 75 +++++++++++++++++++++++------------------- 1 file changed, 41 insertions(+), 34 deletions(-) diff --git a/include/private/SkVx.h b/include/private/SkVx.h index d3dc6c957b..f2da619a2c 100644 --- a/include/private/SkVx.h +++ b/include/private/SkVx.h @@ -139,13 +139,18 @@ struct Vec<1,T> { }; template -static inline D bit_pun(const S& s) { - static_assert(sizeof(D) == sizeof(S), ""); +static inline D unchecked_bit_pun(const S& s) { D d; memcpy(&d, &s, sizeof(D)); return d; } +template +static inline D bit_pun(const S& s) { + static_assert(sizeof(D) == sizeof(S), ""); + return unchecked_bit_pun(s); +} + // Translate from a value type T to its corresponding Mask, the result of a comparison. template struct Mask { using type = T; }; template <> struct Mask { using type = int32_t; }; @@ -272,9 +277,9 @@ SINT Vec<2*N,T> join(const Vec& lo, const Vec& hi) { // N == 1 scalar implementations. SIT Vec<1,T> if_then_else(const Vec<1,M>& cond, const Vec<1,T>& t, const Vec<1,T>& e) { - auto t_bits = bit_pun>(t), - e_bits = bit_pun>(e); - return bit_pun( (cond.val & t_bits) | (~cond.val & e_bits) ); + // In practice this scalar implementation is unlikely to be used. See if_then_else() below. + return bit_pun>(( cond & bit_pun>>(t)) | + (~cond & bit_pun>>(e)) ); } SIT bool any(const Vec<1,T>& x) { return x.val != 0; } @@ -308,8 +313,37 @@ SIT Vec<1,T> mad(const Vec<1,T>& f, // All default N != 1 implementations just recurse on lo and hi halves. SINT Vec if_then_else(const Vec>& cond, const Vec& t, const Vec& e) { - return join(if_then_else(cond.lo, t.lo, e.lo), - if_then_else(cond.hi, t.hi, e.hi)); + // Specializations inline here so they can generalize what types the apply to. + // (This header is used in C++14 contexts, so we have to kind of fake constexpr if.) +#if defined(__AVX__) + if /*constexpr*/ (N == 8 && sizeof(T) == 4) { + return unchecked_bit_pun>(_mm256_blendv_ps(unchecked_bit_pun<__m256>(e), + unchecked_bit_pun<__m256>(t), + unchecked_bit_pun<__m256>(cond))); + } +#endif +#if defined(__SSE4_1__) + if /*constexpr*/ (N == 4 && sizeof(T) == 4) { + return unchecked_bit_pun>(_mm_blendv_ps(unchecked_bit_pun<__m128>(e), + unchecked_bit_pun<__m128>(t), + unchecked_bit_pun<__m128>(cond))); + } +#endif +#if defined(__ARM_NEON) + if /*constexpr*/ (N == 4 && sizeof(T) == 4) { + return unchecked_bit_pun>(vbslq_f32(unchecked_bit_pun< uint32x4_t>(cond), + unchecked_bit_pun(t), + unchecked_bit_pun(e))); + } +#endif + // Recurse for large vectors to try to hit the specializations above. + if /*constexpr*/ (N > 4) { + return join(if_then_else(cond.lo, t.lo, e.lo), + if_then_else(cond.hi, t.hi, e.hi)); + } + // This default can lead to better code than the recursing onto scalars. + return bit_pun>(( cond & bit_pun>>(t)) | + (~cond & bit_pun>>(e)) ); } SINT bool any(const Vec& x) { return any(x.lo) || any(x.hi); } @@ -556,33 +590,6 @@ static inline Vec approx_scale(const Vec& x, const Vec if_then_else(const Vec<4,int >& c, - const Vec<4,float>& t, - const Vec<4,float>& e) { - return bit_pun>(_mm_blendv_ps(bit_pun<__m128>(e), - bit_pun<__m128>(t), - bit_pun<__m128>(c))); - } - #elif defined(__SSE__) - static inline Vec<4,float> if_then_else(const Vec<4,int >& c, - const Vec<4,float>& t, - const Vec<4,float>& e) { - return bit_pun>(_mm_or_ps(_mm_and_ps (bit_pun<__m128>(c), - bit_pun<__m128>(t)), - _mm_andnot_ps(bit_pun<__m128>(c), - bit_pun<__m128>(e)))); - } - #elif defined(__ARM_NEON) - static inline Vec<4,float> if_then_else(const Vec<4,int >& c, - const Vec<4,float>& t, - const Vec<4,float>& e) { - return bit_pun>(vbslq_f32(bit_pun (c), - bit_pun(t), - bit_pun(e))); - } - #endif - #if defined(__AVX2__) static inline Vec<4,float> fma(const Vec<4,float>& x, const Vec<4,float>& y,