x86: Use VMM API in memcmpeq-evex.S and minor changes

Changes to generated code are:
    1. In a few places use `vpcmpeqb` instead of `vpcmpneq` to save a
       byte of code size.
    2. Add a branch for length <= (VEC_SIZE * 6) as opposed to doing
       the entire block of [VEC_SIZE * 4 + 1, VEC_SIZE * 8] in a
       single basic-block (the space to add the extra branch without
       changing code size is bought with the above change).

Change (2) has roughly a 20-25% speedup for sizes in [VEC_SIZE * 4 +
1, VEC_SIZE * 6] and negligible to no-cost for [VEC_SIZE * 6 + 1,
VEC_SIZE * 8]

From N=10 runs on Tigerlake:

align1,align2 ,length ,result               ,New Time ,Cur Time ,New Time / Old Time
0     ,0      ,129    ,0                    ,5.404    ,6.887    ,0.785
0     ,0      ,129    ,1                    ,5.308    ,6.826    ,0.778
0     ,0      ,129    ,18446744073709551615 ,5.359    ,6.823    ,0.785
0     ,0      ,161    ,0                    ,5.284    ,6.827    ,0.774
0     ,0      ,161    ,1                    ,5.317    ,6.745    ,0.788
0     ,0      ,161    ,18446744073709551615 ,5.406    ,6.778    ,0.798

0     ,0      ,193    ,0                    ,6.804    ,6.802    ,1.000
0     ,0      ,193    ,1                    ,6.950    ,6.754    ,1.029
0     ,0      ,193    ,18446744073709551615 ,6.792    ,6.719    ,1.011
0     ,0      ,225    ,0                    ,6.625    ,6.699    ,0.989
0     ,0      ,225    ,1                    ,6.776    ,6.735    ,1.003
0     ,0      ,225    ,18446744073709551615 ,6.758    ,6.738    ,0.992
0     ,0      ,256    ,0                    ,5.402    ,5.462    ,0.989
0     ,0      ,256    ,1                    ,5.364    ,5.483    ,0.978
0     ,0      ,256    ,18446744073709551615 ,5.341    ,5.539    ,0.964

Rewriting with VMM API allows for memcmpeq-evex to be used with
evex512 by including "x86-evex512-vecs.h" at the top.

Complete check passes on x86-64.
This commit is contained in:
Noah Goldstein 2022-10-29 15:19:59 -05:00
parent 419c832aba
commit 2d2493a644

View File

@ -41,24 +41,53 @@
# define MEMCMPEQ __memcmpeq_evex # define MEMCMPEQ __memcmpeq_evex
# endif # endif
# ifndef VEC_SIZE
# include "x86-evex512-vecs.h"
# endif
# include "reg-macros.h"
# if VEC_SIZE == 32
# define TEST_ZERO_VCMP(reg) inc %VGPR(reg)
# define TEST_ZERO(reg) test %VGPR(reg), %VGPR(reg)
# define TO_32BIT_P1(reg) /* Do nothing. */
# define TO_32BIT_P2(reg) /* Do nothing. */
# define TO_32BIT(reg) /* Do nothing. */
# define VEC_CMP VPCMPEQ
# elif VEC_SIZE == 64
# define TEST_ZERO_VCMP(reg) TEST_ZERO(reg)
# define TEST_ZERO(reg) neg %VGPR(reg)
/* VEC_SIZE == 64 needs to reduce the 64-bit mask to a 32-bit
int. We have two methods for this. If the mask with branched
on, we use `neg` for the branch then `sbb` to get the 32-bit
return. If the mask was no branched on, we just use
`popcntq`. */
# define TO_32BIT_P1(reg) TEST_ZERO(reg)
# define TO_32BIT_P2(reg) sbb %VGPR_SZ(reg, 32), %VGPR_SZ(reg, 32)
# define TO_32BIT(reg) popcntq %reg, %reg
# define VEC_CMP VPCMPNEQ
# else
# error "Unsupported VEC_SIZE"
# endif
# define VMOVU_MASK vmovdqu8 # define VMOVU_MASK vmovdqu8
# define VMOVU vmovdqu64 # define VPCMPNEQ vpcmpneqb
# define VPCMP vpcmpub # define VPCMPEQ vpcmpeqb
# define VPTEST vptestmb # define VPTEST vptestmb
# define VEC_SIZE 32
# define PAGE_SIZE 4096 # define PAGE_SIZE 4096
# define YMM0 ymm16 .section SECTION(.text), "ax", @progbits
# define YMM1 ymm17
# define YMM2 ymm18
# define YMM3 ymm19
# define YMM4 ymm20
# define YMM5 ymm21
# define YMM6 ymm22
.section .text.evex, "ax", @progbits
ENTRY_P2ALIGN (MEMCMPEQ, 6) ENTRY_P2ALIGN (MEMCMPEQ, 6)
# ifdef __ILP32__ # ifdef __ILP32__
/* Clear the upper 32 bits. */ /* Clear the upper 32 bits. */
@ -69,47 +98,54 @@ ENTRY_P2ALIGN (MEMCMPEQ, 6)
ja L(more_1x_vec) ja L(more_1x_vec)
/* Create mask of bytes that are guranteed to be valid because /* Create mask of bytes that are guranteed to be valid because
of length (edx). Using masked movs allows us to skip checks for of length (edx). Using masked movs allows us to skip checks
page crosses/zero size. */ for page crosses/zero size. */
movl $-1, %ecx mov $-1, %VRAX
bzhil %edx, %ecx, %ecx bzhi %VRDX, %VRAX, %VRAX
kmovd %ecx, %k2 /* NB: A `jz` might be useful here. Page-faults that are
invalidated by predicate execution (the evex mask) can be
very slow. The expectation is this is not the norm so and
"most" code will not regularly call 'memcmp' with length = 0
and memory that is not wired up. */
KMOV %VRAX, %k2
/* Use masked loads as VEC_SIZE could page cross where length /* Use masked loads as VEC_SIZE could page cross where length
(edx) would not. */ (edx) would not. */
VMOVU_MASK (%rsi), %YMM2{%k2} VMOVU_MASK (%rsi), %VMM(2){%k2}{z}
VPCMP $4,(%rdi), %YMM2, %k1{%k2} VPCMPNEQ (%rdi), %VMM(2), %k1{%k2}
kmovd %k1, %eax KMOV %k1, %VRAX
TO_32BIT (VRAX)
ret ret
.p2align 4,, 3
L(last_1x_vec): L(last_1x_vec):
VMOVU -(VEC_SIZE * 1)(%rsi, %rdx), %YMM1 VMOVU -(VEC_SIZE * 1)(%rsi, %rdx), %VMM(1)
VPCMP $4, -(VEC_SIZE * 1)(%rdi, %rdx), %YMM1, %k1 VPCMPNEQ -(VEC_SIZE * 1)(%rdi, %rdx), %VMM(1), %k1
kmovd %k1, %eax KMOV %k1, %VRAX
TO_32BIT_P1 (rax)
L(return_neq0): L(return_neq0):
TO_32BIT_P2 (rax)
ret ret
.p2align 4,, 12
.p2align 4
L(more_1x_vec): L(more_1x_vec):
/* From VEC + 1 to 2 * VEC. */ /* From VEC + 1 to 2 * VEC. */
VMOVU (%rsi), %YMM1 VMOVU (%rsi), %VMM(1)
/* Use compare not equals to directly check for mismatch. */ /* Use compare not equals to directly check for mismatch. */
VPCMP $4,(%rdi), %YMM1, %k1 VPCMPNEQ (%rdi), %VMM(1), %k1
kmovd %k1, %eax KMOV %k1, %VRAX
testl %eax, %eax TEST_ZERO (rax)
jnz L(return_neq0) jnz L(return_neq0)
cmpq $(VEC_SIZE * 2), %rdx cmpq $(VEC_SIZE * 2), %rdx
jbe L(last_1x_vec) jbe L(last_1x_vec)
/* Check second VEC no matter what. */ /* Check second VEC no matter what. */
VMOVU VEC_SIZE(%rsi), %YMM2 VMOVU VEC_SIZE(%rsi), %VMM(2)
VPCMP $4, VEC_SIZE(%rdi), %YMM2, %k1 VPCMPNEQ VEC_SIZE(%rdi), %VMM(2), %k1
kmovd %k1, %eax KMOV %k1, %VRAX
testl %eax, %eax TEST_ZERO (rax)
jnz L(return_neq0) jnz L(return_neq0)
/* Less than 4 * VEC. */ /* Less than 4 * VEC. */
@ -117,16 +153,16 @@ L(more_1x_vec):
jbe L(last_2x_vec) jbe L(last_2x_vec)
/* Check third and fourth VEC no matter what. */ /* Check third and fourth VEC no matter what. */
VMOVU (VEC_SIZE * 2)(%rsi), %YMM3 VMOVU (VEC_SIZE * 2)(%rsi), %VMM(3)
VPCMP $4,(VEC_SIZE * 2)(%rdi), %YMM3, %k1 VEC_CMP (VEC_SIZE * 2)(%rdi), %VMM(3), %k1
kmovd %k1, %eax KMOV %k1, %VRAX
testl %eax, %eax TEST_ZERO_VCMP (rax)
jnz L(return_neq0) jnz L(return_neq0)
VMOVU (VEC_SIZE * 3)(%rsi), %YMM4 VMOVU (VEC_SIZE * 3)(%rsi), %VMM(4)
VPCMP $4,(VEC_SIZE * 3)(%rdi), %YMM4, %k1 VEC_CMP (VEC_SIZE * 3)(%rdi), %VMM(4), %k1
kmovd %k1, %eax KMOV %k1, %VRAX
testl %eax, %eax TEST_ZERO_VCMP (rax)
jnz L(return_neq0) jnz L(return_neq0)
/* Go to 4x VEC loop. */ /* Go to 4x VEC loop. */
@ -136,8 +172,8 @@ L(more_1x_vec):
/* Handle remainder of size = 4 * VEC + 1 to 8 * VEC without any /* Handle remainder of size = 4 * VEC + 1 to 8 * VEC without any
branches. */ branches. */
VMOVU -(VEC_SIZE * 4)(%rsi, %rdx), %YMM1 VMOVU -(VEC_SIZE * 1)(%rsi, %rdx), %VMM(1)
VMOVU -(VEC_SIZE * 3)(%rsi, %rdx), %YMM2 VMOVU -(VEC_SIZE * 2)(%rsi, %rdx), %VMM(2)
addq %rdx, %rdi addq %rdx, %rdi
/* Wait to load from s1 until addressed adjust due to /* Wait to load from s1 until addressed adjust due to
@ -145,26 +181,32 @@ L(more_1x_vec):
/* vpxor will be all 0s if s1 and s2 are equal. Otherwise it /* vpxor will be all 0s if s1 and s2 are equal. Otherwise it
will have some 1s. */ will have some 1s. */
vpxorq -(VEC_SIZE * 4)(%rdi), %YMM1, %YMM1 vpxorq -(VEC_SIZE * 1)(%rdi), %VMM(1), %VMM(1)
/* Ternary logic to xor -(VEC_SIZE * 3)(%rdi) with YMM2 while /* Ternary logic to xor -(VEC_SIZE * 3)(%rdi) with VEC(2) while
oring with YMM1. Result is stored in YMM1. */ oring with VEC(1). Result is stored in VEC(1). */
vpternlogd $0xde, -(VEC_SIZE * 3)(%rdi), %YMM1, %YMM2 vpternlogd $0xde, -(VEC_SIZE * 2)(%rdi), %VMM(1), %VMM(2)
VMOVU -(VEC_SIZE * 2)(%rsi, %rdx), %YMM3 cmpl $(VEC_SIZE * 6), %edx
vpxorq -(VEC_SIZE * 2)(%rdi), %YMM3, %YMM3 jbe L(4x_last_2x_vec)
/* Or together YMM1, YMM2, and YMM3 into YMM3. */
VMOVU -(VEC_SIZE)(%rsi, %rdx), %YMM4
vpxorq -(VEC_SIZE)(%rdi), %YMM4, %YMM4
/* Or together YMM2, YMM3, and YMM4 into YMM4. */ VMOVU -(VEC_SIZE * 3)(%rsi, %rdx), %VMM(3)
vpternlogd $0xfe, %YMM2, %YMM3, %YMM4 vpxorq -(VEC_SIZE * 3)(%rdi), %VMM(3), %VMM(3)
/* Or together VEC(1), VEC(2), and VEC(3) into VEC(3). */
VMOVU -(VEC_SIZE * 4)(%rsi, %rdx), %VMM(4)
vpxorq -(VEC_SIZE * 4)(%rdi), %VMM(4), %VMM(4)
/* Compare YMM4 with 0. If any 1s s1 and s2 don't match. */ /* Or together VEC(4), VEC(3), and VEC(2) into VEC(2). */
VPTEST %YMM4, %YMM4, %k1 vpternlogd $0xfe, %VMM(4), %VMM(3), %VMM(2)
kmovd %k1, %eax
/* Compare VEC(4) with 0. If any 1s s1 and s2 don't match. */
L(4x_last_2x_vec):
VPTEST %VMM(2), %VMM(2), %k1
KMOV %k1, %VRAX
TO_32BIT (VRAX)
ret ret
.p2align 4
.p2align 4,, 10
L(more_8x_vec): L(more_8x_vec):
/* Set end of s1 in rdx. */ /* Set end of s1 in rdx. */
leaq -(VEC_SIZE * 4)(%rdi, %rdx), %rdx leaq -(VEC_SIZE * 4)(%rdi, %rdx), %rdx
@ -175,67 +217,80 @@ L(more_8x_vec):
andq $-VEC_SIZE, %rdi andq $-VEC_SIZE, %rdi
/* Adjust because first 4x vec where check already. */ /* Adjust because first 4x vec where check already. */
subq $-(VEC_SIZE * 4), %rdi subq $-(VEC_SIZE * 4), %rdi
.p2align 4 .p2align 5,, 12
.p2align 4,, 8
L(loop_4x_vec): L(loop_4x_vec):
VMOVU (%rsi, %rdi), %YMM1 VMOVU (%rsi, %rdi), %VMM(1)
vpxorq (%rdi), %YMM1, %YMM1 vpxorq (%rdi), %VMM(1), %VMM(1)
VMOVU VEC_SIZE(%rsi, %rdi), %YMM2 VMOVU VEC_SIZE(%rsi, %rdi), %VMM(2)
vpternlogd $0xde,(VEC_SIZE)(%rdi), %YMM1, %YMM2 vpternlogd $0xde, (VEC_SIZE)(%rdi), %VMM(1), %VMM(2)
VMOVU (VEC_SIZE * 2)(%rsi, %rdi), %YMM3 VMOVU (VEC_SIZE * 2)(%rsi, %rdi), %VMM(3)
vpxorq (VEC_SIZE * 2)(%rdi), %YMM3, %YMM3 vpxorq (VEC_SIZE * 2)(%rdi), %VMM(3), %VMM(3)
VMOVU (VEC_SIZE * 3)(%rsi, %rdi), %YMM4 VMOVU (VEC_SIZE * 3)(%rsi, %rdi), %VMM(4)
vpxorq (VEC_SIZE * 3)(%rdi), %YMM4, %YMM4 vpxorq (VEC_SIZE * 3)(%rdi), %VMM(4), %VMM(4)
vpternlogd $0xfe, %YMM2, %YMM3, %YMM4 vpternlogd $0xfe, %VMM(2), %VMM(3), %VMM(4)
VPTEST %YMM4, %YMM4, %k1 VPTEST %VMM(4), %VMM(4), %k1
kmovd %k1, %eax KMOV %k1, %VRAX
testl %eax, %eax TEST_ZERO (rax)
jnz L(return_neq2) jnz L(return_neq2)
subq $-(VEC_SIZE * 4), %rdi subq $-(VEC_SIZE * 4), %rdi
cmpq %rdx, %rdi cmpq %rdx, %rdi
jb L(loop_4x_vec) jb L(loop_4x_vec)
subq %rdx, %rdi subq %rdx, %rdi
VMOVU (VEC_SIZE * 3)(%rsi, %rdx), %YMM4
vpxorq (VEC_SIZE * 3)(%rdx), %YMM4, %YMM4 VMOVU (VEC_SIZE * 3)(%rsi, %rdx), %VMM(4)
vpxorq (VEC_SIZE * 3)(%rdx), %VMM(4), %VMM(4)
/* rdi has 4 * VEC_SIZE - remaining length. */ /* rdi has 4 * VEC_SIZE - remaining length. */
cmpl $(VEC_SIZE * 3), %edi
jae L(8x_last_1x_vec)
/* Load regardless of branch. */ /* Load regardless of branch. */
VMOVU (VEC_SIZE * 2)(%rsi, %rdx), %YMM3 VMOVU (VEC_SIZE * 2)(%rsi, %rdx), %VMM(3)
/* Ternary logic to xor (VEC_SIZE * 2)(%rdx) with YMM3 while /* Ternary logic to xor (VEC_SIZE * 2)(%rdx) with VEC(3) while
oring with YMM4. Result is stored in YMM4. */ oring with VEC(4). Result is stored in VEC(4). */
vpternlogd $0xf6,(VEC_SIZE * 2)(%rdx), %YMM3, %YMM4 vpternlogd $0xf6, (VEC_SIZE * 2)(%rdx), %VMM(3), %VMM(4)
/* Seperate logic as we can only use testb for VEC_SIZE == 64.
*/
# if VEC_SIZE == 64
testb %dil, %dil
js L(8x_last_2x_vec)
# else
cmpl $(VEC_SIZE * 2), %edi cmpl $(VEC_SIZE * 2), %edi
jae L(8x_last_2x_vec) jge L(8x_last_2x_vec)
# endif
VMOVU VEC_SIZE(%rsi, %rdx), %YMM2 VMOVU VEC_SIZE(%rsi, %rdx), %VMM(2)
vpxorq VEC_SIZE(%rdx), %YMM2, %YMM2 vpxorq VEC_SIZE(%rdx), %VMM(2), %VMM(2)
VMOVU (%rsi, %rdx), %YMM1 VMOVU (%rsi, %rdx), %VMM(1)
vpxorq (%rdx), %YMM1, %YMM1 vpxorq (%rdx), %VMM(1), %VMM(1)
vpternlogd $0xfe, %YMM1, %YMM2, %YMM4 vpternlogd $0xfe, %VMM(1), %VMM(2), %VMM(4)
L(8x_last_1x_vec): L(8x_last_1x_vec):
L(8x_last_2x_vec): L(8x_last_2x_vec):
VPTEST %YMM4, %YMM4, %k1 VPTEST %VMM(4), %VMM(4), %k1
kmovd %k1, %eax KMOV %k1, %VRAX
TO_32BIT_P1 (rax)
L(return_neq2): L(return_neq2):
TO_32BIT_P2 (rax)
ret ret
.p2align 4,, 8 .p2align 4,, 4
L(last_2x_vec): L(last_2x_vec):
VMOVU -(VEC_SIZE * 2)(%rsi, %rdx), %YMM1 VMOVU -(VEC_SIZE * 2)(%rsi, %rdx), %VMM(1)
vpxorq -(VEC_SIZE * 2)(%rdi, %rdx), %YMM1, %YMM1 vpxorq -(VEC_SIZE * 2)(%rdi, %rdx), %VMM(1), %VMM(1)
VMOVU -(VEC_SIZE * 1)(%rsi, %rdx), %YMM2 VMOVU -(VEC_SIZE * 1)(%rsi, %rdx), %VMM(2)
vpternlogd $0xde, -(VEC_SIZE * 1)(%rdi, %rdx), %YMM1, %YMM2 vpternlogd $0xde, -(VEC_SIZE * 1)(%rdi, %rdx), %VMM(1), %VMM(2)
VPTEST %YMM2, %YMM2, %k1 VPTEST %VMM(2), %VMM(2), %k1
kmovd %k1, %eax KMOV %k1, %VRAX
TO_32BIT (VRAX)
ret ret
/* 1 Bytes from next cache line. */ /* evex256: 1 Bytes from next cache line. evex512: 15 Bytes from
next cache line. */
END (MEMCMPEQ) END (MEMCMPEQ)
#endif #endif