|
2 | 2 | use std::future::Future; |
3 | 3 | use std::{ |
4 | 4 | io, |
| 5 | + mem::MaybeUninit, |
5 | 6 | pin::Pin, |
6 | 7 | task::{Context, Poll}, |
7 | 8 | time::Duration, |
8 | 9 | }; |
9 | 10 |
|
| 11 | +use bytes::BytesMut; |
10 | 12 | use futures::ready; |
11 | 13 | use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; |
12 | 14 |
|
@@ -91,7 +93,7 @@ impl CopyBuffer { |
91 | 93 | }) |
92 | 94 | } |
93 | 95 |
|
94 | | - pub fn amount_transfered(&self) -> u64 { |
| 96 | + pub fn amount_transferred(&self) -> u64 { |
95 | 97 | self.amt |
96 | 98 | } |
97 | 99 |
|
@@ -235,7 +237,7 @@ where |
235 | 237 | match delay.as_mut().poll(cx) { |
236 | 238 | Poll::Ready(()) => { |
237 | 239 | *a_to_b = TransferState::ShuttingDown( |
238 | | - buf.amount_transfered(), |
| 240 | + buf.amount_transferred(), |
239 | 241 | ); |
240 | 242 | continue; |
241 | 243 | } |
@@ -285,7 +287,7 @@ where |
285 | 287 | match delay.as_mut().poll(cx) { |
286 | 288 | Poll::Ready(()) => { |
287 | 289 | *b_to_a = TransferState::ShuttingDown( |
288 | | - buf.amount_transfered(), |
| 290 | + buf.amount_transferred(), |
289 | 291 | ); |
290 | 292 | continue; |
291 | 293 | } |
@@ -352,3 +354,56 @@ where |
352 | 354 | } |
353 | 355 | .await |
354 | 356 | } |
| 357 | + |
| 358 | +pub trait ReadExactBase { |
| 359 | + /// inner stream to be polled |
| 360 | + type I: AsyncRead + Unpin; |
| 361 | + /// prepare the inner stream, read buffer and read position |
| 362 | + fn decompose(&mut self) -> (&mut Self::I, &mut BytesMut, &mut usize); |
| 363 | +} |
| 364 | + |
| 365 | +pub trait ReadExt: ReadExactBase { |
| 366 | + fn poll_read_exact( |
| 367 | + &mut self, |
| 368 | + cx: &mut std::task::Context, |
| 369 | + size: usize, |
| 370 | + ) -> Poll<std::io::Result<()>>; |
| 371 | +} |
| 372 | + |
| 373 | +impl<T: ReadExactBase> ReadExt for T { |
| 374 | + fn poll_read_exact( |
| 375 | + &mut self, |
| 376 | + cx: &mut std::task::Context, |
| 377 | + size: usize, |
| 378 | + ) -> Poll<std::io::Result<()>> { |
| 379 | + let (raw, read_buf, read_pos) = self.decompose(); |
| 380 | + read_buf.reserve(size); |
| 381 | + // # safety: read_buf has reserved `size` |
| 382 | + unsafe { read_buf.set_len(size) } |
| 383 | + loop { |
| 384 | + if *read_pos < size { |
| 385 | + // # safety: read_pos<size==read_buf.len(), and |
| 386 | + // read_buf[0..read_pos] is initialized |
| 387 | + let dst = unsafe { |
| 388 | + &mut *((&mut read_buf[*read_pos..size]) as *mut _ |
| 389 | + as *mut [MaybeUninit<u8>]) |
| 390 | + }; |
| 391 | + let mut buf = ReadBuf::uninit(dst); |
| 392 | + let ptr = buf.filled().as_ptr(); |
| 393 | + ready!(Pin::new(&mut *raw).poll_read(cx, &mut buf))?; |
| 394 | + assert_eq!(ptr, buf.filled().as_ptr()); |
| 395 | + if buf.filled().is_empty() { |
| 396 | + return Poll::Ready(Err(std::io::Error::new( |
| 397 | + std::io::ErrorKind::UnexpectedEof, |
| 398 | + "unexpected eof", |
| 399 | + ))); |
| 400 | + } |
| 401 | + *read_pos += buf.filled().len(); |
| 402 | + } else { |
| 403 | + assert!(*read_pos == size); |
| 404 | + *read_pos = 0; |
| 405 | + return Poll::Ready(Ok(())); |
| 406 | + } |
| 407 | + } |
| 408 | + } |
| 409 | +} |
0 commit comments