suffix _u32 -> _n of mp_(expt|log|root) functions, use int for now

This commit is contained in:
Daniel Mendler 2019-10-29 20:52:29 +01:00 committed by Steffen Jaeckel
parent 86648a0d23
commit f6a7bedb95
14 changed files with 135 additions and 145 deletions

View File

@ -729,7 +729,7 @@ static int test_mp_sqrt(void)
printf("\nmp_sqrt() error!"); printf("\nmp_sqrt() error!");
goto LBL_ERR; goto LBL_ERR;
} }
DO(mp_root_u32(&a, 2u, &c)); DO(mp_root_n(&a, 2u, &c));
if (mp_cmp_mag(&b, &c) != MP_EQ) { if (mp_cmp_mag(&b, &c) != MP_EQ) {
printf("mp_sqrt() bad result!\n"); printf("mp_sqrt() bad result!\n");
goto LBL_ERR; goto LBL_ERR;
@ -1396,10 +1396,10 @@ LBL_ERR:
/* stripped down version of mp_radix_size. The faster version can be off by up t /* stripped down version of mp_radix_size. The faster version can be off by up t
o +3 */ o +3 */
/* TODO: This function should be removed, replaced by mp_radix_size, mp_radix_size_overestimate in 2.0 */ /* TODO: This function should be removed, replaced by mp_radix_size, mp_radix_size_overestimate in 2.0 */
static mp_err s_rs(const mp_int *a, int radix, uint32_t *size) static mp_err s_rs(const mp_int *a, int radix, int *size)
{ {
mp_err res; mp_err res;
uint32_t digs = 0u; int digs = 0u;
mp_int t; mp_int t;
mp_digit d; mp_digit d;
*size = 0u; *size = 0u;
@ -1408,7 +1408,7 @@ static mp_err s_rs(const mp_int *a, int radix, uint32_t *size)
return MP_OKAY; return MP_OKAY;
} }
if (radix == 2) { if (radix == 2) {
*size = (uint32_t)mp_count_bits(a) + 1u; *size = mp_count_bits(a) + 1;
return MP_OKAY; return MP_OKAY;
} }
DOR(mp_init_copy(&t, a)); DOR(mp_init_copy(&t, a));
@ -1424,12 +1424,12 @@ static mp_err s_rs(const mp_int *a, int radix, uint32_t *size)
*size = digs + 1; *size = digs + 1;
return MP_OKAY; return MP_OKAY;
} }
static int test_mp_log_u32(void) static int test_mp_log_n(void)
{ {
mp_int a; mp_int a;
mp_digit d; mp_digit d;
uint32_t base, lb, size; int base, lb, size;
const uint32_t max_base = MP_MIN(UINT32_MAX, MP_DIGIT_MAX); const int max_base = MP_MIN(INT_MAX, MP_DIGIT_MAX);
DOR(mp_init(&a)); DOR(mp_init(&a));
@ -1440,11 +1440,11 @@ static int test_mp_log_u32(void)
*/ */
mp_set(&a, 42u); mp_set(&a, 42u);
base = 0u; base = 0u;
if (mp_log_u32(&a, base, &lb) != MP_VAL) { if (mp_log_n(&a, base, &lb) != MP_VAL) {
goto LBL_ERR; goto LBL_ERR;
} }
base = 1u; base = 1u;
if (mp_log_u32(&a, base, &lb) != MP_VAL) { if (mp_log_n(&a, base, &lb) != MP_VAL) {
goto LBL_ERR; goto LBL_ERR;
} }
/* /*
@ -1456,14 +1456,14 @@ static int test_mp_log_u32(void)
*/ */
base = 2u; base = 2u;
mp_zero(&a); mp_zero(&a);
if (mp_log_u32(&a, base, &lb) != MP_VAL) { if (mp_log_n(&a, base, &lb) != MP_VAL) {
goto LBL_ERR; goto LBL_ERR;
} }
for (d = 1; d < 4; d++) { for (d = 1; d < 4; d++) {
mp_set(&a, d); mp_set(&a, d);
DO(mp_log_u32(&a, base, &lb)); DO(mp_log_n(&a, base, &lb));
if (lb != ((d == 1)?0uL:1uL)) { if (lb != ((d == 1)?0:1)) {
goto LBL_ERR; goto LBL_ERR;
} }
} }
@ -1476,13 +1476,13 @@ static int test_mp_log_u32(void)
*/ */
base = 3u; base = 3u;
mp_zero(&a); mp_zero(&a);
if (mp_log_u32(&a, base, &lb) != MP_VAL) { if (mp_log_n(&a, base, &lb) != MP_VAL) {
goto LBL_ERR; goto LBL_ERR;
} }
for (d = 1; d < 4; d++) { for (d = 1; d < 4; d++) {
mp_set(&a, d); mp_set(&a, d);
DO(mp_log_u32(&a, base, &lb)); DO(mp_log_n(&a, base, &lb));
if (lb != ((d < base)?0uL:1uL)) { if (lb != (((int)d < base)?0:1)) {
goto LBL_ERR; goto LBL_ERR;
} }
} }
@ -1493,8 +1493,8 @@ static int test_mp_log_u32(void)
radix_size. radix_size.
*/ */
DO(mp_rand(&a, 10)); DO(mp_rand(&a, 10));
for (base = 2u; base < 65u; base++) { for (base = 2; base < 65; base++) {
DO(mp_log_u32(&a, base, &lb)); DO(mp_log_n(&a, base, &lb));
DO(s_rs(&a,(int)base, &size)); DO(s_rs(&a,(int)base, &size));
/* radix_size includes the memory needed for '\0', too*/ /* radix_size includes the memory needed for '\0', too*/
size -= 2; size -= 2;
@ -1508,8 +1508,8 @@ static int test_mp_log_u32(void)
test the part of mp_ilogb that uses native types. test the part of mp_ilogb that uses native types.
*/ */
DO(mp_rand(&a, 1)); DO(mp_rand(&a, 1));
for (base = 2u; base < 65u; base++) { for (base = 2; base < 65; base++) {
DO(mp_log_u32(&a, base, &lb)); DO(mp_log_n(&a, base, &lb));
DO(s_rs(&a,(int)base, &size)); DO(s_rs(&a,(int)base, &size));
size -= 2; size -= 2;
if (lb != size) { if (lb != size) {
@ -1519,9 +1519,9 @@ static int test_mp_log_u32(void)
/*Test upper edgecase with base UINT32_MAX and number (UINT32_MAX/2)*UINT32_MAX^10 */ /*Test upper edgecase with base UINT32_MAX and number (UINT32_MAX/2)*UINT32_MAX^10 */
mp_set(&a, max_base); mp_set(&a, max_base);
DO(mp_expt_u32(&a, 10u, &a)); DO(mp_expt_n(&a, 10uL, &a));
DO(mp_add_d(&a, max_base / 2u, &a)); DO(mp_add_d(&a, max_base / 2, &a));
DO(mp_log_u32(&a, max_base, &lb)); DO(mp_log_n(&a, max_base, &lb));
if (lb != 10u) { if (lb != 10u) {
goto LBL_ERR; goto LBL_ERR;
} }
@ -1636,7 +1636,7 @@ LBL_ERR:
} }
/* /*
Cannot test mp_exp(_d) without mp_root and vice versa. Cannot test mp_exp(_d) without mp_root_n and vice versa.
So one of the two has to be tested from scratch. So one of the two has to be tested from scratch.
Numbers generated by Numbers generated by
@ -1658,7 +1658,7 @@ LBL_ERR:
low-mp branch. low-mp branch.
*/ */
static int test_mp_root_u32(void) static int test_mp_root_n(void)
{ {
mp_int a, c, r; mp_int a, c, r;
int i, j; int i, j;
@ -1850,10 +1850,10 @@ static int test_mp_root_u32(void)
for (i = 0; i < 10; i++) { for (i = 0; i < 10; i++) {
DO(mp_read_radix(&a, input[i], 64)); DO(mp_read_radix(&a, input[i], 64));
for (j = 3; j < 100; j++) { for (j = 3; j < 100; j++) {
DO(mp_root_u32(&a, (uint32_t)j, &c)); DO(mp_root_n(&a, j, &c));
DO(mp_read_radix(&r, root[i][j-3], 10)); DO(mp_read_radix(&r, root[i][j-3], 10));
if (mp_cmp(&r, &c) != MP_EQ) { if (mp_cmp(&r, &c) != MP_EQ) {
fprintf(stderr, "mp_root_u32 failed at input #%d, root #%d\n", i, j); fprintf(stderr, "mp_root_n failed at input #%d, root #%d\n", i, j);
goto LBL_ERR; goto LBL_ERR;
} }
} }
@ -2037,8 +2037,8 @@ static int test_mp_radix_size(void)
DOR(mp_init(&a)); DOR(mp_init(&a));
/* number to result in a different size for every base: 67^(4 * 67) */ /* number to result in a different size for every base: 67^(4 * 67) */
mp_set(&a, 67u); mp_set(&a, 67);
DO(mp_expt_u32(&a, 268u, &a)); DO(mp_expt_n(&a, 268, &a));
for (radix = 2; radix < 65; radix++) { for (radix = 2; radix < 65; radix++) {
DO(mp_radix_size(&a, radix, &size)); DO(mp_radix_size(&a, radix, &size));
@ -2304,13 +2304,13 @@ static int unit_tests(int argc, char **argv)
T1(mp_get_u32, MP_GET_I32), T1(mp_get_u32, MP_GET_I32),
T1(mp_get_u64, MP_GET_I64), T1(mp_get_u64, MP_GET_I64),
T1(mp_get_ul, MP_GET_L), T1(mp_get_ul, MP_GET_L),
T1(mp_log_u32, MP_LOG_U32), T1(mp_log_n, MP_LOG_N),
T1(mp_incr, MP_ADD_D), T1(mp_incr, MP_ADD_D),
T1(mp_invmod, MP_INVMOD), T1(mp_invmod, MP_INVMOD),
T1(mp_is_square, MP_IS_SQUARE), T1(mp_is_square, MP_IS_SQUARE),
T1(mp_kronecker, MP_KRONECKER), T1(mp_kronecker, MP_KRONECKER),
T1(mp_montgomery_reduce, MP_MONTGOMERY_REDUCE), T1(mp_montgomery_reduce, MP_MONTGOMERY_REDUCE),
T1(mp_root_u32, MP_ROOT_U32), T1(mp_root_n, MP_ROOT_N),
T1(mp_or, MP_OR), T1(mp_or, MP_OR),
T1(mp_prime_is_prime, MP_PRIME_IS_PRIME), T1(mp_prime_is_prime, MP_PRIME_IS_PRIME),
T1(mp_prime_next_prime, MP_PRIME_NEXT_PRIME), T1(mp_prime_next_prime, MP_PRIME_NEXT_PRIME),
@ -2326,7 +2326,7 @@ static int unit_tests(int argc, char **argv)
T1(mp_set_double, MP_SET_DOUBLE), T1(mp_set_double, MP_SET_DOUBLE),
#endif #endif
T1(mp_signed_rsh, MP_SIGNED_RSH), T1(mp_signed_rsh, MP_SIGNED_RSH),
T1(mp_sqrt, MP_SQRT), T2(mp_sqrt, MP_SQRT, mp_root_n),
T1(mp_sqrtmod_prime, MP_SQRTMOD_PRIME), T1(mp_sqrtmod_prime, MP_SQRTMOD_PRIME),
T1(mp_xor, MP_XOR), T1(mp_xor, MP_XOR),
T2(s_mp_div_recursive, S_MP_DIV_RECURSIVE, S_MP_DIV_SCHOOL), T2(s_mp_div_recursive, S_MP_DIV_RECURSIVE, S_MP_DIV_SCHOOL),

View File

@ -1911,9 +1911,9 @@ mp_err mp_sqrmod(const mp_int *a, const mp_int *b, const mp_int *c, mp_int *d);
\chapter{Exponentiation} \chapter{Exponentiation}
\section{Single Digit Exponentiation} \section{Single Digit Exponentiation}
\index{mp\_expt\_u32} \index{mp\_expt\_n}
\begin{alltt} \begin{alltt}
mp_err mp_expt_u32 (const mp_int *a, uint32_t b, mp_int *c) mp_err mp_expt_n(const mp_int *a, int b, int *c)
\end{alltt} \end{alltt}
This function computes $c = a^b$. This function computes $c = a^b$.
@ -1940,9 +1940,9 @@ mp_err mp_mod_2d(const mp_int *a, int b, mp_int *c)
It calculates $c = a \mod 2^b$. It calculates $c = a \mod 2^b$.
\section{Root Finding} \section{Root Finding}
\index{mp\_root\_u32} \index{mp\_root\_n}
\begin{alltt} \begin{alltt}
mp_err mp_root_u32(const mp_int *a, uint32_t b, mp_int *c) mp_err mp_root_n(const mp_int *a, int b, mp_int *c)
\end{alltt} \end{alltt}
This computes $c = a^{1/b}$ such that $c^b \le a$ and $(c+1)^b > a$. Will return a positive root This computes $c = a^{1/b}$ such that $c^b \le a$ and $(c+1)^b > a$. Will return a positive root
only for even roots and return a root with the sign of the input for odd roots. For example, only for even roots and return a root with the sign of the input for odd roots. For example,
@ -1964,9 +1964,9 @@ mp_err mp_sqrt(const mp_int *arg, mp_int *ret)
A logarithm function for positive integer input \texttt{a, base} computing $\floor{\log_bx}$ such A logarithm function for positive integer input \texttt{a, base} computing $\floor{\log_bx}$ such
that $(\log_b x)^b \le x$. that $(\log_b x)^b \le x$.
\index{mp\_log\_u32} \index{mp\_log\_n}
\begin{alltt} \begin{alltt}
mp_err mp_log_u32(const mp_int *a, uint32_t base, uint32_t *c) mp_err mp_log_n(const mp_int *a, int base, int *c)
\end{alltt} \end{alltt}
\subsection{Example} \subsection{Example}
@ -1981,7 +1981,7 @@ mp_err mp_log_u32(const mp_int *a, uint32_t base, uint32_t *c)
int main(int argc, char **argv) int main(int argc, char **argv)
{ {
mp_int x, output; mp_int x, output;
uint32_t base; int base;
mp_err e; mp_err e;
if (argc != 3) { if (argc != 3) {
@ -1994,12 +1994,8 @@ int main(int argc, char **argv)
exit(EXIT_FAILURE); exit(EXIT_FAILURE);
} }
errno = 0; errno = 0;
#ifdef MP_64BIT base = (int)strtoul(argv[1], NULL, 10);
/* Check for overflow skipped */
base = (uint32_t)strtoull(argv[1], NULL, 10);
#else
base = (uint32_t)strtoul(argv[1], NULL, 10);
#endif
if (errno == ERANGE) { if (errno == ERANGE) {
fprintf(stderr,"strtoul(l) failed: input out of range\textbackslash{}n"); fprintf(stderr,"strtoul(l) failed: input out of range\textbackslash{}n");
exit(EXIT_FAILURE); exit(EXIT_FAILURE);
@ -2009,8 +2005,8 @@ int main(int argc, char **argv)
mp_error_to_string(e)); mp_error_to_string(e));
exit(EXIT_FAILURE); exit(EXIT_FAILURE);
} }
if ((e = mp_log_u32(&x, base, &output)) != MP_OKAY) { if ((e = mp_log_n(&x, base, &output)) != MP_OKAY) {
fprintf(stderr,"mp_ilogb failed: \textbackslash{}"%s\textbackslash{}"\textbackslash{}n", fprintf(stderr,"mp_log_n failed: \textbackslash{}"%s\textbackslash{}"\textbackslash{}n",
mp_error_to_string(e)); mp_error_to_string(e));
exit(EXIT_FAILURE); exit(EXIT_FAILURE);
} }

View File

@ -1,13 +1,12 @@
#include "tommath_private.h" #include "tommath_private.h"
#ifdef MP_EXPT_U32_C #ifdef MP_EXPT_N_C
/* LibTomMath, multiple-precision integer library -- Tom St Denis */ /* LibTomMath, multiple-precision integer library -- Tom St Denis */
/* SPDX-License-Identifier: Unlicense */ /* SPDX-License-Identifier: Unlicense */
/* calculate c = a**b using a square-multiply algorithm */ /* calculate c = a**b using a square-multiply algorithm */
mp_err mp_expt_u32(const mp_int *a, uint32_t b, mp_int *c) mp_err mp_expt_n(const mp_int *a, int b, mp_int *c)
{ {
mp_err err; mp_err err;
mp_int g; mp_int g;
if ((err = mp_init_copy(&g, a)) != MP_OKAY) { if ((err = mp_init_copy(&g, a)) != MP_OKAY) {
@ -17,16 +16,16 @@ mp_err mp_expt_u32(const mp_int *a, uint32_t b, mp_int *c)
/* set initial result */ /* set initial result */
mp_set(c, 1uL); mp_set(c, 1uL);
while (b > 0u) { while (b > 0) {
/* if the bit is set multiply */ /* if the bit is set multiply */
if ((b & 1u) != 0u) { if ((b & 1) != 0) {
if ((err = mp_mul(c, &g, c)) != MP_OKAY) { if ((err = mp_mul(c, &g, c)) != MP_OKAY) {
goto LBL_ERR; goto LBL_ERR;
} }
} }
/* square */ /* square */
if (b > 1u) { if (b > 1) {
if ((err = mp_sqr(&g, &g)) != MP_OKAY) { if ((err = mp_sqr(&g, &g)) != MP_OKAY) {
goto LBL_ERR; goto LBL_ERR;
} }
@ -36,8 +35,6 @@ mp_err mp_expt_u32(const mp_int *a, uint32_t b, mp_int *c)
b >>= 1; b >>= 1;
} }
err = MP_OKAY;
LBL_ERR: LBL_ERR:
mp_clear(&g); mp_clear(&g);
return err; return err;

29
mp_log_n.c Normal file
View File

@ -0,0 +1,29 @@
#include "tommath_private.h"
#ifdef MP_LOG_N_C
/* LibTomMath, multiple-precision integer library -- Tom St Denis */
/* SPDX-License-Identifier: Unlicense */
mp_err mp_log_n(const mp_int *a, int base, int *c)
{
if (mp_isneg(a) || mp_iszero(a) || (base < 2) || (unsigned)base > (unsigned)MP_DIGIT_MAX) {
return MP_VAL;
}
if (MP_HAS(S_MP_LOG_2EXPT) && MP_IS_2EXPT((mp_digit)base)) {
*c = s_mp_log_2expt(a, (mp_digit)base);
return MP_OKAY;
}
if (MP_HAS(S_MP_LOG_D) && (a->used == 1)) {
*c = s_mp_log_d((mp_digit)base, a->dp[0]);
return MP_OKAY;
}
if (MP_HAS(S_MP_LOG)) {
return s_mp_log(a, (mp_digit)base, c);
}
return MP_VAL;
}
#endif

View File

@ -1,29 +0,0 @@
#include "tommath_private.h"
#ifdef MP_LOG_U32_C
/* LibTomMath, multiple-precision integer library -- Tom St Denis */
/* SPDX-License-Identifier: Unlicense */
mp_err mp_log_u32(const mp_int *a, uint32_t base, uint32_t *c)
{
if (mp_isneg(a) || mp_iszero(a) || (base < 2u)) {
return MP_VAL;
}
if (MP_HAS(S_MP_LOG_POW2) && MP_IS_2EXPT(base)) {
*c = s_mp_log_pow2(a, base);
return MP_OKAY;
}
if (MP_HAS(S_MP_LOG_D) && (a->used == 1)) {
*c = (uint32_t)s_mp_log_d(base, a->dp[0]);
return MP_OKAY;
}
if (MP_HAS(S_MP_LOG)) {
return s_mp_log(a, base, c);
}
return MP_VAL;
}
#endif

View File

@ -8,7 +8,7 @@ mp_err mp_radix_size(const mp_int *a, int radix, size_t *size)
{ {
mp_err err; mp_err err;
mp_int a_; mp_int a_;
uint32_t b; int b;
/* make sure the radix is in range */ /* make sure the radix is in range */
if ((radix < 2) || (radix > 64)) { if ((radix < 2) || (radix > 64)) {
@ -22,14 +22,13 @@ mp_err mp_radix_size(const mp_int *a, int radix, size_t *size)
a_ = *a; a_ = *a;
a_.sign = MP_ZPOS; a_.sign = MP_ZPOS;
if ((err = mp_log_u32(&a_, (uint32_t)radix, &b)) != MP_OKAY) { if ((err = mp_log_n(&a_, radix, &b)) != MP_OKAY) {
goto LBL_ERR; return err;
} }
/* mp_ilogb truncates to zero, hence we need one extra put on top and one for `\0`. */ /* mp_ilogb truncates to zero, hence we need one extra put on top and one for `\0`. */
*size = (size_t)b + 2U + (mp_isneg(a) ? 1U : 0U); *size = (size_t)b + 2U + (mp_isneg(a) ? 1U : 0U);
LBL_ERR: return MP_OKAY;
return err;
} }
#endif #endif

View File

@ -1,5 +1,5 @@
#include "tommath_private.h" #include "tommath_private.h"
#ifdef MP_ROOT_U32_C #ifdef MP_ROOT_N_C
/* LibTomMath, multiple-precision integer library -- Tom St Denis */ /* LibTomMath, multiple-precision integer library -- Tom St Denis */
/* SPDX-License-Identifier: Unlicense */ /* SPDX-License-Identifier: Unlicense */
@ -12,15 +12,18 @@
* which will find the root in log(N) time where * which will find the root in log(N) time where
* each step involves a fair bit. * each step involves a fair bit.
*/ */
mp_err mp_root_u32(const mp_int *a, uint32_t b, mp_int *c) mp_err mp_root_n(const mp_int *a, int b, mp_int *c)
{ {
mp_int t1, t2, t3, a_; mp_int t1, t2, t3, a_;
mp_ord cmp;
int ilog2; int ilog2;
mp_err err; mp_err err;
if (b < 0 || (unsigned)b > (unsigned)MP_DIGIT_MAX) {
return MP_VAL;
}
/* input must be positive if b is even */ /* input must be positive if b is even */
if (((b & 1u) == 0u) && mp_isneg(a)) { if (((b & 1) == 0) && mp_isneg(a)) {
return MP_VAL; return MP_VAL;
} }
@ -40,7 +43,7 @@ mp_err mp_root_u32(const mp_int *a, uint32_t b, mp_int *c)
log_2(n) because the bit-length of the "n" is measured log_2(n) because the bit-length of the "n" is measured
with an int and hence the root is always < 2 (two). with an int and hence the root is always < 2 (two).
*/ */
if (b > (uint32_t)(INT_MAX/2)) { if (b > INT_MAX/2) {
mp_set(c, 1uL); mp_set(c, 1uL);
c->sign = a->sign; c->sign = a->sign;
err = MP_OKAY; err = MP_OKAY;
@ -48,13 +51,13 @@ mp_err mp_root_u32(const mp_int *a, uint32_t b, mp_int *c)
} }
/* "b" is smaller than INT_MAX, we can cast safely */ /* "b" is smaller than INT_MAX, we can cast safely */
if (ilog2 < (int)b) { if (ilog2 < b) {
mp_set(c, 1uL); mp_set(c, 1uL);
c->sign = a->sign; c->sign = a->sign;
err = MP_OKAY; err = MP_OKAY;
goto LBL_ERR; goto LBL_ERR;
} }
ilog2 = ilog2 / ((int)b); ilog2 = ilog2 / b;
if (ilog2 == 0) { if (ilog2 == 0) {
mp_set(c, 1uL); mp_set(c, 1uL);
c->sign = a->sign; c->sign = a->sign;
@ -71,7 +74,7 @@ mp_err mp_root_u32(const mp_int *a, uint32_t b, mp_int *c)
/* t2 = t1 - ((t1**b - a) / (b * t1**(b-1))) */ /* t2 = t1 - ((t1**b - a) / (b * t1**(b-1))) */
/* t3 = t1**(b-1) */ /* t3 = t1**(b-1) */
if ((err = mp_expt_u32(&t1, b - 1u, &t3)) != MP_OKAY) goto LBL_ERR; if ((err = mp_expt_n(&t1, b - 1, &t3)) != MP_OKAY) goto LBL_ERR;
/* numerator */ /* numerator */
/* t2 = t1**b */ /* t2 = t1**b */
@ -82,7 +85,7 @@ mp_err mp_root_u32(const mp_int *a, uint32_t b, mp_int *c)
/* denominator */ /* denominator */
/* t3 = t1**(b-1) * b */ /* t3 = t1**(b-1) * b */
if ((err = mp_mul_d(&t3, b, &t3)) != MP_OKAY) goto LBL_ERR; if ((err = mp_mul_d(&t3, (mp_digit)b, &t3)) != MP_OKAY) goto LBL_ERR;
/* t3 = (t1**b - a)/(b * t1**(b-1)) */ /* t3 = (t1**b - a)/(b * t1**(b-1)) */
if ((err = mp_div(&t2, &t3, &t3, NULL)) != MP_OKAY) goto LBL_ERR; if ((err = mp_div(&t2, &t3, &t3, NULL)) != MP_OKAY) goto LBL_ERR;
@ -101,7 +104,8 @@ mp_err mp_root_u32(const mp_int *a, uint32_t b, mp_int *c)
/* result can be off by a few so check */ /* result can be off by a few so check */
/* Loop beneath can overshoot by one if found root is smaller than actual root */ /* Loop beneath can overshoot by one if found root is smaller than actual root */
for (;;) { for (;;) {
if ((err = mp_expt_u32(&t1, b, &t2)) != MP_OKAY) goto LBL_ERR; mp_ord cmp;
if ((err = mp_expt_n(&t1, b, &t2)) != MP_OKAY) goto LBL_ERR;
cmp = mp_cmp(&t2, &a_); cmp = mp_cmp(&t2, &a_);
if (cmp == MP_EQ) { if (cmp == MP_EQ) {
err = MP_OKAY; err = MP_OKAY;
@ -115,7 +119,7 @@ mp_err mp_root_u32(const mp_int *a, uint32_t b, mp_int *c)
} }
/* correct overshoot from above or from recurrence */ /* correct overshoot from above or from recurrence */
for (;;) { for (;;) {
if ((err = mp_expt_u32(&t1, b, &t2)) != MP_OKAY) goto LBL_ERR; if ((err = mp_expt_n(&t1, b, &t2)) != MP_OKAY) goto LBL_ERR;
if (mp_cmp(&t2, &a_) == MP_GT) { if (mp_cmp(&t2, &a_) == MP_GT) {
if ((err = mp_sub_d(&t1, 1uL, &t1)) != MP_OKAY) goto LBL_ERR; if ((err = mp_sub_d(&t1, 1uL, &t1)) != MP_OKAY) goto LBL_ERR;
} else { } else {
@ -129,8 +133,6 @@ mp_err mp_root_u32(const mp_int *a, uint32_t b, mp_int *c)
/* set the sign of the result */ /* set the sign of the result */
c->sign = a->sign; c->sign = a->sign;
err = MP_OKAY;
LBL_ERR: LBL_ERR:
mp_clear_multi(&t1, &t2, &t3, NULL); mp_clear_multi(&t1, &t2, &t3, NULL);
return err; return err;

View File

@ -3,14 +3,13 @@
/* LibTomMath, multiple-precision integer library -- Tom St Denis */ /* LibTomMath, multiple-precision integer library -- Tom St Denis */
/* SPDX-License-Identifier: Unlicense */ /* SPDX-License-Identifier: Unlicense */
mp_err s_mp_log(const mp_int *a, uint32_t base, uint32_t *c) mp_err s_mp_log(const mp_int *a, mp_digit base, int *c)
{ {
mp_err err; mp_err err;
mp_ord cmp; int high, low;
uint32_t high, low, mid;
mp_int bracket_low, bracket_high, bracket_mid, t, bi_base; mp_int bracket_low, bracket_high, bracket_mid, t, bi_base;
cmp = mp_cmp_d(a, base); mp_ord cmp = mp_cmp_d(a, base);
if ((cmp == MP_LT) || (cmp == MP_EQ)) { if ((cmp == MP_LT) || (cmp == MP_EQ)) {
*c = cmp == MP_EQ; *c = cmp == MP_EQ;
return MP_OKAY; return MP_OKAY;
@ -22,9 +21,9 @@ mp_err s_mp_log(const mp_int *a, uint32_t base, uint32_t *c)
return err; return err;
} }
low = 0u; low = 0;
mp_set(&bracket_low, 1uL); mp_set(&bracket_low, 1uL);
high = 1u; high = 1;
mp_set(&bracket_high, base); mp_set(&bracket_high, base);
@ -46,10 +45,10 @@ mp_err s_mp_log(const mp_int *a, uint32_t base, uint32_t *c)
} }
mp_set(&bi_base, base); mp_set(&bi_base, base);
while ((high - low) > 1u) { while ((high - low) > 1) {
mid = (high + low) >> 1; int mid = (high + low) >> 1;
if ((err = mp_expt_u32(&bi_base, (uint32_t)(mid - low), &t)) != MP_OKAY) { if ((err = mp_expt_n(&bi_base, mid - low, &t)) != MP_OKAY) {
goto LBL_END; goto LBL_END;
} }
if ((err = mp_mul(&bracket_low, &t, &bracket_mid)) != MP_OKAY) { if ((err = mp_mul(&bracket_low, &t, &bracket_mid)) != MP_OKAY) {

12
s_mp_log_2expt.c Normal file
View File

@ -0,0 +1,12 @@
#include "tommath_private.h"
#ifdef S_MP_LOG_2EXPT_C
/* LibTomMath, multiple-precision integer library -- Tom St Denis */
/* SPDX-License-Identifier: Unlicense */
int s_mp_log_2expt(const mp_int *a, mp_digit base)
{
int y;
for (y = 0; (base & 1) == 0; y++, base >>= 1) {}
return (mp_count_bits(a) - 1) / y;
}
#endif

View File

@ -17,21 +17,18 @@ static mp_word s_pow(mp_word base, mp_word exponent)
return result; return result;
} }
mp_digit s_mp_log_d(mp_digit base, mp_digit n) int s_mp_log_d(mp_digit base, mp_digit n)
{ {
mp_word bracket_low = 1u, bracket_mid, bracket_high, N; mp_word bracket_low = 1uLL, bracket_high = base, N = n;
mp_digit ret, high = 1uL, low = 0uL, mid; int ret, high = 1, low = 0;
if (n < base) { if (n < base) {
return 0uL; return 0;
} }
if (n == base) { if (n == base) {
return 1uL; return 1;
} }
bracket_high = (mp_word) base ;
N = (mp_word) n;
while (bracket_high < N) { while (bracket_high < N) {
low = high; low = high;
bracket_low = bracket_high; bracket_low = bracket_high;
@ -40,8 +37,8 @@ mp_digit s_mp_log_d(mp_digit base, mp_digit n)
} }
while (((mp_digit)(high - low)) > 1uL) { while (((mp_digit)(high - low)) > 1uL) {
mid = (low + high) >> 1; int mid = (low + high) >> 1;
bracket_mid = bracket_low * s_pow(base, (mp_word)(mid - low)); mp_word bracket_mid = bracket_low * s_pow(base, (mp_word)(mid - low));
if (N < bracket_mid) { if (N < bracket_mid) {
high = mid ; high = mid ;
@ -52,7 +49,7 @@ mp_digit s_mp_log_d(mp_digit base, mp_digit n)
bracket_low = bracket_mid ; bracket_low = bracket_mid ;
} }
if (N == bracket_mid) { if (N == bracket_mid) {
return (mp_digit) mid; return mid;
} }
} }

View File

@ -1,12 +0,0 @@
#include "tommath_private.h"
#ifdef S_MP_LOG_POW2_C
/* LibTomMath, multiple-precision integer library -- Tom St Denis */
/* SPDX-License-Identifier: Unlicense */
uint32_t s_mp_log_pow2(const mp_int *a, uint32_t base)
{
int y;
for (y = 0; (base & 1u) == 0u; y++, base >>= 1) {}
return (uint32_t)((mp_count_bits(a) - 1) / y);
}
#endif

View File

@ -423,11 +423,17 @@ mp_err mp_exteuclid(const mp_int *a, const mp_int *b, mp_int *U1, mp_int *U2, mp
/* c = [a, b] or (a*b)/(a, b) */ /* c = [a, b] or (a*b)/(a, b) */
mp_err mp_lcm(const mp_int *a, const mp_int *b, mp_int *c) MP_WUR; mp_err mp_lcm(const mp_int *a, const mp_int *b, mp_int *c) MP_WUR;
/* Integer logarithm to integer base */
mp_err mp_log_n(const mp_int *a, int base, int *c) MP_WUR;
/* c = a**b */
mp_err mp_expt_n(const mp_int *a, int b, mp_int *c) MP_WUR;
/* finds one of the b'th root of a, such that |c|**b <= |a| /* finds one of the b'th root of a, such that |c|**b <= |a|
* *
* returns error if a < 0 and b is even * returns error if a < 0 and b is even
*/ */
mp_err mp_root_u32(const mp_int *a, uint32_t b, mp_int *c) MP_WUR; mp_err mp_root_n(const mp_int *a, int b, mp_int *c) MP_WUR;
/* special sqrt algo */ /* special sqrt algo */
mp_err mp_sqrt(const mp_int *arg, mp_int *ret) MP_WUR; mp_err mp_sqrt(const mp_int *arg, mp_int *ret) MP_WUR;
@ -557,12 +563,6 @@ mp_err mp_prime_next_prime(mp_int *a, int t, bool bbs_style) MP_WUR;
*/ */
mp_err mp_prime_rand(mp_int *a, int t, int size, int flags) MP_WUR; mp_err mp_prime_rand(mp_int *a, int t, int size, int flags) MP_WUR;
/* Integer logarithm to integer base */
mp_err mp_log_u32(const mp_int *a, uint32_t base, uint32_t *c) MP_WUR;
/* c = a**b */
mp_err mp_expt_u32(const mp_int *a, uint32_t b, mp_int *c) MP_WUR;
/* ---> radix conversion <--- */ /* ---> radix conversion <--- */
int mp_count_bits(const mp_int *a) MP_WUR; int mp_count_bits(const mp_int *a) MP_WUR;

View File

@ -161,7 +161,8 @@ extern MP_PRIVATE mp_err(*s_mp_rand_source)(void *out, size_t size);
/* lowlevel functions, do not call! */ /* lowlevel functions, do not call! */
MP_PRIVATE bool s_mp_get_bit(const mp_int *a, int b) MP_WUR; MP_PRIVATE bool s_mp_get_bit(const mp_int *a, int b) MP_WUR;
MP_PRIVATE mp_digit s_mp_log_d(mp_digit base, mp_digit n) MP_WUR; MP_PRIVATE int s_mp_log_2expt(const mp_int *a, mp_digit base) MP_WUR;
MP_PRIVATE int s_mp_log_d(mp_digit base, mp_digit n) MP_WUR;
MP_PRIVATE mp_err s_mp_add(const mp_int *a, const mp_int *b, mp_int *c) MP_WUR; MP_PRIVATE mp_err s_mp_add(const mp_int *a, const mp_int *b, mp_int *c) MP_WUR;
MP_PRIVATE mp_err s_mp_div_3(const mp_int *a, mp_int *c, mp_digit *d) MP_WUR; MP_PRIVATE mp_err s_mp_div_3(const mp_int *a, mp_int *c, mp_digit *d) MP_WUR;
MP_PRIVATE mp_err s_mp_div_recursive(const mp_int *a, const mp_int *b, mp_int *q, mp_int *r) MP_WUR; MP_PRIVATE mp_err s_mp_div_recursive(const mp_int *a, const mp_int *b, mp_int *q, mp_int *r) MP_WUR;
@ -171,7 +172,7 @@ MP_PRIVATE mp_err s_mp_exptmod(const mp_int *G, const mp_int *X, const mp_int *P
MP_PRIVATE mp_err s_mp_exptmod_fast(const mp_int *G, const mp_int *X, const mp_int *P, mp_int *Y, int redmode) MP_WUR; MP_PRIVATE mp_err s_mp_exptmod_fast(const mp_int *G, const mp_int *X, const mp_int *P, mp_int *Y, int redmode) MP_WUR;
MP_PRIVATE mp_err s_mp_invmod(const mp_int *a, const mp_int *b, mp_int *c) MP_WUR; MP_PRIVATE mp_err s_mp_invmod(const mp_int *a, const mp_int *b, mp_int *c) MP_WUR;
MP_PRIVATE mp_err s_mp_invmod_odd(const mp_int *a, const mp_int *b, mp_int *c) MP_WUR; MP_PRIVATE mp_err s_mp_invmod_odd(const mp_int *a, const mp_int *b, mp_int *c) MP_WUR;
MP_PRIVATE mp_err s_mp_log(const mp_int *a, uint32_t base, uint32_t *c) MP_WUR; MP_PRIVATE mp_err s_mp_log(const mp_int *a, mp_digit base, int *c) MP_WUR;
MP_PRIVATE mp_err s_mp_montgomery_reduce_comba(mp_int *x, const mp_int *n, mp_digit rho) MP_WUR; MP_PRIVATE mp_err s_mp_montgomery_reduce_comba(mp_int *x, const mp_int *n, mp_digit rho) MP_WUR;
MP_PRIVATE mp_err s_mp_mul(const mp_int *a, const mp_int *b, mp_int *c, int digs) MP_WUR; MP_PRIVATE mp_err s_mp_mul(const mp_int *a, const mp_int *b, mp_int *c, int digs) MP_WUR;
MP_PRIVATE mp_err s_mp_mul_balance(const mp_int *a, const mp_int *b, mp_int *c) MP_WUR; MP_PRIVATE mp_err s_mp_mul_balance(const mp_int *a, const mp_int *b, mp_int *c) MP_WUR;
@ -187,7 +188,6 @@ MP_PRIVATE mp_err s_mp_sqr_comba(const mp_int *a, mp_int *b) MP_WUR;
MP_PRIVATE mp_err s_mp_sqr_karatsuba(const mp_int *a, mp_int *b) MP_WUR; MP_PRIVATE mp_err s_mp_sqr_karatsuba(const mp_int *a, mp_int *b) MP_WUR;
MP_PRIVATE mp_err s_mp_sqr_toom(const mp_int *a, mp_int *b) MP_WUR; MP_PRIVATE mp_err s_mp_sqr_toom(const mp_int *a, mp_int *b) MP_WUR;
MP_PRIVATE mp_err s_mp_sub(const mp_int *a, const mp_int *b, mp_int *c) MP_WUR; MP_PRIVATE mp_err s_mp_sub(const mp_int *a, const mp_int *b, mp_int *c) MP_WUR;
MP_PRIVATE uint32_t s_mp_log_pow2(const mp_int *a, uint32_t base) MP_WUR;
MP_PRIVATE void s_mp_copy_digs(mp_digit *d, const mp_digit *s, int digits); MP_PRIVATE void s_mp_copy_digs(mp_digit *d, const mp_digit *s, int digits);
MP_PRIVATE void s_mp_zero_buf(void *mem, size_t size); MP_PRIVATE void s_mp_zero_buf(void *mem, size_t size);
MP_PRIVATE void s_mp_zero_digs(mp_digit *d, int digits); MP_PRIVATE void s_mp_zero_digs(mp_digit *d, int digits);

View File

@ -28,12 +28,12 @@
# define MP_NEG_C # define MP_NEG_C
# define MP_PRIME_FROBENIUS_UNDERWOOD_C # define MP_PRIME_FROBENIUS_UNDERWOOD_C
# define MP_RADIX_SIZE_C # define MP_RADIX_SIZE_C
# define MP_LOG_U32_C # define MP_LOG_N_C
# define MP_RAND_C # define MP_RAND_C
# define MP_REDUCE_C # define MP_REDUCE_C
# define MP_REDUCE_2K_L_C # define MP_REDUCE_2K_L_C
# define MP_FROM_SBIN_C # define MP_FROM_SBIN_C
# define MP_ROOT_U32_C # define MP_ROOT_N_C
# define MP_SET_L_C # define MP_SET_L_C
# define MP_SET_UL_C # define MP_SET_UL_C
# define MP_SBIN_SIZE_C # define MP_SBIN_SIZE_C