diff --git a/include/mbedtls/dhm.h b/include/mbedtls/dhm.h index e8c8a82f5..850813e37 100644 --- a/include/mbedtls/dhm.h +++ b/include/mbedtls/dhm.h @@ -279,10 +279,10 @@ int mbedtls_dhm_make_public( mbedtls_dhm_context *ctx, int x_size, * \param output_size The size of the destination buffer. This must be at * least the size of \c ctx->len (the size of \c P). * \param olen On exit, holds the actual number of Bytes written. - * \param f_rng The RNG function, for blinding purposes. This may - * b \c NULL if blinding isn't needed. - * \param p_rng The RNG context. This may be \c NULL if \p f_rng - * doesn't need a context argument. + * \param f_rng The RNG function. Must not be \c NULL. Used for + * blinding. + * \param p_rng The RNG context to be passed to \p f_rng. This may be + * \c NULL if \p f_rng doesn't need a context parameter. * * \return \c 0 on success. * \return An \c MBEDTLS_ERR_DHM_XXX error code on failure. diff --git a/library/dhm.c b/library/dhm.c index e88f3a2c7..29ce75598 100644 --- a/library/dhm.c +++ b/library/dhm.c @@ -444,6 +444,9 @@ int mbedtls_dhm_calc_secret( mbedtls_dhm_context *ctx, DHM_VALIDATE_RET( output != NULL ); DHM_VALIDATE_RET( olen != NULL ); + if( f_rng == NULL ) + return( MBEDTLS_ERR_DHM_BAD_INPUT_DATA ); + if( output_size < mbedtls_dhm_get_len( ctx ) ) return( MBEDTLS_ERR_DHM_BAD_INPUT_DATA ); @@ -453,25 +456,17 @@ int mbedtls_dhm_calc_secret( mbedtls_dhm_context *ctx, mbedtls_mpi_init( &GYb ); /* Blind peer's value */ - if( f_rng != NULL ) - { - MBEDTLS_MPI_CHK( dhm_update_blinding( ctx, f_rng, p_rng ) ); - MBEDTLS_MPI_CHK( mbedtls_mpi_mul_mpi( &GYb, &ctx->GY, &ctx->Vi ) ); - MBEDTLS_MPI_CHK( mbedtls_mpi_mod_mpi( &GYb, &GYb, &ctx->P ) ); - } - else - MBEDTLS_MPI_CHK( mbedtls_mpi_copy( &GYb, &ctx->GY ) ); + MBEDTLS_MPI_CHK( dhm_update_blinding( ctx, f_rng, p_rng ) ); + MBEDTLS_MPI_CHK( mbedtls_mpi_mul_mpi( &GYb, &ctx->GY, &ctx->Vi ) ); + MBEDTLS_MPI_CHK( mbedtls_mpi_mod_mpi( &GYb, &GYb, &ctx->P ) ); /* Do modular exponentiation */ MBEDTLS_MPI_CHK( mbedtls_mpi_exp_mod( &ctx->K, &GYb, &ctx->X, &ctx->P, &ctx->RP ) ); /* Unblind secret value */ - if( f_rng != NULL ) - { - MBEDTLS_MPI_CHK( mbedtls_mpi_mul_mpi( &ctx->K, &ctx->K, &ctx->Vf ) ); - MBEDTLS_MPI_CHK( mbedtls_mpi_mod_mpi( &ctx->K, &ctx->K, &ctx->P ) ); - } + MBEDTLS_MPI_CHK( mbedtls_mpi_mul_mpi( &ctx->K, &ctx->K, &ctx->Vf ) ); + MBEDTLS_MPI_CHK( mbedtls_mpi_mod_mpi( &ctx->K, &ctx->K, &ctx->P ) ); /* Output the secret without any leading zero byte. This is mandatory * for TLS per RFC 5246 ยง8.1.2. */ diff --git a/tests/suites/test_suite_dhm.function b/tests/suites/test_suite_dhm.function index 62e634a7f..5286bc7b8 100644 --- a/tests/suites/test_suite_dhm.function +++ b/tests/suites/test_suite_dhm.function @@ -150,7 +150,10 @@ void dhm_do_dhm( int radix_P, char *input_P, int x_size, &sec_srv_len, &mbedtls_test_rnd_pseudo_rand, &rnd_info ) == 0 ); - TEST_ASSERT( mbedtls_dhm_calc_secret( &ctx_cli, sec_cli, sizeof( sec_cli ), &sec_cli_len, NULL, NULL ) == 0 ); + TEST_ASSERT( mbedtls_dhm_calc_secret( &ctx_cli, sec_cli, sizeof( sec_cli ), + &sec_cli_len, + &mbedtls_test_rnd_pseudo_rand, + &rnd_info ) == 0 ); TEST_ASSERT( sec_srv_len == sec_cli_len ); TEST_ASSERT( sec_srv_len != 0 ); @@ -206,7 +209,10 @@ void dhm_do_dhm( int radix_P, char *input_P, int x_size, &sec_srv_len, &mbedtls_test_rnd_pseudo_rand, &rnd_info ) == 0 ); - TEST_ASSERT( mbedtls_dhm_calc_secret( &ctx_cli, sec_cli, sizeof( sec_cli ), &sec_cli_len, NULL, NULL ) == 0 ); + TEST_ASSERT( mbedtls_dhm_calc_secret( &ctx_cli, sec_cli, sizeof( sec_cli ), + &sec_cli_len, + &mbedtls_test_rnd_pseudo_rand, + &rnd_info ) == 0 ); TEST_ASSERT( sec_srv_len == sec_cli_len ); TEST_ASSERT( sec_srv_len != 0 );