1 // Copyright 2021 the V8 project authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 // FFT-based multiplication, due to Schönhage and Strassen.
6 // This implementation mostly follows the description given in:
7 // Christoph Lüders: Fast Multiplication of Large Integers,
8 // http://arxiv.org/abs/1503.04955
9
10 #include "src/bigint/bigint-internal.h"
11 #include "src/bigint/digit-arithmetic.h"
12 #include "src/bigint/util.h"
13 #include "src/bigint/vector-arithmetic.h"
14
15 namespace v8 {
16 namespace bigint {
17
18 namespace {
19
20 ////////////////////////////////////////////////////////////////////////////////
21 // Part 1: Functions for "mod F_n" arithmetic.
22 // F_n is of the shape 2^K + 1, and for convenience we use K to count the
23 // number of digits rather than the number of bits, so F_n (or K) are implicit
24 // and deduced from the length {len} of the digits array.
25
26 // Helper function for {ModFn} below.
ModFn_Helper(digit_t * x,int len,signed_digit_t high)27 void ModFn_Helper(digit_t* x, int len, signed_digit_t high) {
28 if (high > 0) {
29 digit_t borrow = high;
30 x[len - 1] = 0;
31 for (int i = 0; i < len; i++) {
32 x[i] = digit_sub(x[i], borrow, &borrow);
33 if (borrow == 0) break;
34 }
35 } else {
36 digit_t carry = -high;
37 x[len - 1] = 0;
38 for (int i = 0; i < len; i++) {
39 x[i] = digit_add2(x[i], carry, &carry);
40 if (carry == 0) break;
41 }
42 }
43 }
44
45 // {x} := {x} mod F_n, assuming that {x} is "slightly" larger than F_n (e.g.
46 // after addition of two numbers that were mod-F_n-normalized before).
ModFn(digit_t * x,int len)47 void ModFn(digit_t* x, int len) {
48 int K = len - 1;
49 signed_digit_t high = x[K];
50 if (high == 0) return;
51 ModFn_Helper(x, len, high);
52 high = x[K];
53 if (high == 0) return;
54 DCHECK(high == 1 || high == -1);
55 ModFn_Helper(x, len, high);
56 high = x[K];
57 if (high == -1) ModFn_Helper(x, len, high);
58 }
59
60 // {dest} := {src} mod F_n, assuming that {src} is about twice as long as F_n
61 // (e.g. after multiplication of two numbers that were mod-F_n-normalized
62 // before).
63 // {len} is length of {dest}; {src} is twice as long.
ModFnDoubleWidth(digit_t * dest,const digit_t * src,int len)64 void ModFnDoubleWidth(digit_t* dest, const digit_t* src, int len) {
65 int K = len - 1;
66 digit_t borrow = 0;
67 for (int i = 0; i < K; i++) {
68 dest[i] = digit_sub2(src[i], src[i + K], borrow, &borrow);
69 }
70 dest[K] = digit_sub2(0, src[2 * K], borrow, &borrow);
71 // {borrow} may be non-zero here, that's OK as {ModFn} will take care of it.
72 ModFn(dest, len);
73 }
74
75 // Sets {sum} := {a} + {b} and {diff} := {a} - {b}, which is more efficient
76 // than computing sum and difference separately. Applies "mod F_n" normalization
77 // to both results.
SumDiff(digit_t * sum,digit_t * diff,const digit_t * a,const digit_t * b,int len)78 void SumDiff(digit_t* sum, digit_t* diff, const digit_t* a, const digit_t* b,
79 int len) {
80 digit_t carry = 0;
81 digit_t borrow = 0;
82 for (int i = 0; i < len; i++) {
83 // Read both values first, because inputs and outputs can overlap.
84 digit_t ai = a[i];
85 digit_t bi = b[i];
86 sum[i] = digit_add3(ai, bi, carry, &carry);
87 diff[i] = digit_sub2(ai, bi, borrow, &borrow);
88 }
89 ModFn(sum, len);
90 ModFn(diff, len);
91 }
92
93 // {result} := ({input} << shift) mod F_n, where shift >= K.
ShiftModFn_Large(digit_t * result,const digit_t * input,int digit_shift,int bits_shift,int K)94 void ShiftModFn_Large(digit_t* result, const digit_t* input, int digit_shift,
95 int bits_shift, int K) {
96 // If {digit_shift} is greater than K, we use the following transformation
97 // (where, since everything is mod 2^K + 1, we are allowed to add or
98 // subtract any multiple of 2^K + 1 at any time):
99 // x * 2^{K+m} mod 2^K + 1
100 // == x * 2^K * 2^m - (2^K + 1)*(x * 2^m) mod 2^K + 1
101 // == x * 2^K * 2^m - x * 2^K * 2^m - x * 2^m mod 2^K + 1
102 // == -x * 2^m mod 2^K + 1
103 // So the flow is the same as for m < K, but we invert the subtraction's
104 // operands. In order to avoid underflow, we virtually initialize the
105 // result to 2^K + 1:
106 // input = [ iK ][iK-1] .... .... [ i1 ][ i0 ]
107 // result = [ 1][0000] .... .... [0000][0001]
108 // + [ iK ] .... [ iX ]
109 // - [iX-1] .... [ i0 ]
110 DCHECK(digit_shift >= K);
111 digit_shift -= K;
112 digit_t borrow = 0;
113 if (bits_shift == 0) {
114 digit_t carry = 1;
115 for (int i = 0; i < digit_shift; i++) {
116 result[i] = digit_add2(input[i + K - digit_shift], carry, &carry);
117 }
118 result[digit_shift] = digit_sub(input[K] + carry, input[0], &borrow);
119 for (int i = digit_shift + 1; i < K; i++) {
120 digit_t d = input[i - digit_shift];
121 result[i] = digit_sub2(0, d, borrow, &borrow);
122 }
123 } else {
124 digit_t add_carry = 1;
125 digit_t input_carry =
126 input[K - digit_shift - 1] >> (kDigitBits - bits_shift);
127 for (int i = 0; i < digit_shift; i++) {
128 digit_t d = input[i + K - digit_shift];
129 digit_t summand = (d << bits_shift) | input_carry;
130 result[i] = digit_add2(summand, add_carry, &add_carry);
131 input_carry = d >> (kDigitBits - bits_shift);
132 }
133 {
134 // result[digit_shift] = (add_carry + iK_part) - i0_part
135 digit_t d = input[K];
136 digit_t iK_part = (d << bits_shift) | input_carry;
137 digit_t iK_carry = d >> (kDigitBits - bits_shift);
138 digit_t sum = digit_add2(add_carry, iK_part, &add_carry);
139 // {iK_carry} is less than a full digit, so we can merge {add_carry}
140 // into it without overflow.
141 iK_carry += add_carry;
142 d = input[0];
143 digit_t i0_part = d << bits_shift;
144 result[digit_shift] = digit_sub(sum, i0_part, &borrow);
145 input_carry = d >> (kDigitBits - bits_shift);
146 if (digit_shift + 1 < K) {
147 d = input[1];
148 digit_t subtrahend = (d << bits_shift) | input_carry;
149 result[digit_shift + 1] =
150 digit_sub2(iK_carry, subtrahend, borrow, &borrow);
151 input_carry = d >> (kDigitBits - bits_shift);
152 }
153 }
154 for (int i = digit_shift + 2; i < K; i++) {
155 digit_t d = input[i - digit_shift];
156 digit_t subtrahend = (d << bits_shift) | input_carry;
157 result[i] = digit_sub2(0, subtrahend, borrow, &borrow);
158 input_carry = d >> (kDigitBits - bits_shift);
159 }
160 }
161 // The virtual 1 in result[K] should be eliminated by {borrow}. If there
162 // is no borrow, then the virtual initialization was too much. Subtract
163 // 2^K + 1.
164 result[K] = 0;
165 if (borrow != 1) {
166 borrow = 1;
167 for (int i = 0; i < K; i++) {
168 result[i] = digit_sub(result[i], borrow, &borrow);
169 if (borrow == 0) break;
170 }
171 if (borrow != 0) {
172 // The result must be 2^K.
173 for (int i = 0; i < K; i++) result[i] = 0;
174 result[K] = 1;
175 }
176 }
177 }
178
179 // Sets {result} := {input} * 2^{power_of_two} mod 2^{K} + 1.
180 // This function is highly relevant for overall performance.
ShiftModFn(digit_t * result,const digit_t * input,int power_of_two,int K,int zero_above=0x7FFFFFFF)181 void ShiftModFn(digit_t* result, const digit_t* input, int power_of_two, int K,
182 int zero_above = 0x7FFFFFFF) {
183 // The modulo-reduction amounts to a subtraction, which we combine
184 // with the shift as follows:
185 // input = [ iK ][iK-1] .... .... [ i1 ][ i0 ]
186 // result = [iX-1] .... [ i0 ] <---------- shift by {power_of_two}
187 // - [ iK ] .... [ iX ]
188 // where "X" is the index "K - digit_shift".
189 int digit_shift = power_of_two / kDigitBits;
190 int bits_shift = power_of_two % kDigitBits;
191 // By an analogous construction to the "digit_shift >= K" case,
192 // it turns out that:
193 // x * 2^{2K+m} == x * 2^m mod 2^K + 1.
194 while (digit_shift >= 2 * K) digit_shift -= 2 * K; // Faster than '%'!
195 if (digit_shift >= K) {
196 return ShiftModFn_Large(result, input, digit_shift, bits_shift, K);
197 }
198 digit_t borrow = 0;
199 if (bits_shift == 0) {
200 // We do a single pass over {input}, starting by copying digits [i1] to
201 // [iX-1] to result indices digit_shift+1 to K-1.
202 int i = 1;
203 // Read input digits unless we know they are zero.
204 int cap = std::min(K - digit_shift, zero_above);
205 for (; i < cap; i++) {
206 result[i + digit_shift] = input[i];
207 }
208 // Any remaining work can hard-code the knowledge that input[i] == 0.
209 for (; i < K - digit_shift; i++) {
210 DCHECK(input[i] == 0);
211 result[i + digit_shift] = 0;
212 }
213 // Second phase: subtract input digits [iX] to [iK] from (virtually) zero-
214 // initialized result indices 0 to digit_shift-1.
215 cap = std::min(K, zero_above);
216 for (; i < cap; i++) {
217 digit_t d = input[i];
218 result[i - K + digit_shift] = digit_sub2(0, d, borrow, &borrow);
219 }
220 // Any remaining work can hard-code the knowledge that input[i] == 0.
221 for (; i < K; i++) {
222 DCHECK(input[i] == 0);
223 result[i - K + digit_shift] = digit_sub(0, borrow, &borrow);
224 }
225 // Last step: subtract [iK] from [i0] and store at result index digit_shift.
226 result[digit_shift] = digit_sub2(input[0], input[K], borrow, &borrow);
227 } else {
228 // Same flow as before, but taking bits_shift != 0 into account.
229 // First phase: result indices digit_shift+1 to K.
230 digit_t carry = 0;
231 int i = 0;
232 // Read input digits unless we know they are zero.
233 int cap = std::min(K - digit_shift, zero_above);
234 for (; i < cap; i++) {
235 digit_t d = input[i];
236 result[i + digit_shift] = (d << bits_shift) | carry;
237 carry = d >> (kDigitBits - bits_shift);
238 }
239 // Any remaining work can hard-code the knowledge that input[i] == 0.
240 for (; i < K - digit_shift; i++) {
241 DCHECK(input[i] == 0);
242 result[i + digit_shift] = carry;
243 carry = 0;
244 }
245 // Second phase: result indices 0 to digit_shift - 1.
246 cap = std::min(K, zero_above);
247 for (; i < cap; i++) {
248 digit_t d = input[i];
249 result[i - K + digit_shift] =
250 digit_sub2(0, (d << bits_shift) | carry, borrow, &borrow);
251 carry = d >> (kDigitBits - bits_shift);
252 }
253 // Any remaining work can hard-code the knowledge that input[i] == 0.
254 if (i < K) {
255 DCHECK(input[i] == 0);
256 result[i - K + digit_shift] = digit_sub2(0, carry, borrow, &borrow);
257 carry = 0;
258 i++;
259 }
260 for (; i < K; i++) {
261 DCHECK(input[i] == 0);
262 result[i - K + digit_shift] = digit_sub(0, borrow, &borrow);
263 }
264 // Last step: compute result[digit_shift].
265 digit_t d = input[K];
266 result[digit_shift] = digit_sub2(
267 result[digit_shift], (d << bits_shift) | carry, borrow, &borrow);
268 // No carry left.
269 DCHECK((d >> (kDigitBits - bits_shift)) == 0);
270 }
271 result[K] = 0;
272 for (int i = digit_shift + 1; i <= K && borrow > 0; i++) {
273 result[i] = digit_sub(result[i], borrow, &borrow);
274 }
275 if (borrow > 0) {
276 // Underflow means we subtracted too much. Add 2^K + 1.
277 digit_t carry = 1;
278 for (int i = 0; i <= K; i++) {
279 result[i] = digit_add2(result[i], carry, &carry);
280 if (carry == 0) break;
281 }
282 result[K] = digit_add2(result[K], 1, &carry);
283 }
284 }
285
286 ////////////////////////////////////////////////////////////////////////////////
287 // Part 2: FFT-based multiplication is very sensitive to appropriate choice
288 // of parameters. The following functions choose the parameters that the
289 // subsequent actual computation will use. This is partially based on formal
290 // constraints and partially on experimentally-determined heuristics.
291
292 struct Parameters {
293 // We never use the default values, but skipping zero-initialization
294 // of these fields saddens and confuses MSan.
295 int m{0};
296 int K{0};
297 int n{0};
298 int s{0};
299 int r{0};
300 };
301
302 // Computes parameters for the main calculation, given a bit length {N} and
303 // an {m}. See the paper for details.
ComputeParameters(int N,int m,Parameters * params)304 void ComputeParameters(int N, int m, Parameters* params) {
305 N *= kDigitBits;
306 int n = 1 << m; // 2^m
307 int nhalf = n >> 1;
308 int s = (N + n - 1) >> m; // ceil(N/n)
309 s = RoundUp(s, kDigitBits);
310 int K = m + 2 * s + 1; // K must be at least this big...
311 K = RoundUp(K, nhalf); // ...and a multiple of n/2.
312 int r = K >> (m - 1); // Which multiple?
313
314 // We want recursive calls to make progress, so force K to be a multiple
315 // of 8 if it's above the recursion threshold. Otherwise, K must be a
316 // multiple of kDigitBits.
317 const int threshold = (K + 1 >= kFftInnerThreshold * kDigitBits)
318 ? 3 + kLog2DigitBits
319 : kLog2DigitBits;
320 int K_tz = CountTrailingZeros(K);
321 while (K_tz < threshold) {
322 K += (1 << K_tz);
323 r = K >> (m - 1);
324 K_tz = CountTrailingZeros(K);
325 }
326
327 DCHECK(K % kDigitBits == 0);
328 DCHECK(s % kDigitBits == 0);
329 params->K = K / kDigitBits;
330 params->s = s / kDigitBits;
331 params->n = n;
332 params->r = r;
333 }
334
335 // Computes parameters for recursive invocations ("inner layer").
ComputeParameters_Inner(int N,Parameters * params)336 void ComputeParameters_Inner(int N, Parameters* params) {
337 int max_m = CountTrailingZeros(N);
338 int N_bits = BitLength(N);
339 int m = N_bits - 4; // Don't let s get too small.
340 m = std::min(max_m, m);
341 N *= kDigitBits;
342 int n = 1 << m; // 2^m
343 // We can't round up s in the inner layer, because N = n*s is fixed.
344 int s = N >> m;
345 DCHECK(N == s * n);
346 int K = m + 2 * s + 1; // K must be at least this big...
347 K = RoundUp(K, n); // ...and a multiple of n and kDigitBits.
348 K = RoundUp(K, kDigitBits);
349 params->r = K >> m; // Which multiple?
350 DCHECK(K % kDigitBits == 0);
351 DCHECK(s % kDigitBits == 0);
352 params->K = K / kDigitBits;
353 params->s = s / kDigitBits;
354 params->n = n;
355 params->m = m;
356 }
357
PredictInnerK(int N)358 int PredictInnerK(int N) {
359 Parameters params;
360 ComputeParameters_Inner(N, ¶ms);
361 return params.K;
362 }
363
364 // Applies heuristics to decide whether {m} should be decremented, by looking
365 // at what would happen to {K} and {s} if {m} was decremented.
ShouldDecrementM(const Parameters & current,const Parameters & next,const Parameters & after_next)366 bool ShouldDecrementM(const Parameters& current, const Parameters& next,
367 const Parameters& after_next) {
368 // K == 64 seems to work particularly well.
369 if (current.K == 64 && next.K >= 112) return false;
370 // Small values for s are never efficient.
371 if (current.s < 6) return true;
372 // The time is roughly determined by K * n. When we decrement m, then
373 // n always halves, and K usually gets bigger, by up to 2x.
374 // For not-quite-so-small s, look at how much bigger K would get: if
375 // the K increase is small enough, making n smaller is worth it.
376 // Empirically, it's most meaningful to look at the K *after* next.
377 // The specific threshold values have been chosen by running many
378 // benchmarks on inputs of many sizes, and manually selecting thresholds
379 // that seemed to produce good results.
380 double factor = static_cast<double>(after_next.K) / current.K;
381 if ((current.s == 6 && factor < 3.85) || // --
382 (current.s == 7 && factor < 3.73) || // --
383 (current.s == 8 && factor < 3.55) || // --
384 (current.s == 9 && factor < 3.50) || // --
385 factor < 3.4) {
386 return true;
387 }
388 // If K is just below the recursion threshold, make sure we do recurse,
389 // unless doing so would be particularly inefficient (large inner_K).
390 // If K is just above the recursion threshold, doubling it often makes
391 // the inner call more efficient.
392 if (current.K >= 160 && current.K < 250 && PredictInnerK(next.K) < 28) {
393 return true;
394 }
395 // If we found no reason to decrement, keep m as large as possible.
396 return false;
397 }
398
399 // Decides what parameters to use for a given input bit length {N}.
400 // Returns the chosen m.
GetParameters(int N,Parameters * params)401 int GetParameters(int N, Parameters* params) {
402 int N_bits = BitLength(N);
403 int max_m = N_bits - 3; // Larger m make s too small.
404 max_m = std::max(kLog2DigitBits, max_m); // Smaller m break the logic below.
405 int m = max_m;
406 Parameters current;
407 ComputeParameters(N, m, ¤t);
408 Parameters next;
409 ComputeParameters(N, m - 1, &next);
410 while (m > 2) {
411 Parameters after_next;
412 ComputeParameters(N, m - 2, &after_next);
413 if (ShouldDecrementM(current, next, after_next)) {
414 m--;
415 current = next;
416 next = after_next;
417 } else {
418 break;
419 }
420 }
421 *params = current;
422 return m;
423 }
424
425 ////////////////////////////////////////////////////////////////////////////////
426 // Part 3: Fast Fourier Transformation.
427
428 class FFTContainer {
429 public:
430 // {n} is the number of chunks, whose length is {K}+1.
431 // {K} determines F_n = 2^(K * kDigitBits) + 1.
FFTContainer(int n,int K,ProcessorImpl * processor)432 FFTContainer(int n, int K, ProcessorImpl* processor)
433 : n_(n), K_(K), length_(K + 1), processor_(processor) {
434 storage_ = new digit_t[length_ * n_];
435 part_ = new digit_t*[n_];
436 digit_t* ptr = storage_;
437 for (int i = 0; i < n; i++, ptr += length_) {
438 part_[i] = ptr;
439 }
440 temp_ = new digit_t[length_ * 2];
441 }
442 FFTContainer() = delete;
443 FFTContainer(const FFTContainer&) = delete;
444 FFTContainer& operator=(const FFTContainer&) = delete;
445
~FFTContainer()446 ~FFTContainer() {
447 delete[] storage_;
448 delete[] part_;
449 delete[] temp_;
450 }
451
452 void Start_Default(Digits X, int chunk_size, int theta, int omega);
453 void Start(Digits X, int chunk_size, int theta, int omega);
454
455 void NormalizeAndRecombine(int omega, int m, RWDigits Z, int chunk_size);
456 void CounterWeightAndRecombine(int theta, int m, RWDigits Z, int chunk_size);
457
458 void FFT_ReturnShuffledThreadsafe(int start, int len, int omega,
459 digit_t* temp);
460 void FFT_Recurse(int start, int half, int omega, digit_t* temp);
461
462 void BackwardFFT(int start, int len, int omega);
463 void BackwardFFT_Threadsafe(int start, int len, int omega, digit_t* temp);
464
465 void PointwiseMultiply(const FFTContainer& other);
466 void DoPointwiseMultiplication(const FFTContainer& other, int start, int end,
467 digit_t* temp);
468
length() const469 int length() const { return length_; }
470
471 private:
472 const int n_; // Number of parts.
473 const int K_; // Always length_ - 1.
474 const int length_; // Length of each part, in digits.
475 ProcessorImpl* processor_;
476 digit_t* storage_; // Combined storage of all parts.
477 digit_t** part_; // Pointers to each part.
478 digit_t* temp_; // Temporary storage with size 2 * length_.
479 };
480
CopyAndZeroExtend(digit_t * dst,const digit_t * src,int digits_to_copy,size_t total_bytes)481 inline void CopyAndZeroExtend(digit_t* dst, const digit_t* src,
482 int digits_to_copy, size_t total_bytes) {
483 size_t bytes_to_copy = digits_to_copy * sizeof(digit_t);
484 memcpy(dst, src, bytes_to_copy);
485 memset(dst + digits_to_copy, 0, total_bytes - bytes_to_copy);
486 }
487
488 // Reads {X} into the FFTContainer's internal storage, dividing it into chunks
489 // while doing so; then performs the forward FFT.
Start_Default(Digits X,int chunk_size,int theta,int omega)490 void FFTContainer::Start_Default(Digits X, int chunk_size, int theta,
491 int omega) {
492 int len = X.len();
493 const digit_t* pointer = X.digits();
494 const size_t part_length_in_bytes = length_ * sizeof(digit_t);
495 int current_theta = 0;
496 int i = 0;
497 for (; i < n_ && len > 0; i++, current_theta += theta) {
498 chunk_size = std::min(chunk_size, len);
499 // For invocations via MultiplyFFT_Inner, X.len() == n_ * chunk_size + 1,
500 // because the outer layer's "K" is passed as the inner layer's "N".
501 // Since X is (mod Fn)-normalized on the outer layer, there is the rare
502 // corner case where X[n_ * chunk_size] == 1. Detect that case, and handle
503 // the extra bit as part of the last chunk; we always have the space.
504 if (i == n_ - 1 && len == chunk_size + 1) {
505 DCHECK(X[n_ * chunk_size] <= 1);
506 DCHECK(length_ >= chunk_size + 1);
507 chunk_size++;
508 }
509 if (current_theta != 0) {
510 // Multiply with theta^i, and reduce modulo 2^K + 1.
511 // We pass theta as a shift amount; it really means 2^theta.
512 CopyAndZeroExtend(temp_, pointer, chunk_size, part_length_in_bytes);
513 ShiftModFn(part_[i], temp_, current_theta, K_, chunk_size);
514 } else {
515 CopyAndZeroExtend(part_[i], pointer, chunk_size, part_length_in_bytes);
516 }
517 pointer += chunk_size;
518 len -= chunk_size;
519 }
520 DCHECK(len == 0);
521 for (; i < n_; i++) {
522 memset(part_[i], 0, part_length_in_bytes);
523 }
524 FFT_ReturnShuffledThreadsafe(0, n_, omega, temp_);
525 }
526
527 // This version of Start is optimized for the case where ~half of the
528 // container will be filled with padding zeros.
Start(Digits X,int chunk_size,int theta,int omega)529 void FFTContainer::Start(Digits X, int chunk_size, int theta, int omega) {
530 int len = X.len();
531 if (len > n_ * chunk_size / 2) {
532 return Start_Default(X, chunk_size, theta, omega);
533 }
534 DCHECK(theta == 0);
535 const digit_t* pointer = X.digits();
536 const size_t part_length_in_bytes = length_ * sizeof(digit_t);
537 int nhalf = n_ / 2;
538 // Unrolled first iteration.
539 CopyAndZeroExtend(part_[0], pointer, chunk_size, part_length_in_bytes);
540 CopyAndZeroExtend(part_[nhalf], pointer, chunk_size, part_length_in_bytes);
541 pointer += chunk_size;
542 len -= chunk_size;
543 int i = 1;
544 for (; i < nhalf && len > 0; i++) {
545 chunk_size = std::min(chunk_size, len);
546 CopyAndZeroExtend(part_[i], pointer, chunk_size, part_length_in_bytes);
547 int w = omega * i;
548 ShiftModFn(part_[i + nhalf], part_[i], w, K_, chunk_size);
549 pointer += chunk_size;
550 len -= chunk_size;
551 }
552 for (; i < nhalf; i++) {
553 memset(part_[i], 0, part_length_in_bytes);
554 memset(part_[i + nhalf], 0, part_length_in_bytes);
555 }
556 FFT_Recurse(0, nhalf, omega, temp_);
557 }
558
559 // Forward transformation.
560 // We use the "DIF" aka "decimation in frequency" transform, because it
561 // leaves the result in "bit reversed" order, which is precisely what we
562 // need as input for the "DIT" aka "decimation in time" backwards transform.
FFT_ReturnShuffledThreadsafe(int start,int len,int omega,digit_t * temp)563 void FFTContainer::FFT_ReturnShuffledThreadsafe(int start, int len, int omega,
564 digit_t* temp) {
565 DCHECK((len & 1) == 0); // {len} must be even.
566 int half = len / 2;
567 SumDiff(part_[start], part_[start + half], part_[start], part_[start + half],
568 length_);
569 for (int k = 1; k < half; k++) {
570 SumDiff(part_[start + k], temp, part_[start + k], part_[start + half + k],
571 length_);
572 int w = omega * k;
573 ShiftModFn(part_[start + half + k], temp, w, K_);
574 }
575 FFT_Recurse(start, half, omega, temp);
576 }
577
578 // Recursive step of the above, factored out for additional callers.
FFT_Recurse(int start,int half,int omega,digit_t * temp)579 void FFTContainer::FFT_Recurse(int start, int half, int omega, digit_t* temp) {
580 if (half > 1) {
581 FFT_ReturnShuffledThreadsafe(start, half, 2 * omega, temp);
582 FFT_ReturnShuffledThreadsafe(start + half, half, 2 * omega, temp);
583 }
584 }
585
586 // Backward transformation.
587 // We use the "DIT" aka "decimation in time" transform here, because it
588 // turns bit-reversed input into normally sorted output.
BackwardFFT(int start,int len,int omega)589 void FFTContainer::BackwardFFT(int start, int len, int omega) {
590 BackwardFFT_Threadsafe(start, len, omega, temp_);
591 }
592
BackwardFFT_Threadsafe(int start,int len,int omega,digit_t * temp)593 void FFTContainer::BackwardFFT_Threadsafe(int start, int len, int omega,
594 digit_t* temp) {
595 DCHECK((len & 1) == 0); // {len} must be even.
596 int half = len / 2;
597 // Don't recurse for half == 2, as PointwiseMultiply already performed
598 // the first level of the backwards FFT.
599 if (half > 2) {
600 BackwardFFT_Threadsafe(start, half, 2 * omega, temp);
601 BackwardFFT_Threadsafe(start + half, half, 2 * omega, temp);
602 }
603 SumDiff(part_[start], part_[start + half], part_[start], part_[start + half],
604 length_);
605 for (int k = 1; k < half; k++) {
606 int w = omega * (len - k);
607 ShiftModFn(temp, part_[start + half + k], w, K_);
608 SumDiff(part_[start + k], part_[start + half + k], part_[start + k], temp,
609 length_);
610 }
611 }
612
613 // Recombines the result's parts into {Z}, after backwards FFT.
NormalizeAndRecombine(int omega,int m,RWDigits Z,int chunk_size)614 void FFTContainer::NormalizeAndRecombine(int omega, int m, RWDigits Z,
615 int chunk_size) {
616 Z.Clear();
617 int z_index = 0;
618 const int shift = n_ * omega - m;
619 for (int i = 0; i < n_; i++, z_index += chunk_size) {
620 digit_t* part = part_[i];
621 ShiftModFn(temp_, part, shift, K_);
622 digit_t carry = 0;
623 int zi = z_index;
624 int j = 0;
625 for (; j < length_ && zi < Z.len(); j++, zi++) {
626 Z[zi] = digit_add3(Z[zi], temp_[j], carry, &carry);
627 }
628 for (; j < length_; j++) {
629 DCHECK(temp_[j] == 0);
630 }
631 if (carry != 0) {
632 DCHECK(zi < Z.len());
633 Z[zi] = carry;
634 }
635 }
636 }
637
638 // Helper function for {CounterWeightAndRecombine} below.
ShouldBeNegative(const digit_t * x,int xlen,digit_t threshold,int s)639 bool ShouldBeNegative(const digit_t* x, int xlen, digit_t threshold, int s) {
640 if (x[2 * s] >= threshold) return true;
641 for (int i = 2 * s + 1; i < xlen; i++) {
642 if (x[i] > 0) return true;
643 }
644 return false;
645 }
646
647 // Same as {NormalizeAndRecombine} above, but for the needs of the recursive
648 // invocation ("inner layer") of FFT multiplication, where an additional
649 // counter-weighting step is required.
CounterWeightAndRecombine(int theta,int m,RWDigits Z,int s)650 void FFTContainer::CounterWeightAndRecombine(int theta, int m, RWDigits Z,
651 int s) {
652 Z.Clear();
653 int z_index = 0;
654 for (int k = 0; k < n_; k++, z_index += s) {
655 int shift = -theta * k - m;
656 if (shift < 0) shift += 2 * n_ * theta;
657 DCHECK(shift >= 0);
658 digit_t* input = part_[k];
659 ShiftModFn(temp_, input, shift, K_);
660 int remaining_z = Z.len() - z_index;
661 if (ShouldBeNegative(temp_, length_, k + 1, s)) {
662 // Subtract F_n from input before adding to result. We use the following
663 // transformation (knowing that X < F_n):
664 // Z + (X - F_n) == Z - (F_n - X)
665 digit_t borrow_z = 0;
666 digit_t borrow_Fn = 0;
667 {
668 // i == 0:
669 digit_t d = digit_sub(1, temp_[0], &borrow_Fn);
670 Z[z_index] = digit_sub(Z[z_index], d, &borrow_z);
671 }
672 int i = 1;
673 for (; i < K_ && i < remaining_z; i++) {
674 digit_t d = digit_sub2(0, temp_[i], borrow_Fn, &borrow_Fn);
675 Z[z_index + i] = digit_sub2(Z[z_index + i], d, borrow_z, &borrow_z);
676 }
677 DCHECK(i == K_ && K_ == length_ - 1);
678 for (; i < length_ && i < remaining_z; i++) {
679 digit_t d = digit_sub2(1, temp_[i], borrow_Fn, &borrow_Fn);
680 Z[z_index + i] = digit_sub2(Z[z_index + i], d, borrow_z, &borrow_z);
681 }
682 DCHECK(borrow_Fn == 0);
683 for (; borrow_z > 0 && i < remaining_z; i++) {
684 Z[z_index + i] = digit_sub(Z[z_index + i], borrow_z, &borrow_z);
685 }
686 } else {
687 digit_t carry = 0;
688 int i = 0;
689 for (; i < length_ && i < remaining_z; i++) {
690 Z[z_index + i] = digit_add3(Z[z_index + i], temp_[i], carry, &carry);
691 }
692 for (; i < length_; i++) {
693 DCHECK(temp_[i] == 0);
694 }
695 for (; carry > 0 && i < remaining_z; i++) {
696 Z[z_index + i] = digit_add2(Z[z_index + i], carry, &carry);
697 }
698 // {carry} might be != 0 here if Z was negative before. That's fine.
699 }
700 }
701 }
702
703 // Main FFT function for recursive invocations ("inner layer").
MultiplyFFT_Inner(RWDigits Z,Digits X,Digits Y,const Parameters & params,ProcessorImpl * processor)704 void MultiplyFFT_Inner(RWDigits Z, Digits X, Digits Y, const Parameters& params,
705 ProcessorImpl* processor) {
706 int omega = 2 * params.r; // really: 2^(2r)
707 int theta = params.r; // really: 2^r
708
709 FFTContainer a(params.n, params.K, processor);
710 a.Start_Default(X, params.s, theta, omega);
711 FFTContainer b(params.n, params.K, processor);
712 b.Start_Default(Y, params.s, theta, omega);
713
714 a.PointwiseMultiply(b);
715 if (processor->should_terminate()) return;
716
717 FFTContainer& c = a;
718 c.BackwardFFT(0, params.n, omega);
719
720 c.CounterWeightAndRecombine(theta, params.m, Z, params.s);
721 }
722
723 // Actual implementation of pointwise multiplications.
DoPointwiseMultiplication(const FFTContainer & other,int start,int end,digit_t * temp)724 void FFTContainer::DoPointwiseMultiplication(const FFTContainer& other,
725 int start, int end,
726 digit_t* temp) {
727 // The (K_ & 3) != 0 condition makes sure that the inner FFT gets
728 // to split the work into at least 4 chunks.
729 bool use_fft = length_ >= kFftInnerThreshold && (K_ & 3) == 0;
730 Parameters params;
731 if (use_fft) ComputeParameters_Inner(K_, ¶ms);
732 RWDigits result(temp, 2 * length_);
733 for (int i = start; i < end; i++) {
734 Digits A(part_[i], length_);
735 Digits B(other.part_[i], length_);
736 if (use_fft) {
737 MultiplyFFT_Inner(result, A, B, params, processor_);
738 } else {
739 processor_->Multiply(result, A, B);
740 }
741 if (processor_->should_terminate()) return;
742 ModFnDoubleWidth(part_[i], result.digits(), length_);
743 // To improve cache friendliness, we perform the first level of the
744 // backwards FFT here.
745 if ((i & 1) == 1) {
746 SumDiff(part_[i - 1], part_[i], part_[i - 1], part_[i], length_);
747 }
748 }
749 }
750
751 // Convenient entry point for pointwise multiplications.
PointwiseMultiply(const FFTContainer & other)752 void FFTContainer::PointwiseMultiply(const FFTContainer& other) {
753 DCHECK(n_ == other.n_);
754 DoPointwiseMultiplication(other, 0, n_, temp_);
755 }
756
757 } // namespace
758
759 ////////////////////////////////////////////////////////////////////////////////
760 // Part 4: Tying everything together into a multiplication algorithm.
761
762 // TODO(jkummerow): Consider doing a "Mersenne transform" and CRT reconstruction
763 // of the final result. Might yield a few percent of perf improvement.
764
765 // TODO(jkummerow): Consider implementing the "sqrt(2) trick".
766 // Gaudry/Kruppa/Zimmerman report that it saved them around 10%.
767
MultiplyFFT(RWDigits Z,Digits X,Digits Y)768 void ProcessorImpl::MultiplyFFT(RWDigits Z, Digits X, Digits Y) {
769 Parameters params;
770 int m = GetParameters(X.len() + Y.len(), ¶ms);
771 int omega = params.r; // really: 2^r
772
773 FFTContainer a(params.n, params.K, this);
774 a.Start(X, params.s, 0, omega);
775 if (X == Y) {
776 // Squaring.
777 a.PointwiseMultiply(a);
778 } else {
779 FFTContainer b(params.n, params.K, this);
780 b.Start(Y, params.s, 0, omega);
781 a.PointwiseMultiply(b);
782 }
783 if (should_terminate()) return;
784
785 a.BackwardFFT(0, params.n, omega);
786 a.NormalizeAndRecombine(omega, m, Z, params.s);
787 }
788
789 } // namespace bigint
790 } // namespace v8
791