, TunnResult<'b>> {
- let packet = Tunn::parse_incoming_packet(src)?;
+ src_addr: IpAddr,
+ handshake: Packet,
+ ) -> Result, TunnResult> {
+ let sender_idx = handshake.sender_idx();
+ let mac1 = handshake.mac1();
+ let mac2 = handshake.mac2();
+
+ let computed_mac1 = b2s_keyed_mac_16(&self.mac1_key, handshake.until_mac1());
+ if verify_slices_are_equal(&computed_mac1, mac1).is_err() {
+ return Err(TunnResult::Err(WireGuardError::InvalidMac));
+ }
- // Verify and rate limit handshake messages only
- if let Packet::HandshakeInit(HandshakeInit { sender_idx, .. })
- | Packet::HandshakeResponse(HandshakeResponse { sender_idx, .. }) = packet
- {
- let (msg, macs) = src.split_at(src.len() - 32);
- let (mac1, mac2) = macs.split_at(16);
-
- let computed_mac1 = b2s_keyed_mac_16(&self.mac1_key, msg);
- verify_slices_are_equal(&computed_mac1[..16], mac1)
- .map_err(|_| TunnResult::Err(WireGuardError::InvalidMac))?;
-
- if self.is_under_load() {
- let addr = match src_addr {
- None => return Err(TunnResult::Err(WireGuardError::UnderLoad)),
- Some(addr) => addr,
- };
-
- // Only given an address can we validate mac2
- let cookie = self.current_cookie(addr);
- let computed_mac2 = b2s_keyed_mac_16_2(&cookie, msg, mac1);
-
- if verify_slices_are_equal(&computed_mac2[..16], mac2).is_err() {
- let cookie_packet = self
- .format_cookie_reply(sender_idx, cookie, mac1, dst)
- .map_err(TunnResult::Err)?;
- return Err(TunnResult::WriteToNetwork(cookie_packet));
- }
+ if self.is_under_load() {
+ let cookie = self.current_cookie(src_addr);
+ let computed_mac2 = b2s_keyed_mac_16_2(&cookie, handshake.until_mac1(), mac1);
+
+ if verify_slices_are_equal(&computed_mac2, mac2).is_err() {
+ let cookie_reply = self.format_cookie_reply(sender_idx, cookie, mac1);
+ let packet = handshake.overwrite_with(&cookie_reply);
+ return Err(TunnResult::WriteToNetwork(packet.into()));
}
+
+ // If under load but mac2 is valid, allow the handshake
+ return Ok(handshake);
}
- Ok(packet)
+ Ok(handshake)
}
}
diff --git a/src/noise/session.rs b/src/noise/session.rs
index b523f53..af9a314 100644
--- a/src/noise/session.rs
+++ b/src/noise/session.rs
@@ -1,11 +1,19 @@
// Copyright (c) 2019 Cloudflare, Inc. All rights reserved.
+//
+// Modified by Mullvad VPN.
+// Copyright (c) 2025 Mullvad VPN.
+//
// SPDX-License-Identifier: BSD-3-Clause
use super::tls::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305};
-use super::PacketData;
-use crate::noise::errors::WireGuardError;
+use crate::{
+ noise::errors::WireGuardError,
+ packet::{Packet, WgData, WgDataHeader, WgKind},
+};
+use bytes::{Buf, BytesMut};
use parking_lot::Mutex;
use std::sync::atomic::{AtomicUsize, Ordering};
+use zerocopy::FromBytes;
pub struct Session {
pub(crate) receiving_index: u32,
@@ -91,10 +99,10 @@ impl ReceivingKeyCounterValidator {
// Drop if too far back
return Err(WireGuardError::InvalidCounter);
}
- if !self.check_bit(counter) {
- Ok(())
- } else {
+ if self.check_bit(counter) {
Err(WireGuardError::DuplicateCounter)
+ } else {
+ Ok(())
}
}
@@ -190,81 +198,85 @@ impl Session {
ret
}
- /// src - an IP packet from the interface
- /// dst - pre-allocated space to hold the encapsulating UDP packet to send over the network
- /// returns the size of the formatted packet
- pub(super) fn format_packet_data<'a>(&self, src: &[u8], dst: &'a mut [u8]) -> &'a mut [u8] {
- if dst.len() < src.len() + super::DATA_OVERHEAD_SZ {
- panic!("The destination buffer is too small");
- }
-
+ /// Encapsulate `packet` into a [`WgData`].
+ pub(super) fn format_packet_data(&self, packet: Packet) -> Packet {
let sending_key_counter = self.sending_key_counter.fetch_add(1, Ordering::Relaxed) as u64;
- let (message_type, rest) = dst.split_at_mut(4);
- let (receiver_index, rest) = rest.split_at_mut(4);
- let (counter, data) = rest.split_at_mut(8);
+ let len = DATA_OFFSET + AEAD_SIZE + packet.len();
- message_type.copy_from_slice(&super::DATA.to_le_bytes());
- receiver_index.copy_from_slice(&self.sending_index.to_le_bytes());
- counter.copy_from_slice(&sending_key_counter.to_le_bytes());
+ // TODO: we can remove this allocation by pre-allocating some extra
+ // space at the beginning of `packet`s allocation, and using that.
+ let mut buf = Packet::from_bytes(BytesMut::zeroed(len));
+
+ let data = WgData::mut_from_bytes(buf.buf_mut()).unwrap();
+
+ data.header = WgDataHeader::new()
+ .with_receiver_idx(self.sending_index)
+ .with_counter(sending_key_counter);
// TODO: spec requires padding to 16 bytes, but actually works fine without it
- let n = {
- let mut nonce = [0u8; 12];
- nonce[4..12].copy_from_slice(&sending_key_counter.to_le_bytes());
- data[..src.len()].copy_from_slice(src);
- self.sender
- .seal_in_place_separate_tag(
- Nonce::assume_unique_for_key(nonce),
- Aad::from(&[]),
- &mut data[..src.len()],
- )
- .map(|tag| {
- data[src.len()..src.len() + AEAD_SIZE].copy_from_slice(tag.as_ref());
- src.len() + AEAD_SIZE
- })
- .unwrap()
+ let mut nonce = [0u8; 12];
+ nonce[4..12].copy_from_slice(&sending_key_counter.to_le_bytes());
+ data.encrypted_encapsulated_packet_mut()
+ .copy_from_slice(&packet);
+ self.sender
+ .seal_in_place_separate_tag(
+ Nonce::assume_unique_for_key(nonce),
+ Aad::from(&[]),
+ data.encrypted_encapsulated_packet_mut(),
+ )
+ .map(|tag| {
+ data.tag_mut().copy_from_slice(tag.as_ref());
+ packet.len() + AEAD_SIZE
+ })
+ .expect("encryption must succeed");
+
+ // this won't panic since we've correctly initialized a WgData packet
+ let packet = buf.try_into_wg().expect("is a wireguard packet");
+ let WgKind::Data(packet) = packet else {
+ unreachable!("is a wireguard data packet");
};
- &mut dst[..DATA_OFFSET + n]
+ packet
}
- /// packet - a data packet we received from the network
- /// dst - pre-allocated space to hold the encapsulated IP packet, to send to the interface
- /// dst will always take less space than src
- /// return the size of the encapsulated packet on success
- pub(super) fn receive_packet_data<'a>(
+ /// Decapsulate `packet` and return the decrypted data.
+ pub(super) fn receive_packet_data(
&self,
- packet: PacketData,
- dst: &'a mut [u8],
- ) -> Result<&'a mut [u8], WireGuardError> {
- let ct_len = packet.encrypted_encapsulated_packet.len();
- if dst.len() < ct_len {
- // This is a very incorrect use of the library, therefore panic and not error
- panic!("The destination buffer is too small");
- }
- if packet.receiver_idx != self.receiving_index {
+ mut packet: Packet,
+ ) -> Result {
+ if packet.header.receiver_idx != self.receiving_index {
return Err(WireGuardError::WrongIndex);
}
+
+ let counter = packet.header.counter.get();
+
// Don't reuse counters, in case this is a replay attack we want to quickly check the counter without running expensive decryption
- self.receiving_counter_quick_check(packet.counter)?;
-
- let ret = {
- let mut nonce = [0u8; 12];
- nonce[4..12].copy_from_slice(&packet.counter.to_le_bytes());
- dst[..ct_len].copy_from_slice(packet.encrypted_encapsulated_packet);
- self.receiver
- .open_in_place(
- Nonce::assume_unique_for_key(nonce),
- Aad::from(&[]),
- &mut dst[..ct_len],
- )
- .map_err(|_| WireGuardError::InvalidAeadTag)?
- };
+ self.receiving_counter_quick_check(counter)?;
+
+ let mut nonce = [0u8; 12];
+ nonce[4..12].copy_from_slice(&packet.header.counter.to_bytes());
+
+ // decrypt the data in-place
+ let decrypted_len = self
+ .receiver
+ .open_in_place(
+ Nonce::assume_unique_for_key(nonce),
+ Aad::from(&[]),
+ &mut packet.encrypted_encapsulated_packet_and_tag,
+ )
+ .map_err(|_| WireGuardError::InvalidAeadTag)?
+ .len();
+
+ // shift the packet buffer slice onto the decrypted data
+ let mut packet = packet.into_bytes();
+ let buf = packet.buf_mut();
+ buf.advance(WgDataHeader::LEN);
+ buf.truncate(decrypted_len);
// After decryption is done, check counter again, and mark as received
- self.receiving_counter_mark(packet.counter)?;
- Ok(ret)
+ self.receiving_counter_mark(counter)?;
+ Ok(packet)
}
/// Returns the estimated downstream packet loss for this session
diff --git a/src/noise/timers.rs b/src/noise/timers.rs
index 6b91d57..7b46159 100644
--- a/src/noise/timers.rs
+++ b/src/noise/timers.rs
@@ -1,13 +1,19 @@
// Copyright (c) 2019 Cloudflare, Inc. All rights reserved.
+//
+// Modified by Mullvad VPN.
+// Copyright (c) 2025 Mullvad VPN.
+//
// SPDX-License-Identifier: BSD-3-Clause
use super::errors::WireGuardError;
-use crate::noise::{Tunn, TunnResult};
+use crate::noise::Tunn;
+use crate::packet::WgKind;
+
use std::mem;
use std::ops::{Index, IndexMut};
-
use std::time::Duration;
+use bytes::BytesMut;
#[cfg(feature = "mock-instant")]
use mock_instant::Instant;
@@ -156,8 +162,8 @@ impl Tunn {
if time_now - *t > REJECT_AFTER_TIME {
if let Some(session) = self.sessions[i].take() {
tracing::debug!(
- message = "SESSION_EXPIRED(REJECT_AFTER_TIME)",
- session = session.receiving_index
+ "SESSION_EXPIRED(REJECT_AFTER_TIME): {}",
+ session.receiving_index
);
}
*t = time_now;
@@ -165,14 +171,18 @@ impl Tunn {
}
}
- pub fn update_timers<'a>(&mut self, dst: &'a mut [u8]) -> TunnResult<'a> {
+ /// Update the tunnel timers
+ ///
+ /// This returns `Ok(None)` if no action is needed, `Ok(Some(packet))` if a packet
+ /// (keepalive or handshake) should be sent, or an error if something went wrong.
+ pub fn update_timers(&mut self) -> Result