diff --git a/library/ssl_tls.c b/library/ssl_tls.c index c4619106f..0333cd5e5 100644 --- a/library/ssl_tls.c +++ b/library/ssl_tls.c @@ -9961,10 +9961,11 @@ static unsigned char ssl_serialized_session_header[] = { * verify_result is put before peer_cert so that all mandatory fields come * together in one block. */ -int mbedtls_ssl_session_save( const mbedtls_ssl_session *session, - unsigned char *buf, - size_t buf_len, - size_t *olen ) +static int ssl_session_save( const mbedtls_ssl_session *session, + unsigned char omit_header, + unsigned char *buf, + size_t buf_len, + size_t *olen ) { unsigned char *p = buf; size_t used = 0; @@ -9978,17 +9979,20 @@ int mbedtls_ssl_session_save( const mbedtls_ssl_session *session, #endif /* MBEDTLS_X509_CRT_PARSE_C */ - /* - * Add version identifier - */ - - used += sizeof( ssl_serialized_session_header ); - - if( used <= buf_len ) + if( !omit_header ) { - memcpy( p, ssl_serialized_session_header, - sizeof( ssl_serialized_session_header ) ); - p += sizeof( ssl_serialized_session_header ); + /* + * Add version identifier + */ + + used += sizeof( ssl_serialized_session_header ); + + if( used <= buf_len ) + { + memcpy( p, ssl_serialized_session_header, + sizeof( ssl_serialized_session_header ) ); + p += sizeof( ssl_serialized_session_header ); + } } /* @@ -10149,13 +10153,25 @@ int mbedtls_ssl_session_save( const mbedtls_ssl_session *session, return( 0 ); } +/* + * Public wrapper for ssl_session_save() + */ +int mbedtls_ssl_session_save( const mbedtls_ssl_session *session, + unsigned char *buf, + size_t buf_len, + size_t *olen ) +{ + return( ssl_session_save( session, 0, buf, buf_len, olen ) ); +} + /* * Deserialize session, see mbedtls_ssl_session_save() for format. * * This internal version is wrapped by a public function that cleans up in - * case of error. + * case of error, and has an extra option omit_header. */ static int ssl_session_load( mbedtls_ssl_session *session, + unsigned char omit_header, const unsigned char *buf, size_t len ) { @@ -10170,19 +10186,22 @@ static int ssl_session_load( mbedtls_ssl_session *session, #endif /* MBEDTLS_SSL_KEEP_PEER_CERTIFICATE */ #endif /* MBEDTLS_X509_CRT_PARSE_C */ - /* - * Check version identifier - */ - - if( (size_t)( end - p ) < sizeof( ssl_serialized_session_header ) ) - return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); - - if( memcmp( p, ssl_serialized_session_header, - sizeof( ssl_serialized_session_header ) ) != 0 ) + if( !omit_header ) { - return( MBEDTLS_ERR_SSL_VERSION_MISMATCH ); + /* + * Check version identifier + */ + + if( (size_t)( end - p ) < sizeof( ssl_serialized_session_header ) ) + return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); + + if( memcmp( p, ssl_serialized_session_header, + sizeof( ssl_serialized_session_header ) ) != 0 ) + { + return( MBEDTLS_ERR_SSL_VERSION_MISMATCH ); + } + p += sizeof( ssl_serialized_session_header ); } - p += sizeof( ssl_serialized_session_header ); /* * Time @@ -10381,7 +10400,7 @@ int mbedtls_ssl_session_load( mbedtls_ssl_session *session, const unsigned char *buf, size_t len ) { - int ret = ssl_session_load( session, buf, len ); + int ret = ssl_session_load( session, 0, buf, len ); if( ret != 0 ) mbedtls_ssl_session_free( session ); @@ -11424,7 +11443,7 @@ int mbedtls_ssl_context_save( mbedtls_ssl_context *ssl, /* * Session (length + data) */ - ret = mbedtls_ssl_session_save( ssl->session, NULL, 0, &session_len ); + ret = ssl_session_save( ssl->session, 1, NULL, 0, &session_len ); if( ret != MBEDTLS_ERR_SSL_BUFFER_TOO_SMALL ) return( ret ); @@ -11436,8 +11455,8 @@ int mbedtls_ssl_context_save( mbedtls_ssl_context *ssl, *p++ = (unsigned char)( ( session_len >> 8 ) & 0xFF ); *p++ = (unsigned char)( ( session_len ) & 0xFF ); - ret = mbedtls_ssl_session_save( ssl->session, - p, session_len, &session_len ); + ret = ssl_session_save( ssl->session, 1, + p, session_len, &session_len ); if( ret != 0 ) return( ret ); @@ -11661,9 +11680,12 @@ static int ssl_context_load( mbedtls_ssl_context *ssl, if( (size_t)( end - p ) < session_len ) return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); - ret = mbedtls_ssl_session_load( ssl->session, p, session_len ); + ret = ssl_session_load( ssl->session, 1, p, session_len ); if( ret != 0 ) + { + mbedtls_ssl_session_free( ssl->session ); return( ret ); + } p += session_len;