From 795cd2013ff9a03a8b9af778e690e9e1fdbd2192 Mon Sep 17 00:00:00 2001 From: Daniel Mendler Date: Tue, 29 Oct 2019 21:48:50 +0100 Subject: [PATCH] simplifications: add s_mp_zero_(digs|buf) and s_mp_copy_digs Originally I made those as macros. However we have many other small functions like mp_clamp, mp_exch which are also not implemented as macros right now. If we would use c99, I would implement them as private static inline functions. And mp_exch would be a public static inline function. But since we are bound to c89, we simply use normal functions. To achieve optimal performance one should either use link time optimization or amalgamation. --- .travis.yml | 1 + etc/tune.c | 2 +- mp_add_d.c | 2 +- mp_clear.c | 2 +- mp_copy.c | 17 +++--------- mp_div_2.c | 2 +- mp_dr_reduce.c | 2 +- mp_fwrite.c | 2 +- mp_grow.c | 2 +- mp_lshd.c | 2 +- mp_mod_2d.c | 2 +- mp_mul_2.c | 2 +- mp_mul_d.c | 2 +- mp_prime_rand.c | 2 +- mp_rshd.c | 2 +- mp_set.c | 2 +- mp_sub_d.c | 2 +- mp_zero.c | 2 +- s_mp_add.c | 2 +- s_mp_balance_mul.c | 20 +++++++------- s_mp_copy_digs.c | 23 ++++++++++++++++ s_mp_karatsuba_mul.c | 16 ++++------- s_mp_karatsuba_sqr.c | 12 +++------ s_mp_montgomery_reduce_fast.c | 4 +-- s_mp_mul_digs_fast.c | 2 +- s_mp_mul_high_digs_fast.c | 2 +- s_mp_sqr_fast.c | 2 +- s_mp_sub.c | 2 +- s_mp_toom_mul.c | 47 ++++++++++++-------------------- s_mp_toom_sqr.c | 22 +++++---------- s_mp_zero_buf.c | 22 +++++++++++++++ s_mp_zero_digs.c | 23 ++++++++++++++++ tommath_private.h | 51 +++++++++-------------------------- 33 files changed, 149 insertions(+), 151 deletions(-) create mode 100644 s_mp_copy_digs.c create mode 100644 s_mp_zero_buf.c create mode 100644 s_mp_zero_digs.c diff --git a/.travis.yml b/.travis.yml index 6c78689..a269c4e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -144,6 +144,7 @@ matrix: # clang for x86-64 architecture (64-bit longs and 64-bit pointers) - env: SANITIZER=1 CONV_WARNINGS=relaxed BUILDOPTIONS='--with-cc=clang-7 --with-m64 --with-travis-valgrind' - env: SANITIZER=1 CONV_WARNINGS=strict BUILDOPTIONS='--with-cc=clang-7 --with-m64 --with-travis-valgrind' + - env: SANITIZER=1 CONV_WARNINGS=strict BUILDOPTIONS='--with-cc=clang-7 --cflags=-DMP_USE_MEMOPS --with-m64 --with-travis-valgrind' - env: SANITIZER=1 CONV_WARNINGS=strict BUILDOPTIONS='--with-cc=clang-7 --c89 --with-m64 --with-travis-valgrind' - env: SANITIZER=1 BUILDOPTIONS='--with-cc=clang-7 --with-m64 --with-travis-valgrind --cflags=-DMP_PREC=MP_MIN_PREC' - env: SANITIZER=1 BUILDOPTIONS='--with-cc=clang-6.0 --with-m64 --with-travis-valgrind' diff --git a/etc/tune.c b/etc/tune.c index be78ce3..383eb49 100644 --- a/etc/tune.c +++ b/etc/tune.c @@ -292,7 +292,7 @@ int main(int argc, char **argv) s_number_of_test_loops = 64; s_stabilization_extra = 3; - MP_ZERO_BUFFER(&args, sizeof(args)); + s_mp_zero_buf(&args, sizeof(args)); args.testmode = 0; args.verbose = 0; diff --git a/mp_add_d.c b/mp_add_d.c index 43d50e8..9ef4475 100644 --- a/mp_add_d.c +++ b/mp_add_d.c @@ -80,7 +80,7 @@ mp_err mp_add_d(const mp_int *a, mp_digit b, mp_int *c) c->sign = MP_ZPOS; /* now zero to oldused */ - MP_ZERO_DIGITS(c->dp + c->used, oldused - c->used); + s_mp_zero_digs(c->dp + c->used, oldused - c->used); mp_clamp(c); return MP_OKAY; diff --git a/mp_clear.c b/mp_clear.c index 55d76b2..11094b2 100644 --- a/mp_clear.c +++ b/mp_clear.c @@ -9,7 +9,7 @@ void mp_clear(mp_int *a) /* only do anything if a hasn't been freed previously */ if (a->dp != NULL) { /* free ram */ - MP_FREE_DIGITS(a->dp, a->alloc); + MP_FREE_DIGS(a->dp, a->alloc); /* reset members to make debugging easier */ a->dp = NULL; diff --git a/mp_copy.c b/mp_copy.c index cf4d5e0..cf93b04 100644 --- a/mp_copy.c +++ b/mp_copy.c @@ -6,8 +6,6 @@ /* copy, b = a */ mp_err mp_copy(const mp_int *a, mp_int *b) { - int n; - /* if dst == src do nothing */ if (a == b) { return MP_OKAY; @@ -21,19 +19,12 @@ mp_err mp_copy(const mp_int *a, mp_int *b) } } - /* zero b and copy the parameters over */ - - /* copy all the digits */ - for (n = 0; n < a->used; n++) { - b->dp[n] = a->dp[n]; - } - - /* clear high digits */ - MP_ZERO_DIGITS(b->dp + a->used, b->used - a->used); - - /* copy used count and sign */ + /* copy everything over and zero high digits */ + s_mp_copy_digs(b->dp, a->dp, a->used); + s_mp_zero_digs(b->dp + a->used, b->used - a->used); b->used = a->used; b->sign = a->sign; + return MP_OKAY; } #endif diff --git a/mp_div_2.c b/mp_div_2.c index 573570d..b15391e 100644 --- a/mp_div_2.c +++ b/mp_div_2.c @@ -33,7 +33,7 @@ mp_err mp_div_2(const mp_int *a, mp_int *b) } /* zero excess digits */ - MP_ZERO_DIGITS(b->dp + b->used, oldused - b->used); + s_mp_zero_digs(b->dp + b->used, oldused - b->used); b->sign = a->sign; mp_clamp(b); diff --git a/mp_dr_reduce.c b/mp_dr_reduce.c index d630246..1b97a1d 100644 --- a/mp_dr_reduce.c +++ b/mp_dr_reduce.c @@ -49,7 +49,7 @@ mp_err mp_dr_reduce(mp_int *x, const mp_int *n, mp_digit k) x->dp[i] = mu; /* zero words above m */ - MP_ZERO_DIGITS(x->dp + m + 1, (x->used - m) - 1); + s_mp_zero_digs(x->dp + m + 1, (x->used - m) - 1); /* clamp, sub and return */ mp_clamp(x); diff --git a/mp_fwrite.c b/mp_fwrite.c index 42d7287..6b8ea13 100644 --- a/mp_fwrite.c +++ b/mp_fwrite.c @@ -25,7 +25,7 @@ mp_err mp_fwrite(const mp_int *a, int radix, FILE *stream) } } - MP_FREE_BUFFER(buf, size); + MP_FREE_BUF(buf, size); return err; } #endif diff --git a/mp_grow.c b/mp_grow.c index 25be5ed..0de6679 100644 --- a/mp_grow.c +++ b/mp_grow.c @@ -26,7 +26,7 @@ mp_err mp_grow(mp_int *a, int size) a->dp = dp; /* zero excess digits */ - MP_ZERO_DIGITS(a->dp + a->alloc, size - a->alloc); + s_mp_zero_digs(a->dp + a->alloc, size - a->alloc); a->alloc = size; } return MP_OKAY; diff --git a/mp_lshd.c b/mp_lshd.c index 6c14402..2f56e5d 100644 --- a/mp_lshd.c +++ b/mp_lshd.c @@ -37,7 +37,7 @@ mp_err mp_lshd(mp_int *a, int b) } /* zero the lower digits */ - MP_ZERO_DIGITS(a->dp, b); + s_mp_zero_digs(a->dp, b); return MP_OKAY; } diff --git a/mp_mod_2d.c b/mp_mod_2d.c index a94a314..82c64f0 100644 --- a/mp_mod_2d.c +++ b/mp_mod_2d.c @@ -29,7 +29,7 @@ mp_err mp_mod_2d(const mp_int *a, int b, mp_int *c) /* zero digits above the last digit of the modulus */ x = (b / MP_DIGIT_BIT) + (((b % MP_DIGIT_BIT) == 0) ? 0 : 1); - MP_ZERO_DIGITS(c->dp + x, c->used - x); + s_mp_zero_digs(c->dp + x, c->used - x); /* clear the digit that is not completely outside/inside the modulus */ c->dp[b / MP_DIGIT_BIT] &= diff --git a/mp_mul_2.c b/mp_mul_2.c index 45b6f1c..9e549c9 100644 --- a/mp_mul_2.c +++ b/mp_mul_2.c @@ -47,7 +47,7 @@ mp_err mp_mul_2(const mp_int *a, mp_int *b) /* now zero any excess digits on the destination * that we didn't write to */ - MP_ZERO_DIGITS(b->dp + b->used, oldused - b->used); + s_mp_zero_digs(b->dp + b->used, oldused - b->used); b->sign = a->sign; return MP_OKAY; diff --git a/mp_mul_d.c b/mp_mul_d.c index 3e5335f..2be366f 100644 --- a/mp_mul_d.c +++ b/mp_mul_d.c @@ -45,7 +45,7 @@ mp_err mp_mul_d(const mp_int *a, mp_digit b, mp_int *c) c->used = a->used + 1; /* now zero digits above the top */ - MP_ZERO_DIGITS(c->dp + c->used, oldused - c->used); + s_mp_zero_digs(c->dp + c->used, oldused - c->used); mp_clamp(c); diff --git a/mp_prime_rand.c b/mp_prime_rand.c index 8476b4f..c5cebbd 100644 --- a/mp_prime_rand.c +++ b/mp_prime_rand.c @@ -116,7 +116,7 @@ mp_err mp_prime_rand(mp_int *a, int t, int size, int flags) err = MP_OKAY; LBL_ERR: - MP_FREE_BUFFER(tmp, (size_t)bsize); + MP_FREE_BUF(tmp, (size_t)bsize); return err; } diff --git a/mp_rshd.c b/mp_rshd.c index d798907..3f0a941 100644 --- a/mp_rshd.c +++ b/mp_rshd.c @@ -35,7 +35,7 @@ void mp_rshd(mp_int *a, int b) } /* zero the top digits */ - MP_ZERO_DIGITS(a->dp + a->used - b, b); + s_mp_zero_digs(a->dp + a->used - b, b); /* remove excess digits */ a->used -= b; diff --git a/mp_set.c b/mp_set.c index 3ee5f81..bc0c4da 100644 --- a/mp_set.c +++ b/mp_set.c @@ -10,6 +10,6 @@ void mp_set(mp_int *a, mp_digit b) a->dp[0] = b & MP_MASK; a->sign = MP_ZPOS; a->used = (a->dp[0] != 0u) ? 1 : 0; - MP_ZERO_DIGITS(a->dp + a->used, oldused - a->used); + s_mp_zero_digs(a->dp + a->used, oldused - a->used); } #endif diff --git a/mp_sub_d.c b/mp_sub_d.c index c5cf726..91437f8 100644 --- a/mp_sub_d.c +++ b/mp_sub_d.c @@ -72,7 +72,7 @@ mp_err mp_sub_d(const mp_int *a, mp_digit b, mp_int *c) } /* zero excess digits */ - MP_ZERO_DIGITS(c->dp + c->used, oldused - c->used); + s_mp_zero_digs(c->dp + c->used, oldused - c->used); mp_clamp(c); return MP_OKAY; diff --git a/mp_zero.c b/mp_zero.c index b7dddd2..48b60e4 100644 --- a/mp_zero.c +++ b/mp_zero.c @@ -7,7 +7,7 @@ void mp_zero(mp_int *a) { a->sign = MP_ZPOS; - MP_ZERO_DIGITS(a->dp, a->used); + s_mp_zero_digs(a->dp, a->used); a->used = 0; } #endif diff --git a/s_mp_add.c b/s_mp_add.c index 1dd09f8..1d799b7 100644 --- a/s_mp_add.c +++ b/s_mp_add.c @@ -64,7 +64,7 @@ mp_err s_mp_add(const mp_int *a, const mp_int *b, mp_int *c) c->dp[i] = u; /* clear digits above oldused */ - MP_ZERO_DIGITS(c->dp + c->used, oldused - c->used); + s_mp_zero_digs(c->dp + c->used, oldused - c->used); mp_clamp(c); return MP_OKAY; diff --git a/s_mp_balance_mul.c b/s_mp_balance_mul.c index 77852a4..167a928 100644 --- a/s_mp_balance_mul.c +++ b/s_mp_balance_mul.c @@ -8,7 +8,7 @@ mp_err s_mp_balance_mul(const mp_int *a, const mp_int *b, mp_int *c) { mp_int a0, tmp, r; mp_err err; - int i, j, count, + int i, j, nblocks = MP_MAX(a->used, b->used) / MP_MIN(a->used, b->used), bsize = MP_MIN(a->used, b->used); @@ -27,12 +27,11 @@ mp_err s_mp_balance_mul(const mp_int *a, const mp_int *b, mp_int *c) for (i = 0, j=0; i < nblocks; i++) { /* Cut a slice off of a */ - a0.used = 0; - for (count = 0; count < bsize; count++) { - a0.dp[count] = a->dp[ j++ ]; - a0.used++; - } + a0.used = bsize; + s_mp_copy_digs(a0.dp, a->dp + j, a0.used); + j += a0.used; mp_clamp(&a0); + /* Multiply with b */ if ((err = mp_mul(&a0, b, &tmp)) != MP_OKAY) { goto LBL_ERR; @@ -48,12 +47,11 @@ mp_err s_mp_balance_mul(const mp_int *a, const mp_int *b, mp_int *c) } /* The left-overs; there are always left-overs */ if (j < a->used) { - a0.used = 0; - for (count = 0; j < a->used; count++) { - a0.dp[count] = a->dp[ j++ ]; - a0.used++; - } + a0.used = a->used - j; + s_mp_copy_digs(a0.dp, a->dp + j, a0.used); + j += a0.used; mp_clamp(&a0); + if ((err = mp_mul(&a0, b, &tmp)) != MP_OKAY) { goto LBL_ERR; } diff --git a/s_mp_copy_digs.c b/s_mp_copy_digs.c new file mode 100644 index 0000000..4079c33 --- /dev/null +++ b/s_mp_copy_digs.c @@ -0,0 +1,23 @@ +#include "tommath_private.h" +#ifdef S_MP_COPY_DIGS_C +/* LibTomMath, multiple-precision integer library -- Tom St Denis */ +/* SPDX-License-Identifier: Unlicense */ + +#ifdef MP_USE_MEMOPS +# include +#endif + +void s_mp_copy_digs(mp_digit *d, const mp_digit *s, int digits) +{ +#ifdef MP_USE_MEMOPS + if (digits > 0) { + memcpy(d, s, (size_t)digits * sizeof(mp_digit)); + } +#else + while (digits-- > 0) { + *d++ = *s++; + } +#endif +} + +#endif diff --git a/s_mp_karatsuba_mul.c b/s_mp_karatsuba_mul.c index 762e5e2..6d607ea 100644 --- a/s_mp_karatsuba_mul.c +++ b/s_mp_karatsuba_mul.c @@ -35,7 +35,7 @@ mp_err s_mp_karatsuba_mul(const mp_int *a, const mp_int *b, mp_int *c) { mp_int x0, x1, y0, y1, t1, x0y0, x1y1; - int B, i; + int B; mp_err err; /* min # of digits */ @@ -77,16 +77,10 @@ mp_err s_mp_karatsuba_mul(const mp_int *a, const mp_int *b, mp_int *c) /* we copy the digits directly instead of using higher level functions * since we also need to shift the digits */ - for (i = 0; i < B; i++) { - x0.dp[i] = a->dp[i]; - y0.dp[i] = b->dp[i]; - } - for (i = B; i < a->used; i++) { - x1.dp[i - B] = a->dp[i]; - } - for (i = B; i < b->used; i++) { - y1.dp[i - B] = b->dp[i]; - } + s_mp_copy_digs(x0.dp, a->dp, x0.used); + s_mp_copy_digs(y0.dp, b->dp, y0.used); + s_mp_copy_digs(x1.dp, a->dp + B, x1.used); + s_mp_copy_digs(y1.dp, b->dp + B, y1.used); /* only need to clamp the lower words since by definition the * upper words x1/y1 must have a known number of digits diff --git a/s_mp_karatsuba_sqr.c b/s_mp_karatsuba_sqr.c index 824fcdc..eb92ccb 100644 --- a/s_mp_karatsuba_sqr.c +++ b/s_mp_karatsuba_sqr.c @@ -13,7 +13,7 @@ mp_err s_mp_karatsuba_sqr(const mp_int *a, mp_int *b) { mp_int x0, x1, t1, t2, x0x0, x1x1; - int B, x; + int B; mp_err err; /* min # of digits */ @@ -39,16 +39,10 @@ mp_err s_mp_karatsuba_sqr(const mp_int *a, mp_int *b) goto X0X0; /* now shift the digits */ - for (x = 0; x < B; x++) { - x0.dp[x] = a->dp[x]; - } - for (x = B; x < a->used; x++) { - x1.dp[x - B] = a->dp[x]; - } - x0.used = B; x1.used = a->used - B; - + s_mp_copy_digs(x0.dp, a->dp, x0.used); + s_mp_copy_digs(x1.dp, a->dp + B, x1.used); mp_clamp(&x0); /* now calc the products x0*x0 and x1*x1 */ diff --git a/s_mp_montgomery_reduce_fast.c b/s_mp_montgomery_reduce_fast.c index a78c537..9b08115 100644 --- a/s_mp_montgomery_reduce_fast.c +++ b/s_mp_montgomery_reduce_fast.c @@ -42,7 +42,7 @@ mp_err s_mp_montgomery_reduce_fast(mp_int *x, const mp_int *n, mp_digit rho) /* zero the high words of W[a->used..m->used*2] */ if (ix < ((n->used * 2) + 1)) { - MP_ZERO_BUFFER(W + x->used, sizeof(mp_word) * (size_t)(((n->used * 2) + 1) - ix)); + s_mp_zero_buf(W + x->used, sizeof(mp_word) * (size_t)(((n->used * 2) + 1) - ix)); } /* now we proceed to zero successive digits @@ -108,7 +108,7 @@ mp_err s_mp_montgomery_reduce_fast(mp_int *x, const mp_int *n, mp_digit rho) /* zero oldused digits, if the input a was larger than * m->used+1 we'll have to clear the digits */ - MP_ZERO_DIGITS(x->dp + x->used, oldused - x->used); + s_mp_zero_digs(x->dp + x->used, oldused - x->used); mp_clamp(x); diff --git a/s_mp_mul_digs_fast.c b/s_mp_mul_digs_fast.c index 44aabd0..3928d04 100644 --- a/s_mp_mul_digs_fast.c +++ b/s_mp_mul_digs_fast.c @@ -72,7 +72,7 @@ mp_err s_mp_mul_digs_fast(const mp_int *a, const mp_int *b, mp_int *c, int digs) } /* clear unused digits [that existed in the old copy of c] */ - MP_ZERO_DIGITS(c->dp + c->used, oldused - c->used); + s_mp_zero_digs(c->dp + c->used, oldused - c->used); mp_clamp(c); return MP_OKAY; diff --git a/s_mp_mul_high_digs_fast.c b/s_mp_mul_high_digs_fast.c index 1384765..01335a5 100644 --- a/s_mp_mul_high_digs_fast.c +++ b/s_mp_mul_high_digs_fast.c @@ -64,7 +64,7 @@ mp_err s_mp_mul_high_digs_fast(const mp_int *a, const mp_int *b, mp_int *c, int } /* clear unused digits [that existed in the old copy of c] */ - MP_ZERO_DIGITS(c->dp + c->used, oldused - c->used); + s_mp_zero_digs(c->dp + c->used, oldused - c->used); mp_clamp(c); return MP_OKAY; diff --git a/s_mp_sqr_fast.c b/s_mp_sqr_fast.c index 675d75d..daf4214 100644 --- a/s_mp_sqr_fast.c +++ b/s_mp_sqr_fast.c @@ -81,7 +81,7 @@ mp_err s_mp_sqr_fast(const mp_int *a, mp_int *b) } /* clear unused digits [that existed in the old copy of c] */ - MP_ZERO_DIGITS(b->dp + b->used, oldused - b->used); + s_mp_zero_digs(b->dp + b->used, oldused - b->used); mp_clamp(b); return MP_OKAY; diff --git a/s_mp_sub.c b/s_mp_sub.c index 05386e5..ead0b51 100644 --- a/s_mp_sub.c +++ b/s_mp_sub.c @@ -49,7 +49,7 @@ mp_err s_mp_sub(const mp_int *a, const mp_int *b, mp_int *c) } /* clear digits above used (since we may not have grown result above) */ - MP_ZERO_DIGITS(c->dp + c->used, oldused - c->used); + s_mp_zero_digs(c->dp + c->used, oldused - c->used); mp_clamp(c); return MP_OKAY; diff --git a/s_mp_toom_mul.c b/s_mp_toom_mul.c index 93113b8..f3dd96a 100644 --- a/s_mp_toom_mul.c +++ b/s_mp_toom_mul.c @@ -32,7 +32,7 @@ mp_err s_mp_toom_mul(const mp_int *a, const mp_int *b, mp_int *c) { mp_int S1, S2, T1, a0, a1, a2, b0, b1, b2; - int B, count; + int B; mp_err err; /* init temps */ @@ -45,43 +45,30 @@ mp_err s_mp_toom_mul(const mp_int *a, const mp_int *b, mp_int *c) /** a = a2 * x^2 + a1 * x + a0; */ if ((err = mp_init_size(&a0, B)) != MP_OKAY) goto LBL_ERRa0; - - for (count = 0; count < B; count++) { - a0.dp[count] = a->dp[count]; - a0.used++; - } - mp_clamp(&a0); if ((err = mp_init_size(&a1, B)) != MP_OKAY) goto LBL_ERRa1; - for (; count < (2 * B); count++) { - a1.dp[count - B] = a->dp[count]; - a1.used++; - } + if ((err = mp_init_size(&a2, a->used - 2 * B)) != MP_OKAY) goto LBL_ERRa2; + + a0.used = a1.used = B; + a2.used = a->used - 2 * B; + s_mp_copy_digs(a0.dp, a->dp, a0.used); + s_mp_copy_digs(a1.dp, a->dp + B, a1.used); + s_mp_copy_digs(a2.dp, a->dp + 2 * B, a2.used); + mp_clamp(&a0); mp_clamp(&a1); - if ((err = mp_init_size(&a2, B + (a->used - (3 * B)))) != MP_OKAY) goto LBL_ERRa2; - for (; count < a->used; count++) { - a2.dp[count - (2 * B)] = a->dp[count]; - a2.used++; - } mp_clamp(&a2); /** b = b2 * x^2 + b1 * x + b0; */ if ((err = mp_init_size(&b0, B)) != MP_OKAY) goto LBL_ERRb0; - for (count = 0; count < B; count++) { - b0.dp[count] = b->dp[count]; - b0.used++; - } - mp_clamp(&b0); if ((err = mp_init_size(&b1, B)) != MP_OKAY) goto LBL_ERRb1; - for (; count < (2 * B); count++) { - b1.dp[count - B] = b->dp[count]; - b1.used++; - } + if ((err = mp_init_size(&b2, b->used - 2 * B)) != MP_OKAY) goto LBL_ERRb2; + + b0.used = b1.used = B; + b2.used = b->used - 2 * B; + s_mp_copy_digs(b0.dp, b->dp, b0.used); + s_mp_copy_digs(b1.dp, b->dp + B, b1.used); + s_mp_copy_digs(b2.dp, b->dp + 2 * B, b2.used); + mp_clamp(&b0); mp_clamp(&b1); - if ((err = mp_init_size(&b2, B + (b->used - (3 * B)))) != MP_OKAY) goto LBL_ERRb2; - for (; count < b->used; count++) { - b2.dp[count - (2 * B)] = b->dp[count]; - b2.used++; - } mp_clamp(&b2); /** \\ S1 = (a2+a1+a0) * (b2+b1+b0); */ diff --git a/s_mp_toom_sqr.c b/s_mp_toom_sqr.c index d8f2f8e..1342d57 100644 --- a/s_mp_toom_sqr.c +++ b/s_mp_toom_sqr.c @@ -21,7 +21,7 @@ mp_err s_mp_toom_sqr(const mp_int *a, mp_int *b) { mp_int S0, a0, a1, a2; - int B, count; + int B; mp_err err; /* init temps */ @@ -34,22 +34,14 @@ mp_err s_mp_toom_sqr(const mp_int *a, mp_int *b) /** a = a2 * x^2 + a1 * x + a0; */ if ((err = mp_init_size(&a0, B)) != MP_OKAY) goto LBL_ERRa0; - - a0.used = B; if ((err = mp_init_size(&a1, B)) != MP_OKAY) goto LBL_ERRa1; - a1.used = B; - if ((err = mp_init_size(&a2, B + (a->used - (3 * B)))) != MP_OKAY) goto LBL_ERRa2; + if ((err = mp_init_size(&a2, a->used - (2 * B))) != MP_OKAY) goto LBL_ERRa2; - for (count = 0; count < B; count++) { - a0.dp[count] = a->dp[count]; - } - for (; count < (2 * B); count++) { - a1.dp[count - B] = a->dp[count]; - } - for (; count < a->used; count++) { - a2.dp[count - 2 * B] = a->dp[count]; - a2.used++; - } + a0.used = a1.used = B; + a2.used = a->used - 2 * B; + s_mp_copy_digs(a0.dp, a->dp, a0.used); + s_mp_copy_digs(a1.dp, a->dp + B, a1.used); + s_mp_copy_digs(a2.dp, a->dp + 2 * B, a2.used); mp_clamp(&a0); mp_clamp(&a1); mp_clamp(&a2); diff --git a/s_mp_zero_buf.c b/s_mp_zero_buf.c new file mode 100644 index 0000000..23a458d --- /dev/null +++ b/s_mp_zero_buf.c @@ -0,0 +1,22 @@ +#include "tommath_private.h" +#ifdef S_MP_ZERO_BUF_C +/* LibTomMath, multiple-precision integer library -- Tom St Denis */ +/* SPDX-License-Identifier: Unlicense */ + +#ifdef MP_USE_MEMOPS +# include +#endif + +void s_mp_zero_buf(void *mem, size_t size) +{ +#ifdef MP_USE_MEMOPS + memset(mem, 0, size); +#else + char *m = (char *)mem; + while (size-- > 0u) { + *m++ = '\0'; + } +#endif +} + +#endif diff --git a/s_mp_zero_digs.c b/s_mp_zero_digs.c new file mode 100644 index 0000000..79e8377 --- /dev/null +++ b/s_mp_zero_digs.c @@ -0,0 +1,23 @@ +#include "tommath_private.h" +#ifdef S_MP_ZERO_DIGS_C +/* LibTomMath, multiple-precision integer library -- Tom St Denis */ +/* SPDX-License-Identifier: Unlicense */ + +#ifdef MP_USE_MEMOPS +# include +#endif + +void s_mp_zero_digs(mp_digit *d, int digits) +{ +#ifdef MP_USE_MEMOPS + if (digits > 0) { + memset(d, 0, (size_t)digits * sizeof(mp_digit)); + } +#else + while (digits-- > 0) { + *d++ = 0; + } +#endif +} + +#endif diff --git a/tommath_private.h b/tommath_private.h index aaa3d23..7aad433 100644 --- a/tommath_private.h +++ b/tommath_private.h @@ -42,55 +42,25 @@ * define MP_NO_ZERO_ON_FREE during compilation. */ #ifdef MP_NO_ZERO_ON_FREE -# define MP_FREE_BUFFER(mem, size) MP_FREE((mem), (size)) -# define MP_FREE_DIGITS(mem, digits) MP_FREE((mem), sizeof (mp_digit) * (size_t)(digits)) +# define MP_FREE_BUF(mem, size) MP_FREE((mem), (size)) +# define MP_FREE_DIGS(mem, digits) MP_FREE((mem), sizeof (mp_digit) * (size_t)(digits)) #else -# define MP_FREE_BUFFER(mem, size) \ +# define MP_FREE_BUF(mem, size) \ do { \ size_t fs_ = (size); \ void* fm_ = (mem); \ if (fm_ != NULL) { \ - MP_ZERO_BUFFER(fm_, fs_); \ + s_mp_zero_buf(fm_, fs_); \ MP_FREE(fm_, fs_); \ } \ } while (0) -# define MP_FREE_DIGITS(mem, digits) \ +# define MP_FREE_DIGS(mem, digits) \ do { \ int fd_ = (digits); \ - void* fm_ = (mem); \ + mp_digit* fm_ = (mem); \ if (fm_ != NULL) { \ - size_t fs_ = sizeof (mp_digit) * (size_t)fd_; \ - MP_ZERO_BUFFER(fm_, fs_); \ - MP_FREE(fm_, fs_); \ - } \ -} while (0) -#endif - -#ifdef MP_USE_MEMSET -# include -# define MP_ZERO_BUFFER(mem, size) memset((mem), 0, (size)) -# define MP_ZERO_DIGITS(mem, digits) \ -do { \ - int zd_ = (digits); \ - if (zd_ > 0) { \ - memset((mem), 0, sizeof(mp_digit) * (size_t)zd_); \ - } \ -} while (0) -#else -# define MP_ZERO_BUFFER(mem, size) \ -do { \ - size_t zs_ = (size); \ - char* zm_ = (char*)(mem); \ - while (zs_-- > 0u) { \ - *zm_++ = '\0'; \ - } \ -} while (0) -# define MP_ZERO_DIGITS(mem, digits) \ -do { \ - int zd_ = (digits); \ - mp_digit* zm_ = (mem); \ - while (zd_-- > 0) { \ - *zm_++ = 0; \ + s_mp_zero_digs(fm_, fd_); \ + MP_FREE(fm_, sizeof (mp_digit) * (size_t)fd_); \ } \ } while (0) #endif @@ -215,6 +185,9 @@ MP_PRIVATE uint32_t s_mp_log_pow2(const mp_int *a, uint32_t base); MP_PRIVATE mp_err s_mp_div_recursive(const mp_int *a, const mp_int *b, mp_int *q, mp_int *r); MP_PRIVATE mp_err s_mp_div_school(const mp_int *a, const mp_int *b, mp_int *c, mp_int *d); MP_PRIVATE mp_err s_mp_div_small(const mp_int *a, const mp_int *b, mp_int *c, mp_int *d); +MP_PRIVATE void s_mp_zero_buf(void *mem, size_t size); +MP_PRIVATE void s_mp_zero_digs(mp_digit *d, int digits); +MP_PRIVATE void s_mp_copy_digs(mp_digit *d, const mp_digit *s, int digits); /* TODO: jenkins prng is not thread safe as of now */ MP_PRIVATE mp_err s_mp_rand_jenkins(void *p, size_t n) MP_WUR; @@ -247,7 +220,7 @@ extern MP_PRIVATE const mp_digit s_mp_prime_tab[]; } \ a->used = i; \ a->sign = MP_ZPOS; \ - MP_ZERO_DIGITS(a->dp + a->used, a->alloc - a->used); \ + s_mp_zero_digs(a->dp + a->used, a->alloc - a->used); \ } #define MP_SET_SIGNED(name, uname, type, utype) \