x86: Use VMM API in memcmpeq-evex.S and minor changes

Changes to generated code are:
    1. In a few places use `vpcmpeqb` instead of `vpcmpneq` to save a
       byte of code size.
    2. Add a branch for length <= (VEC_SIZE * 6) as opposed to doing
       the entire block of [VEC_SIZE * 4 + 1, VEC_SIZE * 8] in a
       single basic-block (the space to add the extra branch without
       changing code size is bought with the above change).

Change (2) has roughly a 20-25% speedup for sizes in [VEC_SIZE * 4 +
1, VEC_SIZE * 6] and negligible to no-cost for [VEC_SIZE * 6 + 1,
VEC_SIZE * 8]

From N=10 runs on Tigerlake:

align1,align2 ,length ,result               ,New Time ,Cur Time ,New Time / Old Time
0     ,0      ,129    ,0                    ,5.404    ,6.887    ,0.785
0     ,0      ,129    ,1                    ,5.308    ,6.826    ,0.778
0     ,0      ,129    ,18446744073709551615 ,5.359    ,6.823    ,0.785
0     ,0      ,161    ,0                    ,5.284    ,6.827    ,0.774
0     ,0      ,161    ,1                    ,5.317    ,6.745    ,0.788
0     ,0      ,161    ,18446744073709551615 ,5.406    ,6.778    ,0.798

0     ,0      ,193    ,0                    ,6.804    ,6.802    ,1.000
0     ,0      ,193    ,1                    ,6.950    ,6.754    ,1.029
0     ,0      ,193    ,18446744073709551615 ,6.792    ,6.719    ,1.011
0     ,0      ,225    ,0                    ,6.625    ,6.699    ,0.989
0     ,0      ,225    ,1                    ,6.776    ,6.735    ,1.003
0     ,0      ,225    ,18446744073709551615 ,6.758    ,6.738    ,0.992
0     ,0      ,256    ,0                    ,5.402    ,5.462    ,0.989
0     ,0      ,256    ,1                    ,5.364    ,5.483    ,0.978
0     ,0      ,256    ,18446744073709551615 ,5.341    ,5.539    ,0.964

Rewriting with VMM API allows for memcmpeq-evex to be used with
evex512 by including "x86-evex512-vecs.h" at the top.

Complete check passes on x86-64.
This commit is contained in:
Noah Goldstein 2022-10-29 15:19:59 -05:00
parent 419c832aba
commit 2d2493a644

View File

@ -41,24 +41,53 @@
# define MEMCMPEQ __memcmpeq_evex
# endif
# ifndef VEC_SIZE
# include "x86-evex512-vecs.h"
# endif
# include "reg-macros.h"
# if VEC_SIZE == 32
# define TEST_ZERO_VCMP(reg) inc %VGPR(reg)
# define TEST_ZERO(reg) test %VGPR(reg), %VGPR(reg)
# define TO_32BIT_P1(reg) /* Do nothing. */
# define TO_32BIT_P2(reg) /* Do nothing. */
# define TO_32BIT(reg) /* Do nothing. */
# define VEC_CMP VPCMPEQ
# elif VEC_SIZE == 64
# define TEST_ZERO_VCMP(reg) TEST_ZERO(reg)
# define TEST_ZERO(reg) neg %VGPR(reg)
/* VEC_SIZE == 64 needs to reduce the 64-bit mask to a 32-bit
int. We have two methods for this. If the mask with branched
on, we use `neg` for the branch then `sbb` to get the 32-bit
return. If the mask was no branched on, we just use
`popcntq`. */
# define TO_32BIT_P1(reg) TEST_ZERO(reg)
# define TO_32BIT_P2(reg) sbb %VGPR_SZ(reg, 32), %VGPR_SZ(reg, 32)
# define TO_32BIT(reg) popcntq %reg, %reg
# define VEC_CMP VPCMPNEQ
# else
# error "Unsupported VEC_SIZE"
# endif
# define VMOVU_MASK vmovdqu8
# define VMOVU vmovdqu64
# define VPCMP vpcmpub
# define VPCMPNEQ vpcmpneqb
# define VPCMPEQ vpcmpeqb
# define VPTEST vptestmb
# define VEC_SIZE 32
# define PAGE_SIZE 4096
# define YMM0 ymm16
# define YMM1 ymm17
# define YMM2 ymm18
# define YMM3 ymm19
# define YMM4 ymm20
# define YMM5 ymm21
# define YMM6 ymm22
.section .text.evex, "ax", @progbits
.section SECTION(.text), "ax", @progbits
ENTRY_P2ALIGN (MEMCMPEQ, 6)
# ifdef __ILP32__
/* Clear the upper 32 bits. */
@ -69,47 +98,54 @@ ENTRY_P2ALIGN (MEMCMPEQ, 6)
ja L(more_1x_vec)
/* Create mask of bytes that are guranteed to be valid because
of length (edx). Using masked movs allows us to skip checks for
page crosses/zero size. */
movl $-1, %ecx
bzhil %edx, %ecx, %ecx
kmovd %ecx, %k2
of length (edx). Using masked movs allows us to skip checks
for page crosses/zero size. */
mov $-1, %VRAX
bzhi %VRDX, %VRAX, %VRAX
/* NB: A `jz` might be useful here. Page-faults that are
invalidated by predicate execution (the evex mask) can be
very slow. The expectation is this is not the norm so and
"most" code will not regularly call 'memcmp' with length = 0
and memory that is not wired up. */
KMOV %VRAX, %k2
/* Use masked loads as VEC_SIZE could page cross where length
(edx) would not. */
VMOVU_MASK (%rsi), %YMM2{%k2}
VPCMP $4,(%rdi), %YMM2, %k1{%k2}
kmovd %k1, %eax
VMOVU_MASK (%rsi), %VMM(2){%k2}{z}
VPCMPNEQ (%rdi), %VMM(2), %k1{%k2}
KMOV %k1, %VRAX
TO_32BIT (VRAX)
ret
.p2align 4,, 3
L(last_1x_vec):
VMOVU -(VEC_SIZE * 1)(%rsi, %rdx), %YMM1
VPCMP $4, -(VEC_SIZE * 1)(%rdi, %rdx), %YMM1, %k1
kmovd %k1, %eax
VMOVU -(VEC_SIZE * 1)(%rsi, %rdx), %VMM(1)
VPCMPNEQ -(VEC_SIZE * 1)(%rdi, %rdx), %VMM(1), %k1
KMOV %k1, %VRAX
TO_32BIT_P1 (rax)
L(return_neq0):
TO_32BIT_P2 (rax)
ret
.p2align 4
.p2align 4,, 12
L(more_1x_vec):
/* From VEC + 1 to 2 * VEC. */
VMOVU (%rsi), %YMM1
VMOVU (%rsi), %VMM(1)
/* Use compare not equals to directly check for mismatch. */
VPCMP $4,(%rdi), %YMM1, %k1
kmovd %k1, %eax
testl %eax, %eax
VPCMPNEQ (%rdi), %VMM(1), %k1
KMOV %k1, %VRAX
TEST_ZERO (rax)
jnz L(return_neq0)
cmpq $(VEC_SIZE * 2), %rdx
jbe L(last_1x_vec)
/* Check second VEC no matter what. */
VMOVU VEC_SIZE(%rsi), %YMM2
VPCMP $4, VEC_SIZE(%rdi), %YMM2, %k1
kmovd %k1, %eax
testl %eax, %eax
VMOVU VEC_SIZE(%rsi), %VMM(2)
VPCMPNEQ VEC_SIZE(%rdi), %VMM(2), %k1
KMOV %k1, %VRAX
TEST_ZERO (rax)
jnz L(return_neq0)
/* Less than 4 * VEC. */
@ -117,16 +153,16 @@ L(more_1x_vec):
jbe L(last_2x_vec)
/* Check third and fourth VEC no matter what. */
VMOVU (VEC_SIZE * 2)(%rsi), %YMM3
VPCMP $4,(VEC_SIZE * 2)(%rdi), %YMM3, %k1
kmovd %k1, %eax
testl %eax, %eax
VMOVU (VEC_SIZE * 2)(%rsi), %VMM(3)
VEC_CMP (VEC_SIZE * 2)(%rdi), %VMM(3), %k1
KMOV %k1, %VRAX
TEST_ZERO_VCMP (rax)
jnz L(return_neq0)
VMOVU (VEC_SIZE * 3)(%rsi), %YMM4
VPCMP $4,(VEC_SIZE * 3)(%rdi), %YMM4, %k1
kmovd %k1, %eax
testl %eax, %eax
VMOVU (VEC_SIZE * 3)(%rsi), %VMM(4)
VEC_CMP (VEC_SIZE * 3)(%rdi), %VMM(4), %k1
KMOV %k1, %VRAX
TEST_ZERO_VCMP (rax)
jnz L(return_neq0)
/* Go to 4x VEC loop. */
@ -136,8 +172,8 @@ L(more_1x_vec):
/* Handle remainder of size = 4 * VEC + 1 to 8 * VEC without any
branches. */
VMOVU -(VEC_SIZE * 4)(%rsi, %rdx), %YMM1
VMOVU -(VEC_SIZE * 3)(%rsi, %rdx), %YMM2
VMOVU -(VEC_SIZE * 1)(%rsi, %rdx), %VMM(1)
VMOVU -(VEC_SIZE * 2)(%rsi, %rdx), %VMM(2)
addq %rdx, %rdi
/* Wait to load from s1 until addressed adjust due to
@ -145,26 +181,32 @@ L(more_1x_vec):
/* vpxor will be all 0s if s1 and s2 are equal. Otherwise it
will have some 1s. */
vpxorq -(VEC_SIZE * 4)(%rdi), %YMM1, %YMM1
/* Ternary logic to xor -(VEC_SIZE * 3)(%rdi) with YMM2 while
oring with YMM1. Result is stored in YMM1. */
vpternlogd $0xde, -(VEC_SIZE * 3)(%rdi), %YMM1, %YMM2
vpxorq -(VEC_SIZE * 1)(%rdi), %VMM(1), %VMM(1)
/* Ternary logic to xor -(VEC_SIZE * 3)(%rdi) with VEC(2) while
oring with VEC(1). Result is stored in VEC(1). */
vpternlogd $0xde, -(VEC_SIZE * 2)(%rdi), %VMM(1), %VMM(2)
VMOVU -(VEC_SIZE * 2)(%rsi, %rdx), %YMM3
vpxorq -(VEC_SIZE * 2)(%rdi), %YMM3, %YMM3
/* Or together YMM1, YMM2, and YMM3 into YMM3. */
VMOVU -(VEC_SIZE)(%rsi, %rdx), %YMM4
vpxorq -(VEC_SIZE)(%rdi), %YMM4, %YMM4
cmpl $(VEC_SIZE * 6), %edx
jbe L(4x_last_2x_vec)
/* Or together YMM2, YMM3, and YMM4 into YMM4. */
vpternlogd $0xfe, %YMM2, %YMM3, %YMM4
VMOVU -(VEC_SIZE * 3)(%rsi, %rdx), %VMM(3)
vpxorq -(VEC_SIZE * 3)(%rdi), %VMM(3), %VMM(3)
/* Or together VEC(1), VEC(2), and VEC(3) into VEC(3). */
VMOVU -(VEC_SIZE * 4)(%rsi, %rdx), %VMM(4)
vpxorq -(VEC_SIZE * 4)(%rdi), %VMM(4), %VMM(4)
/* Compare YMM4 with 0. If any 1s s1 and s2 don't match. */
VPTEST %YMM4, %YMM4, %k1
kmovd %k1, %eax
/* Or together VEC(4), VEC(3), and VEC(2) into VEC(2). */
vpternlogd $0xfe, %VMM(4), %VMM(3), %VMM(2)
/* Compare VEC(4) with 0. If any 1s s1 and s2 don't match. */
L(4x_last_2x_vec):
VPTEST %VMM(2), %VMM(2), %k1
KMOV %k1, %VRAX
TO_32BIT (VRAX)
ret
.p2align 4
.p2align 4,, 10
L(more_8x_vec):
/* Set end of s1 in rdx. */
leaq -(VEC_SIZE * 4)(%rdi, %rdx), %rdx
@ -175,67 +217,80 @@ L(more_8x_vec):
andq $-VEC_SIZE, %rdi
/* Adjust because first 4x vec where check already. */
subq $-(VEC_SIZE * 4), %rdi
.p2align 4
.p2align 5,, 12
.p2align 4,, 8
L(loop_4x_vec):
VMOVU (%rsi, %rdi), %YMM1
vpxorq (%rdi), %YMM1, %YMM1
VMOVU (%rsi, %rdi), %VMM(1)
vpxorq (%rdi), %VMM(1), %VMM(1)
VMOVU VEC_SIZE(%rsi, %rdi), %YMM2
vpternlogd $0xde,(VEC_SIZE)(%rdi), %YMM1, %YMM2
VMOVU VEC_SIZE(%rsi, %rdi), %VMM(2)
vpternlogd $0xde, (VEC_SIZE)(%rdi), %VMM(1), %VMM(2)
VMOVU (VEC_SIZE * 2)(%rsi, %rdi), %YMM3
vpxorq (VEC_SIZE * 2)(%rdi), %YMM3, %YMM3
VMOVU (VEC_SIZE * 2)(%rsi, %rdi), %VMM(3)
vpxorq (VEC_SIZE * 2)(%rdi), %VMM(3), %VMM(3)
VMOVU (VEC_SIZE * 3)(%rsi, %rdi), %YMM4
vpxorq (VEC_SIZE * 3)(%rdi), %YMM4, %YMM4
VMOVU (VEC_SIZE * 3)(%rsi, %rdi), %VMM(4)
vpxorq (VEC_SIZE * 3)(%rdi), %VMM(4), %VMM(4)
vpternlogd $0xfe, %YMM2, %YMM3, %YMM4
VPTEST %YMM4, %YMM4, %k1
kmovd %k1, %eax
testl %eax, %eax
vpternlogd $0xfe, %VMM(2), %VMM(3), %VMM(4)
VPTEST %VMM(4), %VMM(4), %k1
KMOV %k1, %VRAX
TEST_ZERO (rax)
jnz L(return_neq2)
subq $-(VEC_SIZE * 4), %rdi
cmpq %rdx, %rdi
jb L(loop_4x_vec)
subq %rdx, %rdi
VMOVU (VEC_SIZE * 3)(%rsi, %rdx), %YMM4
vpxorq (VEC_SIZE * 3)(%rdx), %YMM4, %YMM4
VMOVU (VEC_SIZE * 3)(%rsi, %rdx), %VMM(4)
vpxorq (VEC_SIZE * 3)(%rdx), %VMM(4), %VMM(4)
/* rdi has 4 * VEC_SIZE - remaining length. */
cmpl $(VEC_SIZE * 3), %edi
jae L(8x_last_1x_vec)
/* Load regardless of branch. */
VMOVU (VEC_SIZE * 2)(%rsi, %rdx), %YMM3
/* Ternary logic to xor (VEC_SIZE * 2)(%rdx) with YMM3 while
oring with YMM4. Result is stored in YMM4. */
vpternlogd $0xf6,(VEC_SIZE * 2)(%rdx), %YMM3, %YMM4
VMOVU (VEC_SIZE * 2)(%rsi, %rdx), %VMM(3)
/* Ternary logic to xor (VEC_SIZE * 2)(%rdx) with VEC(3) while
oring with VEC(4). Result is stored in VEC(4). */
vpternlogd $0xf6, (VEC_SIZE * 2)(%rdx), %VMM(3), %VMM(4)
/* Seperate logic as we can only use testb for VEC_SIZE == 64.
*/
# if VEC_SIZE == 64
testb %dil, %dil
js L(8x_last_2x_vec)
# else
cmpl $(VEC_SIZE * 2), %edi
jae L(8x_last_2x_vec)
jge L(8x_last_2x_vec)
# endif
VMOVU VEC_SIZE(%rsi, %rdx), %YMM2
vpxorq VEC_SIZE(%rdx), %YMM2, %YMM2
VMOVU VEC_SIZE(%rsi, %rdx), %VMM(2)
vpxorq VEC_SIZE(%rdx), %VMM(2), %VMM(2)
VMOVU (%rsi, %rdx), %YMM1
vpxorq (%rdx), %YMM1, %YMM1
VMOVU (%rsi, %rdx), %VMM(1)
vpxorq (%rdx), %VMM(1), %VMM(1)
vpternlogd $0xfe, %YMM1, %YMM2, %YMM4
vpternlogd $0xfe, %VMM(1), %VMM(2), %VMM(4)
L(8x_last_1x_vec):
L(8x_last_2x_vec):
VPTEST %YMM4, %YMM4, %k1
kmovd %k1, %eax
VPTEST %VMM(4), %VMM(4), %k1
KMOV %k1, %VRAX
TO_32BIT_P1 (rax)
L(return_neq2):
TO_32BIT_P2 (rax)
ret
.p2align 4,, 8
.p2align 4,, 4
L(last_2x_vec):
VMOVU -(VEC_SIZE * 2)(%rsi, %rdx), %YMM1
vpxorq -(VEC_SIZE * 2)(%rdi, %rdx), %YMM1, %YMM1
VMOVU -(VEC_SIZE * 1)(%rsi, %rdx), %YMM2
vpternlogd $0xde, -(VEC_SIZE * 1)(%rdi, %rdx), %YMM1, %YMM2
VPTEST %YMM2, %YMM2, %k1
kmovd %k1, %eax
VMOVU -(VEC_SIZE * 2)(%rsi, %rdx), %VMM(1)
vpxorq -(VEC_SIZE * 2)(%rdi, %rdx), %VMM(1), %VMM(1)
VMOVU -(VEC_SIZE * 1)(%rsi, %rdx), %VMM(2)
vpternlogd $0xde, -(VEC_SIZE * 1)(%rdi, %rdx), %VMM(1), %VMM(2)
VPTEST %VMM(2), %VMM(2), %k1
KMOV %k1, %VRAX
TO_32BIT (VRAX)
ret
/* 1 Bytes from next cache line. */
/* evex256: 1 Bytes from next cache line. evex512: 15 Bytes from
next cache line. */
END (MEMCMPEQ)
#endif