simplifications: reduce functions

This commit is contained in:
Daniel Mendler 2019-10-29 20:08:42 +01:00
parent 448f35e2e1
commit 56144eed1e
No known key found for this signature in database
GPG Key ID: D88ADB2A2693CA43
6 changed files with 123 additions and 178 deletions

View File

@ -19,16 +19,12 @@
*/
mp_err mp_dr_reduce(mp_int *x, const mp_int *n, mp_digit k)
{
mp_err err;
int i, m;
mp_word r;
mp_digit mu, *tmpx1, *tmpx2;
/* m = digits in modulus */
m = n->used;
int m = n->used;
/* ensure that "x" has at least 2m digits */
if (x->alloc < (m + m)) {
mp_err err;
if ((err = mp_grow(x, m + m)) != MP_OKAY) {
return err;
}
@ -37,29 +33,23 @@ mp_err mp_dr_reduce(mp_int *x, const mp_int *n, mp_digit k)
/* top of loop, this is where the code resumes if
* another reduction pass is required.
*/
top:
/* aliases for digits */
/* alias for lower half of x */
tmpx1 = x->dp;
/* alias for upper half of x, or x/B**m */
tmpx2 = x->dp + m;
/* set carry to zero */
mu = 0;
for (;;) {
mp_err err;
int i;
mp_digit mu = 0;
/* compute (x mod B**m) + k * [x/B**m] inline and inplace */
for (i = 0; i < m; i++) {
r = ((mp_word)*tmpx2++ * (mp_word)k) + *tmpx1 + mu;
*tmpx1++ = (mp_digit)(r & MP_MASK);
mp_word r = ((mp_word)x->dp[i + m] * (mp_word)k) + x->dp[i] + mu;
x->dp[i] = (mp_digit)(r & MP_MASK);
mu = (mp_digit)(r >> ((mp_word)MP_DIGIT_BIT));
}
/* set final carry */
*tmpx1++ = mu;
x->dp[i] = mu;
/* zero words above m */
MP_ZERO_DIGITS(tmpx1, (x->used - m) - 1);
MP_ZERO_DIGITS(x->dp + m + 1, (x->used - m) - 1);
/* clamp, sub and return */
mp_clamp(x);
@ -67,11 +57,13 @@ top:
/* if x >= n then subtract and reduce again
* Each successive "recursion" makes the input smaller and smaller.
*/
if (mp_cmp_mag(x, n) != MP_LT) {
if (mp_cmp_mag(x, n) == MP_LT) {
break;
}
if ((err = s_mp_sub(x, n, x)) != MP_OKAY) {
return err;
}
goto top;
}
return MP_OKAY;
}

View File

@ -7,8 +7,6 @@
mp_err mp_montgomery_reduce(mp_int *x, const mp_int *n, mp_digit rho)
{
int ix, digs;
mp_err err;
mp_digit mu;
/* can the fast reduction [comba] method be used?
*
@ -25,6 +23,7 @@ mp_err mp_montgomery_reduce(mp_int *x, const mp_int *n, mp_digit rho)
/* grow the input as required */
if (x->alloc < digs) {
mp_err err;
if ((err = mp_grow(x, digs)) != MP_OKAY) {
return err;
}
@ -32,6 +31,9 @@ mp_err mp_montgomery_reduce(mp_int *x, const mp_int *n, mp_digit rho)
x->used = digs;
for (ix = 0; ix < n->used; ix++) {
int iy;
mp_digit u, mu;
/* mu = ai * rho mod b
*
* The value of rho must be precalculated via
@ -43,41 +45,28 @@ mp_err mp_montgomery_reduce(mp_int *x, const mp_int *n, mp_digit rho)
mu = (mp_digit)(((mp_word)x->dp[ix] * (mp_word)rho) & MP_MASK);
/* a = a + mu * m * b**i */
{
int iy;
mp_digit *tmpn, *tmpx, u;
mp_word r;
/* alias for digits of the modulus */
tmpn = n->dp;
/* alias for the digits of x [the input] */
tmpx = x->dp + ix;
/* set the carry to zero */
u = 0;
/* Multiply and add in place */
u = 0;
for (iy = 0; iy < n->used; iy++) {
/* compute product and sum */
r = ((mp_word)mu * (mp_word)*tmpn++) +
(mp_word)u + (mp_word)*tmpx;
mp_word r = ((mp_word)mu * (mp_word)n->dp[iy]) +
(mp_word)u + (mp_word)x->dp[ix + iy];
/* get carry */
u = (mp_digit)(r >> (mp_word)MP_DIGIT_BIT);
/* fix digit */
*tmpx++ = (mp_digit)(r & (mp_word)MP_MASK);
x->dp[ix + iy] = (mp_digit)(r & (mp_word)MP_MASK);
}
/* At this point the ix'th digit of x should be zero */
/* propagate carries upwards as required*/
while (u != 0u) {
*tmpx += u;
u = *tmpx >> MP_DIGIT_BIT;
*tmpx++ &= MP_MASK;
}
x->dp[ix + iy] += u;
u = x->dp[ix + iy] >> MP_DIGIT_BIT;
x->dp[ix + iy] &= MP_MASK;
++iy;
}
}

View File

@ -15,7 +15,7 @@ mp_err mp_reduce_2k(mp_int *a, const mp_int *n, mp_digit d)
}
p = mp_count_bits(n);
top:
for (;;) {
/* q = a/2**p, a = a mod 2**p */
if ((err = mp_div_2d(a, p, &q, a)) != MP_OKAY) {
goto LBL_ERR;
@ -33,11 +33,12 @@ top:
goto LBL_ERR;
}
if (mp_cmp_mag(a, n) != MP_LT) {
if (mp_cmp_mag(a, n) == MP_LT) {
break;
}
if ((err = s_mp_sub(a, n, a)) != MP_OKAY) {
goto LBL_ERR;
}
goto top;
}
LBL_ERR:

View File

@ -18,7 +18,8 @@ mp_err mp_reduce_2k_l(mp_int *a, const mp_int *n, const mp_int *d)
}
p = mp_count_bits(n);
top:
for (;;) {
/* q = a/2**p, a = a mod 2**p */
if ((err = mp_div_2d(a, p, &q, a)) != MP_OKAY) {
goto LBL_ERR;
@ -34,11 +35,13 @@ top:
goto LBL_ERR;
}
if (mp_cmp_mag(a, n) != MP_LT) {
if (mp_cmp_mag(a, n) == MP_LT) {
break;
}
if ((err = s_mp_sub(a, n, a)) != MP_OKAY) {
goto LBL_ERR;
}
goto top;
}
LBL_ERR:

View File

@ -8,25 +8,23 @@ mp_err mp_reduce_2k_setup(const mp_int *a, mp_digit *d)
{
mp_err err;
mp_int tmp;
int p;
if ((err = mp_init(&tmp)) != MP_OKAY) {
return err;
}
p = mp_count_bits(a);
if ((err = mp_2expt(&tmp, p)) != MP_OKAY) {
mp_clear(&tmp);
return err;
if ((err = mp_2expt(&tmp, mp_count_bits(a))) != MP_OKAY) {
goto LBL_ERR;
}
if ((err = s_mp_sub(&tmp, a, &tmp)) != MP_OKAY) {
mp_clear(&tmp);
return err;
goto LBL_ERR;
}
*d = tmp.dp[0];
LBL_ERR:
mp_clear(&tmp);
return MP_OKAY;
return err;
}
#endif

View File

@ -13,7 +13,7 @@
*/
mp_err s_mp_montgomery_reduce_fast(mp_int *x, const mp_int *n, mp_digit rho)
{
int ix, olduse;
int ix, oldused;
mp_err err;
mp_word W[MP_WARRAY];
@ -22,7 +22,7 @@ mp_err s_mp_montgomery_reduce_fast(mp_int *x, const mp_int *n, mp_digit rho)
}
/* get old used count */
olduse = x->used;
oldused = x->used;
/* grow a as required */
if (x->alloc < (n->used + 1)) {
@ -34,38 +34,30 @@ mp_err s_mp_montgomery_reduce_fast(mp_int *x, const mp_int *n, mp_digit rho)
/* first we have to get the digits of the input into
* an array of double precision words W[...]
*/
{
mp_word *_W;
mp_digit *tmpx;
/* alias for the W[] array */
_W = W;
/* alias for the digits of x*/
tmpx = x->dp;
/* copy the digits of a into W[0..a->used-1] */
for (ix = 0; ix < x->used; ix++) {
*_W++ = *tmpx++;
W[ix] = x->dp[ix];
}
/* zero the high words of W[a->used..m->used*2] */
if (ix < ((n->used * 2) + 1)) {
MP_ZERO_BUFFER(_W, sizeof(mp_word) * (size_t)(((n->used * 2) + 1) - ix));
}
MP_ZERO_BUFFER(W + x->used, sizeof(mp_word) * (size_t)(((n->used * 2) + 1) - ix));
}
/* now we proceed to zero successive digits
* from the least significant upwards
*/
for (ix = 0; ix < n->used; ix++) {
int iy;
mp_digit mu;
/* mu = ai * m' mod b
*
* We avoid a double precision multiplication (which isn't required)
* by casting the value down to a mp_digit. Note this requires
* that W[ix-1] have the carry cleared (see after the inner loop)
*/
mp_digit mu;
mu = ((W[ix] & MP_MASK) * rho) & MP_MASK;
/* a = a + mu * m * b**i
@ -82,21 +74,8 @@ mp_err s_mp_montgomery_reduce_fast(mp_int *x, const mp_int *n, mp_digit rho)
* carry fixups are done in order so after these loops the
* first m->used words of W[] have the carries fixed
*/
{
int iy;
mp_digit *tmpn;
mp_word *_W;
/* alias for the digits of the modulus */
tmpn = n->dp;
/* Alias for the columns set by an offset of ix */
_W = W + ix;
/* inner loop */
for (iy = 0; iy < n->used; iy++) {
*_W++ += (mp_word)mu * (mp_word)*tmpn++;
}
W[ix + iy] += (mp_word)mu * (mp_word)n->dp[iy];
}
/* now fix carry for next digit, W[ix+1] */
@ -107,20 +86,9 @@ mp_err s_mp_montgomery_reduce_fast(mp_int *x, const mp_int *n, mp_digit rho)
* shift the words downward [all those least
* significant digits we zeroed].
*/
{
mp_digit *tmpx;
mp_word *_W, *_W1;
/* nox fix rest of carries */
/* alias for current word */
_W1 = W + ix;
/* alias for next word, where the carry goes */
_W = W + ++ix;
for (; ix < ((n->used * 2) + 1); ix++) {
*_W++ += *_W1++ >> (mp_word)MP_DIGIT_BIT;
for (; ix < (n->used * 2); ix++) {
W[ix + 1] += W[ix] >> (mp_word)MP_DIGIT_BIT;
}
/* copy out, A = A/b**n
@ -130,24 +98,18 @@ mp_err s_mp_montgomery_reduce_fast(mp_int *x, const mp_int *n, mp_digit rho)
* we just copy them in the right order
*/
/* alias for destination word */
tmpx = x->dp;
/* alias for shifted double precision result */
_W = W + n->used;
for (ix = 0; ix < (n->used + 1); ix++) {
*tmpx++ = *_W++ & (mp_word)MP_MASK;
x->dp[ix] = W[n->used + ix] & (mp_word)MP_MASK;
}
/* set the max used */
x->used = n->used + 1;
/* zero oldused digits, if the input a was larger than
* m->used+1 we'll have to clear the digits
*/
MP_ZERO_DIGITS(tmpx, olduse - ix);
}
MP_ZERO_DIGITS(x->dp + x->used, oldused - x->used);
/* set the max used and clamp */
x->used = n->used + 1;
mp_clamp(x);
/* if A >= m then A = A - m */