Skip to content

Commit 06c3c42

Browse files
committed
Refactor IncompleteMessage
1 parent c0a099e commit 06c3c42

1 file changed

Lines changed: 21 additions & 118 deletions

File tree

src/protocol/message.rs

Lines changed: 21 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -3,150 +3,53 @@ use crate::{
33
error::{CapacityError, Error, Result},
44
protocol::frame::Utf8Bytes,
55
};
6+
use bytes::{Bytes, BytesMut};
67
use std::{fmt, result::Result as StdResult, str};
78

8-
mod string_collect {
9-
use utf8::DecodeError;
10-
11-
use crate::error::{Error, Result};
12-
13-
#[derive(Debug)]
14-
pub struct StringCollector {
15-
data: String,
16-
incomplete: Option<utf8::Incomplete>,
17-
}
18-
19-
impl StringCollector {
20-
pub fn new() -> Self {
21-
StringCollector { data: String::new(), incomplete: None }
22-
}
23-
24-
pub fn len(&self) -> usize {
25-
self.data
26-
.len()
27-
.saturating_add(self.incomplete.map(|i| i.buffer_len as usize).unwrap_or(0))
28-
}
29-
30-
pub fn extend<T: AsRef<[u8]>>(&mut self, tail: T) -> Result<()> {
31-
let mut input: &[u8] = tail.as_ref();
32-
33-
if let Some(mut incomplete) = self.incomplete.take() {
34-
if let Some((result, rest)) = incomplete.try_complete(input) {
35-
input = rest;
36-
match result {
37-
Ok(text) => self.data.push_str(text),
38-
Err(result_bytes) => {
39-
return Err(Error::Utf8(String::from_utf8_lossy(result_bytes).into()))
40-
}
41-
}
42-
} else {
43-
input = &[];
44-
self.incomplete = Some(incomplete);
45-
}
46-
}
47-
48-
if !input.is_empty() {
49-
match utf8::decode(input) {
50-
Ok(text) => {
51-
self.data.push_str(text);
52-
Ok(())
53-
}
54-
Err(DecodeError::Incomplete { valid_prefix, incomplete_suffix }) => {
55-
self.data.push_str(valid_prefix);
56-
self.incomplete = Some(incomplete_suffix);
57-
Ok(())
58-
}
59-
Err(DecodeError::Invalid { valid_prefix, invalid_sequence, .. }) => {
60-
self.data.push_str(valid_prefix);
61-
Err(Error::Utf8(String::from_utf8_lossy(invalid_sequence).into()))
62-
}
63-
}
64-
} else {
65-
Ok(())
66-
}
67-
}
68-
69-
pub fn into_string(self) -> Result<String> {
70-
if let Some(incomplete) = self.incomplete {
71-
Err(Error::Utf8(format!("incomplete string: {incomplete:?}")))
72-
} else {
73-
Ok(self.data)
74-
}
75-
}
76-
}
77-
}
78-
79-
use self::string_collect::StringCollector;
80-
use bytes::Bytes;
81-
829
/// A struct representing the incomplete message.
10+
///
11+
/// Note: Text messages are utf8 validated on calling [`Self::complete`].
8312
#[derive(Debug)]
84-
pub struct IncompleteMessage {
85-
collector: IncompleteMessageCollector,
86-
}
87-
88-
#[derive(Debug)]
89-
enum IncompleteMessageCollector {
90-
Text(StringCollector),
91-
Binary(Vec<u8>),
13+
pub(crate) struct IncompleteMessage {
14+
kind: MessageType,
15+
buf: BytesMut,
9216
}
9317

9418
impl IncompleteMessage {
95-
/// Create new.
96-
pub fn new(message_type: MessageType) -> Self {
97-
IncompleteMessage {
98-
collector: match message_type {
99-
MessageType::Binary => IncompleteMessageCollector::Binary(Vec::new()),
100-
MessageType::Text => IncompleteMessageCollector::Text(StringCollector::new()),
101-
},
102-
}
103-
}
104-
105-
/// Get the current filled size of the buffer.
106-
pub fn len(&self) -> usize {
107-
match self.collector {
108-
IncompleteMessageCollector::Text(ref t) => t.len(),
109-
IncompleteMessageCollector::Binary(ref b) => b.len(),
110-
}
19+
pub fn new(kind: MessageType) -> Self {
20+
Self { kind, buf: BytesMut::new() }
11121
}
11222

11323
/// Add more data to an existing message.
114-
pub fn extend<T: AsRef<[u8]>>(&mut self, tail: T, size_limit: Option<usize>) -> Result<()> {
24+
pub fn extend(&mut self, tail: Bytes, size_limit: Option<usize>) -> Result<()> {
11525
// Always have a max size. This ensures an error in case of concatenating two buffers
116-
// of more than `usize::max_value()` bytes in total.
117-
let max_size = size_limit.unwrap_or_else(usize::max_value);
118-
let my_size = self.len();
119-
let portion_size = tail.as_ref().len();
26+
// of more than `usize::MAX` bytes in total.
27+
let max_size = size_limit.unwrap_or(usize::MAX);
28+
let my_size = self.buf.len();
29+
let portion_size = tail.len();
12030
// Be careful about integer overflows here.
12131
if my_size > max_size || portion_size > max_size - my_size {
12232
return Err(Error::Capacity(CapacityError::MessageTooLong {
123-
size: my_size + portion_size,
33+
size: my_size.saturating_add(portion_size),
12434
max_size,
12535
}));
12636
}
12737

128-
match self.collector {
129-
IncompleteMessageCollector::Binary(ref mut v) => {
130-
v.extend(tail.as_ref());
131-
Ok(())
132-
}
133-
IncompleteMessageCollector::Text(ref mut t) => t.extend(tail),
134-
}
38+
self.buf.extend_from_slice(&tail);
39+
Ok(())
13540
}
13641

13742
/// Convert an incomplete message into a complete one.
13843
pub fn complete(self) -> Result<Message> {
139-
match self.collector {
140-
IncompleteMessageCollector::Binary(v) => Ok(Message::Binary(v.into())),
141-
IncompleteMessageCollector::Text(t) => {
142-
let text = t.into_string()?;
143-
Ok(Message::text(text))
144-
}
145-
}
44+
Ok(match self.kind {
45+
MessageType::Binary => Message::Binary(self.buf.freeze()),
46+
MessageType::Text => Message::Text(self.buf.try_into()?),
47+
})
14648
}
14749
}
14850

14951
/// The type of incomplete message.
52+
#[derive(Debug)]
15053
pub enum MessageType {
15154
Text,
15255
Binary,

0 commit comments

Comments
 (0)