simplifications: toom and karatsuba
This commit is contained in:
parent
143e0376a1
commit
7b6c6965bb
@ -6,14 +6,10 @@
|
||||
/* single-digit multiplication with the smaller number as the single-digit */
|
||||
mp_err s_mp_balance_mul(const mp_int *a, const mp_int *b, mp_int *c)
|
||||
{
|
||||
int count, len_a, len_b, nblocks, i, j, bsize;
|
||||
mp_int a0, tmp, A, B, r;
|
||||
mp_int a0, tmp, r;
|
||||
mp_err err;
|
||||
|
||||
len_a = a->used;
|
||||
len_b = b->used;
|
||||
|
||||
nblocks = MP_MAX(a->used, b->used) / MP_MIN(a->used, b->used);
|
||||
int i, j, count,
|
||||
nblocks = MP_MAX(a->used, b->used) / MP_MIN(a->used, b->used),
|
||||
bsize = MP_MIN(a->used, b->used);
|
||||
|
||||
if ((err = mp_init_size(&a0, bsize + 2)) != MP_OKAY) {
|
||||
@ -25,24 +21,20 @@ mp_err s_mp_balance_mul(const mp_int *a, const mp_int *b, mp_int *c)
|
||||
}
|
||||
|
||||
/* Make sure that A is the larger one*/
|
||||
if (len_a < len_b) {
|
||||
B = *a;
|
||||
A = *b;
|
||||
} else {
|
||||
A = *a;
|
||||
B = *b;
|
||||
if (a->used < b->used) {
|
||||
MP_EXCH(const mp_int *, a, b);
|
||||
}
|
||||
|
||||
for (i = 0, j=0; i < nblocks; i++) {
|
||||
/* Cut a slice off of a */
|
||||
a0.used = 0;
|
||||
for (count = 0; count < bsize; count++) {
|
||||
a0.dp[count] = A.dp[ j++ ];
|
||||
a0.dp[count] = a->dp[ j++ ];
|
||||
a0.used++;
|
||||
}
|
||||
mp_clamp(&a0);
|
||||
/* Multiply with b */
|
||||
if ((err = mp_mul(&a0, &B, &tmp)) != MP_OKAY) {
|
||||
if ((err = mp_mul(&a0, b, &tmp)) != MP_OKAY) {
|
||||
goto LBL_ERR;
|
||||
}
|
||||
/* Shift tmp to the correct position */
|
||||
@ -55,14 +47,14 @@ mp_err s_mp_balance_mul(const mp_int *a, const mp_int *b, mp_int *c)
|
||||
}
|
||||
}
|
||||
/* The left-overs; there are always left-overs */
|
||||
if (j < A.used) {
|
||||
if (j < a->used) {
|
||||
a0.used = 0;
|
||||
for (count = 0; j < A.used; count++) {
|
||||
a0.dp[count] = A.dp[ j++ ];
|
||||
for (count = 0; j < a->used; count++) {
|
||||
a0.dp[count] = a->dp[ j++ ];
|
||||
a0.used++;
|
||||
}
|
||||
mp_clamp(&a0);
|
||||
if ((err = mp_mul(&a0, &B, &tmp)) != MP_OKAY) {
|
||||
if ((err = mp_mul(&a0, b, &tmp)) != MP_OKAY) {
|
||||
goto LBL_ERR;
|
||||
}
|
||||
if ((err = mp_lshd(&tmp, bsize * i)) != MP_OKAY) {
|
||||
|
@ -35,8 +35,8 @@
|
||||
mp_err s_mp_karatsuba_mul(const mp_int *a, const mp_int *b, mp_int *c)
|
||||
{
|
||||
mp_int x0, x1, y0, y1, t1, x0y0, x1y1;
|
||||
int B;
|
||||
mp_err err = MP_MEM; /* default the return code to an error */
|
||||
int B, i;
|
||||
mp_err err;
|
||||
|
||||
/* min # of digits */
|
||||
B = MP_MIN(a->used, b->used);
|
||||
@ -45,27 +45,27 @@ mp_err s_mp_karatsuba_mul(const mp_int *a, const mp_int *b, mp_int *c)
|
||||
B = B >> 1;
|
||||
|
||||
/* init copy all the temps */
|
||||
if (mp_init_size(&x0, B) != MP_OKAY) {
|
||||
if ((err = mp_init_size(&x0, B)) != MP_OKAY) {
|
||||
goto LBL_ERR;
|
||||
}
|
||||
if (mp_init_size(&x1, a->used - B) != MP_OKAY) {
|
||||
if ((err = mp_init_size(&x1, a->used - B)) != MP_OKAY) {
|
||||
goto X0;
|
||||
}
|
||||
if (mp_init_size(&y0, B) != MP_OKAY) {
|
||||
if ((err = mp_init_size(&y0, B)) != MP_OKAY) {
|
||||
goto X1;
|
||||
}
|
||||
if (mp_init_size(&y1, b->used - B) != MP_OKAY) {
|
||||
if ((err = mp_init_size(&y1, b->used - B)) != MP_OKAY) {
|
||||
goto Y0;
|
||||
}
|
||||
|
||||
/* init temps */
|
||||
if (mp_init_size(&t1, B * 2) != MP_OKAY) {
|
||||
if ((err = mp_init_size(&t1, B * 2)) != MP_OKAY) {
|
||||
goto Y1;
|
||||
}
|
||||
if (mp_init_size(&x0y0, B * 2) != MP_OKAY) {
|
||||
if ((err = mp_init_size(&x0y0, B * 2)) != MP_OKAY) {
|
||||
goto T1;
|
||||
}
|
||||
if (mp_init_size(&x1y1, B * 2) != MP_OKAY) {
|
||||
if ((err = mp_init_size(&x1y1, B * 2)) != MP_OKAY) {
|
||||
goto X0Y0;
|
||||
}
|
||||
|
||||
@ -74,32 +74,18 @@ mp_err s_mp_karatsuba_mul(const mp_int *a, const mp_int *b, mp_int *c)
|
||||
x1.used = a->used - B;
|
||||
y1.used = b->used - B;
|
||||
|
||||
{
|
||||
int x;
|
||||
mp_digit *tmpa, *tmpb, *tmpx, *tmpy;
|
||||
|
||||
/* we copy the digits directly instead of using higher level functions
|
||||
* since we also need to shift the digits
|
||||
*/
|
||||
tmpa = a->dp;
|
||||
tmpb = b->dp;
|
||||
|
||||
tmpx = x0.dp;
|
||||
tmpy = y0.dp;
|
||||
for (x = 0; x < B; x++) {
|
||||
*tmpx++ = *tmpa++;
|
||||
*tmpy++ = *tmpb++;
|
||||
for (i = 0; i < B; i++) {
|
||||
x0.dp[i] = a->dp[i];
|
||||
y0.dp[i] = b->dp[i];
|
||||
}
|
||||
|
||||
tmpx = x1.dp;
|
||||
for (x = B; x < a->used; x++) {
|
||||
*tmpx++ = *tmpa++;
|
||||
}
|
||||
|
||||
tmpy = y1.dp;
|
||||
for (x = B; x < b->used; x++) {
|
||||
*tmpy++ = *tmpb++;
|
||||
for (i = B; i < a->used; i++) {
|
||||
x1.dp[i - B] = a->dp[i];
|
||||
}
|
||||
for (i = B; i < b->used; i++) {
|
||||
y1.dp[i - B] = b->dp[i];
|
||||
}
|
||||
|
||||
/* only need to clamp the lower words since by definition the
|
||||
@ -110,50 +96,47 @@ mp_err s_mp_karatsuba_mul(const mp_int *a, const mp_int *b, mp_int *c)
|
||||
|
||||
/* now calc the products x0y0 and x1y1 */
|
||||
/* after this x0 is no longer required, free temp [x0==t2]! */
|
||||
if (mp_mul(&x0, &y0, &x0y0) != MP_OKAY) {
|
||||
if ((err = mp_mul(&x0, &y0, &x0y0)) != MP_OKAY) {
|
||||
goto X1Y1; /* x0y0 = x0*y0 */
|
||||
}
|
||||
if (mp_mul(&x1, &y1, &x1y1) != MP_OKAY) {
|
||||
if ((err = mp_mul(&x1, &y1, &x1y1)) != MP_OKAY) {
|
||||
goto X1Y1; /* x1y1 = x1*y1 */
|
||||
}
|
||||
|
||||
/* now calc x1+x0 and y1+y0 */
|
||||
if (s_mp_add(&x1, &x0, &t1) != MP_OKAY) {
|
||||
if ((err = s_mp_add(&x1, &x0, &t1)) != MP_OKAY) {
|
||||
goto X1Y1; /* t1 = x1 - x0 */
|
||||
}
|
||||
if (s_mp_add(&y1, &y0, &x0) != MP_OKAY) {
|
||||
if ((err = s_mp_add(&y1, &y0, &x0)) != MP_OKAY) {
|
||||
goto X1Y1; /* t2 = y1 - y0 */
|
||||
}
|
||||
if (mp_mul(&t1, &x0, &t1) != MP_OKAY) {
|
||||
if ((err = mp_mul(&t1, &x0, &t1)) != MP_OKAY) {
|
||||
goto X1Y1; /* t1 = (x1 + x0) * (y1 + y0) */
|
||||
}
|
||||
|
||||
/* add x0y0 */
|
||||
if (mp_add(&x0y0, &x1y1, &x0) != MP_OKAY) {
|
||||
if ((err = mp_add(&x0y0, &x1y1, &x0)) != MP_OKAY) {
|
||||
goto X1Y1; /* t2 = x0y0 + x1y1 */
|
||||
}
|
||||
if (s_mp_sub(&t1, &x0, &t1) != MP_OKAY) {
|
||||
if ((err = s_mp_sub(&t1, &x0, &t1)) != MP_OKAY) {
|
||||
goto X1Y1; /* t1 = (x1+x0)*(y1+y0) - (x1y1 + x0y0) */
|
||||
}
|
||||
|
||||
/* shift by B */
|
||||
if (mp_lshd(&t1, B) != MP_OKAY) {
|
||||
if ((err = mp_lshd(&t1, B)) != MP_OKAY) {
|
||||
goto X1Y1; /* t1 = (x0y0 + x1y1 - (x1-x0)*(y1-y0))<<B */
|
||||
}
|
||||
if (mp_lshd(&x1y1, B * 2) != MP_OKAY) {
|
||||
if ((err = mp_lshd(&x1y1, B * 2)) != MP_OKAY) {
|
||||
goto X1Y1; /* x1y1 = x1y1 << 2*B */
|
||||
}
|
||||
|
||||
if (mp_add(&x0y0, &t1, &t1) != MP_OKAY) {
|
||||
if ((err = mp_add(&x0y0, &t1, &t1)) != MP_OKAY) {
|
||||
goto X1Y1; /* t1 = x0y0 + t1 */
|
||||
}
|
||||
if (mp_add(&t1, &x1y1, c) != MP_OKAY) {
|
||||
if ((err = mp_add(&t1, &x1y1, c)) != MP_OKAY) {
|
||||
goto X1Y1; /* t1 = x0y0 + t1 + x1y1 */
|
||||
}
|
||||
|
||||
/* Algorithm succeeded set the return code to MP_OKAY */
|
||||
err = MP_OKAY;
|
||||
|
||||
X1Y1:
|
||||
mp_clear(&x1y1);
|
||||
X0Y0:
|
||||
|
@ -13,8 +13,8 @@
|
||||
mp_err s_mp_karatsuba_sqr(const mp_int *a, mp_int *b)
|
||||
{
|
||||
mp_int x0, x1, t1, t2, x0x0, x1x1;
|
||||
int B;
|
||||
mp_err err = MP_MEM;
|
||||
int B, x;
|
||||
mp_err err;
|
||||
|
||||
/* min # of digits */
|
||||
B = a->used;
|
||||
@ -23,37 +23,27 @@ mp_err s_mp_karatsuba_sqr(const mp_int *a, mp_int *b)
|
||||
B = B >> 1;
|
||||
|
||||
/* init copy all the temps */
|
||||
if (mp_init_size(&x0, B) != MP_OKAY)
|
||||
if ((err = mp_init_size(&x0, B)) != MP_OKAY)
|
||||
goto LBL_ERR;
|
||||
if (mp_init_size(&x1, a->used - B) != MP_OKAY)
|
||||
if ((err = mp_init_size(&x1, a->used - B)) != MP_OKAY)
|
||||
goto X0;
|
||||
|
||||
/* init temps */
|
||||
if (mp_init_size(&t1, a->used * 2) != MP_OKAY)
|
||||
if ((err = mp_init_size(&t1, a->used * 2)) != MP_OKAY)
|
||||
goto X1;
|
||||
if (mp_init_size(&t2, a->used * 2) != MP_OKAY)
|
||||
if ((err = mp_init_size(&t2, a->used * 2)) != MP_OKAY)
|
||||
goto T1;
|
||||
if (mp_init_size(&x0x0, B * 2) != MP_OKAY)
|
||||
if ((err = mp_init_size(&x0x0, B * 2)) != MP_OKAY)
|
||||
goto T2;
|
||||
if (mp_init_size(&x1x1, (a->used - B) * 2) != MP_OKAY)
|
||||
if ((err = mp_init_size(&x1x1, (a->used - B) * 2)) != MP_OKAY)
|
||||
goto X0X0;
|
||||
|
||||
{
|
||||
int x;
|
||||
mp_digit *dst, *src;
|
||||
|
||||
src = a->dp;
|
||||
|
||||
/* now shift the digits */
|
||||
dst = x0.dp;
|
||||
for (x = 0; x < B; x++) {
|
||||
*dst++ = *src++;
|
||||
x0.dp[x] = a->dp[x];
|
||||
}
|
||||
|
||||
dst = x1.dp;
|
||||
for (x = B; x < a->used; x++) {
|
||||
*dst++ = *src++;
|
||||
}
|
||||
x1.dp[x - B] = a->dp[x];
|
||||
}
|
||||
|
||||
x0.used = B;
|
||||
@ -62,36 +52,34 @@ mp_err s_mp_karatsuba_sqr(const mp_int *a, mp_int *b)
|
||||
mp_clamp(&x0);
|
||||
|
||||
/* now calc the products x0*x0 and x1*x1 */
|
||||
if (mp_sqr(&x0, &x0x0) != MP_OKAY)
|
||||
if ((err = mp_sqr(&x0, &x0x0)) != MP_OKAY)
|
||||
goto X1X1; /* x0x0 = x0*x0 */
|
||||
if (mp_sqr(&x1, &x1x1) != MP_OKAY)
|
||||
if ((err = mp_sqr(&x1, &x1x1)) != MP_OKAY)
|
||||
goto X1X1; /* x1x1 = x1*x1 */
|
||||
|
||||
/* now calc (x1+x0)**2 */
|
||||
if (s_mp_add(&x1, &x0, &t1) != MP_OKAY)
|
||||
if ((err = s_mp_add(&x1, &x0, &t1)) != MP_OKAY)
|
||||
goto X1X1; /* t1 = x1 - x0 */
|
||||
if (mp_sqr(&t1, &t1) != MP_OKAY)
|
||||
if ((err = mp_sqr(&t1, &t1)) != MP_OKAY)
|
||||
goto X1X1; /* t1 = (x1 - x0) * (x1 - x0) */
|
||||
|
||||
/* add x0y0 */
|
||||
if (s_mp_add(&x0x0, &x1x1, &t2) != MP_OKAY)
|
||||
if ((err = s_mp_add(&x0x0, &x1x1, &t2)) != MP_OKAY)
|
||||
goto X1X1; /* t2 = x0x0 + x1x1 */
|
||||
if (s_mp_sub(&t1, &t2, &t1) != MP_OKAY)
|
||||
if ((err = s_mp_sub(&t1, &t2, &t1)) != MP_OKAY)
|
||||
goto X1X1; /* t1 = (x1+x0)**2 - (x0x0 + x1x1) */
|
||||
|
||||
/* shift by B */
|
||||
if (mp_lshd(&t1, B) != MP_OKAY)
|
||||
if ((err = mp_lshd(&t1, B)) != MP_OKAY)
|
||||
goto X1X1; /* t1 = (x0x0 + x1x1 - (x1-x0)*(x1-x0))<<B */
|
||||
if (mp_lshd(&x1x1, B * 2) != MP_OKAY)
|
||||
if ((err = mp_lshd(&x1x1, B * 2)) != MP_OKAY)
|
||||
goto X1X1; /* x1x1 = x1x1 << 2*B */
|
||||
|
||||
if (mp_add(&x0x0, &t1, &t1) != MP_OKAY)
|
||||
if ((err = mp_add(&x0x0, &t1, &t1)) != MP_OKAY)
|
||||
goto X1X1; /* t1 = x0x0 + t1 */
|
||||
if (mp_add(&t1, &x1x1, b) != MP_OKAY)
|
||||
if ((err = mp_add(&t1, &x1x1, b)) != MP_OKAY)
|
||||
goto X1X1; /* t1 = x0x0 + t1 + x1x1 */
|
||||
|
||||
err = MP_OKAY;
|
||||
|
||||
X1X1:
|
||||
mp_clear(&x1x1);
|
||||
X0X0:
|
||||
|
@ -21,11 +21,9 @@
|
||||
mp_err s_mp_toom_sqr(const mp_int *a, mp_int *b)
|
||||
{
|
||||
mp_int S0, a0, a1, a2;
|
||||
mp_digit *tmpa, *tmpc;
|
||||
int B, count;
|
||||
mp_err err;
|
||||
|
||||
|
||||
/* init temps */
|
||||
if ((err = mp_init(&S0)) != MP_OKAY) {
|
||||
return err;
|
||||
@ -42,18 +40,14 @@ mp_err s_mp_toom_sqr(const mp_int *a, mp_int *b)
|
||||
a1.used = B;
|
||||
if ((err = mp_init_size(&a2, B + (a->used - (3 * B)))) != MP_OKAY) goto LBL_ERRa2;
|
||||
|
||||
tmpa = a->dp;
|
||||
tmpc = a0.dp;
|
||||
for (count = 0; count < B; count++) {
|
||||
*tmpc++ = *tmpa++;
|
||||
a0.dp[count] = a->dp[count];
|
||||
}
|
||||
tmpc = a1.dp;
|
||||
for (; count < (2 * B); count++) {
|
||||
*tmpc++ = *tmpa++;
|
||||
a1.dp[count - B] = a->dp[count];
|
||||
}
|
||||
tmpc = a2.dp;
|
||||
for (; count < a->used; count++) {
|
||||
*tmpc++ = *tmpa++;
|
||||
a2.dp[count - 2 * B] = a->dp[count];
|
||||
a2.used++;
|
||||
}
|
||||
mp_clamp(&a0);
|
||||
|
Loading…
Reference in New Issue
Block a user