Reland "[bigint] FFT-based multiplication"
The Schönhage-Strassen method for *very* large inputs.
This is a reland of 347ba35716
,
with added zero-initialization to pacify MSan (spurious report).
Originally:
> Reviewed-on: https://chromium-review.googlesource.com/c/v8/v8/+/3000742
> Commit-Queue: Jakob Kummerow <jkummerow@chromium.org>
> Reviewed-by: Maya Lekova <mslekova@chromium.org>
> Cr-Commit-Position: refs/heads/master@{#75659}
Bug: v8:11515
Change-Id: Ieac6e174bde6eb09af0a9a9a49969feabca79e81
Reviewed-on: https://chromium-review.googlesource.com/c/v8/v8/+/3018081
Reviewed-by: Maya Lekova <mslekova@chromium.org>
Commit-Queue: Jakob Kummerow <jkummerow@chromium.org>
Cr-Commit-Position: refs/heads/master@{#75663}
This commit is contained in:
parent
2a6b205594
commit
afa6126921
@ -2510,6 +2510,7 @@ filegroup(
|
||||
"src/bigint/div-helpers.cc",
|
||||
"src/bigint/div-helpers.h",
|
||||
"src/bigint/div-schoolbook.cc",
|
||||
"src/bigint/mul-fft.cc",
|
||||
"src/bigint/mul-karatsuba.cc",
|
||||
"src/bigint/mul-schoolbook.cc",
|
||||
"src/bigint/mul-toom.cc",
|
||||
|
5
BUILD.gn
5
BUILD.gn
@ -4964,7 +4964,10 @@ v8_source_set("v8_bigint") {
|
||||
]
|
||||
|
||||
if (v8_advanced_bigint_algorithms) {
|
||||
sources += [ "src/bigint/mul-toom.cc" ]
|
||||
sources += [
|
||||
"src/bigint/mul-fft.cc",
|
||||
"src/bigint/mul-toom.cc",
|
||||
]
|
||||
|
||||
defines = [ "V8_ADVANCED_BIGINT_ALGORITHMS" ]
|
||||
}
|
||||
|
@ -35,7 +35,8 @@ void ProcessorImpl::Multiply(RWDigits Z, Digits X, Digits Y) {
|
||||
return MultiplyKaratsuba(Z, X, Y);
|
||||
#else
|
||||
if (Y.len() < kToomThreshold) return MultiplyKaratsuba(Z, X, Y);
|
||||
return MultiplyToomCook(Z, X, Y);
|
||||
if (Y.len() < kFftThreshold) return MultiplyToomCook(Z, X, Y);
|
||||
return MultiplyFFT(Z, X, Y);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
@ -14,6 +14,9 @@ namespace bigint {
|
||||
|
||||
constexpr int kKaratsubaThreshold = 34;
|
||||
constexpr int kToomThreshold = 193;
|
||||
constexpr int kFftThreshold = 1500;
|
||||
constexpr int kFftInnerThreshold = 200;
|
||||
|
||||
constexpr int kBurnikelThreshold = 57;
|
||||
|
||||
class ProcessorImpl : public Processor {
|
||||
@ -42,6 +45,8 @@ class ProcessorImpl : public Processor {
|
||||
#if V8_ADVANCED_BIGINT_ALGORITHMS
|
||||
void MultiplyToomCook(RWDigits Z, Digits X, Digits Y);
|
||||
void Toom3Main(RWDigits Z, Digits X, Digits Y);
|
||||
|
||||
void MultiplyFFT(RWDigits Z, Digits X, Digits Y);
|
||||
#endif // V8_ADVANCED_BIGINT_ALGORITHMS
|
||||
|
||||
// {out_length} initially contains the allocated capacity of {out}, and
|
||||
|
769
src/bigint/mul-fft.cc
Normal file
769
src/bigint/mul-fft.cc
Normal file
@ -0,0 +1,769 @@
|
||||
// Copyright 2021 the V8 project authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style license that can be
|
||||
// found in the LICENSE file.
|
||||
|
||||
// FFT-based multiplication, due to Schönhage and Strassen.
|
||||
// This implementation mostly follows the description given in:
|
||||
// Christoph Lüders: Fast Multiplication of Large Integers,
|
||||
// http://arxiv.org/abs/1503.04955
|
||||
|
||||
#include "src/bigint/bigint-internal.h"
|
||||
#include "src/bigint/digit-arithmetic.h"
|
||||
#include "src/bigint/util.h"
|
||||
#include "src/bigint/vector-arithmetic.h"
|
||||
|
||||
namespace v8 {
|
||||
namespace bigint {
|
||||
|
||||
namespace {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Part 1: Functions for "mod F_n" arithmetic.
|
||||
// F_n is of the shape 2^K + 1, and for convenience we use K to count the
|
||||
// number of digits rather than the number of bits, so F_n (or K) are implicit
|
||||
// and deduced from the length {len} of the digits array.
|
||||
|
||||
// Helper function for {ModFn} below.
|
||||
void ModFn_Helper(digit_t* x, int len, signed_digit_t high) {
|
||||
if (high > 0) {
|
||||
digit_t borrow = high;
|
||||
x[len - 1] = 0;
|
||||
for (int i = 0; i < len; i++) {
|
||||
x[i] = digit_sub(x[i], borrow, &borrow);
|
||||
if (borrow == 0) break;
|
||||
}
|
||||
} else {
|
||||
digit_t carry = -high;
|
||||
x[len - 1] = 0;
|
||||
for (int i = 0; i < len; i++) {
|
||||
x[i] = digit_add2(x[i], carry, &carry);
|
||||
if (carry == 0) break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// {x} := {x} mod F_n, assuming that {x} is "slightly" larger than F_n (e.g.
|
||||
// after addition of two numbers that were mod-F_n-normalized before).
|
||||
void ModFn(digit_t* x, int len) {
|
||||
int K = len - 1;
|
||||
signed_digit_t high = x[K];
|
||||
if (high == 0) return;
|
||||
ModFn_Helper(x, len, high);
|
||||
high = x[K];
|
||||
if (high == 0) return;
|
||||
DCHECK(high == 1 || high == -1);
|
||||
ModFn_Helper(x, len, high);
|
||||
high = x[K];
|
||||
if (high == -1) ModFn_Helper(x, len, high);
|
||||
}
|
||||
|
||||
// {dest} := {src} mod F_n, assuming that {src} is about twice as long as F_n
|
||||
// (e.g. after multiplication of two numbers that were mod-F_n-normalized
|
||||
// before).
|
||||
// {len} is length of {dest}; {src} is twice as long.
|
||||
void ModFnDoubleWidth(digit_t* dest, const digit_t* src, int len) {
|
||||
int K = len - 1;
|
||||
digit_t borrow = 0;
|
||||
for (int i = 0; i < K; i++) {
|
||||
dest[i] = digit_sub2(src[i], src[i + K], borrow, &borrow);
|
||||
}
|
||||
dest[K] = digit_sub(0, borrow, &borrow);
|
||||
ModFn(dest, len);
|
||||
}
|
||||
|
||||
// Sets {sum} := {a} + {b} and {diff} := {a} - {b}, which is more efficient
|
||||
// than computing sum and difference separately. Applies "mod F_n" normalization
|
||||
// to both results.
|
||||
void SumDiff(digit_t* sum, digit_t* diff, const digit_t* a, const digit_t* b,
|
||||
int len) {
|
||||
digit_t carry = 0;
|
||||
digit_t borrow = 0;
|
||||
for (int i = 0; i < len; i++) {
|
||||
// Read both values first, because inputs and outputs can overlap.
|
||||
digit_t ai = a[i];
|
||||
digit_t bi = b[i];
|
||||
sum[i] = digit_add3(ai, bi, carry, &carry);
|
||||
diff[i] = digit_sub2(ai, bi, borrow, &borrow);
|
||||
}
|
||||
ModFn(sum, len);
|
||||
ModFn(diff, len);
|
||||
}
|
||||
|
||||
// {result} := ({input} << shift) mod F_n, where shift >= K.
|
||||
void ShiftModFn_Large(digit_t* result, const digit_t* input, int digit_shift,
|
||||
int bits_shift, int K) {
|
||||
// If {digit_shift} is greater than K, we use the following transformation
|
||||
// (where, since everything is mod 2^K + 1, we are allowed to add or
|
||||
// subtract any multiple of 2^K + 1 at any time):
|
||||
// x * 2^{K+m} mod 2^K + 1
|
||||
// == x * 2^K * 2^m - (2^K + 1)*(x * 2^m) mod 2^K + 1
|
||||
// == x * 2^K * 2^m - x * 2^K * 2^m - x * 2^m mod 2^K + 1
|
||||
// == -x * 2^m mod 2^K + 1
|
||||
// So the flow is the same as for m < K, but we invert the subtraction's
|
||||
// operands. In order to avoid underflow, we virtually initialize the
|
||||
// result to 2^K + 1:
|
||||
// input = [ iK ][iK-1] .... .... [ i1 ][ i0 ]
|
||||
// result = [ 1][0000] .... .... [0000][0001]
|
||||
// + [ iK ] .... [ iX ]
|
||||
// - [iX-1] .... [ i0 ]
|
||||
DCHECK(digit_shift >= K);
|
||||
digit_shift -= K;
|
||||
digit_t borrow = 0;
|
||||
if (bits_shift == 0) {
|
||||
digit_t carry = 1;
|
||||
for (int i = 0; i < digit_shift; i++) {
|
||||
result[i] = digit_add2(input[i + K - digit_shift], carry, &carry);
|
||||
}
|
||||
result[digit_shift] = digit_sub(input[K] + carry, input[0], &borrow);
|
||||
for (int i = digit_shift + 1; i < K; i++) {
|
||||
digit_t d = input[i - digit_shift];
|
||||
result[i] = digit_sub2(0, d, borrow, &borrow);
|
||||
}
|
||||
} else {
|
||||
digit_t add_carry = 1;
|
||||
digit_t input_carry =
|
||||
input[K - digit_shift - 1] >> (kDigitBits - bits_shift);
|
||||
for (int i = 0; i < digit_shift; i++) {
|
||||
digit_t d = input[i + K - digit_shift];
|
||||
digit_t summand = (d << bits_shift) | input_carry;
|
||||
result[i] = digit_add2(summand, add_carry, &add_carry);
|
||||
input_carry = d >> (kDigitBits - bits_shift);
|
||||
}
|
||||
{
|
||||
// result[digit_shift] = (add_carry + iK_part) - i0_part
|
||||
digit_t d = input[K];
|
||||
digit_t iK_part = (d << bits_shift) | input_carry;
|
||||
digit_t iK_carry = d >> (kDigitBits - bits_shift);
|
||||
digit_t sum = digit_add2(add_carry, iK_part, &add_carry);
|
||||
// {iK_carry} is less than a full digit, so we can merge {add_carry}
|
||||
// into it without overflow.
|
||||
iK_carry += add_carry;
|
||||
d = input[0];
|
||||
digit_t i0_part = d << bits_shift;
|
||||
result[digit_shift] = digit_sub(sum, i0_part, &borrow);
|
||||
input_carry = d >> (kDigitBits - bits_shift);
|
||||
if (digit_shift + 1 < K) {
|
||||
digit_t d = input[1];
|
||||
digit_t subtrahend = (d << bits_shift) | input_carry;
|
||||
result[digit_shift + 1] =
|
||||
digit_sub2(iK_carry, subtrahend, borrow, &borrow);
|
||||
input_carry = d >> (kDigitBits - bits_shift);
|
||||
}
|
||||
}
|
||||
for (int i = digit_shift + 2; i < K; i++) {
|
||||
digit_t d = input[i - digit_shift];
|
||||
digit_t subtrahend = (d << bits_shift) | input_carry;
|
||||
result[i] = digit_sub2(0, subtrahend, borrow, &borrow);
|
||||
input_carry = d >> (kDigitBits - bits_shift);
|
||||
}
|
||||
}
|
||||
// The virtual 1 in result[K] should be eliminated by {borrow}. If there
|
||||
// is no borrow, then the virtual initialization was too much. Subtract
|
||||
// 2^K + 1.
|
||||
result[K] = 0;
|
||||
if (borrow != 1) {
|
||||
borrow = 1;
|
||||
for (int i = 0; i < K; i++) {
|
||||
result[i] = digit_sub(result[i], borrow, &borrow);
|
||||
if (borrow == 0) break;
|
||||
}
|
||||
if (borrow != 0) {
|
||||
// The result must be 2^K.
|
||||
for (int i = 0; i < K; i++) result[i] = 0;
|
||||
result[K] = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sets {result} := {input} * 2^{power_of_two} mod 2^{K} + 1.
|
||||
// This function is highly relevant for overall performance.
|
||||
void ShiftModFn(digit_t* result, const digit_t* input, int power_of_two, int K,
|
||||
int zero_above = 0x7FFFFFFF) {
|
||||
// The modulo-reduction amounts to a subtraction, which we combine
|
||||
// with the shift as follows:
|
||||
// input = [ iK ][iK-1] .... .... [ i1 ][ i0 ]
|
||||
// result = [iX-1] .... [ i0 ] <<<<<<<<<<< shift by {power_of_two}
|
||||
// - [ iK ] .... [ iX ]
|
||||
// where "X" is the index "K - digit_shift".
|
||||
int digit_shift = power_of_two / kDigitBits;
|
||||
int bits_shift = power_of_two % kDigitBits;
|
||||
// By an analogous construction to the "digit_shift >= K" case,
|
||||
// it turns out that:
|
||||
// x * 2^{2K+m} == x * 2^m mod 2^K + 1.
|
||||
while (digit_shift >= 2 * K) digit_shift -= 2 * K; // Faster than '%'!
|
||||
digit_t borrow = 0;
|
||||
if (digit_shift >= K) {
|
||||
return ShiftModFn_Large(result, input, digit_shift, bits_shift, K);
|
||||
}
|
||||
if (bits_shift == 0) {
|
||||
int i = 1;
|
||||
// Regular loop: read input digits unless we know they are zero.
|
||||
int cap = std::min(K - digit_shift, zero_above);
|
||||
for (; i < cap; i++) {
|
||||
result[i + digit_shift] = input[i];
|
||||
}
|
||||
cap = std::min(K, zero_above);
|
||||
for (; i < cap; i++) {
|
||||
digit_t d = input[i];
|
||||
result[i - K + digit_shift] = digit_sub2(0, d, borrow, &borrow);
|
||||
}
|
||||
// Fallthrough: any remaining work can hard-code the knowledge that
|
||||
// input[i] == 0.
|
||||
for (; i < K - digit_shift; i++) {
|
||||
DCHECK(input[i] == 0); // NOLINT(readability/check)
|
||||
result[i + digit_shift] = 0;
|
||||
}
|
||||
for (; i < K; i++) {
|
||||
DCHECK(input[i] == 0); // NOLINT(readability/check)
|
||||
result[i - K + digit_shift] = digit_sub(0, borrow, &borrow);
|
||||
}
|
||||
result[digit_shift] = digit_sub2(input[0], input[K], borrow, &borrow);
|
||||
} else {
|
||||
digit_t carry = 0;
|
||||
int i = 0;
|
||||
// Regular loop: read input digits unless we know they are zero.
|
||||
int cap = std::min(K - digit_shift, zero_above);
|
||||
for (; i < cap; i++) {
|
||||
digit_t d = input[i];
|
||||
result[i + digit_shift] = (d << bits_shift) | carry;
|
||||
carry = d >> (kDigitBits - bits_shift);
|
||||
}
|
||||
cap = std::min(K, zero_above);
|
||||
for (; i < cap; i++) {
|
||||
digit_t d = input[i];
|
||||
result[i - K + digit_shift] =
|
||||
digit_sub2(0, (d << bits_shift) | carry, borrow, &borrow);
|
||||
carry = d >> (kDigitBits - bits_shift);
|
||||
}
|
||||
// Fallthrough: any remaining work can hard-code the knowledge that
|
||||
// input[i] == 0.
|
||||
for (; i < K - digit_shift; i++) {
|
||||
DCHECK(input[i] == 0); // NOLINT(readability/check)
|
||||
result[i + digit_shift] = carry;
|
||||
carry = 0;
|
||||
}
|
||||
if (i < K) {
|
||||
DCHECK(input[i] == 0); // NOLINT(readability/check)
|
||||
result[i - K + digit_shift] = digit_sub2(0, carry, borrow, &borrow);
|
||||
carry = 0;
|
||||
i++;
|
||||
}
|
||||
for (; i < K; i++) {
|
||||
DCHECK(input[i] == 0); // NOLINT(readability/check)
|
||||
result[i - K + digit_shift] = digit_sub(0, borrow, &borrow);
|
||||
}
|
||||
digit_t d = input[K];
|
||||
result[digit_shift] = digit_sub2(
|
||||
result[digit_shift], (d << bits_shift) | carry, borrow, &borrow);
|
||||
// No carry left.
|
||||
DCHECK((d >> (kDigitBits - bits_shift)) == 0); // NOLINT(readability/check)
|
||||
}
|
||||
result[K] = 0;
|
||||
for (int i = digit_shift + 1; i <= K && borrow > 0; i++) {
|
||||
result[i] = digit_sub(result[i], borrow, &borrow);
|
||||
}
|
||||
if (borrow > 0) {
|
||||
// Underflow means we subtracted too much. Add 2^K + 1.
|
||||
digit_t carry = 1;
|
||||
for (int i = 0; i <= K; i++) {
|
||||
result[i] = digit_add2(result[i], carry, &carry);
|
||||
if (carry == 0) break;
|
||||
}
|
||||
result[K] = digit_add2(result[K], 1, &carry);
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Part 2: FFT-based multiplication is very sensitive to appropriate choice
|
||||
// of parameters. The following functions choose the parameters that the
|
||||
// subsequent actual computation will use. This is partially based on formal
|
||||
// constraints and partially on experimentally-determined heuristics.
|
||||
|
||||
struct Parameters {
|
||||
// We never use the default values, but skipping zero-initialization
|
||||
// of these fields saddens and confuses MSan.
|
||||
int m{0};
|
||||
int K{0};
|
||||
int n{0};
|
||||
int s{0};
|
||||
int r{0};
|
||||
};
|
||||
|
||||
// Computes parameters for the main calculation, given a bit length {N} and
|
||||
// an {m}. See the paper for details.
|
||||
void ComputeParameters(int N, int m, Parameters* params) {
|
||||
N *= kDigitBits;
|
||||
int n = 1 << m; // 2^m
|
||||
int nhalf = n >> 1;
|
||||
int s = (N + n - 1) >> m; // ceil(N/n)
|
||||
s = RoundUp(s, kDigitBits);
|
||||
int K = m + 2 * s + 1; // K must be at least this big...
|
||||
K = RoundUp(K, nhalf); // ...and a multiple of n/2.
|
||||
int r = K >> (m - 1); // Which multiple?
|
||||
|
||||
// We want recursive calls to make progress, so force K to be a multiple
|
||||
// of 8 if it's above the recursion threshold. Otherwise, K must be a
|
||||
// multiple of kDigitBits.
|
||||
const int threshold = (K + 1 >= kFftInnerThreshold * kDigitBits)
|
||||
? 3 + kLog2DigitBits
|
||||
: kLog2DigitBits;
|
||||
int K_tz = CountTrailingZeros(K);
|
||||
while (K_tz < threshold) {
|
||||
K += (1 << K_tz);
|
||||
r = K >> (m - 1);
|
||||
K_tz = CountTrailingZeros(K);
|
||||
}
|
||||
|
||||
DCHECK(K % kDigitBits == 0); // NOLINT(readability/check)
|
||||
DCHECK(s % kDigitBits == 0); // NOLINT(readability/check)
|
||||
params->K = K / kDigitBits;
|
||||
params->s = s / kDigitBits;
|
||||
params->n = n;
|
||||
params->r = r;
|
||||
}
|
||||
|
||||
// Computes parameters for recursive invocations ("inner layer").
|
||||
void ComputeParameters_Inner(int N, Parameters* params) {
|
||||
int max_m = CountTrailingZeros(N);
|
||||
int N_bits = BitLength(N);
|
||||
int m = N_bits - 4; // Don't let s get too small.
|
||||
m = std::min(max_m, m);
|
||||
N *= kDigitBits;
|
||||
int n = 1 << m; // 2^m
|
||||
// We can't round up s in the inner layer, because N = n*s is fixed.
|
||||
int s = N >> m;
|
||||
DCHECK(N == s * n);
|
||||
int K = m + 2 * s + 1; // K must be at least this big...
|
||||
K = RoundUp(K, n); // ...and a multiple of n and kDigitBits.
|
||||
K = RoundUp(K, kDigitBits);
|
||||
params->r = K >> m; // Which multiple?
|
||||
DCHECK(K % kDigitBits == 0); // NOLINT(readability/check)
|
||||
DCHECK(s % kDigitBits == 0); // NOLINT(readability/check)
|
||||
params->K = K / kDigitBits;
|
||||
params->s = s / kDigitBits;
|
||||
params->n = n;
|
||||
params->m = m;
|
||||
}
|
||||
|
||||
int PredictInnerK(int N) {
|
||||
Parameters params;
|
||||
ComputeParameters_Inner(N, ¶ms);
|
||||
return params.K;
|
||||
}
|
||||
|
||||
// Applies heuristics to decide whether {m} should be decremented, by looking
|
||||
// at what would happen to {K} and {s} if {m} was decremented.
|
||||
bool ShouldDecrementM(const Parameters& current, const Parameters& next,
|
||||
const Parameters& after_next) {
|
||||
// K == 64 seems to work particularly well.
|
||||
if (current.K == 64 && next.K >= 112) return false;
|
||||
// Small values for s are never efficient.
|
||||
if (current.s < 6) return true;
|
||||
// The time is roughly determined by K * n. When we decrement m, then
|
||||
// n always halves, and K usually gets bigger, by up to 2x.
|
||||
// For not-quite-so-small s, look at how much bigger K would get: if
|
||||
// the K increase is small enough, making n smaller is worth it.
|
||||
// Empirically, it's most meaningful to look at the K *after* next.
|
||||
// The specific threshold values have been chosen by running many
|
||||
// benchmarks on inputs of many sizes, and manually selecting thresholds
|
||||
// that seemed to produce good results.
|
||||
double factor = static_cast<double>(after_next.K) / current.K;
|
||||
if ((current.s == 6 && factor < 3.85) || // --
|
||||
(current.s == 7 && factor < 3.73) || // --
|
||||
(current.s == 8 && factor < 3.55) || // --
|
||||
(current.s == 9 && factor < 3.50) || // --
|
||||
factor < 3.4) {
|
||||
return true;
|
||||
}
|
||||
// If K is just below the recursion threshold, make sure we do recurse,
|
||||
// unless doing so would be particularly inefficient (large inner_K).
|
||||
// If K is just above the recursion threshold, doubling it often makes
|
||||
// the inner call more efficient.
|
||||
if (current.K >= 160 && current.K < 250 && PredictInnerK(next.K) < 28) {
|
||||
return true;
|
||||
}
|
||||
// If we found no reason to decrement, keep m as large as possible.
|
||||
return false;
|
||||
}
|
||||
|
||||
// Decides what parameters to use for a given input bit length {N}.
|
||||
// Returns the chosen m.
|
||||
int GetParameters(int N, Parameters* params) {
|
||||
int N_bits = BitLength(N);
|
||||
int max_m = N_bits - 3; // Larger m make s too small.
|
||||
max_m = std::max(kLog2DigitBits, max_m); // Smaller m break the logic below.
|
||||
int m = max_m;
|
||||
Parameters current;
|
||||
ComputeParameters(N, m, ¤t);
|
||||
Parameters next;
|
||||
ComputeParameters(N, m - 1, &next);
|
||||
while (m > 2) {
|
||||
Parameters after_next;
|
||||
ComputeParameters(N, m - 2, &after_next);
|
||||
if (ShouldDecrementM(current, next, after_next)) {
|
||||
m--;
|
||||
current = next;
|
||||
next = after_next;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
*params = current;
|
||||
return m;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Part 3: Fast Fourier Transformation.
|
||||
|
||||
class FFTContainer {
|
||||
public:
|
||||
// {n} is the number of chunks, whose length is {K}+1.
|
||||
// {K} determines F_n = 2^(K * kDigitBits) + 1.
|
||||
FFTContainer(int n, int K, ProcessorImpl* processor)
|
||||
: n_(n), K_(K), length_(K + 1), processor_(processor) {
|
||||
storage_ = new digit_t[length_ * n_];
|
||||
part_ = new digit_t*[n_];
|
||||
digit_t* ptr = storage_;
|
||||
for (int i = 0; i < n; i++, ptr += length_) {
|
||||
part_[i] = ptr;
|
||||
}
|
||||
temp_ = new digit_t[length_ * 2];
|
||||
}
|
||||
FFTContainer() = delete;
|
||||
FFTContainer(const FFTContainer&) = delete;
|
||||
FFTContainer& operator=(const FFTContainer&) = delete;
|
||||
|
||||
~FFTContainer() {
|
||||
delete[] storage_;
|
||||
delete[] part_;
|
||||
delete[] temp_;
|
||||
}
|
||||
|
||||
void Start_Default(Digits X, int chunk_size, int theta, int omega);
|
||||
void Start(Digits X, int chunk_size, int theta, int omega);
|
||||
|
||||
void NormalizeAndRecombine(int omega, int m, RWDigits Z, int chunk_size);
|
||||
void CounterWeightAndRecombine(int theta, int m, RWDigits Z, int chunk_size);
|
||||
|
||||
void FFT_ReturnShuffledThreadsafe(int start, int len, int omega,
|
||||
digit_t* temp);
|
||||
void FFT_Recurse(int start, int half, int omega, digit_t* temp);
|
||||
|
||||
void BackwardFFT(int start, int len, int omega);
|
||||
void BackwardFFT_Threadsafe(int start, int len, int omega, digit_t* temp);
|
||||
|
||||
void PointwiseMultiply(const FFTContainer& other);
|
||||
void DoPointwiseMultiplication(const FFTContainer& other, int start, int end,
|
||||
digit_t* temp);
|
||||
|
||||
int length() const { return length_; }
|
||||
|
||||
private:
|
||||
const int n_; // Number of parts.
|
||||
const int K_; // Always length_ - 1.
|
||||
const int length_; // Length of each part, in digits.
|
||||
ProcessorImpl* processor_;
|
||||
digit_t* storage_; // Combined storage of all parts.
|
||||
digit_t** part_; // Pointers to each part.
|
||||
digit_t* temp_; // Temporary storage with size 2 * length_.
|
||||
};
|
||||
|
||||
inline void CopyAndZeroExtend(digit_t* dst, const digit_t* src,
|
||||
int digits_to_copy, size_t total_bytes) {
|
||||
size_t bytes_to_copy = digits_to_copy * sizeof(digit_t);
|
||||
memcpy(dst, src, bytes_to_copy);
|
||||
memset(dst + digits_to_copy, 0, total_bytes - bytes_to_copy);
|
||||
}
|
||||
|
||||
// Reads {X} into the FFTContainer's internal storage, dividing it into chunks
|
||||
// while doing so; then performs the forward FFT.
|
||||
void FFTContainer::Start_Default(Digits X, int chunk_size, int theta,
|
||||
int omega) {
|
||||
int len = X.len();
|
||||
const digit_t* pointer = X.digits();
|
||||
const size_t part_length_in_bytes = length_ * sizeof(digit_t);
|
||||
int current_theta = 0;
|
||||
int i = 0;
|
||||
for (; i < n_ && len > 0; i++, current_theta += theta) {
|
||||
chunk_size = std::min(chunk_size, len);
|
||||
if (current_theta != 0) {
|
||||
// Multiply with theta^i, and reduce modulo 2^K + 1.
|
||||
// We pass theta as a shift amount; it really means 2^theta.
|
||||
CopyAndZeroExtend(temp_, pointer, chunk_size, part_length_in_bytes);
|
||||
ShiftModFn(part_[i], temp_, current_theta, K_, chunk_size);
|
||||
} else {
|
||||
CopyAndZeroExtend(part_[i], pointer, chunk_size, part_length_in_bytes);
|
||||
}
|
||||
pointer += chunk_size;
|
||||
len -= chunk_size;
|
||||
}
|
||||
for (; i < n_; i++) {
|
||||
memset(part_[i], 0, part_length_in_bytes);
|
||||
}
|
||||
FFT_ReturnShuffledThreadsafe(0, n_, omega, temp_);
|
||||
}
|
||||
|
||||
// This version of Start is optimized for the case where ~half of the
|
||||
// container will be filled with padding zeros.
|
||||
void FFTContainer::Start(Digits X, int chunk_size, int theta, int omega) {
|
||||
int len = X.len();
|
||||
if (len > n_ * chunk_size / 2) {
|
||||
return Start_Default(X, chunk_size, theta, omega);
|
||||
}
|
||||
DCHECK(theta == 0); // NOLINT(readability/check)
|
||||
const digit_t* pointer = X.digits();
|
||||
const size_t part_length_in_bytes = length_ * sizeof(digit_t);
|
||||
int nhalf = n_ / 2;
|
||||
// Unrolled first iteration.
|
||||
CopyAndZeroExtend(part_[0], pointer, chunk_size, part_length_in_bytes);
|
||||
CopyAndZeroExtend(part_[nhalf], pointer, chunk_size, part_length_in_bytes);
|
||||
pointer += chunk_size;
|
||||
len -= chunk_size;
|
||||
int i = 1;
|
||||
for (; i < nhalf && len > 0; i++) {
|
||||
chunk_size = std::min(chunk_size, len);
|
||||
CopyAndZeroExtend(part_[i], pointer, chunk_size, part_length_in_bytes);
|
||||
int w = omega * i;
|
||||
ShiftModFn(part_[i + nhalf], part_[i], w, K_, chunk_size);
|
||||
pointer += chunk_size;
|
||||
len -= chunk_size;
|
||||
}
|
||||
for (; i < nhalf; i++) {
|
||||
memset(part_[i], 0, part_length_in_bytes);
|
||||
memset(part_[i + nhalf], 0, part_length_in_bytes);
|
||||
}
|
||||
FFT_Recurse(0, nhalf, omega, temp_);
|
||||
}
|
||||
|
||||
// Forward transformation.
|
||||
// We use the "DIF" aka "decimation in frequency" transform, because it
|
||||
// leaves the result in "bit reversed" order, which is precisely what we
|
||||
// need as input for the "DIT" aka "decimation in time" backwards transform.
|
||||
void FFTContainer::FFT_ReturnShuffledThreadsafe(int start, int len, int omega,
|
||||
digit_t* temp) {
|
||||
DCHECK((len & 1) == 0); // {len} must be even. NOLINT(readability/check)
|
||||
int half = len / 2;
|
||||
SumDiff(part_[start], part_[start + half], part_[start], part_[start + half],
|
||||
length_);
|
||||
for (int k = 1; k < half; k++) {
|
||||
SumDiff(part_[start + k], temp, part_[start + k], part_[start + half + k],
|
||||
length_);
|
||||
int w = omega * k;
|
||||
ShiftModFn(part_[start + half + k], temp, w, K_);
|
||||
}
|
||||
FFT_Recurse(start, half, omega, temp);
|
||||
}
|
||||
|
||||
// Recursive step of the above, factored out for additional callers.
|
||||
void FFTContainer::FFT_Recurse(int start, int half, int omega, digit_t* temp) {
|
||||
if (half > 1) {
|
||||
FFT_ReturnShuffledThreadsafe(start, half, 2 * omega, temp);
|
||||
FFT_ReturnShuffledThreadsafe(start + half, half, 2 * omega, temp);
|
||||
}
|
||||
}
|
||||
|
||||
// Backward transformation.
|
||||
// We use the "DIT" aka "decimation in time" transform here, because it
|
||||
// turns bit-reversed input into normally sorted output.
|
||||
void FFTContainer::BackwardFFT(int start, int len, int omega) {
|
||||
BackwardFFT_Threadsafe(start, len, omega, temp_);
|
||||
}
|
||||
|
||||
void FFTContainer::BackwardFFT_Threadsafe(int start, int len, int omega,
|
||||
digit_t* temp) {
|
||||
DCHECK((len & 1) == 0); // {len} must be even. NOLINT(readability/check)
|
||||
int half = len / 2;
|
||||
// Don't recurse for half == 2, as PointwiseMultiply already performed
|
||||
// the first level of the backwards FFT.
|
||||
if (half > 2) {
|
||||
BackwardFFT_Threadsafe(start, half, 2 * omega, temp);
|
||||
BackwardFFT_Threadsafe(start + half, half, 2 * omega, temp);
|
||||
}
|
||||
SumDiff(part_[start], part_[start + half], part_[start], part_[start + half],
|
||||
length_);
|
||||
for (int k = 1; k < half; k++) {
|
||||
int w = omega * (len - k);
|
||||
ShiftModFn(temp, part_[start + half + k], w, K_);
|
||||
SumDiff(part_[start + k], part_[start + half + k], part_[start + k], temp,
|
||||
length_);
|
||||
}
|
||||
}
|
||||
|
||||
// Recombines the result's parts into {Z}, after backwards FFT.
|
||||
void FFTContainer::NormalizeAndRecombine(int omega, int m, RWDigits Z,
|
||||
int chunk_size) {
|
||||
Z.Clear();
|
||||
int z_index = 0;
|
||||
const int shift = n_ * omega - m;
|
||||
for (int i = 0; i < n_; i++, z_index += chunk_size) {
|
||||
digit_t* part = part_[i];
|
||||
ShiftModFn(temp_, part, shift, K_);
|
||||
digit_t carry = 0;
|
||||
int zi = z_index;
|
||||
int j = 0;
|
||||
for (; j < length_ && zi < Z.len(); j++, zi++) {
|
||||
Z[zi] = digit_add3(Z[zi], temp_[j], carry, &carry);
|
||||
}
|
||||
for (; j < length_; j++) {
|
||||
DCHECK(temp_[j] == 0); // NOLINT(readability/check)
|
||||
}
|
||||
if (carry != 0) {
|
||||
DCHECK(zi < Z.len());
|
||||
Z[zi] = carry;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function for {CounterWeightAndRecombine} below.
|
||||
bool ShouldBeNegative(const digit_t* x, int xlen, digit_t threshold, int s) {
|
||||
if (x[2 * s] >= threshold) return true;
|
||||
for (int i = 2 * s + 1; i < xlen; i++) {
|
||||
if (x[i] > 0) return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Same as {NormalizeAndRecombine} above, but for the needs of the recursive
|
||||
// invocation ("inner layer") of FFT multiplication, where an additional
|
||||
// counter-weighting step is required.
|
||||
void FFTContainer::CounterWeightAndRecombine(int theta, int m, RWDigits Z,
|
||||
int s) {
|
||||
Z.Clear();
|
||||
int z_index = 0;
|
||||
for (int k = 0; k < n_; k++, z_index += s) {
|
||||
int shift = -theta * k - m;
|
||||
if (shift < 0) shift += 2 * n_ * theta;
|
||||
DCHECK(shift >= 0); // NOLINT(readability/check)
|
||||
digit_t* input = part_[k];
|
||||
ShiftModFn(temp_, input, shift, K_);
|
||||
int remaining_z = Z.len() - z_index;
|
||||
if (ShouldBeNegative(temp_, length_, k + 1, s)) {
|
||||
// Subtract F_n from input before adding to result. We use the following
|
||||
// transformation (knowing that X < F_n):
|
||||
// Z + (X - F_n) == Z - (F_n - X)
|
||||
digit_t borrow_z = 0;
|
||||
digit_t borrow_Fn = 0;
|
||||
{
|
||||
// i == 0:
|
||||
digit_t d = digit_sub(1, temp_[0], &borrow_Fn);
|
||||
Z[z_index] = digit_sub(Z[z_index], d, &borrow_z);
|
||||
}
|
||||
int i = 1;
|
||||
for (; i < K_ && i < remaining_z; i++) {
|
||||
digit_t d = digit_sub2(0, temp_[i], borrow_Fn, &borrow_Fn);
|
||||
Z[z_index + i] = digit_sub2(Z[z_index + i], d, borrow_z, &borrow_z);
|
||||
}
|
||||
DCHECK(i == K_ && K_ == length_ - 1);
|
||||
for (; i < length_ && i < remaining_z; i++) {
|
||||
digit_t d = digit_sub2(1, temp_[i], borrow_Fn, &borrow_Fn);
|
||||
Z[z_index + i] = digit_sub2(Z[z_index + i], d, borrow_z, &borrow_z);
|
||||
}
|
||||
DCHECK(borrow_Fn == 0); // NOLINT(readability/check)
|
||||
for (; borrow_z > 0 && i < remaining_z; i++) {
|
||||
Z[z_index + i] = digit_sub(Z[z_index + i], borrow_z, &borrow_z);
|
||||
}
|
||||
} else {
|
||||
digit_t carry = 0;
|
||||
int i = 0;
|
||||
for (; i < length_ && i < remaining_z; i++) {
|
||||
Z[z_index + i] = digit_add3(Z[z_index + i], temp_[i], carry, &carry);
|
||||
}
|
||||
for (; i < length_; i++) {
|
||||
DCHECK(temp_[i] == 0); // NOLINT(readability/check)
|
||||
}
|
||||
for (; carry > 0 && i < remaining_z; i++) {
|
||||
Z[z_index + i] = digit_add2(Z[z_index + i], carry, &carry);
|
||||
}
|
||||
// {carry} might be != 0 here if Z was negative before. That's fine.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Main FFT function for recursive invocations ("inner layer").
|
||||
void MultiplyFFT_Inner(RWDigits Z, Digits X, Digits Y, const Parameters& params,
|
||||
ProcessorImpl* processor) {
|
||||
int omega = 2 * params.r; // really: 2^(2r)
|
||||
int theta = params.r; // really: 2^r
|
||||
|
||||
FFTContainer a(params.n, params.K, processor);
|
||||
a.Start_Default(X, params.s, theta, omega);
|
||||
FFTContainer b(params.n, params.K, processor);
|
||||
b.Start_Default(Y, params.s, theta, omega);
|
||||
|
||||
a.PointwiseMultiply(b);
|
||||
if (processor->should_terminate()) return;
|
||||
|
||||
FFTContainer& c = a;
|
||||
c.BackwardFFT(0, params.n, omega);
|
||||
|
||||
c.CounterWeightAndRecombine(theta, params.m, Z, params.s);
|
||||
}
|
||||
|
||||
// Actual implementation of pointwise multiplications.
|
||||
void FFTContainer::DoPointwiseMultiplication(const FFTContainer& other,
|
||||
int start, int end,
|
||||
digit_t* temp) {
|
||||
// The (K_ & 3) != 0 condition makes sure that the inner FFT gets
|
||||
// to split the work into at least 4 chunks.
|
||||
bool use_fft = length_ >= kFftInnerThreshold && (K_ & 3) == 0;
|
||||
Parameters params;
|
||||
if (use_fft) ComputeParameters_Inner(K_, ¶ms);
|
||||
RWDigits result(temp, 2 * length_);
|
||||
for (int i = start; i < end; i++) {
|
||||
Digits A(part_[i], length_);
|
||||
Digits B(other.part_[i], length_);
|
||||
if (use_fft) {
|
||||
MultiplyFFT_Inner(result, A, B, params, processor_);
|
||||
} else {
|
||||
processor_->Multiply(result, A, B);
|
||||
}
|
||||
if (processor_->should_terminate()) return;
|
||||
ModFnDoubleWidth(part_[i], result.digits(), length_);
|
||||
// To improve cache friendliness, we perform the first level of the
|
||||
// backwards FFT here.
|
||||
if ((i & 1) == 1) {
|
||||
SumDiff(part_[i - 1], part_[i], part_[i - 1], part_[i], length_);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convenient entry point for pointwise multiplications.
|
||||
void FFTContainer::PointwiseMultiply(const FFTContainer& other) {
|
||||
DCHECK(n_ == other.n_);
|
||||
DoPointwiseMultiplication(other, 0, n_, temp_);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Part 4: Tying everything together into a multiplication algorithm.
|
||||
|
||||
// TODO(jkummerow): Consider doing a "Mersenne transform" and CRT reconstruction
|
||||
// of the final result. Might yield a few percent of perf improvement.
|
||||
|
||||
// TODO(jkummerow): Consider implementing the "sqrt(2) trick".
|
||||
// Gaudry/Kruppa/Zimmerman report that it saved them around 10%.
|
||||
|
||||
void ProcessorImpl::MultiplyFFT(RWDigits Z, Digits X, Digits Y) {
|
||||
Parameters params;
|
||||
int m = GetParameters(X.len() + Y.len(), ¶ms);
|
||||
int omega = params.r; // really: 2^r
|
||||
|
||||
FFTContainer a(params.n, params.K, this);
|
||||
a.Start(X, params.s, 0, omega);
|
||||
if (X == Y) {
|
||||
// Squaring.
|
||||
a.PointwiseMultiply(a);
|
||||
} else {
|
||||
FFTContainer b(params.n, params.K, this);
|
||||
b.Start(Y, params.s, 0, omega);
|
||||
a.PointwiseMultiply(b);
|
||||
}
|
||||
if (should_terminate()) return;
|
||||
|
||||
a.BackwardFFT(0, params.n, omega);
|
||||
a.NormalizeAndRecombine(omega, m, Z, params.s);
|
||||
}
|
||||
|
||||
} // namespace bigint
|
||||
} // namespace v8
|
@ -123,11 +123,8 @@ void ProcessorImpl::Toom3Main(RWDigits Z, Digits X, Digits Y) {
|
||||
|
||||
// Phase 3a: Pointwise multiplication, steps 0, 1, m1.
|
||||
Multiply(r_0, X0, Y0);
|
||||
if (should_terminate()) return;
|
||||
Multiply(r_1, p_1, q_1);
|
||||
if (should_terminate()) return;
|
||||
Multiply(r_m1, p_m1, q_m1);
|
||||
if (should_terminate()) return;
|
||||
bool r_m1_sign = p_m1_sign != q_m1_sign;
|
||||
|
||||
// Phase 2b: Evaluation, steps m2 and inf.
|
||||
@ -152,14 +149,12 @@ void ProcessorImpl::Toom3Main(RWDigits Z, Digits X, Digits Y) {
|
||||
MARK_INVALID(p_m1);
|
||||
MARK_INVALID(q_m1);
|
||||
Multiply(r_m2, p_m2, q_m2);
|
||||
if (should_terminate()) return;
|
||||
bool r_m2_sign = p_m2_sign != q_m2_sign;
|
||||
|
||||
RWDigits r_inf(t + r_len, r_len);
|
||||
MARK_INVALID(p_m2);
|
||||
MARK_INVALID(q_m2);
|
||||
Multiply(r_inf, X2, Y2);
|
||||
if (should_terminate()) return;
|
||||
|
||||
// Phase 4: Interpolation.
|
||||
Digits R0 = r_0;
|
||||
@ -215,7 +210,6 @@ void ProcessorImpl::MultiplyToomCook(RWDigits Z, Digits X, Digits Y) {
|
||||
if (X.len() > Y.len()) {
|
||||
ScratchDigits T(2 * k);
|
||||
for (int i = k; i < X.len(); i += k) {
|
||||
if (should_terminate()) return;
|
||||
Digits Xi(X, i, k);
|
||||
// TODO(jkummerow): would it be a measurable improvement to craft a
|
||||
// "ToomChunk" method in the style of {KaratsubaChunk}?
|
||||
|
@ -27,8 +27,9 @@ int PrintHelp(char** argv) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
#define TESTS(V) \
|
||||
V(kKaratsuba, "karatsuba") V(kBurnikel, "burnikel") V(kToom, "toom")
|
||||
#define TESTS(V) \
|
||||
V(kKaratsuba, "karatsuba") \
|
||||
V(kBurnikel, "burnikel") V(kToom, "toom") V(kFFT, "fft")
|
||||
|
||||
enum Operation { kNoOp, kList, kTest };
|
||||
|
||||
@ -168,6 +169,10 @@ class Runner {
|
||||
for (int i = 0; i < runs_; i++) {
|
||||
TestToom(&count);
|
||||
}
|
||||
} else if (test_ == kFFT) {
|
||||
for (int i = 0; i < runs_; i++) {
|
||||
TestFFT(&count);
|
||||
}
|
||||
} else {
|
||||
DCHECK(false); // Unreachable.
|
||||
}
|
||||
@ -225,6 +230,40 @@ class Runner {
|
||||
#endif // V8_ADVANCED_BIGINT_ALGORITHMS
|
||||
}
|
||||
|
||||
void TestFFT(int* count) {
|
||||
#if V8_ADVANCED_BIGINT_ALGORITHMS
|
||||
// Larger multiplications are slower, so to keep individual runs fast,
|
||||
// we test a few random samples. With build bots running 24/7, we'll
|
||||
// get decent coverage over time.
|
||||
uint64_t random_bits = rng_.NextUint64();
|
||||
int min = kFftThreshold - static_cast<int>(random_bits & 1023);
|
||||
random_bits >>= 10;
|
||||
int max = kFftThreshold + static_cast<int>(random_bits & 1023);
|
||||
random_bits >>= 10;
|
||||
// If delta is too small, then this run gets too slow. If it happened
|
||||
// to be zero, we'd even loop forever!
|
||||
int delta = 10 + (random_bits & 127);
|
||||
std::cout << "min " << min << " max " << max << " delta " << delta << "\n";
|
||||
for (int right_size = min; right_size <= max; right_size += delta) {
|
||||
for (int left_size = right_size; left_size <= max; left_size += delta) {
|
||||
ScratchDigits A(left_size);
|
||||
ScratchDigits B(right_size);
|
||||
int result_len = MultiplyResultLength(A, B);
|
||||
ScratchDigits result(result_len);
|
||||
ScratchDigits result_toom(result_len);
|
||||
GenerateRandom(A);
|
||||
GenerateRandom(B);
|
||||
processor()->MultiplyFFT(result, A, B);
|
||||
// Using Toom-Cook as reference.
|
||||
processor()->MultiplyToomCook(result_toom, A, B);
|
||||
AssertEquals(A, B, result_toom, result);
|
||||
if (error_) return;
|
||||
(*count)++;
|
||||
}
|
||||
}
|
||||
#endif // V8_ADVANCED_BIGINT_ALGORITHMS
|
||||
}
|
||||
|
||||
void TestBurnikel(int* count) {
|
||||
// Start small to save test execution time.
|
||||
constexpr int kMin = kBurnikelThreshold / 2;
|
||||
|
Loading…
Reference in New Issue
Block a user