simplifications: toom and karatsuba

This commit is contained in:
Daniel Mendler 2019-10-29 20:05:30 +01:00
parent 143e0376a1
commit 7b6c6965bb
No known key found for this signature in database
GPG Key ID: D88ADB2A2693CA43
4 changed files with 71 additions and 114 deletions

View File

@ -6,14 +6,10 @@
/* single-digit multiplication with the smaller number as the single-digit */
mp_err s_mp_balance_mul(const mp_int *a, const mp_int *b, mp_int *c)
{
int count, len_a, len_b, nblocks, i, j, bsize;
mp_int a0, tmp, A, B, r;
mp_int a0, tmp, r;
mp_err err;
len_a = a->used;
len_b = b->used;
nblocks = MP_MAX(a->used, b->used) / MP_MIN(a->used, b->used);
int i, j, count,
nblocks = MP_MAX(a->used, b->used) / MP_MIN(a->used, b->used),
bsize = MP_MIN(a->used, b->used);
if ((err = mp_init_size(&a0, bsize + 2)) != MP_OKAY) {
@ -25,24 +21,20 @@ mp_err s_mp_balance_mul(const mp_int *a, const mp_int *b, mp_int *c)
}
/* Make sure that A is the larger one*/
if (len_a < len_b) {
B = *a;
A = *b;
} else {
A = *a;
B = *b;
if (a->used < b->used) {
MP_EXCH(const mp_int *, a, b);
}
for (i = 0, j=0; i < nblocks; i++) {
/* Cut a slice off of a */
a0.used = 0;
for (count = 0; count < bsize; count++) {
a0.dp[count] = A.dp[ j++ ];
a0.dp[count] = a->dp[ j++ ];
a0.used++;
}
mp_clamp(&a0);
/* Multiply with b */
if ((err = mp_mul(&a0, &B, &tmp)) != MP_OKAY) {
if ((err = mp_mul(&a0, b, &tmp)) != MP_OKAY) {
goto LBL_ERR;
}
/* Shift tmp to the correct position */
@ -55,14 +47,14 @@ mp_err s_mp_balance_mul(const mp_int *a, const mp_int *b, mp_int *c)
}
}
/* The left-overs; there are always left-overs */
if (j < A.used) {
if (j < a->used) {
a0.used = 0;
for (count = 0; j < A.used; count++) {
a0.dp[count] = A.dp[ j++ ];
for (count = 0; j < a->used; count++) {
a0.dp[count] = a->dp[ j++ ];
a0.used++;
}
mp_clamp(&a0);
if ((err = mp_mul(&a0, &B, &tmp)) != MP_OKAY) {
if ((err = mp_mul(&a0, b, &tmp)) != MP_OKAY) {
goto LBL_ERR;
}
if ((err = mp_lshd(&tmp, bsize * i)) != MP_OKAY) {

View File

@ -35,8 +35,8 @@
mp_err s_mp_karatsuba_mul(const mp_int *a, const mp_int *b, mp_int *c)
{
mp_int x0, x1, y0, y1, t1, x0y0, x1y1;
int B;
mp_err err = MP_MEM; /* default the return code to an error */
int B, i;
mp_err err;
/* min # of digits */
B = MP_MIN(a->used, b->used);
@ -45,27 +45,27 @@ mp_err s_mp_karatsuba_mul(const mp_int *a, const mp_int *b, mp_int *c)
B = B >> 1;
/* init copy all the temps */
if (mp_init_size(&x0, B) != MP_OKAY) {
if ((err = mp_init_size(&x0, B)) != MP_OKAY) {
goto LBL_ERR;
}
if (mp_init_size(&x1, a->used - B) != MP_OKAY) {
if ((err = mp_init_size(&x1, a->used - B)) != MP_OKAY) {
goto X0;
}
if (mp_init_size(&y0, B) != MP_OKAY) {
if ((err = mp_init_size(&y0, B)) != MP_OKAY) {
goto X1;
}
if (mp_init_size(&y1, b->used - B) != MP_OKAY) {
if ((err = mp_init_size(&y1, b->used - B)) != MP_OKAY) {
goto Y0;
}
/* init temps */
if (mp_init_size(&t1, B * 2) != MP_OKAY) {
if ((err = mp_init_size(&t1, B * 2)) != MP_OKAY) {
goto Y1;
}
if (mp_init_size(&x0y0, B * 2) != MP_OKAY) {
if ((err = mp_init_size(&x0y0, B * 2)) != MP_OKAY) {
goto T1;
}
if (mp_init_size(&x1y1, B * 2) != MP_OKAY) {
if ((err = mp_init_size(&x1y1, B * 2)) != MP_OKAY) {
goto X0Y0;
}
@ -74,32 +74,18 @@ mp_err s_mp_karatsuba_mul(const mp_int *a, const mp_int *b, mp_int *c)
x1.used = a->used - B;
y1.used = b->used - B;
{
int x;
mp_digit *tmpa, *tmpb, *tmpx, *tmpy;
/* we copy the digits directly instead of using higher level functions
* since we also need to shift the digits
*/
tmpa = a->dp;
tmpb = b->dp;
tmpx = x0.dp;
tmpy = y0.dp;
for (x = 0; x < B; x++) {
*tmpx++ = *tmpa++;
*tmpy++ = *tmpb++;
for (i = 0; i < B; i++) {
x0.dp[i] = a->dp[i];
y0.dp[i] = b->dp[i];
}
tmpx = x1.dp;
for (x = B; x < a->used; x++) {
*tmpx++ = *tmpa++;
}
tmpy = y1.dp;
for (x = B; x < b->used; x++) {
*tmpy++ = *tmpb++;
for (i = B; i < a->used; i++) {
x1.dp[i - B] = a->dp[i];
}
for (i = B; i < b->used; i++) {
y1.dp[i - B] = b->dp[i];
}
/* only need to clamp the lower words since by definition the
@ -110,50 +96,47 @@ mp_err s_mp_karatsuba_mul(const mp_int *a, const mp_int *b, mp_int *c)
/* now calc the products x0y0 and x1y1 */
/* after this x0 is no longer required, free temp [x0==t2]! */
if (mp_mul(&x0, &y0, &x0y0) != MP_OKAY) {
if ((err = mp_mul(&x0, &y0, &x0y0)) != MP_OKAY) {
goto X1Y1; /* x0y0 = x0*y0 */
}
if (mp_mul(&x1, &y1, &x1y1) != MP_OKAY) {
if ((err = mp_mul(&x1, &y1, &x1y1)) != MP_OKAY) {
goto X1Y1; /* x1y1 = x1*y1 */
}
/* now calc x1+x0 and y1+y0 */
if (s_mp_add(&x1, &x0, &t1) != MP_OKAY) {
if ((err = s_mp_add(&x1, &x0, &t1)) != MP_OKAY) {
goto X1Y1; /* t1 = x1 - x0 */
}
if (s_mp_add(&y1, &y0, &x0) != MP_OKAY) {
if ((err = s_mp_add(&y1, &y0, &x0)) != MP_OKAY) {
goto X1Y1; /* t2 = y1 - y0 */
}
if (mp_mul(&t1, &x0, &t1) != MP_OKAY) {
if ((err = mp_mul(&t1, &x0, &t1)) != MP_OKAY) {
goto X1Y1; /* t1 = (x1 + x0) * (y1 + y0) */
}
/* add x0y0 */
if (mp_add(&x0y0, &x1y1, &x0) != MP_OKAY) {
if ((err = mp_add(&x0y0, &x1y1, &x0)) != MP_OKAY) {
goto X1Y1; /* t2 = x0y0 + x1y1 */
}
if (s_mp_sub(&t1, &x0, &t1) != MP_OKAY) {
if ((err = s_mp_sub(&t1, &x0, &t1)) != MP_OKAY) {
goto X1Y1; /* t1 = (x1+x0)*(y1+y0) - (x1y1 + x0y0) */
}
/* shift by B */
if (mp_lshd(&t1, B) != MP_OKAY) {
if ((err = mp_lshd(&t1, B)) != MP_OKAY) {
goto X1Y1; /* t1 = (x0y0 + x1y1 - (x1-x0)*(y1-y0))<<B */
}
if (mp_lshd(&x1y1, B * 2) != MP_OKAY) {
if ((err = mp_lshd(&x1y1, B * 2)) != MP_OKAY) {
goto X1Y1; /* x1y1 = x1y1 << 2*B */
}
if (mp_add(&x0y0, &t1, &t1) != MP_OKAY) {
if ((err = mp_add(&x0y0, &t1, &t1)) != MP_OKAY) {
goto X1Y1; /* t1 = x0y0 + t1 */
}
if (mp_add(&t1, &x1y1, c) != MP_OKAY) {
if ((err = mp_add(&t1, &x1y1, c)) != MP_OKAY) {
goto X1Y1; /* t1 = x0y0 + t1 + x1y1 */
}
/* Algorithm succeeded set the return code to MP_OKAY */
err = MP_OKAY;
X1Y1:
mp_clear(&x1y1);
X0Y0:

View File

@ -13,8 +13,8 @@
mp_err s_mp_karatsuba_sqr(const mp_int *a, mp_int *b)
{
mp_int x0, x1, t1, t2, x0x0, x1x1;
int B;
mp_err err = MP_MEM;
int B, x;
mp_err err;
/* min # of digits */
B = a->used;
@ -23,37 +23,27 @@ mp_err s_mp_karatsuba_sqr(const mp_int *a, mp_int *b)
B = B >> 1;
/* init copy all the temps */
if (mp_init_size(&x0, B) != MP_OKAY)
if ((err = mp_init_size(&x0, B)) != MP_OKAY)
goto LBL_ERR;
if (mp_init_size(&x1, a->used - B) != MP_OKAY)
if ((err = mp_init_size(&x1, a->used - B)) != MP_OKAY)
goto X0;
/* init temps */
if (mp_init_size(&t1, a->used * 2) != MP_OKAY)
if ((err = mp_init_size(&t1, a->used * 2)) != MP_OKAY)
goto X1;
if (mp_init_size(&t2, a->used * 2) != MP_OKAY)
if ((err = mp_init_size(&t2, a->used * 2)) != MP_OKAY)
goto T1;
if (mp_init_size(&x0x0, B * 2) != MP_OKAY)
if ((err = mp_init_size(&x0x0, B * 2)) != MP_OKAY)
goto T2;
if (mp_init_size(&x1x1, (a->used - B) * 2) != MP_OKAY)
if ((err = mp_init_size(&x1x1, (a->used - B) * 2)) != MP_OKAY)
goto X0X0;
{
int x;
mp_digit *dst, *src;
src = a->dp;
/* now shift the digits */
dst = x0.dp;
for (x = 0; x < B; x++) {
*dst++ = *src++;
x0.dp[x] = a->dp[x];
}
dst = x1.dp;
for (x = B; x < a->used; x++) {
*dst++ = *src++;
}
x1.dp[x - B] = a->dp[x];
}
x0.used = B;
@ -62,36 +52,34 @@ mp_err s_mp_karatsuba_sqr(const mp_int *a, mp_int *b)
mp_clamp(&x0);
/* now calc the products x0*x0 and x1*x1 */
if (mp_sqr(&x0, &x0x0) != MP_OKAY)
if ((err = mp_sqr(&x0, &x0x0)) != MP_OKAY)
goto X1X1; /* x0x0 = x0*x0 */
if (mp_sqr(&x1, &x1x1) != MP_OKAY)
if ((err = mp_sqr(&x1, &x1x1)) != MP_OKAY)
goto X1X1; /* x1x1 = x1*x1 */
/* now calc (x1+x0)**2 */
if (s_mp_add(&x1, &x0, &t1) != MP_OKAY)
if ((err = s_mp_add(&x1, &x0, &t1)) != MP_OKAY)
goto X1X1; /* t1 = x1 - x0 */
if (mp_sqr(&t1, &t1) != MP_OKAY)
if ((err = mp_sqr(&t1, &t1)) != MP_OKAY)
goto X1X1; /* t1 = (x1 - x0) * (x1 - x0) */
/* add x0y0 */
if (s_mp_add(&x0x0, &x1x1, &t2) != MP_OKAY)
if ((err = s_mp_add(&x0x0, &x1x1, &t2)) != MP_OKAY)
goto X1X1; /* t2 = x0x0 + x1x1 */
if (s_mp_sub(&t1, &t2, &t1) != MP_OKAY)
if ((err = s_mp_sub(&t1, &t2, &t1)) != MP_OKAY)
goto X1X1; /* t1 = (x1+x0)**2 - (x0x0 + x1x1) */
/* shift by B */
if (mp_lshd(&t1, B) != MP_OKAY)
if ((err = mp_lshd(&t1, B)) != MP_OKAY)
goto X1X1; /* t1 = (x0x0 + x1x1 - (x1-x0)*(x1-x0))<<B */
if (mp_lshd(&x1x1, B * 2) != MP_OKAY)
if ((err = mp_lshd(&x1x1, B * 2)) != MP_OKAY)
goto X1X1; /* x1x1 = x1x1 << 2*B */
if (mp_add(&x0x0, &t1, &t1) != MP_OKAY)
if ((err = mp_add(&x0x0, &t1, &t1)) != MP_OKAY)
goto X1X1; /* t1 = x0x0 + t1 */
if (mp_add(&t1, &x1x1, b) != MP_OKAY)
if ((err = mp_add(&t1, &x1x1, b)) != MP_OKAY)
goto X1X1; /* t1 = x0x0 + t1 + x1x1 */
err = MP_OKAY;
X1X1:
mp_clear(&x1x1);
X0X0:

View File

@ -21,11 +21,9 @@
mp_err s_mp_toom_sqr(const mp_int *a, mp_int *b)
{
mp_int S0, a0, a1, a2;
mp_digit *tmpa, *tmpc;
int B, count;
mp_err err;
/* init temps */
if ((err = mp_init(&S0)) != MP_OKAY) {
return err;
@ -42,18 +40,14 @@ mp_err s_mp_toom_sqr(const mp_int *a, mp_int *b)
a1.used = B;
if ((err = mp_init_size(&a2, B + (a->used - (3 * B)))) != MP_OKAY) goto LBL_ERRa2;
tmpa = a->dp;
tmpc = a0.dp;
for (count = 0; count < B; count++) {
*tmpc++ = *tmpa++;
a0.dp[count] = a->dp[count];
}
tmpc = a1.dp;
for (; count < (2 * B); count++) {
*tmpc++ = *tmpa++;
a1.dp[count - B] = a->dp[count];
}
tmpc = a2.dp;
for (; count < a->used; count++) {
*tmpc++ = *tmpa++;
a2.dp[count - 2 * B] = a->dp[count];
a2.used++;
}
mp_clamp(&a0);