@@ -96,28 +96,29 @@ std::shared_ptr<Connection_Cipher_State> Channel_Impl_12::write_cipher_state_epo
9696}
9797
9898std::vector<X509_Certificate> Channel_Impl_12::peer_cert_chain () const {
99- if (const auto * active = active_state ()) {
100- return get_peer_cert_chain (*active );
99+ if (m_active_state. has_value ()) {
100+ return m_active_state-> peer_certs ( );
101101 }
102102 return std::vector<X509_Certificate>();
103103}
104104
105105std::optional<std::string> Channel_Impl_12::external_psk_identity () const {
106- const auto * state = (active_state () != nullptr ) ? active_state () : pending_state ();
107- if (state != nullptr ) {
106+ if (m_active_state.has_value ()) {
107+ return m_active_state->psk_identity ();
108+ }
109+ if (const auto * state = pending_state ()) {
108110 return state->psk_identity ();
109- } else {
110- return std::nullopt ;
111111 }
112+ return std::nullopt ;
112113}
113114
114115Handshake_State& Channel_Impl_12::create_handshake_state (Protocol_Version version) {
115116 if (pending_state () != nullptr ) {
116117 throw Internal_Error (" create_handshake_state called during handshake" );
117118 }
118119
119- if (const auto * active = active_state ()) {
120- const Protocol_Version active_version = active ->version ();
120+ if (m_active_state. has_value ()) {
121+ const Protocol_Version active_version = m_active_state ->version ();
121122
122123 if (active_version.is_datagram_protocol () != version.is_datagram_protocol ()) {
123124 throw TLS_Exception (Alert::ProtocolVersion,
@@ -156,8 +157,8 @@ Handshake_State& Channel_Impl_12::create_handshake_state(Protocol_Version versio
156157
157158 m_pending_state = new_handshake_state (std::move (io));
158159
159- if (const auto * active = active_state ()) {
160- m_pending_state->set_version (active ->version ());
160+ if (m_active_state. has_value ()) {
161+ m_pending_state->set_version (m_active_state ->version ());
161162 }
162163
163164 return *m_pending_state;
@@ -177,12 +178,12 @@ void Channel_Impl_12::renegotiate(bool force_full_renegotiation) {
177178 return ;
178179 }
179180
180- if (const auto * active = active_state ()) {
181+ if (m_active_state. has_value ()) {
181182 if (!force_full_renegotiation) {
182183 force_full_renegotiation = !policy ().allow_resumption_for_renegotiation ();
183184 }
184185
185- initiate_handshake (create_handshake_state (active ->version ()), force_full_renegotiation);
186+ initiate_handshake (create_handshake_state (m_active_state ->version ()), force_full_renegotiation);
186187 } else {
187188 throw Invalid_State (" Cannot renegotiate on inactive connection" );
188189 }
@@ -245,7 +246,7 @@ void Channel_Impl_12::change_cipher_spec_writer(Connection_Side side) {
245246}
246247
247248bool Channel_Impl_12::is_handshake_complete () const {
248- return ( active_state () != nullptr );
249+ return m_active_state. has_value ( );
249250}
250251
251252bool Channel_Impl_12::is_active () const {
@@ -257,10 +258,11 @@ bool Channel_Impl_12::is_closed() const {
257258}
258259
259260void Channel_Impl_12::activate_session () {
260- std::swap (m_active_state, m_pending_state);
261- m_pending_state.reset ();
261+ BOTAN_ASSERT_NONNULL (m_pending_state);
262262
263- if (!m_active_state->version ().is_datagram_protocol ()) {
263+ const auto & state = *m_pending_state;
264+
265+ if (!state.version ().is_datagram_protocol ()) {
264266 // TLS is easy just remove all but the current state
265267 const uint16_t current_epoch = sequence_numbers ().current_write_epoch ();
266268
@@ -270,6 +272,29 @@ void Channel_Impl_12::activate_session() {
270272 map_remove_if (not_current_epoch, m_read_cipher_states);
271273 }
272274
275+ m_active_state = Active_Connection_State_12 (state.version (),
276+ state.server_hello ()->ciphersuite (),
277+ state.client_hello ()->random (),
278+ application_protocol (),
279+ get_peer_cert_chain (state),
280+ state.psk_identity (),
281+ state.server_hello ()->random (),
282+ Session_ID (state.server_hello ()->session_id ()),
283+ state.session_keys ().master_secret (),
284+ state.ciphersuite ().prf_algo (),
285+ state.client_hello ()->secure_renegotiation (),
286+ state.server_hello ()->secure_renegotiation (),
287+ state.client_finished ()->verify_data (),
288+ state.server_finished ()->verify_data (),
289+ state.server_hello ()->supports_extended_master_secret ());
290+
291+ // For DTLS, keep the handshake IO for last-flight retransmission.
292+ if (m_is_datagram) {
293+ m_active_state->set_dtls_handshake_io (m_pending_state->take_handshake_io ());
294+ }
295+
296+ m_pending_state.reset ();
297+
273298 callbacks ().tls_session_activated ();
274299}
275300
@@ -319,10 +344,10 @@ size_t Channel_Impl_12::from_peer(std::span<const uint8_t> data) {
319344 throw TLS_Exception (Alert::RecordOverflow, " TLS plaintext record is larger than allowed maximum" );
320345 }
321346
322- const bool epoch0_restart = m_is_datagram && record.epoch () == 0 && active_state () != nullptr ;
347+ const bool epoch0_restart = m_is_datagram && record.epoch () == 0 && m_active_state. has_value () ;
323348 BOTAN_ASSERT_IMPLICATION (epoch0_restart, allow_epoch0_restart, " Allowed state" );
324349
325- const bool initial_record = epoch0_restart || (pending_state () == nullptr && active_state () == nullptr );
350+ const bool initial_record = epoch0_restart || (pending_state () == nullptr && !m_active_state. has_value () );
326351 bool initial_handshake_message = false ;
327352 if (record.type () == Record_Type::Handshake && !m_record_buf.empty ()) {
328353 const Handshake_Type type = static_cast <Handshake_Type>(m_record_buf[0 ]);
@@ -340,8 +365,8 @@ size_t Channel_Impl_12::from_peer(std::span<const uint8_t> data) {
340365 record.version () != pending->version ()) {
341366 throw TLS_Exception (Alert::ProtocolVersion, " Received unexpected record version" );
342367 }
343- } else if (const auto * active = active_state ()) {
344- if (record.version () != active ->version () && !initial_handshake_message) {
368+ } else if (m_active_state. has_value ()) {
369+ if (record.version () != m_active_state ->version () && !initial_handshake_message) {
345370 throw TLS_Exception (Alert::ProtocolVersion, " Received unexpected record version" );
346371 }
347372 }
@@ -404,8 +429,10 @@ void Channel_Impl_12::process_handshake_ccs(const secure_vector<uint8_t>& record
404429 if (epoch == sequence_numbers ().current_read_epoch ()) {
405430 create_handshake_state (record_version);
406431 } else if (epoch == sequence_numbers ().current_read_epoch () - 1 ) {
407- BOTAN_ASSERT (m_active_state, " Have active state here" );
408- m_active_state->handshake_io ().add_record (record.data (), record.size (), record_type, record_sequence);
432+ BOTAN_ASSERT (m_active_state.has_value () && m_active_state->dtls_handshake_io (),
433+ " Have DTLS handshake IO for retransmission" );
434+ m_active_state->dtls_handshake_io ()->add_record (
435+ record.data (), record.size (), record_type, record_sequence);
409436 }
410437 } else {
411438 create_handshake_state (record_version);
@@ -426,7 +453,7 @@ void Channel_Impl_12::process_handshake_ccs(const secure_vector<uint8_t>& record
426453 break ;
427454 }
428455
429- process_handshake_msg (active_state (), *pending, msg.first , msg.second , epoch0_restart);
456+ process_handshake_msg (*pending, msg.first , msg.second , epoch0_restart);
430457
431458 if (!m_pending_state) {
432459 break ;
@@ -436,7 +463,7 @@ void Channel_Impl_12::process_handshake_ccs(const secure_vector<uint8_t>& record
436463}
437464
438465void Channel_Impl_12::process_application_data (uint64_t seq_no, const secure_vector<uint8_t >& record) {
439- if (active_state () == nullptr ) {
466+ if (!m_active_state. has_value () ) {
440467 throw Unexpected_Message (" Application data before handshake done" );
441468 }
442469
@@ -453,11 +480,10 @@ void Channel_Impl_12::process_alert(const secure_vector<uint8_t>& record) {
453480 callbacks ().tls_alert (alert_msg);
454481
455482 if (alert_msg.is_fatal ()) {
456- if (const auto * active = active_state ()) {
457- BOTAN_ASSERT_NONNULL (active->server_hello ());
458- const auto & session_id = active->server_hello ()->session_id ();
459- if (!session_id.empty ()) {
460- session_manager ().remove (Session_Handle (session_id));
483+ if (m_active_state.has_value ()) {
484+ const auto & sid = m_active_state->session_id ();
485+ if (!sid.empty ()) {
486+ session_manager ().remove (Session_Handle (sid));
461487 }
462488 }
463489 }
@@ -479,10 +505,9 @@ void Channel_Impl_12::write_record(Connection_Cipher_State* cipher_state,
479505 Record_Type record_type,
480506 const uint8_t input[],
481507 size_t length) {
482- BOTAN_ASSERT (m_pending_state || m_active_state, " Some connection state exists" );
508+ BOTAN_ASSERT (m_pending_state || m_active_state. has_value () , " Some connection state exists" );
483509
484- const Protocol_Version record_version =
485- (m_pending_state) ? (m_pending_state->version ()) : (m_active_state->version ());
510+ const Protocol_Version record_version = (m_pending_state) ? (m_pending_state->version ()) : m_active_state->version ();
486511
487512 const uint64_t next_seq = sequence_numbers ().next_write_sequence (epoch);
488513
@@ -543,11 +568,10 @@ void Channel_Impl_12::send_alert(const Alert& alert) {
543568 }
544569
545570 if (alert.is_fatal ()) {
546- if (const auto * active = active_state ()) {
547- BOTAN_ASSERT_NONNULL (active->server_hello ());
548- const auto & session_id = active->server_hello ()->session_id ();
549- if (!session_id.empty ()) {
550- session_manager ().remove (Session_Handle (Session_ID (session_id)));
571+ if (m_active_state.has_value ()) {
572+ const auto & sid = m_active_state->session_id ();
573+ if (!sid.empty ()) {
574+ session_manager ().remove (Session_Handle (sid));
551575 }
552576 }
553577 reset_state ();
@@ -562,11 +586,8 @@ void Channel_Impl_12::secure_renegotiation_check(const Client_Hello_12* client_h
562586 BOTAN_ASSERT_NONNULL (client_hello);
563587 const bool secure_renegotiation = client_hello->secure_renegotiation ();
564588
565- if (const auto * active = active_state ()) {
566- BOTAN_ASSERT_NONNULL (active->client_hello ());
567- const bool active_sr = active->client_hello ()->secure_renegotiation ();
568-
569- if (active_sr != secure_renegotiation) {
589+ if (m_active_state.has_value ()) {
590+ if (m_active_state->client_supports_secure_renegotiation () != secure_renegotiation) {
570591 throw TLS_Exception (Alert::HandshakeFailure, " Client changed its mind about secure renegotiation" );
571592 }
572593 }
@@ -584,11 +605,8 @@ void Channel_Impl_12::secure_renegotiation_check(const Server_Hello_12* server_h
584605 BOTAN_ASSERT_NONNULL (server_hello);
585606 const bool secure_renegotiation = server_hello->secure_renegotiation ();
586607
587- if (const auto * active = active_state ()) {
588- BOTAN_ASSERT_NONNULL (active->server_hello ());
589- const bool active_sr = active->server_hello ()->secure_renegotiation ();
590-
591- if (active_sr != secure_renegotiation) {
608+ if (m_active_state.has_value ()) {
609+ if (m_active_state->server_supports_secure_renegotiation () != secure_renegotiation) {
592610 throw TLS_Exception (Alert::HandshakeFailure, " Server changed its mind about secure renegotiation" );
593611 }
594612 }
@@ -603,28 +621,25 @@ void Channel_Impl_12::secure_renegotiation_check(const Server_Hello_12* server_h
603621}
604622
605623std::vector<uint8_t > Channel_Impl_12::secure_renegotiation_data_for_client_hello () const {
606- if (const auto * active = active_state ()) {
607- BOTAN_ASSERT_NONNULL (active->client_finished ());
608- return active->client_finished ()->verify_data ();
624+ if (m_active_state.has_value ()) {
625+ return m_active_state->client_finished_verify_data ();
609626 }
610627 return std::vector<uint8_t >();
611628}
612629
613630std::vector<uint8_t > Channel_Impl_12::secure_renegotiation_data_for_server_hello () const {
614- if (const auto * active = active_state ()) {
615- BOTAN_ASSERT_NONNULL (active->client_finished ());
616- BOTAN_ASSERT_NONNULL (active->server_finished ());
617- std::vector<uint8_t > buf = active->client_finished ()->verify_data ();
618- buf += active->server_finished ()->verify_data ();
631+ if (m_active_state.has_value ()) {
632+ std::vector<uint8_t > buf = m_active_state->client_finished_verify_data ();
633+ buf += m_active_state->server_finished_verify_data ();
619634 return buf;
620635 }
621636
622637 return std::vector<uint8_t >();
623638}
624639
625640bool Channel_Impl_12::secure_renegotiation_supported () const {
626- if (const auto * active = active_state ()) {
627- return active-> server_hello ()-> secure_renegotiation ();
641+ if (m_active_state. has_value ()) {
642+ return m_active_state-> server_supports_secure_renegotiation ();
628643 }
629644
630645 if (const auto * pending = pending_state ()) {
@@ -639,35 +654,31 @@ bool Channel_Impl_12::secure_renegotiation_supported() const {
639654SymmetricKey Channel_Impl_12::key_material_export (std::string_view label,
640655 std::string_view context,
641656 size_t length) const {
642- if (const auto * active = active_state ()) {
643- if (pending_state () != nullptr ) {
644- throw Invalid_State (" Channel_Impl_12::key_material_export cannot export during renegotiation" );
645- }
657+ if (!m_active_state.has_value ()) {
658+ throw Invalid_State (" Channel_Impl_12::key_material_export connection not active" );
659+ }
646660
647- auto prf = active->protocol_specific_prf ();
661+ if (pending_state () != nullptr ) {
662+ throw Invalid_State (" Channel_Impl_12::key_material_export cannot export during renegotiation" );
663+ }
648664
649- const secure_vector< uint8_t >& master_secret = active-> session_keys ().master_secret ( );
665+ auto prf = callbacks ().tls12_protocol_specific_kdf (m_active_state-> prf_algo () );
650666
651- BOTAN_ASSERT_NONNULL (active->client_hello ());
652- BOTAN_ASSERT_NONNULL (active->server_hello ());
653- std::vector<uint8_t > salt;
654- salt += active->client_hello ()->random ();
655- salt += active->server_hello ()->random ();
667+ std::vector<uint8_t > salt;
668+ salt += m_active_state->client_random ();
669+ salt += m_active_state->server_random ();
656670
657- if (!context.empty ()) {
658- const size_t context_size = context.length ();
659- if (context_size > 0xFFFF ) {
660- throw Invalid_Argument (" key_material_export context is too long" );
661- }
662- salt.push_back (get_byte<0 >(static_cast <uint16_t >(context_size)));
663- salt.push_back (get_byte<1 >(static_cast <uint16_t >(context_size)));
664- salt += as_span_of_bytes (context);
671+ if (!context.empty ()) {
672+ const size_t context_size = context.length ();
673+ if (context_size > 0xFFFF ) {
674+ throw Invalid_Argument (" key_material_export context is too long" );
665675 }
666-
667- return SymmetricKey (prf->derive_key (length, master_secret, salt, as_span_of_bytes (label)));
668- } else {
669- throw Invalid_State (" Channel_Impl_12::key_material_export connection not active" );
676+ salt.push_back (get_byte<0 >(static_cast <uint16_t >(context_size)));
677+ salt.push_back (get_byte<1 >(static_cast <uint16_t >(context_size)));
678+ salt += as_span_of_bytes (context);
670679 }
680+
681+ return SymmetricKey (prf->derive_key (length, m_active_state->master_secret (), salt, as_span_of_bytes (label)));
671682}
672683
673684} // namespace Botan::TLS
0 commit comments