Skip to content

Commit 8c980ea

Browse files
authored
io: add write_all_vectored to tokio-util (#7768)
1 parent e35fd6d commit 8c980ea

3 files changed

Lines changed: 309 additions & 0 deletions

File tree

tokio-util/src/io/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ mod reader_stream;
1717
pub mod simplex;
1818
mod sink_writer;
1919
mod stream_reader;
20+
mod write_all_vectored;
2021

2122
cfg_io_util! {
2223
mod read_arc;
@@ -32,4 +33,5 @@ pub use self::read_buf::read_buf;
3233
pub use self::reader_stream::ReaderStream;
3334
pub use self::sink_writer::SinkWriter;
3435
pub use self::stream_reader::StreamReader;
36+
pub use self::write_all_vectored::{write_all_vectored, WriteAllVectored};
3537
pub use crate::util::{poll_read_buf, poll_write_buf};
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
use tokio::io::AsyncWrite;
2+
3+
use pin_project_lite::pin_project;
4+
use std::marker::PhantomPinned;
5+
use std::pin::Pin;
6+
use std::task::{ready, Context, Poll};
7+
use std::{future::Future, io::IoSlice};
8+
use std::{io, mem};
9+
10+
pin_project! {
11+
/// A future that writes all data from multiple buffers to a writer.
12+
#[derive(Debug)]
13+
#[must_use = "futures do nothing unless you `.await` or poll them"]
14+
pub struct WriteAllVectored<'a, 'b, W: ?Sized> {
15+
writer: &'a mut W,
16+
bufs: &'a mut [IoSlice<'b>],
17+
// Make this future `!Unpin` for compatibility with async trait methods.
18+
#[pin]
19+
_pin: PhantomPinned,
20+
}
21+
}
22+
/// Like [`write_all`] but writes all data from multiple buffers into this writer.
23+
///
24+
/// This function writes multiple (possibly non-contiguous) buffers into the writer,
25+
/// using the `writev` syscall to potentially write in a single system call.
26+
///
27+
/// Equivalent to:
28+
///
29+
/// ```ignore
30+
/// async fn write_all_vectored<W: AsyncWrite + Unpin + ?Sized>(
31+
/// writer: &mut W,
32+
/// mut bufs: &mut [IoSlice<'_>]
33+
/// ) -> io::Result<()> {
34+
/// while !bufs.is_empty() {
35+
/// let n = write_vectored(writer, bufs).await?;
36+
/// if n == 0 {
37+
/// return Err(io::ErrorKind::WriteZero.into());
38+
/// }
39+
/// IoSlice::advance_slices(&mut bufs, n);
40+
/// }
41+
/// Ok(())
42+
/// }
43+
/// ```
44+
///
45+
/// # Cancel safety
46+
///
47+
/// This method is not cancellation safe. If it is used as the event
48+
/// in a `tokio::select!` statement and some other
49+
/// branch completes first, then the provided buffer may have been
50+
/// partially written, but future calls to `write_all_vectored` will
51+
/// have lost its place in the buffer.
52+
///
53+
/// # Examples
54+
///
55+
/// ```rust
56+
/// use tokio_util::io::write_all_vectored;
57+
/// use std::io::IoSlice;
58+
///
59+
/// #[tokio::main(flavor = "current_thread")]
60+
/// async fn main() -> std::io::Result<()> {
61+
///
62+
/// let mut writer = Vec::new();
63+
/// let bufs = &mut [
64+
/// IoSlice::new(&[1]),
65+
/// IoSlice::new(&[2, 3]),
66+
/// IoSlice::new(&[4, 5, 6]),
67+
/// ];
68+
///
69+
/// write_all_vectored(&mut writer, bufs).await?;
70+
///
71+
/// // Note: `bufs` has been modified by `IoSlice::advance_slices` and should not be reused.
72+
/// assert_eq!(writer, &[1, 2, 3, 4, 5, 6]);
73+
/// Ok(())
74+
/// }
75+
/// ```
76+
///
77+
/// # Notes
78+
///
79+
/// See the documentation for [`Write::write_all_vectored`] from std.
80+
/// After calling this function, the buffer slices may have
81+
/// been advanced and should not be reused.
82+
///
83+
/// [`Write::write_all_vectored`]: std::io::Write::write_all_vectored
84+
/// [`write_all`]: tokio::io::AsyncWriteExt::write_all
85+
/// [`writev`]: https://man7.org/linux/man-pages/man3/writev.3p.html
86+
pub fn write_all_vectored<'a, 'b, W>(
87+
writer: &'a mut W,
88+
bufs: &'a mut [IoSlice<'b>],
89+
) -> WriteAllVectored<'a, 'b, W>
90+
where
91+
W: AsyncWrite + Unpin + ?Sized,
92+
{
93+
WriteAllVectored {
94+
writer,
95+
bufs,
96+
_pin: PhantomPinned,
97+
}
98+
}
99+
100+
impl<W> Future for WriteAllVectored<'_, '_, W>
101+
where
102+
W: AsyncWrite + Unpin + ?Sized,
103+
{
104+
type Output = io::Result<()>;
105+
106+
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
107+
let me = self.project();
108+
while !me.bufs.is_empty() {
109+
// advance to first non-empty buffer
110+
let non_empty = match me.bufs.iter().position(|b| !b.is_empty()) {
111+
Some(pos) => pos,
112+
None => return Poll::Ready(Ok(())),
113+
};
114+
115+
// drop empty buffers at the start
116+
*me.bufs = &mut mem::take(me.bufs)[non_empty..];
117+
118+
let n = ready!(Pin::new(&mut *me.writer).poll_write_vectored(cx, me.bufs))?;
119+
if n == 0 {
120+
return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
121+
}
122+
self::advance_slices(me.bufs, n);
123+
}
124+
125+
Poll::Ready(Ok(()))
126+
}
127+
}
128+
129+
// copied from `std::IoSlice::advance_slices`
130+
// replace with method when MSRV is 1.81.0
131+
fn advance_slices<'a>(bufs: &mut &mut [IoSlice<'a>], n: usize) {
132+
// Number of buffers to remove.
133+
let mut remove = 0;
134+
// Remaining length before reaching n. This prevents overflow
135+
// that could happen if the length of slices in `bufs` were instead
136+
// accumulated. Those slice may be aliased and, if they are large
137+
// enough, their added length may overflow a `usize`.
138+
let mut left = n;
139+
for buf in bufs.iter() {
140+
if let Some(remainder) = left.checked_sub(buf.len()) {
141+
left = remainder;
142+
remove += 1;
143+
} else {
144+
break;
145+
}
146+
}
147+
148+
*bufs = &mut std::mem::take(bufs)[remove..];
149+
if let Some(first) = bufs.first_mut() {
150+
let buf = &first[..left];
151+
// necessary due to limitating in the borrow checker,
152+
// when tokio MSRV reaches 1.81.0 this entire function
153+
// can be replaced with `IoSlice::advance_slices`
154+
//
155+
// SAFETY: transmute a sub-slice of an IoSlice<'a> back to
156+
// the lifetime `'a`. This is safe because the underlying memory
157+
// is guaranteed to live for 'a, we have shared access, and no
158+
// underlying data is reinterpreted to a different type.
159+
unsafe {
160+
*first = IoSlice::new(std::mem::transmute::<&[u8], &'a [u8]>(buf));
161+
}
162+
} else {
163+
assert!(left == 0, "advancing io slices beyond their length");
164+
}
165+
}
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
#![warn(rust_2018_idioms)]
2+
#![cfg(feature = "full")]
3+
4+
use tokio::io::AsyncWrite;
5+
use tokio_util::io::write_all_vectored;
6+
7+
use bytes::BytesMut;
8+
use std::io;
9+
use std::io::IoSlice;
10+
use std::pin::Pin;
11+
use std::task::{Context, Poll};
12+
13+
#[tokio::test]
14+
async fn test_write_all_vectored() {
15+
struct Wr {
16+
buf: BytesMut,
17+
}
18+
impl AsyncWrite for Wr {
19+
fn poll_write(
20+
self: Pin<&mut Self>,
21+
_cx: &mut Context<'_>,
22+
_buf: &[u8],
23+
) -> Poll<io::Result<usize>> {
24+
// When executing `write_all_buf` with this writer,
25+
// `poll_write` is not called.
26+
panic!("shouldn't be called")
27+
}
28+
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
29+
Ok(()).into()
30+
}
31+
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
32+
Ok(()).into()
33+
}
34+
fn poll_write_vectored(
35+
mut self: Pin<&mut Self>,
36+
_cx: &mut Context<'_>,
37+
bufs: &[io::IoSlice<'_>],
38+
) -> Poll<Result<usize, io::Error>> {
39+
for buf in bufs {
40+
self.buf.extend_from_slice(buf);
41+
}
42+
let n = self.buf.len();
43+
Ok(n).into()
44+
}
45+
fn is_write_vectored(&self) -> bool {
46+
// Enable vectored write. (doesn't need to be enabled explicitly for `write_all_vectored`)
47+
true
48+
}
49+
}
50+
51+
let mut wr = Wr {
52+
buf: BytesMut::with_capacity(64),
53+
};
54+
55+
let buf = &mut [
56+
IoSlice::new(&b"hello"[..]),
57+
IoSlice::new(&b" "[..]),
58+
IoSlice::new(&b"world"[..]),
59+
];
60+
61+
write_all_vectored(&mut wr, buf).await.unwrap();
62+
assert_eq!(&wr.buf[..], b"hello world");
63+
}
64+
65+
#[tokio::test]
66+
async fn write_all_vectored_with_empty_slice() {
67+
struct Wr {
68+
buf: BytesMut,
69+
}
70+
impl AsyncWrite for Wr {
71+
fn poll_write(
72+
self: Pin<&mut Self>,
73+
_cx: &mut Context<'_>,
74+
_buf: &[u8],
75+
) -> Poll<io::Result<usize>> {
76+
panic!("shouldn't be called")
77+
}
78+
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
79+
Ok(()).into()
80+
}
81+
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
82+
Ok(()).into()
83+
}
84+
fn poll_write_vectored(
85+
mut self: Pin<&mut Self>,
86+
_cx: &mut Context<'_>,
87+
bufs: &[io::IoSlice<'_>],
88+
) -> Poll<Result<usize, io::Error>> {
89+
for buf in bufs {
90+
self.buf.extend_from_slice(buf);
91+
}
92+
let n = self.buf.len();
93+
Ok(n).into()
94+
}
95+
fn is_write_vectored(&self) -> bool {
96+
// Enable vectored write.
97+
true
98+
}
99+
}
100+
101+
// case 1 middle empty slice
102+
let mut wr = Wr {
103+
buf: BytesMut::with_capacity(64),
104+
};
105+
106+
let buf = &mut [
107+
IoSlice::new(&b"hello"[..]),
108+
IoSlice::new(&[]),
109+
IoSlice::new(&b"world"[..]),
110+
];
111+
112+
write_all_vectored(&mut wr, buf).await.unwrap();
113+
assert_eq!(&wr.buf[..], b"helloworld");
114+
115+
// case 2 no slices
116+
let mut wr = Wr {
117+
buf: BytesMut::with_capacity(64),
118+
};
119+
120+
let buf = &mut [];
121+
122+
write_all_vectored(&mut wr, buf).await.unwrap();
123+
assert_eq!(&wr.buf[..], b"");
124+
125+
// case 3 just an empty slice
126+
let mut wr = Wr {
127+
buf: BytesMut::with_capacity(64),
128+
};
129+
let buf = &mut [IoSlice::new(&[])];
130+
131+
write_all_vectored(&mut wr, buf).await.unwrap();
132+
assert_eq!(&wr.buf[..], b"");
133+
134+
// case 4 ending with empty slice
135+
let mut wr = Wr {
136+
buf: BytesMut::with_capacity(64),
137+
};
138+
let buf = &mut [IoSlice::new(b"hello"), IoSlice::new(&[])];
139+
140+
write_all_vectored(&mut wr, buf).await.unwrap();
141+
assert_eq!(&wr.buf[..], b"hello");
142+
}

0 commit comments

Comments
 (0)