Reland "[bigint] FFT-based multiplication"

The Schönhage-Strassen method for *very* large inputs.

This is a reland of 347ba35716,
with added zero-initialization to pacify MSan (spurious report).

Originally:
> Reviewed-on: https://chromium-review.googlesource.com/c/v8/v8/+/3000742
> Commit-Queue: Jakob Kummerow <jkummerow@chromium.org>
> Reviewed-by: Maya Lekova <mslekova@chromium.org>
> Cr-Commit-Position: refs/heads/master@{#75659}

Bug: v8:11515
Change-Id: Ieac6e174bde6eb09af0a9a9a49969feabca79e81
Reviewed-on: https://chromium-review.googlesource.com/c/v8/v8/+/3018081
Reviewed-by: Maya Lekova <mslekova@chromium.org>
Commit-Queue: Jakob Kummerow <jkummerow@chromium.org>
Cr-Commit-Position: refs/heads/master@{#75663}
This commit is contained in:
Jakob Kummerow 2021-07-09 15:12:52 +02:00 committed by V8 LUCI CQ
parent 2a6b205594
commit afa6126921
7 changed files with 822 additions and 10 deletions

View File

@ -2510,6 +2510,7 @@ filegroup(
"src/bigint/div-helpers.cc",
"src/bigint/div-helpers.h",
"src/bigint/div-schoolbook.cc",
"src/bigint/mul-fft.cc",
"src/bigint/mul-karatsuba.cc",
"src/bigint/mul-schoolbook.cc",
"src/bigint/mul-toom.cc",

View File

@ -4964,7 +4964,10 @@ v8_source_set("v8_bigint") {
]
if (v8_advanced_bigint_algorithms) {
sources += [ "src/bigint/mul-toom.cc" ]
sources += [
"src/bigint/mul-fft.cc",
"src/bigint/mul-toom.cc",
]
defines = [ "V8_ADVANCED_BIGINT_ALGORITHMS" ]
}

View File

@ -35,7 +35,8 @@ void ProcessorImpl::Multiply(RWDigits Z, Digits X, Digits Y) {
return MultiplyKaratsuba(Z, X, Y);
#else
if (Y.len() < kToomThreshold) return MultiplyKaratsuba(Z, X, Y);
return MultiplyToomCook(Z, X, Y);
if (Y.len() < kFftThreshold) return MultiplyToomCook(Z, X, Y);
return MultiplyFFT(Z, X, Y);
#endif
}

View File

@ -14,6 +14,9 @@ namespace bigint {
constexpr int kKaratsubaThreshold = 34;
constexpr int kToomThreshold = 193;
constexpr int kFftThreshold = 1500;
constexpr int kFftInnerThreshold = 200;
constexpr int kBurnikelThreshold = 57;
class ProcessorImpl : public Processor {
@ -42,6 +45,8 @@ class ProcessorImpl : public Processor {
#if V8_ADVANCED_BIGINT_ALGORITHMS
void MultiplyToomCook(RWDigits Z, Digits X, Digits Y);
void Toom3Main(RWDigits Z, Digits X, Digits Y);
void MultiplyFFT(RWDigits Z, Digits X, Digits Y);
#endif // V8_ADVANCED_BIGINT_ALGORITHMS
// {out_length} initially contains the allocated capacity of {out}, and

769
src/bigint/mul-fft.cc Normal file
View File

@ -0,0 +1,769 @@
// Copyright 2021 the V8 project authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
// FFT-based multiplication, due to Schönhage and Strassen.
// This implementation mostly follows the description given in:
// Christoph Lüders: Fast Multiplication of Large Integers,
// http://arxiv.org/abs/1503.04955
#include "src/bigint/bigint-internal.h"
#include "src/bigint/digit-arithmetic.h"
#include "src/bigint/util.h"
#include "src/bigint/vector-arithmetic.h"
namespace v8 {
namespace bigint {
namespace {
////////////////////////////////////////////////////////////////////////////////
// Part 1: Functions for "mod F_n" arithmetic.
// F_n is of the shape 2^K + 1, and for convenience we use K to count the
// number of digits rather than the number of bits, so F_n (or K) are implicit
// and deduced from the length {len} of the digits array.
// Helper function for {ModFn} below.
void ModFn_Helper(digit_t* x, int len, signed_digit_t high) {
if (high > 0) {
digit_t borrow = high;
x[len - 1] = 0;
for (int i = 0; i < len; i++) {
x[i] = digit_sub(x[i], borrow, &borrow);
if (borrow == 0) break;
}
} else {
digit_t carry = -high;
x[len - 1] = 0;
for (int i = 0; i < len; i++) {
x[i] = digit_add2(x[i], carry, &carry);
if (carry == 0) break;
}
}
}
// {x} := {x} mod F_n, assuming that {x} is "slightly" larger than F_n (e.g.
// after addition of two numbers that were mod-F_n-normalized before).
void ModFn(digit_t* x, int len) {
int K = len - 1;
signed_digit_t high = x[K];
if (high == 0) return;
ModFn_Helper(x, len, high);
high = x[K];
if (high == 0) return;
DCHECK(high == 1 || high == -1);
ModFn_Helper(x, len, high);
high = x[K];
if (high == -1) ModFn_Helper(x, len, high);
}
// {dest} := {src} mod F_n, assuming that {src} is about twice as long as F_n
// (e.g. after multiplication of two numbers that were mod-F_n-normalized
// before).
// {len} is length of {dest}; {src} is twice as long.
void ModFnDoubleWidth(digit_t* dest, const digit_t* src, int len) {
int K = len - 1;
digit_t borrow = 0;
for (int i = 0; i < K; i++) {
dest[i] = digit_sub2(src[i], src[i + K], borrow, &borrow);
}
dest[K] = digit_sub(0, borrow, &borrow);
ModFn(dest, len);
}
// Sets {sum} := {a} + {b} and {diff} := {a} - {b}, which is more efficient
// than computing sum and difference separately. Applies "mod F_n" normalization
// to both results.
void SumDiff(digit_t* sum, digit_t* diff, const digit_t* a, const digit_t* b,
int len) {
digit_t carry = 0;
digit_t borrow = 0;
for (int i = 0; i < len; i++) {
// Read both values first, because inputs and outputs can overlap.
digit_t ai = a[i];
digit_t bi = b[i];
sum[i] = digit_add3(ai, bi, carry, &carry);
diff[i] = digit_sub2(ai, bi, borrow, &borrow);
}
ModFn(sum, len);
ModFn(diff, len);
}
// {result} := ({input} << shift) mod F_n, where shift >= K.
void ShiftModFn_Large(digit_t* result, const digit_t* input, int digit_shift,
int bits_shift, int K) {
// If {digit_shift} is greater than K, we use the following transformation
// (where, since everything is mod 2^K + 1, we are allowed to add or
// subtract any multiple of 2^K + 1 at any time):
// x * 2^{K+m} mod 2^K + 1
// == x * 2^K * 2^m - (2^K + 1)*(x * 2^m) mod 2^K + 1
// == x * 2^K * 2^m - x * 2^K * 2^m - x * 2^m mod 2^K + 1
// == -x * 2^m mod 2^K + 1
// So the flow is the same as for m < K, but we invert the subtraction's
// operands. In order to avoid underflow, we virtually initialize the
// result to 2^K + 1:
// input = [ iK ][iK-1] .... .... [ i1 ][ i0 ]
// result = [ 1][0000] .... .... [0000][0001]
// + [ iK ] .... [ iX ]
// - [iX-1] .... [ i0 ]
DCHECK(digit_shift >= K);
digit_shift -= K;
digit_t borrow = 0;
if (bits_shift == 0) {
digit_t carry = 1;
for (int i = 0; i < digit_shift; i++) {
result[i] = digit_add2(input[i + K - digit_shift], carry, &carry);
}
result[digit_shift] = digit_sub(input[K] + carry, input[0], &borrow);
for (int i = digit_shift + 1; i < K; i++) {
digit_t d = input[i - digit_shift];
result[i] = digit_sub2(0, d, borrow, &borrow);
}
} else {
digit_t add_carry = 1;
digit_t input_carry =
input[K - digit_shift - 1] >> (kDigitBits - bits_shift);
for (int i = 0; i < digit_shift; i++) {
digit_t d = input[i + K - digit_shift];
digit_t summand = (d << bits_shift) | input_carry;
result[i] = digit_add2(summand, add_carry, &add_carry);
input_carry = d >> (kDigitBits - bits_shift);
}
{
// result[digit_shift] = (add_carry + iK_part) - i0_part
digit_t d = input[K];
digit_t iK_part = (d << bits_shift) | input_carry;
digit_t iK_carry = d >> (kDigitBits - bits_shift);
digit_t sum = digit_add2(add_carry, iK_part, &add_carry);
// {iK_carry} is less than a full digit, so we can merge {add_carry}
// into it without overflow.
iK_carry += add_carry;
d = input[0];
digit_t i0_part = d << bits_shift;
result[digit_shift] = digit_sub(sum, i0_part, &borrow);
input_carry = d >> (kDigitBits - bits_shift);
if (digit_shift + 1 < K) {
digit_t d = input[1];
digit_t subtrahend = (d << bits_shift) | input_carry;
result[digit_shift + 1] =
digit_sub2(iK_carry, subtrahend, borrow, &borrow);
input_carry = d >> (kDigitBits - bits_shift);
}
}
for (int i = digit_shift + 2; i < K; i++) {
digit_t d = input[i - digit_shift];
digit_t subtrahend = (d << bits_shift) | input_carry;
result[i] = digit_sub2(0, subtrahend, borrow, &borrow);
input_carry = d >> (kDigitBits - bits_shift);
}
}
// The virtual 1 in result[K] should be eliminated by {borrow}. If there
// is no borrow, then the virtual initialization was too much. Subtract
// 2^K + 1.
result[K] = 0;
if (borrow != 1) {
borrow = 1;
for (int i = 0; i < K; i++) {
result[i] = digit_sub(result[i], borrow, &borrow);
if (borrow == 0) break;
}
if (borrow != 0) {
// The result must be 2^K.
for (int i = 0; i < K; i++) result[i] = 0;
result[K] = 1;
}
}
}
// Sets {result} := {input} * 2^{power_of_two} mod 2^{K} + 1.
// This function is highly relevant for overall performance.
void ShiftModFn(digit_t* result, const digit_t* input, int power_of_two, int K,
int zero_above = 0x7FFFFFFF) {
// The modulo-reduction amounts to a subtraction, which we combine
// with the shift as follows:
// input = [ iK ][iK-1] .... .... [ i1 ][ i0 ]
// result = [iX-1] .... [ i0 ] <<<<<<<<<<< shift by {power_of_two}
// - [ iK ] .... [ iX ]
// where "X" is the index "K - digit_shift".
int digit_shift = power_of_two / kDigitBits;
int bits_shift = power_of_two % kDigitBits;
// By an analogous construction to the "digit_shift >= K" case,
// it turns out that:
// x * 2^{2K+m} == x * 2^m mod 2^K + 1.
while (digit_shift >= 2 * K) digit_shift -= 2 * K; // Faster than '%'!
digit_t borrow = 0;
if (digit_shift >= K) {
return ShiftModFn_Large(result, input, digit_shift, bits_shift, K);
}
if (bits_shift == 0) {
int i = 1;
// Regular loop: read input digits unless we know they are zero.
int cap = std::min(K - digit_shift, zero_above);
for (; i < cap; i++) {
result[i + digit_shift] = input[i];
}
cap = std::min(K, zero_above);
for (; i < cap; i++) {
digit_t d = input[i];
result[i - K + digit_shift] = digit_sub2(0, d, borrow, &borrow);
}
// Fallthrough: any remaining work can hard-code the knowledge that
// input[i] == 0.
for (; i < K - digit_shift; i++) {
DCHECK(input[i] == 0); // NOLINT(readability/check)
result[i + digit_shift] = 0;
}
for (; i < K; i++) {
DCHECK(input[i] == 0); // NOLINT(readability/check)
result[i - K + digit_shift] = digit_sub(0, borrow, &borrow);
}
result[digit_shift] = digit_sub2(input[0], input[K], borrow, &borrow);
} else {
digit_t carry = 0;
int i = 0;
// Regular loop: read input digits unless we know they are zero.
int cap = std::min(K - digit_shift, zero_above);
for (; i < cap; i++) {
digit_t d = input[i];
result[i + digit_shift] = (d << bits_shift) | carry;
carry = d >> (kDigitBits - bits_shift);
}
cap = std::min(K, zero_above);
for (; i < cap; i++) {
digit_t d = input[i];
result[i - K + digit_shift] =
digit_sub2(0, (d << bits_shift) | carry, borrow, &borrow);
carry = d >> (kDigitBits - bits_shift);
}
// Fallthrough: any remaining work can hard-code the knowledge that
// input[i] == 0.
for (; i < K - digit_shift; i++) {
DCHECK(input[i] == 0); // NOLINT(readability/check)
result[i + digit_shift] = carry;
carry = 0;
}
if (i < K) {
DCHECK(input[i] == 0); // NOLINT(readability/check)
result[i - K + digit_shift] = digit_sub2(0, carry, borrow, &borrow);
carry = 0;
i++;
}
for (; i < K; i++) {
DCHECK(input[i] == 0); // NOLINT(readability/check)
result[i - K + digit_shift] = digit_sub(0, borrow, &borrow);
}
digit_t d = input[K];
result[digit_shift] = digit_sub2(
result[digit_shift], (d << bits_shift) | carry, borrow, &borrow);
// No carry left.
DCHECK((d >> (kDigitBits - bits_shift)) == 0); // NOLINT(readability/check)
}
result[K] = 0;
for (int i = digit_shift + 1; i <= K && borrow > 0; i++) {
result[i] = digit_sub(result[i], borrow, &borrow);
}
if (borrow > 0) {
// Underflow means we subtracted too much. Add 2^K + 1.
digit_t carry = 1;
for (int i = 0; i <= K; i++) {
result[i] = digit_add2(result[i], carry, &carry);
if (carry == 0) break;
}
result[K] = digit_add2(result[K], 1, &carry);
}
}
////////////////////////////////////////////////////////////////////////////////
// Part 2: FFT-based multiplication is very sensitive to appropriate choice
// of parameters. The following functions choose the parameters that the
// subsequent actual computation will use. This is partially based on formal
// constraints and partially on experimentally-determined heuristics.
struct Parameters {
// We never use the default values, but skipping zero-initialization
// of these fields saddens and confuses MSan.
int m{0};
int K{0};
int n{0};
int s{0};
int r{0};
};
// Computes parameters for the main calculation, given a bit length {N} and
// an {m}. See the paper for details.
void ComputeParameters(int N, int m, Parameters* params) {
N *= kDigitBits;
int n = 1 << m; // 2^m
int nhalf = n >> 1;
int s = (N + n - 1) >> m; // ceil(N/n)
s = RoundUp(s, kDigitBits);
int K = m + 2 * s + 1; // K must be at least this big...
K = RoundUp(K, nhalf); // ...and a multiple of n/2.
int r = K >> (m - 1); // Which multiple?
// We want recursive calls to make progress, so force K to be a multiple
// of 8 if it's above the recursion threshold. Otherwise, K must be a
// multiple of kDigitBits.
const int threshold = (K + 1 >= kFftInnerThreshold * kDigitBits)
? 3 + kLog2DigitBits
: kLog2DigitBits;
int K_tz = CountTrailingZeros(K);
while (K_tz < threshold) {
K += (1 << K_tz);
r = K >> (m - 1);
K_tz = CountTrailingZeros(K);
}
DCHECK(K % kDigitBits == 0); // NOLINT(readability/check)
DCHECK(s % kDigitBits == 0); // NOLINT(readability/check)
params->K = K / kDigitBits;
params->s = s / kDigitBits;
params->n = n;
params->r = r;
}
// Computes parameters for recursive invocations ("inner layer").
void ComputeParameters_Inner(int N, Parameters* params) {
int max_m = CountTrailingZeros(N);
int N_bits = BitLength(N);
int m = N_bits - 4; // Don't let s get too small.
m = std::min(max_m, m);
N *= kDigitBits;
int n = 1 << m; // 2^m
// We can't round up s in the inner layer, because N = n*s is fixed.
int s = N >> m;
DCHECK(N == s * n);
int K = m + 2 * s + 1; // K must be at least this big...
K = RoundUp(K, n); // ...and a multiple of n and kDigitBits.
K = RoundUp(K, kDigitBits);
params->r = K >> m; // Which multiple?
DCHECK(K % kDigitBits == 0); // NOLINT(readability/check)
DCHECK(s % kDigitBits == 0); // NOLINT(readability/check)
params->K = K / kDigitBits;
params->s = s / kDigitBits;
params->n = n;
params->m = m;
}
int PredictInnerK(int N) {
Parameters params;
ComputeParameters_Inner(N, &params);
return params.K;
}
// Applies heuristics to decide whether {m} should be decremented, by looking
// at what would happen to {K} and {s} if {m} was decremented.
bool ShouldDecrementM(const Parameters& current, const Parameters& next,
const Parameters& after_next) {
// K == 64 seems to work particularly well.
if (current.K == 64 && next.K >= 112) return false;
// Small values for s are never efficient.
if (current.s < 6) return true;
// The time is roughly determined by K * n. When we decrement m, then
// n always halves, and K usually gets bigger, by up to 2x.
// For not-quite-so-small s, look at how much bigger K would get: if
// the K increase is small enough, making n smaller is worth it.
// Empirically, it's most meaningful to look at the K *after* next.
// The specific threshold values have been chosen by running many
// benchmarks on inputs of many sizes, and manually selecting thresholds
// that seemed to produce good results.
double factor = static_cast<double>(after_next.K) / current.K;
if ((current.s == 6 && factor < 3.85) || // --
(current.s == 7 && factor < 3.73) || // --
(current.s == 8 && factor < 3.55) || // --
(current.s == 9 && factor < 3.50) || // --
factor < 3.4) {
return true;
}
// If K is just below the recursion threshold, make sure we do recurse,
// unless doing so would be particularly inefficient (large inner_K).
// If K is just above the recursion threshold, doubling it often makes
// the inner call more efficient.
if (current.K >= 160 && current.K < 250 && PredictInnerK(next.K) < 28) {
return true;
}
// If we found no reason to decrement, keep m as large as possible.
return false;
}
// Decides what parameters to use for a given input bit length {N}.
// Returns the chosen m.
int GetParameters(int N, Parameters* params) {
int N_bits = BitLength(N);
int max_m = N_bits - 3; // Larger m make s too small.
max_m = std::max(kLog2DigitBits, max_m); // Smaller m break the logic below.
int m = max_m;
Parameters current;
ComputeParameters(N, m, &current);
Parameters next;
ComputeParameters(N, m - 1, &next);
while (m > 2) {
Parameters after_next;
ComputeParameters(N, m - 2, &after_next);
if (ShouldDecrementM(current, next, after_next)) {
m--;
current = next;
next = after_next;
} else {
break;
}
}
*params = current;
return m;
}
////////////////////////////////////////////////////////////////////////////////
// Part 3: Fast Fourier Transformation.
class FFTContainer {
public:
// {n} is the number of chunks, whose length is {K}+1.
// {K} determines F_n = 2^(K * kDigitBits) + 1.
FFTContainer(int n, int K, ProcessorImpl* processor)
: n_(n), K_(K), length_(K + 1), processor_(processor) {
storage_ = new digit_t[length_ * n_];
part_ = new digit_t*[n_];
digit_t* ptr = storage_;
for (int i = 0; i < n; i++, ptr += length_) {
part_[i] = ptr;
}
temp_ = new digit_t[length_ * 2];
}
FFTContainer() = delete;
FFTContainer(const FFTContainer&) = delete;
FFTContainer& operator=(const FFTContainer&) = delete;
~FFTContainer() {
delete[] storage_;
delete[] part_;
delete[] temp_;
}
void Start_Default(Digits X, int chunk_size, int theta, int omega);
void Start(Digits X, int chunk_size, int theta, int omega);
void NormalizeAndRecombine(int omega, int m, RWDigits Z, int chunk_size);
void CounterWeightAndRecombine(int theta, int m, RWDigits Z, int chunk_size);
void FFT_ReturnShuffledThreadsafe(int start, int len, int omega,
digit_t* temp);
void FFT_Recurse(int start, int half, int omega, digit_t* temp);
void BackwardFFT(int start, int len, int omega);
void BackwardFFT_Threadsafe(int start, int len, int omega, digit_t* temp);
void PointwiseMultiply(const FFTContainer& other);
void DoPointwiseMultiplication(const FFTContainer& other, int start, int end,
digit_t* temp);
int length() const { return length_; }
private:
const int n_; // Number of parts.
const int K_; // Always length_ - 1.
const int length_; // Length of each part, in digits.
ProcessorImpl* processor_;
digit_t* storage_; // Combined storage of all parts.
digit_t** part_; // Pointers to each part.
digit_t* temp_; // Temporary storage with size 2 * length_.
};
inline void CopyAndZeroExtend(digit_t* dst, const digit_t* src,
int digits_to_copy, size_t total_bytes) {
size_t bytes_to_copy = digits_to_copy * sizeof(digit_t);
memcpy(dst, src, bytes_to_copy);
memset(dst + digits_to_copy, 0, total_bytes - bytes_to_copy);
}
// Reads {X} into the FFTContainer's internal storage, dividing it into chunks
// while doing so; then performs the forward FFT.
void FFTContainer::Start_Default(Digits X, int chunk_size, int theta,
int omega) {
int len = X.len();
const digit_t* pointer = X.digits();
const size_t part_length_in_bytes = length_ * sizeof(digit_t);
int current_theta = 0;
int i = 0;
for (; i < n_ && len > 0; i++, current_theta += theta) {
chunk_size = std::min(chunk_size, len);
if (current_theta != 0) {
// Multiply with theta^i, and reduce modulo 2^K + 1.
// We pass theta as a shift amount; it really means 2^theta.
CopyAndZeroExtend(temp_, pointer, chunk_size, part_length_in_bytes);
ShiftModFn(part_[i], temp_, current_theta, K_, chunk_size);
} else {
CopyAndZeroExtend(part_[i], pointer, chunk_size, part_length_in_bytes);
}
pointer += chunk_size;
len -= chunk_size;
}
for (; i < n_; i++) {
memset(part_[i], 0, part_length_in_bytes);
}
FFT_ReturnShuffledThreadsafe(0, n_, omega, temp_);
}
// This version of Start is optimized for the case where ~half of the
// container will be filled with padding zeros.
void FFTContainer::Start(Digits X, int chunk_size, int theta, int omega) {
int len = X.len();
if (len > n_ * chunk_size / 2) {
return Start_Default(X, chunk_size, theta, omega);
}
DCHECK(theta == 0); // NOLINT(readability/check)
const digit_t* pointer = X.digits();
const size_t part_length_in_bytes = length_ * sizeof(digit_t);
int nhalf = n_ / 2;
// Unrolled first iteration.
CopyAndZeroExtend(part_[0], pointer, chunk_size, part_length_in_bytes);
CopyAndZeroExtend(part_[nhalf], pointer, chunk_size, part_length_in_bytes);
pointer += chunk_size;
len -= chunk_size;
int i = 1;
for (; i < nhalf && len > 0; i++) {
chunk_size = std::min(chunk_size, len);
CopyAndZeroExtend(part_[i], pointer, chunk_size, part_length_in_bytes);
int w = omega * i;
ShiftModFn(part_[i + nhalf], part_[i], w, K_, chunk_size);
pointer += chunk_size;
len -= chunk_size;
}
for (; i < nhalf; i++) {
memset(part_[i], 0, part_length_in_bytes);
memset(part_[i + nhalf], 0, part_length_in_bytes);
}
FFT_Recurse(0, nhalf, omega, temp_);
}
// Forward transformation.
// We use the "DIF" aka "decimation in frequency" transform, because it
// leaves the result in "bit reversed" order, which is precisely what we
// need as input for the "DIT" aka "decimation in time" backwards transform.
void FFTContainer::FFT_ReturnShuffledThreadsafe(int start, int len, int omega,
digit_t* temp) {
DCHECK((len & 1) == 0); // {len} must be even. NOLINT(readability/check)
int half = len / 2;
SumDiff(part_[start], part_[start + half], part_[start], part_[start + half],
length_);
for (int k = 1; k < half; k++) {
SumDiff(part_[start + k], temp, part_[start + k], part_[start + half + k],
length_);
int w = omega * k;
ShiftModFn(part_[start + half + k], temp, w, K_);
}
FFT_Recurse(start, half, omega, temp);
}
// Recursive step of the above, factored out for additional callers.
void FFTContainer::FFT_Recurse(int start, int half, int omega, digit_t* temp) {
if (half > 1) {
FFT_ReturnShuffledThreadsafe(start, half, 2 * omega, temp);
FFT_ReturnShuffledThreadsafe(start + half, half, 2 * omega, temp);
}
}
// Backward transformation.
// We use the "DIT" aka "decimation in time" transform here, because it
// turns bit-reversed input into normally sorted output.
void FFTContainer::BackwardFFT(int start, int len, int omega) {
BackwardFFT_Threadsafe(start, len, omega, temp_);
}
void FFTContainer::BackwardFFT_Threadsafe(int start, int len, int omega,
digit_t* temp) {
DCHECK((len & 1) == 0); // {len} must be even. NOLINT(readability/check)
int half = len / 2;
// Don't recurse for half == 2, as PointwiseMultiply already performed
// the first level of the backwards FFT.
if (half > 2) {
BackwardFFT_Threadsafe(start, half, 2 * omega, temp);
BackwardFFT_Threadsafe(start + half, half, 2 * omega, temp);
}
SumDiff(part_[start], part_[start + half], part_[start], part_[start + half],
length_);
for (int k = 1; k < half; k++) {
int w = omega * (len - k);
ShiftModFn(temp, part_[start + half + k], w, K_);
SumDiff(part_[start + k], part_[start + half + k], part_[start + k], temp,
length_);
}
}
// Recombines the result's parts into {Z}, after backwards FFT.
void FFTContainer::NormalizeAndRecombine(int omega, int m, RWDigits Z,
int chunk_size) {
Z.Clear();
int z_index = 0;
const int shift = n_ * omega - m;
for (int i = 0; i < n_; i++, z_index += chunk_size) {
digit_t* part = part_[i];
ShiftModFn(temp_, part, shift, K_);
digit_t carry = 0;
int zi = z_index;
int j = 0;
for (; j < length_ && zi < Z.len(); j++, zi++) {
Z[zi] = digit_add3(Z[zi], temp_[j], carry, &carry);
}
for (; j < length_; j++) {
DCHECK(temp_[j] == 0); // NOLINT(readability/check)
}
if (carry != 0) {
DCHECK(zi < Z.len());
Z[zi] = carry;
}
}
}
// Helper function for {CounterWeightAndRecombine} below.
bool ShouldBeNegative(const digit_t* x, int xlen, digit_t threshold, int s) {
if (x[2 * s] >= threshold) return true;
for (int i = 2 * s + 1; i < xlen; i++) {
if (x[i] > 0) return true;
}
return false;
}
// Same as {NormalizeAndRecombine} above, but for the needs of the recursive
// invocation ("inner layer") of FFT multiplication, where an additional
// counter-weighting step is required.
void FFTContainer::CounterWeightAndRecombine(int theta, int m, RWDigits Z,
int s) {
Z.Clear();
int z_index = 0;
for (int k = 0; k < n_; k++, z_index += s) {
int shift = -theta * k - m;
if (shift < 0) shift += 2 * n_ * theta;
DCHECK(shift >= 0); // NOLINT(readability/check)
digit_t* input = part_[k];
ShiftModFn(temp_, input, shift, K_);
int remaining_z = Z.len() - z_index;
if (ShouldBeNegative(temp_, length_, k + 1, s)) {
// Subtract F_n from input before adding to result. We use the following
// transformation (knowing that X < F_n):
// Z + (X - F_n) == Z - (F_n - X)
digit_t borrow_z = 0;
digit_t borrow_Fn = 0;
{
// i == 0:
digit_t d = digit_sub(1, temp_[0], &borrow_Fn);
Z[z_index] = digit_sub(Z[z_index], d, &borrow_z);
}
int i = 1;
for (; i < K_ && i < remaining_z; i++) {
digit_t d = digit_sub2(0, temp_[i], borrow_Fn, &borrow_Fn);
Z[z_index + i] = digit_sub2(Z[z_index + i], d, borrow_z, &borrow_z);
}
DCHECK(i == K_ && K_ == length_ - 1);
for (; i < length_ && i < remaining_z; i++) {
digit_t d = digit_sub2(1, temp_[i], borrow_Fn, &borrow_Fn);
Z[z_index + i] = digit_sub2(Z[z_index + i], d, borrow_z, &borrow_z);
}
DCHECK(borrow_Fn == 0); // NOLINT(readability/check)
for (; borrow_z > 0 && i < remaining_z; i++) {
Z[z_index + i] = digit_sub(Z[z_index + i], borrow_z, &borrow_z);
}
} else {
digit_t carry = 0;
int i = 0;
for (; i < length_ && i < remaining_z; i++) {
Z[z_index + i] = digit_add3(Z[z_index + i], temp_[i], carry, &carry);
}
for (; i < length_; i++) {
DCHECK(temp_[i] == 0); // NOLINT(readability/check)
}
for (; carry > 0 && i < remaining_z; i++) {
Z[z_index + i] = digit_add2(Z[z_index + i], carry, &carry);
}
// {carry} might be != 0 here if Z was negative before. That's fine.
}
}
}
// Main FFT function for recursive invocations ("inner layer").
void MultiplyFFT_Inner(RWDigits Z, Digits X, Digits Y, const Parameters& params,
ProcessorImpl* processor) {
int omega = 2 * params.r; // really: 2^(2r)
int theta = params.r; // really: 2^r
FFTContainer a(params.n, params.K, processor);
a.Start_Default(X, params.s, theta, omega);
FFTContainer b(params.n, params.K, processor);
b.Start_Default(Y, params.s, theta, omega);
a.PointwiseMultiply(b);
if (processor->should_terminate()) return;
FFTContainer& c = a;
c.BackwardFFT(0, params.n, omega);
c.CounterWeightAndRecombine(theta, params.m, Z, params.s);
}
// Actual implementation of pointwise multiplications.
void FFTContainer::DoPointwiseMultiplication(const FFTContainer& other,
int start, int end,
digit_t* temp) {
// The (K_ & 3) != 0 condition makes sure that the inner FFT gets
// to split the work into at least 4 chunks.
bool use_fft = length_ >= kFftInnerThreshold && (K_ & 3) == 0;
Parameters params;
if (use_fft) ComputeParameters_Inner(K_, &params);
RWDigits result(temp, 2 * length_);
for (int i = start; i < end; i++) {
Digits A(part_[i], length_);
Digits B(other.part_[i], length_);
if (use_fft) {
MultiplyFFT_Inner(result, A, B, params, processor_);
} else {
processor_->Multiply(result, A, B);
}
if (processor_->should_terminate()) return;
ModFnDoubleWidth(part_[i], result.digits(), length_);
// To improve cache friendliness, we perform the first level of the
// backwards FFT here.
if ((i & 1) == 1) {
SumDiff(part_[i - 1], part_[i], part_[i - 1], part_[i], length_);
}
}
}
// Convenient entry point for pointwise multiplications.
void FFTContainer::PointwiseMultiply(const FFTContainer& other) {
DCHECK(n_ == other.n_);
DoPointwiseMultiplication(other, 0, n_, temp_);
}
} // namespace
////////////////////////////////////////////////////////////////////////////////
// Part 4: Tying everything together into a multiplication algorithm.
// TODO(jkummerow): Consider doing a "Mersenne transform" and CRT reconstruction
// of the final result. Might yield a few percent of perf improvement.
// TODO(jkummerow): Consider implementing the "sqrt(2) trick".
// Gaudry/Kruppa/Zimmerman report that it saved them around 10%.
void ProcessorImpl::MultiplyFFT(RWDigits Z, Digits X, Digits Y) {
Parameters params;
int m = GetParameters(X.len() + Y.len(), &params);
int omega = params.r; // really: 2^r
FFTContainer a(params.n, params.K, this);
a.Start(X, params.s, 0, omega);
if (X == Y) {
// Squaring.
a.PointwiseMultiply(a);
} else {
FFTContainer b(params.n, params.K, this);
b.Start(Y, params.s, 0, omega);
a.PointwiseMultiply(b);
}
if (should_terminate()) return;
a.BackwardFFT(0, params.n, omega);
a.NormalizeAndRecombine(omega, m, Z, params.s);
}
} // namespace bigint
} // namespace v8

View File

@ -123,11 +123,8 @@ void ProcessorImpl::Toom3Main(RWDigits Z, Digits X, Digits Y) {
// Phase 3a: Pointwise multiplication, steps 0, 1, m1.
Multiply(r_0, X0, Y0);
if (should_terminate()) return;
Multiply(r_1, p_1, q_1);
if (should_terminate()) return;
Multiply(r_m1, p_m1, q_m1);
if (should_terminate()) return;
bool r_m1_sign = p_m1_sign != q_m1_sign;
// Phase 2b: Evaluation, steps m2 and inf.
@ -152,14 +149,12 @@ void ProcessorImpl::Toom3Main(RWDigits Z, Digits X, Digits Y) {
MARK_INVALID(p_m1);
MARK_INVALID(q_m1);
Multiply(r_m2, p_m2, q_m2);
if (should_terminate()) return;
bool r_m2_sign = p_m2_sign != q_m2_sign;
RWDigits r_inf(t + r_len, r_len);
MARK_INVALID(p_m2);
MARK_INVALID(q_m2);
Multiply(r_inf, X2, Y2);
if (should_terminate()) return;
// Phase 4: Interpolation.
Digits R0 = r_0;
@ -215,7 +210,6 @@ void ProcessorImpl::MultiplyToomCook(RWDigits Z, Digits X, Digits Y) {
if (X.len() > Y.len()) {
ScratchDigits T(2 * k);
for (int i = k; i < X.len(); i += k) {
if (should_terminate()) return;
Digits Xi(X, i, k);
// TODO(jkummerow): would it be a measurable improvement to craft a
// "ToomChunk" method in the style of {KaratsubaChunk}?

View File

@ -27,8 +27,9 @@ int PrintHelp(char** argv) {
return 1;
}
#define TESTS(V) \
V(kKaratsuba, "karatsuba") V(kBurnikel, "burnikel") V(kToom, "toom")
#define TESTS(V) \
V(kKaratsuba, "karatsuba") \
V(kBurnikel, "burnikel") V(kToom, "toom") V(kFFT, "fft")
enum Operation { kNoOp, kList, kTest };
@ -168,6 +169,10 @@ class Runner {
for (int i = 0; i < runs_; i++) {
TestToom(&count);
}
} else if (test_ == kFFT) {
for (int i = 0; i < runs_; i++) {
TestFFT(&count);
}
} else {
DCHECK(false); // Unreachable.
}
@ -225,6 +230,40 @@ class Runner {
#endif // V8_ADVANCED_BIGINT_ALGORITHMS
}
void TestFFT(int* count) {
#if V8_ADVANCED_BIGINT_ALGORITHMS
// Larger multiplications are slower, so to keep individual runs fast,
// we test a few random samples. With build bots running 24/7, we'll
// get decent coverage over time.
uint64_t random_bits = rng_.NextUint64();
int min = kFftThreshold - static_cast<int>(random_bits & 1023);
random_bits >>= 10;
int max = kFftThreshold + static_cast<int>(random_bits & 1023);
random_bits >>= 10;
// If delta is too small, then this run gets too slow. If it happened
// to be zero, we'd even loop forever!
int delta = 10 + (random_bits & 127);
std::cout << "min " << min << " max " << max << " delta " << delta << "\n";
for (int right_size = min; right_size <= max; right_size += delta) {
for (int left_size = right_size; left_size <= max; left_size += delta) {
ScratchDigits A(left_size);
ScratchDigits B(right_size);
int result_len = MultiplyResultLength(A, B);
ScratchDigits result(result_len);
ScratchDigits result_toom(result_len);
GenerateRandom(A);
GenerateRandom(B);
processor()->MultiplyFFT(result, A, B);
// Using Toom-Cook as reference.
processor()->MultiplyToomCook(result_toom, A, B);
AssertEquals(A, B, result_toom, result);
if (error_) return;
(*count)++;
}
}
#endif // V8_ADVANCED_BIGINT_ALGORITHMS
}
void TestBurnikel(int* count) {
// Start small to save test execution time.
constexpr int kMin = kBurnikelThreshold / 2;