Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ byteorder = "1.3.2"
bytes = "1.0"
http = "0.2"
httparse = "1.3.4"
input_buffer = "0.4.0"
log = "0.4.8"
rand = "0.8.0"
sha-1 = "0.9"
Expand All @@ -53,5 +52,12 @@ optional = true
version = "0.5.0"

[dev-dependencies]
criterion = "0.3.4"
env_logger = "0.8.1"
input_buffer = "0.5.0"
net2 = "0.2.33"
rand = "0.8.4"

[[bench]]
name = "buffer"
harness = false
36 changes: 36 additions & 0 deletions benches/buffer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
use std::io::{Cursor, Read};

use criterion::*;
use input_buffer::InputBuffer;
use tungstenite::buffer::ReadBuffer;

const CHUNK_SIZE: usize = 4096;

#[inline]
fn current_input_buffer(mut stream: impl Read) {
let mut buffer = InputBuffer::with_capacity(CHUNK_SIZE);
while buffer.read_from(&mut stream).unwrap() != 0 {}
}

#[inline]
fn fast_input_buffer(mut stream: impl Read) {
let mut buffer = ReadBuffer::<CHUNK_SIZE>::new();
while buffer.read_from(&mut stream).unwrap() != 0 {}
}

fn benchmark(c: &mut Criterion) {
const STREAM_SIZE: usize = 1024 * 1024 * 4;
let data: Vec<u8> = (0..STREAM_SIZE).map(|_| rand::random()).collect();
let stream = Cursor::new(data);

let mut group = c.benchmark_group("buffers");
group.throughput(Throughput::Bytes(STREAM_SIZE as u64));
group.bench_function("InputBuffer", |b| {
b.iter(|| current_input_buffer(black_box(stream.clone())))
});
group.bench_function("ReadBuffer", |b| b.iter(|| fast_input_buffer(black_box(stream.clone()))));
group.finish();
}

criterion_group!(benches, benchmark);
criterion_main!(benches);
119 changes: 119 additions & 0 deletions src/buffer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
//! A buffer for reading data from the network.
//!
//! The `ReadBuffer` is a buffer of bytes similar to a first-in, first-out queue.
//! It is filled by reading from a stream supporting `Read` and is then
//! accessible as a cursor for reading bytes.

use std::io::{Cursor, Read, Result as IoResult};

use bytes::Buf;

/// A FIFO buffer for reading packets from the network.
#[derive(Debug)]
pub struct ReadBuffer<const CHUNK_SIZE: usize> {
storage: Cursor<Vec<u8>>,
chunk: [u8; CHUNK_SIZE],
}

impl<const CHUNK_SIZE: usize> ReadBuffer<CHUNK_SIZE> {
/// Create a new empty input buffer.
pub fn new() -> Self {
Self::with_capacity(CHUNK_SIZE)
}

/// Create a new empty input buffer with a given `capacity`.
pub fn with_capacity(capacity: usize) -> Self {
Self::from_partially_read(Vec::with_capacity(capacity))
}

/// Create a input buffer filled with previously read data.
pub fn from_partially_read(part: Vec<u8>) -> Self {
Self { storage: Cursor::new(part), chunk: [0; CHUNK_SIZE] }
}

/// Get a cursor to the data storage.
pub fn as_cursor(&self) -> &Cursor<Vec<u8>> {
&self.storage
}

/// Get a cursor to the mutable data storage.
pub fn as_cursor_mut(&mut self) -> &mut Cursor<Vec<u8>> {
&mut self.storage
}

/// Consume the `ReadBuffer` and get the internal storage.
pub fn into_vec(mut self) -> Vec<u8> {
// Current implementation of `tungstenite-rs` expects that the `into_vec()` drains
// the data from the container that has already been read by the cursor.
self.clean_up();

// Now we can safely return the internal container.
self.storage.into_inner()
}

/// Read next portion of data from the given input stream.
pub fn read_from<S: Read>(&mut self, stream: &mut S) -> IoResult<usize> {
self.clean_up();
let size = stream.read(&mut self.chunk)?;
self.storage.get_mut().extend_from_slice(&self.chunk[..size]);
Ok(size)
}

/// Cleans ups the part of the vector that has been already read by the cursor.
fn clean_up(&mut self) {
let pos = self.storage.position() as usize;
self.storage.get_mut().drain(0..pos).count();
self.storage.set_position(0);
}
}

impl<const CHUNK_SIZE: usize> Buf for ReadBuffer<CHUNK_SIZE> {
fn remaining(&self) -> usize {
Buf::remaining(self.as_cursor())
}

fn chunk(&self) -> &[u8] {
Buf::chunk(self.as_cursor())
}

fn advance(&mut self, cnt: usize) {
Buf::advance(self.as_cursor_mut(), cnt)
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn simple_reading() {
let mut input = Cursor::new(b"Hello World!".to_vec());
let mut buffer = ReadBuffer::<4096>::new();
let size = buffer.read_from(&mut input).unwrap();
assert_eq!(size, 12);
assert_eq!(buffer.chunk(), b"Hello World!");
}

#[test]
fn reading_in_chunks() {
let mut inp = Cursor::new(b"Hello World!".to_vec());
let mut buf = ReadBuffer::<4>::new();

let size = buf.read_from(&mut inp).unwrap();
assert_eq!(size, 4);
assert_eq!(buf.chunk(), b"Hell");

buf.advance(2);
assert_eq!(buf.chunk(), b"ll");
assert_eq!(buf.storage.get_mut(), b"Hell");

let size = buf.read_from(&mut inp).unwrap();
assert_eq!(size, 4);
assert_eq!(buf.chunk(), b"llo Wo");
assert_eq!(buf.storage.get_mut(), b"llo Wo");

let size = buf.read_from(&mut inp).unwrap();
assert_eq!(size, 4);
assert_eq!(buf.chunk(), b"llo World!");
}
}
3 changes: 2 additions & 1 deletion src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ mod encryption {
Mode::Tls => {
let config = {
let mut config = ClientConfig::new();
config.root_store = rustls_native_certs::load_native_certs().map_err(|(_, err)| err)?;
config.root_store =
rustls_native_certs::load_native_certs().map_err(|(_, err)| err)?;

Arc::new(config)
};
Expand Down
5 changes: 0 additions & 5 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,6 @@ pub enum CapacityError {
#[error("Too many headers")]
TooManyHeaders,
/// Received header is too long.
#[error("Header too long")]
HeaderTooLong,
/// Message is bigger than the maximum allowed size.
#[error("Message too long: {size} > {max_size}")]
MessageTooLong {
Expand All @@ -137,9 +135,6 @@ pub enum CapacityError {
/// The maximum allowed message size.
max_size: usize,
},
/// TCP buffer is full.
#[error("Incoming TCP buffer is full")]
TcpBufferFull,
}

/// Indicates the specific type/cause of a protocol error.
Expand Down
18 changes: 5 additions & 13 deletions src/handshake/machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ use log::*;
use std::io::{Cursor, Read, Write};

use crate::{
error::{CapacityError, Error, ProtocolError, Result},
error::{Error, ProtocolError, Result},
util::NonBlockingResult,
ReadBuffer,
};
use input_buffer::{InputBuffer, MIN_READ};

/// A generic handshake state machine.
#[derive(Debug)]
Expand All @@ -18,10 +18,7 @@ pub struct HandshakeMachine<Stream> {
impl<Stream> HandshakeMachine<Stream> {
/// Start reading data from the peer.
pub fn start_read(stream: Stream) -> Self {
HandshakeMachine {
stream,
state: HandshakeState::Reading(InputBuffer::with_capacity(MIN_READ)),
}
HandshakeMachine { stream, state: HandshakeState::Reading(ReadBuffer::new()) }
}
/// Start writing data to the peer.
pub fn start_write<D: Into<Vec<u8>>>(stream: Stream, data: D) -> Self {
Expand All @@ -43,12 +40,7 @@ impl<Stream: Read + Write> HandshakeMachine<Stream> {
trace!("Doing handshake round.");
match self.state {
HandshakeState::Reading(mut buf) => {
let read = buf
.prepare_reserve(MIN_READ)
.with_limit(usize::max_value()) // TODO limit size
.map_err(|_| Error::Capacity(CapacityError::HeaderTooLong))?
.read_from(&mut self.stream)
.no_block()?;
let read = buf.read_from(&mut self.stream).no_block()?;
match read {
Some(0) => Err(Error::Protocol(ProtocolError::HandshakeIncomplete)),
Some(_) => Ok(if let Some((size, obj)) = Obj::try_parse(Buf::chunk(&buf))? {
Expand Down Expand Up @@ -124,7 +116,7 @@ pub trait TryParse: Sized {
#[derive(Debug)]
enum HandshakeState {
/// Reading data from the peer.
Reading(InputBuffer),
Reading(ReadBuffer),
/// Sending data to the peer.
Writing(Cursor<Vec<u8>>),
}
5 changes: 1 addition & 4 deletions src/handshake/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,6 @@ mod tests {
#[test]
fn key_conversion() {
// example from RFC 6455
assert_eq!(
derive_accept_key(b"dGhlIHNhbXBsZSBub25jZQ=="),
"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="
);
assert_eq!(derive_accept_key(b"dGhlIHNhbXBsZSBub25jZQ=="), "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
}
}
4 changes: 4 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

pub use http;

pub mod buffer;
pub mod client;
pub mod error;
pub mod handshake;
Expand All @@ -22,6 +23,9 @@ pub mod server;
pub mod stream;
pub mod util;

const READ_BUFFER_CHUNK_SIZE: usize = 4096;
type ReadBuffer = buffer::ReadBuffer<READ_BUFFER_CHUNK_SIZE>;

pub use crate::{
client::{client, connect},
error::{Error, Result},
Expand Down
28 changes: 11 additions & 17 deletions src/protocol/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@ pub mod coding;
mod frame;
mod mask;

pub use self::frame::{CloseFrame, Frame, FrameHeader};
use std::io::{Error as IoError, ErrorKind as IoErrorKind, Read, Write};

use crate::error::{CapacityError, Error, Result};
use input_buffer::{InputBuffer, MIN_READ};
use log::*;
use std::io::{Error as IoError, ErrorKind as IoErrorKind, Read, Write};

pub use self::frame::{CloseFrame, Frame, FrameHeader};
use crate::{
error::{CapacityError, Error, Result},
ReadBuffer,
};

/// A reader and writer for WebSocket frames.
#[derive(Debug)]
Expand Down Expand Up @@ -82,7 +85,7 @@ where
#[derive(Debug)]
pub(super) struct FrameCodec {
/// Buffer to read data from the stream.
in_buffer: InputBuffer,
in_buffer: ReadBuffer,
/// Buffer to send packets to the network.
out_buffer: Vec<u8>,
/// Header and remaining size of the incoming packet being processed.
Expand All @@ -92,17 +95,13 @@ pub(super) struct FrameCodec {
impl FrameCodec {
/// Create a new frame codec.
pub(super) fn new() -> Self {
Self {
in_buffer: InputBuffer::with_capacity(MIN_READ),
out_buffer: Vec::new(),
header: None,
}
Self { in_buffer: ReadBuffer::new(), out_buffer: Vec::new(), header: None }
}

/// Create a new frame codec from partially read data.
pub(super) fn from_partially_read(part: Vec<u8>) -> Self {
Self {
in_buffer: InputBuffer::from_partially_read(part),
in_buffer: ReadBuffer::from_partially_read(part),
out_buffer: Vec::new(),
header: None,
}
Expand Down Expand Up @@ -152,12 +151,7 @@ impl FrameCodec {
}

// Not enough data in buffer.
let size = self
.in_buffer
.prepare_reserve(MIN_READ)
.with_limit(usize::max_value())
.map_err(|_| Error::Capacity(CapacityError::TcpBufferFull))?
.read_from(stream)?;
let size = self.in_buffer.read_from(stream)?;
if size == 0 {
trace!("no frame received");
return Ok(None);
Expand Down