[bigint] Fix bugs in FFT multiplication

A single ClusterFuzz report flushed out two minor issues in the
bit fiddling routines.

Bug: chromium:1227752,v8:11515
Change-Id: I16ab914b7c3859f55aa141ced371dd80171d0cb5
Reviewed-on: https://chromium-review.googlesource.com/c/v8/v8/+/3017809
Reviewed-by: Maya Lekova <mslekova@chromium.org>
Commit-Queue: Jakob Kummerow <jkummerow@chromium.org>
Cr-Commit-Position: refs/heads/master@{#75678}
This commit is contained in:
Jakob Kummerow 2021-07-11 17:44:46 +02:00 committed by V8 LUCI CQ
parent 0665568de4
commit 6018d479b6
2 changed files with 39 additions and 17 deletions

View File

@ -67,7 +67,8 @@ void ModFnDoubleWidth(digit_t* dest, const digit_t* src, int len) {
for (int i = 0; i < K; i++) { for (int i = 0; i < K; i++) {
dest[i] = digit_sub2(src[i], src[i + K], borrow, &borrow); dest[i] = digit_sub2(src[i], src[i + K], borrow, &borrow);
} }
dest[K] = digit_sub(0, borrow, &borrow); dest[K] = digit_sub2(0, src[2 * K], borrow, &borrow);
// {borrow} may be non-zero here, that's OK as {ModFn} will take care of it.
ModFn(dest, len); ModFn(dest, len);
} }
@ -191,43 +192,57 @@ void ShiftModFn(digit_t* result, const digit_t* input, int power_of_two, int K,
// it turns out that: // it turns out that:
// x * 2^{2K+m} == x * 2^m mod 2^K + 1. // x * 2^{2K+m} == x * 2^m mod 2^K + 1.
while (digit_shift >= 2 * K) digit_shift -= 2 * K; // Faster than '%'! while (digit_shift >= 2 * K) digit_shift -= 2 * K; // Faster than '%'!
digit_t borrow = 0;
if (digit_shift >= K) { if (digit_shift >= K) {
return ShiftModFn_Large(result, input, digit_shift, bits_shift, K); return ShiftModFn_Large(result, input, digit_shift, bits_shift, K);
} }
digit_t borrow = 0;
if (bits_shift == 0) { if (bits_shift == 0) {
// We do a single pass over {input}, starting by copying digits [i1] to
// [iX-1] to result indices digit_shift+1 to K-1.
int i = 1; int i = 1;
// Regular loop: read input digits unless we know they are zero. // Read input digits unless we know they are zero.
int cap = std::min(K - digit_shift, zero_above); int cap = std::min(K - digit_shift, zero_above);
for (; i < cap; i++) { for (; i < cap; i++) {
result[i + digit_shift] = input[i]; result[i + digit_shift] = input[i];
} }
// Any remaining work can hard-code the knowledge that input[i] == 0.
for (; i < K - digit_shift; i++) {
DCHECK(input[i] == 0); // NOLINT(readability/check)
result[i + digit_shift] = 0;
}
// Second phase: subtract input digits [iX] to [iK] from (virtually) zero-
// initialized result indices 0 to digit_shift-1.
cap = std::min(K, zero_above); cap = std::min(K, zero_above);
for (; i < cap; i++) { for (; i < cap; i++) {
digit_t d = input[i]; digit_t d = input[i];
result[i - K + digit_shift] = digit_sub2(0, d, borrow, &borrow); result[i - K + digit_shift] = digit_sub2(0, d, borrow, &borrow);
} }
// Fallthrough: any remaining work can hard-code the knowledge that // Any remaining work can hard-code the knowledge that input[i] == 0.
// input[i] == 0.
for (; i < K - digit_shift; i++) {
DCHECK(input[i] == 0); // NOLINT(readability/check)
result[i + digit_shift] = 0;
}
for (; i < K; i++) { for (; i < K; i++) {
DCHECK(input[i] == 0); // NOLINT(readability/check) DCHECK(input[i] == 0); // NOLINT(readability/check)
result[i - K + digit_shift] = digit_sub(0, borrow, &borrow); result[i - K + digit_shift] = digit_sub(0, borrow, &borrow);
} }
// Last step: subtract [iK] from [i0] and store at result index digit_shift.
result[digit_shift] = digit_sub2(input[0], input[K], borrow, &borrow); result[digit_shift] = digit_sub2(input[0], input[K], borrow, &borrow);
} else { } else {
// Same flow as before, but taking bits_shift != 0 into account.
// First phase: result indices digit_shift+1 to K.
digit_t carry = 0; digit_t carry = 0;
int i = 0; int i = 0;
// Regular loop: read input digits unless we know they are zero. // Read input digits unless we know they are zero.
int cap = std::min(K - digit_shift, zero_above); int cap = std::min(K - digit_shift, zero_above);
for (; i < cap; i++) { for (; i < cap; i++) {
digit_t d = input[i]; digit_t d = input[i];
result[i + digit_shift] = (d << bits_shift) | carry; result[i + digit_shift] = (d << bits_shift) | carry;
carry = d >> (kDigitBits - bits_shift); carry = d >> (kDigitBits - bits_shift);
} }
// Any remaining work can hard-code the knowledge that input[i] == 0.
for (; i < K - digit_shift; i++) {
DCHECK(input[i] == 0); // NOLINT(readability/check)
result[i + digit_shift] = carry;
carry = 0;
}
// Second phase: result indices 0 to digit_shift - 1.
cap = std::min(K, zero_above); cap = std::min(K, zero_above);
for (; i < cap; i++) { for (; i < cap; i++) {
digit_t d = input[i]; digit_t d = input[i];
@ -235,13 +250,7 @@ void ShiftModFn(digit_t* result, const digit_t* input, int power_of_two, int K,
digit_sub2(0, (d << bits_shift) | carry, borrow, &borrow); digit_sub2(0, (d << bits_shift) | carry, borrow, &borrow);
carry = d >> (kDigitBits - bits_shift); carry = d >> (kDigitBits - bits_shift);
} }
// Fallthrough: any remaining work can hard-code the knowledge that // Any remaining work can hard-code the knowledge that input[i] == 0.
// input[i] == 0.
for (; i < K - digit_shift; i++) {
DCHECK(input[i] == 0); // NOLINT(readability/check)
result[i + digit_shift] = carry;
carry = 0;
}
if (i < K) { if (i < K) {
DCHECK(input[i] == 0); // NOLINT(readability/check) DCHECK(input[i] == 0); // NOLINT(readability/check)
result[i - K + digit_shift] = digit_sub2(0, carry, borrow, &borrow); result[i - K + digit_shift] = digit_sub2(0, carry, borrow, &borrow);
@ -252,6 +261,7 @@ void ShiftModFn(digit_t* result, const digit_t* input, int power_of_two, int K,
DCHECK(input[i] == 0); // NOLINT(readability/check) DCHECK(input[i] == 0); // NOLINT(readability/check)
result[i - K + digit_shift] = digit_sub(0, borrow, &borrow); result[i - K + digit_shift] = digit_sub(0, borrow, &borrow);
} }
// Last step: compute result[digit_shift].
digit_t d = input[K]; digit_t d = input[K];
result[digit_shift] = digit_sub2( result[digit_shift] = digit_sub2(
result[digit_shift], (d << bits_shift) | carry, borrow, &borrow); result[digit_shift], (d << bits_shift) | carry, borrow, &borrow);

View File

@ -0,0 +1,12 @@
// Copyright 2021 the V8 project authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
function regress_1227752(power) {
let a = 2n ** power;
let a_squared = a * a;
let expected = 2n ** (2n * power);
assertEquals(expected, a_squared);
}
regress_1227752(48016n); // This triggered the bug on 32-bit platforms.
regress_1227752(95960n); // This triggered the bug on 64-bit platforms.