first batch of simplifications

This commit is contained in:
Daniel Mendler 2019-10-29 18:41:25 +01:00
parent b9977adfb8
commit 3cdcec43e6
No known key found for this signature in database
GPG Key ID: D88ADB2A2693CA43
40 changed files with 202 additions and 293 deletions

View File

@ -9,10 +9,9 @@
*/
mp_err mp_abs(const mp_int *a, mp_int *b)
{
mp_err err;
/* copy a to b */
if (a != b) {
mp_err err;
if ((err = mp_copy(a, b)) != MP_OKAY) {
return err;
}

View File

@ -6,33 +6,24 @@
/* high level addition (handles signs) */
mp_err mp_add(const mp_int *a, const mp_int *b, mp_int *c)
{
mp_sign sa, sb;
mp_err err;
/* get sign of both inputs */
sa = a->sign;
sb = b->sign;
/* handle two cases, not four */
if (sa == sb) {
if (a->sign == b->sign) {
/* both positive or both negative */
/* add their magnitudes, copy the sign */
c->sign = sa;
err = s_mp_add(a, b, c);
} else {
/* one positive, the other negative */
/* subtract the one with the greater magnitude from */
/* the one of the lesser magnitude. The result gets */
/* the sign of the one with the greater magnitude. */
if (mp_cmp_mag(a, b) == MP_LT) {
c->sign = sb;
err = s_mp_sub(b, a, c);
} else {
c->sign = sa;
err = s_mp_sub(a, b, c);
}
c->sign = a->sign;
return s_mp_add(a, b, c);
}
return err;
/* one positive, the other negative */
/* subtract the one with the greater magnitude from */
/* the one of the lesser magnitude. The result gets */
/* the sign of the one with the greater magnitude. */
if (mp_cmp_mag(a, b) == MP_LT) {
MP_EXCH(const mp_int *, a, b);
}
c->sign = a->sign;
return s_mp_sub(a, b, c);
}
#endif

View File

@ -7,12 +7,11 @@
void mp_clear_multi(mp_int *mp, ...)
{
mp_int *next_mp = mp;
va_list args;
va_start(args, mp);
while (next_mp != NULL) {
mp_clear(next_mp);
next_mp = va_arg(args, mp_int *);
while (mp != NULL) {
mp_clear(mp);
mp = va_arg(args, mp_int *);
}
va_end(args);
}

View File

@ -8,19 +8,14 @@ mp_ord mp_cmp(const mp_int *a, const mp_int *b)
{
/* compare based on sign */
if (a->sign != b->sign) {
if (a->sign == MP_NEG) {
return MP_LT;
} else {
return MP_GT;
}
return a->sign == MP_NEG ? MP_LT : MP_GT;
}
/* compare digits */
/* if negative compare opposite direction */
if (a->sign == MP_NEG) {
/* if negative compare opposite direction */
return mp_cmp_mag(b, a);
} else {
return mp_cmp_mag(a, b);
MP_EXCH(const mp_int *, a, b);
}
return mp_cmp_mag(a, b);
}
#endif

View File

@ -17,12 +17,10 @@ mp_ord mp_cmp_d(const mp_int *a, mp_digit b)
}
/* compare the only digit of a to b */
if (a->dp[0] > b) {
return MP_GT;
} else if (a->dp[0] < b) {
return MP_LT;
} else {
return MP_EQ;
if (a->dp[0] != b) {
return a->dp[0] > b ? MP_GT : MP_LT;
}
return MP_EQ;
}
#endif

View File

@ -6,34 +6,20 @@
/* compare maginitude of two ints (unsigned) */
mp_ord mp_cmp_mag(const mp_int *a, const mp_int *b)
{
int n;
const mp_digit *tmpa, *tmpb;
int n;
/* compare based on # of non-zero digits */
if (a->used > b->used) {
return MP_GT;
if (a->used != b->used) {
return a->used > b->used ? MP_GT : MP_LT;
}
if (a->used < b->used) {
return MP_LT;
}
/* alias for a */
tmpa = a->dp + (a->used - 1);
/* alias for b */
tmpb = b->dp + (a->used - 1);
/* compare based on digits */
for (n = 0; n < a->used; ++n, --tmpa, --tmpb) {
if (*tmpa > *tmpb) {
return MP_GT;
}
if (*tmpa < *tmpb) {
return MP_LT;
for (n = a->used; n --> 0;) {
if (a->dp[n] != b->dp[n]) {
return a->dp[n] > b->dp[n] ? MP_GT : MP_LT;
}
}
return MP_EQ;
}
#endif

View File

@ -3,7 +3,7 @@
/* LibTomMath, multiple-precision integer library -- Tom St Denis */
/* SPDX-License-Identifier: Unlicense */
static const int lnz[16] = {
static const char lnz[16] = {
4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0
};
@ -11,7 +11,7 @@ static const int lnz[16] = {
int mp_cnt_lsb(const mp_int *a)
{
int x;
mp_digit q, qq;
mp_digit q;
/* easy out */
if (mp_iszero(a)) {
@ -25,11 +25,12 @@ int mp_cnt_lsb(const mp_int *a)
/* now scan this digit until a 1 is found */
if ((q & 1u) == 0u) {
mp_digit p;
do {
qq = q & 15u;
x += lnz[qq];
p = q & 15u;
x += lnz[p];
q >>= 4;
} while (qq == 0u);
} while (p == 0u);
}
return x;
}

View File

@ -7,8 +7,6 @@
mp_err mp_copy(const mp_int *a, mp_int *b)
{
int n;
mp_digit *tmpa, *tmpb;
mp_err err;
/* if dst == src do nothing */
if (a == b) {
@ -17,27 +15,21 @@ mp_err mp_copy(const mp_int *a, mp_int *b)
/* grow dest */
if (b->alloc < a->used) {
mp_err err;
if ((err = mp_grow(b, a->used)) != MP_OKAY) {
return err;
}
}
/* zero b and copy the parameters over */
/* pointer aliases */
/* source */
tmpa = a->dp;
/* destination */
tmpb = b->dp;
/* copy all the digits */
for (n = 0; n < a->used; n++) {
*tmpb++ = *tmpa++;
b->dp[n] = a->dp[n];
}
/* clear high digits */
MP_ZERO_DIGITS(tmpb, b->used - n);
MP_ZERO_DIGITS(b->dp + a->used, b->used - a->used);
/* copy used count and sign */
b->used = a->used;

View File

@ -15,14 +15,14 @@ mp_err mp_div(const mp_int *a, const mp_int *b, mp_int *c, mp_int *d)
/* if a < b then q = 0, r = a */
if (mp_cmp_mag(a, b) == MP_LT) {
if (d != NULL) {
err = mp_copy(a, d);
} else {
err = MP_OKAY;
if ((err = mp_copy(a, d)) != MP_OKAY) {
return err;
}
}
if (c != NULL) {
mp_zero(c);
}
return err;
return MP_OKAY;
}
if (MP_HAS(S_MP_DIV_RECURSIVE)
@ -31,11 +31,12 @@ mp_err mp_div(const mp_int *a, const mp_int *b, mp_int *c, mp_int *d)
err = s_mp_div_recursive(a, b, c, d);
} else if (MP_HAS(S_MP_DIV_SCHOOL)) {
err = s_mp_div_school(a, b, c, d);
} else {
} else if (MP_HAS(S_MP_DIV_SMALL)) {
err = s_mp_div_small(a, b, c, d);
} else {
err = MP_VAL;
}
return err;
}
#endif

View File

@ -7,7 +7,7 @@
mp_err mp_div_3(const mp_int *a, mp_int *c, mp_digit *d)
{
mp_int q;
mp_word w, t;
mp_word w;
mp_digit b;
mp_err err;
int ix;
@ -22,7 +22,8 @@ mp_err mp_div_3(const mp_int *a, mp_int *c, mp_digit *d)
q.used = a->used;
q.sign = a->sign;
w = 0;
for (ix = a->used - 1; ix >= 0; ix--) {
for (ix = a->used; ix --> 0;) {
mp_word t;
w = (w << (mp_word)MP_DIGIT_BIT) | (mp_word)a->dp[ix];
if (w >= 3u) {
@ -57,7 +58,7 @@ mp_err mp_div_3(const mp_int *a, mp_int *c, mp_digit *d)
}
mp_clear(&q);
return err;
return MP_OKAY;
}
#endif

View File

@ -8,10 +8,6 @@
*/
void mp_exch(mp_int *a, mp_int *b)
{
mp_int t;
t = *a;
*a = *b;
*b = t;
MP_EXCH(mp_int, *a, *b);
}
#endif

View File

@ -64,13 +64,15 @@ LBL_ERR:
/* if the modulus is odd or dr != 0 use the montgomery method */
if (MP_HAS(S_MP_EXPTMOD_FAST) && (mp_isodd(P) || (dr != 0))) {
return s_mp_exptmod_fast(G, X, P, Y, dr);
} else if (MP_HAS(S_MP_EXPTMOD)) {
/* otherwise use the generic Barrett reduction technique */
return s_mp_exptmod(G, X, P, Y, 0);
} else {
/* no exptmod for evens */
return MP_VAL;
}
/* otherwise use the generic Barrett reduction technique */
if (MP_HAS(S_MP_EXPTMOD)) {
return s_mp_exptmod(G, X, P, Y, 0);
}
/* no exptmod for evens */
return MP_VAL;
}
#endif

View File

@ -65,7 +65,6 @@ mp_err mp_exteuclid(const mp_int *a, const mp_int *b, mp_int *U1, mp_int *U2, mp
mp_exch(U3, &u3);
}
err = MP_OKAY;
LBL_ERR:
mp_clear_multi(&u1, &u2, &u3, &v1, &v2, &v3, &t1, &t2, &t3, &q, &tmp, NULL);
return err;

View File

@ -32,7 +32,7 @@ mp_err mp_fread(mp_int *a, int radix, FILE *stream)
mp_zero(a);
do {
int y;
uint8_t y;
unsigned pos;
ch = (radix <= 36) ? MP_TOUPPER(ch) : ch;
pos = (unsigned)(ch - (int)'(');
@ -40,7 +40,7 @@ mp_err mp_fread(mp_int *a, int radix, FILE *stream)
break;
}
y = (int)s_mp_rmap_reverse[pos];
y = s_mp_rmap_reverse[pos];
if ((y == 0xff) || (y >= radix)) {
break;
@ -50,7 +50,7 @@ mp_err mp_fread(mp_int *a, int radix, FILE *stream)
if ((err = mp_mul_d(a, (mp_digit)radix, a)) != MP_OKAY) {
return err;
}
if ((err = mp_add_d(a, (mp_digit)y, a)) != MP_OKAY) {
if ((err = mp_add_d(a, y, a)) != MP_OKAY) {
return err;
}
} while ((ch = fgetc(stream)) != EOF);

View File

@ -14,11 +14,7 @@ mp_err mp_from_sbin(mp_int *a, const uint8_t *buf, size_t size)
}
/* first byte is 0 for positive, non-zero for negative */
if (buf[0] == (uint8_t)0) {
a->sign = MP_ZPOS;
} else {
a->sign = MP_NEG;
}
a->sign = (buf[0] == (uint8_t)0) ? MP_ZPOS : MP_NEG;
return MP_OKAY;
}

View File

@ -8,31 +8,24 @@ mp_err mp_fwrite(const mp_int *a, int radix, FILE *stream)
{
char *buf;
mp_err err;
size_t len, written;
size_t size, written;
/* TODO: this function is not in this PR */
if ((err = mp_radix_size(a, radix, &len)) != MP_OKAY) {
if ((err = mp_radix_size(a, radix, &size)) != MP_OKAY) {
return err;
}
buf = (char *) MP_MALLOC(len);
buf = (char *) MP_MALLOC(size);
if (buf == NULL) {
return MP_MEM;
}
if ((err = mp_to_radix(a, buf, len, &written, radix)) != MP_OKAY) {
goto LBL_ERR;
if ((err = mp_to_radix(a, buf, size, &written, radix)) == MP_OKAY) {
if (fwrite(buf, written, 1uL, stream) != 1uL) {
err = MP_ERR;
}
}
if (fwrite(buf, written, 1uL, stream) != 1uL) {
err = MP_ERR;
goto LBL_ERR;
}
err = MP_OKAY;
LBL_ERR:
MP_FREE_BUFFER(buf, len);
MP_FREE_BUFFER(buf, size);
return err;
}
#endif

View File

@ -6,9 +6,6 @@
/* grow as required */
mp_err mp_grow(mp_int *a, int size)
{
int i;
mp_digit *tmp;
/* if the alloc size is smaller alloc more ram */
if (a->alloc < size) {
/* reallocate the array a->dp
@ -17,21 +14,20 @@ mp_err mp_grow(mp_int *a, int size)
* in case the operation failed we don't want
* to overwrite the dp member of a.
*/
tmp = (mp_digit *) MP_REALLOC(a->dp,
(size_t)a->alloc * sizeof(mp_digit),
(size_t)size * sizeof(mp_digit));
if (tmp == NULL) {
mp_digit *dp = (mp_digit *) MP_REALLOC(a->dp,
(size_t)a->alloc * sizeof(mp_digit),
(size_t)size * sizeof(mp_digit));
if (dp == NULL) {
/* reallocation failed but "a" is still valid [can be freed] */
return MP_MEM;
}
/* reallocation succeeded so set a->dp */
a->dp = tmp;
a->dp = dp;
/* zero excess digits */
i = a->alloc;
MP_ZERO_DIGITS(a->dp + a->alloc, size - a->alloc);
a->alloc = size;
MP_ZERO_DIGITS(a->dp + i, a->alloc - i);
}
return MP_OKAY;
}

View File

@ -28,10 +28,10 @@ static const char rem_105[105] = {
/* Store non-zero to ret if arg is square, and zero if not */
mp_err mp_is_square(const mp_int *arg, bool *ret)
{
mp_err err;
mp_digit c;
mp_int t;
unsigned long r;
mp_err err;
mp_digit c;
mp_int t;
uint32_t r;
/* Default to Non-square :) */
*ret = false;

View File

@ -23,7 +23,7 @@ mp_err mp_kronecker(const mp_int *a, const mp_int *p, int *c)
mp_err err;
int v, k;
static const int table[8] = {0, 1, 0, -1, 0, -1, 0, 1};
static const char table[] = {0, 1, 0, -1, 0, -1, 0, 1};
if (mp_iszero(p)) {
if ((a->used == 1) && (a->dp[0] == 1u)) {

View File

@ -7,8 +7,6 @@
mp_err mp_lshd(mp_int *a, int b)
{
int x;
mp_err err;
mp_digit *top, *bottom;
/* if its less than zero return */
if (b <= 0) {
@ -21,6 +19,7 @@ mp_err mp_lshd(mp_int *a, int b)
/* grow to fit the new digits */
if (a->alloc < (a->used + b)) {
mp_err err;
if ((err = mp_grow(a, a->used + b)) != MP_OKAY) {
return err;
}
@ -29,18 +28,12 @@ mp_err mp_lshd(mp_int *a, int b)
/* increment the used by the shift amount then copy upwards */
a->used += b;
/* top */
top = a->dp + a->used - 1;
/* base */
bottom = (a->dp + a->used - 1) - b;
/* much like mp_rshd this is implemented using a sliding window
* except the window goes the otherway around. Copying from
* the bottom to the top. see mp_rshd.c for more info.
*/
for (x = a->used - 1; x >= b; x--) {
*top-- = *bottom--;
for (x = a->used; x --> b;) {
a->dp[x] = a->dp[x - b];
}
/* zero the lower digits */

View File

@ -9,8 +9,11 @@ mp_err mp_mod_2d(const mp_int *a, int b, mp_int *c)
int x;
mp_err err;
/* if b is <= 0 then zero the int */
if (b <= 0) {
if (b < 0) {
return MP_VAL;
}
if (b == 0) {
mp_zero(c);
return MP_OKAY;
}
@ -20,7 +23,6 @@ mp_err mp_mod_2d(const mp_int *a, int b, mp_int *c)
return mp_copy(a, c);
}
/* copy */
if ((err = mp_copy(a, c)) != MP_OKAY) {
return err;
}

View File

@ -26,7 +26,6 @@ mp_err mp_montgomery_calc_normalization(mp_int *a, const mp_int *b)
bits = 1;
}
/* now compute C = A * B mod b */
for (x = bits - 1; x < (int)MP_DIGIT_BIT; x++) {
if ((err = mp_mul_2(a, a)) != MP_OKAY) {

View File

@ -6,18 +6,14 @@
/* b = -a */
mp_err mp_neg(const mp_int *a, mp_int *b)
{
mp_err err;
if (a != b) {
mp_err err;
if ((err = mp_copy(a, b)) != MP_OKAY) {
return err;
}
}
if (!mp_iszero(b)) {
b->sign = (a->sign == MP_ZPOS) ? MP_NEG : MP_ZPOS;
} else {
b->sign = MP_ZPOS;
}
b->sign = mp_iszero(b) || b->sign == MP_NEG ? MP_ZPOS : MP_NEG;
return MP_OKAY;
}

View File

@ -36,7 +36,8 @@ int mp_prime_rabin_miller_trials(int size)
for (x = 0; x < (int)(sizeof(sizes)/(sizeof(sizes[0]))); x++) {
if (sizes[x].k == size) {
return sizes[x].t;
} else if (sizes[x].k > size) {
}
if (sizes[x].k > size) {
return (x == 0) ? sizes[0].t : sizes[x - 1].t;
}
}

View File

@ -31,13 +31,13 @@ mp_err mp_read_radix(mp_int *a, const char *str, int radix)
* this allows numbers like 1AB and 1ab to represent the same value
* [e.g. in hex]
*/
int y;
uint8_t y;
char ch = (radix <= 36) ? (char)MP_TOUPPER((int)*str) : *str;
unsigned pos = (unsigned)(ch - '(');
if (MP_RMAP_REVERSE_SIZE < pos) {
break;
}
y = (int)s_mp_rmap_reverse[pos];
y = s_mp_rmap_reverse[pos];
/* if the char was found in the map
* and is less than the given radix add it
@ -49,14 +49,14 @@ mp_err mp_read_radix(mp_int *a, const char *str, int radix)
if ((err = mp_mul_d(a, (mp_digit)radix, a)) != MP_OKAY) {
return err;
}
if ((err = mp_add_d(a, (mp_digit)y, a)) != MP_OKAY) {
if ((err = mp_add_d(a, y, a)) != MP_OKAY) {
return err;
}
++str;
}
/* if an illegal character was found, fail. */
if (!((*str == '\0') || (*str == '\r') || (*str == '\n'))) {
if ((*str != '\0') && (*str != '\r') && (*str != '\n')) {
return MP_VAL;
}

View File

@ -24,19 +24,19 @@ mp_err mp_reduce(mp_int *x, const mp_int *m, const mp_int *mu)
/* according to HAC this optimization is ok */
if ((mp_digit)um > ((mp_digit)1 << (MP_DIGIT_BIT - 1))) {
if ((err = mp_mul(&q, mu, &q)) != MP_OKAY) {
goto CLEANUP;
goto LBL_ERR;
}
} else if (MP_HAS(S_MP_MUL_HIGH_DIGS)) {
if ((err = s_mp_mul_high_digs(&q, mu, &q, um)) != MP_OKAY) {
goto CLEANUP;
goto LBL_ERR;
}
} else if (MP_HAS(S_MP_MUL_HIGH_DIGS_FAST)) {
if ((err = s_mp_mul_high_digs_fast(&q, mu, &q, um)) != MP_OKAY) {
goto CLEANUP;
goto LBL_ERR;
}
} else {
err = MP_VAL;
goto CLEANUP;
goto LBL_ERR;
}
/* q3 = q2 / b**(k+1) */
@ -44,38 +44,38 @@ mp_err mp_reduce(mp_int *x, const mp_int *m, const mp_int *mu)
/* x = x mod b**(k+1), quick (no division) */
if ((err = mp_mod_2d(x, MP_DIGIT_BIT * (um + 1), x)) != MP_OKAY) {
goto CLEANUP;
goto LBL_ERR;
}
/* q = q * m mod b**(k+1), quick (no division) */
if ((err = s_mp_mul_digs(&q, m, &q, um + 1)) != MP_OKAY) {
goto CLEANUP;
goto LBL_ERR;
}
/* x = x - q */
if ((err = mp_sub(x, &q, x)) != MP_OKAY) {
goto CLEANUP;
goto LBL_ERR;
}
/* If x < 0, add b**(k+1) to it */
if (mp_cmp_d(x, 0uL) == MP_LT) {
mp_set(&q, 1uL);
if ((err = mp_lshd(&q, um + 1)) != MP_OKAY) {
goto CLEANUP;
goto LBL_ERR;
}
if ((err = mp_add(x, &q, x)) != MP_OKAY) {
goto CLEANUP;
goto LBL_ERR;
}
}
/* Back off if it's too big */
while (mp_cmp(x, m) != MP_LT) {
if ((err = s_mp_sub(x, m, x)) != MP_OKAY) {
goto CLEANUP;
goto LBL_ERR;
}
}
CLEANUP:
LBL_ERR:
mp_clear(&q);
return err;

View File

@ -6,17 +6,13 @@
/* determines if mp_reduce_2k can be used */
bool mp_reduce_is_2k(const mp_int *a)
{
int ix, iy, iw;
mp_digit iz;
if (a->used == 0) {
return false;
} else if (a->used == 1) {
return true;
} else if (a->used > 1) {
iy = mp_count_bits(a);
iz = 1;
iw = 1;
int ix, iy = mp_count_bits(a), iw = 1;
mp_digit iz = 1;
/* Test every bit from the second digit up, must be 1 */
for (ix = MP_DIGIT_BIT; ix < iy; ix++) {

View File

@ -6,14 +6,13 @@
/* determines if reduce_2k_l can be used */
bool mp_reduce_is_2k_l(const mp_int *a)
{
int ix, iy;
if (a->used == 0) {
return false;
} else if (a->used == 1) {
return true;
} else if (a->used > 1) {
/* if more than half of the digits are -1 we're sold */
int ix, iy;
for (iy = ix = 0; ix < a->used; ix++) {
if (a->dp[ix] == MP_DIGIT_MAX) {
++iy;

View File

@ -6,8 +6,7 @@
/* shift right a certain amount of digits */
void mp_rshd(mp_int *a, int b)
{
int x;
mp_digit *bottom, *top;
int x;
/* if b <= 0 then ignore it */
if (b <= 0) {
@ -20,15 +19,8 @@ void mp_rshd(mp_int *a, int b)
return;
}
/* shift the digits down */
/* bottom */
bottom = a->dp;
/* top [offset into digits] */
top = a->dp + b;
/* this is implemented as a sliding window where
/* shift the digits down.
* this is implemented as a sliding window where
* the window is b-digits long and digits from
* the top of the window are copied to the bottom
*
@ -39,11 +31,11 @@ void mp_rshd(mp_int *a, int b)
\-------------------/ ---->
*/
for (x = 0; x < (a->used - b); x++) {
*bottom++ = *top++;
a->dp[x] = a->dp[x + b];
}
/* zero the top digits */
MP_ZERO_DIGITS(bottom, a->used - x);
MP_ZERO_DIGITS(a->dp + a->used - b, b);
/* remove excess digits */
a->used -= b;

View File

@ -6,9 +6,10 @@
/* set to a digit */
void mp_set(mp_int *a, mp_digit b)
{
int oldused = a->used;
a->dp[0] = b & MP_MASK;
a->sign = MP_ZPOS;
a->used = (a->dp[0] != 0u) ? 1 : 0;
MP_ZERO_DIGITS(a->dp + a->used, a->alloc - a->used);
MP_ZERO_DIGITS(a->dp + a->used, oldused - a->used);
}
#endif

View File

@ -6,15 +6,15 @@
/* shrink a bignum */
mp_err mp_shrink(mp_int *a)
{
mp_digit *tmp;
int alloc = MP_MAX(MP_MIN_PREC, a->used);
if (a->alloc != alloc) {
if ((tmp = (mp_digit *) MP_REALLOC(a->dp,
(size_t)a->alloc * sizeof(mp_digit),
(size_t)alloc * sizeof(mp_digit))) == NULL) {
mp_digit *dp = (mp_digit *) MP_REALLOC(a->dp,
(size_t)a->alloc * sizeof(mp_digit),
(size_t)alloc * sizeof(mp_digit));
if (dp == NULL) {
return MP_MEM;
}
a->dp = tmp;
a->dp = dp;
a->alloc = alloc;
}
return MP_OKAY;

View File

@ -6,17 +6,16 @@
/* shift right by a certain bit count with sign extension */
mp_err mp_signed_rsh(const mp_int *a, int b, mp_int *c)
{
mp_err res;
mp_err err;
if (a->sign == MP_ZPOS) {
return mp_div_2d(a, b, c, NULL);
}
res = mp_add_d(a, 1uL, c);
if (res != MP_OKAY) {
return res;
if ((err = mp_add_d(a, 1uL, c)) != MP_OKAY) {
return err;
}
res = mp_div_2d(c, b, c, NULL);
return (res == MP_OKAY) ? mp_sub_d(c, 1uL, c) : res;
err = mp_div_2d(c, b, c, NULL);
return (err == MP_OKAY) ? mp_sub_d(c, 1uL, c) : err;
}
#endif

View File

@ -25,7 +25,7 @@ mp_err mp_sqrt(const mp_int *arg, mp_int *ret)
}
if ((err = mp_init(&t2)) != MP_OKAY) {
goto E2;
goto LBL_ERR2;
}
/* First approx. (not very bad for large arg) */
@ -33,33 +33,33 @@ mp_err mp_sqrt(const mp_int *arg, mp_int *ret)
/* t1 > 0 */
if ((err = mp_div(arg, &t1, &t2, NULL)) != MP_OKAY) {
goto E1;
goto LBL_ERR1;
}
if ((err = mp_add(&t1, &t2, &t1)) != MP_OKAY) {
goto E1;
goto LBL_ERR1;
}
if ((err = mp_div_2(&t1, &t1)) != MP_OKAY) {
goto E1;
goto LBL_ERR1;
}
/* And now t1 > sqrt(arg) */
do {
if ((err = mp_div(arg, &t1, &t2, NULL)) != MP_OKAY) {
goto E1;
goto LBL_ERR1;
}
if ((err = mp_add(&t1, &t2, &t1)) != MP_OKAY) {
goto E1;
goto LBL_ERR1;
}
if ((err = mp_div_2(&t1, &t1)) != MP_OKAY) {
goto E1;
goto LBL_ERR1;
}
/* t1 >= sqrt(arg) >= t2 at this point */
} while (mp_cmp_mag(&t1, &t2) == MP_GT);
mp_exch(&t1, ret);
E1:
LBL_ERR1:
mp_clear(&t2);
E2:
LBL_ERR2:
mp_clear(&t1);
return err;
}

View File

@ -33,28 +33,28 @@ mp_err mp_sqrtmod_prime(const mp_int *n, const mp_int *prime, mp_int *ret)
* compute directly: err = n^(prime+1)/4 mod prime
* Handbook of Applied Cryptography algorithm 3.36
*/
if ((err = mp_mod_d(prime, 4uL, &i)) != MP_OKAY) goto cleanup;
if ((err = mp_mod_d(prime, 4uL, &i)) != MP_OKAY) goto LBL_END;
if (i == 3u) {
if ((err = mp_add_d(prime, 1uL, &t1)) != MP_OKAY) goto cleanup;
if ((err = mp_div_2(&t1, &t1)) != MP_OKAY) goto cleanup;
if ((err = mp_div_2(&t1, &t1)) != MP_OKAY) goto cleanup;
if ((err = mp_exptmod(n, &t1, prime, ret)) != MP_OKAY) goto cleanup;
if ((err = mp_add_d(prime, 1uL, &t1)) != MP_OKAY) goto LBL_END;
if ((err = mp_div_2(&t1, &t1)) != MP_OKAY) goto LBL_END;
if ((err = mp_div_2(&t1, &t1)) != MP_OKAY) goto LBL_END;
if ((err = mp_exptmod(n, &t1, prime, ret)) != MP_OKAY) goto LBL_END;
err = MP_OKAY;
goto cleanup;
goto LBL_END;
}
/* NOW: Tonelli-Shanks algorithm */
/* factor out powers of 2 from prime-1, defining Q and S as: prime-1 = Q*2^S */
if ((err = mp_copy(prime, &Q)) != MP_OKAY) goto cleanup;
if ((err = mp_sub_d(&Q, 1uL, &Q)) != MP_OKAY) goto cleanup;
if ((err = mp_copy(prime, &Q)) != MP_OKAY) goto LBL_END;
if ((err = mp_sub_d(&Q, 1uL, &Q)) != MP_OKAY) goto LBL_END;
/* Q = prime - 1 */
mp_zero(&S);
/* S = 0 */
while (mp_iseven(&Q)) {
if ((err = mp_div_2(&Q, &Q)) != MP_OKAY) goto cleanup;
if ((err = mp_div_2(&Q, &Q)) != MP_OKAY) goto LBL_END;
/* Q = Q / 2 */
if ((err = mp_add_d(&S, 1uL, &S)) != MP_OKAY) goto cleanup;
if ((err = mp_add_d(&S, 1uL, &S)) != MP_OKAY) goto LBL_END;
/* S = S + 1 */
}
@ -62,55 +62,55 @@ mp_err mp_sqrtmod_prime(const mp_int *n, const mp_int *prime, mp_int *ret)
mp_set(&Z, 2uL);
/* Z = 2 */
for (;;) {
if ((err = mp_kronecker(&Z, prime, &legendre)) != MP_OKAY) goto cleanup;
if ((err = mp_kronecker(&Z, prime, &legendre)) != MP_OKAY) goto LBL_END;
if (legendre == -1) break;
if ((err = mp_add_d(&Z, 1uL, &Z)) != MP_OKAY) goto cleanup;
if ((err = mp_add_d(&Z, 1uL, &Z)) != MP_OKAY) goto LBL_END;
/* Z = Z + 1 */
}
if ((err = mp_exptmod(&Z, &Q, prime, &C)) != MP_OKAY) goto cleanup;
if ((err = mp_exptmod(&Z, &Q, prime, &C)) != MP_OKAY) goto LBL_END;
/* C = Z ^ Q mod prime */
if ((err = mp_add_d(&Q, 1uL, &t1)) != MP_OKAY) goto cleanup;
if ((err = mp_div_2(&t1, &t1)) != MP_OKAY) goto cleanup;
if ((err = mp_add_d(&Q, 1uL, &t1)) != MP_OKAY) goto LBL_END;
if ((err = mp_div_2(&t1, &t1)) != MP_OKAY) goto LBL_END;
/* t1 = (Q + 1) / 2 */
if ((err = mp_exptmod(n, &t1, prime, &R)) != MP_OKAY) goto cleanup;
if ((err = mp_exptmod(n, &t1, prime, &R)) != MP_OKAY) goto LBL_END;
/* R = n ^ ((Q + 1) / 2) mod prime */
if ((err = mp_exptmod(n, &Q, prime, &T)) != MP_OKAY) goto cleanup;
if ((err = mp_exptmod(n, &Q, prime, &T)) != MP_OKAY) goto LBL_END;
/* T = n ^ Q mod prime */
if ((err = mp_copy(&S, &M)) != MP_OKAY) goto cleanup;
if ((err = mp_copy(&S, &M)) != MP_OKAY) goto LBL_END;
/* M = S */
mp_set(&two, 2uL);
for (;;) {
if ((err = mp_copy(&T, &t1)) != MP_OKAY) goto cleanup;
if ((err = mp_copy(&T, &t1)) != MP_OKAY) goto LBL_END;
i = 0;
for (;;) {
if (mp_cmp_d(&t1, 1uL) == MP_EQ) break;
if ((err = mp_exptmod(&t1, &two, prime, &t1)) != MP_OKAY) goto cleanup;
if ((err = mp_exptmod(&t1, &two, prime, &t1)) != MP_OKAY) goto LBL_END;
i++;
}
if (i == 0u) {
if ((err = mp_copy(&R, ret)) != MP_OKAY) goto cleanup;
if ((err = mp_copy(&R, ret)) != MP_OKAY) goto LBL_END;
err = MP_OKAY;
goto cleanup;
goto LBL_END;
}
if ((err = mp_sub_d(&M, i, &t1)) != MP_OKAY) goto cleanup;
if ((err = mp_sub_d(&t1, 1uL, &t1)) != MP_OKAY) goto cleanup;
if ((err = mp_exptmod(&two, &t1, prime, &t1)) != MP_OKAY) goto cleanup;
if ((err = mp_sub_d(&M, i, &t1)) != MP_OKAY) goto LBL_END;
if ((err = mp_sub_d(&t1, 1uL, &t1)) != MP_OKAY) goto LBL_END;
if ((err = mp_exptmod(&two, &t1, prime, &t1)) != MP_OKAY) goto LBL_END;
/* t1 = 2 ^ (M - i - 1) */
if ((err = mp_exptmod(&C, &t1, prime, &t1)) != MP_OKAY) goto cleanup;
if ((err = mp_exptmod(&C, &t1, prime, &t1)) != MP_OKAY) goto LBL_END;
/* t1 = C ^ (2 ^ (M - i - 1)) mod prime */
if ((err = mp_sqrmod(&t1, prime, &C)) != MP_OKAY) goto cleanup;
if ((err = mp_sqrmod(&t1, prime, &C)) != MP_OKAY) goto LBL_END;
/* C = (t1 * t1) mod prime */
if ((err = mp_mulmod(&R, &t1, prime, &R)) != MP_OKAY) goto cleanup;
if ((err = mp_mulmod(&R, &t1, prime, &R)) != MP_OKAY) goto LBL_END;
/* R = (R * t1) mod prime */
if ((err = mp_mulmod(&T, &C, prime, &T)) != MP_OKAY) goto cleanup;
if ((err = mp_mulmod(&T, &C, prime, &T)) != MP_OKAY) goto LBL_END;
/* T = (T * C) mod prime */
mp_set(&M, i);
/* M = i */
}
cleanup:
LBL_END:
mp_clear_multi(&t1, &C, &Q, &S, &Z, &M, &T, &R, &two, NULL);
return err;
}

View File

@ -6,35 +6,31 @@
/* high level subtraction (handles signs) */
mp_err mp_sub(const mp_int *a, const mp_int *b, mp_int *c)
{
mp_sign sa = a->sign, sb = b->sign;
mp_err err;
if (sa != sb) {
if (a->sign != b->sign) {
/* subtract a negative from a positive, OR */
/* subtract a positive from a negative. */
/* In either case, ADD their magnitudes, */
/* and use the sign of the first number. */
c->sign = sa;
err = s_mp_add(a, b, c);
} else {
/* subtract a positive from a positive, OR */
/* subtract a negative from a negative. */
/* First, take the difference between their */
/* magnitudes, then... */
if (mp_cmp_mag(a, b) != MP_LT) {
/* Copy the sign from the first */
c->sign = sa;
/* The first has a larger or equal magnitude */
err = s_mp_sub(a, b, c);
} else {
/* The result has the *opposite* sign from */
/* the first number. */
c->sign = (sa == MP_ZPOS) ? MP_NEG : MP_ZPOS;
/* The second has a larger magnitude */
err = s_mp_sub(b, a, c);
}
c->sign = a->sign;
return s_mp_add(a, b, c);
}
return err;
/* subtract a positive from a positive, OR */
/* subtract a negative from a negative. */
/* First, take the difference between their */
/* magnitudes, then... */
if (mp_cmp_mag(a, b) == MP_LT) {
/* The second has a larger magnitude */
/* The result has the *opposite* sign from */
/* the first number. */
c->sign = (a->sign == MP_ZPOS) ? MP_NEG : MP_ZPOS;
MP_EXCH(const mp_int *, a, b);
} else {
/* The first has a larger or equal magnitude */
/* Copy the sign from the first */
c->sign = a->sign;
}
return s_mp_sub(a, b, c);
}
#endif

View File

@ -4,17 +4,11 @@
/* SPDX-License-Identifier: Unlicense */
/* reverse an array, used for radix code */
static void s_mp_reverse(uint8_t *s, size_t len)
static void s_mp_reverse(char *s, size_t len)
{
size_t ix, iy;
uint8_t t;
ix = 0u;
iy = len - 1u;
size_t ix = 0, iy = len - 1u;
while (ix < iy) {
t = s[ix];
s[ix] = s[iy];
s[iy] = t;
MP_EXCH(char, s[ix], s[iy]);
++ix;
--iy;
}
@ -83,7 +77,7 @@ mp_err mp_to_radix(const mp_int *a, char *str, size_t maxlen, size_t *written, i
/* reverse the digits of the string. In this case _s points
* to the first digit [exluding the sign] of the number
*/
s_mp_reverse((uint8_t *)_s, digs);
s_mp_reverse(_s, digs);
/* append a NULL so the string is properly terminated */
*str = '\0';

View File

@ -7,7 +7,7 @@
void mp_zero(mp_int *a)
{
a->sign = MP_ZPOS;
MP_ZERO_DIGITS(a->dp, a->used);
a->used = 0;
MP_ZERO_DIGITS(a->dp, a->alloc);
}
#endif

View File

@ -17,10 +17,9 @@
static mp_err s_mp_recursion(const mp_int *a, const mp_int *b, mp_int *q, mp_int *r)
{
mp_err err;
int m, k;
mp_int A1, A2, B1, B0, Q1, Q0, R1, R0, t;
int m = a->used - b->used, k = m/2;
m = a->used - b->used;
if (m < MP_KARATSUBA_MUL_CUTOFF) {
return s_mp_div_school(a, b, q, r);
}
@ -29,9 +28,6 @@ static mp_err s_mp_recursion(const mp_int *a, const mp_int *b, mp_int *q, mp_int
goto LBL_ERR;
}
/* k = floor(m/2) */
k = m/2;
/* B1 = b / beta^k, B0 = b % beta^k*/
if ((err = mp_div_2d(b, k * MP_DIGIT_BIT, &B1, &B0)) != MP_OKAY) goto LBL_ERR;

View File

@ -140,8 +140,6 @@ mp_err s_mp_div_school(const mp_int *a, const mp_int *b, mp_int *c, mp_int *d)
mp_exch(&x, d);
}
err = MP_OKAY;
LBL_Y:
mp_clear(&y);
LBL_X:

View File

@ -148,6 +148,8 @@ extern void MP_FREE(void *mem, size_t size);
#define MP_TOUPPER(c) ((((c) >= 'a') && ((c) <= 'z')) ? (((c) + 'A') - 'a') : (c))
#define MP_EXCH(t, a, b) do { t _c = a; a = b; b = _c; } while (0)
/* Static assertion */
#define MP_STATIC_ASSERT(msg, cond) typedef char mp_static_assert_##msg[(cond) ? 1 : -1];
@ -267,9 +269,9 @@ extern MP_PRIVATE const mp_digit s_mp_prime_tab[];
#define MP_GET_MAG(name, type) \
type name(const mp_int* a) \
{ \
unsigned i = MP_MIN((unsigned)a->used, (unsigned)((MP_SIZEOF_BITS(type) + MP_DIGIT_BIT - 1) / MP_DIGIT_BIT)); \
int i = MP_MIN(a->used, (int)((MP_SIZEOF_BITS(type) + MP_DIGIT_BIT - 1) / MP_DIGIT_BIT)); \
type res = 0u; \
while (i --> 0u) { \
while (i --> 0) { \
res <<= ((MP_SIZEOF_BITS(type) <= MP_DIGIT_BIT) ? 0 : MP_DIGIT_BIT); \
res |= (type)a->dp[i]; \
if (MP_SIZEOF_BITS(type) <= MP_DIGIT_BIT) { break; } \