diff --git a/include/polarssl/dhm.h b/include/polarssl/dhm.h index 322a78860..064472f35 100644 --- a/include/polarssl/dhm.h +++ b/include/polarssl/dhm.h @@ -169,6 +169,13 @@ typedef struct } dhm_context; +/** + * \brief Initialize DHM context + * + * \param ctx DHM context to be initialized + */ +void dhm_init( dhm_context *ctx ); + /** * \brief Parse the ServerKeyExchange parameters * @@ -256,7 +263,9 @@ int dhm_calc_secret( dhm_context *ctx, void *p_rng ); /** - * \brief Free the components of a DHM key + * \brief Free and clear the components of a DHM key + * + * \param ctx DHM context to free and clear */ void dhm_free( dhm_context *ctx ); diff --git a/library/dhm.c b/library/dhm.c index 5a87e3ca7..1b1d6d65c 100644 --- a/library/dhm.c +++ b/library/dhm.c @@ -116,6 +116,11 @@ cleanup: return( ret ); } +void dhm_init( dhm_context *ctx ) +{ + memset( ctx, 0, sizeof( dhm_context ) ); +} + /* * Parse the ServerKeyExchange parameters */ @@ -125,8 +130,6 @@ int dhm_read_params( dhm_context *ctx, { int ret; - dhm_free( ctx ); - if( ( ret = dhm_read_bignum( &ctx->P, p, end ) ) != 0 || ( ret = dhm_read_bignum( &ctx->G, p, end ) ) != 0 || ( ret = dhm_read_bignum( &ctx->GY, p, end ) ) != 0 ) @@ -417,7 +420,6 @@ int dhm_parse_dhm( dhm_context *dhm, const unsigned char *dhmin, pem_context pem; pem_init( &pem ); - memset( dhm, 0, sizeof( dhm_context ) ); ret = pem_read_buffer( &pem, "-----BEGIN DH PARAMETERS-----", @@ -561,6 +563,8 @@ int dhm_self_test( int verbose ) int ret; dhm_context dhm; + dhm_init( &dhm ); + if( verbose != 0 ) polarssl_printf( " DHM parameter load: " ); @@ -570,15 +574,16 @@ int dhm_self_test( int verbose ) if( verbose != 0 ) polarssl_printf( "failed\n" ); - return( ret ); + goto exit; } if( verbose != 0 ) polarssl_printf( "passed\n\n" ); +exit: dhm_free( &dhm ); - return( 0 ); + return( ret ); #else if( verbose != 0 ) polarssl_printf( " DHM parameter load: skipped\n" ); diff --git a/library/ssl_tls.c b/library/ssl_tls.c index 963f02b71..a6761adce 100644 --- a/library/ssl_tls.c +++ b/library/ssl_tls.c @@ -3334,6 +3334,9 @@ static int ssl_handshake_init( ssl_context *ssl ) ssl->handshake->update_checksum = ssl_update_checksum_start; ssl->handshake->sig_alg = SSL_HASH_SHA1; +#if defined(POLARSSL_DHM_C) + dhm_init( &ssl->handshake->dhm_ctx ); +#endif #if defined(POLARSSL_ECDH_C) ecdh_init( &ssl->handshake->ecdh_ctx ); #endif