Skip to content

Commit d33f4f6

Browse files
committed
Discard TLS handshake state after handshake has completed
Various bits of data relating to the handshake remained attached to the Channel even after the handshake completed, but in practice we only need a few pieces of information. Collect those into a new class, Active_Connection_State_{12,13}, and discard the rest.
1 parent 042e5bd commit d33f4f6

20 files changed

+761
-342
lines changed

src/lib/tls/tls12/info.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ tls_extensions_12.h
1414

1515
<header:internal>
1616
tls_channel_impl_12.h
17+
tls_connection_state_12.h
1718
tls_client_impl_12.h
1819
tls_record.h
1920
tls_server_impl_12.h

src/lib/tls/tls12/tls_channel_impl_12.cpp

Lines changed: 91 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -96,28 +96,29 @@ std::shared_ptr<Connection_Cipher_State> Channel_Impl_12::write_cipher_state_epo
9696
}
9797

9898
std::vector<X509_Certificate> Channel_Impl_12::peer_cert_chain() const {
99-
if(const auto* active = active_state()) {
100-
return get_peer_cert_chain(*active);
99+
if(m_active_state.has_value()) {
100+
return m_active_state->peer_certs();
101101
}
102102
return std::vector<X509_Certificate>();
103103
}
104104

105105
std::optional<std::string> Channel_Impl_12::external_psk_identity() const {
106-
const auto* state = (active_state() != nullptr) ? active_state() : pending_state();
107-
if(state != nullptr) {
106+
if(m_active_state.has_value()) {
107+
return m_active_state->psk_identity();
108+
}
109+
if(const auto* state = pending_state()) {
108110
return state->psk_identity();
109-
} else {
110-
return std::nullopt;
111111
}
112+
return std::nullopt;
112113
}
113114

114115
Handshake_State& Channel_Impl_12::create_handshake_state(Protocol_Version version) {
115116
if(pending_state() != nullptr) {
116117
throw Internal_Error("create_handshake_state called during handshake");
117118
}
118119

119-
if(const auto* active = active_state()) {
120-
const Protocol_Version active_version = active->version();
120+
if(m_active_state.has_value()) {
121+
const Protocol_Version active_version = m_active_state->version();
121122

122123
if(active_version.is_datagram_protocol() != version.is_datagram_protocol()) {
123124
throw TLS_Exception(Alert::ProtocolVersion,
@@ -156,8 +157,8 @@ Handshake_State& Channel_Impl_12::create_handshake_state(Protocol_Version versio
156157

157158
m_pending_state = new_handshake_state(std::move(io));
158159

159-
if(const auto* active = active_state()) {
160-
m_pending_state->set_version(active->version());
160+
if(m_active_state.has_value()) {
161+
m_pending_state->set_version(m_active_state->version());
161162
}
162163

163164
return *m_pending_state;
@@ -177,12 +178,12 @@ void Channel_Impl_12::renegotiate(bool force_full_renegotiation) {
177178
return;
178179
}
179180

180-
if(const auto* active = active_state()) {
181+
if(m_active_state.has_value()) {
181182
if(!force_full_renegotiation) {
182183
force_full_renegotiation = !policy().allow_resumption_for_renegotiation();
183184
}
184185

185-
initiate_handshake(create_handshake_state(active->version()), force_full_renegotiation);
186+
initiate_handshake(create_handshake_state(m_active_state->version()), force_full_renegotiation);
186187
} else {
187188
throw Invalid_State("Cannot renegotiate on inactive connection");
188189
}
@@ -245,7 +246,7 @@ void Channel_Impl_12::change_cipher_spec_writer(Connection_Side side) {
245246
}
246247

247248
bool Channel_Impl_12::is_handshake_complete() const {
248-
return (active_state() != nullptr);
249+
return m_active_state.has_value();
249250
}
250251

251252
bool Channel_Impl_12::is_active() const {
@@ -257,10 +258,11 @@ bool Channel_Impl_12::is_closed() const {
257258
}
258259

259260
void Channel_Impl_12::activate_session() {
260-
std::swap(m_active_state, m_pending_state);
261-
m_pending_state.reset();
261+
BOTAN_ASSERT_NONNULL(m_pending_state);
262262

263-
if(!m_active_state->version().is_datagram_protocol()) {
263+
const auto& state = *m_pending_state;
264+
265+
if(!state.version().is_datagram_protocol()) {
264266
// TLS is easy just remove all but the current state
265267
const uint16_t current_epoch = sequence_numbers().current_write_epoch();
266268

@@ -270,6 +272,29 @@ void Channel_Impl_12::activate_session() {
270272
map_remove_if(not_current_epoch, m_read_cipher_states);
271273
}
272274

275+
m_active_state = Active_Connection_State_12(state.version(),
276+
state.server_hello()->ciphersuite(),
277+
state.client_hello()->random(),
278+
application_protocol(),
279+
get_peer_cert_chain(state),
280+
state.psk_identity(),
281+
state.server_hello()->random(),
282+
Session_ID(state.server_hello()->session_id()),
283+
state.session_keys().master_secret(),
284+
state.ciphersuite().prf_algo(),
285+
state.client_hello()->secure_renegotiation(),
286+
state.server_hello()->secure_renegotiation(),
287+
state.client_finished()->verify_data(),
288+
state.server_finished()->verify_data(),
289+
state.server_hello()->supports_extended_master_secret());
290+
291+
// For DTLS, keep the handshake IO for last-flight retransmission.
292+
if(m_is_datagram) {
293+
m_active_state->set_dtls_handshake_io(m_pending_state->take_handshake_io());
294+
}
295+
296+
m_pending_state.reset();
297+
273298
callbacks().tls_session_activated();
274299
}
275300

@@ -319,10 +344,10 @@ size_t Channel_Impl_12::from_peer(std::span<const uint8_t> data) {
319344
throw TLS_Exception(Alert::RecordOverflow, "TLS plaintext record is larger than allowed maximum");
320345
}
321346

322-
const bool epoch0_restart = m_is_datagram && record.epoch() == 0 && active_state() != nullptr;
347+
const bool epoch0_restart = m_is_datagram && record.epoch() == 0 && m_active_state.has_value();
323348
BOTAN_ASSERT_IMPLICATION(epoch0_restart, allow_epoch0_restart, "Allowed state");
324349

325-
const bool initial_record = epoch0_restart || (pending_state() == nullptr && active_state() == nullptr);
350+
const bool initial_record = epoch0_restart || (pending_state() == nullptr && !m_active_state.has_value());
326351
bool initial_handshake_message = false;
327352
if(record.type() == Record_Type::Handshake && !m_record_buf.empty()) {
328353
const Handshake_Type type = static_cast<Handshake_Type>(m_record_buf[0]);
@@ -340,8 +365,8 @@ size_t Channel_Impl_12::from_peer(std::span<const uint8_t> data) {
340365
record.version() != pending->version()) {
341366
throw TLS_Exception(Alert::ProtocolVersion, "Received unexpected record version");
342367
}
343-
} else if(const auto* active = active_state()) {
344-
if(record.version() != active->version() && !initial_handshake_message) {
368+
} else if(m_active_state.has_value()) {
369+
if(record.version() != m_active_state->version() && !initial_handshake_message) {
345370
throw TLS_Exception(Alert::ProtocolVersion, "Received unexpected record version");
346371
}
347372
}
@@ -404,8 +429,10 @@ void Channel_Impl_12::process_handshake_ccs(const secure_vector<uint8_t>& record
404429
if(epoch == sequence_numbers().current_read_epoch()) {
405430
create_handshake_state(record_version);
406431
} else if(epoch == sequence_numbers().current_read_epoch() - 1) {
407-
BOTAN_ASSERT(m_active_state, "Have active state here");
408-
m_active_state->handshake_io().add_record(record.data(), record.size(), record_type, record_sequence);
432+
BOTAN_ASSERT(m_active_state.has_value() && m_active_state->dtls_handshake_io(),
433+
"Have DTLS handshake IO for retransmission");
434+
m_active_state->dtls_handshake_io()->add_record(
435+
record.data(), record.size(), record_type, record_sequence);
409436
}
410437
} else {
411438
create_handshake_state(record_version);
@@ -426,7 +453,7 @@ void Channel_Impl_12::process_handshake_ccs(const secure_vector<uint8_t>& record
426453
break;
427454
}
428455

429-
process_handshake_msg(active_state(), *pending, msg.first, msg.second, epoch0_restart);
456+
process_handshake_msg(*pending, msg.first, msg.second, epoch0_restart);
430457

431458
if(!m_pending_state) {
432459
break;
@@ -436,7 +463,7 @@ void Channel_Impl_12::process_handshake_ccs(const secure_vector<uint8_t>& record
436463
}
437464

438465
void Channel_Impl_12::process_application_data(uint64_t seq_no, const secure_vector<uint8_t>& record) {
439-
if(active_state() == nullptr) {
466+
if(!m_active_state.has_value()) {
440467
throw Unexpected_Message("Application data before handshake done");
441468
}
442469

@@ -453,11 +480,10 @@ void Channel_Impl_12::process_alert(const secure_vector<uint8_t>& record) {
453480
callbacks().tls_alert(alert_msg);
454481

455482
if(alert_msg.is_fatal()) {
456-
if(const auto* active = active_state()) {
457-
BOTAN_ASSERT_NONNULL(active->server_hello());
458-
const auto& session_id = active->server_hello()->session_id();
459-
if(!session_id.empty()) {
460-
session_manager().remove(Session_Handle(session_id));
483+
if(m_active_state.has_value()) {
484+
const auto& sid = m_active_state->session_id();
485+
if(!sid.empty()) {
486+
session_manager().remove(Session_Handle(sid));
461487
}
462488
}
463489
}
@@ -479,10 +505,9 @@ void Channel_Impl_12::write_record(Connection_Cipher_State* cipher_state,
479505
Record_Type record_type,
480506
const uint8_t input[],
481507
size_t length) {
482-
BOTAN_ASSERT(m_pending_state || m_active_state, "Some connection state exists");
508+
BOTAN_ASSERT(m_pending_state || m_active_state.has_value(), "Some connection state exists");
483509

484-
const Protocol_Version record_version =
485-
(m_pending_state) ? (m_pending_state->version()) : (m_active_state->version());
510+
const Protocol_Version record_version = (m_pending_state) ? (m_pending_state->version()) : m_active_state->version();
486511

487512
const uint64_t next_seq = sequence_numbers().next_write_sequence(epoch);
488513

@@ -543,11 +568,10 @@ void Channel_Impl_12::send_alert(const Alert& alert) {
543568
}
544569

545570
if(alert.is_fatal()) {
546-
if(const auto* active = active_state()) {
547-
BOTAN_ASSERT_NONNULL(active->server_hello());
548-
const auto& session_id = active->server_hello()->session_id();
549-
if(!session_id.empty()) {
550-
session_manager().remove(Session_Handle(Session_ID(session_id)));
571+
if(m_active_state.has_value()) {
572+
const auto& sid = m_active_state->session_id();
573+
if(!sid.empty()) {
574+
session_manager().remove(Session_Handle(sid));
551575
}
552576
}
553577
reset_state();
@@ -562,11 +586,8 @@ void Channel_Impl_12::secure_renegotiation_check(const Client_Hello_12* client_h
562586
BOTAN_ASSERT_NONNULL(client_hello);
563587
const bool secure_renegotiation = client_hello->secure_renegotiation();
564588

565-
if(const auto* active = active_state()) {
566-
BOTAN_ASSERT_NONNULL(active->client_hello());
567-
const bool active_sr = active->client_hello()->secure_renegotiation();
568-
569-
if(active_sr != secure_renegotiation) {
589+
if(m_active_state.has_value()) {
590+
if(m_active_state->client_supports_secure_renegotiation() != secure_renegotiation) {
570591
throw TLS_Exception(Alert::HandshakeFailure, "Client changed its mind about secure renegotiation");
571592
}
572593
}
@@ -584,11 +605,8 @@ void Channel_Impl_12::secure_renegotiation_check(const Server_Hello_12* server_h
584605
BOTAN_ASSERT_NONNULL(server_hello);
585606
const bool secure_renegotiation = server_hello->secure_renegotiation();
586607

587-
if(const auto* active = active_state()) {
588-
BOTAN_ASSERT_NONNULL(active->server_hello());
589-
const bool active_sr = active->server_hello()->secure_renegotiation();
590-
591-
if(active_sr != secure_renegotiation) {
608+
if(m_active_state.has_value()) {
609+
if(m_active_state->server_supports_secure_renegotiation() != secure_renegotiation) {
592610
throw TLS_Exception(Alert::HandshakeFailure, "Server changed its mind about secure renegotiation");
593611
}
594612
}
@@ -603,28 +621,25 @@ void Channel_Impl_12::secure_renegotiation_check(const Server_Hello_12* server_h
603621
}
604622

605623
std::vector<uint8_t> Channel_Impl_12::secure_renegotiation_data_for_client_hello() const {
606-
if(const auto* active = active_state()) {
607-
BOTAN_ASSERT_NONNULL(active->client_finished());
608-
return active->client_finished()->verify_data();
624+
if(m_active_state.has_value()) {
625+
return m_active_state->client_finished_verify_data();
609626
}
610627
return std::vector<uint8_t>();
611628
}
612629

613630
std::vector<uint8_t> Channel_Impl_12::secure_renegotiation_data_for_server_hello() const {
614-
if(const auto* active = active_state()) {
615-
BOTAN_ASSERT_NONNULL(active->client_finished());
616-
BOTAN_ASSERT_NONNULL(active->server_finished());
617-
std::vector<uint8_t> buf = active->client_finished()->verify_data();
618-
buf += active->server_finished()->verify_data();
631+
if(m_active_state.has_value()) {
632+
std::vector<uint8_t> buf = m_active_state->client_finished_verify_data();
633+
buf += m_active_state->server_finished_verify_data();
619634
return buf;
620635
}
621636

622637
return std::vector<uint8_t>();
623638
}
624639

625640
bool Channel_Impl_12::secure_renegotiation_supported() const {
626-
if(const auto* active = active_state()) {
627-
return active->server_hello()->secure_renegotiation();
641+
if(m_active_state.has_value()) {
642+
return m_active_state->server_supports_secure_renegotiation();
628643
}
629644

630645
if(const auto* pending = pending_state()) {
@@ -639,35 +654,31 @@ bool Channel_Impl_12::secure_renegotiation_supported() const {
639654
SymmetricKey Channel_Impl_12::key_material_export(std::string_view label,
640655
std::string_view context,
641656
size_t length) const {
642-
if(const auto* active = active_state()) {
643-
if(pending_state() != nullptr) {
644-
throw Invalid_State("Channel_Impl_12::key_material_export cannot export during renegotiation");
645-
}
657+
if(!m_active_state.has_value()) {
658+
throw Invalid_State("Channel_Impl_12::key_material_export connection not active");
659+
}
646660

647-
auto prf = active->protocol_specific_prf();
661+
if(pending_state() != nullptr) {
662+
throw Invalid_State("Channel_Impl_12::key_material_export cannot export during renegotiation");
663+
}
648664

649-
const secure_vector<uint8_t>& master_secret = active->session_keys().master_secret();
665+
auto prf = callbacks().tls12_protocol_specific_kdf(m_active_state->prf_algo());
650666

651-
BOTAN_ASSERT_NONNULL(active->client_hello());
652-
BOTAN_ASSERT_NONNULL(active->server_hello());
653-
std::vector<uint8_t> salt;
654-
salt += active->client_hello()->random();
655-
salt += active->server_hello()->random();
667+
std::vector<uint8_t> salt;
668+
salt += m_active_state->client_random();
669+
salt += m_active_state->server_random();
656670

657-
if(!context.empty()) {
658-
const size_t context_size = context.length();
659-
if(context_size > 0xFFFF) {
660-
throw Invalid_Argument("key_material_export context is too long");
661-
}
662-
salt.push_back(get_byte<0>(static_cast<uint16_t>(context_size)));
663-
salt.push_back(get_byte<1>(static_cast<uint16_t>(context_size)));
664-
salt += as_span_of_bytes(context);
671+
if(!context.empty()) {
672+
const size_t context_size = context.length();
673+
if(context_size > 0xFFFF) {
674+
throw Invalid_Argument("key_material_export context is too long");
665675
}
666-
667-
return SymmetricKey(prf->derive_key(length, master_secret, salt, as_span_of_bytes(label)));
668-
} else {
669-
throw Invalid_State("Channel_Impl_12::key_material_export connection not active");
676+
salt.push_back(get_byte<0>(static_cast<uint16_t>(context_size)));
677+
salt.push_back(get_byte<1>(static_cast<uint16_t>(context_size)));
678+
salt += as_span_of_bytes(context);
670679
}
680+
681+
return SymmetricKey(prf->derive_key(length, m_active_state->master_secret(), salt, as_span_of_bytes(label)));
671682
}
672683

673684
} // namespace Botan::TLS

src/lib/tls/tls12/tls_channel_impl_12.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <botan/tls_alert.h>
1313
#include <botan/tls_session_manager.h>
1414
#include <botan/internal/tls_channel_impl.h>
15+
#include <botan/internal/tls_connection_state_12.h>
1516
#include <map>
1617
#include <memory>
1718
#include <string>
@@ -149,8 +150,9 @@ class Channel_Impl_12 : public Channel_Impl {
149150
bool timeout_check() override;
150151

151152
protected:
152-
virtual void process_handshake_msg(const Handshake_State* active_state,
153-
Handshake_State& pending_state,
153+
const std::optional<Active_Connection_State_12>& active_state() const { return m_active_state; }
154+
155+
virtual void process_handshake_msg(Handshake_State& pending_state,
154156
Handshake_Type type,
155157
const std::vector<uint8_t>& contents,
156158
bool epoch0_restart) = 0;
@@ -206,8 +208,6 @@ class Channel_Impl_12 : public Channel_Impl {
206208

207209
std::shared_ptr<Connection_Cipher_State> write_cipher_state_epoch(uint16_t epoch) const;
208210

209-
const Handshake_State* active_state() const { return m_active_state.get(); }
210-
211211
const Handshake_State* pending_state() const { return m_pending_state.get(); }
212212

213213
/* methods to handle incoming traffic through Channel_Impl_12::receive_data. */
@@ -235,8 +235,7 @@ class Channel_Impl_12 : public Channel_Impl {
235235
/* sequence number state */
236236
std::unique_ptr<Connection_Sequence_Numbers> m_sequence_numbers;
237237

238-
/* pending and active connection states */
239-
std::unique_ptr<Handshake_State> m_active_state;
238+
/* pending handshake state (null when no handshake is in progress) */
240239
std::unique_ptr<Handshake_State> m_pending_state;
241240

242241
/* cipher states for each epoch */
@@ -249,6 +248,8 @@ class Channel_Impl_12 : public Channel_Impl {
249248
secure_vector<uint8_t> m_record_buf;
250249

251250
bool m_has_been_closed;
251+
252+
std::optional<Active_Connection_State_12> m_active_state;
252253
};
253254

254255
} // namespace TLS

0 commit comments

Comments
 (0)