simplifications: basic arithmetic functions

This commit is contained in:
Daniel Mendler 2019-10-29 20:02:32 +01:00
parent e60149dec7
commit 143e0376a1
No known key found for this signature in database
GPG Key ID: D88ADB2A2693CA43
12 changed files with 181 additions and 276 deletions

View File

@ -6,9 +6,7 @@
/* single digit addition */ /* single digit addition */
mp_err mp_add_d(const mp_int *a, mp_digit b, mp_int *c) mp_err mp_add_d(const mp_int *a, mp_digit b, mp_int *c)
{ {
mp_err err; int oldused;
int ix, oldused;
mp_digit *tmpa, *tmpc;
/* fast path for a == c */ /* fast path for a == c */
if (a == c) { if (a == c) {
@ -27,6 +25,7 @@ mp_err mp_add_d(const mp_int *a, mp_digit b, mp_int *c)
/* grow c as required */ /* grow c as required */
if (c->alloc < (a->used + 1)) { if (c->alloc < (a->used + 1)) {
mp_err err;
if ((err = mp_grow(c, a->used + 1)) != MP_OKAY) { if ((err = mp_grow(c, a->used + 1)) != MP_OKAY) {
return err; return err;
} }
@ -34,6 +33,7 @@ mp_err mp_add_d(const mp_int *a, mp_digit b, mp_int *c)
/* if a is negative and |a| >= b, call c = |a| - b */ /* if a is negative and |a| >= b, call c = |a| - b */
if ((a->sign == MP_NEG) && ((a->used > 1) || (a->dp[0] >= b))) { if ((a->sign == MP_NEG) && ((a->used > 1) || (a->dp[0] >= b))) {
mp_err err;
mp_int a_ = *a; mp_int a_ = *a;
/* temporarily fix sign of a */ /* temporarily fix sign of a */
a_.sign = MP_ZPOS; a_.sign = MP_ZPOS;
@ -53,49 +53,34 @@ mp_err mp_add_d(const mp_int *a, mp_digit b, mp_int *c)
/* old number of used digits in c */ /* old number of used digits in c */
oldused = c->used; oldused = c->used;
/* source alias */
tmpa = a->dp;
/* destination alias */
tmpc = c->dp;
/* if a is positive */ /* if a is positive */
if (a->sign == MP_ZPOS) { if (a->sign == MP_ZPOS) {
/* add digits, mu is carry */ /* add digits, mu is carry */
int i;
mp_digit mu = b; mp_digit mu = b;
for (ix = 0; ix < a->used; ix++) { for (i = 0; i < a->used; i++) {
*tmpc = *tmpa++ + mu; c->dp[i] = a->dp[i] + mu;
mu = *tmpc >> MP_DIGIT_BIT; mu = c->dp[i] >> MP_DIGIT_BIT;
*tmpc++ &= MP_MASK; c->dp[i] &= MP_MASK;
} }
/* set final carry */ /* set final carry */
ix++; c->dp[i] = mu;
*tmpc++ = mu;
/* setup size */ /* setup size */
c->used = a->used + 1; c->used = a->used + 1;
} else { } else {
/* a was negative and |a| < b */ /* a was negative and |a| < b */
c->used = 1; c->used = 1;
/* the result is a single digit */ /* the result is a single digit */
if (a->used == 1) { c->dp[0] = (a->used == 1) ? b - a->dp[0] : b;
*tmpc++ = b - a->dp[0];
} else {
*tmpc++ = b;
}
/* setup count so the clearing of oldused
* can fall through correctly
*/
ix = 1;
} }
/* sign always positive */ /* sign always positive */
c->sign = MP_ZPOS; c->sign = MP_ZPOS;
/* now zero to oldused */ /* now zero to oldused */
MP_ZERO_DIGITS(tmpc, oldused - ix); MP_ZERO_DIGITS(c->dp + c->used, oldused - c->used);
mp_clamp(c); mp_clamp(c);
return MP_OKAY; return MP_OKAY;

View File

@ -6,12 +6,11 @@
/* b = a/2 */ /* b = a/2 */
mp_err mp_div_2(const mp_int *a, mp_int *b) mp_err mp_div_2(const mp_int *a, mp_int *b)
{ {
int x, oldused; int x, oldused;
mp_digit r, rr, *tmpa, *tmpb; mp_digit r;
mp_err err;
/* copy */
if (b->alloc < a->used) { if (b->alloc < a->used) {
mp_err err;
if ((err = mp_grow(b, a->used)) != MP_OKAY) { if ((err = mp_grow(b, a->used)) != MP_OKAY) {
return err; return err;
} }
@ -20,20 +19,14 @@ mp_err mp_div_2(const mp_int *a, mp_int *b)
oldused = b->used; oldused = b->used;
b->used = a->used; b->used = a->used;
/* source alias */
tmpa = a->dp + b->used - 1;
/* dest alias */
tmpb = b->dp + b->used - 1;
/* carry */ /* carry */
r = 0; r = 0;
for (x = b->used - 1; x >= 0; x--) { for (x = b->used; x --> 0;) {
/* get the carry for the next iteration */ /* get the carry for the next iteration */
rr = *tmpa & 1u; mp_digit rr = a->dp[x] & 1u;
/* shift the current digit, add in carry and store */ /* shift the current digit, add in carry and store */
*tmpb-- = (*tmpa-- >> 1) | (r << (MP_DIGIT_BIT - 1)); b->dp[x] = (a->dp[x] >> 1) | (r << (MP_DIGIT_BIT - 1));
/* forward carry to next iteration */ /* forward carry to next iteration */
r = rr; r = rr;

View File

@ -6,23 +6,16 @@
/* shift right by a certain bit count (store quotient in c, optional remainder in d) */ /* shift right by a certain bit count (store quotient in c, optional remainder in d) */
mp_err mp_div_2d(const mp_int *a, int b, mp_int *c, mp_int *d) mp_err mp_div_2d(const mp_int *a, int b, mp_int *c, mp_int *d)
{ {
mp_digit D, r, rr;
int x;
mp_err err; mp_err err;
/* if the shift count is <= 0 then we do no work */ if (b < 0) {
if (b <= 0) { return MP_VAL;
err = mp_copy(a, c);
if (d != NULL) {
mp_zero(d);
}
return err;
} }
/* copy */
if ((err = mp_copy(a, c)) != MP_OKAY) { if ((err = mp_copy(a, c)) != MP_OKAY) {
return err; return err;
} }
/* 'a' should not be used after here - it might be the same as d */ /* 'a' should not be used after here - it might be the same as d */
/* get the remainder */ /* get the remainder */
@ -38,28 +31,25 @@ mp_err mp_div_2d(const mp_int *a, int b, mp_int *c, mp_int *d)
} }
/* shift any bit count < MP_DIGIT_BIT */ /* shift any bit count < MP_DIGIT_BIT */
D = (mp_digit)(b % MP_DIGIT_BIT); b %= MP_DIGIT_BIT;
if (D != 0u) { if (b != 0u) {
mp_digit *tmpc, mask, shift; int x;
mp_digit r, mask, shift;
/* mask */ /* mask */
mask = ((mp_digit)1 << D) - 1uL; mask = ((mp_digit)1 << b) - 1uL;
/* shift for lsb */ /* shift for lsb */
shift = (mp_digit)MP_DIGIT_BIT - D; shift = (mp_digit)(MP_DIGIT_BIT - b);
/* alias */
tmpc = c->dp + (c->used - 1);
/* carry */ /* carry */
r = 0; r = 0;
for (x = c->used - 1; x >= 0; x--) { for (x = c->used; x --> 0;) {
/* get the lower bits of this word in a temp */ /* get the lower bits of this word in a temp */
rr = *tmpc & mask; mp_digit rr = c->dp[x] & mask;
/* shift the current word and mix in the carry bits from the previous word */ /* shift the current word and mix in the carry bits from the previous word */
*tmpc = (*tmpc >> D) | (r << shift); c->dp[x] = (c->dp[x] >> b) | (r << shift);
--tmpc;
/* set the carry to the carry bits of the current word found above */ /* set the carry to the carry bits of the current word found above */
r = rr; r = rr;

View File

@ -8,7 +8,6 @@ mp_err mp_div_d(const mp_int *a, mp_digit b, mp_int *c, mp_digit *d)
{ {
mp_int q; mp_int q;
mp_word w; mp_word w;
mp_digit t;
mp_err err; mp_err err;
int ix; int ix;
@ -56,14 +55,12 @@ mp_err mp_div_d(const mp_int *a, mp_digit b, mp_int *c, mp_digit *d)
q.used = a->used; q.used = a->used;
q.sign = a->sign; q.sign = a->sign;
w = 0; w = 0;
for (ix = a->used - 1; ix >= 0; ix--) { for (ix = a->used; ix --> 0;) {
mp_digit t = 0;
w = (w << (mp_word)MP_DIGIT_BIT) | (mp_word)a->dp[ix]; w = (w << (mp_word)MP_DIGIT_BIT) | (mp_word)a->dp[ix];
if (w >= b) { if (w >= b) {
t = (mp_digit)(w / b); t = (mp_digit)(w / b);
w -= (mp_word)t * (mp_word)b; w -= (mp_word)t * (mp_word)b;
} else {
t = 0;
} }
q.dp[ix] = t; q.dp[ix] = t;
} }
@ -78,7 +75,7 @@ mp_err mp_div_d(const mp_int *a, mp_digit b, mp_int *c, mp_digit *d)
} }
mp_clear(&q); mp_clear(&q);
return err; return MP_OKAY;
} }
#endif #endif

View File

@ -7,8 +7,8 @@
mp_err mp_mul(const mp_int *a, const mp_int *b, mp_int *c) mp_err mp_mul(const mp_int *a, const mp_int *b, mp_int *c)
{ {
mp_err err; mp_err err;
int min_len = MP_MIN(a->used, b->used), int min = MP_MIN(a->used, b->used),
max_len = MP_MAX(a->used, b->used), max = MP_MAX(a->used, b->used),
digs = a->used + b->used + 1; digs = a->used + b->used + 1;
mp_sign neg = (a->sign == b->sign) ? MP_ZPOS : MP_NEG; mp_sign neg = (a->sign == b->sign) ? MP_ZPOS : MP_NEG;
@ -20,16 +20,16 @@ mp_err mp_mul(const mp_int *a, const mp_int *b, mp_int *c)
* Using it to cut the input into slices small enough for s_mp_mul_digs_fast * Using it to cut the input into slices small enough for s_mp_mul_digs_fast
* was actually slower on the author's machine, but YMMV. * was actually slower on the author's machine, but YMMV.
*/ */
(min_len >= MP_KARATSUBA_MUL_CUTOFF) && (min >= MP_KARATSUBA_MUL_CUTOFF) &&
((max_len / 2) >= MP_KARATSUBA_MUL_CUTOFF) && ((max / 2) >= MP_KARATSUBA_MUL_CUTOFF) &&
/* Not much effect was observed below a ratio of 1:2, but again: YMMV. */ /* Not much effect was observed below a ratio of 1:2, but again: YMMV. */
(max_len >= (2 * min_len))) { (max >= (2 * min))) {
err = s_mp_balance_mul(a,b,c); err = s_mp_balance_mul(a,b,c);
} else if (MP_HAS(S_MP_TOOM_MUL) && } else if (MP_HAS(S_MP_TOOM_MUL) &&
(min_len >= MP_TOOM_MUL_CUTOFF)) { (min >= MP_TOOM_MUL_CUTOFF)) {
err = s_mp_toom_mul(a, b, c); err = s_mp_toom_mul(a, b, c);
} else if (MP_HAS(S_MP_KARATSUBA_MUL) && } else if (MP_HAS(S_MP_KARATSUBA_MUL) &&
(min_len >= MP_KARATSUBA_MUL_CUTOFF)) { (min >= MP_KARATSUBA_MUL_CUTOFF)) {
err = s_mp_karatsuba_mul(a, b, c); err = s_mp_karatsuba_mul(a, b, c);
} else if (MP_HAS(S_MP_MUL_DIGS_FAST) && } else if (MP_HAS(S_MP_MUL_DIGS_FAST) &&
/* can we use the fast multiplier? /* can we use the fast multiplier?
@ -39,7 +39,7 @@ mp_err mp_mul(const mp_int *a, const mp_int *b, mp_int *c)
* digits won't affect carry propagation * digits won't affect carry propagation
*/ */
(digs < MP_WARRAY) && (digs < MP_WARRAY) &&
(min_len <= MP_MAXFAST)) { (min <= MP_MAXFAST)) {
err = s_mp_mul_digs_fast(a, b, c, digs); err = s_mp_mul_digs_fast(a, b, c, digs);
} else if (MP_HAS(S_MP_MUL_DIGS)) { } else if (MP_HAS(S_MP_MUL_DIGS)) {
err = s_mp_mul_digs(a, b, c, digs); err = s_mp_mul_digs(a, b, c, digs);

View File

@ -6,11 +6,12 @@
/* b = a*2 */ /* b = a*2 */
mp_err mp_mul_2(const mp_int *a, mp_int *b) mp_err mp_mul_2(const mp_int *a, mp_int *b)
{ {
int x, oldused; int x, oldused;
mp_err err; mp_digit r;
/* grow to accomodate result */ /* grow to accomodate result */
if (b->alloc < (a->used + 1)) { if (b->alloc < (a->used + 1)) {
mp_err err;
if ((err = mp_grow(b, a->used + 1)) != MP_OKAY) { if ((err = mp_grow(b, a->used + 1)) != MP_OKAY) {
return err; return err;
} }
@ -19,45 +20,35 @@ mp_err mp_mul_2(const mp_int *a, mp_int *b)
oldused = b->used; oldused = b->used;
b->used = a->used; b->used = a->used;
{ /* carry */
mp_digit r, rr, *tmpa, *tmpb; r = 0;
for (x = 0; x < a->used; x++) {
/* alias for source */ /* get what will be the *next* carry bit from the
tmpa = a->dp; * MSB of the current digit
/* alias for dest */
tmpb = b->dp;
/* carry */
r = 0;
for (x = 0; x < a->used; x++) {
/* get what will be the *next* carry bit from the
* MSB of the current digit
*/
rr = *tmpa >> (mp_digit)(MP_DIGIT_BIT - 1);
/* now shift up this digit, add in the carry [from the previous] */
*tmpb++ = ((*tmpa++ << 1uL) | r) & MP_MASK;
/* copy the carry that would be from the source
* digit into the next iteration
*/
r = rr;
}
/* new leading digit? */
if (r != 0u) {
/* add a MSB which is always 1 at this point */
*tmpb = 1;
++(b->used);
}
/* now zero any excess digits on the destination
* that we didn't write to
*/ */
MP_ZERO_DIGITS(b->dp + b->used, oldused - b->used); mp_digit rr = a->dp[x] >> (mp_digit)(MP_DIGIT_BIT - 1);
/* now shift up this digit, add in the carry [from the previous] */
b->dp[x] = ((a->dp[x] << 1uL) | r) & MP_MASK;
/* copy the carry that would be from the source
* digit into the next iteration
*/
r = rr;
} }
/* new leading digit? */
if (r != 0u) {
/* add a MSB which is always 1 at this point */
b->dp[b->used++] = 1;
}
/* now zero any excess digits on the destination
* that we didn't write to
*/
MP_ZERO_DIGITS(b->dp + b->used, oldused - b->used);
b->sign = a->sign; b->sign = a->sign;
return MP_OKAY; return MP_OKAY;
} }

View File

@ -6,17 +6,19 @@
/* shift left by a certain bit count */ /* shift left by a certain bit count */
mp_err mp_mul_2d(const mp_int *a, int b, mp_int *c) mp_err mp_mul_2d(const mp_int *a, int b, mp_int *c)
{ {
mp_digit d; if (b < 0) {
mp_err err; return MP_VAL;
}
/* copy */
if (a != c) { if (a != c) {
mp_err err;
if ((err = mp_copy(a, c)) != MP_OKAY) { if ((err = mp_copy(a, c)) != MP_OKAY) {
return err; return err;
} }
} }
if (c->alloc < (c->used + (b / MP_DIGIT_BIT) + 1)) { if (c->alloc < (c->used + (b / MP_DIGIT_BIT) + 1)) {
mp_err err;
if ((err = mp_grow(c, c->used + (b / MP_DIGIT_BIT) + 1)) != MP_OKAY) { if ((err = mp_grow(c, c->used + (b / MP_DIGIT_BIT) + 1)) != MP_OKAY) {
return err; return err;
} }
@ -24,35 +26,32 @@ mp_err mp_mul_2d(const mp_int *a, int b, mp_int *c)
/* shift by as many digits in the bit count */ /* shift by as many digits in the bit count */
if (b >= MP_DIGIT_BIT) { if (b >= MP_DIGIT_BIT) {
mp_err err;
if ((err = mp_lshd(c, b / MP_DIGIT_BIT)) != MP_OKAY) { if ((err = mp_lshd(c, b / MP_DIGIT_BIT)) != MP_OKAY) {
return err; return err;
} }
} }
/* shift any bit count < MP_DIGIT_BIT */ /* shift any bit count < MP_DIGIT_BIT */
d = (mp_digit)(b % MP_DIGIT_BIT); b %= MP_DIGIT_BIT;
if (d != 0u) { if (b != 0u) {
mp_digit *tmpc, shift, mask, r, rr; mp_digit shift, mask, r;
int x; int x;
/* bitmask for carries */ /* bitmask for carries */
mask = ((mp_digit)1 << d) - (mp_digit)1; mask = ((mp_digit)1 << b) - (mp_digit)1;
/* shift for msbs */ /* shift for msbs */
shift = (mp_digit)MP_DIGIT_BIT - d; shift = (mp_digit)(MP_DIGIT_BIT - b);
/* alias */
tmpc = c->dp;
/* carry */ /* carry */
r = 0; r = 0;
for (x = 0; x < c->used; x++) { for (x = 0; x < c->used; x++) {
/* get the higher bits of the current word */ /* get the higher bits of the current word */
rr = (*tmpc >> shift) & mask; mp_digit rr = (c->dp[x] >> shift) & mask;
/* shift the current word and OR in the carry */ /* shift the current word and OR in the carry */
*tmpc = ((*tmpc << d) | r) & MP_MASK; c->dp[x] = ((c->dp[x] << b) | r) & MP_MASK;
++tmpc;
/* set the carry to the carry bits of the current word */ /* set the carry to the carry bits of the current word */
r = rr; r = rr;

View File

@ -6,10 +6,9 @@
/* multiply by a digit */ /* multiply by a digit */
mp_err mp_mul_d(const mp_int *a, mp_digit b, mp_int *c) mp_err mp_mul_d(const mp_int *a, mp_digit b, mp_int *c)
{ {
mp_digit u, *tmpa, *tmpc; mp_digit u;
mp_word r;
mp_err err; mp_err err;
int ix, olduse; int ix, oldused;
/* make sure c is big enough to hold a*b */ /* make sure c is big enough to hold a*b */
if (c->alloc < (a->used + 1)) { if (c->alloc < (a->used + 1)) {
@ -19,41 +18,35 @@ mp_err mp_mul_d(const mp_int *a, mp_digit b, mp_int *c)
} }
/* get the original destinations used count */ /* get the original destinations used count */
olduse = c->used; oldused = c->used;
/* set the sign */ /* set the sign */
c->sign = a->sign; c->sign = a->sign;
/* alias for a->dp [source] */
tmpa = a->dp;
/* alias for c->dp [dest] */
tmpc = c->dp;
/* zero carry */ /* zero carry */
u = 0; u = 0;
/* compute columns */ /* compute columns */
for (ix = 0; ix < a->used; ix++) { for (ix = 0; ix < a->used; ix++) {
/* compute product and carry sum for this term */ /* compute product and carry sum for this term */
r = (mp_word)u + ((mp_word)*tmpa++ * (mp_word)b); mp_word r = (mp_word)u + ((mp_word)a->dp[ix] * (mp_word)b);
/* mask off higher bits to get a single digit */ /* mask off higher bits to get a single digit */
*tmpc++ = (mp_digit)(r & (mp_word)MP_MASK); c->dp[ix] = (mp_digit)(r & (mp_word)MP_MASK);
/* send carry into next iteration */ /* send carry into next iteration */
u = (mp_digit)(r >> (mp_word)MP_DIGIT_BIT); u = (mp_digit)(r >> (mp_word)MP_DIGIT_BIT);
} }
/* store final carry [if any] and increment ix offset */ /* store final carry [if any] and increment ix offset */
*tmpc++ = u; c->dp[ix] = u;
++ix;
/* now zero digits above the top */
MP_ZERO_DIGITS(tmpc, olduse - ix);
/* set used count */ /* set used count */
c->used = a->used + 1; c->used = a->used + 1;
/* now zero digits above the top */
MP_ZERO_DIGITS(c->dp + c->used, oldused - c->used);
mp_clamp(c); mp_clamp(c);
return MP_OKAY; return MP_OKAY;

View File

@ -6,9 +6,7 @@
/* single digit subtraction */ /* single digit subtraction */
mp_err mp_sub_d(const mp_int *a, mp_digit b, mp_int *c) mp_err mp_sub_d(const mp_int *a, mp_digit b, mp_int *c)
{ {
mp_digit *tmpa, *tmpc; int oldused;
mp_err err;
int ix, oldused;
/* fast path for a == c */ /* fast path for a == c */
if (a == c) { if (a == c) {
@ -26,6 +24,7 @@ mp_err mp_sub_d(const mp_int *a, mp_digit b, mp_int *c)
/* grow c as required */ /* grow c as required */
if (c->alloc < (a->used + 1)) { if (c->alloc < (a->used + 1)) {
mp_err err;
if ((err = mp_grow(c, a->used + 1)) != MP_OKAY) { if ((err = mp_grow(c, a->used + 1)) != MP_OKAY) {
return err; return err;
} }
@ -35,6 +34,7 @@ mp_err mp_sub_d(const mp_int *a, mp_digit b, mp_int *c)
* addition [with fudged signs] * addition [with fudged signs]
*/ */
if (a->sign == MP_NEG) { if (a->sign == MP_NEG) {
mp_err err;
mp_int a_ = *a; mp_int a_ = *a;
a_.sign = MP_ZPOS; a_.sign = MP_ZPOS;
err = mp_add_d(&a_, b, c); err = mp_add_d(&a_, b, c);
@ -46,24 +46,17 @@ mp_err mp_sub_d(const mp_int *a, mp_digit b, mp_int *c)
return err; return err;
} }
/* setup regs */
oldused = c->used; oldused = c->used;
tmpa = a->dp;
tmpc = c->dp;
/* if a <= b simply fix the single digit */ /* if a <= b simply fix the single digit */
if (((a->used == 1) && (a->dp[0] <= b)) || (a->used == 0)) { if (((a->used == 1) && (a->dp[0] <= b)) || (a->used == 0)) {
if (a->used == 1) { c->dp[0] = (a->used == 1) ? b - a->dp[0] : b;
*tmpc++ = b - *tmpa;
} else {
*tmpc++ = b;
}
ix = 1;
/* negative/1digit */ /* negative/1digit */
c->sign = MP_NEG; c->sign = MP_NEG;
c->used = 1; c->used = 1;
} else { } else {
int i;
mp_digit mu = b; mp_digit mu = b;
/* positive/size */ /* positive/size */
@ -71,15 +64,15 @@ mp_err mp_sub_d(const mp_int *a, mp_digit b, mp_int *c)
c->used = a->used; c->used = a->used;
/* subtract digits, mu is carry */ /* subtract digits, mu is carry */
for (ix = 0; ix < a->used; ix++) { for (i = 0; i < a->used; i++) {
*tmpc = *tmpa++ - mu; c->dp[i] = a->dp[i] - mu;
mu = *tmpc >> (MP_SIZEOF_BITS(mp_digit) - 1u); mu = c->dp[i] >> (MP_SIZEOF_BITS(mp_digit) - 1u);
*tmpc++ &= MP_MASK; c->dp[i] &= MP_MASK;
} }
} }
/* zero excess digits */ /* zero excess digits */
MP_ZERO_DIGITS(tmpc, oldused - ix); MP_ZERO_DIGITS(c->dp + c->used, oldused - c->used);
mp_clamp(c); mp_clamp(c);
return MP_OKAY; return MP_OKAY;

View File

@ -6,85 +6,66 @@
/* low level addition, based on HAC pp.594, Algorithm 14.7 */ /* low level addition, based on HAC pp.594, Algorithm 14.7 */
mp_err s_mp_add(const mp_int *a, const mp_int *b, mp_int *c) mp_err s_mp_add(const mp_int *a, const mp_int *b, mp_int *c)
{ {
const mp_int *x; int oldused, min, max, i;
mp_err err; mp_digit u;
int olduse, min, max;
/* find sizes, we let |a| <= |b| which means we have to sort /* find sizes, we let |a| <= |b| which means we have to sort
* them. "x" will point to the input with the most digits * them. "x" will point to the input with the most digits
*/ */
if (a->used > b->used) { if (a->used < b->used) {
min = b->used; MP_EXCH(const mp_int *, a, b);
max = a->used;
x = a;
} else {
min = a->used;
max = b->used;
x = b;
} }
min = b->used;
max = a->used;
/* init result */ /* init result */
if (c->alloc < (max + 1)) { if (c->alloc < (max + 1)) {
mp_err err;
if ((err = mp_grow(c, max + 1)) != MP_OKAY) { if ((err = mp_grow(c, max + 1)) != MP_OKAY) {
return err; return err;
} }
} }
/* get old used digit count and set new one */ /* get old used digit count and set new one */
olduse = c->used; oldused = c->used;
c->used = max + 1; c->used = max + 1;
{ /* zero the carry */
mp_digit u, *tmpa, *tmpb, *tmpc; u = 0;
int i; for (i = 0; i < min; i++) {
/* Compute the sum at one digit, T[i] = A[i] + B[i] + U */
c->dp[i] = a->dp[i] + b->dp[i] + u;
/* alias for digit pointers */ /* U = carry bit of T[i] */
u = c->dp[i] >> (mp_digit)MP_DIGIT_BIT;
/* first input */ /* take away carry bit from T[i] */
tmpa = a->dp; c->dp[i] &= MP_MASK;
}
/* second input */ /* now copy higher words if any, that is in A+B
tmpb = b->dp; * if A or B has more digits add those in
*/
/* destination */ if (min != max) {
tmpc = c->dp; for (; i < max; i++) {
/* T[i] = A[i] + U */
/* zero the carry */ c->dp[i] = a->dp[i] + u;
u = 0;
for (i = 0; i < min; i++) {
/* Compute the sum at one digit, T[i] = A[i] + B[i] + U */
*tmpc = *tmpa++ + *tmpb++ + u;
/* U = carry bit of T[i] */ /* U = carry bit of T[i] */
u = *tmpc >> (mp_digit)MP_DIGIT_BIT; u = c->dp[i] >> (mp_digit)MP_DIGIT_BIT;
/* take away carry bit from T[i] */ /* take away carry bit from T[i] */
*tmpc++ &= MP_MASK; c->dp[i] &= MP_MASK;
} }
/* now copy higher words if any, that is in A+B
* if A or B has more digits add those in
*/
if (min != max) {
for (; i < max; i++) {
/* T[i] = X[i] + U */
*tmpc = x->dp[i] + u;
/* U = carry bit of T[i] */
u = *tmpc >> (mp_digit)MP_DIGIT_BIT;
/* take away carry bit from T[i] */
*tmpc++ &= MP_MASK;
}
}
/* add carry */
*tmpc++ = u;
/* clear digits above oldused */
MP_ZERO_DIGITS(tmpc, olduse - c->used);
} }
/* add carry */
c->dp[i] = u;
/* clear digits above oldused */
MP_ZERO_DIGITS(c->dp + c->used, oldused - c->used);
mp_clamp(c); mp_clamp(c);
return MP_OKAY; return MP_OKAY;
} }

View File

@ -7,10 +7,8 @@
mp_err s_mp_sqr(const mp_int *a, mp_int *b) mp_err s_mp_sqr(const mp_int *a, mp_int *b)
{ {
mp_int t; mp_int t;
int ix, iy, pa; int ix, pa;
mp_err err; mp_err err;
mp_word r;
mp_digit u, tmpx, *tmpt;
pa = a->used; pa = a->used;
if ((err = mp_init_size(&t, (2 * pa) + 1)) != MP_OKAY) { if ((err = mp_init_size(&t, (2 * pa) + 1)) != MP_OKAY) {
@ -21,10 +19,13 @@ mp_err s_mp_sqr(const mp_int *a, mp_int *b)
t.used = (2 * pa) + 1; t.used = (2 * pa) + 1;
for (ix = 0; ix < pa; ix++) { for (ix = 0; ix < pa; ix++) {
mp_digit u;
int iy;
/* first calculate the digit at 2*ix */ /* first calculate the digit at 2*ix */
/* calculate double precision result */ /* calculate double precision result */
r = (mp_word)t.dp[2*ix] + mp_word r = (mp_word)t.dp[2*ix] +
((mp_word)a->dp[ix] * (mp_word)a->dp[ix]); ((mp_word)a->dp[ix] * (mp_word)a->dp[ix]);
/* store lower part in result */ /* store lower part in result */
t.dp[ix+ix] = (mp_digit)(r & (mp_word)MP_MASK); t.dp[ix+ix] = (mp_digit)(r & (mp_word)MP_MASK);
@ -32,32 +33,27 @@ mp_err s_mp_sqr(const mp_int *a, mp_int *b)
/* get the carry */ /* get the carry */
u = (mp_digit)(r >> (mp_word)MP_DIGIT_BIT); u = (mp_digit)(r >> (mp_word)MP_DIGIT_BIT);
/* left hand side of A[ix] * A[iy] */
tmpx = a->dp[ix];
/* alias for where to store the results */
tmpt = t.dp + ((2 * ix) + 1);
for (iy = ix + 1; iy < pa; iy++) { for (iy = ix + 1; iy < pa; iy++) {
/* first calculate the product */ /* first calculate the product */
r = (mp_word)tmpx * (mp_word)a->dp[iy]; r = (mp_word)a->dp[ix] * (mp_word)a->dp[iy];
/* now calculate the double precision result, note we use /* now calculate the double precision result, note we use
* addition instead of *2 since it's easier to optimize * addition instead of *2 since it's easier to optimize
*/ */
r = (mp_word)*tmpt + r + r + (mp_word)u; r = (mp_word)t.dp[ix + iy] + r + r + (mp_word)u;
/* store lower part */ /* store lower part */
*tmpt++ = (mp_digit)(r & (mp_word)MP_MASK); t.dp[ix + iy] = (mp_digit)(r & (mp_word)MP_MASK);
/* get carry */ /* get carry */
u = (mp_digit)(r >> (mp_word)MP_DIGIT_BIT); u = (mp_digit)(r >> (mp_word)MP_DIGIT_BIT);
} }
/* propagate upwards */ /* propagate upwards */
while (u != 0uL) { while (u != 0uL) {
r = (mp_word)*tmpt + (mp_word)u; r = (mp_word)t.dp[ix + iy] + (mp_word)u;
*tmpt++ = (mp_digit)(r & (mp_word)MP_MASK); t.dp[ix + iy] = (mp_digit)(r & (mp_word)MP_MASK);
u = (mp_digit)(r >> (mp_word)MP_DIGIT_BIT); u = (mp_digit)(r >> (mp_word)MP_DIGIT_BIT);
++iy;
} }
} }

View File

@ -6,64 +6,51 @@
/* low level subtraction (assumes |a| > |b|), HAC pp.595 Algorithm 14.9 */ /* low level subtraction (assumes |a| > |b|), HAC pp.595 Algorithm 14.9 */
mp_err s_mp_sub(const mp_int *a, const mp_int *b, mp_int *c) mp_err s_mp_sub(const mp_int *a, const mp_int *b, mp_int *c)
{ {
int olduse, min, max; int oldused = c->used, min = b->used, max = a->used, i;
mp_err err; mp_digit u;
/* find sizes */
min = b->used;
max = a->used;
/* init result */ /* init result */
if (c->alloc < max) { if (c->alloc < max) {
mp_err err;
if ((err = mp_grow(c, max)) != MP_OKAY) { if ((err = mp_grow(c, max)) != MP_OKAY) {
return err; return err;
} }
} }
olduse = c->used;
c->used = max; c->used = max;
{ /* set carry to zero */
mp_digit u, *tmpa, *tmpb, *tmpc; u = 0;
int i; for (i = 0; i < min; i++) {
/* T[i] = A[i] - B[i] - U */
c->dp[i] = (a->dp[i] - b->dp[i]) - u;
/* alias for digit pointers */ /* U = carry bit of T[i]
tmpa = a->dp; * Note this saves performing an AND operation since
tmpb = b->dp; * if a carry does occur it will propagate all the way to the
tmpc = c->dp; * MSB. As a result a single shift is enough to get the carry
*/
u = c->dp[i] >> (MP_SIZEOF_BITS(mp_digit) - 1u);
/* set carry to zero */ /* Clear carry from T[i] */
u = 0; c->dp[i] &= MP_MASK;
for (i = 0; i < min; i++) {
/* T[i] = A[i] - B[i] - U */
*tmpc = (*tmpa++ - *tmpb++) - u;
/* U = carry bit of T[i]
* Note this saves performing an AND operation since
* if a carry does occur it will propagate all the way to the
* MSB. As a result a single shift is enough to get the carry
*/
u = *tmpc >> (MP_SIZEOF_BITS(mp_digit) - 1u);
/* Clear carry from T[i] */
*tmpc++ &= MP_MASK;
}
/* now copy higher words if any, e.g. if A has more digits than B */
for (; i < max; i++) {
/* T[i] = A[i] - U */
*tmpc = *tmpa++ - u;
/* U = carry bit of T[i] */
u = *tmpc >> (MP_SIZEOF_BITS(mp_digit) - 1u);
/* Clear carry from T[i] */
*tmpc++ &= MP_MASK;
}
/* clear digits above used (since we may not have grown result above) */
MP_ZERO_DIGITS(tmpc, olduse - c->used);
} }
/* now copy higher words if any, e.g. if A has more digits than B */
for (; i < max; i++) {
/* T[i] = A[i] - U */
c->dp[i] = a->dp[i] - u;
/* U = carry bit of T[i] */
u = c->dp[i] >> (MP_SIZEOF_BITS(mp_digit) - 1u);
/* Clear carry from T[i] */
c->dp[i] &= MP_MASK;
}
/* clear digits above used (since we may not have grown result above) */
MP_ZERO_DIGITS(c->dp + c->used, oldused - c->used);
mp_clamp(c); mp_clamp(c);
return MP_OKAY; return MP_OKAY;
} }