diff --git a/library/ssl_tls.c b/library/ssl_tls.c index 5ad244d33..0b325616f 100644 --- a/library/ssl_tls.c +++ b/library/ssl_tls.c @@ -66,6 +66,16 @@ static void polarssl_zeroize( void *v, size_t n ) { volatile unsigned char *p = v; while( n-- ) *p++ = 0; } +/* Length of the "epoch" field in the record header */ +static inline size_t ssl_ep_len( const ssl_context *ssl ) +{ +#if defined(POLARSSL_SSL_PROTO_DTLS) + if( ssl->transport == SSL_TRANSPORT_DATAGRAM ) + return( 2 ); +#endif + return( 0 ); +} + #if defined(POLARSSL_SSL_MAX_FRAGMENT_LENGTH) /* * Convert max_fragment_length codes to length. @@ -1690,19 +1700,17 @@ static int ssl_decrypt_buf( ssl_context *ssl ) else ssl->nb_zero = 0; - /* For DTLS we don't maintain our own incoming counter (for now) */ - if( ssl->transport == SSL_TRANSPORT_STREAM ) - { - for( i = 8; i > 0; i-- ) - if( ++ssl->in_ctr[i - 1] != 0 ) - break; + /* Input counter not used with DTLS right now, + * but it doesn't hurt to have this part ready */ + for( i = 8; i > ssl_ep_len( ssl ); i-- ) + if( ++ssl->in_ctr[i - 1] != 0 ) + break; - /* The loop goes to its end iff the counter is wrapping */ - if( i == 0 ) - { - SSL_DEBUG_MSG( 1, ( "incoming message counter would wrap" ) ); - return( POLARSSL_ERR_SSL_COUNTER_WRAPPING ); - } + /* The loop goes to its end iff the counter is wrapping */ + if( i == ssl_ep_len( ssl ) ) + { + SSL_DEBUG_MSG( 1, ( "incoming message counter would wrap" ) ); + return( POLARSSL_ERR_SSL_COUNTER_WRAPPING ); } SSL_DEBUG_MSG( 2, ( "<= decrypt buf" ) ); @@ -1851,8 +1859,8 @@ int ssl_fetch_input( ssl_context *ssl, size_t nb_want ) */ int ssl_flush_output( ssl_context *ssl ) { - int ret, i; - unsigned char *buf; + int ret; + unsigned char *buf, i; SSL_DEBUG_MSG( 2, ( "=> flush output" ) ); @@ -1880,13 +1888,12 @@ int ssl_flush_output( ssl_context *ssl ) ssl->out_left -= ret; } - // TODO: adapt for DTLS (start from i = 6) - for( i = 8; i > 0; i-- ) + for( i = 8; i > ssl_ep_len( ssl ); i-- ) if( ++ssl->out_ctr[i - 1] != 0 ) break; /* The loop goes to its end iff the counter is wrapping */ - if( i == 0 ) + if( i == ssl_ep_len( ssl ) ) { SSL_DEBUG_MSG( 1, ( "outgoing message counter would wrap" ) ); return( POLARSSL_ERR_SSL_COUNTER_WRAPPING ); @@ -3122,7 +3129,7 @@ void ssl_handshake_wrapup( ssl_context *ssl ) int ssl_write_finished( ssl_context *ssl ) { - int ret, hash_len; + int ret, hash_len, i; SSL_DEBUG_MSG( 2, ( "=> write finished" ) ); @@ -3170,8 +3177,12 @@ int ssl_write_finished( ssl_context *ssl ) SSL_DEBUG_MSG( 3, ( "switching to new transform spec for outbound data" ) ); ssl->transform_out = ssl->transform_negotiate; ssl->session_out = ssl->session_negotiate; - // TODO: DTLS epoch? - memset( ssl->out_ctr, 0, 8 ); + + memset( ssl->out_ctr + ssl_ep_len( ssl ), 0, 8 - ssl_ep_len( ssl ) ); + for( i = ssl_ep_len( ssl ); i > 0; i-- ) + if( ++ssl->out_ctr[i - 1] != 0 ) + break; + // TODO: abort on epoch wrap! #if defined(POLARSSL_SSL_HW_RECORD_ACCEL) if( ssl_hw_record_activate != NULL ) @@ -3198,7 +3209,7 @@ int ssl_write_finished( ssl_context *ssl ) int ssl_parse_finished( ssl_context *ssl ) { int ret; - unsigned int hash_len; + unsigned int hash_len, i; unsigned char buf[36]; SSL_DEBUG_MSG( 2, ( "=> parse finished" ) ); @@ -3212,8 +3223,14 @@ int ssl_parse_finished( ssl_context *ssl ) SSL_DEBUG_MSG( 3, ( "switching to new transform spec for inbound data" ) ); ssl->transform_in = ssl->transform_negotiate; ssl->session_in = ssl->session_negotiate; - // TODO: DTLS epoch? - memset( ssl->in_ctr, 0, 8 ); + + /* Input counter/epoch not used with DTLS right now, + * but it doesn't hurt to have this part ready */ + memset( ssl->in_ctr + ssl_ep_len( ssl ), 0, 8 - ssl_ep_len( ssl ) ); + for( i = ssl_ep_len( ssl ); i > 0; i-- ) + if( ++ssl->in_ctr[i - 1] != 0 ) + break; + // TODO: abort on epoch wrap! /* * Set the in_msg pointer to the correct location based on IV length