Skip to content

Commit 8b44d12

Browse files
committed
Add deflate feature
1 parent c62eccc commit 8b44d12

11 files changed

Lines changed: 131 additions & 57 deletions

File tree

.travis.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,6 @@ before_script:
1010

1111
script:
1212
- cargo test --release
13+
- cargo test --release --features=deflate
1314
- echo "Running Autobahn TestSuite for client" && ./scripts/autobahn-client.sh
1415
- echo "Running Autobahn TestSuite for server" && ./scripts/autobahn-server.sh

Cargo.toml

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,21 @@ native-tls-vendored = ["native-tls", "native-tls-crate/vendored"]
2323
rustls-tls-native-roots = ["__rustls-tls", "rustls-native-certs"]
2424
rustls-tls-webpki-roots = ["__rustls-tls", "webpki-roots"]
2525
__rustls-tls = ["rustls", "webpki"]
26+
deflate = ["flate2"]
27+
# deflate-zlib = ["flate2/zlib"]
28+
29+
[[example]]
30+
name = "autobahn-client"
31+
required-features = ["deflate"]
32+
33+
[[example]]
34+
name = "autobahn-server"
35+
required-features = ["deflate"]
2636

2737
[dependencies]
2838
base64 = "0.13.0"
2939
byteorder = "1.3.2"
3040
bytes = "1.0"
31-
flate2 = "1.0"
3241
http = "0.2"
3342
httparse = "1.3.4"
3443
log = "0.4.8"
@@ -38,6 +47,10 @@ thiserror = "1.0.23"
3847
url = "2.1.0"
3948
utf-8 = "0.7.5"
4049

50+
[dependencies.flate2]
51+
optional = true
52+
version = "1.0"
53+
4154
[dependencies.native-tls-crate]
4255
optional = true
4356
package = "native-tls"

examples/srv_accept_unmasked_frames.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ fn main() {
3535
// rare cases where it is necessary to integrate with existing/legacy
3636
// clients which are sending unmasked frames
3737
accept_unmasked_frames: true,
38-
..WebSocketConfig::default()
38+
#[cfg(feature = "deflate")]
39+
compression: None,
3940
});
4041

4142
let mut websocket = accept_hdr_with_config(stream.unwrap(), callback, config).unwrap();

scripts/autobahn-client.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,5 +32,5 @@ docker run -d --rm \
3232
wstest -m fuzzingserver -s 'autobahn/fuzzingserver.json'
3333

3434
sleep 3
35-
cargo run --release --example autobahn-client
35+
cargo run --release --example autobahn-client --features=deflate
3636
test_diff

scripts/autobahn-server.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ function test_diff() {
2222
fi
2323
}
2424

25-
cargo run --release --example autobahn-server & WSSERVER_PID=$!
25+
cargo run --release --example autobahn-server --features=deflate & WSSERVER_PID=$!
2626
sleep 3
2727

2828
docker run --rm \

src/error.rs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,7 @@
22
33
use std::{io, result, str, string};
44

5-
use crate::{
6-
extensions,
7-
protocol::{frame::coding::Data, Message},
8-
};
5+
use crate::protocol::{frame::coding::Data, Message};
96
use http::Response;
107
use thiserror::Error;
118

@@ -71,8 +68,9 @@ pub enum Error {
7168
#[error("HTTP format error: {0}")]
7269
HttpFormat(#[from] http::Error),
7370
/// Error from `permessage-deflate` extension.
71+
#[cfg(feature = "deflate")]
7472
#[error("Deflate error: {0}")]
75-
Deflate(#[from] extensions::DeflateError),
73+
Deflate(#[from] crate::extensions::DeflateError),
7674
}
7775

7876
impl From<str::Utf8Error> for Error {

src/extensions/mod.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,20 @@
11
//! WebSocket extensions.
22
// Only `permessage-deflate` is supported at the moment.
3+
use http::HeaderValue;
34

5+
#[cfg(feature = "deflate")]
46
mod compression;
7+
#[cfg(feature = "deflate")]
58
use compression::deflate::DeflateContext;
9+
#[cfg(feature = "deflate")]
610
pub use compression::deflate::{DeflateConfig, DeflateError};
7-
use http::HeaderValue;
811

912
/// Container for configured extensions.
1013
#[derive(Debug, Default)]
14+
#[allow(missing_copy_implementations)]
1115
pub struct Extensions {
1216
// Per-Message Compression. Only `permessage-deflate` is supported.
17+
#[cfg(feature = "deflate")]
1318
pub(crate) compression: Option<DeflateContext>,
1419
}
1520

src/handshake/client.rs

Lines changed: 45 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ use super::{
1717
};
1818
use crate::{
1919
error::{Error, ProtocolError, Result, UrlError},
20-
extensions::{self, Extensions},
20+
extensions::Extensions,
2121
protocol::{Role, WebSocket, WebSocketConfig},
2222
};
2323

@@ -161,7 +161,7 @@ impl VerifyData {
161161
pub fn verify_response(
162162
&self,
163163
response: Response,
164-
config: &Option<WebSocketConfig>,
164+
_config: &Option<WebSocketConfig>,
165165
) -> Result<(Response, Option<Extensions>)> {
166166
// 1. If the status code received from the server is not 101, the
167167
// client handles the response per HTTP [RFC2616] procedures. (RFC 6455)
@@ -202,43 +202,62 @@ impl VerifyData {
202202
if !headers.get("Sec-WebSocket-Accept").map(|h| h == &self.accept_key).unwrap_or(false) {
203203
return Err(Error::Protocol(ProtocolError::SecWebSocketAcceptKeyMismatch));
204204
}
205-
let mut extensions = None;
206205
// 5. If the response includes a |Sec-WebSocket-Extensions| header
207206
// field and this header field indicates the use of an extension
208207
// that was not present in the client's handshake (the server has
209208
// indicated an extension not requested by the client), the client
210209
// MUST _Fail the WebSocket Connection_. (RFC 6455)
211210
let mut extensions_values = headers.get_all("Sec-WebSocket-Extensions").iter();
212-
if let Some(value) = extensions_values.next() {
211+
let extensions = if let Some(value) = extensions_values.next() {
213212
if extensions_values.next().is_some() {
214213
return Err(Error::Protocol(ProtocolError::MultipleExtensionsHeaderInResponse));
215214
}
216215

217-
let mut exts = extensions::iter_all(std::iter::once(value));
218-
if let Some(compression) = &config.and_then(|c| c.compression) {
219-
for (name, params) in exts {
220-
if name != compression.name() {
216+
let mut exts = crate::extensions::iter_all(std::iter::once(value));
217+
#[cfg(feature = "deflate")]
218+
{
219+
let mut extensions = None;
220+
if let Some(config) = _config {
221+
if let Some(compression) = config.compression {
222+
for (name, params) in exts {
223+
if name != compression.name() {
224+
return Err(Error::Protocol(ProtocolError::InvalidExtension(
225+
name.to_string(),
226+
)));
227+
}
228+
229+
// Already had PMCE configured
230+
if extensions.is_some() {
231+
return Err(Error::Protocol(ProtocolError::ExtensionConflict(
232+
name.to_string(),
233+
)));
234+
}
235+
236+
extensions = Some(Extensions {
237+
compression: Some(compression.accept_response(params)?),
238+
});
239+
}
240+
} else if let Some((name, _)) = exts.next() {
241+
// The client didn't request anything, but got something
221242
return Err(Error::Protocol(ProtocolError::InvalidExtension(
222243
name.to_string(),
223244
)));
224245
}
246+
}
247+
extensions
248+
}
225249

226-
// Already had PMCE configured
227-
if extensions.is_some() {
228-
return Err(Error::Protocol(ProtocolError::ExtensionConflict(
229-
name.to_string(),
230-
)));
231-
}
232-
233-
extensions = Some(Extensions {
234-
compression: Some(compression.accept_response(params)?),
235-
});
250+
#[cfg(not(feature = "deflate"))]
251+
{
252+
if let Some((name, _)) = exts.next() {
253+
// The client didn't request anything, but got something
254+
return Err(Error::Protocol(ProtocolError::InvalidExtension(name.to_string())));
236255
}
237-
} else if let Some((name, _)) = exts.next() {
238-
// The client didn't request anything, but got something
239-
return Err(Error::Protocol(ProtocolError::InvalidExtension(name.to_string())));
256+
None
240257
}
241-
}
258+
} else {
259+
None
260+
};
242261

243262
// 6. If the response includes a |Sec-WebSocket-Protocol| header field
244263
// and this header field indicates the use of a subprotocol that was
@@ -292,7 +311,9 @@ fn generate_key() -> String {
292311
#[cfg(test)]
293312
mod tests {
294313
use super::{super::machine::TryParse, generate_key, generate_request, Response};
295-
use crate::{client::IntoClientRequest, extensions::DeflateConfig, protocol::WebSocketConfig};
314+
use crate::client::IntoClientRequest;
315+
#[cfg(feature = "deflate")]
316+
use crate::{extensions::DeflateConfig, protocol::WebSocketConfig};
296317

297318
#[test]
298319
fn random_keys() {
@@ -361,6 +382,7 @@ mod tests {
361382
assert_eq!(&request[..], &correct[..]);
362383
}
363384

385+
#[cfg(feature = "deflate")]
364386
#[test]
365387
fn request_with_compression() {
366388
let request = "ws://localhost/getCaseCount".into_client_request().unwrap();

src/protocol/frame/frame.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,7 @@ impl Frame {
306306

307307
/// Create a new compressed data frame.
308308
#[inline]
309+
#[cfg(feature = "deflate")]
309310
pub(crate) fn compressed_message(data: Vec<u8>, opcode: OpCode, is_final: bool) -> Frame {
310311
debug_assert!(matches!(opcode, OpCode::Data(_)), "Invalid opcode for data frame.");
311312

src/protocol/message.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ impl IncompleteMessage {
107107
}
108108
}
109109

110+
#[cfg(feature = "deflate")]
110111
pub fn compressed(&self) -> bool {
111112
self.compressed
112113
}

0 commit comments

Comments
 (0)