simplifications: reduce functions

This commit is contained in:
Daniel Mendler 2019-10-29 20:08:42 +01:00
parent 448f35e2e1
commit 56144eed1e
No known key found for this signature in database
GPG Key ID: D88ADB2A2693CA43
6 changed files with 123 additions and 178 deletions

View File

@ -19,16 +19,12 @@
*/ */
mp_err mp_dr_reduce(mp_int *x, const mp_int *n, mp_digit k) mp_err mp_dr_reduce(mp_int *x, const mp_int *n, mp_digit k)
{ {
mp_err err;
int i, m;
mp_word r;
mp_digit mu, *tmpx1, *tmpx2;
/* m = digits in modulus */ /* m = digits in modulus */
m = n->used; int m = n->used;
/* ensure that "x" has at least 2m digits */ /* ensure that "x" has at least 2m digits */
if (x->alloc < (m + m)) { if (x->alloc < (m + m)) {
mp_err err;
if ((err = mp_grow(x, m + m)) != MP_OKAY) { if ((err = mp_grow(x, m + m)) != MP_OKAY) {
return err; return err;
} }
@ -37,41 +33,37 @@ mp_err mp_dr_reduce(mp_int *x, const mp_int *n, mp_digit k)
/* top of loop, this is where the code resumes if /* top of loop, this is where the code resumes if
* another reduction pass is required. * another reduction pass is required.
*/ */
top: for (;;) {
/* aliases for digits */ mp_err err;
/* alias for lower half of x */ int i;
tmpx1 = x->dp; mp_digit mu = 0;
/* alias for upper half of x, or x/B**m */ /* compute (x mod B**m) + k * [x/B**m] inline and inplace */
tmpx2 = x->dp + m; for (i = 0; i < m; i++) {
mp_word r = ((mp_word)x->dp[i + m] * (mp_word)k) + x->dp[i] + mu;
x->dp[i] = (mp_digit)(r & MP_MASK);
mu = (mp_digit)(r >> ((mp_word)MP_DIGIT_BIT));
}
/* set carry to zero */ /* set final carry */
mu = 0; x->dp[i] = mu;
/* compute (x mod B**m) + k * [x/B**m] inline and inplace */ /* zero words above m */
for (i = 0; i < m; i++) { MP_ZERO_DIGITS(x->dp + m + 1, (x->used - m) - 1);
r = ((mp_word)*tmpx2++ * (mp_word)k) + *tmpx1 + mu;
*tmpx1++ = (mp_digit)(r & MP_MASK);
mu = (mp_digit)(r >> ((mp_word)MP_DIGIT_BIT));
}
/* set final carry */ /* clamp, sub and return */
*tmpx1++ = mu; mp_clamp(x);
/* zero words above m */ /* if x >= n then subtract and reduce again
MP_ZERO_DIGITS(tmpx1, (x->used - m) - 1); * Each successive "recursion" makes the input smaller and smaller.
*/
if (mp_cmp_mag(x, n) == MP_LT) {
break;
}
/* clamp, sub and return */
mp_clamp(x);
/* if x >= n then subtract and reduce again
* Each successive "recursion" makes the input smaller and smaller.
*/
if (mp_cmp_mag(x, n) != MP_LT) {
if ((err = s_mp_sub(x, n, x)) != MP_OKAY) { if ((err = s_mp_sub(x, n, x)) != MP_OKAY) {
return err; return err;
} }
goto top;
} }
return MP_OKAY; return MP_OKAY;
} }

View File

@ -6,9 +6,7 @@
/* computes xR**-1 == x (mod N) via Montgomery Reduction */ /* computes xR**-1 == x (mod N) via Montgomery Reduction */
mp_err mp_montgomery_reduce(mp_int *x, const mp_int *n, mp_digit rho) mp_err mp_montgomery_reduce(mp_int *x, const mp_int *n, mp_digit rho)
{ {
int ix, digs; int ix, digs;
mp_err err;
mp_digit mu;
/* can the fast reduction [comba] method be used? /* can the fast reduction [comba] method be used?
* *
@ -25,6 +23,7 @@ mp_err mp_montgomery_reduce(mp_int *x, const mp_int *n, mp_digit rho)
/* grow the input as required */ /* grow the input as required */
if (x->alloc < digs) { if (x->alloc < digs) {
mp_err err;
if ((err = mp_grow(x, digs)) != MP_OKAY) { if ((err = mp_grow(x, digs)) != MP_OKAY) {
return err; return err;
} }
@ -32,6 +31,9 @@ mp_err mp_montgomery_reduce(mp_int *x, const mp_int *n, mp_digit rho)
x->used = digs; x->used = digs;
for (ix = 0; ix < n->used; ix++) { for (ix = 0; ix < n->used; ix++) {
int iy;
mp_digit u, mu;
/* mu = ai * rho mod b /* mu = ai * rho mod b
* *
* The value of rho must be precalculated via * The value of rho must be precalculated via
@ -43,41 +45,28 @@ mp_err mp_montgomery_reduce(mp_int *x, const mp_int *n, mp_digit rho)
mu = (mp_digit)(((mp_word)x->dp[ix] * (mp_word)rho) & MP_MASK); mu = (mp_digit)(((mp_word)x->dp[ix] * (mp_word)rho) & MP_MASK);
/* a = a + mu * m * b**i */ /* a = a + mu * m * b**i */
{
int iy;
mp_digit *tmpn, *tmpx, u;
mp_word r;
/* alias for digits of the modulus */ /* Multiply and add in place */
tmpn = n->dp; u = 0;
for (iy = 0; iy < n->used; iy++) {
/* compute product and sum */
mp_word r = ((mp_word)mu * (mp_word)n->dp[iy]) +
(mp_word)u + (mp_word)x->dp[ix + iy];
/* alias for the digits of x [the input] */ /* get carry */
tmpx = x->dp + ix; u = (mp_digit)(r >> (mp_word)MP_DIGIT_BIT);
/* set the carry to zero */ /* fix digit */
u = 0; x->dp[ix + iy] = (mp_digit)(r & (mp_word)MP_MASK);
}
/* At this point the ix'th digit of x should be zero */
/* Multiply and add in place */ /* propagate carries upwards as required*/
for (iy = 0; iy < n->used; iy++) { while (u != 0u) {
/* compute product and sum */ x->dp[ix + iy] += u;
r = ((mp_word)mu * (mp_word)*tmpn++) + u = x->dp[ix + iy] >> MP_DIGIT_BIT;
(mp_word)u + (mp_word)*tmpx; x->dp[ix + iy] &= MP_MASK;
++iy;
/* get carry */
u = (mp_digit)(r >> (mp_word)MP_DIGIT_BIT);
/* fix digit */
*tmpx++ = (mp_digit)(r & (mp_word)MP_MASK);
}
/* At this point the ix'th digit of x should be zero */
/* propagate carries upwards as required*/
while (u != 0u) {
*tmpx += u;
u = *tmpx >> MP_DIGIT_BIT;
*tmpx++ &= MP_MASK;
}
} }
} }

View File

@ -8,36 +8,37 @@ mp_err mp_reduce_2k(mp_int *a, const mp_int *n, mp_digit d)
{ {
mp_int q; mp_int q;
mp_err err; mp_err err;
int p; int p;
if ((err = mp_init(&q)) != MP_OKAY) { if ((err = mp_init(&q)) != MP_OKAY) {
return err; return err;
} }
p = mp_count_bits(n); p = mp_count_bits(n);
top: for (;;) {
/* q = a/2**p, a = a mod 2**p */ /* q = a/2**p, a = a mod 2**p */
if ((err = mp_div_2d(a, p, &q, a)) != MP_OKAY) { if ((err = mp_div_2d(a, p, &q, a)) != MP_OKAY) {
goto LBL_ERR;
}
if (d != 1u) {
/* q = q * d */
if ((err = mp_mul_d(&q, d, &q)) != MP_OKAY) {
goto LBL_ERR; goto LBL_ERR;
} }
}
/* a = a + q */ if (d != 1u) {
if ((err = s_mp_add(a, &q, a)) != MP_OKAY) { /* q = q * d */
goto LBL_ERR; if ((err = mp_mul_d(&q, d, &q)) != MP_OKAY) {
} goto LBL_ERR;
}
}
if (mp_cmp_mag(a, n) != MP_LT) { /* a = a + q */
if ((err = s_mp_add(a, &q, a)) != MP_OKAY) {
goto LBL_ERR;
}
if (mp_cmp_mag(a, n) == MP_LT) {
break;
}
if ((err = s_mp_sub(a, n, a)) != MP_OKAY) { if ((err = s_mp_sub(a, n, a)) != MP_OKAY) {
goto LBL_ERR; goto LBL_ERR;
} }
goto top;
} }
LBL_ERR: LBL_ERR:

View File

@ -18,27 +18,30 @@ mp_err mp_reduce_2k_l(mp_int *a, const mp_int *n, const mp_int *d)
} }
p = mp_count_bits(n); p = mp_count_bits(n);
top:
/* q = a/2**p, a = a mod 2**p */
if ((err = mp_div_2d(a, p, &q, a)) != MP_OKAY) {
goto LBL_ERR;
}
/* q = q * d */ for (;;) {
if ((err = mp_mul(&q, d, &q)) != MP_OKAY) { /* q = a/2**p, a = a mod 2**p */
goto LBL_ERR; if ((err = mp_div_2d(a, p, &q, a)) != MP_OKAY) {
} goto LBL_ERR;
}
/* a = a + q */ /* q = q * d */
if ((err = s_mp_add(a, &q, a)) != MP_OKAY) { if ((err = mp_mul(&q, d, &q)) != MP_OKAY) {
goto LBL_ERR; goto LBL_ERR;
} }
if (mp_cmp_mag(a, n) != MP_LT) { /* a = a + q */
if ((err = s_mp_add(a, &q, a)) != MP_OKAY) {
goto LBL_ERR;
}
if (mp_cmp_mag(a, n) == MP_LT) {
break;
}
if ((err = s_mp_sub(a, n, a)) != MP_OKAY) { if ((err = s_mp_sub(a, n, a)) != MP_OKAY) {
goto LBL_ERR; goto LBL_ERR;
} }
goto top;
} }
LBL_ERR: LBL_ERR:

View File

@ -8,25 +8,23 @@ mp_err mp_reduce_2k_setup(const mp_int *a, mp_digit *d)
{ {
mp_err err; mp_err err;
mp_int tmp; mp_int tmp;
int p;
if ((err = mp_init(&tmp)) != MP_OKAY) { if ((err = mp_init(&tmp)) != MP_OKAY) {
return err; return err;
} }
p = mp_count_bits(a); if ((err = mp_2expt(&tmp, mp_count_bits(a))) != MP_OKAY) {
if ((err = mp_2expt(&tmp, p)) != MP_OKAY) { goto LBL_ERR;
mp_clear(&tmp);
return err;
} }
if ((err = s_mp_sub(&tmp, a, &tmp)) != MP_OKAY) { if ((err = s_mp_sub(&tmp, a, &tmp)) != MP_OKAY) {
mp_clear(&tmp); goto LBL_ERR;
return err;
} }
*d = tmp.dp[0]; *d = tmp.dp[0];
LBL_ERR:
mp_clear(&tmp); mp_clear(&tmp);
return MP_OKAY; return err;
} }
#endif #endif

View File

@ -13,7 +13,7 @@
*/ */
mp_err s_mp_montgomery_reduce_fast(mp_int *x, const mp_int *n, mp_digit rho) mp_err s_mp_montgomery_reduce_fast(mp_int *x, const mp_int *n, mp_digit rho)
{ {
int ix, olduse; int ix, oldused;
mp_err err; mp_err err;
mp_word W[MP_WARRAY]; mp_word W[MP_WARRAY];
@ -22,7 +22,7 @@ mp_err s_mp_montgomery_reduce_fast(mp_int *x, const mp_int *n, mp_digit rho)
} }
/* get old used count */ /* get old used count */
olduse = x->used; oldused = x->used;
/* grow a as required */ /* grow a as required */
if (x->alloc < (n->used + 1)) { if (x->alloc < (n->used + 1)) {
@ -34,38 +34,30 @@ mp_err s_mp_montgomery_reduce_fast(mp_int *x, const mp_int *n, mp_digit rho)
/* first we have to get the digits of the input into /* first we have to get the digits of the input into
* an array of double precision words W[...] * an array of double precision words W[...]
*/ */
{
mp_word *_W;
mp_digit *tmpx;
/* alias for the W[] array */ /* copy the digits of a into W[0..a->used-1] */
_W = W; for (ix = 0; ix < x->used; ix++) {
W[ix] = x->dp[ix];
}
/* alias for the digits of x*/ /* zero the high words of W[a->used..m->used*2] */
tmpx = x->dp; if (ix < ((n->used * 2) + 1)) {
MP_ZERO_BUFFER(W + x->used, sizeof(mp_word) * (size_t)(((n->used * 2) + 1) - ix));
/* copy the digits of a into W[0..a->used-1] */
for (ix = 0; ix < x->used; ix++) {
*_W++ = *tmpx++;
}
/* zero the high words of W[a->used..m->used*2] */
if (ix < ((n->used * 2) + 1)) {
MP_ZERO_BUFFER(_W, sizeof(mp_word) * (size_t)(((n->used * 2) + 1) - ix));
}
} }
/* now we proceed to zero successive digits /* now we proceed to zero successive digits
* from the least significant upwards * from the least significant upwards
*/ */
for (ix = 0; ix < n->used; ix++) { for (ix = 0; ix < n->used; ix++) {
int iy;
mp_digit mu;
/* mu = ai * m' mod b /* mu = ai * m' mod b
* *
* We avoid a double precision multiplication (which isn't required) * We avoid a double precision multiplication (which isn't required)
* by casting the value down to a mp_digit. Note this requires * by casting the value down to a mp_digit. Note this requires
* that W[ix-1] have the carry cleared (see after the inner loop) * that W[ix-1] have the carry cleared (see after the inner loop)
*/ */
mp_digit mu;
mu = ((W[ix] & MP_MASK) * rho) & MP_MASK; mu = ((W[ix] & MP_MASK) * rho) & MP_MASK;
/* a = a + mu * m * b**i /* a = a + mu * m * b**i
@ -82,21 +74,8 @@ mp_err s_mp_montgomery_reduce_fast(mp_int *x, const mp_int *n, mp_digit rho)
* carry fixups are done in order so after these loops the * carry fixups are done in order so after these loops the
* first m->used words of W[] have the carries fixed * first m->used words of W[] have the carries fixed
*/ */
{ for (iy = 0; iy < n->used; iy++) {
int iy; W[ix + iy] += (mp_word)mu * (mp_word)n->dp[iy];
mp_digit *tmpn;
mp_word *_W;
/* alias for the digits of the modulus */
tmpn = n->dp;
/* Alias for the columns set by an offset of ix */
_W = W + ix;
/* inner loop */
for (iy = 0; iy < n->used; iy++) {
*_W++ += (mp_word)mu * (mp_word)*tmpn++;
}
} }
/* now fix carry for next digit, W[ix+1] */ /* now fix carry for next digit, W[ix+1] */
@ -107,47 +86,30 @@ mp_err s_mp_montgomery_reduce_fast(mp_int *x, const mp_int *n, mp_digit rho)
* shift the words downward [all those least * shift the words downward [all those least
* significant digits we zeroed]. * significant digits we zeroed].
*/ */
{
mp_digit *tmpx;
mp_word *_W, *_W1;
/* nox fix rest of carries */ for (; ix < (n->used * 2); ix++) {
W[ix + 1] += W[ix] >> (mp_word)MP_DIGIT_BIT;
/* alias for current word */
_W1 = W + ix;
/* alias for next word, where the carry goes */
_W = W + ++ix;
for (; ix < ((n->used * 2) + 1); ix++) {
*_W++ += *_W1++ >> (mp_word)MP_DIGIT_BIT;
}
/* copy out, A = A/b**n
*
* The result is A/b**n but instead of converting from an
* array of mp_word to mp_digit than calling mp_rshd
* we just copy them in the right order
*/
/* alias for destination word */
tmpx = x->dp;
/* alias for shifted double precision result */
_W = W + n->used;
for (ix = 0; ix < (n->used + 1); ix++) {
*tmpx++ = *_W++ & (mp_word)MP_MASK;
}
/* zero oldused digits, if the input a was larger than
* m->used+1 we'll have to clear the digits
*/
MP_ZERO_DIGITS(tmpx, olduse - ix);
} }
/* set the max used and clamp */ /* copy out, A = A/b**n
*
* The result is A/b**n but instead of converting from an
* array of mp_word to mp_digit than calling mp_rshd
* we just copy them in the right order
*/
for (ix = 0; ix < (n->used + 1); ix++) {
x->dp[ix] = W[n->used + ix] & (mp_word)MP_MASK;
}
/* set the max used */
x->used = n->used + 1; x->used = n->used + 1;
/* zero oldused digits, if the input a was larger than
* m->used+1 we'll have to clear the digits
*/
MP_ZERO_DIGITS(x->dp + x->used, oldused - x->used);
mp_clamp(x); mp_clamp(x);
/* if A >= m then A = A - m */ /* if A >= m then A = A - m */