refactor skvx::if_then_else()

First move if_then_else() specializations inline using a
quasi-constexpr-if approach, letting them operate on any types of the
right vector and lane size.  We can't use constexpr-if per se because
this header is sometimes used in C++14 contexts.

Then, add AVX specialization for 8x32-bit types.

SkVM's interpreter uses if_then_else() on three i32x16, and these
changes allow that to vectorize ideally, as two vblendvps instructions.

Change-Id: I8355c47975c736c1fbc32b1f8ceddb772978d271
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/302080
Auto-Submit: Mike Klein <mtklein@google.com>
Commit-Queue: Brian Osman <brianosman@google.com>
Reviewed-by: Brian Osman <brianosman@google.com>
This commit is contained in:
Mike Klein 2020-07-10 15:46:46 -05:00 committed by Skia Commit-Bot
parent 872a32c58d
commit 5cb47d6a88

View File

@ -139,13 +139,18 @@ struct Vec<1,T> {
};
template <typename D, typename S>
static inline D bit_pun(const S& s) {
static_assert(sizeof(D) == sizeof(S), "");
static inline D unchecked_bit_pun(const S& s) {
D d;
memcpy(&d, &s, sizeof(D));
return d;
}
template <typename D, typename S>
static inline D bit_pun(const S& s) {
static_assert(sizeof(D) == sizeof(S), "");
return unchecked_bit_pun<D>(s);
}
// Translate from a value type T to its corresponding Mask, the result of a comparison.
template <typename T> struct Mask { using type = T; };
template <> struct Mask<float > { using type = int32_t; };
@ -272,9 +277,9 @@ SINT Vec<2*N,T> join(const Vec<N,T>& lo, const Vec<N,T>& hi) {
// N == 1 scalar implementations.
SIT Vec<1,T> if_then_else(const Vec<1,M<T>>& cond, const Vec<1,T>& t, const Vec<1,T>& e) {
auto t_bits = bit_pun<M<T>>(t),
e_bits = bit_pun<M<T>>(e);
return bit_pun<T>( (cond.val & t_bits) | (~cond.val & e_bits) );
// In practice this scalar implementation is unlikely to be used. See if_then_else() below.
return bit_pun<Vec<1,T>>(( cond & bit_pun<Vec<1, M<T>>>(t)) |
(~cond & bit_pun<Vec<1, M<T>>>(e)) );
}
SIT bool any(const Vec<1,T>& x) { return x.val != 0; }
@ -308,8 +313,37 @@ SIT Vec<1,T> mad(const Vec<1,T>& f,
// All default N != 1 implementations just recurse on lo and hi halves.
SINT Vec<N,T> if_then_else(const Vec<N,M<T>>& cond, const Vec<N,T>& t, const Vec<N,T>& e) {
return join(if_then_else(cond.lo, t.lo, e.lo),
if_then_else(cond.hi, t.hi, e.hi));
// Specializations inline here so they can generalize what types the apply to.
// (This header is used in C++14 contexts, so we have to kind of fake constexpr if.)
#if defined(__AVX__)
if /*constexpr*/ (N == 8 && sizeof(T) == 4) {
return unchecked_bit_pun<Vec<N,T>>(_mm256_blendv_ps(unchecked_bit_pun<__m256>(e),
unchecked_bit_pun<__m256>(t),
unchecked_bit_pun<__m256>(cond)));
}
#endif
#if defined(__SSE4_1__)
if /*constexpr*/ (N == 4 && sizeof(T) == 4) {
return unchecked_bit_pun<Vec<N,T>>(_mm_blendv_ps(unchecked_bit_pun<__m128>(e),
unchecked_bit_pun<__m128>(t),
unchecked_bit_pun<__m128>(cond)));
}
#endif
#if defined(__ARM_NEON)
if /*constexpr*/ (N == 4 && sizeof(T) == 4) {
return unchecked_bit_pun<Vec<N,T>>(vbslq_f32(unchecked_bit_pun< uint32x4_t>(cond),
unchecked_bit_pun<float32x4_t>(t),
unchecked_bit_pun<float32x4_t>(e)));
}
#endif
// Recurse for large vectors to try to hit the specializations above.
if /*constexpr*/ (N > 4) {
return join(if_then_else(cond.lo, t.lo, e.lo),
if_then_else(cond.hi, t.hi, e.hi));
}
// This default can lead to better code than the recursing onto scalars.
return bit_pun<Vec<N,T>>(( cond & bit_pun<Vec<N, M<T>>>(t)) |
(~cond & bit_pun<Vec<N, M<T>>>(e)) );
}
SINT bool any(const Vec<N,T>& x) { return any(x.lo) || any(x.hi); }
@ -556,33 +590,6 @@ static inline Vec<N,uint8_t> approx_scale(const Vec<N,uint8_t>& x, const Vec<N,u
}
#endif
#if defined(__SSE4_1__)
static inline Vec<4,float> if_then_else(const Vec<4,int >& c,
const Vec<4,float>& t,
const Vec<4,float>& e) {
return bit_pun<Vec<4,float>>(_mm_blendv_ps(bit_pun<__m128>(e),
bit_pun<__m128>(t),
bit_pun<__m128>(c)));
}
#elif defined(__SSE__)
static inline Vec<4,float> if_then_else(const Vec<4,int >& c,
const Vec<4,float>& t,
const Vec<4,float>& e) {
return bit_pun<Vec<4,float>>(_mm_or_ps(_mm_and_ps (bit_pun<__m128>(c),
bit_pun<__m128>(t)),
_mm_andnot_ps(bit_pun<__m128>(c),
bit_pun<__m128>(e))));
}
#elif defined(__ARM_NEON)
static inline Vec<4,float> if_then_else(const Vec<4,int >& c,
const Vec<4,float>& t,
const Vec<4,float>& e) {
return bit_pun<Vec<4,float>>(vbslq_f32(bit_pun<uint32x4_t> (c),
bit_pun<float32x4_t>(t),
bit_pun<float32x4_t>(e)));
}
#endif
#if defined(__AVX2__)
static inline Vec<4,float> fma(const Vec<4,float>& x,
const Vec<4,float>& y,