diff --git a/mp_abs.c b/mp_abs.c index 4ad1a4a..902279e 100644 --- a/mp_abs.c +++ b/mp_abs.c @@ -9,10 +9,9 @@ */ mp_err mp_abs(const mp_int *a, mp_int *b) { - mp_err err; - /* copy a to b */ if (a != b) { + mp_err err; if ((err = mp_copy(a, b)) != MP_OKAY) { return err; } diff --git a/mp_add.c b/mp_add.c index c78614b..bf7a61e 100644 --- a/mp_add.c +++ b/mp_add.c @@ -6,33 +6,24 @@ /* high level addition (handles signs) */ mp_err mp_add(const mp_int *a, const mp_int *b, mp_int *c) { - mp_sign sa, sb; - mp_err err; - - /* get sign of both inputs */ - sa = a->sign; - sb = b->sign; - /* handle two cases, not four */ - if (sa == sb) { + if (a->sign == b->sign) { /* both positive or both negative */ /* add their magnitudes, copy the sign */ - c->sign = sa; - err = s_mp_add(a, b, c); - } else { - /* one positive, the other negative */ - /* subtract the one with the greater magnitude from */ - /* the one of the lesser magnitude. The result gets */ - /* the sign of the one with the greater magnitude. */ - if (mp_cmp_mag(a, b) == MP_LT) { - c->sign = sb; - err = s_mp_sub(b, a, c); - } else { - c->sign = sa; - err = s_mp_sub(a, b, c); - } + c->sign = a->sign; + return s_mp_add(a, b, c); } - return err; + + /* one positive, the other negative */ + /* subtract the one with the greater magnitude from */ + /* the one of the lesser magnitude. The result gets */ + /* the sign of the one with the greater magnitude. */ + if (mp_cmp_mag(a, b) == MP_LT) { + MP_EXCH(const mp_int *, a, b); + } + + c->sign = a->sign; + return s_mp_sub(a, b, c); } #endif diff --git a/mp_clear_multi.c b/mp_clear_multi.c index 74406c7..9c7aed8 100644 --- a/mp_clear_multi.c +++ b/mp_clear_multi.c @@ -7,12 +7,11 @@ void mp_clear_multi(mp_int *mp, ...) { - mp_int *next_mp = mp; va_list args; va_start(args, mp); - while (next_mp != NULL) { - mp_clear(next_mp); - next_mp = va_arg(args, mp_int *); + while (mp != NULL) { + mp_clear(mp); + mp = va_arg(args, mp_int *); } va_end(args); } diff --git a/mp_cmp.c b/mp_cmp.c index a9bd910..b9c4592 100644 --- a/mp_cmp.c +++ b/mp_cmp.c @@ -8,19 +8,14 @@ mp_ord mp_cmp(const mp_int *a, const mp_int *b) { /* compare based on sign */ if (a->sign != b->sign) { - if (a->sign == MP_NEG) { - return MP_LT; - } else { - return MP_GT; - } + return a->sign == MP_NEG ? MP_LT : MP_GT; } - /* compare digits */ + /* if negative compare opposite direction */ if (a->sign == MP_NEG) { - /* if negative compare opposite direction */ - return mp_cmp_mag(b, a); - } else { - return mp_cmp_mag(a, b); + MP_EXCH(const mp_int *, a, b); } + + return mp_cmp_mag(a, b); } #endif diff --git a/mp_cmp_d.c b/mp_cmp_d.c index 03d8e2c..0d98e05 100644 --- a/mp_cmp_d.c +++ b/mp_cmp_d.c @@ -17,12 +17,10 @@ mp_ord mp_cmp_d(const mp_int *a, mp_digit b) } /* compare the only digit of a to b */ - if (a->dp[0] > b) { - return MP_GT; - } else if (a->dp[0] < b) { - return MP_LT; - } else { - return MP_EQ; + if (a->dp[0] != b) { + return a->dp[0] > b ? MP_GT : MP_LT; } + + return MP_EQ; } #endif diff --git a/mp_cmp_mag.c b/mp_cmp_mag.c index b3a7b04..e5e502b 100644 --- a/mp_cmp_mag.c +++ b/mp_cmp_mag.c @@ -6,34 +6,20 @@ /* compare maginitude of two ints (unsigned) */ mp_ord mp_cmp_mag(const mp_int *a, const mp_int *b) { - int n; - const mp_digit *tmpa, *tmpb; + int n; /* compare based on # of non-zero digits */ - if (a->used > b->used) { - return MP_GT; + if (a->used != b->used) { + return a->used > b->used ? MP_GT : MP_LT; } - if (a->used < b->used) { - return MP_LT; - } - - /* alias for a */ - tmpa = a->dp + (a->used - 1); - - /* alias for b */ - tmpb = b->dp + (a->used - 1); - /* compare based on digits */ - for (n = 0; n < a->used; ++n, --tmpa, --tmpb) { - if (*tmpa > *tmpb) { - return MP_GT; - } - - if (*tmpa < *tmpb) { - return MP_LT; + for (n = a->used; n --> 0;) { + if (a->dp[n] != b->dp[n]) { + return a->dp[n] > b->dp[n] ? MP_GT : MP_LT; } } + return MP_EQ; } #endif diff --git a/mp_cnt_lsb.c b/mp_cnt_lsb.c index 7ae8bc1..8519ad1 100644 --- a/mp_cnt_lsb.c +++ b/mp_cnt_lsb.c @@ -3,7 +3,7 @@ /* LibTomMath, multiple-precision integer library -- Tom St Denis */ /* SPDX-License-Identifier: Unlicense */ -static const int lnz[16] = { +static const char lnz[16] = { 4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0 }; @@ -11,7 +11,7 @@ static const int lnz[16] = { int mp_cnt_lsb(const mp_int *a) { int x; - mp_digit q, qq; + mp_digit q; /* easy out */ if (mp_iszero(a)) { @@ -25,11 +25,12 @@ int mp_cnt_lsb(const mp_int *a) /* now scan this digit until a 1 is found */ if ((q & 1u) == 0u) { + mp_digit p; do { - qq = q & 15u; - x += lnz[qq]; + p = q & 15u; + x += lnz[p]; q >>= 4; - } while (qq == 0u); + } while (p == 0u); } return x; } diff --git a/mp_copy.c b/mp_copy.c index a7ac34a..cf4d5e0 100644 --- a/mp_copy.c +++ b/mp_copy.c @@ -7,8 +7,6 @@ mp_err mp_copy(const mp_int *a, mp_int *b) { int n; - mp_digit *tmpa, *tmpb; - mp_err err; /* if dst == src do nothing */ if (a == b) { @@ -17,27 +15,21 @@ mp_err mp_copy(const mp_int *a, mp_int *b) /* grow dest */ if (b->alloc < a->used) { + mp_err err; if ((err = mp_grow(b, a->used)) != MP_OKAY) { return err; } } /* zero b and copy the parameters over */ - /* pointer aliases */ - - /* source */ - tmpa = a->dp; - - /* destination */ - tmpb = b->dp; /* copy all the digits */ for (n = 0; n < a->used; n++) { - *tmpb++ = *tmpa++; + b->dp[n] = a->dp[n]; } /* clear high digits */ - MP_ZERO_DIGITS(tmpb, b->used - n); + MP_ZERO_DIGITS(b->dp + a->used, b->used - a->used); /* copy used count and sign */ b->used = a->used; diff --git a/mp_div.c b/mp_div.c index 23a2acf..05b96dd 100644 --- a/mp_div.c +++ b/mp_div.c @@ -15,14 +15,14 @@ mp_err mp_div(const mp_int *a, const mp_int *b, mp_int *c, mp_int *d) /* if a < b then q = 0, r = a */ if (mp_cmp_mag(a, b) == MP_LT) { if (d != NULL) { - err = mp_copy(a, d); - } else { - err = MP_OKAY; + if ((err = mp_copy(a, d)) != MP_OKAY) { + return err; + } } if (c != NULL) { mp_zero(c); } - return err; + return MP_OKAY; } if (MP_HAS(S_MP_DIV_RECURSIVE) @@ -31,11 +31,12 @@ mp_err mp_div(const mp_int *a, const mp_int *b, mp_int *c, mp_int *d) err = s_mp_div_recursive(a, b, c, d); } else if (MP_HAS(S_MP_DIV_SCHOOL)) { err = s_mp_div_school(a, b, c, d); - } else { + } else if (MP_HAS(S_MP_DIV_SMALL)) { err = s_mp_div_small(a, b, c, d); + } else { + err = MP_VAL; } return err; } #endif - diff --git a/mp_div_3.c b/mp_div_3.c index 5789b2d..c26692c 100644 --- a/mp_div_3.c +++ b/mp_div_3.c @@ -7,7 +7,7 @@ mp_err mp_div_3(const mp_int *a, mp_int *c, mp_digit *d) { mp_int q; - mp_word w, t; + mp_word w; mp_digit b; mp_err err; int ix; @@ -22,7 +22,8 @@ mp_err mp_div_3(const mp_int *a, mp_int *c, mp_digit *d) q.used = a->used; q.sign = a->sign; w = 0; - for (ix = a->used - 1; ix >= 0; ix--) { + for (ix = a->used; ix --> 0;) { + mp_word t; w = (w << (mp_word)MP_DIGIT_BIT) | (mp_word)a->dp[ix]; if (w >= 3u) { @@ -57,7 +58,7 @@ mp_err mp_div_3(const mp_int *a, mp_int *c, mp_digit *d) } mp_clear(&q); - return err; + return MP_OKAY; } #endif diff --git a/mp_exch.c b/mp_exch.c index 7bc4ee7..50b97d9 100644 --- a/mp_exch.c +++ b/mp_exch.c @@ -8,10 +8,6 @@ */ void mp_exch(mp_int *a, mp_int *b) { - mp_int t; - - t = *a; - *a = *b; - *b = t; + MP_EXCH(mp_int, *a, *b); } #endif diff --git a/mp_exptmod.c b/mp_exptmod.c index e643ded..b917c0b 100644 --- a/mp_exptmod.c +++ b/mp_exptmod.c @@ -64,13 +64,15 @@ LBL_ERR: /* if the modulus is odd or dr != 0 use the montgomery method */ if (MP_HAS(S_MP_EXPTMOD_FAST) && (mp_isodd(P) || (dr != 0))) { return s_mp_exptmod_fast(G, X, P, Y, dr); - } else if (MP_HAS(S_MP_EXPTMOD)) { - /* otherwise use the generic Barrett reduction technique */ - return s_mp_exptmod(G, X, P, Y, 0); - } else { - /* no exptmod for evens */ - return MP_VAL; } + + /* otherwise use the generic Barrett reduction technique */ + if (MP_HAS(S_MP_EXPTMOD)) { + return s_mp_exptmod(G, X, P, Y, 0); + } + + /* no exptmod for evens */ + return MP_VAL; } #endif diff --git a/mp_exteuclid.c b/mp_exteuclid.c index eb8ad37..0d0bfd3 100644 --- a/mp_exteuclid.c +++ b/mp_exteuclid.c @@ -65,7 +65,6 @@ mp_err mp_exteuclid(const mp_int *a, const mp_int *b, mp_int *U1, mp_int *U2, mp mp_exch(U3, &u3); } - err = MP_OKAY; LBL_ERR: mp_clear_multi(&u1, &u2, &u3, &v1, &v2, &v3, &t1, &t2, &t3, &q, &tmp, NULL); return err; diff --git a/mp_fread.c b/mp_fread.c index 767e5a3..005c62a 100644 --- a/mp_fread.c +++ b/mp_fread.c @@ -32,7 +32,7 @@ mp_err mp_fread(mp_int *a, int radix, FILE *stream) mp_zero(a); do { - int y; + uint8_t y; unsigned pos; ch = (radix <= 36) ? MP_TOUPPER(ch) : ch; pos = (unsigned)(ch - (int)'('); @@ -40,7 +40,7 @@ mp_err mp_fread(mp_int *a, int radix, FILE *stream) break; } - y = (int)s_mp_rmap_reverse[pos]; + y = s_mp_rmap_reverse[pos]; if ((y == 0xff) || (y >= radix)) { break; @@ -50,7 +50,7 @@ mp_err mp_fread(mp_int *a, int radix, FILE *stream) if ((err = mp_mul_d(a, (mp_digit)radix, a)) != MP_OKAY) { return err; } - if ((err = mp_add_d(a, (mp_digit)y, a)) != MP_OKAY) { + if ((err = mp_add_d(a, y, a)) != MP_OKAY) { return err; } } while ((ch = fgetc(stream)) != EOF); diff --git a/mp_from_sbin.c b/mp_from_sbin.c index c6e87d7..3ff7d5d 100644 --- a/mp_from_sbin.c +++ b/mp_from_sbin.c @@ -14,11 +14,7 @@ mp_err mp_from_sbin(mp_int *a, const uint8_t *buf, size_t size) } /* first byte is 0 for positive, non-zero for negative */ - if (buf[0] == (uint8_t)0) { - a->sign = MP_ZPOS; - } else { - a->sign = MP_NEG; - } + a->sign = (buf[0] == (uint8_t)0) ? MP_ZPOS : MP_NEG; return MP_OKAY; } diff --git a/mp_fwrite.c b/mp_fwrite.c index be78f7f..42d7287 100644 --- a/mp_fwrite.c +++ b/mp_fwrite.c @@ -8,31 +8,24 @@ mp_err mp_fwrite(const mp_int *a, int radix, FILE *stream) { char *buf; mp_err err; - size_t len, written; + size_t size, written; - /* TODO: this function is not in this PR */ - if ((err = mp_radix_size(a, radix, &len)) != MP_OKAY) { + if ((err = mp_radix_size(a, radix, &size)) != MP_OKAY) { return err; } - buf = (char *) MP_MALLOC(len); + buf = (char *) MP_MALLOC(size); if (buf == NULL) { return MP_MEM; } - if ((err = mp_to_radix(a, buf, len, &written, radix)) != MP_OKAY) { - goto LBL_ERR; + if ((err = mp_to_radix(a, buf, size, &written, radix)) == MP_OKAY) { + if (fwrite(buf, written, 1uL, stream) != 1uL) { + err = MP_ERR; + } } - if (fwrite(buf, written, 1uL, stream) != 1uL) { - err = MP_ERR; - goto LBL_ERR; - } - err = MP_OKAY; - - -LBL_ERR: - MP_FREE_BUFFER(buf, len); + MP_FREE_BUFFER(buf, size); return err; } #endif diff --git a/mp_grow.c b/mp_grow.c index 3354e59..25be5ed 100644 --- a/mp_grow.c +++ b/mp_grow.c @@ -6,9 +6,6 @@ /* grow as required */ mp_err mp_grow(mp_int *a, int size) { - int i; - mp_digit *tmp; - /* if the alloc size is smaller alloc more ram */ if (a->alloc < size) { /* reallocate the array a->dp @@ -17,21 +14,20 @@ mp_err mp_grow(mp_int *a, int size) * in case the operation failed we don't want * to overwrite the dp member of a. */ - tmp = (mp_digit *) MP_REALLOC(a->dp, - (size_t)a->alloc * sizeof(mp_digit), - (size_t)size * sizeof(mp_digit)); - if (tmp == NULL) { + mp_digit *dp = (mp_digit *) MP_REALLOC(a->dp, + (size_t)a->alloc * sizeof(mp_digit), + (size_t)size * sizeof(mp_digit)); + if (dp == NULL) { /* reallocation failed but "a" is still valid [can be freed] */ return MP_MEM; } /* reallocation succeeded so set a->dp */ - a->dp = tmp; + a->dp = dp; /* zero excess digits */ - i = a->alloc; + MP_ZERO_DIGITS(a->dp + a->alloc, size - a->alloc); a->alloc = size; - MP_ZERO_DIGITS(a->dp + i, a->alloc - i); } return MP_OKAY; } diff --git a/mp_is_square.c b/mp_is_square.c index f92ecbf..47f8300 100644 --- a/mp_is_square.c +++ b/mp_is_square.c @@ -28,10 +28,10 @@ static const char rem_105[105] = { /* Store non-zero to ret if arg is square, and zero if not */ mp_err mp_is_square(const mp_int *arg, bool *ret) { - mp_err err; - mp_digit c; - mp_int t; - unsigned long r; + mp_err err; + mp_digit c; + mp_int t; + uint32_t r; /* Default to Non-square :) */ *ret = false; diff --git a/mp_kronecker.c b/mp_kronecker.c index 0ac6338..b106f77 100644 --- a/mp_kronecker.c +++ b/mp_kronecker.c @@ -23,7 +23,7 @@ mp_err mp_kronecker(const mp_int *a, const mp_int *p, int *c) mp_err err; int v, k; - static const int table[8] = {0, 1, 0, -1, 0, -1, 0, 1}; + static const char table[] = {0, 1, 0, -1, 0, -1, 0, 1}; if (mp_iszero(p)) { if ((a->used == 1) && (a->dp[0] == 1u)) { diff --git a/mp_lshd.c b/mp_lshd.c index b0a8454..6c14402 100644 --- a/mp_lshd.c +++ b/mp_lshd.c @@ -7,8 +7,6 @@ mp_err mp_lshd(mp_int *a, int b) { int x; - mp_err err; - mp_digit *top, *bottom; /* if its less than zero return */ if (b <= 0) { @@ -21,6 +19,7 @@ mp_err mp_lshd(mp_int *a, int b) /* grow to fit the new digits */ if (a->alloc < (a->used + b)) { + mp_err err; if ((err = mp_grow(a, a->used + b)) != MP_OKAY) { return err; } @@ -29,18 +28,12 @@ mp_err mp_lshd(mp_int *a, int b) /* increment the used by the shift amount then copy upwards */ a->used += b; - /* top */ - top = a->dp + a->used - 1; - - /* base */ - bottom = (a->dp + a->used - 1) - b; - /* much like mp_rshd this is implemented using a sliding window * except the window goes the otherway around. Copying from * the bottom to the top. see mp_rshd.c for more info. */ - for (x = a->used - 1; x >= b; x--) { - *top-- = *bottom--; + for (x = a->used; x --> b;) { + a->dp[x] = a->dp[x - b]; } /* zero the lower digits */ diff --git a/mp_mod_2d.c b/mp_mod_2d.c index 651c79a..a94a314 100644 --- a/mp_mod_2d.c +++ b/mp_mod_2d.c @@ -9,8 +9,11 @@ mp_err mp_mod_2d(const mp_int *a, int b, mp_int *c) int x; mp_err err; - /* if b is <= 0 then zero the int */ - if (b <= 0) { + if (b < 0) { + return MP_VAL; + } + + if (b == 0) { mp_zero(c); return MP_OKAY; } @@ -20,7 +23,6 @@ mp_err mp_mod_2d(const mp_int *a, int b, mp_int *c) return mp_copy(a, c); } - /* copy */ if ((err = mp_copy(a, c)) != MP_OKAY) { return err; } diff --git a/mp_montgomery_calc_normalization.c b/mp_montgomery_calc_normalization.c index 0d0d5c4..cc07799 100644 --- a/mp_montgomery_calc_normalization.c +++ b/mp_montgomery_calc_normalization.c @@ -26,7 +26,6 @@ mp_err mp_montgomery_calc_normalization(mp_int *a, const mp_int *b) bits = 1; } - /* now compute C = A * B mod b */ for (x = bits - 1; x < (int)MP_DIGIT_BIT; x++) { if ((err = mp_mul_2(a, a)) != MP_OKAY) { diff --git a/mp_neg.c b/mp_neg.c index 2fc1854..f54ef3e 100644 --- a/mp_neg.c +++ b/mp_neg.c @@ -6,18 +6,14 @@ /* b = -a */ mp_err mp_neg(const mp_int *a, mp_int *b) { - mp_err err; if (a != b) { + mp_err err; if ((err = mp_copy(a, b)) != MP_OKAY) { return err; } } - if (!mp_iszero(b)) { - b->sign = (a->sign == MP_ZPOS) ? MP_NEG : MP_ZPOS; - } else { - b->sign = MP_ZPOS; - } + b->sign = mp_iszero(b) || b->sign == MP_NEG ? MP_ZPOS : MP_NEG; return MP_OKAY; } diff --git a/mp_prime_rabin_miller_trials.c b/mp_prime_rabin_miller_trials.c index 1728142..9f66f8d 100644 --- a/mp_prime_rabin_miller_trials.c +++ b/mp_prime_rabin_miller_trials.c @@ -36,7 +36,8 @@ int mp_prime_rabin_miller_trials(int size) for (x = 0; x < (int)(sizeof(sizes)/(sizeof(sizes[0]))); x++) { if (sizes[x].k == size) { return sizes[x].t; - } else if (sizes[x].k > size) { + } + if (sizes[x].k > size) { return (x == 0) ? sizes[0].t : sizes[x - 1].t; } } diff --git a/mp_read_radix.c b/mp_read_radix.c index d4a3d1e..df8059a 100644 --- a/mp_read_radix.c +++ b/mp_read_radix.c @@ -31,13 +31,13 @@ mp_err mp_read_radix(mp_int *a, const char *str, int radix) * this allows numbers like 1AB and 1ab to represent the same value * [e.g. in hex] */ - int y; + uint8_t y; char ch = (radix <= 36) ? (char)MP_TOUPPER((int)*str) : *str; unsigned pos = (unsigned)(ch - '('); if (MP_RMAP_REVERSE_SIZE < pos) { break; } - y = (int)s_mp_rmap_reverse[pos]; + y = s_mp_rmap_reverse[pos]; /* if the char was found in the map * and is less than the given radix add it @@ -49,14 +49,14 @@ mp_err mp_read_radix(mp_int *a, const char *str, int radix) if ((err = mp_mul_d(a, (mp_digit)radix, a)) != MP_OKAY) { return err; } - if ((err = mp_add_d(a, (mp_digit)y, a)) != MP_OKAY) { + if ((err = mp_add_d(a, y, a)) != MP_OKAY) { return err; } ++str; } /* if an illegal character was found, fail. */ - if (!((*str == '\0') || (*str == '\r') || (*str == '\n'))) { + if ((*str != '\0') && (*str != '\r') && (*str != '\n')) { return MP_VAL; } diff --git a/mp_reduce.c b/mp_reduce.c index 1b4435c..5226fe7 100644 --- a/mp_reduce.c +++ b/mp_reduce.c @@ -24,19 +24,19 @@ mp_err mp_reduce(mp_int *x, const mp_int *m, const mp_int *mu) /* according to HAC this optimization is ok */ if ((mp_digit)um > ((mp_digit)1 << (MP_DIGIT_BIT - 1))) { if ((err = mp_mul(&q, mu, &q)) != MP_OKAY) { - goto CLEANUP; + goto LBL_ERR; } } else if (MP_HAS(S_MP_MUL_HIGH_DIGS)) { if ((err = s_mp_mul_high_digs(&q, mu, &q, um)) != MP_OKAY) { - goto CLEANUP; + goto LBL_ERR; } } else if (MP_HAS(S_MP_MUL_HIGH_DIGS_FAST)) { if ((err = s_mp_mul_high_digs_fast(&q, mu, &q, um)) != MP_OKAY) { - goto CLEANUP; + goto LBL_ERR; } } else { err = MP_VAL; - goto CLEANUP; + goto LBL_ERR; } /* q3 = q2 / b**(k+1) */ @@ -44,38 +44,38 @@ mp_err mp_reduce(mp_int *x, const mp_int *m, const mp_int *mu) /* x = x mod b**(k+1), quick (no division) */ if ((err = mp_mod_2d(x, MP_DIGIT_BIT * (um + 1), x)) != MP_OKAY) { - goto CLEANUP; + goto LBL_ERR; } /* q = q * m mod b**(k+1), quick (no division) */ if ((err = s_mp_mul_digs(&q, m, &q, um + 1)) != MP_OKAY) { - goto CLEANUP; + goto LBL_ERR; } /* x = x - q */ if ((err = mp_sub(x, &q, x)) != MP_OKAY) { - goto CLEANUP; + goto LBL_ERR; } /* If x < 0, add b**(k+1) to it */ if (mp_cmp_d(x, 0uL) == MP_LT) { mp_set(&q, 1uL); if ((err = mp_lshd(&q, um + 1)) != MP_OKAY) { - goto CLEANUP; + goto LBL_ERR; } if ((err = mp_add(x, &q, x)) != MP_OKAY) { - goto CLEANUP; + goto LBL_ERR; } } /* Back off if it's too big */ while (mp_cmp(x, m) != MP_LT) { if ((err = s_mp_sub(x, m, x)) != MP_OKAY) { - goto CLEANUP; + goto LBL_ERR; } } -CLEANUP: +LBL_ERR: mp_clear(&q); return err; diff --git a/mp_reduce_is_2k.c b/mp_reduce_is_2k.c index 618ab54..a5798e6 100644 --- a/mp_reduce_is_2k.c +++ b/mp_reduce_is_2k.c @@ -6,17 +6,13 @@ /* determines if mp_reduce_2k can be used */ bool mp_reduce_is_2k(const mp_int *a) { - int ix, iy, iw; - mp_digit iz; - if (a->used == 0) { return false; } else if (a->used == 1) { return true; } else if (a->used > 1) { - iy = mp_count_bits(a); - iz = 1; - iw = 1; + int ix, iy = mp_count_bits(a), iw = 1; + mp_digit iz = 1; /* Test every bit from the second digit up, must be 1 */ for (ix = MP_DIGIT_BIT; ix < iy; ix++) { diff --git a/mp_reduce_is_2k_l.c b/mp_reduce_is_2k_l.c index 30fc10d..dca2d7e 100644 --- a/mp_reduce_is_2k_l.c +++ b/mp_reduce_is_2k_l.c @@ -6,14 +6,13 @@ /* determines if reduce_2k_l can be used */ bool mp_reduce_is_2k_l(const mp_int *a) { - int ix, iy; - if (a->used == 0) { return false; } else if (a->used == 1) { return true; } else if (a->used > 1) { /* if more than half of the digits are -1 we're sold */ + int ix, iy; for (iy = ix = 0; ix < a->used; ix++) { if (a->dp[ix] == MP_DIGIT_MAX) { ++iy; diff --git a/mp_rshd.c b/mp_rshd.c index 2eabb12..d798907 100644 --- a/mp_rshd.c +++ b/mp_rshd.c @@ -6,8 +6,7 @@ /* shift right a certain amount of digits */ void mp_rshd(mp_int *a, int b) { - int x; - mp_digit *bottom, *top; + int x; /* if b <= 0 then ignore it */ if (b <= 0) { @@ -20,15 +19,8 @@ void mp_rshd(mp_int *a, int b) return; } - /* shift the digits down */ - - /* bottom */ - bottom = a->dp; - - /* top [offset into digits] */ - top = a->dp + b; - - /* this is implemented as a sliding window where + /* shift the digits down. + * this is implemented as a sliding window where * the window is b-digits long and digits from * the top of the window are copied to the bottom * @@ -39,11 +31,11 @@ void mp_rshd(mp_int *a, int b) \-------------------/ ----> */ for (x = 0; x < (a->used - b); x++) { - *bottom++ = *top++; + a->dp[x] = a->dp[x + b]; } /* zero the top digits */ - MP_ZERO_DIGITS(bottom, a->used - x); + MP_ZERO_DIGITS(a->dp + a->used - b, b); /* remove excess digits */ a->used -= b; diff --git a/mp_set.c b/mp_set.c index 0777f09..3ee5f81 100644 --- a/mp_set.c +++ b/mp_set.c @@ -6,9 +6,10 @@ /* set to a digit */ void mp_set(mp_int *a, mp_digit b) { + int oldused = a->used; a->dp[0] = b & MP_MASK; a->sign = MP_ZPOS; a->used = (a->dp[0] != 0u) ? 1 : 0; - MP_ZERO_DIGITS(a->dp + a->used, a->alloc - a->used); + MP_ZERO_DIGITS(a->dp + a->used, oldused - a->used); } #endif diff --git a/mp_shrink.c b/mp_shrink.c index 6c3c95b..e5814cb 100644 --- a/mp_shrink.c +++ b/mp_shrink.c @@ -6,15 +6,15 @@ /* shrink a bignum */ mp_err mp_shrink(mp_int *a) { - mp_digit *tmp; int alloc = MP_MAX(MP_MIN_PREC, a->used); if (a->alloc != alloc) { - if ((tmp = (mp_digit *) MP_REALLOC(a->dp, - (size_t)a->alloc * sizeof(mp_digit), - (size_t)alloc * sizeof(mp_digit))) == NULL) { + mp_digit *dp = (mp_digit *) MP_REALLOC(a->dp, + (size_t)a->alloc * sizeof(mp_digit), + (size_t)alloc * sizeof(mp_digit)); + if (dp == NULL) { return MP_MEM; } - a->dp = tmp; + a->dp = dp; a->alloc = alloc; } return MP_OKAY; diff --git a/mp_signed_rsh.c b/mp_signed_rsh.c index c56dfba..ecaaa21 100644 --- a/mp_signed_rsh.c +++ b/mp_signed_rsh.c @@ -6,17 +6,16 @@ /* shift right by a certain bit count with sign extension */ mp_err mp_signed_rsh(const mp_int *a, int b, mp_int *c) { - mp_err res; + mp_err err; if (a->sign == MP_ZPOS) { return mp_div_2d(a, b, c, NULL); } - res = mp_add_d(a, 1uL, c); - if (res != MP_OKAY) { - return res; + if ((err = mp_add_d(a, 1uL, c)) != MP_OKAY) { + return err; } - res = mp_div_2d(c, b, c, NULL); - return (res == MP_OKAY) ? mp_sub_d(c, 1uL, c) : res; + err = mp_div_2d(c, b, c, NULL); + return (err == MP_OKAY) ? mp_sub_d(c, 1uL, c) : err; } #endif diff --git a/mp_sqrt.c b/mp_sqrt.c index b51f615..e36a81a 100644 --- a/mp_sqrt.c +++ b/mp_sqrt.c @@ -25,7 +25,7 @@ mp_err mp_sqrt(const mp_int *arg, mp_int *ret) } if ((err = mp_init(&t2)) != MP_OKAY) { - goto E2; + goto LBL_ERR2; } /* First approx. (not very bad for large arg) */ @@ -33,33 +33,33 @@ mp_err mp_sqrt(const mp_int *arg, mp_int *ret) /* t1 > 0 */ if ((err = mp_div(arg, &t1, &t2, NULL)) != MP_OKAY) { - goto E1; + goto LBL_ERR1; } if ((err = mp_add(&t1, &t2, &t1)) != MP_OKAY) { - goto E1; + goto LBL_ERR1; } if ((err = mp_div_2(&t1, &t1)) != MP_OKAY) { - goto E1; + goto LBL_ERR1; } /* And now t1 > sqrt(arg) */ do { if ((err = mp_div(arg, &t1, &t2, NULL)) != MP_OKAY) { - goto E1; + goto LBL_ERR1; } if ((err = mp_add(&t1, &t2, &t1)) != MP_OKAY) { - goto E1; + goto LBL_ERR1; } if ((err = mp_div_2(&t1, &t1)) != MP_OKAY) { - goto E1; + goto LBL_ERR1; } /* t1 >= sqrt(arg) >= t2 at this point */ } while (mp_cmp_mag(&t1, &t2) == MP_GT); mp_exch(&t1, ret); -E1: +LBL_ERR1: mp_clear(&t2); -E2: +LBL_ERR2: mp_clear(&t1); return err; } diff --git a/mp_sqrtmod_prime.c b/mp_sqrtmod_prime.c index 96b2836..8930184 100644 --- a/mp_sqrtmod_prime.c +++ b/mp_sqrtmod_prime.c @@ -33,28 +33,28 @@ mp_err mp_sqrtmod_prime(const mp_int *n, const mp_int *prime, mp_int *ret) * compute directly: err = n^(prime+1)/4 mod prime * Handbook of Applied Cryptography algorithm 3.36 */ - if ((err = mp_mod_d(prime, 4uL, &i)) != MP_OKAY) goto cleanup; + if ((err = mp_mod_d(prime, 4uL, &i)) != MP_OKAY) goto LBL_END; if (i == 3u) { - if ((err = mp_add_d(prime, 1uL, &t1)) != MP_OKAY) goto cleanup; - if ((err = mp_div_2(&t1, &t1)) != MP_OKAY) goto cleanup; - if ((err = mp_div_2(&t1, &t1)) != MP_OKAY) goto cleanup; - if ((err = mp_exptmod(n, &t1, prime, ret)) != MP_OKAY) goto cleanup; + if ((err = mp_add_d(prime, 1uL, &t1)) != MP_OKAY) goto LBL_END; + if ((err = mp_div_2(&t1, &t1)) != MP_OKAY) goto LBL_END; + if ((err = mp_div_2(&t1, &t1)) != MP_OKAY) goto LBL_END; + if ((err = mp_exptmod(n, &t1, prime, ret)) != MP_OKAY) goto LBL_END; err = MP_OKAY; - goto cleanup; + goto LBL_END; } /* NOW: Tonelli-Shanks algorithm */ /* factor out powers of 2 from prime-1, defining Q and S as: prime-1 = Q*2^S */ - if ((err = mp_copy(prime, &Q)) != MP_OKAY) goto cleanup; - if ((err = mp_sub_d(&Q, 1uL, &Q)) != MP_OKAY) goto cleanup; + if ((err = mp_copy(prime, &Q)) != MP_OKAY) goto LBL_END; + if ((err = mp_sub_d(&Q, 1uL, &Q)) != MP_OKAY) goto LBL_END; /* Q = prime - 1 */ mp_zero(&S); /* S = 0 */ while (mp_iseven(&Q)) { - if ((err = mp_div_2(&Q, &Q)) != MP_OKAY) goto cleanup; + if ((err = mp_div_2(&Q, &Q)) != MP_OKAY) goto LBL_END; /* Q = Q / 2 */ - if ((err = mp_add_d(&S, 1uL, &S)) != MP_OKAY) goto cleanup; + if ((err = mp_add_d(&S, 1uL, &S)) != MP_OKAY) goto LBL_END; /* S = S + 1 */ } @@ -62,55 +62,55 @@ mp_err mp_sqrtmod_prime(const mp_int *n, const mp_int *prime, mp_int *ret) mp_set(&Z, 2uL); /* Z = 2 */ for (;;) { - if ((err = mp_kronecker(&Z, prime, &legendre)) != MP_OKAY) goto cleanup; + if ((err = mp_kronecker(&Z, prime, &legendre)) != MP_OKAY) goto LBL_END; if (legendre == -1) break; - if ((err = mp_add_d(&Z, 1uL, &Z)) != MP_OKAY) goto cleanup; + if ((err = mp_add_d(&Z, 1uL, &Z)) != MP_OKAY) goto LBL_END; /* Z = Z + 1 */ } - if ((err = mp_exptmod(&Z, &Q, prime, &C)) != MP_OKAY) goto cleanup; + if ((err = mp_exptmod(&Z, &Q, prime, &C)) != MP_OKAY) goto LBL_END; /* C = Z ^ Q mod prime */ - if ((err = mp_add_d(&Q, 1uL, &t1)) != MP_OKAY) goto cleanup; - if ((err = mp_div_2(&t1, &t1)) != MP_OKAY) goto cleanup; + if ((err = mp_add_d(&Q, 1uL, &t1)) != MP_OKAY) goto LBL_END; + if ((err = mp_div_2(&t1, &t1)) != MP_OKAY) goto LBL_END; /* t1 = (Q + 1) / 2 */ - if ((err = mp_exptmod(n, &t1, prime, &R)) != MP_OKAY) goto cleanup; + if ((err = mp_exptmod(n, &t1, prime, &R)) != MP_OKAY) goto LBL_END; /* R = n ^ ((Q + 1) / 2) mod prime */ - if ((err = mp_exptmod(n, &Q, prime, &T)) != MP_OKAY) goto cleanup; + if ((err = mp_exptmod(n, &Q, prime, &T)) != MP_OKAY) goto LBL_END; /* T = n ^ Q mod prime */ - if ((err = mp_copy(&S, &M)) != MP_OKAY) goto cleanup; + if ((err = mp_copy(&S, &M)) != MP_OKAY) goto LBL_END; /* M = S */ mp_set(&two, 2uL); for (;;) { - if ((err = mp_copy(&T, &t1)) != MP_OKAY) goto cleanup; + if ((err = mp_copy(&T, &t1)) != MP_OKAY) goto LBL_END; i = 0; for (;;) { if (mp_cmp_d(&t1, 1uL) == MP_EQ) break; - if ((err = mp_exptmod(&t1, &two, prime, &t1)) != MP_OKAY) goto cleanup; + if ((err = mp_exptmod(&t1, &two, prime, &t1)) != MP_OKAY) goto LBL_END; i++; } if (i == 0u) { - if ((err = mp_copy(&R, ret)) != MP_OKAY) goto cleanup; + if ((err = mp_copy(&R, ret)) != MP_OKAY) goto LBL_END; err = MP_OKAY; - goto cleanup; + goto LBL_END; } - if ((err = mp_sub_d(&M, i, &t1)) != MP_OKAY) goto cleanup; - if ((err = mp_sub_d(&t1, 1uL, &t1)) != MP_OKAY) goto cleanup; - if ((err = mp_exptmod(&two, &t1, prime, &t1)) != MP_OKAY) goto cleanup; + if ((err = mp_sub_d(&M, i, &t1)) != MP_OKAY) goto LBL_END; + if ((err = mp_sub_d(&t1, 1uL, &t1)) != MP_OKAY) goto LBL_END; + if ((err = mp_exptmod(&two, &t1, prime, &t1)) != MP_OKAY) goto LBL_END; /* t1 = 2 ^ (M - i - 1) */ - if ((err = mp_exptmod(&C, &t1, prime, &t1)) != MP_OKAY) goto cleanup; + if ((err = mp_exptmod(&C, &t1, prime, &t1)) != MP_OKAY) goto LBL_END; /* t1 = C ^ (2 ^ (M - i - 1)) mod prime */ - if ((err = mp_sqrmod(&t1, prime, &C)) != MP_OKAY) goto cleanup; + if ((err = mp_sqrmod(&t1, prime, &C)) != MP_OKAY) goto LBL_END; /* C = (t1 * t1) mod prime */ - if ((err = mp_mulmod(&R, &t1, prime, &R)) != MP_OKAY) goto cleanup; + if ((err = mp_mulmod(&R, &t1, prime, &R)) != MP_OKAY) goto LBL_END; /* R = (R * t1) mod prime */ - if ((err = mp_mulmod(&T, &C, prime, &T)) != MP_OKAY) goto cleanup; + if ((err = mp_mulmod(&T, &C, prime, &T)) != MP_OKAY) goto LBL_END; /* T = (T * C) mod prime */ mp_set(&M, i); /* M = i */ } -cleanup: +LBL_END: mp_clear_multi(&t1, &C, &Q, &S, &Z, &M, &T, &R, &two, NULL); return err; } diff --git a/mp_sub.c b/mp_sub.c index c859026..8104740 100644 --- a/mp_sub.c +++ b/mp_sub.c @@ -6,35 +6,31 @@ /* high level subtraction (handles signs) */ mp_err mp_sub(const mp_int *a, const mp_int *b, mp_int *c) { - mp_sign sa = a->sign, sb = b->sign; - mp_err err; - - if (sa != sb) { + if (a->sign != b->sign) { /* subtract a negative from a positive, OR */ /* subtract a positive from a negative. */ /* In either case, ADD their magnitudes, */ /* and use the sign of the first number. */ - c->sign = sa; - err = s_mp_add(a, b, c); - } else { - /* subtract a positive from a positive, OR */ - /* subtract a negative from a negative. */ - /* First, take the difference between their */ - /* magnitudes, then... */ - if (mp_cmp_mag(a, b) != MP_LT) { - /* Copy the sign from the first */ - c->sign = sa; - /* The first has a larger or equal magnitude */ - err = s_mp_sub(a, b, c); - } else { - /* The result has the *opposite* sign from */ - /* the first number. */ - c->sign = (sa == MP_ZPOS) ? MP_NEG : MP_ZPOS; - /* The second has a larger magnitude */ - err = s_mp_sub(b, a, c); - } + c->sign = a->sign; + return s_mp_add(a, b, c); } - return err; + + /* subtract a positive from a positive, OR */ + /* subtract a negative from a negative. */ + /* First, take the difference between their */ + /* magnitudes, then... */ + if (mp_cmp_mag(a, b) == MP_LT) { + /* The second has a larger magnitude */ + /* The result has the *opposite* sign from */ + /* the first number. */ + c->sign = (a->sign == MP_ZPOS) ? MP_NEG : MP_ZPOS; + MP_EXCH(const mp_int *, a, b); + } else { + /* The first has a larger or equal magnitude */ + /* Copy the sign from the first */ + c->sign = a->sign; + } + return s_mp_sub(a, b, c); } #endif diff --git a/mp_to_radix.c b/mp_to_radix.c index c1ea233..8b7728d 100644 --- a/mp_to_radix.c +++ b/mp_to_radix.c @@ -4,17 +4,11 @@ /* SPDX-License-Identifier: Unlicense */ /* reverse an array, used for radix code */ -static void s_mp_reverse(uint8_t *s, size_t len) +static void s_mp_reverse(char *s, size_t len) { - size_t ix, iy; - uint8_t t; - - ix = 0u; - iy = len - 1u; + size_t ix = 0, iy = len - 1u; while (ix < iy) { - t = s[ix]; - s[ix] = s[iy]; - s[iy] = t; + MP_EXCH(char, s[ix], s[iy]); ++ix; --iy; } @@ -83,7 +77,7 @@ mp_err mp_to_radix(const mp_int *a, char *str, size_t maxlen, size_t *written, i /* reverse the digits of the string. In this case _s points * to the first digit [exluding the sign] of the number */ - s_mp_reverse((uint8_t *)_s, digs); + s_mp_reverse(_s, digs); /* append a NULL so the string is properly terminated */ *str = '\0'; diff --git a/mp_zero.c b/mp_zero.c index 0b79f50..b7dddd2 100644 --- a/mp_zero.c +++ b/mp_zero.c @@ -7,7 +7,7 @@ void mp_zero(mp_int *a) { a->sign = MP_ZPOS; + MP_ZERO_DIGITS(a->dp, a->used); a->used = 0; - MP_ZERO_DIGITS(a->dp, a->alloc); } #endif diff --git a/s_mp_div_recursive.c b/s_mp_div_recursive.c index d641123..7007aef 100644 --- a/s_mp_div_recursive.c +++ b/s_mp_div_recursive.c @@ -17,10 +17,9 @@ static mp_err s_mp_recursion(const mp_int *a, const mp_int *b, mp_int *q, mp_int *r) { mp_err err; - int m, k; mp_int A1, A2, B1, B0, Q1, Q0, R1, R0, t; + int m = a->used - b->used, k = m/2; - m = a->used - b->used; if (m < MP_KARATSUBA_MUL_CUTOFF) { return s_mp_div_school(a, b, q, r); } @@ -29,9 +28,6 @@ static mp_err s_mp_recursion(const mp_int *a, const mp_int *b, mp_int *q, mp_int goto LBL_ERR; } - /* k = floor(m/2) */ - k = m/2; - /* B1 = b / beta^k, B0 = b % beta^k*/ if ((err = mp_div_2d(b, k * MP_DIGIT_BIT, &B1, &B0)) != MP_OKAY) goto LBL_ERR; diff --git a/s_mp_div_school.c b/s_mp_div_school.c index 6ff427a..cf34cc9 100644 --- a/s_mp_div_school.c +++ b/s_mp_div_school.c @@ -140,8 +140,6 @@ mp_err s_mp_div_school(const mp_int *a, const mp_int *b, mp_int *c, mp_int *d) mp_exch(&x, d); } - err = MP_OKAY; - LBL_Y: mp_clear(&y); LBL_X: diff --git a/tommath_private.h b/tommath_private.h index 0f5ac93..31f1ea5 100644 --- a/tommath_private.h +++ b/tommath_private.h @@ -148,6 +148,8 @@ extern void MP_FREE(void *mem, size_t size); #define MP_TOUPPER(c) ((((c) >= 'a') && ((c) <= 'z')) ? (((c) + 'A') - 'a') : (c)) +#define MP_EXCH(t, a, b) do { t _c = a; a = b; b = _c; } while (0) + /* Static assertion */ #define MP_STATIC_ASSERT(msg, cond) typedef char mp_static_assert_##msg[(cond) ? 1 : -1]; @@ -267,9 +269,9 @@ extern MP_PRIVATE const mp_digit s_mp_prime_tab[]; #define MP_GET_MAG(name, type) \ type name(const mp_int* a) \ { \ - unsigned i = MP_MIN((unsigned)a->used, (unsigned)((MP_SIZEOF_BITS(type) + MP_DIGIT_BIT - 1) / MP_DIGIT_BIT)); \ + int i = MP_MIN(a->used, (int)((MP_SIZEOF_BITS(type) + MP_DIGIT_BIT - 1) / MP_DIGIT_BIT)); \ type res = 0u; \ - while (i --> 0u) { \ + while (i --> 0) { \ res <<= ((MP_SIZEOF_BITS(type) <= MP_DIGIT_BIT) ? 0 : MP_DIGIT_BIT); \ res |= (type)a->dp[i]; \ if (MP_SIZEOF_BITS(type) <= MP_DIGIT_BIT) { break; } \