diff --git a/include/polarssl/ssl.h b/include/polarssl/ssl.h index 1a7722c9d..63a75286e 100644 --- a/include/polarssl/ssl.h +++ b/include/polarssl/ssl.h @@ -644,6 +644,12 @@ struct _ssl_handshake_params unsigned char retransmit_state; /*!< Retransmission state */ ssl_flight_item *flight; /*!< Current outgoing flight */ ssl_flight_item *cur_msg; /*!< Current message in flight */ + unsigned int in_flight_start_seq; /*!< Minimum message sequence in the + flight being received */ + ssl_transform *alt_transform_out; /*!< Alternative transform for + resending messages */ + unsigned char alt_out_ctr[8]; /*!< Alternative record epoch/counter + for resending messages */ #endif /* @@ -719,7 +725,8 @@ struct _ssl_key_cert struct _ssl_flight_item { unsigned char *p; /*!< message, including handshake headers */ - size_t len; /*!< length of hs_msg */ + size_t len; /*!< length of p */ + unsigned char type; /*!< type of the message: handshake or CCS */ ssl_flight_item *next; /*!< next handshake message(s) */ }; #endif /* POLARSSL_SSL_PROTO_DTLS */ @@ -2031,6 +2038,11 @@ static inline size_t ssl_hs_hdr_len( const ssl_context *ssl ) return( 4 ); } +#if defined(POLARSSL_SSL_PROTO_DTLS) +void ssl_recv_flight_completed( ssl_context *ssl ); +int ssl_resend( ssl_context *ssl ); +#endif + /* constant-time buffer comparison */ static inline int safer_memcmp( const void *a, const void *b, size_t n ) { diff --git a/library/ssl_cli.c b/library/ssl_cli.c index ce14f5866..b5dae23a6 100644 --- a/library/ssl_cli.c +++ b/library/ssl_cli.c @@ -1011,6 +1011,8 @@ static int ssl_parse_hello_verify_request( ssl_context *ssl ) ssl->state = SSL_CLIENT_HELLO; ssl_reset_checksum( ssl ); + ssl_recv_flight_completed( ssl ); + SSL_DEBUG_MSG( 2, ( "<= parse hello verify request" ) ); return( 0 ); @@ -2229,6 +2231,11 @@ static int ssl_parse_server_hello_done( ssl_context *ssl ) ssl->state++; +#if defined(POLARSSL_SSL_PROTO_DTLS) + if( ssl->transport == SSL_TRANSPORT_DATAGRAM ) + ssl_recv_flight_completed( ssl ); +#endif + SSL_DEBUG_MSG( 2, ( "<= parse server hello done" ) ); return( 0 ); @@ -2734,6 +2741,16 @@ int ssl_handshake_client_step( ssl_context *ssl ) if( ( ret = ssl_flush_output( ssl ) ) != 0 ) return( ret ); +#if defined(POLARSSL_SSL_PROTO_DTLS) + if( ssl->transport == SSL_TRANSPORT_DATAGRAM && + ssl->handshake != NULL && + ssl->handshake->retransmit_state == SSL_RETRANS_SENDING ) + { + if( ( ret = ssl_resend( ssl ) ) != 0 ) + return( ret ); + } +#endif + switch( ssl->state ) { case SSL_HELLO_REQUEST: diff --git a/library/ssl_srv.c b/library/ssl_srv.c index c839ea7e5..a0bb6538b 100644 --- a/library/ssl_srv.c +++ b/library/ssl_srv.c @@ -1805,6 +1805,11 @@ have_ciphersuite: ssl->state++; +#if defined(POLARSSL_SSL_PROTO_DTLS) + if( ssl->transport == SSL_TRANSPORT_DATAGRAM ) + ssl_recv_flight_completed( ssl ); +#endif + SSL_DEBUG_MSG( 2, ( "<= parse client hello" ) ); return( 0 ); @@ -3485,6 +3490,16 @@ int ssl_handshake_server_step( ssl_context *ssl ) if( ( ret = ssl_flush_output( ssl ) ) != 0 ) return( ret ); +#if defined(POLARSSL_SSL_PROTO_DTLS) + if( ssl->transport == SSL_TRANSPORT_DATAGRAM && + ssl->handshake != NULL && + ssl->handshake->retransmit_state == SSL_RETRANS_SENDING ) + { + if( ( ret = ssl_resend( ssl ) ) != 0 ) + return( ret ); + } +#endif + switch( ssl->state ) { case SSL_HELLO_REQUEST: diff --git a/library/ssl_tls.c b/library/ssl_tls.c index 423bc0bd1..9160baa72 100644 --- a/library/ssl_tls.c +++ b/library/ssl_tls.c @@ -2013,6 +2013,9 @@ int ssl_flush_output( ssl_context *ssl ) return( 0 ); } +/* + * Functions to handle the DTLS retransmission state machine + */ #if defined(POLARSSL_SSL_PROTO_DTLS) /* * Append current handshake message to current outgoing flight @@ -2038,14 +2041,12 @@ static int ssl_flight_append( ssl_context *ssl ) /* Copy current handshake message with headers */ memcpy( msg->p, ssl->out_msg, ssl->out_msglen ); msg->len = ssl->out_msglen; + msg->type = ssl->out_msgtype; msg->next = NULL; /* Append to the current flight */ if( ssl->handshake->flight == NULL ) - { - ssl->handshake->flight = msg; - ssl->handshake->cur_msg = msg; - } + ssl->handshake->flight = msg; else { ssl_flight_item *cur = ssl->handshake->flight; @@ -2077,30 +2078,72 @@ static void ssl_flight_free( ssl_flight_item *flight ) } /* - * Send current flight of messages. + * Swap transform_out and out_ctr with the alternative ones + */ +static void ssl_swap_epochs( ssl_context *ssl ) +{ + ssl_transform *tmp_transform; + unsigned char tmp_out_ctr[8]; + + if( ssl->transform_out == ssl->handshake->alt_transform_out ) + { + SSL_DEBUG_MSG( 3, ( "skip swap epochs" ) ); + return; + } + + SSL_DEBUG_MSG( 3, ( "swap epochs" ) ); + + tmp_transform = ssl->transform_out; + ssl->transform_out = ssl->handshake->alt_transform_out; + ssl->handshake->alt_transform_out = tmp_transform; + + memcpy( tmp_out_ctr, ssl->out_ctr, 8 ); + memcpy( ssl->out_ctr, ssl->handshake->alt_out_ctr, 8 ); + memcpy( ssl->handshake->alt_out_ctr, tmp_out_ctr, 8 ); +} + +/* + * Retransmit the current flight of messages. * * Need to remember the current message in case flush_output returns * WANT_WRITE, causing us to exit this function and come back later. + * This function must be called until state is no longer SENDING. */ -static int ssl_send_current_flight( ssl_context *ssl ) +int ssl_resend( ssl_context *ssl ) { - ssl->handshake->retransmit_state = SSL_RETRANS_SENDING; + SSL_DEBUG_MSG( 2, ( "=> ssl_resend" ) ); - SSL_DEBUG_MSG( 2, ( "=> ssl_send_current_flight" ) ); + if( ssl->handshake->retransmit_state != SSL_RETRANS_SENDING ) + { + SSL_DEBUG_MSG( 2, ( "initialise resending" ) ); + + ssl->handshake->cur_msg = ssl->handshake->flight; + ssl_swap_epochs( ssl ); + + ssl->handshake->retransmit_state = SSL_RETRANS_SENDING; + } while( ssl->handshake->cur_msg != NULL ) { int ret; ssl_flight_item *cur = ssl->handshake->cur_msg; + memcpy( ssl->out_msg, cur->p, cur->len ); ssl->out_msglen = cur->len; - memcpy( ssl->out_msg, cur->p, ssl->out_msglen ); - ssl->out_msgtype = SSL_MSG_HANDSHAKE; + ssl->out_msgtype = cur->type; ssl->handshake->cur_msg = cur->next; SSL_DEBUG_BUF( 3, "resent handshake message header", ssl->out_msg, 12 ); + /* Swap epochs before sending Finished: we can't do it right after + * sending ChangeCipherSpec, in case write returns WANT_READ */ + if( ssl->out_msgtype == SSL_MSG_HANDSHAKE && + ssl->out_msg[0] == SSL_HS_FINISHED ) + { + ssl_swap_epochs( ssl ); + } + if( ( ret = ssl_write_record( ssl ) ) != 0 ) { SSL_DEBUG_RET( 1, "ssl_write_record", ret ); @@ -2110,10 +2153,32 @@ static int ssl_send_current_flight( ssl_context *ssl ) ssl->handshake->retransmit_state = SSL_RETRANS_WAITING; - SSL_DEBUG_MSG( 2, ( "<= ssl_send_current_flight" ) ); + SSL_DEBUG_MSG( 2, ( "<= ssl_resend" ) ); return( 0 ); } + +/* + * To be called when the last message of an incoming flight is received. + */ +void ssl_recv_flight_completed( ssl_context *ssl ) +{ + /* We won't need to resend that one any more */ + ssl_flight_free( ssl->handshake->flight ); + ssl->handshake->flight = NULL; + ssl->handshake->cur_msg = NULL; + + /* The next incoming flight will start with this msg_seq */ + ssl->handshake->in_flight_start_seq = ssl->handshake->in_msg_seq; + + if( ssl->in_msgtype == SSL_MSG_HANDSHAKE && + ssl->in_msg[0] == SSL_HS_FINISHED ) + { + ssl->handshake->retransmit_state = SSL_RETRANS_FINISHED; + } + else + ssl->handshake->retransmit_state = SSL_RETRANS_PREPARING; +} #endif /* POLARSSL_SSL_PROTO_DTLS */ /* @@ -3818,14 +3883,16 @@ int ssl_write_finished( ssl_context *ssl ) * data. */ SSL_DEBUG_MSG( 3, ( "switching to new transform spec for outbound data" ) ); - ssl->transform_out = ssl->transform_negotiate; - ssl->session_out = ssl->session_negotiate; #if defined(POLARSSL_SSL_PROTO_DTLS) if( ssl->transport == SSL_TRANSPORT_DATAGRAM ) { unsigned char i; + /* Remember current epoch settings for resending */ + ssl->handshake->alt_transform_out = ssl->transform_out; + memcpy( ssl->handshake->alt_out_ctr, ssl->out_ctr, 8 ); + /* Set sequence_number to zero */ memset( ssl->out_ctr + 2, 0, 6 ); @@ -3845,6 +3912,9 @@ int ssl_write_finished( ssl_context *ssl ) #endif /* POLARSSL_SSL_PROTO_DTLS */ memset( ssl->out_ctr, 0, 8 ); + ssl->transform_out = ssl->transform_negotiate; + ssl->session_out = ssl->session_negotiate; + #if defined(POLARSSL_SSL_HW_RECORD_ACCEL) if( ssl_hw_record_activate != NULL ) { @@ -3985,6 +4055,11 @@ int ssl_parse_finished( ssl_context *ssl ) else ssl->state++; +#if defined(POLARSSL_SSL_PROTO_DTLS) + if( ssl->transport == SSL_TRANSPORT_DATAGRAM ) + ssl_recv_flight_completed( ssl ); +#endif + SSL_DEBUG_MSG( 2, ( "<= parse finished" ) ); return( 0 ); @@ -4098,6 +4173,18 @@ static int ssl_handshake_init( ssl_context *ssl ) ssl->handshake->key_cert = ssl->key_cert; #endif +#if defined(POLARSSL_SSL_PROTO_DTLS) + if( ssl->transport == SSL_TRANSPORT_DATAGRAM ) + { + ssl->handshake->alt_transform_out = ssl->transform_out; + + if( ssl->endpoint == SSL_IS_CLIENT ) + ssl->handshake->retransmit_state = SSL_RETRANS_PREPARING; + else + ssl->handshake->retransmit_state = SSL_RETRANS_WAITING; + } +#endif + return( 0 ); }