diff --git a/Cargo.toml b/Cargo.toml index 6b69ab8..4235c03 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,6 +35,13 @@ hmac = "0.12" mock_instant = { version = "0.6.0", optional = true } portable-atomic = "1" +eyre = "0.6.12" +zerocopy = { version = "0.8.27", features = ["std", "derive"] } +bytes = "1.10.1" +duplicate = { version = "2.0.0", default-features = false } +either = "1.15.0" +bitfield-struct = "0.11.0" + [target.'cfg(unix)'.dependencies] nix = { version = "0.30.1", default-features = false, features = [ "time", diff --git a/src/lib.rs b/src/lib.rs index 42415d1..af9f91e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,6 +6,7 @@ //! git clone https://github.com/cloudflare/boringtun.git pub mod noise; +pub mod packet; #[cfg(not(feature = "mock-instant"))] pub(crate) mod sleepyinstant; diff --git a/src/noise/errors.rs b/src/noise/errors.rs index 10513ae..ea1a8ef 100644 --- a/src/noise/errors.rs +++ b/src/noise/errors.rs @@ -19,5 +19,4 @@ pub enum WireGuardError { NoCurrentSession, LockFailed, ConnectionExpired, - UnderLoad, } diff --git a/src/noise/handshake.rs b/src/noise/handshake.rs index e8e49ff..0ca5e47 100644 --- a/src/noise/handshake.rs +++ b/src/noise/handshake.rs @@ -1,10 +1,14 @@ // Copyright (c) 2019 Cloudflare, Inc. All rights reserved. +// +// Modified by Mullvad VPN. +// Copyright (c) 2025 Mullvad VPN. +// // SPDX-License-Identifier: BSD-3-Clause use super::tls::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305}; -use super::{HandshakeInit, HandshakeResponse, PacketCookieReply}; use crate::noise::errors::WireGuardError; use crate::noise::session::Session; +use crate::packet::{Packet, WgCookieReply, WgHandshakeBase, WgHandshakeInit, WgHandshakeResp}; #[cfg(not(feature = "mock-instant"))] use crate::sleepyinstant::Instant; use crate::x25519; @@ -15,6 +19,7 @@ use chacha20poly1305::XChaCha20Poly1305; use rand_core::OsRng; use std::convert::TryInto; use std::time::{Duration, SystemTime}; +use zerocopy::IntoBytes; #[cfg(feature = "mock-instant")] use mock_instant::Instant; @@ -55,7 +60,7 @@ pub(crate) fn b2s_hmac(key: &[u8], data1: &[u8]) -> [u8; 32] { } #[inline] -/// Like b2s_hmac, but chain data1 and data2 together +/// Like [`b2s_hmac`], but chain data1 and data2 together pub(crate) fn b2s_hmac2(key: &[u8], data1: &[u8], data2: &[u8]) -> [u8; 32] { use blake2::digest::Update; type HmacBlake2s = hmac::SimpleHmac; @@ -159,21 +164,21 @@ fn aead_chacha20_open_inner( } #[derive(Debug)] -/// This struct represents a 12 byte [Tai64N](https://cr.yp.to/libtai/tai64.html) timestamp +/// This struct represents a 12 byte [`Tai64N`](https://cr.yp.to/libtai/tai64.html) timestamp struct Tai64N { secs: u64, nano: u32, } #[derive(Debug)] -/// This struct computes a [Tai64N](https://cr.yp.to/libtai/tai64.html) timestamp from current system time +/// This struct computes a [`Tai64N`](https://cr.yp.to/libtai/tai64.html) timestamp from current system time struct TimeStamper { duration_at_start: Duration, instant_at_start: Instant, } impl TimeStamper { - /// Create a new TimeStamper + /// Create a new [`TimeStamper`] pub fn new() -> TimeStamper { TimeStamper { duration_at_start: SystemTime::now() @@ -236,9 +241,9 @@ struct NoiseParams { static_private: x25519::StaticSecret, /// Static public key of the other party peer_static_public: x25519::PublicKey, - /// A shared key = DH(static_private, peer_static_public) + /// A shared key = DH(`static_private`, `peer_static_public`) static_shared: x25519::SharedSecret, - /// A pre-computation of HASH("mac1----", peer_static_public) for this peer + /// A pre-computation of HASH("mac1----", `peer_static_public`) for this peer sending_mac1_key: [u8; KEY_LEN], /// An optional preshared key preshared_key: Option<[u8; KEY_LEN]>, @@ -326,16 +331,16 @@ pub struct HalfHandshake { pub fn parse_handshake_anon( static_private: &x25519::StaticSecret, static_public: &x25519::PublicKey, - packet: &HandshakeInit, + packet: &WgHandshakeInit, ) -> Result { - let peer_index = packet.sender_idx; + let peer_index = packet.sender_idx.get(); // initiator.chaining_key = HASH(CONSTRUCTION) let mut chaining_key = INITIAL_CHAIN_KEY; // initiator.hash = HASH(HASH(initiator.chaining_key || IDENTIFIER) || responder.static_public) let mut hash = INITIAL_CHAIN_HASH; hash = b2s_hash(&hash, static_public.as_bytes()); // msg.unencrypted_ephemeral = DH_PUBKEY(initiator.ephemeral_private) - let peer_ephemeral_public = x25519::PublicKey::from(*packet.unencrypted_ephemeral); + let peer_ephemeral_public = x25519::PublicKey::from(packet.unencrypted_ephemeral); // initiator.hash = HASH(initiator.hash || msg.unencrypted_ephemeral) hash = b2s_hash(&hash, peer_ephemeral_public.as_bytes()); // temp = HMAC(initiator.chaining_key, msg.unencrypted_ephemeral) @@ -358,7 +363,7 @@ pub fn parse_handshake_anon( &mut peer_static_public, &key, 0, - packet.encrypted_static, + packet.encrypted_static.as_bytes(), &hash, )?; @@ -478,11 +483,10 @@ impl Handshake { self.params.set_static_private(private_key, public_key) } - pub(super) fn receive_handshake_initialization<'a>( + pub(super) fn receive_handshake_initialization( &mut self, - packet: HandshakeInit, - dst: &'a mut [u8], - ) -> Result<(&'a mut [u8], Session), WireGuardError> { + packet: crate::packet::Packet, + ) -> Result<(crate::packet::Packet, Session), WireGuardError> { // initiator.chaining_key = HASH(CONSTRUCTION) let mut chaining_key = INITIAL_CHAIN_KEY; // initiator.hash = HASH(HASH(initiator.chaining_key || IDENTIFIER) || responder.static_public) @@ -491,7 +495,7 @@ impl Handshake { // msg.sender_index = little_endian(initiator.sender_index) let peer_index = packet.sender_idx; // msg.unencrypted_ephemeral = DH_PUBKEY(initiator.ephemeral_private) - let peer_ephemeral_public = x25519::PublicKey::from(*packet.unencrypted_ephemeral); + let peer_ephemeral_public = x25519::PublicKey::from(packet.unencrypted_ephemeral); // initiator.hash = HASH(initiator.hash || msg.unencrypted_ephemeral) hash = b2s_hash(&hash, peer_ephemeral_public.as_bytes()); // temp = HMAC(initiator.chaining_key, msg.unencrypted_ephemeral) @@ -517,7 +521,7 @@ impl Handshake { &mut peer_static_public_decrypted, &key, 0, - packet.encrypted_static, + packet.encrypted_static.as_bytes(), &hash, )?; @@ -528,7 +532,7 @@ impl Handshake { .map_err(|_| WireGuardError::WrongKey)?; // initiator.hash = HASH(initiator.hash || msg.encrypted_static) - hash = b2s_hash(&hash, packet.encrypted_static); + hash = b2s_hash(&hash, packet.encrypted_static.as_bytes()); // temp = HMAC(initiator.chaining_key, DH(initiator.static_private, responder.static_public)) let temp = b2s_hmac(&chaining_key, self.params.static_shared.as_bytes()); // initiator.chaining_key = HMAC(temp, 0x1) @@ -537,7 +541,7 @@ impl Handshake { let key = b2s_hmac2(&temp, &chaining_key, &[0x02]); // msg.encrypted_timestamp = AEAD(key, 0, TAI64N(), initiator.hash) let mut timestamp = [0u8; TIMESTAMP_LEN]; - aead_chacha20_open(&mut timestamp, &key, 0, packet.encrypted_timestamp, &hash)?; + aead_chacha20_open(&mut timestamp, &key, 0, packet.timestamp.as_bytes(), &hash)?; let timestamp = Tai64N::parse(×tamp)?; if !timestamp.after(&self.last_handshake_timestamp) { @@ -547,7 +551,7 @@ impl Handshake { self.last_handshake_timestamp = timestamp; // initiator.hash = HASH(initiator.hash || msg.encrypted_timestamp) - hash = b2s_hash(&hash, packet.encrypted_timestamp); + hash = b2s_hash(&hash, packet.timestamp.as_bytes()); self.previous = std::mem::replace( &mut self.state, @@ -555,28 +559,32 @@ impl Handshake { chaining_key, hash, peer_ephemeral_public, - peer_index, + peer_index: peer_index.get(), }, ); - self.format_handshake_response(dst) + Ok(self.format_handshake_response(packet.into_bytes())) } pub(super) fn receive_handshake_response( &mut self, - packet: HandshakeResponse, + packet: &WgHandshakeResp, ) -> Result { // Check if there is a handshake awaiting a response and return the correct one let (state, is_previous) = match (&self.state, &self.previous) { - (HandshakeState::InitSent(s), _) if s.local_index == packet.receiver_idx => (s, false), - (_, HandshakeState::InitSent(s)) if s.local_index == packet.receiver_idx => (s, true), + (HandshakeState::InitSent(s), _) if s.local_index == packet.receiver_idx.get() => { + (s, false) + } + (_, HandshakeState::InitSent(s)) if s.local_index == packet.receiver_idx.get() => { + (s, true) + } _ => return Err(WireGuardError::UnexpectedPacket), }; let peer_index = packet.sender_idx; let local_index = state.local_index; - let unencrypted_ephemeral = x25519::PublicKey::from(*packet.unencrypted_ephemeral); + let unencrypted_ephemeral = x25519::PublicKey::from(packet.unencrypted_ephemeral); // msg.unencrypted_ephemeral = DH_PUBKEY(responder.ephemeral_private) // responder.hash = HASH(responder.hash || msg.unencrypted_ephemeral) let mut hash = b2s_hash(&state.hash, unencrypted_ephemeral.as_bytes()); @@ -616,7 +624,7 @@ impl Handshake { // responder.hash = HASH(responder.hash || temp2) hash = b2s_hash(&hash, &temp2); // msg.encrypted_nothing = AEAD(key, 0, [empty], responder.hash) - aead_chacha20_open(&mut [], &key, 0, packet.encrypted_nothing, &hash)?; + aead_chacha20_open(&mut [], &key, 0, packet.encrypted_nothing.as_bytes(), &hash)?; // responder.hash = HASH(responder.hash || msg.encrypted_nothing) // hash = b2s_hash(hash, buf[ENC_NOTHING_OFF..ENC_NOTHING_OFF + ENC_NOTHING_SZ]); @@ -641,12 +649,12 @@ impl Handshake { } else { self.state = HandshakeState::None; } - Ok(Session::new(local_index, peer_index, temp3, temp2)) + Ok(Session::new(local_index, peer_index.get(), temp3, temp2)) } pub(super) fn receive_cookie_reply( &mut self, - packet: PacketCookieReply, + packet: &WgCookieReply, ) -> Result<(), WireGuardError> { let mac1 = match self.cookies.last_mac1 { Some(mac) => mac, @@ -663,12 +671,12 @@ impl Handshake { let key = b2s_hash(LABEL_COOKIE, self.params.peer_static_public.as_bytes()); // TODO: pre-compute let payload = Payload { - aad: &mac1[0..16], - msg: packet.encrypted_cookie, + aad: &mac1, + msg: packet.encrypted_cookie.as_bytes(), }; let plaintext = XChaCha20Poly1305::new_from_slice(&key) .unwrap() - .decrypt(packet.nonce.into(), payload) + .decrypt(&packet.nonce.into(), payload) .map_err(|_| WireGuardError::InvalidAeadTag)?; let cookie = plaintext @@ -679,46 +687,23 @@ impl Handshake { } // Compute and append mac1 and mac2 to a handshake message - fn append_mac1_and_mac2<'a>( - &mut self, - local_index: u32, - dst: &'a mut [u8], - ) -> Result<&'a mut [u8], WireGuardError> { - let mac1_off = dst.len() - 32; - let mac2_off = dst.len() - 16; - + fn init_mac1_and_mac2(&mut self, packet: &mut T, local_index: u32) { // msg.mac1 = MAC(HASH(LABEL_MAC1 || responder.static_public), msg[0:offsetof(msg.mac1)]) - let msg_mac1 = b2s_keyed_mac_16(&self.params.sending_mac1_key, &dst[..mac1_off]); - - dst[mac1_off..mac2_off].copy_from_slice(&msg_mac1[..]); + *packet.mac1_mut() = b2s_keyed_mac_16(&self.params.sending_mac1_key, packet.until_mac1()); //msg.mac2 = MAC(initiator.last_received_cookie, msg[0:offsetof(msg.mac2)]) - let msg_mac2: [u8; 16] = if let Some(cookie) = self.cookies.write_cookie { - b2s_keyed_mac_16(&cookie, &dst[..mac2_off]) + *packet.mac2_mut() = if let Some(cookie) = &self.cookies.write_cookie { + b2s_keyed_mac_16(cookie, packet.until_mac2()) } else { [0u8; 16] }; - dst[mac2_off..].copy_from_slice(&msg_mac2[..]); - self.cookies.index = local_index; - self.cookies.last_mac1 = Some(msg_mac1); - Ok(dst) + self.cookies.last_mac1 = Some(*packet.mac1()); } - pub(super) fn format_handshake_initiation<'a>( - &mut self, - dst: &'a mut [u8], - ) -> Result<&'a mut [u8], WireGuardError> { - if dst.len() < super::HANDSHAKE_INIT_SZ { - return Err(WireGuardError::DestinationBufferTooSmall); - } - - let (message_type, rest) = dst.split_at_mut(4); - let (sender_index, rest) = rest.split_at_mut(4); - let (unencrypted_ephemeral, rest) = rest.split_at_mut(32); - let (encrypted_static, rest) = rest.split_at_mut(32 + 16); - let (encrypted_timestamp, _) = rest.split_at_mut(12 + 16); + pub(super) fn format_handshake_initiation(&mut self) -> crate::packet::Packet { + let mut handshake = WgHandshakeInit::new(); let local_index = self.inc_index(); @@ -731,17 +716,20 @@ impl Handshake { let ephemeral_private = x25519::ReusableSecret::random_from_rng(OsRng); // msg.message_type = 1 // msg.reserved_zero = { 0, 0, 0 } - message_type.copy_from_slice(&super::HANDSHAKE_INIT.to_le_bytes()); // msg.sender_index = little_endian(initiator.sender_index) - sender_index.copy_from_slice(&local_index.to_le_bytes()); + handshake.sender_idx.set(local_index); // msg.unencrypted_ephemeral = DH_PUBKEY(initiator.ephemeral_private) - unencrypted_ephemeral + handshake + .unencrypted_ephemeral .copy_from_slice(x25519::PublicKey::from(&ephemeral_private).as_bytes()); // initiator.hash = HASH(initiator.hash || msg.unencrypted_ephemeral) - hash = b2s_hash(&hash, unencrypted_ephemeral); + hash = b2s_hash(&hash, &handshake.unencrypted_ephemeral); // temp = HMAC(initiator.chaining_key, msg.unencrypted_ephemeral) // initiator.chaining_key = HMAC(temp, 0x1) - chaining_key = b2s_hmac(&b2s_hmac(&chaining_key, unencrypted_ephemeral), &[0x01]); + chaining_key = b2s_hmac( + &b2s_hmac(&chaining_key, &handshake.unencrypted_ephemeral), + &[0x01], + ); // temp = HMAC(initiator.chaining_key, DH(initiator.ephemeral_private, responder.static_public)) let ephemeral_shared = ephemeral_private.diffie_hellman(&self.params.peer_static_public); let temp = b2s_hmac(&chaining_key, &ephemeral_shared.to_bytes()); @@ -751,14 +739,14 @@ impl Handshake { let key = b2s_hmac2(&temp, &chaining_key, &[0x02]); // msg.encrypted_static = AEAD(key, 0, initiator.static_public, initiator.hash) aead_chacha20_seal( - encrypted_static, + handshake.encrypted_static.as_mut_bytes(), &key, 0, self.params.static_public.as_bytes(), &hash, ); // initiator.hash = HASH(initiator.hash || msg.encrypted_static) - hash = b2s_hash(&hash, encrypted_static); + hash = b2s_hash(&hash, handshake.encrypted_static.as_bytes()); // temp = HMAC(initiator.chaining_key, DH(initiator.static_private, responder.static_public)) let temp = b2s_hmac(&chaining_key, self.params.static_shared.as_bytes()); // initiator.chaining_key = HMAC(temp, 0x1) @@ -767,9 +755,15 @@ impl Handshake { let key = b2s_hmac2(&temp, &chaining_key, &[0x02]); // msg.encrypted_timestamp = AEAD(key, 0, TAI64N(), initiator.hash) let timestamp = self.stamper.stamp(); - aead_chacha20_seal(encrypted_timestamp, &key, 0, ×tamp, &hash); + aead_chacha20_seal( + handshake.timestamp.as_mut_bytes(), + &key, + 0, + ×tamp, + &hash, + ); // initiator.hash = HASH(initiator.hash || msg.encrypted_timestamp) - hash = b2s_hash(&hash, encrypted_timestamp); + hash = b2s_hash(&hash, handshake.timestamp.as_bytes()); let time_now = Instant::now(); self.previous = std::mem::replace( @@ -783,17 +777,15 @@ impl Handshake { }), ); - self.append_mac1_and_mac2(local_index, &mut dst[..super::HANDSHAKE_INIT_SZ]) + self.init_mac1_and_mac2(&mut handshake, local_index); + + Packet::copy_from(&handshake) } - fn format_handshake_response<'a>( + fn format_handshake_response( &mut self, - dst: &'a mut [u8], - ) -> Result<(&'a mut [u8], Session), WireGuardError> { - if dst.len() < super::HANDSHAKE_RESP_SZ { - return Err(WireGuardError::DestinationBufferTooSmall); - } - + buf: crate::packet::Packet, + ) -> (crate::packet::Packet, Session) { let state = std::mem::replace(&mut self.state, HandshakeState::None); let (mut chaining_key, mut hash, peer_ephemeral_public, peer_index) = match state { HandshakeState::InitReceived { @@ -807,29 +799,23 @@ impl Handshake { } }; - let (message_type, rest) = dst.split_at_mut(4); - let (sender_index, rest) = rest.split_at_mut(4); - let (receiver_index, rest) = rest.split_at_mut(4); - let (unencrypted_ephemeral, rest) = rest.split_at_mut(32); - let (encrypted_nothing, _) = rest.split_at_mut(16); - // responder.ephemeral_private = DH_GENERATE() let ephemeral_private = x25519::ReusableSecret::random_from_rng(OsRng); let local_index = self.inc_index(); // msg.message_type = 2 // msg.reserved_zero = { 0, 0, 0 } - message_type.copy_from_slice(&super::HANDSHAKE_RESP.to_le_bytes()); + let mut resp = WgHandshakeResp::new( + local_index, + peer_index, + *x25519::PublicKey::from(&ephemeral_private).as_bytes(), + ); // msg.sender_index = little_endian(responder.sender_index) - sender_index.copy_from_slice(&local_index.to_le_bytes()); // msg.receiver_index = little_endian(initiator.sender_index) - receiver_index.copy_from_slice(&peer_index.to_le_bytes()); // msg.unencrypted_ephemeral = DH_PUBKEY(initiator.ephemeral_private) - unencrypted_ephemeral - .copy_from_slice(x25519::PublicKey::from(&ephemeral_private).as_bytes()); // responder.hash = HASH(responder.hash || msg.unencrypted_ephemeral) - hash = b2s_hash(&hash, unencrypted_ephemeral); + hash = b2s_hash(&hash, &resp.unencrypted_ephemeral); // temp = HMAC(responder.chaining_key, msg.unencrypted_ephemeral) - let temp = b2s_hmac(&chaining_key, unencrypted_ephemeral); + let temp = b2s_hmac(&chaining_key, &resp.unencrypted_ephemeral); // responder.chaining_key = HMAC(temp, 0x1) chaining_key = b2s_hmac(&temp, &[0x01]); // temp = HMAC(responder.chaining_key, DH(responder.ephemeral_private, initiator.ephemeral_public)) @@ -860,7 +846,7 @@ impl Handshake { // responder.hash = HASH(responder.hash || temp2) hash = b2s_hash(&hash, &temp2); // msg.encrypted_nothing = AEAD(key, 0, [empty], responder.hash) - aead_chacha20_seal(encrypted_nothing, &key, 0, &[], &hash); + aead_chacha20_seal(resp.encrypted_nothing.as_mut_bytes(), &key, 0, &[], &hash); // Derive keys // temp1 = HMAC(initiator.chaining_key, [empty]) @@ -874,9 +860,11 @@ impl Handshake { let temp2 = b2s_hmac(&temp1, &[0x01]); let temp3 = b2s_hmac2(&temp1, &temp2, &[0x02]); - let dst = self.append_mac1_and_mac2(local_index, &mut dst[..super::HANDSHAKE_RESP_SZ])?; + self.init_mac1_and_mac2(&mut resp, local_index); + + let packet = buf.overwrite_with(&resp); - Ok((dst, Session::new(local_index, peer_index, temp2, temp3))) + (packet, Session::new(local_index, peer_index, temp2, temp3)) } } @@ -932,7 +920,7 @@ mod tests { aead_chacha20_seal(&mut encrypted_nothing, &key, counter, &[], &aad); - eprintln!("encrypted_nothing: {:?}", encrypted_nothing); + eprintln!("encrypted_nothing: {encrypted_nothing:?}"); aead_chacha20_open(&mut [], &key, counter, &encrypted_nothing, &aad) .expect("Should open what we just sealed"); diff --git a/src/noise/mod.rs b/src/noise/mod.rs index d36dbdc..5c6a38a 100644 --- a/src/noise/mod.rs +++ b/src/noise/mod.rs @@ -1,4 +1,8 @@ // Copyright (c) 2019 Cloudflare, Inc. All rights reserved. +// +// Modified by Mullvad VPN. +// Copyright (c) 2025 Mullvad VPN. +// // SPDX-License-Identifier: BSD-3-Clause pub mod errors; @@ -9,189 +13,64 @@ pub(crate) mod tls; mod session; mod timers; +use zerocopy::IntoBytes; + use crate::noise::errors::WireGuardError; use crate::noise::handshake::Handshake; use crate::noise::rate_limiter::RateLimiter; use crate::noise::timers::{TimerName, Timers}; +use crate::packet::{Packet, WgCookieReply, WgData, WgHandshakeInit, WgHandshakeResp, WgKind}; use crate::x25519; use std::collections::VecDeque; -use std::convert::{TryFrom, TryInto}; -use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use std::sync::Arc; use std::time::Duration; -/// The default value to use for rate limiting, when no other rate limiter is defined +/// The default value to use for rate limiting, when no other rate limiter is defined. const PEER_HANDSHAKE_RATE_LIMIT: u64 = 10; -const IPV4_MIN_HEADER_SIZE: usize = 20; -const IPV4_LEN_OFF: usize = 2; -const IPV4_SRC_IP_OFF: usize = 12; -const IPV4_DST_IP_OFF: usize = 16; -const IPV4_IP_SZ: usize = 4; - -const IPV6_MIN_HEADER_SIZE: usize = 40; -const IPV6_LEN_OFF: usize = 4; -const IPV6_SRC_IP_OFF: usize = 8; -const IPV6_DST_IP_OFF: usize = 24; -const IPV6_IP_SZ: usize = 16; - -const IP_LEN_SZ: usize = 2; - const MAX_QUEUE_DEPTH: usize = 256; -/// number of sessions in the ring, better keep a PoT +/// number of sessions in the ring, better keep a PoT. const N_SESSIONS: usize = 8; #[derive(Debug)] -pub enum TunnResult<'a> { +pub enum TunnResult { Done, Err(WireGuardError), - WriteToNetwork(&'a mut [u8]), - WriteToTunnelV4(&'a mut [u8], Ipv4Addr), - WriteToTunnelV6(&'a mut [u8], Ipv6Addr), + WriteToNetwork(WgKind), + WriteToTunnel(Packet), } -impl<'a> From for TunnResult<'a> { - fn from(err: WireGuardError) -> TunnResult<'a> { +impl From for TunnResult { + fn from(err: WireGuardError) -> TunnResult { TunnResult::Err(err) } } -/// Tunnel represents a point-to-point WireGuard connection +/// Tunnel represents a point-to-point WireGuard connection. pub struct Tunn { - /// The handshake currently in progress + /// The handshake currently in progress. handshake: handshake::Handshake, - /// The N_SESSIONS most recent sessions, index is session id modulo N_SESSIONS + /// The [`N_SESSIONS`] most recent sessions, index is session id modulo [`N_SESSIONS`]. sessions: [Option; N_SESSIONS], - /// Index of most recently used session + /// Index of most recently used session. current: usize, - /// Queue to store blocked packets - packet_queue: VecDeque>, - /// Keeps tabs on the expiring timers + /// Queue to store blocked packets. + packet_queue: VecDeque, + + /// Keeps tabs on the expiring timers. timers: timers::Timers, tx_bytes: usize, rx_bytes: usize, rate_limiter: Arc, } -type MessageType = u32; -const HANDSHAKE_INIT: MessageType = 1; -const HANDSHAKE_RESP: MessageType = 2; -const COOKIE_REPLY: MessageType = 3; -const DATA: MessageType = 4; - -const HANDSHAKE_INIT_SZ: usize = 148; -const HANDSHAKE_RESP_SZ: usize = 92; -const COOKIE_REPLY_SZ: usize = 64; -const DATA_OVERHEAD_SZ: usize = 32; - -#[derive(Debug)] -pub struct HandshakeInit<'a> { - sender_idx: u32, - unencrypted_ephemeral: &'a [u8; 32], - encrypted_static: &'a [u8], - encrypted_timestamp: &'a [u8], -} - -#[derive(Debug)] -pub struct HandshakeResponse<'a> { - sender_idx: u32, - pub receiver_idx: u32, - unencrypted_ephemeral: &'a [u8; 32], - encrypted_nothing: &'a [u8], -} - -#[derive(Debug)] -pub struct PacketCookieReply<'a> { - pub receiver_idx: u32, - nonce: &'a [u8], - encrypted_cookie: &'a [u8], -} - -#[derive(Debug)] -pub struct PacketData<'a> { - pub receiver_idx: u32, - counter: u64, - encrypted_encapsulated_packet: &'a [u8], -} - -/// Describes a packet from network -#[derive(Debug)] -pub enum Packet<'a> { - HandshakeInit(HandshakeInit<'a>), - HandshakeResponse(HandshakeResponse<'a>), - PacketCookieReply(PacketCookieReply<'a>), - PacketData(PacketData<'a>), -} - impl Tunn { - #[inline(always)] - pub fn parse_incoming_packet(src: &'_ [u8]) -> Result, WireGuardError> { - if src.len() < 4 { - return Err(WireGuardError::InvalidPacket); - } - - // Checks the type, as well as the reserved zero fields - let packet_type = u32::from_le_bytes(src[0..4].try_into().unwrap()); - - Ok(match (packet_type, src.len()) { - (HANDSHAKE_INIT, HANDSHAKE_INIT_SZ) => Packet::HandshakeInit(HandshakeInit { - sender_idx: u32::from_le_bytes(src[4..8].try_into().unwrap()), - unencrypted_ephemeral: <&[u8; 32] as TryFrom<&[u8]>>::try_from(&src[8..40]) - .expect("length already checked above"), - encrypted_static: &src[40..88], - encrypted_timestamp: &src[88..116], - }), - (HANDSHAKE_RESP, HANDSHAKE_RESP_SZ) => Packet::HandshakeResponse(HandshakeResponse { - sender_idx: u32::from_le_bytes(src[4..8].try_into().unwrap()), - receiver_idx: u32::from_le_bytes(src[8..12].try_into().unwrap()), - unencrypted_ephemeral: <&[u8; 32] as TryFrom<&[u8]>>::try_from(&src[12..44]) - .expect("length already checked above"), - encrypted_nothing: &src[44..60], - }), - (COOKIE_REPLY, COOKIE_REPLY_SZ) => Packet::PacketCookieReply(PacketCookieReply { - receiver_idx: u32::from_le_bytes(src[4..8].try_into().unwrap()), - nonce: &src[8..32], - encrypted_cookie: &src[32..64], - }), - (DATA, DATA_OVERHEAD_SZ..=std::usize::MAX) => Packet::PacketData(PacketData { - receiver_idx: u32::from_le_bytes(src[4..8].try_into().unwrap()), - counter: u64::from_le_bytes(src[8..16].try_into().unwrap()), - encrypted_encapsulated_packet: &src[16..], - }), - _ => return Err(WireGuardError::InvalidPacket), - }) - } - pub fn is_expired(&self) -> bool { self.handshake.is_expired() } - pub fn dst_address(packet: &[u8]) -> Option { - if packet.is_empty() { - return None; - } - - match packet[0] >> 4 { - 4 if packet.len() >= IPV4_MIN_HEADER_SIZE => { - let addr_bytes: [u8; IPV4_IP_SZ] = packet - [IPV4_DST_IP_OFF..IPV4_DST_IP_OFF + IPV4_IP_SZ] - .try_into() - .unwrap(); - Some(IpAddr::from(addr_bytes)) - } - 6 if packet.len() >= IPV6_MIN_HEADER_SIZE => { - let addr_bytes: [u8; IPV6_IP_SZ] = packet - [IPV6_DST_IP_OFF..IPV6_DST_IP_OFF + IPV6_IP_SZ] - .try_into() - .unwrap(); - Some(IpAddr::from(addr_bytes)) - } - _ => None, - } - } - - /// Create a new tunnel using own private key and the peer public key + /// Create a new tunnel using own private key and the peer public key. pub fn new( static_private: x25519::StaticSecret, peer_static_public: x25519::PublicKey, @@ -224,7 +103,7 @@ impl Tunn { } } - /// Update the private key and clear existing sessions + /// Update the private key and clear existing sessions. pub fn set_static_private( &mut self, static_private: x25519::StaticSecret, @@ -242,91 +121,63 @@ impl Tunn { } } - /// Encapsulate a single packet from the tunnel interface. - /// Returns TunnResult. + /// Encapsulate a single packet. + /// + /// If there's an active session, return the encapsulated packet. Otherwise, if needed, return + /// a handshake initiation. `None` is returned if a handshake is already in progress. In that + /// case, the packet is added to a queue. + pub fn handle_outgoing_packet(&mut self, packet: Packet) -> Option { + match self.encapsulate_with_session(packet) { + Ok(encapsulated_packet) => Some(encapsulated_packet.into()), + Err(packet) => { + // If there is no session, queue the packet for future retry + self.queue_packet(packet); + // Initiate a new handshake if none is in progress + self.format_handshake_initiation(false).map(Into::into) + } + } + } + + /// Encapsulate a single packet into a [`WgData`]. /// - /// # Panics - /// Panics if dst buffer is too small. - /// Size of dst should be at least src.len() + 32, and no less than 148 bytes. - pub fn encapsulate<'a>(&mut self, src: &[u8], dst: &'a mut [u8]) -> TunnResult<'a> { + /// Returns `Err(original_packet)` if there is no active session. + pub fn encapsulate_with_session(&mut self, packet: Packet) -> Result, Packet> { let current = self.current; if let Some(ref session) = self.sessions[current % N_SESSIONS] { // Send the packet using an established session - let packet = session.format_packet_data(src, dst); + let packet = session.format_packet_data(packet); self.timer_tick(TimerName::TimeLastPacketSent); // Exclude Keepalive packets from timer update. - if !src.is_empty() { + if !packet.as_bytes().is_empty() { self.timer_tick(TimerName::TimeLastDataPacketSent); } - self.tx_bytes += src.len(); - return TunnResult::WriteToNetwork(packet); - } - - // If there is no session, queue the packet for future retry - self.queue_packet(src); - // Initiate a new handshake if none is in progress - self.format_handshake_initiation(dst, false) - } - - /// Receives a UDP datagram from the network and parses it. - /// Returns TunnResult. - /// - /// If the result is of type TunnResult::WriteToNetwork, should repeat the call with empty datagram, - /// until TunnResult::Done is returned. If batch processing packets, it is OK to defer until last - /// packet is processed. - pub fn decapsulate<'a>( - &mut self, - src_addr: Option, - datagram: &[u8], - dst: &'a mut [u8], - ) -> TunnResult<'a> { - if datagram.is_empty() { - // Indicates a repeated call - return self.send_queued_packet(dst); + self.tx_bytes += packet.as_bytes().len(); + Ok(packet) + } else { + Err(packet) } - - let mut cookie = [0u8; COOKIE_REPLY_SZ]; - let packet = match self - .rate_limiter - .verify_packet(src_addr, datagram, &mut cookie) - { - Ok(packet) => packet, - Err(TunnResult::WriteToNetwork(cookie)) => { - dst[..cookie.len()].copy_from_slice(cookie); - return TunnResult::WriteToNetwork(&mut dst[..cookie.len()]); - } - Err(TunnResult::Err(e)) => return TunnResult::Err(e), - _ => unreachable!(), - }; - - self.handle_verified_packet(packet, dst) } - pub(crate) fn handle_verified_packet<'a>( - &mut self, - packet: Packet, - dst: &'a mut [u8], - ) -> TunnResult<'a> { + pub fn handle_incoming_packet(&mut self, packet: WgKind) -> TunnResult { match packet { - Packet::HandshakeInit(p) => self.handle_handshake_init(p, dst), - Packet::HandshakeResponse(p) => self.handle_handshake_response(p, dst), - Packet::PacketCookieReply(p) => self.handle_cookie_reply(p), - Packet::PacketData(p) => self.handle_data(p, dst), + WgKind::HandshakeInit(p) => self.handle_handshake_init(p), + WgKind::HandshakeResp(p) => self.handle_handshake_response(p), + WgKind::CookieReply(p) => self.handle_cookie_reply(&p), + WgKind::Data(p) => self.handle_data(p), } .unwrap_or_else(TunnResult::from) } - fn handle_handshake_init<'a>( + fn handle_handshake_init( &mut self, - p: HandshakeInit, - dst: &'a mut [u8], - ) -> Result, WireGuardError> { + p: Packet, + ) -> Result { tracing::debug!( message = "Received handshake_initiation", - remote_idx = p.sender_idx + sender_idx = p.sender_idx.get() ); - let (packet, session) = self.handshake.receive_handshake_initialization(p, dst)?; + let (packet, session) = self.handshake.receive_handshake_initialization(p)?; // Store new session in ring buffer let index = session.local_index(); @@ -338,23 +189,25 @@ impl Tunn { tracing::debug!(message = "Sending handshake_response", local_idx = index); - Ok(TunnResult::WriteToNetwork(packet)) + Ok(TunnResult::WriteToNetwork(packet.into())) } - fn handle_handshake_response<'a>( + fn handle_handshake_response( &mut self, - p: HandshakeResponse, - dst: &'a mut [u8], - ) -> Result, WireGuardError> { + p: Packet, + ) -> Result { tracing::debug!( message = "Received handshake_response", - local_idx = p.receiver_idx, - remote_idx = p.sender_idx + local_idx = p.receiver_idx.get(), + remote_idx = p.sender_idx.get() ); - let session = self.handshake.receive_handshake_response(p)?; + let session = self.handshake.receive_handshake_response(&p)?; + + let mut p = p.into_bytes(); + p.truncate(0); - let keepalive_packet = session.format_packet_data(&[], dst); + let keepalive_packet = session.format_packet_data(p); // Store new session in ring buffer let l_idx = session.local_index(); let index = l_idx % N_SESSIONS; @@ -366,16 +219,13 @@ impl Tunn { tracing::debug!("Sending keepalive"); - Ok(TunnResult::WriteToNetwork(keepalive_packet)) // Send a keepalive as a response + Ok(TunnResult::WriteToNetwork(keepalive_packet.into())) // Send a keepalive as a response } - fn handle_cookie_reply<'a>( - &mut self, - p: PacketCookieReply, - ) -> Result, WireGuardError> { + fn handle_cookie_reply(&mut self, p: &WgCookieReply) -> Result { tracing::debug!( message = "Received cookie_reply", - local_idx = p.receiver_idx + local_idx = p.receiver_idx.get() ); self.handshake.receive_cookie_reply(p)?; @@ -403,13 +253,21 @@ impl Tunn { } } - /// Decrypts a data packet, and stores the decapsulated packet in dst. - fn handle_data<'a>( + /// Decrypt a data packet, and return a [`TunnResult::WriteToTunnel`] (`Ipv4` or `Ipv6`) if successful. + fn handle_data(&mut self, packet: Packet) -> Result { + let decapsulated_packet = self.decapsulate_with_session(packet)?; + + self.timer_tick(TimerName::TimeLastDataPacketReceived); + self.rx_bytes += decapsulated_packet.as_bytes().len(); + + Ok(TunnResult::WriteToTunnel(decapsulated_packet)) + } + + pub fn decapsulate_with_session( &mut self, - packet: PacketData, - dst: &'a mut [u8], - ) -> Result, WireGuardError> { - let r_idx = packet.receiver_idx as usize; + packet: Packet, + ) -> Result { + let r_idx = packet.header.receiver_idx.get() as usize; let idx = r_idx % N_SESSIONS; // Get the (probably) right session @@ -419,25 +277,26 @@ impl Tunn { tracing::trace!(message = "No current session available", remote_idx = r_idx); WireGuardError::NoCurrentSession })?; - session.receive_packet_data(packet, dst)? + session.receive_packet_data(packet)? }; self.set_current_session(r_idx); self.timer_tick(TimerName::TimeLastPacketReceived); - Ok(self.validate_decapsulated_packet(decapsulated_packet)) + Ok(decapsulated_packet) } - /// Formats a new handshake initiation message and store it in dst. If force_resend is true will send - /// a new handshake, even if a handshake is already in progress (for example when a handshake times out) - pub fn format_handshake_initiation<'a>( + /// Return a new handshake if appropriate, or `None` otherwise. + /// + /// If `force_resend` is true will send a new handshake, even if a handshake + /// is already in progress (for example when a handshake times out) + pub fn format_handshake_initiation( &mut self, - dst: &'a mut [u8], force_resend: bool, - ) -> TunnResult<'a> { + ) -> Option> { if self.handshake.is_in_progress() && !force_resend { - return TunnResult::Done; + return None; } if self.handshake.is_expired() { @@ -446,98 +305,31 @@ impl Tunn { let starting_new_handshake = !self.handshake.is_in_progress(); - match self.handshake.format_handshake_initiation(dst) { - Ok(packet) => { - tracing::debug!("Sending handshake_initiation"); - - if starting_new_handshake { - self.timer_tick(TimerName::TimeLastHandshakeStarted); - } - self.timer_tick(TimerName::TimeLastPacketSent); - TunnResult::WriteToNetwork(packet) - } - Err(e) => TunnResult::Err(e), - } - } - - /// Check if an IP packet is v4 or v6, truncate to the length indicated by the length field - /// Returns the truncated packet and the source IP as TunnResult - fn validate_decapsulated_packet<'a>(&mut self, packet: &'a mut [u8]) -> TunnResult<'a> { - let (computed_len, src_ip_address) = match packet.len() { - 0 => return TunnResult::Done, // This is keepalive, and not an error - _ if packet[0] >> 4 == 4 && packet.len() >= IPV4_MIN_HEADER_SIZE => { - let len_bytes: [u8; IP_LEN_SZ] = packet[IPV4_LEN_OFF..IPV4_LEN_OFF + IP_LEN_SZ] - .try_into() - .unwrap(); - let addr_bytes: [u8; IPV4_IP_SZ] = packet - [IPV4_SRC_IP_OFF..IPV4_SRC_IP_OFF + IPV4_IP_SZ] - .try_into() - .unwrap(); - ( - u16::from_be_bytes(len_bytes) as usize, - IpAddr::from(addr_bytes), - ) - } - _ if packet[0] >> 4 == 6 && packet.len() >= IPV6_MIN_HEADER_SIZE => { - let len_bytes: [u8; IP_LEN_SZ] = packet[IPV6_LEN_OFF..IPV6_LEN_OFF + IP_LEN_SZ] - .try_into() - .unwrap(); - let addr_bytes: [u8; IPV6_IP_SZ] = packet - [IPV6_SRC_IP_OFF..IPV6_SRC_IP_OFF + IPV6_IP_SZ] - .try_into() - .unwrap(); - ( - u16::from_be_bytes(len_bytes) as usize + IPV6_MIN_HEADER_SIZE, - IpAddr::from(addr_bytes), - ) - } - _ => return TunnResult::Err(WireGuardError::InvalidPacket), - }; - - if computed_len > packet.len() { - return TunnResult::Err(WireGuardError::InvalidPacket); - } - - self.timer_tick(TimerName::TimeLastDataPacketReceived); - self.rx_bytes += computed_len; + let packet = self.handshake.format_handshake_initiation(); + tracing::debug!("Sending handshake_initiation"); - match src_ip_address { - IpAddr::V4(addr) => TunnResult::WriteToTunnelV4(&mut packet[..computed_len], addr), - IpAddr::V6(addr) => TunnResult::WriteToTunnelV6(&mut packet[..computed_len], addr), - } - } - - /// Get a packet from the queue, and try to encapsulate it - fn send_queued_packet<'a>(&mut self, dst: &'a mut [u8]) -> TunnResult<'a> { - if let Some(packet) = self.dequeue_packet() { - match self.encapsulate(&packet, dst) { - TunnResult::Err(_) => { - // On error, return packet to the queue - self.requeue_packet(packet); - } - r => return r, - } + if starting_new_handshake { + self.timer_tick(TimerName::TimeLastHandshakeStarted); } - TunnResult::Done + self.timer_tick(TimerName::TimeLastPacketSent); + Some(packet) } - /// Push packet to the back of the queue - fn queue_packet(&mut self, packet: &[u8]) { - if self.packet_queue.len() < MAX_QUEUE_DEPTH { - // Drop if too many are already in queue - self.packet_queue.push_back(packet.to_vec()); - } + /// Pop the first queued packet if it exists and try to encapsulate it. + pub fn next_queued_packet(&mut self) -> Option { + self.dequeue_packet() + .and_then(|packet| self.handle_outgoing_packet(packet)) } - /// Push packet to the front of the queue - fn requeue_packet(&mut self, packet: Vec) { + /// Push packet to the back of the queue. + fn queue_packet(&mut self, packet: Packet) { if self.packet_queue.len() < MAX_QUEUE_DEPTH { // Drop if too many are already in queue - self.packet_queue.push_front(packet); + self.packet_queue.push_back(packet); } } - fn dequeue_packet(&mut self) -> Option> { + fn dequeue_packet(&mut self) -> Option { self.packet_queue.pop_front() } @@ -588,10 +380,16 @@ impl Tunn { #[cfg(test)] mod tests { + use std::net::Ipv4Addr; + #[cfg(feature = "mock-instant")] use crate::noise::timers::{REKEY_AFTER_TIME, REKEY_TIMEOUT}; + use crate::packet::Ipv4; use super::*; + use bytes::BytesMut; + #[cfg(feature = "mock-instant")] + use mock_instant::MockClock; use rand_core::{OsRng, RngCore}; fn create_two_tuns() -> (Tunn, Tunn) { @@ -610,84 +408,83 @@ mod tests { (my_tun, their_tun) } - fn create_handshake_init(tun: &mut Tunn) -> Vec { - let mut dst = vec![0u8; 2048]; - let handshake_init = tun.format_handshake_initiation(&mut dst, false); - assert!(matches!(handshake_init, TunnResult::WriteToNetwork(_))); - let handshake_init = if let TunnResult::WriteToNetwork(sent) = handshake_init { - sent - } else { - unreachable!(); - }; - - handshake_init.into() + fn create_handshake_init(tun: &mut Tunn) -> Packet { + tun.format_handshake_initiation(false) + .expect("expected handshake init") } - fn create_handshake_response(tun: &mut Tunn, handshake_init: &[u8]) -> Vec { - let mut dst = vec![0u8; 2048]; - let handshake_resp = tun.decapsulate(None, handshake_init, &mut dst); - assert!(matches!(handshake_resp, TunnResult::WriteToNetwork(_))); + fn create_handshake_response( + tun: &mut Tunn, + handshake_init: Packet, + ) -> Packet { + let handshake_resp = tun.handle_incoming_packet(WgKind::HandshakeInit(handshake_init)); + assert!( + matches!(handshake_resp, TunnResult::WriteToNetwork(_)), + "expected WriteToNetwork, {handshake_resp:?}" + ); - let handshake_resp = if let TunnResult::WriteToNetwork(sent) = handshake_resp { - sent - } else { - unreachable!(); + let TunnResult::WriteToNetwork(handshake_resp) = handshake_resp else { + unreachable!("expected WriteToNetwork"); + }; + + let WgKind::HandshakeResp(handshake_resp) = handshake_resp else { + unreachable!("expected WgHandshakeResp, got {handshake_resp:?}"); }; - handshake_resp.into() + handshake_resp } - fn parse_handshake_resp(tun: &mut Tunn, handshake_resp: &[u8]) -> Vec { - let mut dst = vec![0u8; 2048]; - let keepalive = tun.decapsulate(None, handshake_resp, &mut dst); + fn parse_handshake_resp( + tun: &mut Tunn, + handshake_resp: Packet, + ) -> Packet { + let keepalive = tun.handle_incoming_packet(WgKind::HandshakeResp(handshake_resp)); assert!(matches!(keepalive, TunnResult::WriteToNetwork(_))); - let keepalive = if let TunnResult::WriteToNetwork(sent) = keepalive { - sent - } else { - unreachable!(); + let TunnResult::WriteToNetwork(keepalive) = keepalive else { + unreachable!("expected WriteToNetwork") + }; + + let WgKind::Data(keepalive) = keepalive else { + unreachable!("expected WgData, got {keepalive:?}"); }; - keepalive.into() + keepalive } - fn parse_keepalive(tun: &mut Tunn, keepalive: &[u8]) { - let mut dst = vec![0u8; 2048]; - let keepalive = tun.decapsulate(None, keepalive, &mut dst); - assert!(matches!(keepalive, TunnResult::Done)); + fn parse_keepalive(tun: &mut Tunn, keepalive: Packet) { + let result = tun.handle_incoming_packet(WgKind::Data(keepalive)); + assert!(matches!(result, TunnResult::WriteToTunnel(p) if p.is_empty())); } fn create_two_tuns_and_handshake() -> (Tunn, Tunn) { let (mut my_tun, mut their_tun) = create_two_tuns(); let init = create_handshake_init(&mut my_tun); - let resp = create_handshake_response(&mut their_tun, &init); - let keepalive = parse_handshake_resp(&mut my_tun, &resp); - parse_keepalive(&mut their_tun, &keepalive); + let resp = create_handshake_response(&mut their_tun, init); + let keepalive = parse_handshake_resp(&mut my_tun, resp); + parse_keepalive(&mut their_tun, keepalive); (my_tun, their_tun) } - fn create_ipv4_udp_packet() -> Vec { + fn create_ipv4_udp_packet() -> Packet { let header = etherparse::PacketBuilder::ipv4([192, 168, 1, 2], [192, 168, 1, 3], 5).udp(5678, 23); let payload = [0, 1, 2, 3]; let mut packet = Vec::::with_capacity(header.size(payload.len())); header.write(&mut packet, &payload).unwrap(); - packet + let packet = Packet::from_bytes(BytesMut::from(&packet[..])); + + packet.try_into_ipvx().unwrap().unwrap_left() } #[cfg(feature = "mock-instant")] fn update_timer_results_in_handshake(tun: &mut Tunn) { - let mut dst = vec![0u8; 2048]; - let result = tun.update_timers(&mut dst); - assert!(matches!(result, TunnResult::WriteToNetwork(_))); - let packet_data = if let TunnResult::WriteToNetwork(data) = result { - data - } else { - unreachable!(); - }; - let packet = Tunn::parse_incoming_packet(packet_data).unwrap(); - assert!(matches!(packet, Packet::HandshakeInit(_))); + let packet = tun + .update_timers() + .expect("update_timers should succeed") + .unwrap(); + assert!(matches!(packet, WgKind::HandshakeInit(..))); } #[test] @@ -698,59 +495,136 @@ mod tests { #[test] fn handshake_init() { let (mut my_tun, _their_tun) = create_two_tuns(); + let _init = create_handshake_init(&mut my_tun); + } + + #[test] + // Verify that a valid hanshake is accepted by two linked peers when rate limiting is not + // applied. + fn verify_handshake() { + let (mut my_tun, mut their_tun) = create_two_tuns(); let init = create_handshake_init(&mut my_tun); - let packet = Tunn::parse_incoming_packet(&init).unwrap(); - assert!(matches!(packet, Packet::HandshakeInit(_))); + let resp = create_handshake_response(&mut their_tun, init.clone()); + + their_tun + .rate_limiter + .verify_handshake(Ipv4Addr::LOCALHOST.into(), init) + .expect("Handshake init to be valid"); + + my_tun + .rate_limiter + .verify_handshake(Ipv4Addr::LOCALHOST.into(), resp) + .expect("Handshake response to be valid"); + } + + #[test] + #[cfg(feature = "mock-instant")] + /// Verify that cookie reply is sent when rate limit is hit. + /// And that handshakes are accepted under load with a valid mac2. + fn verify_cookie_reply() { + let forced_handshake_init = |tun: &mut Tunn| { + tun.format_handshake_initiation(true) + .expect("expected handshake init") + }; + + let (mut my_tun, their_tun) = create_two_tuns(); + + for _ in 0..HANDSHAKE_RATE_LIMIT { + let init = forced_handshake_init(&mut my_tun); + their_tun + .rate_limiter + .verify_handshake(Ipv4Addr::LOCALHOST.into(), init) + .expect("Handshake init to be valid"); + + MockClock::advance(Duration::from_micros(1)); + } + + // Next handshake should trigger rate limiting + let init = forced_handshake_init(&mut my_tun); + let Err(TunnResult::WriteToNetwork(WgKind::CookieReply(cookie_resp))) = their_tun + .rate_limiter + .verify_handshake(Ipv4Addr::LOCALHOST.into(), init) + else { + panic!("expected cookie reply due to rate limiting"); + }; + + // Verify that cookie reply can be processed + // And that the peer accepts our handshake after that + my_tun + .handle_cookie_reply(&cookie_resp) + .expect("expected cookie reply to be valid"); + + let init = forced_handshake_init(&mut my_tun); + their_tun + .rate_limiter + .verify_handshake(Ipv4Addr::LOCALHOST.into(), init) + .expect("should accept handshake with cookie"); + } + + #[test] + // Verify that an invalid hanshake is rejected by both linked peers. + fn reject_handshake() { + let (mut my_tun, mut their_tun) = create_two_tuns(); + let mut init = create_handshake_init(&mut my_tun); + let mut resp = create_handshake_response(&mut their_tun, init.clone()); + + // Mess with the mac of both the handshake init & handshake response packets. + std::mem::swap(&mut init.mac1, &mut resp.mac1); + + their_tun + .rate_limiter + .verify_handshake(Ipv4Addr::LOCALHOST.into(), init.clone()) + .map(|packet| packet.mac1) + .expect_err("Handshake init to be invalid"); + + my_tun + .rate_limiter + .verify_handshake(Ipv4Addr::LOCALHOST.into(), resp) + .map(|packet| packet.mac1) + .expect_err("Handshake response to be invalid"); } #[test] fn handshake_init_and_response() { let (mut my_tun, mut their_tun) = create_two_tuns(); let init = create_handshake_init(&mut my_tun); - let resp = create_handshake_response(&mut their_tun, &init); - let packet = Tunn::parse_incoming_packet(&resp).unwrap(); - assert!(matches!(packet, Packet::HandshakeResponse(_))); + let _resp = create_handshake_response(&mut their_tun, init); } #[test] fn full_handshake() { let (mut my_tun, mut their_tun) = create_two_tuns(); let init = create_handshake_init(&mut my_tun); - let resp = create_handshake_response(&mut their_tun, &init); - let keepalive = parse_handshake_resp(&mut my_tun, &resp); - let packet = Tunn::parse_incoming_packet(&keepalive).unwrap(); - assert!(matches!(packet, Packet::PacketData(_))); + let resp = create_handshake_response(&mut their_tun, init); + let _keepalive = parse_handshake_resp(&mut my_tun, resp); } #[test] fn full_handshake_plus_timers() { let (mut my_tun, mut their_tun) = create_two_tuns_and_handshake(); // Time has not yet advanced so their is nothing to do - assert!(matches!(my_tun.update_timers(&mut []), TunnResult::Done)); - assert!(matches!(their_tun.update_timers(&mut []), TunnResult::Done)); + assert!(matches!(my_tun.update_timers(), Ok(None))); + assert!(matches!(their_tun.update_timers(), Ok(None))); } #[test] #[cfg(feature = "mock-instant")] fn new_handshake_after_two_mins() { let (mut my_tun, mut their_tun) = create_two_tuns_and_handshake(); - let mut my_dst = [0u8; 1024]; // Advance time 1 second and "send" 1 packet so that we send a handshake // after the timeout mock_instant::MockClock::advance(Duration::from_secs(1)); - assert!(matches!(their_tun.update_timers(&mut []), TunnResult::Done)); - assert!(matches!( - my_tun.update_timers(&mut my_dst), - TunnResult::Done - )); + assert!(matches!(their_tun.update_timers(), Ok(None))); + assert!(matches!(my_tun.update_timers(), Ok(None))); let sent_packet_buf = create_ipv4_udp_packet(); - let data = my_tun.encapsulate(&sent_packet_buf, &mut my_dst); - assert!(matches!(data, TunnResult::WriteToNetwork(_))); + let _data = my_tun + .handle_outgoing_packet(sent_packet_buf.into_bytes()) + .expect("expected encapsulated packet"); //Advance to timeout mock_instant::MockClock::advance(REKEY_AFTER_TIME); - assert!(matches!(their_tun.update_timers(&mut []), TunnResult::Done)); + assert!(matches!(their_tun.update_timers(), Ok(None))); update_timer_results_in_handshake(&mut my_tun); } @@ -759,9 +633,7 @@ mod tests { fn handshake_no_resp_rekey_timeout() { let (mut my_tun, _their_tun) = create_two_tuns(); - let init = create_handshake_init(&mut my_tun); - let packet = Tunn::parse_incoming_packet(&init).unwrap(); - assert!(matches!(packet, Packet::HandshakeInit(_))); + let _init = create_handshake_init(&mut my_tun); mock_instant::MockClock::advance(REKEY_TIMEOUT); update_timer_results_in_handshake(&mut my_tun) @@ -770,26 +642,21 @@ mod tests { #[test] fn one_ip_packet() { let (mut my_tun, mut their_tun) = create_two_tuns_and_handshake(); - let mut my_dst = [0u8; 1024]; - let mut their_dst = [0u8; 1024]; let sent_packet_buf = create_ipv4_udp_packet(); - let data = my_tun.encapsulate(&sent_packet_buf, &mut my_dst); - assert!(matches!(data, TunnResult::WriteToNetwork(_))); - let data = if let TunnResult::WriteToNetwork(sent) = data { - sent - } else { - unreachable!(); - }; + let data = my_tun + .handle_outgoing_packet(sent_packet_buf.clone().into_bytes()) + .unwrap(); + + assert!(matches!(data, WgKind::Data(..))); - let data = their_tun.decapsulate(None, data, &mut their_dst); - assert!(matches!(data, TunnResult::WriteToTunnelV4(..))); - let recv_packet_buf = if let TunnResult::WriteToTunnelV4(recv, _addr) = data { + let data = their_tun.handle_incoming_packet(data); + let recv_packet_buf = if let TunnResult::WriteToTunnel(recv) = data { recv } else { - unreachable!(); + unreachable!("expected WritetoTunnelV4"); }; - assert_eq!(sent_packet_buf, recv_packet_buf); + assert_eq!(sent_packet_buf.as_bytes(), recv_packet_buf.as_bytes()); } } diff --git a/src/noise/rate_limiter.rs b/src/noise/rate_limiter.rs index 2bd5f96..b9d6add 100644 --- a/src/noise/rate_limiter.rs +++ b/src/noise/rate_limiter.rs @@ -1,12 +1,20 @@ +// Copyright (c) 2019 Cloudflare, Inc. All rights reserved. +// +// Modified by Mullvad VPN. +// Copyright (c) 2025 Mullvad VPN. +// +// SPDX-License-Identifier: BSD-3-Clause + use super::handshake::{b2s_hash, b2s_keyed_mac_16, b2s_keyed_mac_16_2, b2s_mac_24}; use crate::noise::handshake::{LABEL_COOKIE, LABEL_MAC1}; -use crate::noise::{HandshakeInit, HandshakeResponse, Packet, Tunn, TunnResult, WireGuardError}; +use crate::noise::{TunnResult, WireGuardError}; +use crate::packet::{Packet, WgCookieReply, WgHandshakeBase, WgKind}; #[cfg(feature = "mock-instant")] use mock_instant::Instant; -use portable_atomic::AtomicU64; use std::net::IpAddr; -use std::sync::atomic::Ordering; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::time::Duration; #[cfg(not(feature = "mock-instant"))] use crate::sleepyinstant::Instant; @@ -22,14 +30,15 @@ const COOKIE_REFRESH: u64 = 128; // Use 128 and not 120 so the compiler can opti const COOKIE_SIZE: usize = 16; const COOKIE_NONCE_SIZE: usize = 24; -/// How often should reset count in seconds -const RESET_PERIOD: u64 = 1; +/// How often to reset the under-load counter +const RESET_PERIOD: Duration = Duration::from_secs(1); type Cookie = [u8; COOKIE_SIZE]; /// There are two places where WireGuard requires "randomness" for cookies /// * The 24 byte nonce in the cookie massage - here the only goal is to avoid nonce reuse /// * A secret value that changes every two minutes +/// /// Because the main goal of the cookie is simply for a party to prove ownership of an IP address /// we can relax the randomness definition a bit, in order to avoid locking, because using less /// resources is the main goal of any DoS prevention mechanism. @@ -76,11 +85,11 @@ impl RateLimiter { } /// Reset packet count (ideally should be called with a period of 1 second) - pub fn reset_count(&self) { + pub fn try_reset_count(&self) { // The rate limiter is not very accurate, but at the scale we care about it doesn't matter much let current_time = Instant::now(); let mut last_reset_time = self.last_reset.lock(); - if current_time.duration_since(*last_reset_time).as_secs() >= RESET_PERIOD { + if current_time.duration_since(*last_reset_time) >= RESET_PERIOD { self.count.store(0, Ordering::SeqCst); *last_reset_time = current_time; } @@ -113,82 +122,81 @@ impl RateLimiter { self.count.fetch_add(1, Ordering::SeqCst) >= self.limit } - pub(crate) fn format_cookie_reply<'a>( + pub(crate) fn format_cookie_reply( &self, idx: u32, cookie: Cookie, mac1: &[u8], - dst: &'a mut [u8], - ) -> Result<&'a mut [u8], WireGuardError> { - if dst.len() < super::COOKIE_REPLY_SZ { - return Err(WireGuardError::DestinationBufferTooSmall); - } - - let (message_type, rest) = dst.split_at_mut(4); - let (receiver_index, rest) = rest.split_at_mut(4); - let (nonce, rest) = rest.split_at_mut(24); - let (encrypted_cookie, _) = rest.split_at_mut(16 + 16); + ) -> WgCookieReply { + let mut wg_cookie_reply = WgCookieReply::new(); // msg.message_type = 3 // msg.reserved_zero = { 0, 0, 0 } - message_type.copy_from_slice(&super::COOKIE_REPLY.to_le_bytes()); // msg.receiver_index = little_endian(initiator.sender_index) - receiver_index.copy_from_slice(&idx.to_le_bytes()); - nonce.copy_from_slice(&self.nonce()[..]); + wg_cookie_reply.receiver_idx.set(idx); + wg_cookie_reply.nonce = self.nonce(); let cipher = XChaCha20Poly1305::new(&self.cookie_key); - let iv = GenericArray::from_slice(nonce); + let iv = GenericArray::from_slice(&wg_cookie_reply.nonce); - encrypted_cookie[..16].copy_from_slice(&cookie); + wg_cookie_reply.encrypted_cookie.encrypted = cookie; let tag = cipher - .encrypt_in_place_detached(iv, mac1, &mut encrypted_cookie[..16]) - .map_err(|_| WireGuardError::DestinationBufferTooSmall)?; + .encrypt_in_place_detached(iv, mac1, &mut wg_cookie_reply.encrypted_cookie.encrypted) + .expect("wg_cookie_reply is large enough"); - encrypted_cookie[16..].copy_from_slice(&tag); + wg_cookie_reply.encrypted_cookie.tag = tag.into(); + wg_cookie_reply + } + + /// Decode the packet as wireguard packet. + /// Then, verify the MAC fields on the packet (if any), and apply rate limiting if needed. + pub fn verify_packet(&self, src_addr: IpAddr, packet: Packet) -> Result { + let packet = packet + .try_into_wg() + .map_err(|_err| TunnResult::Err(WireGuardError::InvalidPacket))?; - Ok(&mut dst[..super::COOKIE_REPLY_SZ]) + // Verify and rate limit handshake messages only + match packet { + WgKind::HandshakeInit(packet) => self + .verify_handshake(src_addr, packet) + .map(WgKind::HandshakeInit), + WgKind::HandshakeResp(packet) => self + .verify_handshake(src_addr, packet) + .map(WgKind::HandshakeResp), + _ => Ok(packet), + } } - /// Verify the MAC fields on the datagram, and apply rate limiting if needed - pub fn verify_packet<'a, 'b>( + /// Verify the MAC fields on the handshake, and apply rate limiting if needed. + pub(crate) fn verify_handshake( &self, - src_addr: Option, - src: &'a [u8], - dst: &'b mut [u8], - ) -> Result, TunnResult<'b>> { - let packet = Tunn::parse_incoming_packet(src)?; + src_addr: IpAddr, + handshake: Packet

, + ) -> Result, TunnResult> { + let sender_idx = handshake.sender_idx(); + let mac1 = handshake.mac1(); + let mac2 = handshake.mac2(); + + let computed_mac1 = b2s_keyed_mac_16(&self.mac1_key, handshake.until_mac1()); + if verify_slices_are_equal(&computed_mac1, mac1).is_err() { + return Err(TunnResult::Err(WireGuardError::InvalidMac)); + } - // Verify and rate limit handshake messages only - if let Packet::HandshakeInit(HandshakeInit { sender_idx, .. }) - | Packet::HandshakeResponse(HandshakeResponse { sender_idx, .. }) = packet - { - let (msg, macs) = src.split_at(src.len() - 32); - let (mac1, mac2) = macs.split_at(16); - - let computed_mac1 = b2s_keyed_mac_16(&self.mac1_key, msg); - verify_slices_are_equal(&computed_mac1[..16], mac1) - .map_err(|_| TunnResult::Err(WireGuardError::InvalidMac))?; - - if self.is_under_load() { - let addr = match src_addr { - None => return Err(TunnResult::Err(WireGuardError::UnderLoad)), - Some(addr) => addr, - }; - - // Only given an address can we validate mac2 - let cookie = self.current_cookie(addr); - let computed_mac2 = b2s_keyed_mac_16_2(&cookie, msg, mac1); - - if verify_slices_are_equal(&computed_mac2[..16], mac2).is_err() { - let cookie_packet = self - .format_cookie_reply(sender_idx, cookie, mac1, dst) - .map_err(TunnResult::Err)?; - return Err(TunnResult::WriteToNetwork(cookie_packet)); - } + if self.is_under_load() { + let cookie = self.current_cookie(src_addr); + let computed_mac2 = b2s_keyed_mac_16_2(&cookie, handshake.until_mac1(), mac1); + + if verify_slices_are_equal(&computed_mac2, mac2).is_err() { + let cookie_reply = self.format_cookie_reply(sender_idx, cookie, mac1); + let packet = handshake.overwrite_with(&cookie_reply); + return Err(TunnResult::WriteToNetwork(packet.into())); } + + // If under load but mac2 is valid, allow the handshake + return Ok(handshake); } - Ok(packet) + Ok(handshake) } } diff --git a/src/noise/session.rs b/src/noise/session.rs index b523f53..af9a314 100644 --- a/src/noise/session.rs +++ b/src/noise/session.rs @@ -1,11 +1,19 @@ // Copyright (c) 2019 Cloudflare, Inc. All rights reserved. +// +// Modified by Mullvad VPN. +// Copyright (c) 2025 Mullvad VPN. +// // SPDX-License-Identifier: BSD-3-Clause use super::tls::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305}; -use super::PacketData; -use crate::noise::errors::WireGuardError; +use crate::{ + noise::errors::WireGuardError, + packet::{Packet, WgData, WgDataHeader, WgKind}, +}; +use bytes::{Buf, BytesMut}; use parking_lot::Mutex; use std::sync::atomic::{AtomicUsize, Ordering}; +use zerocopy::FromBytes; pub struct Session { pub(crate) receiving_index: u32, @@ -91,10 +99,10 @@ impl ReceivingKeyCounterValidator { // Drop if too far back return Err(WireGuardError::InvalidCounter); } - if !self.check_bit(counter) { - Ok(()) - } else { + if self.check_bit(counter) { Err(WireGuardError::DuplicateCounter) + } else { + Ok(()) } } @@ -190,81 +198,85 @@ impl Session { ret } - /// src - an IP packet from the interface - /// dst - pre-allocated space to hold the encapsulating UDP packet to send over the network - /// returns the size of the formatted packet - pub(super) fn format_packet_data<'a>(&self, src: &[u8], dst: &'a mut [u8]) -> &'a mut [u8] { - if dst.len() < src.len() + super::DATA_OVERHEAD_SZ { - panic!("The destination buffer is too small"); - } - + /// Encapsulate `packet` into a [`WgData`]. + pub(super) fn format_packet_data(&self, packet: Packet) -> Packet { let sending_key_counter = self.sending_key_counter.fetch_add(1, Ordering::Relaxed) as u64; - let (message_type, rest) = dst.split_at_mut(4); - let (receiver_index, rest) = rest.split_at_mut(4); - let (counter, data) = rest.split_at_mut(8); + let len = DATA_OFFSET + AEAD_SIZE + packet.len(); - message_type.copy_from_slice(&super::DATA.to_le_bytes()); - receiver_index.copy_from_slice(&self.sending_index.to_le_bytes()); - counter.copy_from_slice(&sending_key_counter.to_le_bytes()); + // TODO: we can remove this allocation by pre-allocating some extra + // space at the beginning of `packet`s allocation, and using that. + let mut buf = Packet::from_bytes(BytesMut::zeroed(len)); + + let data = WgData::mut_from_bytes(buf.buf_mut()).unwrap(); + + data.header = WgDataHeader::new() + .with_receiver_idx(self.sending_index) + .with_counter(sending_key_counter); // TODO: spec requires padding to 16 bytes, but actually works fine without it - let n = { - let mut nonce = [0u8; 12]; - nonce[4..12].copy_from_slice(&sending_key_counter.to_le_bytes()); - data[..src.len()].copy_from_slice(src); - self.sender - .seal_in_place_separate_tag( - Nonce::assume_unique_for_key(nonce), - Aad::from(&[]), - &mut data[..src.len()], - ) - .map(|tag| { - data[src.len()..src.len() + AEAD_SIZE].copy_from_slice(tag.as_ref()); - src.len() + AEAD_SIZE - }) - .unwrap() + let mut nonce = [0u8; 12]; + nonce[4..12].copy_from_slice(&sending_key_counter.to_le_bytes()); + data.encrypted_encapsulated_packet_mut() + .copy_from_slice(&packet); + self.sender + .seal_in_place_separate_tag( + Nonce::assume_unique_for_key(nonce), + Aad::from(&[]), + data.encrypted_encapsulated_packet_mut(), + ) + .map(|tag| { + data.tag_mut().copy_from_slice(tag.as_ref()); + packet.len() + AEAD_SIZE + }) + .expect("encryption must succeed"); + + // this won't panic since we've correctly initialized a WgData packet + let packet = buf.try_into_wg().expect("is a wireguard packet"); + let WgKind::Data(packet) = packet else { + unreachable!("is a wireguard data packet"); }; - &mut dst[..DATA_OFFSET + n] + packet } - /// packet - a data packet we received from the network - /// dst - pre-allocated space to hold the encapsulated IP packet, to send to the interface - /// dst will always take less space than src - /// return the size of the encapsulated packet on success - pub(super) fn receive_packet_data<'a>( + /// Decapsulate `packet` and return the decrypted data. + pub(super) fn receive_packet_data( &self, - packet: PacketData, - dst: &'a mut [u8], - ) -> Result<&'a mut [u8], WireGuardError> { - let ct_len = packet.encrypted_encapsulated_packet.len(); - if dst.len() < ct_len { - // This is a very incorrect use of the library, therefore panic and not error - panic!("The destination buffer is too small"); - } - if packet.receiver_idx != self.receiving_index { + mut packet: Packet, + ) -> Result { + if packet.header.receiver_idx != self.receiving_index { return Err(WireGuardError::WrongIndex); } + + let counter = packet.header.counter.get(); + // Don't reuse counters, in case this is a replay attack we want to quickly check the counter without running expensive decryption - self.receiving_counter_quick_check(packet.counter)?; - - let ret = { - let mut nonce = [0u8; 12]; - nonce[4..12].copy_from_slice(&packet.counter.to_le_bytes()); - dst[..ct_len].copy_from_slice(packet.encrypted_encapsulated_packet); - self.receiver - .open_in_place( - Nonce::assume_unique_for_key(nonce), - Aad::from(&[]), - &mut dst[..ct_len], - ) - .map_err(|_| WireGuardError::InvalidAeadTag)? - }; + self.receiving_counter_quick_check(counter)?; + + let mut nonce = [0u8; 12]; + nonce[4..12].copy_from_slice(&packet.header.counter.to_bytes()); + + // decrypt the data in-place + let decrypted_len = self + .receiver + .open_in_place( + Nonce::assume_unique_for_key(nonce), + Aad::from(&[]), + &mut packet.encrypted_encapsulated_packet_and_tag, + ) + .map_err(|_| WireGuardError::InvalidAeadTag)? + .len(); + + // shift the packet buffer slice onto the decrypted data + let mut packet = packet.into_bytes(); + let buf = packet.buf_mut(); + buf.advance(WgDataHeader::LEN); + buf.truncate(decrypted_len); // After decryption is done, check counter again, and mark as received - self.receiving_counter_mark(packet.counter)?; - Ok(ret) + self.receiving_counter_mark(counter)?; + Ok(packet) } /// Returns the estimated downstream packet loss for this session diff --git a/src/noise/timers.rs b/src/noise/timers.rs index 6b91d57..7b46159 100644 --- a/src/noise/timers.rs +++ b/src/noise/timers.rs @@ -1,13 +1,19 @@ // Copyright (c) 2019 Cloudflare, Inc. All rights reserved. +// +// Modified by Mullvad VPN. +// Copyright (c) 2025 Mullvad VPN. +// // SPDX-License-Identifier: BSD-3-Clause use super::errors::WireGuardError; -use crate::noise::{Tunn, TunnResult}; +use crate::noise::Tunn; +use crate::packet::WgKind; + use std::mem; use std::ops::{Index, IndexMut}; - use std::time::Duration; +use bytes::BytesMut; #[cfg(feature = "mock-instant")] use mock_instant::Instant; @@ -156,8 +162,8 @@ impl Tunn { if time_now - *t > REJECT_AFTER_TIME { if let Some(session) = self.sessions[i].take() { tracing::debug!( - message = "SESSION_EXPIRED(REJECT_AFTER_TIME)", - session = session.receiving_index + "SESSION_EXPIRED(REJECT_AFTER_TIME): {}", + session.receiving_index ); } *t = time_now; @@ -165,14 +171,18 @@ impl Tunn { } } - pub fn update_timers<'a>(&mut self, dst: &'a mut [u8]) -> TunnResult<'a> { + /// Update the tunnel timers + /// + /// This returns `Ok(None)` if no action is needed, `Ok(Some(packet))` if a packet + /// (keepalive or handshake) should be sent, or an error if something went wrong. + pub fn update_timers(&mut self) -> Result, WireGuardError> { let mut handshake_initiation_required = false; let mut keepalive_required = false; let time = Instant::now(); if self.timers.should_reset_rr { - self.rate_limiter.reset_count(); + self.rate_limiter.try_reset_count(); } // All the times are counted from tunnel initiation, for efficiency our timers are rounded @@ -193,7 +203,7 @@ impl Tunn { { if self.handshake.is_expired() { - return TunnResult::Err(WireGuardError::ConnectionExpired); + return Err(WireGuardError::ConnectionExpired); } // Clear cookie after COOKIE_EXPIRATION_TIME @@ -209,7 +219,7 @@ impl Tunn { tracing::error!("CONNECTION_EXPIRED(REJECT_AFTER_TIME * 3)"); self.handshake.set_expired(); self.clear_all(); - return TunnResult::Err(WireGuardError::ConnectionExpired); + return Err(WireGuardError::ConnectionExpired); } if let Some(time_init_sent) = self.handshake.timer() { @@ -222,7 +232,7 @@ impl Tunn { tracing::error!("CONNECTION_EXPIRED(REKEY_ATTEMPT_TIME)"); self.handshake.set_expired(); self.clear_all(); - return TunnResult::Err(WireGuardError::ConnectionExpired); + return Err(WireGuardError::ConnectionExpired); } if time_init_sent.elapsed() >= REKEY_TIMEOUT { @@ -301,14 +311,16 @@ impl Tunn { } if handshake_initiation_required { - return self.format_handshake_initiation(dst, true); + return Ok(self.format_handshake_initiation(true).map(Into::into)); } if keepalive_required { - return self.encapsulate(&[], dst); + return Ok( + self.handle_outgoing_packet(crate::packet::Packet::from_bytes(BytesMut::new())) + ); } - TunnResult::Done + Ok(None) } pub fn time_since_last_handshake(&self) -> Option { diff --git a/src/packet/ip.rs b/src/packet/ip.rs new file mode 100644 index 0000000..ab49034 --- /dev/null +++ b/src/packet/ip.rs @@ -0,0 +1,67 @@ +// Copyright (c) 2025 Mullvad VPN AB. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +use std::net::IpAddr; + +use bitfield_struct::bitfield; +use either::Either; +use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout, Unaligned}; + +use crate::packet::{Ipv4, Ipv6}; + +/// A packet bitfield-struct containing the `version`-field that is shared between IPv4 and IPv6. +#[bitfield(u8)] +#[derive(FromBytes, IntoBytes, KnownLayout, Unaligned, Immutable, PartialEq, Eq)] +pub struct IpvxVersion { + #[bits(4)] + pub _unknown: u8, + #[bits(4)] + pub version: u8, +} + +/// An IP packet, including headers, that may be either IPv4 or IPv6. +/// [Read more](crate::packet) +#[repr(C, packed)] +#[derive(FromBytes, IntoBytes, KnownLayout, Unaligned, Immutable)] +pub struct Ip { + /// The IP version field. [Read more](IpvxVersion) + pub header: IpvxVersion, + + /// The rest of the IP packet, + /// i.e. everything in the header that comes after the first byte, and the payload. + /// + /// You probably don't want to access this directly. + pub rest: [u8], +} + +impl Ip { + fn as_v4_or_v6(&self) -> Option> { + let b = self.as_bytes(); + match self.header.version() { + 4 => Ipv4::<[u8]>::ref_from_bytes(b).ok().map(Either::Left), + 6 => Ipv6::<[u8]>::ref_from_bytes(b).ok().map(Either::Right), + _ => None, + } + } + /// Try to extract the source [`IpAddr`]. + /// + /// Returns `None` if the version field is not `4` or `6`, or if the packet is too small. + /// Other than that, no checks are done to ensure this is a valid ip packet. + pub fn source(&self) -> Option { + Some(match self.as_v4_or_v6()? { + Either::Left(ipv4) => ipv4.header.source().into(), + Either::Right(ipv6) => ipv6.header.source().into(), + }) + } + + /// Try to extract the destination [`IpAddr`]. + /// + /// Returns `None` if the version field is not `4` or `6`, or if the packet is too small. + /// Other than that, no checks are done to ensure this is a valid ip packet. + pub fn destination(&self) -> Option { + Some(match self.as_v4_or_v6()? { + Either::Left(ipv4) => ipv4.header.destination().into(), + Either::Right(ipv6) => ipv6.header.destination().into(), + }) + } +} diff --git a/src/packet/ipv4/mod.rs b/src/packet/ipv4/mod.rs new file mode 100644 index 0000000..62dd83b --- /dev/null +++ b/src/packet/ipv4/mod.rs @@ -0,0 +1,266 @@ +// Copyright (c) 2025 Mullvad VPN AB. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +use bitfield_struct::bitfield; +use std::{fmt::Debug, net::Ipv4Addr}; +use zerocopy::{big_endian, FromBytes, Immutable, IntoBytes, KnownLayout, Unaligned}; + +mod protocol; +pub use protocol::*; + +use super::util::size_must_be; + +/// An Ipv4 packet. +/// +/// This is a dynamically sized zerocopy type, which means you can compose packet types like +/// `Ipv4>` and cast them to/from byte slices using [`FromBytes`] and [`IntoBytes`]. +/// [Read more](crate::packet) +#[repr(C)] +#[derive(Debug, FromBytes, IntoBytes, KnownLayout, Unaligned, Immutable)] +pub struct Ipv4 { + /// IPv4 header. + pub header: Ipv4Header, + /// IPv4 payload. + pub payload: Payload, +} + +/// A bitfield struct containing the IPv4 fields `version` and `ihl`. +#[bitfield(u8)] +#[derive(FromBytes, IntoBytes, KnownLayout, Unaligned, Immutable, PartialEq, Eq)] +pub struct Ipv4VersionIhl { + /// IPv4 `ihl` field (Internet Header Length). + /// + /// This determines the length in `u32`s of the IPv4 header, including optional fields. + /// The minimum value is `5`, which implies no optional fields. + #[bits(4)] + pub ihl: u8, + + /// IPv4 `version` field. This must be `4`. + #[bits(4)] + pub version: u8, +} + +/// A bitfield struct containing the IPv4 fields `dscp` and `ecn`. +#[bitfield(u8)] +#[derive(FromBytes, IntoBytes, KnownLayout, Unaligned, Immutable, PartialEq, Eq)] +pub struct Ipv4DscpEcn { + #[bits(2)] + pub ecn: u8, + #[bits(6)] + pub dscp: u8, +} + +/// A bitfield struct containing the IPv4 bitflags and the `fragment_offset` field. +#[bitfield(u16, order = Msb, repr = big_endian::U16, from = big_endian::U16::new, into = big_endian::U16::get)] +#[derive(FromBytes, IntoBytes, KnownLayout, Unaligned, Immutable, PartialEq, Eq)] +pub struct Ipv4FlagsFragmentOffset { + _reserved: bool, + /// IPv4 `dont_fragment` flag. + pub dont_fragment: bool, + /// IPv4 `more_fragments` flag. + pub more_fragments: bool, + /// IPv4 `fragment_offset` field. + #[bits(13)] + pub fragment_offset: u16, +} + +/// An IPv4 header. +#[repr(C, packed)] +#[derive(Clone, Copy, FromBytes, IntoBytes, KnownLayout, Unaligned, Immutable, PartialEq, Eq)] +pub struct Ipv4Header { + /// IPv4 `version`, and `ihl` fields. + pub version_and_ihl: Ipv4VersionIhl, + /// IPv4 `dscp`, and `ecn` fields. + pub dscp_and_ecn: Ipv4DscpEcn, + /// Length of the IPv4 packet, including headers. + pub total_len: big_endian::U16, + /// IPv4 `identification`. This is used for fragmentation. + pub identification: big_endian::U16, + /// IPv4 bitflags, and `fragment_offset` fields. + pub flags_and_fragment_offset: Ipv4FlagsFragmentOffset, + /// Maximum number of hops for the IPv4 packet. + pub time_to_live: u8, + /// Protocol of the IPv4 payload. + pub protocol: IpNextProtocol, + /// Checksum of the IPv4 header. + pub header_checksum: big_endian::U16, + /// IPv4 source address. Use [`Ipv4Header::source`]. + pub source_address: big_endian::U32, + /// IPv4 destination address. Use [`Ipv4Header::destination`]. + pub destination_address: big_endian::U32, +} + +impl Ipv4Header { + /// Construct an IPv4 header with the reasonable defaults. + /// + /// `payload` field is used to set the `total_len` field. + #[allow(dead_code)] + pub const fn new( + source: Ipv4Addr, + destination: Ipv4Addr, + protocol: IpNextProtocol, + payload: &[u8], + ) -> Self { + Self::new_for_length(source, destination, protocol, payload.len() as u16) + } + + /// Construct an IPv4 header with the reasonable defaults. + /// + /// `payload_len` is used to set the `total_len` field. + /// The checksum is initialized to `0`. + pub const fn new_for_length( + source: Ipv4Addr, + destination: Ipv4Addr, + protocol: IpNextProtocol, + payload_len: u16, + ) -> Self { + let header_len = size_of::() as u16; + let total_len = header_len + payload_len; + + Self { + protocol, + + version_and_ihl: Ipv4VersionIhl::new().with_version(4).with_ihl(5), + dscp_and_ecn: Ipv4DscpEcn::new(), + total_len: big_endian::U16::new(total_len), + identification: big_endian::U16::ZERO, + flags_and_fragment_offset: Ipv4FlagsFragmentOffset::new(), + time_to_live: 64, // default TTL in linux + source_address: big_endian::U32::from_bytes(source.octets()), + destination_address: big_endian::U32::from_bytes(destination.octets()), + + // TODO: + header_checksum: big_endian::U16::ZERO, + } + } +} + +impl Ipv4Header { + /// Length of an [`Ipv4Header`], in bytes. + pub const LEN: usize = size_must_be::(20); + + /// Get IP version. Must be `4` for a valid IPv4 header. + pub const fn version(&self) -> u8 { + self.version_and_ihl.version() + } + + /// Get [`ihl`](Ipv4VersionIhl::ihl) + pub const fn ihl(&self) -> u8 { + self.version_and_ihl.ihl() + } + + /// Get [`source_address`](Ipv4Header::source_address). + pub const fn source(&self) -> Ipv4Addr { + let bits = self.source_address.get(); + Ipv4Addr::from_bits(bits) + } + + /// Get [`destination_address`](Ipv4Header::destination_address). + pub const fn destination(&self) -> Ipv4Addr { + let bits = self.destination_address.get(); + Ipv4Addr::from_bits(bits) + } + + /// Get [`protocol`](Ipv4Header::protocol). + pub const fn next_protocol(&self) -> IpNextProtocol { + self.protocol + } + + /// Get [`dscp`](Ipv4DscpEcn::dscp). + pub const fn dscp(&self) -> u8 { + self.dscp_and_ecn.dscp() + } + + /// Get [`ecn`](Ipv4DscpEcn::ecn). + pub const fn ecn(&self) -> u8 { + self.dscp_and_ecn.ecn() + } + + /// Get [`dont_fragment`](Ipv4FlagsFragmentOffset::dont_fragment). + pub const fn dont_fragment(&self) -> bool { + self.flags_and_fragment_offset.dont_fragment() + } + + /// Get [`more_fragments`](Ipv4FlagsFragmentOffset::more_fragments). + pub const fn more_fragments(&self) -> bool { + self.flags_and_fragment_offset.more_fragments() + } + + /// Get [`fragment_offset`](Ipv4FlagsFragmentOffset::fragment_offset). + /// + /// This is the offset of IP fragment payload relative to the start of payload of the original + /// packet. Note that the value returned is in units of 8 bytes. + pub const fn fragment_offset(&self) -> u16 { + self.flags_and_fragment_offset.fragment_offset() + } +} + +impl Ipv4 { + /// Maximum possible length of an IPv4 packet. + pub const MAX_LEN: usize = 65535; +} + +impl Debug for Ipv4Header { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Ipv4Header") + .field("version", &self.version()) + .field("ihl", &self.ihl()) + .field("dscp", &self.dscp()) + .field("ecn", &self.ecn()) + .field("total_len", &self.total_len.get()) + .field("identification", &self.identification.get()) + .field("dont_fragment", &self.dont_fragment()) + .field("more_fragments", &self.more_fragments()) + .field("fragment_offset", &self.fragment_offset()) + .field("time_to_live", &self.time_to_live) + .field("protocol", &self.protocol) + .field("header_checksum", &self.header_checksum.get()) + .field("source_address", &self.source()) + .field("destination_address", &self.destination()) + .finish() + } +} + +#[cfg(test)] +mod tests { + use zerocopy::FromBytes; + + use super::{Ipv4, Ipv4Header}; + use crate::packet::IpNextProtocol; + use std::net::Ipv4Addr; + + const EXAMPLE_IPV4_ICMP: &[u8] = &[ + 0x45, 0x83, 0x0, 0x54, 0xa3, 0x13, 0x40, 0x0, 0x40, 0x1, 0xc6, 0x26, 0xa, 0x8c, 0xc2, 0xdd, + 0x1, 0x2, 0x3, 0x4, 0x8, 0x0, 0x51, 0x13, 0x0, 0x2b, 0x0, 0x1, 0xb1, 0x5c, 0x87, 0x68, 0x0, + 0x0, 0x0, 0x0, 0xa8, 0x28, 0x7, 0x0, 0x0, 0x0, 0x0, 0x0, 0x10, 0x11, 0x12, 0x13, 0x14, + 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23, + 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30, 0x31, 0x32, + 0x33, 0x34, 0x35, 0x36, 0x37, + ]; + + #[test] + fn ipv4_header_layout() { + let packet = Ipv4::<[u8]>::ref_from_bytes(EXAMPLE_IPV4_ICMP).unwrap(); + let header = &packet.header; + + assert_eq!(header.version(), 4); + assert_eq!(header.ihl(), 5); + assert_eq!(header.dscp(), 32); + assert_eq!(header.ecn(), 0x3); + assert_eq!(header.total_len, 84); + assert_eq!(header.identification, 41747); + assert!(header.dont_fragment()); + assert!(!header.more_fragments()); + assert_eq!(header.fragment_offset(), 0); + assert_eq!(header.time_to_live, 64); + assert_eq!(header.protocol, IpNextProtocol::Icmp); + assert_eq!(header.header_checksum, 0xc626); + assert_eq!(header.source(), Ipv4Addr::new(10, 140, 194, 221)); + assert_eq!(header.destination(), Ipv4Addr::new(1, 2, 3, 4)); + + assert_eq!( + packet.payload.len() + Ipv4Header::LEN, + usize::from(header.total_len) + ); + } +} diff --git a/src/packet/ipv4/protocol.rs b/src/packet/ipv4/protocol.rs new file mode 100644 index 0000000..fadee81 --- /dev/null +++ b/src/packet/ipv4/protocol.rs @@ -0,0 +1,602 @@ +// Copyright (c) 2025 Mullvad VPN AB. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause +#![allow(clippy::doc_markdown)] + +use std::fmt::Debug; + +use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout, Unaligned}; + +/// Type of the `protocol`/`next_header` fields of IPv4 and IPv6. +/// +/// The value indicates what is stored in the IP packet. +#[repr(transparent)] +#[derive(Clone, Copy, PartialEq, Eq, Immutable, Unaligned, FromBytes, IntoBytes, KnownLayout)] +pub struct IpNextProtocol(u8); + +// https://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml +impl IpNextProtocol { + #![allow(non_upper_case_globals)] + + /// IPv6 Hop-by-Hop Option \[RFC2460\] + pub const Hopopt: IpNextProtocol = IpNextProtocol(0); + + /// Internet Control Message \[RFC792\] + pub const Icmp: IpNextProtocol = IpNextProtocol(1); + + /// Internet Group Management \[RFC1112\] + pub const Igmp: IpNextProtocol = IpNextProtocol(2); + + /// Gateway-to-Gateway \[RFC823\] + pub const Ggp: IpNextProtocol = IpNextProtocol(3); + + /// IPv4 encapsulation \[RFC2003\] + pub const Ipv4: IpNextProtocol = IpNextProtocol(4); + + /// Stream \[RFC1190\]\[RFC1819\] + pub const St: IpNextProtocol = IpNextProtocol(5); + + /// Transmission Control \[RFC793\] + pub const Tcp: IpNextProtocol = IpNextProtocol(6); + + /// CBT + pub const Cbt: IpNextProtocol = IpNextProtocol(7); + + /// Exterior Gateway Protocol \[RFC888\] + pub const Egp: IpNextProtocol = IpNextProtocol(8); + + /// any private interior gateway (used by Cisco for their IGRP) + pub const Igp: IpNextProtocol = IpNextProtocol(9); + + /// BBN RCC Monitoring + pub const BbnRccMon: IpNextProtocol = IpNextProtocol(10); + + /// Network Voice Protocol \[RFC741\] + pub const NvpII: IpNextProtocol = IpNextProtocol(11); + + /// PUP + pub const Pup: IpNextProtocol = IpNextProtocol(12); + + /// ARGUS + pub const Argus: IpNextProtocol = IpNextProtocol(13); + + /// EMCON + pub const Emcon: IpNextProtocol = IpNextProtocol(14); + + /// Cross Net Debugger + pub const Xnet: IpNextProtocol = IpNextProtocol(15); + + /// Chaos + pub const Chaos: IpNextProtocol = IpNextProtocol(16); + + /// User Datagram \[RFC768\] + pub const Udp: IpNextProtocol = IpNextProtocol(17); + + /// Multiplexing + pub const Mux: IpNextProtocol = IpNextProtocol(18); + + /// DCN Measurement Subsystems + pub const DcnMeas: IpNextProtocol = IpNextProtocol(19); + + /// Host Monitoring \[RFC869\] + pub const Hmp: IpNextProtocol = IpNextProtocol(20); + + /// Packet Radio Measurement + pub const Prm: IpNextProtocol = IpNextProtocol(21); + + /// XEROX NS IDP + pub const XnsIdp: IpNextProtocol = IpNextProtocol(22); + + /// Trunk-1 + pub const Trunk1: IpNextProtocol = IpNextProtocol(23); + + /// Trunk-2 + pub const Trunk2: IpNextProtocol = IpNextProtocol(24); + + /// Leaf-1 + pub const Leaf1: IpNextProtocol = IpNextProtocol(25); + + /// Leaf-2 + pub const Leaf2: IpNextProtocol = IpNextProtocol(26); + + /// Reliable Data Protocol \[RFC908\] + pub const Rdp: IpNextProtocol = IpNextProtocol(27); + + /// Internet Reliable Transaction \[RFC938\] + pub const Irtp: IpNextProtocol = IpNextProtocol(28); + + /// ISO Transport Protocol Class 4 \[RFC905\] + pub const IsoTp4: IpNextProtocol = IpNextProtocol(29); + + /// Bulk Data Transfer Protocol \[RFC969\] + pub const Netblt: IpNextProtocol = IpNextProtocol(30); + + /// MFE Network Services Protocol + pub const MfeNsp: IpNextProtocol = IpNextProtocol(31); + + /// MERIT Internodal Protocol + pub const MeritInp: IpNextProtocol = IpNextProtocol(32); + + /// Datagram Congestion Control Protocol \[RFC4340\] + pub const Dccp: IpNextProtocol = IpNextProtocol(33); + + /// Third Party Connect Protocol + pub const ThreePc: IpNextProtocol = IpNextProtocol(34); + + /// Inter-Domain Policy Routing Protocol + pub const Idpr: IpNextProtocol = IpNextProtocol(35); + + /// XTP + pub const Xtp: IpNextProtocol = IpNextProtocol(36); + + /// Datagram Delivery Protocol + pub const Ddp: IpNextProtocol = IpNextProtocol(37); + + /// IDPR Control Message Transport Proto + pub const IdprCmtp: IpNextProtocol = IpNextProtocol(38); + + /// TP++ Transport Protocol + pub const TpPlusPlus: IpNextProtocol = IpNextProtocol(39); + + /// IL Transport Protocol + pub const Il: IpNextProtocol = IpNextProtocol(40); + + /// IPv6 encapsulation \[RFC2473\] + pub const Ipv6: IpNextProtocol = IpNextProtocol(41); + + /// Source Demand Routing Protocol + pub const Sdrp: IpNextProtocol = IpNextProtocol(42); + + /// Routing Header for IPv6 + pub const Ipv6Route: IpNextProtocol = IpNextProtocol(43); + + /// Fragment Header for IPv6 + pub const Ipv6Frag: IpNextProtocol = IpNextProtocol(44); + + /// Inter-Domain Routing Protocol + pub const Idrp: IpNextProtocol = IpNextProtocol(45); + + /// Reservation Protocol \[RFC2205\]\[RFC3209\] + pub const Rsvp: IpNextProtocol = IpNextProtocol(46); + + /// Generic Routing Encapsulation \[RFC1701\] + pub const Gre: IpNextProtocol = IpNextProtocol(47); + + /// Dynamic Source Routing Protocol \[RFC4728\] + pub const Dsr: IpNextProtocol = IpNextProtocol(48); + + /// BNA + pub const Bna: IpNextProtocol = IpNextProtocol(49); + + /// Encap Security Payload \[RFC4303\] + pub const Esp: IpNextProtocol = IpNextProtocol(50); + + /// Authentication Header \[RFC4302\] + pub const Ah: IpNextProtocol = IpNextProtocol(51); + + /// Integrated Net Layer Security TUBA + pub const INlsp: IpNextProtocol = IpNextProtocol(52); + + /// IP with Encryption + pub const Swipe: IpNextProtocol = IpNextProtocol(53); + + /// NBMA Address Resolution Protocol \[RFC1735\] + pub const Narp: IpNextProtocol = IpNextProtocol(54); + + /// IP Mobility + pub const Mobile: IpNextProtocol = IpNextProtocol(55); + + /// Transport Layer Security Protocol using Kryptonet key management + pub const Tlsp: IpNextProtocol = IpNextProtocol(56); + + /// SKIP + pub const Skip: IpNextProtocol = IpNextProtocol(57); + + /// ICMPv6 \[RFC4443\] + pub const Icmpv6: IpNextProtocol = IpNextProtocol(58); + + /// No Next Header for IPv6 \[RFC2460\] + pub const Ipv6NoNxt: IpNextProtocol = IpNextProtocol(59); + + /// Destination Options for IPv6 \[RFC2460\] + pub const Ipv6Opts: IpNextProtocol = IpNextProtocol(60); + + /// any host internal protocol + pub const HostInternal: IpNextProtocol = IpNextProtocol(61); + + /// CFTP + pub const Cftp: IpNextProtocol = IpNextProtocol(62); + + /// any local network + pub const LocalNetwork: IpNextProtocol = IpNextProtocol(63); + + /// SATNET and Backroom EXPAK + pub const SatExpak: IpNextProtocol = IpNextProtocol(64); + + /// Kryptolan + pub const Kryptolan: IpNextProtocol = IpNextProtocol(65); + + /// MIT Remote Virtual Disk Protocol + pub const Rvd: IpNextProtocol = IpNextProtocol(66); + + /// Internet Pluribus Packet Core + pub const Ippc: IpNextProtocol = IpNextProtocol(67); + + /// any distributed file system + pub const DistributedFs: IpNextProtocol = IpNextProtocol(68); + + /// SATNET Monitoring + pub const SatMon: IpNextProtocol = IpNextProtocol(69); + + /// VISA Protocol + pub const Visa: IpNextProtocol = IpNextProtocol(70); + + /// Internet Packet Core Utility + pub const Ipcv: IpNextProtocol = IpNextProtocol(71); + + /// Computer Protocol Network Executive + pub const Cpnx: IpNextProtocol = IpNextProtocol(72); + + /// Computer Protocol Heart Beat + pub const Cphb: IpNextProtocol = IpNextProtocol(73); + + /// Wang Span Network + pub const Wsn: IpNextProtocol = IpNextProtocol(74); + + /// Packet Video Protocol + pub const Pvp: IpNextProtocol = IpNextProtocol(75); + + /// Backroom SATNET Monitoring + pub const BrSatMon: IpNextProtocol = IpNextProtocol(76); + + /// SUN ND PROTOCOL-Temporary + pub const SunNd: IpNextProtocol = IpNextProtocol(77); + + /// WIDEBAND Monitoring + pub const WbMon: IpNextProtocol = IpNextProtocol(78); + + /// WIDEBAND EXPAK + pub const WbExpak: IpNextProtocol = IpNextProtocol(79); + + /// ISO Internet Protocol + pub const IsoIp: IpNextProtocol = IpNextProtocol(80); + + /// VMTP + pub const Vmtp: IpNextProtocol = IpNextProtocol(81); + + /// SECURE-VMTP + pub const SecureVmtp: IpNextProtocol = IpNextProtocol(82); + + /// VINES + pub const Vines: IpNextProtocol = IpNextProtocol(83); + + /// Transaction Transport Protocol/IP Traffic Manager + pub const TtpOrIptm: IpNextProtocol = IpNextProtocol(84); + + /// NSFNET-IGP + pub const NsfnetIgp: IpNextProtocol = IpNextProtocol(85); + + /// Dissimilar Gateway Protocol + pub const Dgp: IpNextProtocol = IpNextProtocol(86); + + /// TCF + pub const Tcf: IpNextProtocol = IpNextProtocol(87); + + /// EIGRP + pub const Eigrp: IpNextProtocol = IpNextProtocol(88); + + /// OSPFIGP \[RFC1583\]\[RFC2328\]\[RFC5340\] + pub const OspfigP: IpNextProtocol = IpNextProtocol(89); + + /// Sprite RPC Protocol + pub const SpriteRpc: IpNextProtocol = IpNextProtocol(90); + + /// Locus Address Resolution Protocol + pub const Larp: IpNextProtocol = IpNextProtocol(91); + + /// Multicast Transport Protocol + pub const Mtp: IpNextProtocol = IpNextProtocol(92); + + /// AX.25 Frames + pub const Ax25: IpNextProtocol = IpNextProtocol(93); + + /// IP-within-IP Encapsulation Protocol + pub const IpIp: IpNextProtocol = IpNextProtocol(94); + + /// Mobile Internetworking Control Pro. + pub const Micp: IpNextProtocol = IpNextProtocol(95); + + /// Semaphore Communications Sec. Pro. + pub const SccSp: IpNextProtocol = IpNextProtocol(96); + + /// Ethernet-within-IP Encapsulation \[RFC3378\] + pub const Etherip: IpNextProtocol = IpNextProtocol(97); + + /// Encapsulation Header \[RFC1241\] + pub const Encap: IpNextProtocol = IpNextProtocol(98); + + /// any private encryption scheme + pub const PrivEncryption: IpNextProtocol = IpNextProtocol(99); + + /// GMTP + pub const Gmtp: IpNextProtocol = IpNextProtocol(100); + + /// Ipsilon Flow Management Protocol + pub const Ifmp: IpNextProtocol = IpNextProtocol(101); + + /// PNNI over IP + pub const Pnni: IpNextProtocol = IpNextProtocol(102); + + /// Protocol Independent Multicast \[RFC4601\] + pub const Pim: IpNextProtocol = IpNextProtocol(103); + + /// ARIS + pub const Aris: IpNextProtocol = IpNextProtocol(104); + + /// SCPS + pub const Scps: IpNextProtocol = IpNextProtocol(105); + + /// QNX + pub const Qnx: IpNextProtocol = IpNextProtocol(106); + + /// Active Networks + pub const AN: IpNextProtocol = IpNextProtocol(107); + + /// IP Payload Compression Protocol \[RFC2393\] + pub const IpComp: IpNextProtocol = IpNextProtocol(108); + + /// Sitara Networks Protocol + pub const Snp: IpNextProtocol = IpNextProtocol(109); + + /// Compaq Peer Protocol + pub const CompaqPeer: IpNextProtocol = IpNextProtocol(110); + + /// IPX in IP + pub const IpxInIp: IpNextProtocol = IpNextProtocol(111); + + /// Virtual Router Redundancy Protocol \[RFC5798\] + pub const Vrrp: IpNextProtocol = IpNextProtocol(112); + + /// PGM Reliable Transport Protocol + pub const Pgm: IpNextProtocol = IpNextProtocol(113); + + /// any 0-hop protocol + pub const ZeroHop: IpNextProtocol = IpNextProtocol(114); + + /// Layer Two Tunneling Protocol \[RFC3931\] + pub const L2tp: IpNextProtocol = IpNextProtocol(115); + + /// D-II Data Exchange (DDX) + pub const Ddx: IpNextProtocol = IpNextProtocol(116); + + /// Interactive Agent Transfer Protocol + pub const Iatp: IpNextProtocol = IpNextProtocol(117); + + /// Schedule Transfer Protocol + pub const Stp: IpNextProtocol = IpNextProtocol(118); + + /// SpectraLink Radio Protocol + pub const Srp: IpNextProtocol = IpNextProtocol(119); + + /// UTI + pub const Uti: IpNextProtocol = IpNextProtocol(120); + + /// Simple Message Protocol + pub const Smp: IpNextProtocol = IpNextProtocol(121); + + /// Simple Multicast Protocol + pub const Sm: IpNextProtocol = IpNextProtocol(122); + + /// Performance Transparency Protocol + pub const Ptp: IpNextProtocol = IpNextProtocol(123); + + /// Intermediate System to Intermediate System (IS-IS) Protocol over IPv4 + pub const IsisOverIpv4: IpNextProtocol = IpNextProtocol(124); + + /// Flexible Intra-AS Routing Environment + pub const Fire: IpNextProtocol = IpNextProtocol(125); + + /// Combat Radio Transport Protocol + pub const Crtp: IpNextProtocol = IpNextProtocol(126); + + /// Combat Radio User Datagram + pub const Crudp: IpNextProtocol = IpNextProtocol(127); + + /// Service-Specific Connection-Oriented Protocol in a Multilink and Connectionless Environment + pub const Sscopmce: IpNextProtocol = IpNextProtocol(128); + + // This protocol doesn't seem to be documented by IANA. + #[allow(missing_docs)] + pub const Iplt: IpNextProtocol = IpNextProtocol(129); + + /// Secure Packet Shield + pub const Sps: IpNextProtocol = IpNextProtocol(130); + + /// Private IP Encapsulation within IP + pub const Pipe: IpNextProtocol = IpNextProtocol(131); + + /// Stream Control Transmission Protocol + pub const Sctp: IpNextProtocol = IpNextProtocol(132); + + /// Fibre Channel \[RFC6172\] + pub const Fc: IpNextProtocol = IpNextProtocol(133); + + /// \[RFC3175\] + pub const RsvpE2eIgnore: IpNextProtocol = IpNextProtocol(134); + + /// \[RFC6275\] + pub const MobilityHeader: IpNextProtocol = IpNextProtocol(135); + + /// \[RFC3828\] + pub const UdpLite: IpNextProtocol = IpNextProtocol(136); + + /// \[RFC4023\] + pub const MplsInIp: IpNextProtocol = IpNextProtocol(137); + + /// MANET Protocols \[RFC5498\] + pub const Manet: IpNextProtocol = IpNextProtocol(138); + + /// Host Identity Protocol \[RFC5201\] + pub const Hip: IpNextProtocol = IpNextProtocol(139); + + /// Shim6 Protocol \[RFC5533\] + pub const Shim6: IpNextProtocol = IpNextProtocol(140); + + /// Wrapped Encapsulating Security Payload \[RFC5840\] + pub const Wesp: IpNextProtocol = IpNextProtocol(141); + + /// Robust Header Compression \[RFC5858\] + pub const Rohc: IpNextProtocol = IpNextProtocol(142); +} + +impl Debug for IpNextProtocol { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let name = match *self { + IpNextProtocol::Hopopt => "Hopopt", + IpNextProtocol::Icmp => "Icmp", + IpNextProtocol::Igmp => "Igmp", + IpNextProtocol::Ggp => "Ggp", + IpNextProtocol::Ipv4 => "Ipv4", + IpNextProtocol::St => "St", + IpNextProtocol::Tcp => "Tcp", + IpNextProtocol::Cbt => "Cbt", + IpNextProtocol::Egp => "Egp", + IpNextProtocol::Igp => "Igp", + IpNextProtocol::BbnRccMon => "BbnRccMon", + IpNextProtocol::NvpII => "NvpII", + IpNextProtocol::Pup => "Pup", + IpNextProtocol::Argus => "Argus", + IpNextProtocol::Emcon => "Emcon", + IpNextProtocol::Xnet => "Xnet", + IpNextProtocol::Chaos => "Chaos", + IpNextProtocol::Udp => "Udp", + IpNextProtocol::Mux => "Mux", + IpNextProtocol::DcnMeas => "DcnMeas", + IpNextProtocol::Hmp => "Hmp", + IpNextProtocol::Prm => "Prm", + IpNextProtocol::XnsIdp => "XnsIdp", + IpNextProtocol::Trunk1 => "Trunk1", + IpNextProtocol::Trunk2 => "Trunk2", + IpNextProtocol::Leaf1 => "Leaf1", + IpNextProtocol::Leaf2 => "Leaf2", + IpNextProtocol::Rdp => "Rdp", + IpNextProtocol::Irtp => "Irtp", + IpNextProtocol::IsoTp4 => "IsoTp4", + IpNextProtocol::Netblt => "Netblt", + IpNextProtocol::MfeNsp => "MfeNsp", + IpNextProtocol::MeritInp => "MeritInp", + IpNextProtocol::Dccp => "Dccp", + IpNextProtocol::ThreePc => "ThreePc", + IpNextProtocol::Idpr => "Idpr", + IpNextProtocol::Xtp => "Xtp", + IpNextProtocol::Ddp => "Ddp", + IpNextProtocol::IdprCmtp => "IdprCmtp", + IpNextProtocol::TpPlusPlus => "TpPlusPlus", + IpNextProtocol::Il => "Il", + IpNextProtocol::Ipv6 => "Ipv6", + IpNextProtocol::Sdrp => "Sdrp", + IpNextProtocol::Ipv6Route => "Ipv6Route", + IpNextProtocol::Ipv6Frag => "Ipv6Frag", + IpNextProtocol::Idrp => "Idrp", + IpNextProtocol::Rsvp => "Rsvp", + IpNextProtocol::Gre => "Gre", + IpNextProtocol::Dsr => "Dsr", + IpNextProtocol::Bna => "Bna", + IpNextProtocol::Esp => "Esp", + IpNextProtocol::Ah => "Ah", + IpNextProtocol::INlsp => "INlsp", + IpNextProtocol::Swipe => "Swipe", + IpNextProtocol::Narp => "Narp", + IpNextProtocol::Mobile => "Mobile", + IpNextProtocol::Tlsp => "Tlsp", + IpNextProtocol::Skip => "Skip", + IpNextProtocol::Icmpv6 => "Icmpv6", + IpNextProtocol::Ipv6NoNxt => "Ipv6NoNxt", + IpNextProtocol::Ipv6Opts => "Ipv6Opts", + IpNextProtocol::HostInternal => "HostInternal", + IpNextProtocol::Cftp => "Cftp", + IpNextProtocol::LocalNetwork => "LocalNetwork", + IpNextProtocol::SatExpak => "SatExpak", + IpNextProtocol::Kryptolan => "Kryptolan", + IpNextProtocol::Rvd => "Rvd", + IpNextProtocol::Ippc => "Ippc", + IpNextProtocol::DistributedFs => "DistributedFs", + IpNextProtocol::SatMon => "SatMon", + IpNextProtocol::Visa => "Visa", + IpNextProtocol::Ipcv => "Ipcv", + IpNextProtocol::Cpnx => "Cpnx", + IpNextProtocol::Cphb => "Cphb", + IpNextProtocol::Wsn => "Wsn", + IpNextProtocol::Pvp => "Pvp", + IpNextProtocol::BrSatMon => "BrSatMon", + IpNextProtocol::SunNd => "SunNd", + IpNextProtocol::WbMon => "WbMon", + IpNextProtocol::WbExpak => "WbExpak", + IpNextProtocol::IsoIp => "IsoIp", + IpNextProtocol::Vmtp => "Vmtp", + IpNextProtocol::SecureVmtp => "SecureVmtp", + IpNextProtocol::Vines => "Vines", + IpNextProtocol::TtpOrIptm => "TtpOrIptm", + IpNextProtocol::NsfnetIgp => "NsfnetIgp", + IpNextProtocol::Dgp => "Dgp", + IpNextProtocol::Tcf => "Tcf", + IpNextProtocol::Eigrp => "Eigrp", + IpNextProtocol::OspfigP => "OspfigP", + IpNextProtocol::SpriteRpc => "SpriteRpc", + IpNextProtocol::Larp => "Larp", + IpNextProtocol::Mtp => "Mtp", + IpNextProtocol::Ax25 => "Ax25", + IpNextProtocol::IpIp => "IpIp", + IpNextProtocol::Micp => "Micp", + IpNextProtocol::SccSp => "SccSp", + IpNextProtocol::Etherip => "Etherip", + IpNextProtocol::Encap => "Encap", + IpNextProtocol::PrivEncryption => "PrivEncryption", + IpNextProtocol::Gmtp => "Gmtp", + IpNextProtocol::Ifmp => "Ifmp", + IpNextProtocol::Pnni => "Pnni", + IpNextProtocol::Pim => "Pim", + IpNextProtocol::Aris => "Aris", + IpNextProtocol::Scps => "Scps", + IpNextProtocol::Qnx => "Qnx", + IpNextProtocol::AN => "AN", + IpNextProtocol::IpComp => "IpComp", + IpNextProtocol::Snp => "Snp", + IpNextProtocol::CompaqPeer => "CompaqPeer", + IpNextProtocol::IpxInIp => "IpxInIp", + IpNextProtocol::Vrrp => "Vrrp", + IpNextProtocol::Pgm => "Pgm", + IpNextProtocol::ZeroHop => "ZeroHop", + IpNextProtocol::L2tp => "L2tp", + IpNextProtocol::Ddx => "Ddx", + IpNextProtocol::Iatp => "Iatp", + IpNextProtocol::Stp => "Stp", + IpNextProtocol::Srp => "Srp", + IpNextProtocol::Uti => "Uti", + IpNextProtocol::Smp => "Smp", + IpNextProtocol::Sm => "Sm", + IpNextProtocol::Ptp => "Ptp", + IpNextProtocol::IsisOverIpv4 => "IsisOverIpv4", + IpNextProtocol::Fire => "Fire", + IpNextProtocol::Crtp => "Crtp", + IpNextProtocol::Crudp => "Crudp", + IpNextProtocol::Sscopmce => "Sscopmce", + IpNextProtocol::Iplt => "Iplt", + IpNextProtocol::Sps => "Sps", + IpNextProtocol::Pipe => "Pipe", + IpNextProtocol::Sctp => "Sctp", + IpNextProtocol::Fc => "Fc", + IpNextProtocol::RsvpE2eIgnore => "RsvpE2eIgnore", + IpNextProtocol::MobilityHeader => "MobilityHeader", + IpNextProtocol::UdpLite => "UdpLite", + IpNextProtocol::MplsInIp => "MplsInIp", + IpNextProtocol::Manet => "Manet", + IpNextProtocol::Hip => "Hip", + IpNextProtocol::Shim6 => "Shim6", + IpNextProtocol::Wesp => "Wesp", + IpNextProtocol::Rohc => "Rohc", + _ => "Unknown", + }; + + f.debug_tuple(name).finish() + } +} diff --git a/src/packet/ipv6/mod.rs b/src/packet/ipv6/mod.rs new file mode 100644 index 0000000..5fc81a7 --- /dev/null +++ b/src/packet/ipv6/mod.rs @@ -0,0 +1,176 @@ +// Copyright (c) 2025 Mullvad VPN AB. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +use bitfield_struct::bitfield; +use std::{fmt::Debug, net::Ipv6Addr}; +use zerocopy::{big_endian, FromBytes, Immutable, IntoBytes, KnownLayout, Unaligned}; + +use super::{util::size_must_be, IpNextProtocol}; + +/// An IPv6 packet. +/// +/// This is a dynamically sized zerocopy type, which means you can compose packet types like +/// `Ipv6>` and cast them to/from byte slices using [`FromBytes`] and [`IntoBytes`]. +/// [Read more](crate::packet) +#[repr(C)] +#[derive(Debug, FromBytes, IntoBytes, KnownLayout, Unaligned, Immutable)] +pub struct Ipv6 { + /// IPv6 header. + pub header: Ipv6Header, + /// IPv6 payload. The type of this is `[u8]` by default, but it may be any zerocopy type, + /// e.g. a `Udp` + pub payload: Payload, +} + +/// An IPv6 header. +#[repr(C, packed)] +#[derive(Clone, Copy, FromBytes, IntoBytes, KnownLayout, Unaligned, Immutable, PartialEq, Eq)] +pub struct Ipv6Header { + /// IPv6 `flow_label`, `traffic_class` and `version` fields. + pub version_traffic_flow: Ipv6VersionTrafficFlow, + /// Length of the IPv6 payload. + pub payload_length: big_endian::U16, + /// Protocol of the IPv6 payload. + pub next_header: IpNextProtocol, + /// Maximum number of hops for the IPv6 packet. + pub hop_limit: u8, + /// IPv6 source address. + pub source_address: big_endian::U128, + /// IPv6 destination address. + pub destination_address: big_endian::U128, +} + +/// A bitfield struct containing the IPv6 fields `flow_label`, `traffic_class` and `version`. +#[bitfield(u32, repr = big_endian::U32, from = big_endian::U32::new, into = big_endian::U32::get)] +#[derive(FromBytes, IntoBytes, KnownLayout, Unaligned, Immutable, PartialEq, Eq)] +pub struct Ipv6VersionTrafficFlow { + /// IPv6 flow label. + #[bits(20)] + pub flow_label: u32, + /// IPv6 traffic class. + #[bits(8)] + pub traffic_class: u8, + /// IPv6 version. This must be `6`. + #[bits(4)] + pub version: u8, +} + +impl Ipv6Header { + /// Length of an [`Ipv6Header`], in bytes. + #[allow(dead_code)] + pub const LEN: usize = size_must_be::(40); + + /// Get [`version`](Ipv6VersionTrafficFlow::version). This is expected to be `6`. + pub const fn version(&self) -> u8 { + self.version_traffic_flow.version() + } + + /// Get [`traffic_class`](Ipv6VersionTrafficFlow::traffic_class). + pub const fn traffic_class(&self) -> u8 { + self.version_traffic_flow.traffic_class() + } + + /// Get [`flow_label`](Ipv6VersionTrafficFlow::flow_label). + pub const fn flow_label(&self) -> u32 { + self.version_traffic_flow.flow_label() + } + + /// Set [`version`](Ipv6VersionTrafficFlow::version). + // If you're setting it to anything other than `6`, you're probably doing it wrong. + pub const fn set_version(&mut self, version: u8) { + self.version_traffic_flow.set_version(version); + } + + /// Set [`traffic_class`](Ipv6VersionTrafficFlow::traffic_class). + pub const fn set_traffic_class(&mut self, tc: u8) { + self.version_traffic_flow.set_traffic_class(tc); + } + + /// Set [`flow_label`](Ipv6VersionTrafficFlow::flow_label). + pub const fn set_flow_label(&mut self, flow: u32) { + self.version_traffic_flow.set_flow_label(flow); + } + + /// Set [next header protocol](Ipv6Header::next_protocol). + pub const fn next_protocol(&self) -> IpNextProtocol { + self.next_header + } + + /// Get source address. + pub const fn source(&self) -> Ipv6Addr { + let bits = self.source_address.get(); + Ipv6Addr::from_bits(bits) + } + + /// Get destination address. + pub const fn destination(&self) -> Ipv6Addr { + let bits = self.destination_address.get(); + Ipv6Addr::from_bits(bits) + } + + /// Get [`Ipv6Header::payload_length`] plus [`Ipv6Header::LEN`]. + /// This is a [`usize`] because the length might exceed [`u16::MAX`]. + pub const fn total_length(&self) -> usize { + self.payload_length.get() as usize + Ipv6Header::LEN + } +} + +impl Debug for Ipv6Header { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Ipv6Header") + .field("version", &self.version()) + .field("traffic_class", &self.traffic_class()) + .field("flow_label", &self.flow_label()) + .field("payload_length", &self.payload_length.get()) + .field("next_header", &self.next_header) + .field("hop_limit", &self.hop_limit) + .field("source_address", &self.source()) + .field("destination_address", &self.destination()) + .finish() + } +} + +#[cfg(test)] +mod tests { + use zerocopy::FromBytes; + + use super::Ipv6; + use crate::packet::{IpNextProtocol, Ipv6Header}; + use std::{net::Ipv6Addr, str::FromStr}; + + const EXAMPLE_IPV6_ICMP: &[u8] = &[ + 0x60, 0x8, 0xc7, 0xf3, 0x0, 0x40, 0x3a, 0x40, 0xfc, 0x0, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0x1, + 0x0, 0xd, 0x0, 0x0, 0x0, 0xc, 0xc2, 0xdd, 0x26, 0x6, 0x47, 0x0, 0x47, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x11, 0x11, 0x80, 0x0, 0x2d, 0xc5, 0x0, 0x2f, 0x0, 0xb, 0x1c, + 0xa7, 0x87, 0x68, 0x0, 0x0, 0x0, 0x0, 0x35, 0x1b, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x10, 0x11, + 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, + 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, + 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, + ]; + + #[test] + fn ipv6_header_layout() { + let packet = Ipv6::<[u8]>::ref_from_bytes(EXAMPLE_IPV6_ICMP).unwrap(); + let header = &packet.header; + + assert_eq!(header.version(), 6); + assert_eq!(header.traffic_class(), 0); + assert_eq!(header.flow_label(), 0x8c7f3); + assert_eq!(header.payload_length, 64); + assert_eq!(usize::from(header.payload_length), packet.payload.len()); + assert_eq!(header.next_protocol(), IpNextProtocol::Icmpv6); + assert_eq!(header.hop_limit, 64); + assert_eq!( + header.source(), + Ipv6Addr::from_str("fc00:bbbb:bbbb:bb01:d:0:c:c2dd").unwrap(), + ); + assert_eq!( + header.destination(), + Ipv6Addr::from_str("2606:4700:4700::1111").unwrap(), + ); + assert_eq!( + Ipv6Header::LEN + packet.payload.len(), + EXAMPLE_IPV6_ICMP.len(), + ); + } +} diff --git a/src/packet/mod.rs b/src/packet/mod.rs new file mode 100644 index 0000000..018f781 --- /dev/null +++ b/src/packet/mod.rs @@ -0,0 +1,503 @@ +// Copyright (c) 2025 Mullvad VPN AB. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +//! Types to create, parse, and move network packets around in a zero-copy manner. +//! +//! See [`Packet`] for an implementation of a [`bytes`]-backed owned packet buffer. +//! +//! Any of the [`zerocopy`]-enabled definitions such as [`Ipv4`] or [`Udp`] can be used to cheaply +//! construct or parse packets: +//! ``` +//! let example_ipv4_icmp: &mut [u8] = &mut [ +//! 0x45, 0x83, 0x0, 0x54, 0xa3, 0x13, 0x40, 0x0, 0x40, 0x1, 0xc6, 0x26, 0xa, 0x8c, 0xc2, 0xdd, +//! 0x1, 0x2, 0x3, 0x4, 0x8, 0x0, 0x51, 0x13, 0x0, 0x2b, 0x0, 0x1, 0xb1, 0x5c, 0x87, 0x68, 0x0, +//! 0x0, 0x0, 0x0, 0xa8, 0x28, 0x7, 0x0, 0x0, 0x0, 0x0, 0x0, 0x10, 0x11, 0x12, 0x13, 0x14, +//! 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23, +//! 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30, 0x31, 0x32, +//! 0x33, 0x34, 0x35, 0x36, 0x37, +//! ]; +//! +//! use gotatun::packet::{Ipv4, Ipv4Header, IpNextProtocol}; +//! use zerocopy::FromBytes; +//! use std::net::Ipv4Addr; +//! +//! // Cast the `&[u8]` to an &Ipv4. +//! // Note that this doesn't validate anything about the packet, +//! // except that it's at least Ipv4Header::LEN bytes long. +//! let packet = Ipv4::<[u8]>::mut_from_bytes(example_ipv4_icmp) +//! .expect("Packet must be large enough to be IPv4"); +//! let header: &mut Ipv4Header = &mut packet.header; +//! let payload: &mut [u8] = &mut packet.payload; +//! +//! // Read stuff from the IPv4 header +//! assert_eq!(header.version(), 4); +//! assert_eq!(header.source(), Ipv4Addr::new(10, 140, 194, 221)); +//! assert_eq!(header.destination(), Ipv4Addr::new(1, 2, 3, 4)); +//! assert_eq!(header.header_checksum, 0xc626); +//! assert_eq!(header.protocol, IpNextProtocol::Icmp); +//! +//! // Write stuff to the header. Note that this invalidates the checksum. +//! header.time_to_live = 123; +//! +//! // Write stuff to the payload. Note that this clobbers the ICMP packet stored here. +//! payload[..12].copy_from_slice(b"Hello there!"); +//! assert_eq!(&example_ipv4_icmp[20..][..12], b"Hello there!"); +//! ``` + +use std::{ + fmt::{self, Debug}, + marker::PhantomData, + ops::{Deref, DerefMut}, +}; + +use bytes::{Buf, BytesMut}; +use duplicate::duplicate_item; +use either::Either; +use eyre::{bail, eyre, Context}; +use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout, Unaligned}; + +mod ip; +mod ipv4; +mod ipv6; +mod pool; +mod udp; +mod util; +mod wg; + +pub use ip::*; +pub use ipv4::*; +pub use ipv6::*; +pub use pool::*; +pub use udp::*; +pub use wg::*; + +/// An owned packet of some type. +/// +/// The generic type `Kind` represents the type of packet. +/// For example, a `Packet<[u8]>` is an untyped packet containing arbitrary bytes. +/// It can be safely decoded into a `Packet` using [`Packet::try_into_ip`], +/// and further decoded into a `Packet>` using [`Packet::try_into_udp`]. +/// +/// [`Packet`] uses [`BytesMut`] as the backing buffer. +/// +/// ``` +/// use gotatun::packet::*; +/// use std::net::Ipv4Addr; +/// use zerocopy::IntoBytes; +/// +/// let ip_header = Ipv4Header::new( +/// Ipv4Addr::new(10, 0, 0, 1), +/// Ipv4Addr::new(1, 2, 3, 4), +/// IpNextProtocol::Icmp, +/// &[], +/// ); +/// +/// let ip_header_bytes = ip_header.as_bytes(); +/// +/// let raw_packet: Packet<[u8]> = Packet::copy_from(ip_header_bytes); +/// let ipv4_packet: Packet = raw_packet.try_into_ipvx().unwrap().unwrap_left(); +/// assert_eq!(&ip_header, &ipv4_packet.header); +/// ``` +pub struct Packet { + inner: PacketInner, + + /// Marker type defining what type `Bytes` is. + /// + /// INVARIANT: + /// `buf` must have been ensured to actually contain a packet of this type. + _kind: PhantomData, +} + +struct PacketInner { + buf: BytesMut, + + // If the [BytesMut] was allocated by a [PacketBufPool], this will return the buffer to be re-used later. + _return_to_pool: Option, +} + +/// A marker trait that indicates that a [Packet] contains a valid payload of a specific type. +/// +/// For example, [`CheckedPayload`] is implemented for [`Ipv4<[u8]>`], and a [`Packet>>`] +/// can only be constructed through [`Packet::<[u8]>::try_into_ipvx`], which checks that the IPv4 +/// header is valid. +pub trait CheckedPayload: FromBytes + IntoBytes + KnownLayout + Immutable + Unaligned {} + +impl CheckedPayload for [u8] {} +impl CheckedPayload for Ip {} +impl CheckedPayload for Ipv6

{} +impl CheckedPayload for Ipv4

{} +impl CheckedPayload for Udp

{} +impl CheckedPayload for WgHandshakeInit {} +impl CheckedPayload for WgHandshakeResp {} +impl CheckedPayload for WgCookieReply {} +impl CheckedPayload for WgData {} + +impl Packet { + /// Cast `T` to `Y` without checking anything. + /// + /// Only invoke this after checking that the backing buffer contain a bitwise valid `Y` type. + /// Incorrect usage of this function will cause [`Packet::deref`] to panic. + fn cast(self) -> Packet { + Packet { + inner: self.inner, + _kind: PhantomData::, + } + } + + /// Discard the type of this packet and treat it as a pile of bytes. + pub fn into_bytes(self) -> Packet<[u8]> { + self.cast() + } + + fn buf(&self) -> &[u8] { + &self.inner.buf + } + + /// Create a `Packet` from a `&T`. + pub fn copy_from(payload: &T) -> Self { + Self { + inner: PacketInner { + buf: BytesMut::from(payload.as_bytes()), + _return_to_pool: None, + }, + _kind: PhantomData::, + } + } + + /// Create a `Packet` from a `&Y` by copying its bytes into the backing buffer of this + /// `Packet`. + /// + /// If the `Y` won't fit into the backing buffer, this call will allocate, and effectively + /// devolves into [`Packet::copy_from`]. + pub fn overwrite_with(mut self, payload: &Y) -> Packet { + self.inner.buf.clear(); + self.inner.buf.extend_from_slice(payload.as_bytes()); + self.cast() + } +} + +// Trivial `From`-conversions between packet types +#[duplicate_item( + FromType ToType; + [Ipv4] [Ipv4]; + [Ipv6] [Ipv6]; + + [Ipv4] [Ip]; + [Ipv6] [Ip]; + [Ipv4] [Ip]; + [Ipv6] [Ip]; + + [Ipv4] [[u8]]; + [Ipv6] [[u8]]; + [Ipv4] [[u8]]; + [Ipv6] [[u8]]; + [Ip] [[u8]]; + [WgData] [[u8]]; + [WgHandshakeInit] [[u8]]; + [WgHandshakeResp] [[u8]]; + [WgCookieReply] [[u8]]; +)] +impl From> for Packet { + fn from(value: Packet) -> Packet { + value.cast() + } +} + +impl Default for Packet<[u8]> { + fn default() -> Self { + Self { + inner: PacketInner { + buf: BytesMut::default(), + _return_to_pool: None, + }, + _kind: PhantomData, + } + } +} + +impl Packet<[u8]> { + pub fn new_from_pool(return_to_pool: ReturnToPool, bytes: BytesMut) -> Self { + Self { + inner: PacketInner { + buf: bytes, + _return_to_pool: Some(return_to_pool), + }, + _kind: PhantomData::<[u8]>, + } + } + + /// Create a `Packet::` from a [`BytesMut`]. + pub fn from_bytes(bytes: BytesMut) -> Self { + Self { + inner: PacketInner { + buf: bytes, + _return_to_pool: None, + }, + _kind: PhantomData::<[u8]>, + } + } + + /// See [`BytesMut::truncate`]. + pub fn truncate(&mut self, new_len: usize) { + self.inner.buf.truncate(new_len); + } + + /// Get direct mutable access to the backing buffer. + pub fn buf_mut(&mut self) -> &mut BytesMut { + &mut self.inner.buf + } + + /// Try to cast this untyped packet into an [`Ip`]. + /// + /// This is a stepping stone to casting the packet into an [`Ipv4`] or an [`Ipv6`]. + /// See also [`Packet::try_into_ipvx`]. + /// + /// # Errors + /// + /// Returns [`Err`] if this packet is smaller than [`Ipv4Header::LEN`] bytes. + pub fn try_into_ip(self) -> eyre::Result> { + let buf_len = self.buf().len(); + + // IPv6 packets are larger, but their length after we know the packet IP version. + // This is the smallest any packet can be. + if buf_len < Ipv4Header::LEN { + bail!("Packet too small ({buf_len} < {})", Ipv4Header::LEN); + } + + // we have asserted that the packet is long enough to _maybe_ be an IP packet. + Ok(self.cast::()) + } + + /// Try to cast this untyped packet into either an [`Ipv4`] or [`Ipv6`] packet. + /// + /// The buffer will be truncated to [`Ipv4Header::total_len`] or [`Ipv6Header::total_length`]. + /// + /// # Errors + /// + /// Returns [`Err`] if any of the following checks fail: + /// - The IP version field is `4` or `6` + /// - The packet is smaller than the minimum header length. + /// - The IPv4 packet is smaller than [`Ipv4Header::total_len`]. + /// - The IPv6 payload is smaller than [`Ipv6Header::payload_length`]. + pub fn try_into_ipvx(self) -> eyre::Result, Packet>> { + self.try_into_ip()?.try_into_ipvx() + } +} + +impl Packet { + /// Try to cast this [`Ip`] packet into either an [`Ipv4`] or [`Ipv6`] packet. + /// + /// The buffer will be truncated to [`Ipv4Header::total_len`] or [`Ipv6Header::total_length`]. + /// + /// # Errors + /// + /// Returns [`Err`] if any of the following checks fail: + /// - The IP version field is `4` or `6` + /// - The IPv4 packet is smaller than [`Ipv4Header::total_len`]. + /// - The IPv6 payload is smaller than [`Ipv6Header::payload_length`]. + pub fn try_into_ipvx(mut self) -> eyre::Result, Packet>> { + match self.header.version() { + 4 => { + let buf_len = self.buf().len(); + + let ipv4 = Ipv4::<[u8]>::ref_from_bytes(self.buf()) + .map_err(|e| eyre!("Bad IPv4 packet: {e:?}"))?; + + let ip_len = usize::from(ipv4.header.total_len.get()); + if ip_len > buf_len { + bail!("IPv4 `total_len` exceeded actual packet length: {ip_len} > {buf_len}"); + } + if ip_len < Ipv4Header::LEN { + bail!( + "IPv4 `total_len` less than packet header len: {ip_len} < {}", + Ipv4Header::LEN + ); + } + + self.inner.buf.truncate(ip_len); + + // TODO: validate checksum + + // we have asserted that the packet is a valid IPv4 packet. + // update `_kind` to reflect this. + Ok(Either::Left(self.cast::())) + } + 6 => { + let ipv6 = Ipv6::<[u8]>::ref_from_bytes(self.buf()) + .map_err(|e| eyre!("Bad IPv6 packet: {e:?}"))?; + + let payload_len = usize::from(ipv6.header.payload_length.get()); + if payload_len > ipv6.payload.len() { + bail!( + "IPv6 `payload_len` exceeded actual payload length: {payload_len} > {}", + ipv6.payload.len() + ); + } + + self.inner.buf.truncate(payload_len + Ipv6Header::LEN); + + // TODO: validate checksum + + // we have asserted that the packet is a valid IPv6 packet. + // update `_kind` to reflect this. + Ok(Either::Right(self.cast::())) + } + v => bail!("Bad IP version: {v}"), + } + } +} + +impl Packet { + /// Try to cast this [`Ipv4`] packet into an [`Udp`] packet. + /// + /// Returns `Packet>` if the packet is a valid, + /// non-fragmented IPv4 UDP packet with no options (IHL == `5`). + /// + /// # Errors + /// Returns an error if + /// - the packet is a fragment + /// - the IHL is not `5` + /// - UDP validation fails + pub fn try_into_udp(self) -> eyre::Result>> { + let ip = self.deref(); + + // We validate the IHL here, instead of in the `try_into_ipvx` method, + // because there we can still parse the part of the Ipv4 header that is always present + // and ignore the options. To parse the UDP packet, we must know that the IHL is 5, + // otherwise it will not start at the right offset. + match ip.header.ihl() { + 5 => {} + 6.. => { + return Err(eyre!("IP header: {:?}", ip.header)) + .wrap_err(eyre!("IPv4 packets with options are not supported")); + } + ihl @ ..5 => { + return Err(eyre!("IP header: {:?}", ip.header)) + .wrap_err(eyre!("Bad IHL value: {ihl}")); + } + } + + if ip.header.fragment_offset() != 0 || ip.header.more_fragments() { + eyre::bail!("IPv4 packet is a fragment: {:?}", ip.header); + } + + validate_udp(ip.header.next_protocol(), &ip.payload) + .wrap_err_with(|| eyre!("IP header: {:?}", ip.header))?; + + // we have asserted that the packet is a valid IPv4 UDP packet. + // update `_kind` to reflect this. + Ok(self.cast::>()) + } +} + +impl Packet { + /// Try to cast this [`Ipv6`] packet into an [`Udp`] packet. + /// + /// Returns `Packet>` if the packet is a valid IPv6 UDP packet. + /// + /// # Errors + /// Returns an error if UDP validation fails + pub fn try_into_udp(self) -> eyre::Result>> { + let ip = self.deref(); + + validate_udp(ip.header.next_protocol(), &ip.payload) + .wrap_err_with(|| eyre!("IP header: {:?}", ip.header))?; + + // we have asserted that the packet is a valid IPv6 UDP packet. + // update `_kind` to reflect this. + Ok(self.cast::>()) + } +} + +impl Packet> { + pub fn into_payload(mut self) -> Packet { + debug_assert_eq!( + self.header.ihl() as usize * 4, + Ipv4Header::LEN, + "IPv4 header length must be 20 bytes (IHL = 5)" + ); + self.inner.buf.advance(Ipv4Header::LEN); + self.cast::() + } +} +impl Packet> { + pub fn into_payload(mut self) -> Packet { + self.inner.buf.advance(Ipv6Header::LEN); + self.cast::() + } +} +impl Packet> { + pub fn into_payload(mut self) -> Packet { + self.inner.buf.advance(UdpHeader::LEN); + self.cast::() + } +} + +fn validate_udp(next_protocol: IpNextProtocol, payload: &[u8]) -> eyre::Result<()> { + let IpNextProtocol::Udp = next_protocol else { + bail!("Expected UDP, but packet was {next_protocol:?}"); + }; + + let ip_payload_len = payload.len(); + let udp = Udp::<[u8]>::ref_from_bytes(payload).map_err(|e| eyre!("Bad UDP packet: {e:?}"))?; + + let udp_len = usize::from(udp.header.length.get()); + if udp_len != ip_payload_len { + return Err(eyre!("UDP header: {:?}", udp.header)).wrap_err_with(|| { + eyre!( + "UDP header length did not match IP payload length: {} != {}", + udp_len, + ip_payload_len, + ) + }); + } + + // TODO: validate checksum? + + Ok(()) +} + +impl Deref for Packet +where + Kind: CheckedPayload + ?Sized, +{ + type Target = Kind; + + fn deref(&self) -> &Self::Target { + Self::Target::ref_from_bytes(&self.inner.buf) + .expect("We have previously checked that the payload is valid") + } +} + +impl DerefMut for Packet +where + Kind: CheckedPayload + ?Sized, +{ + fn deref_mut(&mut self) -> &mut Self::Target { + Self::Target::mut_from_bytes(&mut self.inner.buf) + .expect("We have previously checked that the payload is valid") + } +} + +// This clone implementation is only for tests, as the clone will cause an allocation and will not return the buffer to the pool. +#[cfg(test)] +impl Clone for Packet { + fn clone(&self) -> Self { + Self { + inner: PacketInner { + buf: self.inner.buf.clone(), + _return_to_pool: None, // Clone does not return to pool + }, + _kind: PhantomData, + } + } +} + +impl Debug for Packet +where + Kind: CheckedPayload + ?Sized, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("Packet").field(&self.deref()).finish() + } +} diff --git a/src/packet/pool.rs b/src/packet/pool.rs new file mode 100644 index 0000000..41d7554 --- /dev/null +++ b/src/packet/pool.rs @@ -0,0 +1,163 @@ +// Copyright (c) 2025 Mullvad VPN AB. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +use bytes::BytesMut; +use std::{ + collections::VecDeque, + sync::{Arc, Mutex}, +}; + +use crate::packet::Packet; + +/// A pool of packet buffers. +#[derive(Clone)] +pub struct PacketBufPool { + queue: Arc>>, + capacity: usize, +} + +impl PacketBufPool { + /// Create a new [`PacketBufPool`] with space for at least `capacity` packets, + /// each allocated with a capacity of `N` bytes. + pub fn new(capacity: usize) -> Self { + let mut queue = VecDeque::with_capacity(capacity); + + // pre-allocate contiguous backing buffer + let mut backing_buffer = BytesMut::zeroed(N * capacity); + for _ in 0..capacity { + let buf = backing_buffer.split_to(N).split_to(0); + queue.push_back(buf); + } + + PacketBufPool { + queue: Arc::new(Mutex::new(queue)), + capacity, + } + } + + pub fn capacity(&self) -> usize { + self.capacity + } + + /// Get a new [`Packet`] from the pool. + /// + /// This will try to re-use an already allocated packet if possible, or allocate one otherwise. + pub fn get(&self) -> Packet<[u8]> { + while let Some(mut pointer_to_start_of_allocation) = + { self.queue.lock().unwrap().pop_front() } + { + debug_assert_eq!(pointer_to_start_of_allocation.len(), 0); + if pointer_to_start_of_allocation.try_reclaim(N) { + let mut buf = pointer_to_start_of_allocation.split_off(0); + + debug_assert!(buf.capacity() >= N); + + // SAFETY: + // - buf was split from the BytesMut allocated below. + // - buf has not been mutated, and still points to the original allocation. + // - try_reclaim succeeded, so the capacity is at least `N`. + // - the allocation was created using `BytesMut::zeroed`, so the bytes are initialized. + unsafe { buf.set_len(N) }; + + let return_to_pool = ReturnToPool { + pointer_to_start_of_allocation: Some(pointer_to_start_of_allocation), + queue: self.queue.clone(), + }; + + return Packet::new_from_pool(return_to_pool, buf); + } else { + // Backing buffer is still in use. Someone probably called split_* on it. + continue; + } + } + + let mut buf = BytesMut::zeroed(N); + let pointer_to_start_of_allocation = buf.split_to(0); + + debug_assert_eq!(pointer_to_start_of_allocation.len(), 0); + debug_assert_eq!(buf.len(), N); + + let return_to_pool = ReturnToPool { + pointer_to_start_of_allocation: Some(pointer_to_start_of_allocation), + queue: self.queue.clone(), + }; + + Packet::new_from_pool(return_to_pool, buf) + } +} + +/// This sends a previously allocated [`BytesMut`] back to [`PacketBufPool`] when its dropped. +pub struct ReturnToPool { + /// This is a pointer to the allocation allocated by [`PacketBufPool::get`]. + /// By making sure we never modify this (by calling reserve, etc), we can efficiently re-use this allocation later. + /// + /// INVARIANT: + /// - Points to the start of an `N`-sized allocation. + // Note: Option is faster than mem::take + pointer_to_start_of_allocation: Option, + queue: Arc>>, +} + +impl Drop for ReturnToPool { + fn drop(&mut self) { + let p = self.pointer_to_start_of_allocation.take().unwrap(); + let mut queue_g = self.queue.lock().unwrap(); + if queue_g.len() < queue_g.capacity() { + // Add the packet back to the pool unless we're at capacity + queue_g.push_back(p); + } + } +} + +#[cfg(test)] +mod tests { + use std::{hint::black_box, thread}; + + use super::PacketBufPool; + + /// Test buffer recycle semantics of [PacketBufPool]. + #[test] + fn pool_buffer_recycle() { + let pool = PacketBufPool::<4096>::new(1); + + for i in 0..10 { + // Get a packet and record its address. + let mut packet1 = black_box(pool.get()); + let packet1_addr = packet1.buf().as_ptr(); + + // Mutate the packet for good measure + let data = format!("Hello there. x{i}\nGeneral Kenobi! You are a bold one."); + let data = data.as_bytes(); + packet1.truncate(data.len()); + packet1.copy_from_slice(data); + + // Drop the packet, allowing it to be re-used. + // Do it on another thread for good measure. + thread::spawn(move || drop(packet1)).join().unwrap(); + + // Get another packet. This should be the same as packet1. + let packet2 = black_box(pool.get()); + let packet2_addr = packet2.buf().as_ptr(); + + // Get a third packet. + // Since we're still holding packet2, this will result in an allocation. + let packet3 = black_box(pool.get()); + let packet3_addr = packet3.buf().as_ptr(); + + assert!( + packet2.starts_with(data), + "old data should remain in the recycled buffer", + ); + + assert!( + !packet3.starts_with(data), + "old data should not exist in the new buffer", + ); + + assert_eq!(packet1_addr, packet2_addr); + assert_ne!(packet1_addr, packet3_addr); + + drop((packet2, packet3)); + } + } +} diff --git a/src/packet/udp.rs b/src/packet/udp.rs new file mode 100644 index 0000000..379f610 --- /dev/null +++ b/src/packet/udp.rs @@ -0,0 +1,54 @@ +// Copyright (c) 2025 Mullvad VPN AB. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +use std::fmt; + +use zerocopy::{big_endian, FromBytes, Immutable, IntoBytes, KnownLayout, Unaligned}; + +use super::util::size_must_be; + +/// A UDP packet. +/// +/// This is a dynamically sized zerocopy type, which means you can compose packet types like +/// `Ipv6>` and cast them to/from byte slices using [`FromBytes`] and [`IntoBytes`]. +/// [Read more](crate::packet) +#[repr(C)] +#[derive(Debug, FromBytes, IntoBytes, KnownLayout, Unaligned, Immutable)] +pub struct Udp { + /// UDP header. + pub header: UdpHeader, + /// UDP payload. The type of this is `[u8]` by default, but it may be any zerocopy type, + /// e.g. a `WgData` + pub payload: Payload, +} + +/// A UDP header. +#[repr(C, packed)] +#[derive(Clone, Copy, FromBytes, IntoBytes, KnownLayout, Unaligned, Immutable)] +pub struct UdpHeader { + /// UDP source port. + pub source_port: big_endian::U16, + /// UDP destination port. + pub destination_port: big_endian::U16, + /// Length of the UDP packet (including header) in bytes. + pub length: big_endian::U16, + /// Checksum of the UDP packet + pub checksum: big_endian::U16, +} + +impl UdpHeader { + /// Length of a [`UdpHeader`], in bytes. + #[allow(dead_code)] + pub const LEN: usize = size_must_be::(8); +} + +impl fmt::Debug for UdpHeader { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("UdpHeader") + .field("source_port", &self.source_port.get()) + .field("destination_port", &self.destination_port.get()) + .field("length", &self.length.get()) + .field("checksum", &self.checksum.get()) + .finish() + } +} diff --git a/src/packet/util.rs b/src/packet/util.rs new file mode 100644 index 0000000..37395a0 --- /dev/null +++ b/src/packet/util.rs @@ -0,0 +1,13 @@ +// Copyright (c) 2025 Mullvad VPN AB. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +/// Check that the size of type `T` is `size`. If not, panic. +/// +/// Returns `size` for convenience. +pub const fn size_must_be(size: usize) -> usize { + if size_of::() == size { + size + } else { + panic!("Size of T is wrong!") + } +} diff --git a/src/packet/wg.rs b/src/packet/wg.rs new file mode 100644 index 0000000..acf42b0 --- /dev/null +++ b/src/packet/wg.rs @@ -0,0 +1,507 @@ +// Copyright (c) 2025 Mullvad VPN AB. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +#![deny(clippy::unwrap_used)] +use std::fmt::{self, Debug}; +use std::mem::offset_of; +use std::ops::Deref; + +use eyre::{bail, eyre}; +use zerocopy::{little_endian, FromBytes, FromZeros, Immutable, IntoBytes, KnownLayout, Unaligned}; + +use crate::packet::util::size_must_be; +use crate::packet::{CheckedPayload, Packet}; + +#[derive(FromBytes, IntoBytes, KnownLayout, Unaligned, Immutable)] +#[repr(C, packed)] +struct Wg { + pub packet_type: WgPacketType, + rest: [u8], +} + +impl Debug for Wg { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Wg") + .field("packet_type", &self.packet_type) + .finish() + } +} + +/// An owned WireGuard [`Packet`] whose [`WgPacketType`] is known. See [`Packet::try_into_wg`]. +pub enum WgKind { + /// An owned [`WgHandshakeInit`] packet. + HandshakeInit(Packet), + + /// An owned [`WgHandshakeResp`] packet. + HandshakeResp(Packet), + + /// An owned [`WgCookieReply`] packet. + CookieReply(Packet), + + /// An owned [`WgData`] packet. + Data(Packet), +} + +impl Debug for WgKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::HandshakeInit(_) => f.debug_tuple("HandshakeInit").finish(), + Self::HandshakeResp(_) => f.debug_tuple("HandshakeResp").finish(), + Self::CookieReply(_) => f.debug_tuple("CookieReply").finish(), + Self::Data(_) => f.debug_tuple("Data").finish(), + } + } +} + +impl From> for WgKind { + fn from(p: Packet) -> Self { + WgKind::HandshakeInit(p) + } +} + +impl From> for WgKind { + fn from(p: Packet) -> Self { + WgKind::HandshakeResp(p) + } +} + +impl From> for WgKind { + fn from(p: Packet) -> Self { + WgKind::CookieReply(p) + } +} + +impl From> for WgKind { + fn from(p: Packet) -> Self { + WgKind::Data(p) + } +} + +impl From for Packet { + fn from(kind: WgKind) -> Self { + match kind { + WgKind::HandshakeInit(packet) => packet.into(), + WgKind::HandshakeResp(packet) => packet.into(), + WgKind::CookieReply(packet) => packet.into(), + WgKind::Data(packet) => packet.into(), + } + } +} + +/// The first byte of a WireGuard packet. This identifies its type. +#[derive(FromBytes, IntoBytes, KnownLayout, Unaligned, Immutable, PartialEq, Eq, Clone, Copy)] +#[repr(transparent)] +pub struct WgPacketType(pub u8); + +impl WgPacketType { + #![allow(non_upper_case_globals)] + + /// The type discriminant of a [`WgHandshakeInit`] packet. + pub const HandshakeInit: WgPacketType = WgPacketType(1); + + /// The type discriminant of a [`WgHandshakeResp`] packet. + pub const HandshakeResp: WgPacketType = WgPacketType(2); + + /// The type discriminant of a [`WgCookieReply`] packet. + pub const CookieReply: WgPacketType = WgPacketType(3); + + /// The type discriminant of a [`WgData`] packet. + pub const Data: WgPacketType = WgPacketType(4); +} + +/// Header of [`WgData`]. +/// See section 5.4.6 of the [whitepaper](https://www.wireguard.com/papers/wireguard.pdf). +#[derive(FromBytes, IntoBytes, KnownLayout, Unaligned, Immutable)] +#[repr(C)] +pub struct WgDataHeader { + // INVARIANT: Must be WgPacketType::Data + packet_type: WgPacketType, + _reserved_zeros: [u8; 4 - size_of::()], + + /// An integer that identifies the WireGuard session for the receiving peer. + pub receiver_idx: little_endian::U32, + + /// A counter that must be incremented for every data packet to prevent replay attacks. + pub counter: little_endian::U64, +} + +impl WgDataHeader { + /// Header length + pub const LEN: usize = size_must_be::(16); +} + +/// WireGuard data packet. +/// See section 5.4.6 of the [whitepaper](https://www.wireguard.com/papers/wireguard.pdf). +#[derive(FromBytes, IntoBytes, KnownLayout, Unaligned, Immutable)] +#[repr(C, packed)] +pub struct WgData { + /// Data packet header. + pub header: WgDataHeader, + + /// Data packet payload and tag. + pub encrypted_encapsulated_packet_and_tag: WgDataAndTag, +} + +/// WireGuard data payload with a trailing tag. +/// +/// This is essentially a byte slice that is at least [`WgData::TAG_LEN`] long. +#[derive(FromBytes, IntoBytes, KnownLayout, Unaligned, Immutable)] +#[repr(C)] +pub struct WgDataAndTag { + // Don't access these field directly. The tag is actually at the end of the struct. + _tag_size: [u8; WgData::TAG_LEN], + _extra: [u8], +} + +/// An encrypted value with an attached Poly1305 authentication tag. +#[derive(Clone, Copy, FromBytes, IntoBytes, KnownLayout, Unaligned, Immutable, PartialEq, Eq)] +#[repr(C)] +pub struct EncryptedWithTag { + pub encrypted: T, + pub tag: [u8; 16], +} + +impl WgData { + /// Data packet overhead: header and tag (16 bytes) + pub const OVERHEAD: usize = WgDataHeader::LEN + WgData::TAG_LEN; + + /// Length of the trailing `tag` field, in bytes. + pub const TAG_LEN: usize = 16; + + /// Strip the tag from the encapsulated packet. + fn split_encapsulated_packet_and_tag(&mut self) -> (&mut [u8], &mut [u8; WgData::TAG_LEN]) { + self.encrypted_encapsulated_packet_and_tag + .split_last_chunk_mut::<{ WgData::TAG_LEN }>() + .expect("WgDataAndTag is at least TAG_LEN bytes long") + } + + /// Get a reference to the encapsulated packet, without the trailing tag. + pub fn encrypted_encapsulated_packet_mut(&mut self) -> &mut [u8] { + let (encrypted_encapsulated_packet, _) = self.split_encapsulated_packet_and_tag(); + encrypted_encapsulated_packet + } + + /// Get a reference to the tag of the encapsulated packet. + /// + /// Returns None if if the encapsulated packet + tag is less than 16 bytes. + pub fn tag_mut(&mut self) -> &mut [u8; WgData::TAG_LEN] { + let (_, tag) = self.split_encapsulated_packet_and_tag(); + tag + } + + /// Returns true if the payload is empty. + pub const fn is_empty(&self) -> bool { + self.encrypted_encapsulated_packet_and_tag._extra.is_empty() + } + + /// [`Self::is_empty`]. Keepalive packets are just data packets with no payload. + pub const fn is_keepalive(&self) -> bool { + self.is_empty() + } +} + +impl WgDataHeader { + /// Construct a [`WgDataHeader`] where all fields except `packet_type` are zeroed. + pub fn new() -> Self { + Self { + packet_type: WgPacketType::Data, + ..WgDataHeader::new_zeroed() + } + } + + /// Set `receiver_idx`. + pub const fn with_receiver_idx(mut self, receiver_idx: u32) -> Self { + self.receiver_idx = little_endian::U32::new(receiver_idx); + self + } + + /// Set `counter`. + pub const fn with_counter(mut self, counter: u64) -> Self { + self.counter = little_endian::U64::new(counter); + self + } +} + +impl Default for WgDataHeader { + fn default() -> Self { + Self::new() + } +} + +impl Deref for WgDataAndTag { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + self.as_bytes() + } +} + +impl std::ops::DerefMut for WgDataAndTag { + fn deref_mut(&mut self) -> &mut Self::Target { + self.as_mut_bytes() + } +} + +/// Trait for fields common to both [`WgHandshakeInit`] and [`WgHandshakeResp`]. +pub trait WgHandshakeBase: + FromBytes + IntoBytes + KnownLayout + Unaligned + Immutable + CheckedPayload +{ + /// Length of the handshake packet, in bytes. + const LEN: usize; + + /// Offset of the `mac1` field. + /// This is used for getting a byte slice up until `mac1`, i.e. `&packet[..MAC1_OFF]`. + const MAC1_OFF: usize; + + /// Offset of the `mac2` field. + /// This is used for getting a byte slice up until `mac2`, i.e. `&packet[..MAC2_OFF]`. + const MAC2_OFF: usize; + + /// Get `sender_id`. + fn sender_idx(&self) -> u32; + + /// Get a mutable reference to `mac1`. + fn mac1_mut(&mut self) -> &mut [u8; 16]; + + /// Get a mutable reference to `mac2`. + fn mac2_mut(&mut self) -> &mut [u8; 16]; + + /// Get `mac1`. + fn mac1(&self) -> &[u8; 16]; + + /// Get `mac2`. + fn mac2(&self) -> &[u8; 16]; + + /// Get packet until MAC1. Precisely equivalent to `packet[0..offsetof(packet.mac1)]`. + #[inline(always)] + fn until_mac1(&self) -> &[u8] { + &self.as_bytes()[..Self::MAC1_OFF] + } + + /// Get packet until MAC2. Precisely equivalent to `packet[0..offsetof(packet.mac2)]`. + #[inline(always)] + fn until_mac2(&self) -> &[u8] { + &self.as_bytes()[..Self::MAC2_OFF] + } +} + +/// WireGuard handshake initialization packet. +/// See section 5.4.2 of the [whitepaper](https://www.wireguard.com/papers/wireguard.pdf). +#[derive(FromBytes, IntoBytes, KnownLayout, Unaligned, Immutable)] +#[repr(C, packed)] +pub struct WgHandshakeInit { + // INVARIANT: Must be WgPacketType::HandshakeInit + packet_type: WgPacketType, + _reserved_zeros: [u8; 4 - size_of::()], + + /// An integer that identifies the WireGuard session for the initiating peer. + pub sender_idx: little_endian::U32, + + /// Ephemeral public key of the initiating peer. + pub unencrypted_ephemeral: [u8; 32], + + /// Encrypted static public key. + pub encrypted_static: EncryptedWithTag<[u8; 32]>, + + /// A TAI64N timestamp. Used to avoid replay attacks. + pub timestamp: EncryptedWithTag<[u8; 12]>, + + /// Message authentication code 1. + pub mac1: [u8; 16], + + /// Message authentication code 2. + pub mac2: [u8; 16], +} + +impl WgHandshakeInit { + /// Length of the packet, in bytes. + pub const LEN: usize = size_must_be::(148); + + /// Construct a [`WgHandshakeInit`] where all fields except `packet_type` are zeroed. + pub fn new() -> Self { + Self { + packet_type: WgPacketType::HandshakeInit, + ..WgHandshakeInit::new_zeroed() + } + } +} + +impl WgHandshakeBase for WgHandshakeInit { + const LEN: usize = Self::LEN; + const MAC1_OFF: usize = offset_of!(Self, mac1); + const MAC2_OFF: usize = offset_of!(Self, mac2); + + fn sender_idx(&self) -> u32 { + self.sender_idx.get() + } + + fn mac1_mut(&mut self) -> &mut [u8; 16] { + &mut self.mac1 + } + + fn mac2_mut(&mut self) -> &mut [u8; 16] { + &mut self.mac2 + } + + fn mac1(&self) -> &[u8; 16] { + &self.mac1 + } + + fn mac2(&self) -> &[u8; 16] { + &self.mac2 + } +} + +impl Default for WgHandshakeInit { + fn default() -> Self { + Self::new() + } +} + +/// WireGuard handshake response packet. +/// See section 5.4.3 of the [whitepaper](https://www.wireguard.com/papers/wireguard.pdf). +#[derive(FromBytes, IntoBytes, KnownLayout, Unaligned, Immutable)] +#[repr(C, packed)] +pub struct WgHandshakeResp { + // INVARIANT: Must be WgPacketType::HandshakeResp + packet_type: WgPacketType, + _reserved_zeros: [u8; 4 - size_of::()], + + /// An integer that identifies the WireGuard session for the responding peer. + pub sender_idx: little_endian::U32, + + /// An integer that identifies the WireGuard session for the initiating peer. + pub receiver_idx: little_endian::U32, + + /// Ephemeral public key of the responding peer. + pub unencrypted_ephemeral: [u8; 32], + + /// A Poly1305 authentication tag generated from an empty message. + pub encrypted_nothing: EncryptedWithTag<()>, + + /// Message authentication code 1. + pub mac1: [u8; 16], + + /// Message authentication code 2. + pub mac2: [u8; 16], +} + +impl WgHandshakeResp { + /// Length of the packet, in bytes. + pub const LEN: usize = size_must_be::(92); + + /// Construct a [`WgHandshakeResp`]. + pub fn new(sender_idx: u32, receiver_idx: u32, unencrypted_ephemeral: [u8; 32]) -> Self { + Self { + packet_type: WgPacketType::HandshakeResp, + _reserved_zeros: [0; 3], + sender_idx: sender_idx.into(), + receiver_idx: receiver_idx.into(), + unencrypted_ephemeral, + encrypted_nothing: EncryptedWithTag::new_zeroed(), + mac1: [0u8; 16], + mac2: [0u8; 16], + } + } +} + +impl WgHandshakeBase for WgHandshakeResp { + const LEN: usize = Self::LEN; + const MAC1_OFF: usize = offset_of!(Self, mac1); + const MAC2_OFF: usize = offset_of!(Self, mac2); + + fn sender_idx(&self) -> u32 { + self.sender_idx.get() + } + + fn mac1_mut(&mut self) -> &mut [u8; 16] { + &mut self.mac1 + } + + fn mac2_mut(&mut self) -> &mut [u8; 16] { + &mut self.mac2 + } + + fn mac1(&self) -> &[u8; 16] { + &self.mac1 + } + + fn mac2(&self) -> &[u8; 16] { + &self.mac2 + } +} + +/// WireGuard cookie reply packet. +/// See section 5.4.7 of the [whitepaper](https://www.wireguard.com/papers/wireguard.pdf). +#[derive(FromBytes, IntoBytes, KnownLayout, Unaligned, Immutable)] +#[repr(C, packed)] +pub struct WgCookieReply { + // INVARIANT: Must be WgPacketType::CookieReply + packet_type: WgPacketType, + _reserved_zeros: [u8; 4 - size_of::()], + + /// An integer that identifies the WireGuard session for the handshake-initiating peer. + pub receiver_idx: little_endian::U32, + + /// Number only used once. + pub nonce: [u8; 24], + + /// An encrypted 16-byte value that identifies the [`WgHandshakeInit`] that this packet is in response to. + /// Plus a 16 byte Poly1305 authentication tag. + pub encrypted_cookie: EncryptedWithTag<[u8; 16]>, +} + +impl WgCookieReply { + /// Length of the packet, in bytes. + pub const LEN: usize = size_must_be::(64); + + /// Construct a [`WgCookieReply`] where all fields except `packet_type` are zeroed. + pub fn new() -> Self { + Self { + packet_type: WgPacketType::CookieReply, + ..Self::new_zeroed() + } + } +} + +impl Default for WgCookieReply { + fn default() -> Self { + Self::new() + } +} + +impl Packet { + /// Try to cast to a WireGuard packet while sanity-checking packet type and size. + pub fn try_into_wg(self) -> eyre::Result { + let wg = Wg::ref_from_bytes(self.as_bytes()) + .map_err(|_| eyre!("Not a wireguard packet, too small."))?; + + let len = wg.as_bytes().len(); + match (wg.packet_type, len) { + (WgPacketType::HandshakeInit, WgHandshakeInit::LEN) => { + Ok(WgKind::HandshakeInit(self.cast())) + } + (WgPacketType::HandshakeResp, WgHandshakeResp::LEN) => { + Ok(WgKind::HandshakeResp(self.cast())) + } + (WgPacketType::CookieReply, WgCookieReply::LEN) => Ok(WgKind::CookieReply(self.cast())), + (WgPacketType::Data, WgData::OVERHEAD..) => Ok(WgKind::Data(self.cast())), + _ => bail!("Not a wireguard packet, bad type/size."), + } + } +} + +impl Debug for WgPacketType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let name = match self { + &WgPacketType::HandshakeInit => "HandshakeInit", + &WgPacketType::HandshakeResp => "HandshakeResp", + &WgPacketType::CookieReply => "CookieReply", + &WgPacketType::Data => "Data", + + WgPacketType(t) => return Debug::fmt(t, f), + }; + + f.debug_tuple(name).finish() + } +} diff --git a/src/sleepyinstant/mod.rs b/src/sleepyinstant/mod.rs index 542beea..2af7b67 100644 --- a/src/sleepyinstant/mod.rs +++ b/src/sleepyinstant/mod.rs @@ -1,5 +1,5 @@ #![forbid(unsafe_code)] -//! Attempts to provide the same functionality as std::time::Instant, except it +//! Attempts to provide the same functionality as `std::time::Instant`, except it //! uses a timer which accounts for time when the system is asleep use std::time::Duration;