Skip to content

Commit 2e404a4

Browse files
committed
Use headers
1 parent 491dad9 commit 2e404a4

6 files changed

Lines changed: 69 additions & 177 deletions

File tree

Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ sha-1 = "0.10"
4747
thiserror = "1.0.23"
4848
url = "2.1.0"
4949
utf-8 = "0.7.5"
50+
headers = "0.3.7"
5051

5152
[dependencies.flate2]
5253
optional = true
@@ -83,3 +84,6 @@ rand = "0.8.4"
8384
[[bench]]
8485
name = "buffer"
8586
harness = false
87+
88+
[patch.crates-io]
89+
headers = { git = "https://github.com/kazk/headers", branch = "sec-websocket-extensions" }

src/error.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,9 @@ pub enum ProtocolError {
234234
/// The negotiation response included an extension more than once.
235235
#[error("Extension negotiation response had conflicting extension: {0}")]
236236
ExtensionConflict(String),
237+
/// The `Sec-WebSocket-Extensions` header is invalid.
238+
#[error("Invalid \"Sec-WebSocket-Extensions\" header")]
239+
InvalidExtensionsHeader,
237240
}
238241

239242
/// Indicates the specific type/cause of URL error.

src/extensions/mod.rs

Lines changed: 0 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
//! WebSocket extensions.
22
// Only `permessage-deflate` is supported at the moment.
3-
use http::HeaderValue;
43

54
#[cfg(feature = "deflate")]
65
mod compression;
@@ -17,128 +16,3 @@ pub struct Extensions {
1716
#[cfg(feature = "deflate")]
1817
pub(crate) compression: Option<DeflateContext>,
1918
}
20-
21-
/// Iterator of all extension offers/responses in `Sec-WebSocket-Extensions` values.
22-
pub(crate) fn iter_all<'a>(
23-
values: impl Iterator<Item = &'a HeaderValue>,
24-
) -> impl Iterator<Item = (&'a str, impl Iterator<Item = (&'a str, Option<&'a str>)>)> {
25-
values
26-
.filter_map(|h| h.to_str().ok())
27-
.map(|value_str| {
28-
split_iter(value_str, ',').filter_map(|offer| {
29-
// Parameters are separted by semicolons.
30-
// The first element is the name of the extension.
31-
let mut iter = split_iter(offer.trim(), ';').map(str::trim);
32-
let name = iter.next()?;
33-
let params = iter.filter_map(|kv| {
34-
let mut it = kv.splitn(2, '=');
35-
let key = it.next()?.trim();
36-
let val = it.next().map(|v| v.trim().trim_matches('"'));
37-
Some((key, val))
38-
});
39-
Some((name, params))
40-
})
41-
})
42-
.flatten()
43-
}
44-
45-
fn split_iter(input: &str, sep: char) -> impl Iterator<Item = &str> {
46-
let mut in_quotes = false;
47-
let mut prev = None;
48-
input.split(move |c| {
49-
if in_quotes {
50-
if c == '"' && prev != Some('\\') {
51-
in_quotes = false;
52-
}
53-
prev = Some(c);
54-
false
55-
} else if c == sep {
56-
prev = Some(c);
57-
true
58-
} else {
59-
if c == '"' {
60-
in_quotes = true;
61-
}
62-
prev = Some(c);
63-
false
64-
}
65-
})
66-
}
67-
68-
#[cfg(test)]
69-
mod tests {
70-
use http::{header::SEC_WEBSOCKET_EXTENSIONS, HeaderMap};
71-
72-
use super::*;
73-
74-
// Make sure comma separated offers and multiple headers are equivalent
75-
fn test_iteration<'a>(
76-
mut iter: impl Iterator<Item = (&'a str, impl Iterator<Item = (&'a str, Option<&'a str>)>)>,
77-
) {
78-
let (name, mut params) = iter.next().unwrap();
79-
assert_eq!(name, "permessage-deflate");
80-
assert_eq!(params.next(), Some(("client_max_window_bits", None)));
81-
assert_eq!(params.next(), Some(("server_max_window_bits", Some("10"))));
82-
assert!(params.next().is_none());
83-
84-
let (name, mut params) = iter.next().unwrap();
85-
assert_eq!(name, "permessage-deflate");
86-
assert_eq!(params.next(), Some(("client_max_window_bits", None)));
87-
assert!(params.next().is_none());
88-
89-
assert!(iter.next().is_none());
90-
}
91-
92-
#[test]
93-
fn iter_single() {
94-
let mut hm = HeaderMap::new();
95-
hm.append(
96-
SEC_WEBSOCKET_EXTENSIONS,
97-
HeaderValue::from_static(
98-
"permessage-deflate; client_max_window_bits; server_max_window_bits=10, permessage-deflate; client_max_window_bits",
99-
),
100-
);
101-
test_iteration(iter_all(std::iter::once(hm.get(SEC_WEBSOCKET_EXTENSIONS).unwrap())));
102-
}
103-
104-
#[test]
105-
fn iter_multiple() {
106-
let mut hm = HeaderMap::new();
107-
hm.append(
108-
SEC_WEBSOCKET_EXTENSIONS,
109-
HeaderValue::from_static(
110-
"permessage-deflate; client_max_window_bits; server_max_window_bits=10",
111-
),
112-
);
113-
hm.append(
114-
SEC_WEBSOCKET_EXTENSIONS,
115-
HeaderValue::from_static("permessage-deflate; client_max_window_bits"),
116-
);
117-
test_iteration(iter_all(hm.get_all(SEC_WEBSOCKET_EXTENSIONS).iter()));
118-
}
119-
}
120-
121-
// TODO More strict parsing
122-
// https://datatracker.ietf.org/doc/html/rfc6455#section-4.3
123-
// Sec-WebSocket-Extensions = extension-list
124-
// extension-list = 1#extension
125-
// extension = extension-token *( ";" extension-param )
126-
// extension-token = registered-token
127-
// registered-token = token
128-
// extension-param = token [ "=" (token | quoted-string) ]
129-
// ;When using the quoted-string syntax variant, the value
130-
// ;after quoted-string unescaping MUST conform to the
131-
// ;'token' ABNF.
132-
//
133-
// token = 1*<any CHAR except CTLs or separators>
134-
// CHAR = <any US-ASCII character (octets 0 - 127)>
135-
// CTL = <any US-ASCII control character (octets 0 - 31) and DEL (127)>
136-
// separators = "(" | ")" | "<" | ">" | "@"
137-
// | "," | ";" | ":" | "\" | <">
138-
// | "/" | "[" | "]" | "?" | "="
139-
// | "{" | "}" | SP | HT
140-
// SP = <US-ASCII SP, space (32)>
141-
// HT = <US-ASCII HT, horizontal-tab (9)>
142-
// quoted-string = ( <"> *(qdtext | quoted-pair ) <"> )
143-
// qdtext = <any TEXT except <">>
144-
// quoted-pair = "\" CHAR

src/handshake/client.rs

Lines changed: 42 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use std::{
55
marker::PhantomData,
66
};
77

8+
use headers::{HeaderMapExt, SecWebsocketExtensions};
89
use http::{
910
header::HeaderName, HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode,
1011
};
@@ -234,44 +235,53 @@ impl VerifyData {
234235
// indicated an extension not requested by the client), the client
235236
// MUST _Fail the WebSocket Connection_. (RFC 6455)
236237
let extensions = {
237-
// Note that multiple headers are allowed in response. See https://www.rfc-editor.org/errata/eid3433
238-
let mut agreed_extensions =
239-
crate::extensions::iter_all(headers.get_all("Sec-WebSocket-Extensions").iter());
240-
#[cfg(feature = "deflate")]
238+
if let Some(agreed_extensions) = headers
239+
.typed_try_get::<SecWebsocketExtensions>()
240+
.map_err(|_| Error::Protocol(ProtocolError::InvalidExtensionsHeader))?
241241
{
242-
let mut extensions = None;
243-
if let Some(compression) = _config.and_then(|c| c.compression) {
244-
for (name, params) in agreed_extensions {
245-
if name != compression.name() {
246-
return Err(Error::Protocol(ProtocolError::InvalidExtension(
247-
name.to_string(),
248-
)));
242+
let mut agreed_extensions = agreed_extensions.iter();
243+
#[cfg(feature = "deflate")]
244+
{
245+
let mut extensions = None;
246+
if let Some(compression) = _config.and_then(|c| c.compression) {
247+
for extension in agreed_extensions {
248+
if extension.name() != compression.name() {
249+
return Err(Error::Protocol(ProtocolError::InvalidExtension(
250+
extension.name().to_string(),
251+
)));
252+
}
253+
254+
// Already had PMCE configured
255+
if extensions.is_some() {
256+
return Err(Error::Protocol(ProtocolError::ExtensionConflict(
257+
extension.name().to_string(),
258+
)));
259+
}
260+
261+
extensions = Some(Extensions {
262+
compression: Some(compression.accept_response(extension.params())?),
263+
});
249264
}
250-
251-
// Already had PMCE configured
252-
if extensions.is_some() {
253-
return Err(Error::Protocol(ProtocolError::ExtensionConflict(
254-
name.to_string(),
255-
)));
256-
}
257-
258-
extensions = Some(Extensions {
259-
compression: Some(compression.accept_response(params)?),
260-
});
265+
} else if let Some(extension) = agreed_extensions.next() {
266+
// The client didn't request anything, but got something
267+
return Err(Error::Protocol(ProtocolError::InvalidExtension(
268+
extension.name().to_string(),
269+
)));
261270
}
262-
} else if let Some((name, _)) = agreed_extensions.next() {
263-
// The client didn't request anything, but got something
264-
return Err(Error::Protocol(ProtocolError::InvalidExtension(name.to_string())));
271+
extensions
265272
}
266-
extensions
267-
}
268273

269-
#[cfg(not(feature = "deflate"))]
270-
{
271-
if let Some((name, _)) = agreed_extensions.next() {
272-
// The client didn't request anything, but got something
273-
return Err(Error::Protocol(ProtocolError::InvalidExtension(name.to_string())));
274+
#[cfg(not(feature = "deflate"))]
275+
{
276+
if let Some(extension) = agreed_extensions.next() {
277+
// The client didn't request anything, but got something
278+
return Err(Error::Protocol(ProtocolError::InvalidExtension(
279+
extension.name().to_string(),
280+
)));
281+
}
282+
None
274283
}
284+
} else {
275285
None
276286
}
277287
};

src/handshake/server.rs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use std::{
66
result::Result as StdResult,
77
};
88

9+
use headers::{HeaderMapExt, SecWebsocketExtensions};
910
use http::{
1011
response::Builder, HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode,
1112
};
@@ -246,10 +247,15 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
246247

247248
let mut response = create_response(&result)?;
248249
if let Some(config) = &self.config {
249-
let values = result.headers().get_all("Sec-WebSocket-Extensions").iter();
250-
if let Some((agreed, extensions)) = config.accept_offers(values) {
251-
response.headers_mut().insert("Sec-WebSocket-Extensions", agreed);
252-
self.extensions = Some(extensions);
250+
if let Some(values) = result
251+
.headers()
252+
.typed_try_get::<SecWebsocketExtensions>()
253+
.map_err(|_| Error::Protocol(ProtocolError::InvalidExtensionsHeader))?
254+
{
255+
if let Some((agreed, extensions)) = config.accept_offers(values) {
256+
response.headers_mut().insert("Sec-WebSocket-Extensions", agreed);
257+
self.extensions = Some(extensions);
258+
}
253259
}
254260
}
255261

src/protocol/mod.rs

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -92,25 +92,20 @@ impl WebSocketConfig {
9292

9393
// This can be used with `WebSocket::from_raw_socket_with_extensions` for integration.
9494
/// Returns negotiation response based on offers and `Extensions` to manage extensions.
95-
pub fn accept_offers<'a>(
96-
&'a self,
97-
_extensions: impl Iterator<Item = &'a HeaderValue>,
95+
pub fn accept_offers(
96+
&self,
97+
_extensions: headers::SecWebsocketExtensions,
9898
) -> Option<(HeaderValue, Extensions)> {
9999
#[cfg(feature = "deflate")]
100100
{
101101
if let Some(compression) = &self.compression {
102-
let extensions = crate::extensions::iter_all(_extensions);
103-
let offers =
104-
extensions.filter_map(
105-
|(k, v)| {
106-
if k == compression.name() {
107-
Some(v)
108-
} else {
109-
None
110-
}
111-
},
112-
);
113-
102+
let offers = _extensions.iter().filter_map(|extension| {
103+
if extension.name() == compression.name() {
104+
Some(extension.params())
105+
} else {
106+
None
107+
}
108+
});
114109
// To support more extensions, store extension context in `Extensions` and
115110
// concatenate negotiation responses from each extension.
116111
compression

0 commit comments

Comments
 (0)