Merge pull request #471 from libtom/prevent-overflow

Prevent overflow
This commit is contained in:
Steffen Jaeckel 2020-01-07 18:35:50 +01:00 committed by GitHub
commit ffd80665d1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 164 additions and 440 deletions

View File

@ -45,5 +45,5 @@ void print_header(void)
printf("Size of mp_digit: %u\n", (unsigned int)sizeof(mp_digit));
printf("Size of mp_word: %u\n", (unsigned int)sizeof(mp_word));
printf("MP_DIGIT_BIT: %d\n", MP_DIGIT_BIT);
printf("MP_PREC: %d\n", MP_PREC);
printf("MP_DEFAULT_DIGIT_COUNT: %d\n", MP_DEFAULT_DIGIT_COUNT);
}

View File

@ -17,8 +17,8 @@
#include "tommath_private.h"
#define EXPECT(a) do { if (!(a)) { fprintf(stderr, "%d: EXPECT(%s) failed\n", __LINE__, #a); goto LBL_ERR; } } while(0)
#define DO_WHAT(a, what) do { mp_err err; if ((err = (a)) != MP_OKAY) { fprintf(stderr, "%d: DO(%s) failed: %s\n", __LINE__, #a, mp_error_to_string(err)); what; } } while(0)
#define EXPECT(a) do { if (!(a)) { fprintf(stderr, "%s, line %d: EXPECT(%s) failed\n", __func__, __LINE__, #a); goto LBL_ERR; } } while(0)
#define DO_WHAT(a, what) do { mp_err err; if ((err = (a)) != MP_OKAY) { fprintf(stderr, "%s, line %d: DO(%s) failed: %s\n", __func__, __LINE__, #a, mp_error_to_string(err)); what; } } while(0)
#define DO(a) DO_WHAT(a, goto LBL_ERR)
#define DOR(a) DO_WHAT(a, return EXIT_FAILURE)

View File

@ -334,10 +334,7 @@ static int test_mp_kronecker(void)
mp_set_ul(&a, 0uL);
mp_set_ul(&b, 1uL);
DO(mp_kronecker(&a, &b, &i));
if (i != 1) {
printf("Failed trivial mp_kronecker(0 | 1) %d != 1\n", i);
goto LBL_ERR;
}
EXPECT(i == 1);
for (cnt = 0; cnt < (int)(sizeof(kronecker)/sizeof(kronecker[0])); ++cnt) {
k = kronecker[cnt].n;
mp_set_l(&a, k);
@ -345,10 +342,7 @@ static int test_mp_kronecker(void)
for (m = -10; m <= 10; m++) {
mp_set_l(&b, m);
DO(mp_kronecker(&a, &b, &i));
if (i != kronecker[cnt].c[m + 10]) {
printf("Failed trivial mp_kronecker(%ld | %ld) %d != %d\n", kronecker[cnt].n, m, i, kronecker[cnt].c[m + 10]);
goto LBL_ERR;
}
EXPECT(i == kronecker[cnt].c[m + 10]);
}
}
@ -374,10 +368,7 @@ static int test_mp_complement(void)
l = ~l;
mp_set_l(&c, l);
if (mp_cmp(&b, &c) != MP_EQ) {
printf("\nmp_complement() bad result!");
goto LBL_ERR;
}
EXPECT(mp_cmp(&b, &c) == MP_EQ);
}
mp_clear_multi(&a, &b, &c, NULL);
@ -406,10 +397,7 @@ static int test_mp_signed_rsh(void)
mp_set_l(&d, l >> em);
DO(mp_signed_rsh(&a, em, &b));
if (mp_cmp(&b, &d) != MP_EQ) {
printf("\nmp_signed_rsh() bad result!");
goto LBL_ERR;
}
EXPECT(mp_cmp(&b, &d) == MP_EQ);
}
mp_clear_multi(&a, &b, &d, NULL);
@ -439,10 +427,7 @@ static int test_mp_xor(void)
mp_set_l(&d, l ^ em);
DO(mp_xor(&a, &b, &c));
if (mp_cmp(&c, &d) != MP_EQ) {
printf("\nmp_xor() bad result!");
goto LBL_ERR;
}
EXPECT(mp_cmp(&c, &d) == MP_EQ);
}
mp_clear_multi(&a, &b, &c, &d, NULL);
@ -472,10 +457,7 @@ static int test_mp_or(void)
mp_set_l(&d, l | em);
DO(mp_or(&a, &b, &c));
if (mp_cmp(&c, &d) != MP_EQ) {
printf("\nmp_or() bad result!");
goto LBL_ERR;
}
EXPECT(mp_cmp(&c, &d) == MP_EQ);
}
mp_clear_multi(&a, &b, &c, &d, NULL);
@ -504,10 +486,7 @@ static int test_mp_and(void)
mp_set_l(&d, l & em);
DO(mp_and(&a, &b, &c));
if (mp_cmp(&c, &d) != MP_EQ) {
printf("\nmp_and() bad result!");
goto LBL_ERR;
}
EXPECT(mp_cmp(&c, &d) == MP_EQ);
}
mp_clear_multi(&a, &b, &c, &d, NULL);
@ -528,28 +507,11 @@ static int test_mp_invmod(void)
const char *b_ = "FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFF";
const char *should_ = "0521A82E10376F8E4FDEF9A32A427AC2A0FFF686E00290D39E3E4B5522409596";
if (mp_read_radix(&a, a_, 16) != MP_OKAY) {
printf("\nmp_read_radix(a) failed!");
goto LBL_ERR;
}
if (mp_read_radix(&b, b_, 16) != MP_OKAY) {
printf("\nmp_read_radix(b) failed!");
goto LBL_ERR;
}
if (mp_read_radix(&c, should_, 16) != MP_OKAY) {
printf("\nmp_read_radix(should) failed!");
goto LBL_ERR;
}
if (mp_invmod(&a, &b, &d) != MP_OKAY) {
printf("\nmp_invmod() failed!");
goto LBL_ERR;
}
if (mp_cmp(&c, &d) != MP_EQ) {
printf("\nmp_invmod() bad result!");
goto LBL_ERR;
}
DO(mp_read_radix(&a, a_, 16));
DO(mp_read_radix(&b, b_, 16));
DO(mp_read_radix(&c, should_, 16));
DO(mp_invmod(&a, &b, &d));
EXPECT(mp_cmp(&c, &d) == MP_EQ);
}
mp_clear_multi(&a, &b, &c, &d, NULL);
@ -569,42 +531,18 @@ static int test_mp_set_double(void)
DOR(mp_init_multi(&a, &b, NULL));
/* test mp_get_double/mp_set_double */
if (mp_set_double(&a, +1.0/0.0) != MP_VAL) {
printf("\nmp_set_double should return MP_VAL for +inf");
goto LBL_ERR;
}
if (mp_set_double(&a, -1.0/0.0) != MP_VAL) {
printf("\nmp_set_double should return MP_VAL for -inf");
goto LBL_ERR;
}
if (mp_set_double(&a, +0.0/0.0) != MP_VAL) {
printf("\nmp_set_double should return MP_VAL for NaN");
goto LBL_ERR;
}
if (mp_set_double(&a, -0.0/0.0) != MP_VAL) {
printf("\nmp_set_double should return MP_VAL for NaN");
goto LBL_ERR;
}
EXPECT(mp_set_double(&a, +1.0/0.0) == MP_VAL);
EXPECT(mp_set_double(&a, -1.0/0.0) == MP_VAL);
EXPECT(mp_set_double(&a, +0.0/0.0) == MP_VAL);
EXPECT(mp_set_double(&a, -0.0/0.0) == MP_VAL);
for (i = 0; i < 1000; ++i) {
int tmp = rand_int();
double dbl = (double)tmp * rand_int() + 1;
if (mp_set_double(&a, dbl) != MP_OKAY) {
printf("\nmp_set_double() failed");
goto LBL_ERR;
}
if (dbl != mp_get_double(&a)) {
printf("\nmp_get_double() bad result!");
goto LBL_ERR;
}
if (mp_set_double(&a, -dbl) != MP_OKAY) {
printf("\nmp_set_double() failed");
goto LBL_ERR;
}
if (-dbl != mp_get_double(&a)) {
printf("\nmp_get_double() bad result!");
goto LBL_ERR;
}
DO(mp_set_double(&a, dbl));
EXPECT(dbl == mp_get_double(&a));
DO(mp_set_double(&a, -dbl));
EXPECT(-dbl == mp_get_double(&a));
}
mp_clear_multi(&a, &b, NULL);
@ -627,21 +565,12 @@ static int test_mp_get_u32(void)
for (i = 0; i < 1000; ++i) {
t = (uint32_t)rand_long();
mp_set_ul(&a, t);
if (t != mp_get_u32(&a)) {
printf("\nmp_get_u32() bad result!");
goto LBL_ERR;
}
EXPECT(t == mp_get_u32(&a));
}
mp_set_ul(&a, 0uL);
if (mp_get_u32(&a) != 0) {
printf("\nmp_get_u32() bad result!");
goto LBL_ERR;
}
EXPECT(mp_get_u32(&a) == 0);
mp_set_ul(&a, UINT32_MAX);
if (mp_get_u32(&a) != UINT32_MAX) {
printf("\nmp_get_u32() bad result!");
goto LBL_ERR;
}
EXPECT(mp_get_u32(&a) == UINT32_MAX);
mp_clear_multi(&a, &b, NULL);
return EXIT_SUCCESS;
@ -666,10 +595,7 @@ static int test_mp_get_ul(void)
do {
mp_set_ul(&a, t);
s = mp_get_ul(&a);
if (s != t) {
printf("\nmp_get_ul() bad result! 0x%lx != 0x%lx", s, t);
goto LBL_ERR;
}
EXPECT(s == t);
t <<= 1;
} while (t != 0uL);
}
@ -697,10 +623,7 @@ static int test_mp_get_u64(void)
do {
mp_set_u64(&a, r);
q = mp_get_u64(&a);
if (q != r) {
printf("\nmp_get_u64() bad result! 0x%" PRIx64 " != 0x%" PRIx64, q, r);
goto LBL_ERR;
}
EXPECT(q == r);
r <<= 1;
} while (r != 0u);
}
@ -725,15 +648,9 @@ static int test_mp_sqrt(void)
fflush(stdout);
n = (rand_int() & 15) + 1;
DO(mp_rand(&a, n));
if (mp_sqrt(&a, &b) != MP_OKAY) {
printf("\nmp_sqrt() error!");
goto LBL_ERR;
}
DO(mp_sqrt(&a, &b));
DO(mp_root_n(&a, 2, &c));
if (mp_cmp_mag(&b, &c) != MP_EQ) {
printf("mp_sqrt() bad result!\n");
goto LBL_ERR;
}
EXPECT(mp_cmp_mag(&b, &c) == MP_EQ);
}
mp_clear_multi(&a, &b, &c, NULL);
@ -760,26 +677,13 @@ static int test_mp_is_square(void)
n = (rand_int() & 7) + 1;
DO(mp_rand(&a, n));
DO(mp_sqr(&a, &a));
if (mp_is_square(&a, &res) != MP_OKAY) {
printf("\nfn:mp_is_square() error!");
goto LBL_ERR;
}
if (!res) {
printf("\nfn:mp_is_square() bad result!");
goto LBL_ERR;
}
DO(mp_is_square(&a, &res));
EXPECT(res);
/* test for false positives */
DO(mp_add_d(&a, 1u, &a));
if (mp_is_square(&a, &res) != MP_OKAY) {
printf("\nfp:mp_is_square() error!");
goto LBL_ERR;
}
if (res) {
printf("\nfp:mp_is_square() bad result!");
goto LBL_ERR;
}
DO(mp_is_square(&a, &res));
EXPECT(!res);
}
mp_clear_multi(&a, &b, NULL);
@ -811,15 +715,8 @@ static int test_mp_sqrtmod_prime(void)
for (i = 0; i < (int)(sizeof(sqrtmod_prime)/sizeof(sqrtmod_prime[0])); ++i) {
mp_set_ul(&a, sqrtmod_prime[i].p);
mp_set_ul(&b, sqrtmod_prime[i].n);
if (mp_sqrtmod_prime(&b, &a, &c) != MP_OKAY) {
printf("Failed executing %d. mp_sqrtmod_prime\n", (i+1));
goto LBL_ERR;
}
if (mp_cmp_d(&c, sqrtmod_prime[i].r) != MP_EQ) {
printf("Failed %d. trivial mp_sqrtmod_prime\n", (i+1));
ndraw(&c, "r");
goto LBL_ERR;
}
DO(mp_sqrtmod_prime(&b, &a, &c));
EXPECT(mp_cmp_d(&c, sqrtmod_prime[i].r) == MP_EQ);
}
mp_clear_multi(&a, &b, &c, NULL);
@ -832,23 +729,15 @@ LBL_ERR:
static int test_mp_prime_rand(void)
{
int ix;
mp_err e;
mp_int a, b;
DOR(mp_init_multi(&a, &b, NULL));
/* test for size */
for (ix = 10; ix < 128; ix++) {
printf("Testing (not safe-prime): %9d bits \r", ix);
printf("Testing (not safe-prime): %9d bits \n", ix);
fflush(stdout);
e = mp_prime_rand(&a, 8, ix, (rand_int() & 1) ? 0 : MP_PRIME_2MSB_ON);
if (e != MP_OKAY) {
printf("\nfailed with error: %s\n", mp_error_to_string(e));
goto LBL_ERR;
}
if (mp_count_bits(&a) != ix) {
printf("Prime is %d not %d bits!!!\n", mp_count_bits(&a), ix);
goto LBL_ERR;
}
DO(mp_prime_rand(&a, 8, ix, (rand_int() & 1) ? 0 : MP_PRIME_2MSB_ON));
EXPECT(mp_count_bits(&a) == ix);
}
mp_clear_multi(&a, &b, NULL);
@ -902,31 +791,16 @@ static int test_mp_prime_is_prime(void)
for (ix = 16; ix < 128; ix++) {
printf("\rTesting ( safe-prime): %9d bits ", ix);
fflush(stdout);
e = mp_prime_rand(&a, 8, ix, ((rand_int() & 1) ? 0 : MP_PRIME_2MSB_ON) | MP_PRIME_SAFE);
if (e != MP_OKAY) {
printf("\nfailed with error: %s\n", mp_error_to_string(e));
goto LBL_ERR;
}
if (mp_count_bits(&a) != ix) {
printf("Prime is %d not %d bits!!!\n", mp_count_bits(&a), ix);
goto LBL_ERR;
}
DO(mp_prime_rand(&a, 8, ix, ((rand_int() & 1) ? 0 : MP_PRIME_2MSB_ON) | MP_PRIME_SAFE));
EXPECT(mp_count_bits(&a) == ix);
/* let's see if it's really a safe prime */
DO(mp_sub_d(&a, 1u, &b));
DO(mp_div_2(&b, &b));
e = mp_prime_is_prime(&b, mp_prime_rabin_miller_trials(mp_count_bits(&b)), &cnt);
/* small problem */
if (e != MP_OKAY) {
printf("\nfailed with error: %s\n", mp_error_to_string(e));
}
DO(mp_prime_is_prime(&b, mp_prime_rabin_miller_trials(mp_count_bits(&b)), &cnt));
/* large problem */
if (!cnt) {
printf("\nsub is not prime!\n");
}
EXPECT(cnt);
DO(mp_prime_frobenius_underwood(&b, &fu));
if (!fu) {
printf("\nfrobenius-underwood says sub is not prime!\n");
}
EXPECT(fu);
if ((e != MP_OKAY) || !cnt) {
printf("prime tested was: 0x");
DO(mp_fwrite(&a,16,stdout));
@ -942,15 +816,9 @@ static int test_mp_prime_is_prime(void)
DO(mp_read_radix(&a,
"FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A63A3620FFFFFFFFFFFFFFFF",
16));
e = mp_prime_strong_lucas_selfridge(&a, &cnt);
/* small problem */
if (e != MP_OKAY) {
printf("\nmp_prime_strong_lucas_selfridge failed with error: %s\n", mp_error_to_string(e));
}
DO(mp_prime_strong_lucas_selfridge(&a, &cnt));
/* large problem */
if (!cnt) {
printf("\n\nissue #143 - mp_prime_strong_lucas_selfridge FAILED!\n");
}
EXPECT(cnt);
if ((e != MP_OKAY) || !cnt) {
printf("prime tested was: 0x");
DO(mp_fwrite(&a,16,stdout));
@ -1185,10 +1053,7 @@ static int test_mp_cnt_lsb(void)
mp_set(&a, 1u);
for (ix = 0; ix < 1024; ix++) {
if (mp_cnt_lsb(&a) != ix) {
printf("Failed at %d, %d\n", ix, mp_cnt_lsb(&a));
goto LBL_ERR;
}
EXPECT(mp_cnt_lsb(&a) == ix);
DO(mp_mul_2(&a, &a));
}
@ -1227,10 +1092,7 @@ static int test_mp_reduce_2k(void)
DO(mp_copy(&c, &b));
DO(mp_mod(&c, &a, &c));
DO(mp_reduce_2k(&b, &a, 2u));
if (mp_cmp(&c, &b) != MP_EQ) {
printf("FAILED\n");
goto LBL_ERR;
}
EXPECT(mp_cmp(&c, &b) == MP_EQ);
}
}
@ -1261,10 +1123,7 @@ static int test_s_mp_div_3(void)
DO(mp_div(&a, &d, &b, &e));
DO(s_mp_div_3(&a, &c, &r2));
if (mp_cmp(&b, &c) || mp_cmp_d(&e, r2)) {
printf("\ns_mp_div_3 => Failure\n");
goto LBL_ERR;
}
EXPECT(!mp_cmp(&b, &c) && !mp_cmp_d(&e, r2));
}
printf("... passed!");
@ -1313,10 +1172,7 @@ static int test_mp_dr_reduce(void)
mp_dr_setup(&a, &mp);
DO(mp_dr_reduce(&c, &a, mp));
if (mp_cmp(&b, &c) != MP_EQ) {
printf("Failed on trial %u\n", rr);
goto LBL_ERR;
}
EXPECT(mp_cmp(&b, &c) == MP_EQ);
} while (++rr < 500);
printf(" passed");
fflush(stdout);
@ -1356,10 +1212,7 @@ static int test_mp_reduce_2k_l(void)
DO(mp_to_radix(&a, buf, sizeof(buf), &length, 10));
printf("\n\np==%s, length = %zu\n", buf, length);
/* now mp_reduce_is_2k_l() should return */
if (mp_reduce_is_2k_l(&a) != 1) {
printf("mp_reduce_is_2k_l() return 0, should be 1\n");
goto LBL_ERR;
}
EXPECT(mp_reduce_is_2k_l(&a) == 1);
DO(mp_reduce_2k_setup_l(&a, &d));
/* now do a million square+1 to see if it varies */
DO(mp_rand(&b, 64));
@ -1395,7 +1248,6 @@ LBL_ERR:
}
/* stripped down version of mp_radix_size. The faster version can be off by up t
o +3 */
/* TODO: This function should be removed, replaced by mp_radix_size, mp_radix_size_overestimate in 2.0 */
static mp_err s_rs(const mp_int *a, int radix, int *size)
{
mp_err res;
@ -1440,13 +1292,9 @@ static int test_mp_log_n(void)
*/
mp_set(&a, 42u);
base = 0u;
if (mp_log_n(&a, base, &lb) != MP_VAL) {
goto LBL_ERR;
}
EXPECT(mp_log_n(&a, base, &lb) == MP_VAL);
base = 1u;
if (mp_log_n(&a, base, &lb) != MP_VAL) {
goto LBL_ERR;
}
EXPECT(mp_log_n(&a, base, &lb) == MP_VAL);
/*
base a result
2 0 MP_VAL
@ -1456,16 +1304,12 @@ static int test_mp_log_n(void)
*/
base = 2u;
mp_zero(&a);
if (mp_log_n(&a, base, &lb) != MP_VAL) {
goto LBL_ERR;
}
EXPECT(mp_log_n(&a, base, &lb) == MP_VAL);
for (d = 1; d < 4; d++) {
mp_set(&a, d);
DO(mp_log_n(&a, base, &lb));
if (lb != ((d == 1)?0:1)) {
goto LBL_ERR;
}
EXPECT(lb == ((d == 1)?0:1));
}
/*
base a result
@ -1476,15 +1320,11 @@ static int test_mp_log_n(void)
*/
base = 3u;
mp_zero(&a);
if (mp_log_n(&a, base, &lb) != MP_VAL) {
goto LBL_ERR;
}
EXPECT(mp_log_n(&a, base, &lb) == MP_VAL);
for (d = 1; d < 4; d++) {
mp_set(&a, d);
DO(mp_log_n(&a, base, &lb));
if (lb != (((int)d < base)?0:1)) {
goto LBL_ERR;
}
EXPECT(lb == (((int)d < base)?0:1));
}
/*
@ -1498,9 +1338,7 @@ static int test_mp_log_n(void)
DO(s_rs(&a,(int)base, &size));
/* radix_size includes the memory needed for '\0', too*/
size -= 2;
if (lb != size) {
goto LBL_ERR;
}
EXPECT(lb == size);
}
/*
@ -1512,9 +1350,7 @@ static int test_mp_log_n(void)
DO(mp_log_n(&a, base, &lb));
DO(s_rs(&a,(int)base, &size));
size -= 2;
if (lb != size) {
goto LBL_ERR;
}
EXPECT(lb == size);
}
/*Test upper edgecase with base UINT32_MAX and number (UINT32_MAX/2)*UINT32_MAX^10 */
@ -1522,9 +1358,7 @@ static int test_mp_log_n(void)
DO(mp_expt_n(&a, 10uL, &a));
DO(mp_add_d(&a, max_base / 2, &a));
DO(mp_log_n(&a, max_base, &lb));
if (lb != 10u) {
goto LBL_ERR;
}
EXPECT(lb == 10u);
mp_clear(&a);
return EXIT_SUCCESS;
@ -1542,39 +1376,30 @@ static int test_mp_incr(void)
/* Does it increment inside the limits of a MP_xBIT limb? */
mp_set(&a, MP_MASK/2);
DO(mp_incr(&a));
if (mp_cmp_d(&a, (MP_MASK/2u) + 1u) != MP_EQ) {
goto LBL_ERR;
}
EXPECT(mp_cmp_d(&a, (MP_MASK/2u) + 1u) == MP_EQ);
/* Does it increment outside of the limits of a MP_xBIT limb? */
mp_set(&a, MP_MASK);
mp_set(&b, MP_MASK);
DO(mp_incr(&a));
DO(mp_add_d(&b, 1u, &b));
if (mp_cmp(&a, &b) != MP_EQ) {
goto LBL_ERR;
}
EXPECT(mp_cmp(&a, &b) == MP_EQ);
/* Does it increment from -1 to 0? */
mp_set(&a, 1u);
a.sign = MP_NEG;
DO(mp_incr(&a));
if (mp_cmp_d(&a, 0u) != MP_EQ) {
goto LBL_ERR;
}
EXPECT(mp_cmp_d(&a, 0u) == MP_EQ);
/* Does it increment from -(MP_MASK + 1) to -MP_MASK? */
mp_set(&a, MP_MASK);
DO(mp_add_d(&a, 1u, &a));
a.sign = MP_NEG;
DO(mp_incr(&a));
if (a.sign != MP_NEG) {
goto LBL_ERR;
}
EXPECT(a.sign == MP_NEG);
a.sign = MP_ZPOS;
if (mp_cmp_d(&a, MP_MASK) != MP_EQ) {
goto LBL_ERR;
}
EXPECT(mp_cmp_d(&a, MP_MASK) == MP_EQ);
mp_clear_multi(&a, &b, NULL);
return EXIT_SUCCESS;
@ -1592,26 +1417,20 @@ static int test_mp_decr(void)
/* Does it decrement inside the limits of a MP_xBIT limb? */
mp_set(&a, MP_MASK/2);
DO(mp_decr(&a));
if (mp_cmp_d(&a, (MP_MASK/2u) - 1u) != MP_EQ) {
goto LBL_ERR;
}
EXPECT(mp_cmp_d(&a, (MP_MASK/2u) - 1u) == MP_EQ);
/* Does it decrement outside of the limits of a MP_xBIT limb? */
mp_set(&a, MP_MASK);
DO(mp_add_d(&a, 1u, &a));
DO(mp_decr(&a));
if (mp_cmp_d(&a, MP_MASK) != MP_EQ) {
goto LBL_ERR;
}
EXPECT(mp_cmp_d(&a, MP_MASK) == MP_EQ);
/* Does it decrement from 0 to -1? */
mp_zero(&a);
DO(mp_decr(&a));
if (a.sign == MP_NEG) {
a.sign = MP_ZPOS;
if (mp_cmp_d(&a, 1u) != MP_EQ) {
goto LBL_ERR;
}
EXPECT(mp_cmp_d(&a, 1u) == MP_EQ);
} else {
goto LBL_ERR;
}
@ -1624,9 +1443,7 @@ static int test_mp_decr(void)
b.sign = MP_NEG;
DO(mp_sub_d(&b, 1u, &b));
DO(mp_decr(&a));
if (mp_cmp(&a, &b) != MP_EQ) {
goto LBL_ERR;
}
EXPECT(mp_cmp(&a, &b) == MP_EQ);
mp_clear_multi(&a, &b, NULL);
return EXIT_SUCCESS;
@ -1852,10 +1669,7 @@ static int test_mp_root_n(void)
for (j = 3; j < 100; j++) {
DO(mp_root_n(&a, j, &c));
DO(mp_read_radix(&r, root[i][j-3], 10));
if (mp_cmp(&r, &c) != MP_EQ) {
fprintf(stderr, "mp_root_n failed at input #%d, root #%d\n", i, j);
goto LBL_ERR;
}
EXPECT(mp_cmp(&r, &c) == MP_EQ);
}
}
mp_clear_multi(&a, &c, &r, NULL);
@ -1884,9 +1698,7 @@ static int test_s_mp_mul_balance(void)
DO(mp_read_radix(&b, nc, 64));
if (mp_cmp(&b, &c) != MP_EQ) {
goto LBL_ERR;
}
EXPECT(mp_cmp(&b, &c) == MP_EQ);
mp_clear_multi(&a, &b, &c, NULL);
return EXIT_SUCCESS;
@ -1907,10 +1719,7 @@ static int test_s_mp_mul_karatsuba(void)
DO(mp_rand(&b, size));
DO(s_mp_mul_karatsuba(&a, &b, &c));
DO(s_mp_mul_full(&a,&b,&d));
if (mp_cmp(&c, &d) != MP_EQ) {
fprintf(stderr, "Karatsuba multiplication failed at size %d\n", size);
goto LBL_ERR;
}
EXPECT(mp_cmp(&c, &d) == MP_EQ);
}
mp_clear_multi(&a, &b, &c, &d, NULL);
@ -1930,10 +1739,7 @@ static int test_s_mp_sqr_karatsuba(void)
DO(mp_rand(&a, size));
DO(s_mp_sqr_karatsuba(&a, &b));
DO(s_mp_sqr(&a, &c));
if (mp_cmp(&b, &c) != MP_EQ) {
fprintf(stderr, "Karatsuba squaring failed at size %d\n", size);
goto LBL_ERR;
}
EXPECT(mp_cmp(&b, &c) == MP_EQ);
}
mp_clear_multi(&a, &b, &c, NULL);
@ -1969,10 +1775,7 @@ static int test_s_mp_mul_toom(void)
DO(mp_mul(&a, &b, &c));
MP_MUL_TOOM_CUTOFF = tc_cutoff;
DO(mp_mul(&a, &b, &d));
if (mp_cmp(&c, &d) != MP_EQ) {
fprintf(stderr, "Toom-Cook 3-way multiplication failed for edgecase f1 * f2\n");
goto LBL_ERR;
}
EXPECT(mp_cmp(&c, &d) == MP_EQ);
#endif
for (size = MP_MUL_TOOM_CUTOFF; size < (MP_MUL_TOOM_CUTOFF + 20); size++) {
@ -1980,10 +1783,7 @@ static int test_s_mp_mul_toom(void)
DO(mp_rand(&b, size));
DO(s_mp_mul_toom(&a, &b, &c));
DO(s_mp_mul_full(&a,&b,&d));
if (mp_cmp(&c, &d) != MP_EQ) {
fprintf(stderr, "Toom-Cook 3-way multiplication failed at size %d\n", size);
goto LBL_ERR;
}
EXPECT(mp_cmp(&c, &d) == MP_EQ);
}
mp_clear_multi(&a, &b, &c, &d, NULL);
@ -2003,10 +1803,7 @@ static int test_s_mp_sqr_toom(void)
DO(mp_rand(&a, size));
DO(s_mp_sqr_toom(&a, &b));
DO(s_mp_sqr(&a, &c));
if (mp_cmp(&b, &c) != MP_EQ) {
fprintf(stderr, "Toom-Cook 3-way squaring failed at size %d\n", size);
goto LBL_ERR;
}
EXPECT(mp_cmp(&b, &c) == MP_EQ);
}
mp_clear_multi(&a, &b, &c, NULL);
@ -2043,18 +1840,10 @@ static int test_mp_radix_size(void)
for (radix = 2; radix < 65; radix++) {
DO(mp_radix_size(&a, radix, &size));
if (size != results[radix]) {
fprintf(stderr, "mp_radix_size: result for base %d was %zu instead of %zu\n",
radix, size, results[radix]);
goto LBL_ERR;
}
EXPECT(size == results[radix]);
a.sign = MP_NEG;
DO(mp_radix_size(&a, radix, &size));
if (size != (results[radix] + 1)) {
fprintf(stderr, "mp_radix_size: result for base %d was %zu instead of %zu\n",
radix, size, results[radix]);
goto LBL_ERR;
}
EXPECT(size == (results[radix] + 1));
a.sign = MP_ZPOS;
}
@ -2082,16 +1871,8 @@ static int test_s_mp_div_recursive(void)
DO(mp_rand(&b, size));
DO(s_mp_div_recursive(&a, &b, &c_q, &c_r));
DO(s_mp_div_school(&a, &b, &d_q, &d_r));
if (mp_cmp(&c_q, &d_q) != MP_EQ) {
fprintf(stderr, "1a. Recursive division failed at sizes %d / %d, wrong quotient\n",
10 * size, size);
goto LBL_ERR;
}
if (mp_cmp(&c_r, &d_r) != MP_EQ) {
fprintf(stderr, "1a. Recursive division failed at sizes %d / %d, wrong remainder\n",
10 * size, size);
goto LBL_ERR;
}
EXPECT(mp_cmp(&c_q, &d_q) == MP_EQ);
EXPECT(mp_cmp(&c_r, &d_r) == MP_EQ);
printf("\rsizes = %d / %d", 2 * size, size);
/* Relation 10:1 negative numerator*/
@ -2100,16 +1881,8 @@ static int test_s_mp_div_recursive(void)
DO(mp_rand(&b, size));
DO(s_mp_div_recursive(&a, &b, &c_q, &c_r));
DO(s_mp_div_school(&a, &b, &d_q, &d_r));
if (mp_cmp(&c_q, &d_q) != MP_EQ) {
fprintf(stderr, "1b. Recursive division failed at sizes %d / %d, wrong quotient\n",
10 * size, size);
goto LBL_ERR;
}
if (mp_cmp(&c_r, &d_r) != MP_EQ) {
fprintf(stderr, "1b. Recursive division failed at sizes %d / %d, wrong remainder\n",
10 * size, size);
goto LBL_ERR;
}
EXPECT(mp_cmp(&c_q, &d_q) == MP_EQ);
EXPECT(mp_cmp(&c_r, &d_r) == MP_EQ);
printf("\rsizes = %d / %d, negative numerator", 2 * size, size);
/* Relation 10:1 negative denominator*/
@ -2118,16 +1891,8 @@ static int test_s_mp_div_recursive(void)
DO(mp_neg(&b, &b));
DO(s_mp_div_recursive(&a, &b, &c_q, &c_r));
DO(s_mp_div_school(&a, &b, &d_q, &d_r));
if (mp_cmp(&c_q, &d_q) != MP_EQ) {
fprintf(stderr, "1c. Recursive division failed at sizes %d / %d, wrong quotient\n",
10 * size, size);
goto LBL_ERR;
}
if (mp_cmp(&c_r, &d_r) != MP_EQ) {
fprintf(stderr, "1c. Recursive division failed at sizes %d / %d, wrong remainder\n",
10 * size, size);
goto LBL_ERR;
}
EXPECT(mp_cmp(&c_q, &d_q) == MP_EQ);
EXPECT(mp_cmp(&c_r, &d_r) == MP_EQ);
printf("\rsizes = %d / %d, negative denominator", 2 * size, size);
/* Relation 2:1 */
@ -2135,32 +1900,16 @@ static int test_s_mp_div_recursive(void)
DO(mp_rand(&b, size));
DO(s_mp_div_recursive(&a, &b, &c_q, &c_r));
DO(s_mp_div_school(&a, &b, &d_q, &d_r));
if (mp_cmp(&c_q, &d_q) != MP_EQ) {
fprintf(stderr, "2. Recursive division failed at sizes %d / %d, wrong quotient\n",
2 * size, size);
goto LBL_ERR;
}
if (mp_cmp(&c_r, &d_r) != MP_EQ) {
fprintf(stderr, "2. Recursive division failed at sizes %d / %d, wrong remainder\n",
2 * size, size);
goto LBL_ERR;
}
EXPECT(mp_cmp(&c_q, &d_q) == MP_EQ);
EXPECT(mp_cmp(&c_r, &d_r) == MP_EQ);
printf("\rsizes = %d / %d", 3 * size, 2 * size);
/* Upper limit 3:2 */
DO(mp_rand(&a, 3 * size));
DO(mp_rand(&b, 2 * size));
DO(s_mp_div_recursive(&a, &b, &c_q, &c_r));
DO(s_mp_div_school(&a, &b, &d_q, &d_r));
if (mp_cmp(&c_q, &d_q) != MP_EQ) {
fprintf(stderr, "3. Recursive division failed at sizes %d / %d, wrong quotient\n",
3 * size, 2 * size);
goto LBL_ERR;
}
if (mp_cmp(&c_r, &d_r) != MP_EQ) {
fprintf(stderr, "3. Recursive division failed at sizes %d / %d, wrong remainder\n",
3 * size, 2 * size);
goto LBL_ERR;
}
EXPECT(mp_cmp(&c_q, &d_q) == MP_EQ);
EXPECT(mp_cmp(&c_r, &d_r) == MP_EQ);
}
mp_clear_multi(&a, &b, &c_q, &c_r, &d_q, &d_r, NULL);
@ -2183,16 +1932,8 @@ static int test_s_mp_div_small(void)
DO(mp_rand(&b, size));
DO(s_mp_div_small(&a, &b, &c_q, &c_r));
DO(s_mp_div_school(&a, &b, &d_q, &d_r));
if (mp_cmp(&c_q, &d_q) != MP_EQ) {
fprintf(stderr, "1. Small division failed at sizes %d / %d, wrong quotient\n",
2 * size, size);
goto LBL_ERR;
}
if (mp_cmp(&c_r, &d_r) != MP_EQ) {
fprintf(stderr, "1. Small division failed at sizes %d / %d, wrong remainder\n",
2 * size, size);
goto LBL_ERR;
}
EXPECT(mp_cmp(&c_q, &d_q) == MP_EQ);
EXPECT(mp_cmp(&c_r, &d_r) == MP_EQ);
}
mp_clear_multi(&a, &b, &c_q, &c_r, &d_q, &d_r, NULL);
return EXIT_SUCCESS;
@ -2205,10 +1946,9 @@ LBL_ERR:
static int test_s_mp_radix_size_overestimate(void)
{
mp_err err;
mp_int a;
int radix;
size_t size;
int radix, n;
size_t size, size2;
/* *INDENT-OFF* */
size_t results[65] = {
0u, 0u, 1627u, 1027u, 814u, 702u, 630u, 581u, 543u,
@ -2220,78 +1960,43 @@ static int test_s_mp_radix_size_overestimate(void)
284u, 283u, 281u, 280u, 279u, 278u, 277u, 276u, 275u,
273u, 272u
};
size_t big_results[65] = {
0u, 0u, 0u, 1354911329u, 1073741825u,
924870867u, 830760078u, 764949110u, 715827883u, 677455665u,
646456994u, 620761988u, 599025415u, 580332018u, 564035582u,
549665673u, 536870913u, 525383039u, 514993351u, 505536793u,
496880930u, 488918137u, 481559946u, 474732892u, 468375401u,
462435434u, 456868672u, 451637110u, 446707948u, 442052707u,
437646532u, 433467613u, 429496730u, 425716865u, 422112892u,
418671312u, 415380039u, 412228213u, 409206043u, 406304679u,
403516096u, 400833001u, 398248746u, 395757256u, 393352972u,
391030789u, 388786017u, 386614331u, 384511740u, 382474555u,
380499357u, 378582973u, 376722456u, 374915062u, 373158233u,
371449582u, 369786879u, 368168034u, 366591092u, 365054217u,
363555684u, 362093873u, 360667257u, 359274399u, 357913942
};
/* *INDENT-ON* */
if ((err = mp_init(&a)) != MP_OKAY) goto LBL_ERR;
DO(mp_init(&a));
/* number to result in a different size for every base: 67^(4 * 67) */
mp_set(&a, 67);
if ((err = mp_expt_n(&a, 268, &a)) != MP_OKAY) {
goto LBL_ERR;
}
DO(mp_expt_n(&a, 268, &a));
for (radix = 2; radix < 65; radix++) {
if ((err = s_mp_radix_size_overestimate(&a, radix, &size)) != MP_OKAY) {
goto LBL_ERR;
}
if (size < results[radix]) {
fprintf(stderr, "s_mp_radix_size_overestimate: result for base %d was %zu instead of %zu\n",
radix, size, results[radix]);
goto LBL_ERR;
}
DO(s_mp_radix_size_overestimate(&a, radix, &size));
EXPECT(size >= results[radix]);
EXPECT(size < results[radix] + 20); /* some error bound */
a.sign = MP_NEG;
if ((err = s_mp_radix_size_overestimate(&a, radix, &size)) != MP_OKAY) {
goto LBL_ERR;
}
if (size < results[radix]) {
fprintf(stderr, "s_mp_radix_size_overestimate: result for base %d was %zu instead of %zu\n",
radix, size, results[radix]);
goto LBL_ERR;
}
DO(s_mp_radix_size_overestimate(&a, radix, &size));
EXPECT(size >= results[radix]);
EXPECT(size < results[radix] + 20); /* some error bound */
a.sign = MP_ZPOS;
}
if ((err = mp_2expt(&a, INT_MAX - 1)) != MP_OKAY) {
goto LBL_ERR;
}
printf("bitcount = %d, alloc = %d\n", mp_count_bits(&a), a.alloc);
/* Start at 3 to avoid integer overflow */
for (radix = 3; radix < 65; radix++) {
printf("radix = %d, ",radix);
if ((err = s_mp_radix_size_overestimate(&a, radix, &size)) != MP_OKAY) {
goto LBL_ERR;
}
printf("size = %zu, diff = %zu\n", size, size - big_results[radix]);
if (size < big_results[radix]) {
fprintf(stderr, "s_mp_radix_size_overestimate: result for base %d was %zu instead of %zu\n",
radix, size, results[radix]);
goto LBL_ERR;
}
a.sign = MP_NEG;
if ((err = s_mp_radix_size_overestimate(&a, radix, &size)) != MP_OKAY) {
goto LBL_ERR;
}
if (size < big_results[radix]) {
fprintf(stderr, "s_mp_radix_size_overestimate: result for base %d was %zu instead of %zu\n",
radix, size, results[radix]);
goto LBL_ERR;
}
a.sign = MP_ZPOS;
/* randomized test */
for (n = 1; n < 1024; n += 1234) {
DO(mp_rand(&a, n));
for (radix = 2; radix < 65; radix++) {
DO(s_mp_radix_size_overestimate(&a, radix, &size));
DO(mp_radix_size(&a, radix, &size2));
EXPECT(size >= size2);
EXPECT(size < size2 + 20); /* some error bound */
a.sign = MP_NEG;
DO(s_mp_radix_size_overestimate(&a, radix, &size));
DO(mp_radix_size(&a, radix, &size2));
EXPECT(size >= size2);
EXPECT(size < size2 + 20); /* some error bound */
a.sign = MP_ZPOS;
}
}
mp_clear(&a);
return EXIT_SUCCESS;
LBL_ERR:

View File

@ -172,6 +172,7 @@ c89:
-e 's/\(PRI[iux]64\)/MP_\1/g' \
-e 's/uint\([0-9][0-9]*\)_t/mp_u\1/g' \
-e 's/int\([0-9][0-9]*\)_t/mp_i\1/g' \
-e 's/__func__/MP_FUNCTION_NAME/g' \
*.c tommath.h tommath_private.h demo/*.c demo/*.h etc/*.c
c99:
@ -194,6 +195,7 @@ c99:
-e 's/MP_\(PRI[iux]64\)/\1/g' \
-e 's/mp_u\([0-9][0-9]*\)/uint\1_t/g' \
-e 's/mp_i\([0-9][0-9]*\)/int\1_t/g' \
-e 's/MP_FUNCTION_NAME/__func__/g' \
*.c tommath.h tommath_private.h demo/*.c demo/*.h etc/*.c
astyle:

View File

@ -19,6 +19,8 @@ const char *mp_error_to_string(mp_err code)
return "Max. iterations reached";
case MP_BUF:
return "Buffer overflow";
case MP_OVF:
return "Integer overflow";
default:
return "Invalid error code";
}

View File

@ -8,15 +8,21 @@ mp_err mp_grow(mp_int *a, int size)
{
/* if the alloc size is smaller alloc more ram */
if (a->alloc < size) {
mp_digit *dp;
if (size > MP_MAX_DIGIT_COUNT) {
return MP_OVF;
}
/* reallocate the array a->dp
*
* We store the return in a temporary variable
* in case the operation failed we don't want
* to overwrite the dp member of a.
*/
mp_digit *dp = (mp_digit *) MP_REALLOC(a->dp,
(size_t)a->alloc * sizeof(mp_digit),
(size_t)size * sizeof(mp_digit));
dp = (mp_digit *) MP_REALLOC(a->dp,
(size_t)a->alloc * sizeof(mp_digit),
(size_t)size * sizeof(mp_digit));
if (dp == NULL) {
/* reallocation failed but "a" is still valid [can be freed] */
return MP_MEM;

View File

@ -7,7 +7,7 @@
mp_err mp_init(mp_int *a)
{
/* allocate memory required and clear it */
a->dp = (mp_digit *) MP_CALLOC((size_t)MP_PREC, sizeof(mp_digit));
a->dp = (mp_digit *) MP_CALLOC((size_t)MP_DEFAULT_DIGIT_COUNT, sizeof(mp_digit));
if (a->dp == NULL) {
return MP_MEM;
}
@ -15,7 +15,7 @@ mp_err mp_init(mp_int *a)
/* set the used to zero, allocated digits to the default precision
* and sign to positive */
a->used = 0;
a->alloc = MP_PREC;
a->alloc = MP_DEFAULT_DIGIT_COUNT;
a->sign = MP_ZPOS;
return MP_OKAY;

View File

@ -7,14 +7,15 @@
mp_err mp_init_multi(mp_int *mp, ...)
{
mp_err err = MP_OKAY; /* Assume ok until proven otherwise */
mp_err err = MP_OKAY;
int n = 0; /* Number of ok inits */
mp_int *cur_arg = mp;
va_list args;
va_start(args, mp); /* init args to next argument from caller */
while (cur_arg != NULL) {
if (mp_init(cur_arg) != MP_OKAY) {
err = mp_init(cur_arg);
if (err != MP_OKAY) {
/* Oops - error! Back-track and mp_clear what we already
succeeded in init-ing, then return error.
*/
@ -28,14 +29,13 @@ mp_err mp_init_multi(mp_int *mp, ...)
cur_arg = va_arg(clean_args, mp_int *);
}
va_end(clean_args);
err = MP_MEM;
break;
}
n++;
cur_arg = va_arg(args, mp_int *);
}
va_end(args);
return err; /* Assumed ok, if error flagged above. */
return err;
}
#endif

View File

@ -6,7 +6,11 @@
/* init an mp_init for a given size */
mp_err mp_init_size(mp_int *a, int size)
{
size = MP_MAX(MP_MIN_PREC, size);
size = MP_MAX(MP_MIN_DIGIT_COUNT, size);
if (size > MP_MAX_DIGIT_COUNT) {
return MP_OVF;
}
/* alloc mem */
a->dp = (mp_digit *) MP_CALLOC((size_t)size, sizeof(mp_digit));

View File

@ -6,7 +6,7 @@
/* shrink a bignum */
mp_err mp_shrink(mp_int *a)
{
int alloc = MP_MAX(MP_MIN_PREC, a->used);
int alloc = MP_MAX(MP_MIN_DIGIT_COUNT, a->used);
if (a->alloc != alloc) {
mp_digit *dp = (mp_digit *) MP_REALLOC(a->dp,
(size_t)a->alloc * sizeof(mp_digit),

View File

@ -52,12 +52,7 @@ mp_err s_mp_radix_size_overestimate(const mp_int *a, const int radix, size_t *si
if (MP_HAS(S_MP_LOG_2EXPT) && MP_IS_2EXPT((mp_digit)radix)) {
/* floor(log_{2^n}(a)) + 1 + EOS + sign */
*size = (size_t)(s_mp_log_2expt(a, (mp_digit)radix));
/* Would overflow with base 2 otherwise */
if (*size > (INT_MAX - 4)) {
return MP_VAL;
}
*size += 3u;
*size = (size_t)(s_mp_log_2expt(a, (mp_digit)radix) + 3);
return MP_OKAY;
}

View File

@ -100,7 +100,8 @@ typedef enum {
MP_MEM = -2, /* out of mem */
MP_VAL = -3, /* invalid input */
MP_ITER = -4, /* maximum iterations reached */
MP_BUF = -5 /* buffer overflow, supplied buffer too small */
MP_BUF = -5, /* buffer overflow, supplied buffer too small */
MP_OVF = -6 /* mp_int overflow, too many digits */
} mp_err;
typedef enum {

View File

@ -36,3 +36,5 @@ typedef __UINT64_TYPE__ mp_u64;
#define MP_PRIi64 MP_PRI64_PREFIX "i"
#define MP_PRIu64 MP_PRI64_PREFIX "u"
#define MP_PRIx64 MP_PRI64_PREFIX "x"
#define MP_FUNCTION_NAME __func__

View File

@ -140,22 +140,29 @@ typedef uint64_t mp_word;
MP_STATIC_ASSERT(correct_word_size, sizeof(mp_word) == (2u * sizeof(mp_digit)))
/* default precision */
#ifndef MP_PREC
/* default number of digits */
#ifndef MP_DEFAULT_DIGIT_COUNT
# ifndef MP_LOW_MEM
# define MP_PREC 32 /* default digits of precision */
# define MP_DEFAULT_DIGIT_COUNT 32
# else
# define MP_PREC 8 /* default digits of precision */
# define MP_DEFAULT_DIGIT_COUNT 8
# endif
#endif
/* Minimum number of available digits in mp_int, MP_PREC >= MP_MIN_PREC
/* Minimum number of available digits in mp_int, MP_DEFAULT_DIGIT_COUNT >= MP_MIN_DIGIT_COUNT
* - Must be at least 3 for s_mp_div_school.
* - Must be large enough such that the mp_set_u64 setter can
* store uint64_t in the mp_int without growing
*/
#define MP_MIN_PREC MP_MAX(3, (((int)MP_SIZEOF_BITS(uint64_t) + MP_DIGIT_BIT) - 1) / MP_DIGIT_BIT)
MP_STATIC_ASSERT(prec_geq_min_prec, MP_PREC >= MP_MIN_PREC)
#define MP_MIN_DIGIT_COUNT MP_MAX(3, (((int)MP_SIZEOF_BITS(uint64_t) + MP_DIGIT_BIT) - 1) / MP_DIGIT_BIT)
MP_STATIC_ASSERT(prec_geq_min_prec, MP_DEFAULT_DIGIT_COUNT >= MP_MIN_DIGIT_COUNT)
/* Maximum number of digits.
* - Must be small enough such that mp_bit_count does not overflow.
* - Must be small enough such that mp_radix_size for base 2 does not overflow.
* mp_radix_size needs two additional bytes for zero termination and sign.
*/
#define MP_MAX_DIGIT_COUNT ((INT_MAX - 2) / MP_DIGIT_BIT)
/* random number source */
extern MP_PRIVATE mp_err(*s_mp_rand_source)(void *out, size_t size);