simplifications: mul/sqr comba

This commit is contained in:
Daniel Mendler 2019-10-29 20:06:20 +01:00
parent 7b6c6965bb
commit 8ac493512c
No known key found for this signature in database
GPG Key ID: D88ADB2A2693CA43
5 changed files with 50 additions and 102 deletions

View File

@ -11,10 +11,7 @@ mp_err s_mp_mul_digs(const mp_int *a, const mp_int *b, mp_int *c, int digs)
{ {
mp_int t; mp_int t;
mp_err err; mp_err err;
int pa, pb, ix, iy; int pa, ix;
mp_digit u;
mp_word r;
mp_digit tmpx, *tmpt, *tmpy;
/* can we use the fast multiplier? */ /* can we use the fast multiplier? */
if ((digs < MP_WARRAY) && if ((digs < MP_WARRAY) &&
@ -30,38 +27,28 @@ mp_err s_mp_mul_digs(const mp_int *a, const mp_int *b, mp_int *c, int digs)
/* compute the digits of the product directly */ /* compute the digits of the product directly */
pa = a->used; pa = a->used;
for (ix = 0; ix < pa; ix++) { for (ix = 0; ix < pa; ix++) {
/* set the carry to zero */ int iy, pb;
u = 0; mp_digit u = 0;
/* limit ourselves to making digs digits of output */ /* limit ourselves to making digs digits of output */
pb = MP_MIN(b->used, digs - ix); pb = MP_MIN(b->used, digs - ix);
/* setup some aliases */
/* copy of the digit from a used within the nested loop */
tmpx = a->dp[ix];
/* an alias for the destination shifted ix places */
tmpt = t.dp + ix;
/* an alias for the digits of b */
tmpy = b->dp;
/* compute the columns of the output and propagate the carry */ /* compute the columns of the output and propagate the carry */
for (iy = 0; iy < pb; iy++) { for (iy = 0; iy < pb; iy++) {
/* compute the column as a mp_word */ /* compute the column as a mp_word */
r = (mp_word)*tmpt + mp_word r = (mp_word)t.dp[ix + iy] +
((mp_word)tmpx * (mp_word)*tmpy++) + ((mp_word)a->dp[ix] * (mp_word)b->dp[iy]) +
(mp_word)u; (mp_word)u;
/* the new column is the lower part of the result */ /* the new column is the lower part of the result */
*tmpt++ = (mp_digit)(r & (mp_word)MP_MASK); t.dp[ix + iy] = (mp_digit)(r & (mp_word)MP_MASK);
/* get the carry word from the result */ /* get the carry word from the result */
u = (mp_digit)(r >> (mp_word)MP_DIGIT_BIT); u = (mp_digit)(r >> (mp_word)MP_DIGIT_BIT);
} }
/* set carry if it is placed below digs */ /* set carry if it is placed below digs */
if ((ix + iy) < digs) { if ((ix + iy) < digs) {
*tmpt = u; t.dp[ix + pb] = u;
} }
} }

View File

@ -21,7 +21,7 @@
*/ */
mp_err s_mp_mul_digs_fast(const mp_int *a, const mp_int *b, mp_int *c, int digs) mp_err s_mp_mul_digs_fast(const mp_int *a, const mp_int *b, mp_int *c, int digs)
{ {
int olduse, pa, ix, iz; int oldused, pa, ix;
mp_err err; mp_err err;
mp_digit W[MP_WARRAY]; mp_digit W[MP_WARRAY];
mp_word _W; mp_word _W;
@ -39,18 +39,12 @@ mp_err s_mp_mul_digs_fast(const mp_int *a, const mp_int *b, mp_int *c, int digs)
/* clear the carry */ /* clear the carry */
_W = 0; _W = 0;
for (ix = 0; ix < pa; ix++) { for (ix = 0; ix < pa; ix++) {
int tx, ty; int tx, ty, iy, iz;
int iy;
mp_digit *tmpx, *tmpy;
/* get offsets into the two bignums */ /* get offsets into the two bignums */
ty = MP_MIN(b->used-1, ix); ty = MP_MIN(b->used-1, ix);
tx = ix - ty; tx = ix - ty;
/* setup temp aliases */
tmpx = a->dp + tx;
tmpy = b->dp + ty;
/* this is the number of times the loop will iterrate, essentially /* this is the number of times the loop will iterrate, essentially
while (tx++ < a->used && ty-- >= 0) { ... } while (tx++ < a->used && ty-- >= 0) { ... }
*/ */
@ -58,8 +52,7 @@ mp_err s_mp_mul_digs_fast(const mp_int *a, const mp_int *b, mp_int *c, int digs)
/* execute loop */ /* execute loop */
for (iz = 0; iz < iy; ++iz) { for (iz = 0; iz < iy; ++iz) {
_W += (mp_word)*tmpx++ * (mp_word)*tmpy--; _W += (mp_word)a->dp[tx + iz] * (mp_word)b->dp[ty - iz];
} }
/* store term */ /* store term */
@ -70,20 +63,17 @@ mp_err s_mp_mul_digs_fast(const mp_int *a, const mp_int *b, mp_int *c, int digs)
} }
/* setup dest */ /* setup dest */
olduse = c->used; oldused = c->used;
c->used = pa; c->used = pa;
{ for (ix = 0; ix < pa; ix++) {
mp_digit *tmpc; /* now extract the previous digit [below the carry] */
tmpc = c->dp; c->dp[ix] = W[ix];
for (ix = 0; ix < pa; ix++) {
/* now extract the previous digit [below the carry] */
*tmpc++ = W[ix];
}
/* clear unused digits [that existed in the old copy of c] */
MP_ZERO_DIGITS(tmpc, olduse - ix);
} }
/* clear unused digits [that existed in the old copy of c] */
MP_ZERO_DIGITS(c->dp + c->used, oldused - c->used);
mp_clamp(c); mp_clamp(c);
return MP_OKAY; return MP_OKAY;
} }

View File

@ -9,11 +9,8 @@
mp_err s_mp_mul_high_digs(const mp_int *a, const mp_int *b, mp_int *c, int digs) mp_err s_mp_mul_high_digs(const mp_int *a, const mp_int *b, mp_int *c, int digs)
{ {
mp_int t; mp_int t;
int pa, pb, ix, iy; int pa, pb, ix;
mp_err err; mp_err err;
mp_digit u;
mp_word r;
mp_digit tmpx, *tmpt, *tmpy;
/* can we use the fast multiplier? */ /* can we use the fast multiplier? */
if (MP_HAS(S_MP_MUL_HIGH_DIGS_FAST) if (MP_HAS(S_MP_MUL_HIGH_DIGS_FAST)
@ -30,31 +27,22 @@ mp_err s_mp_mul_high_digs(const mp_int *a, const mp_int *b, mp_int *c, int digs)
pa = a->used; pa = a->used;
pb = b->used; pb = b->used;
for (ix = 0; ix < pa; ix++) { for (ix = 0; ix < pa; ix++) {
/* clear the carry */ int iy;
u = 0; mp_digit u = 0;
/* left hand side of A[ix] * B[iy] */
tmpx = a->dp[ix];
/* alias to the address of where the digits will be stored */
tmpt = &(t.dp[digs]);
/* alias for where to read the right hand side from */
tmpy = b->dp + (digs - ix);
for (iy = digs - ix; iy < pb; iy++) { for (iy = digs - ix; iy < pb; iy++) {
/* calculate the double precision result */ /* calculate the double precision result */
r = (mp_word)*tmpt + mp_word r = (mp_word)t.dp[ix + iy] +
((mp_word)tmpx * (mp_word)*tmpy++) + ((mp_word)a->dp[ix] * (mp_word)b->dp[iy]) +
(mp_word)u; (mp_word)u;
/* get the lower part */ /* get the lower part */
*tmpt++ = (mp_digit)(r & (mp_word)MP_MASK); t.dp[ix + iy] = (mp_digit)(r & (mp_word)MP_MASK);
/* carry the carry */ /* carry the carry */
u = (mp_digit)(r >> (mp_word)MP_DIGIT_BIT); u = (mp_digit)(r >> (mp_word)MP_DIGIT_BIT);
} }
*tmpt = u; t.dp[ix + pb] = u;
} }
mp_clamp(&t); mp_clamp(&t);
mp_exch(&t, c); mp_exch(&t, c);

View File

@ -14,7 +14,7 @@
*/ */
mp_err s_mp_mul_high_digs_fast(const mp_int *a, const mp_int *b, mp_int *c, int digs) mp_err s_mp_mul_high_digs_fast(const mp_int *a, const mp_int *b, mp_int *c, int digs)
{ {
int olduse, pa, ix, iz; int oldused, pa, ix;
mp_err err; mp_err err;
mp_digit W[MP_WARRAY]; mp_digit W[MP_WARRAY];
mp_word _W; mp_word _W;
@ -31,17 +31,12 @@ mp_err s_mp_mul_high_digs_fast(const mp_int *a, const mp_int *b, mp_int *c, int
pa = a->used + b->used; pa = a->used + b->used;
_W = 0; _W = 0;
for (ix = digs; ix < pa; ix++) { for (ix = digs; ix < pa; ix++) {
int tx, ty, iy; int tx, ty, iy, iz;
mp_digit *tmpx, *tmpy;
/* get offsets into the two bignums */ /* get offsets into the two bignums */
ty = MP_MIN(b->used-1, ix); ty = MP_MIN(b->used-1, ix);
tx = ix - ty; tx = ix - ty;
/* setup temp aliases */
tmpx = a->dp + tx;
tmpy = b->dp + ty;
/* this is the number of times the loop will iterrate, essentially its /* this is the number of times the loop will iterrate, essentially its
while (tx++ < a->used && ty-- >= 0) { ... } while (tx++ < a->used && ty-- >= 0) { ... }
*/ */
@ -49,7 +44,7 @@ mp_err s_mp_mul_high_digs_fast(const mp_int *a, const mp_int *b, mp_int *c, int
/* execute loop */ /* execute loop */
for (iz = 0; iz < iy; iz++) { for (iz = 0; iz < iy; iz++) {
_W += (mp_word)*tmpx++ * (mp_word)*tmpy--; _W += (mp_word)a->dp[tx + iz] * (mp_word)b->dp[ty - iz];
} }
/* store term */ /* store term */
@ -60,21 +55,17 @@ mp_err s_mp_mul_high_digs_fast(const mp_int *a, const mp_int *b, mp_int *c, int
} }
/* setup dest */ /* setup dest */
olduse = c->used; oldused = c->used;
c->used = pa; c->used = pa;
{ for (ix = digs; ix < pa; ix++) {
mp_digit *tmpc; /* now extract the previous digit [below the carry] */
c->dp[ix] = W[ix];
tmpc = c->dp + digs;
for (ix = digs; ix < pa; ix++) {
/* now extract the previous digit [below the carry] */
*tmpc++ = W[ix];
}
/* clear unused digits [that existed in the old copy of c] */
MP_ZERO_DIGITS(tmpc, olduse - ix);
} }
/* clear unused digits [that existed in the old copy of c] */
MP_ZERO_DIGITS(c->dp + c->used, oldused - c->used);
mp_clamp(c); mp_clamp(c);
return MP_OKAY; return MP_OKAY;
} }

View File

@ -15,14 +15,14 @@ After that loop you do the squares and add them in.
mp_err s_mp_sqr_fast(const mp_int *a, mp_int *b) mp_err s_mp_sqr_fast(const mp_int *a, mp_int *b)
{ {
int olduse, pa, ix, iz; int oldused, pa, ix;
mp_digit W[MP_WARRAY], *tmpx; mp_digit W[MP_WARRAY];
mp_word W1; mp_word W1;
mp_err err;
/* grow the destination as required */ /* grow the destination as required */
pa = a->used + a->used; pa = a->used + a->used;
if (b->alloc < pa) { if (b->alloc < pa) {
mp_err err;
if ((err = mp_grow(b, pa)) != MP_OKAY) { if ((err = mp_grow(b, pa)) != MP_OKAY) {
return err; return err;
} }
@ -31,9 +31,8 @@ mp_err s_mp_sqr_fast(const mp_int *a, mp_int *b)
/* number of output digits to produce */ /* number of output digits to produce */
W1 = 0; W1 = 0;
for (ix = 0; ix < pa; ix++) { for (ix = 0; ix < pa; ix++) {
int tx, ty, iy; int tx, ty, iy, iz;
mp_word _W; mp_word _W;
mp_digit *tmpy;
/* clear counter */ /* clear counter */
_W = 0; _W = 0;
@ -42,10 +41,6 @@ mp_err s_mp_sqr_fast(const mp_int *a, mp_int *b)
ty = MP_MIN(a->used-1, ix); ty = MP_MIN(a->used-1, ix);
tx = ix - ty; tx = ix - ty;
/* setup temp aliases */
tmpx = a->dp + tx;
tmpy = a->dp + ty;
/* this is the number of times the loop will iterrate, essentially /* this is the number of times the loop will iterrate, essentially
while (tx++ < a->used && ty-- >= 0) { ... } while (tx++ < a->used && ty-- >= 0) { ... }
*/ */
@ -59,7 +54,7 @@ mp_err s_mp_sqr_fast(const mp_int *a, mp_int *b)
/* execute loop */ /* execute loop */
for (iz = 0; iz < iy; iz++) { for (iz = 0; iz < iy; iz++) {
_W += (mp_word)*tmpx++ * (mp_word)*tmpy--; _W += (mp_word)a->dp[tx + iz] * (mp_word)a->dp[ty - iz];
} }
/* double the inner product and add carry */ /* double the inner product and add carry */
@ -78,19 +73,16 @@ mp_err s_mp_sqr_fast(const mp_int *a, mp_int *b)
} }
/* setup dest */ /* setup dest */
olduse = b->used; oldused = b->used;
b->used = a->used+a->used; b->used = a->used+a->used;
{ for (ix = 0; ix < pa; ix++) {
mp_digit *tmpb; b->dp[ix] = W[ix] & MP_MASK;
tmpb = b->dp;
for (ix = 0; ix < pa; ix++) {
*tmpb++ = W[ix] & MP_MASK;
}
/* clear unused digits [that existed in the old copy of c] */
MP_ZERO_DIGITS(tmpb, olduse - ix);
} }
/* clear unused digits [that existed in the old copy of c] */
MP_ZERO_DIGITS(b->dp + b->used, oldused - b->used);
mp_clamp(b); mp_clamp(b);
return MP_OKAY; return MP_OKAY;
} }