Skip to content

Commit 01e5be5

Browse files
authored
Add example of detecting client drops in bidirectional streams on server side (#931)
1 parent e30bb7e commit 01e5be5

3 files changed

Lines changed: 174 additions & 43 deletions

File tree

examples/Cargo.toml

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -180,37 +180,40 @@ path = "src/streaming/server.rs"
180180

181181
[dependencies]
182182
async-stream = "0.3"
183-
futures = {version = "0.3", default-features = false, features = ["alloc"]}
183+
futures = { version = "0.3", default-features = false, features = ["alloc"] }
184184
prost = "0.9"
185-
tokio = {version = "1.0", features = ["rt-multi-thread", "time", "fs", "macros", "net"]}
186-
tokio-stream = {version = "0.1", features = ["net"]}
187-
tonic = {path = "../tonic", features = ["tls", "compression"]}
188-
tower = {version = "0.4"}
185+
tokio = { version = "1.0", features = [ "rt-multi-thread", "time", "fs", "macros", "net",] }
186+
tokio-stream = { version = "0.1", features = ["net"] }
187+
tonic = { path = "../tonic", features = ["tls", "compression"] }
188+
tower = { version = "0.4" }
189189
# Required for routeguide
190190
rand = "0.8"
191-
serde = {version = "1.0", features = ["derive"]}
191+
serde = { version = "1.0", features = ["derive"] }
192192
serde_json = "1.0"
193193
# Tracing
194194
tracing = "0.1.16"
195195
tracing-attributes = "0.1"
196196
tracing-futures = "0.2"
197-
tracing-subscriber = {version = "0.3", features = ["tracing-log"]}
197+
tracing-subscriber = { version = "0.3", features = ["tracing-log"] }
198198
# Required for wellknown types
199199
prost-types = "0.9"
200200
# Hyper example
201201
http = "0.2"
202202
http-body = "0.4.2"
203-
hyper = {version = "0.14", features = ["full"]}
203+
hyper = { version = "0.14", features = ["full"] }
204204
pin-project = "1.0"
205205
warp = "0.3"
206206
# Health example
207-
tonic-health = {path = "../tonic-health"}
207+
tonic-health = { path = "../tonic-health" }
208208
# Reflection example
209209
listenfd = "0.3"
210-
tonic-reflection = {path = "../tonic-reflection"}
210+
tonic-reflection = { path = "../tonic-reflection" }
211211
# grpc-web example
212212
bytes = "1"
213-
tonic-web = {path = "../tonic-web"}
213+
tonic-web = { path = "../tonic-web" }
214+
# streaming example
215+
h2 = "0.3"
216+
214217

215218
[build-dependencies]
216-
tonic-build = {path = "../tonic-build", features = ["prost", "compression"]}
219+
tonic-build = { path = "../tonic-build", features = ["prost", "compression"] }

examples/src/streaming/client.rs

Lines changed: 66 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,28 +2,85 @@ pub mod pb {
22
tonic::include_proto!("grpc.examples.echo");
33
}
44

5+
use futures::stream::Stream;
6+
use std::time::Duration;
7+
use tokio_stream::StreamExt;
8+
use tonic::transport::Channel;
9+
510
use pb::{echo_client::EchoClient, EchoRequest};
611

7-
#[tokio::main]
8-
async fn main() -> Result<(), Box<dyn std::error::Error>> {
9-
let mut client = EchoClient::connect("http://[::1]:50051").await.unwrap();
12+
fn echo_requests_iter() -> impl Stream<Item = EchoRequest> {
13+
tokio_stream::iter(1..usize::MAX).map(|i| EchoRequest {
14+
message: format!("msg {:02}", i),
15+
})
16+
}
1017

18+
async fn streaming_echo(client: &mut EchoClient<Channel>, num: usize) {
1119
let stream = client
1220
.server_streaming_echo(EchoRequest {
1321
message: "foo".into(),
1422
})
1523
.await
24+
.unwrap()
25+
.into_inner();
26+
27+
// stream is infinite - take just 5 elements and then disconnect
28+
let mut stream = stream.take(num);
29+
while let Some(item) = stream.next().await {
30+
println!("\trecived: {}", item.unwrap().message);
31+
}
32+
// stream is droped here and the disconnect info is send to server
33+
}
34+
35+
async fn bidirectional_streaming_echo(client: &mut EchoClient<Channel>, num: usize) {
36+
let in_stream = echo_requests_iter().take(num);
37+
38+
let response = client
39+
.bidirectional_streaming_echo(in_stream)
40+
.await
1641
.unwrap();
1742

18-
println!("Connected...now sleeping for 2 seconds...");
43+
let mut resp_stream = response.into_inner();
44+
45+
while let Some(recived) = resp_stream.next().await {
46+
let recived = recived.unwrap();
47+
println!("\trecived message: `{}`", recived.message);
48+
}
49+
}
50+
51+
async fn bidirectional_streaming_echo_throttle(client: &mut EchoClient<Channel>, dur: Duration) {
52+
let in_stream = echo_requests_iter().throttle(dur);
53+
54+
let response = client
55+
.bidirectional_streaming_echo(in_stream)
56+
.await
57+
.unwrap();
58+
59+
let mut resp_stream = response.into_inner();
60+
61+
while let Some(recived) = resp_stream.next().await {
62+
let recived = recived.unwrap();
63+
println!("\trecived message: `{}`", recived.message);
64+
}
65+
}
66+
67+
#[tokio::main]
68+
async fn main() -> Result<(), Box<dyn std::error::Error>> {
69+
let mut client = EchoClient::connect("http://[::1]:50051").await.unwrap();
1970

20-
tokio::time::sleep(std::time::Duration::from_secs(2)).await;
71+
println!("Streaming echo:");
72+
streaming_echo(&mut client, 5).await;
73+
tokio::time::sleep(Duration::from_secs(1)).await; //do not mess server println functions
2174

22-
// Disconnect
23-
drop(stream);
24-
drop(client);
75+
// Echo stream that sends 17 requests then gracefull end that conection
76+
println!("\r\nBidirectional stream echo:");
77+
bidirectional_streaming_echo(&mut client, 17).await;
2578

26-
println!("Disconnected...");
79+
// Echo stream that sends up to `usize::MAX` requets. One request each 2s.
80+
// Exiting client with CTRL+C demostrate how to distinguise broken pipe from
81+
//gracefull client disconnection (above example) on the server side.
82+
println!("\r\nBidirectional stream echo (kill client with CTLR+C):");
83+
bidirectional_streaming_echo_throttle(&mut client, Duration::from_secs(2)).await;
2784

2885
Ok(())
2986
}

examples/src/streaming/server.rs

Lines changed: 93 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,39 @@ pub mod pb {
33
}
44

55
use futures::Stream;
6-
use std::net::ToSocketAddrs;
7-
use std::pin::Pin;
8-
use std::task::{Context, Poll};
9-
use tokio::sync::oneshot;
6+
use std::{error::Error, io::ErrorKind, net::ToSocketAddrs, pin::Pin, time::Duration};
7+
use tokio::sync::mpsc;
8+
use tokio_stream::{wrappers::ReceiverStream, StreamExt};
109
use tonic::{transport::Server, Request, Response, Status, Streaming};
1110

1211
use pb::{EchoRequest, EchoResponse};
1312

1413
type EchoResult<T> = Result<Response<T>, Status>;
1514
type ResponseStream = Pin<Box<dyn Stream<Item = Result<EchoResponse, Status>> + Send>>;
1615

16+
fn match_for_io_error(err_status: &Status) -> Option<&std::io::Error> {
17+
let mut err: &(dyn Error + 'static) = err_status;
18+
19+
loop {
20+
if let Some(io_err) = err.downcast_ref::<std::io::Error>() {
21+
return Some(io_err);
22+
}
23+
24+
// h2::Error do not expose std::io::Error with `source()`
25+
// https://github.com/hyperium/h2/pull/462
26+
if let Some(h2_err) = err.downcast_ref::<h2::Error>() {
27+
if let Some(io_err) = h2_err.get_io() {
28+
return Some(io_err);
29+
}
30+
}
31+
32+
err = match err.source() {
33+
Some(err) => err,
34+
None => return None,
35+
};
36+
}
37+
}
38+
1739
#[derive(Debug)]
1840
pub struct EchoServer {}
1941

@@ -29,28 +51,36 @@ impl pb::echo_server::Echo for EchoServer {
2951
&self,
3052
req: Request<EchoRequest>,
3153
) -> EchoResult<Self::ServerStreamingEchoStream> {
32-
println!("Client connected from: {:?}", req.remote_addr());
54+
println!("EchoServer::server_streaming_echo");
55+
println!("\tclient connected from: {:?}", req.remote_addr());
3356

34-
let (tx, rx) = oneshot::channel::<()>();
35-
36-
tokio::spawn(async move {
37-
let _ = rx.await;
38-
println!("The rx resolved therefore the client disconnected!");
57+
// creating infinite stream with requested message
58+
let repeat = std::iter::repeat(EchoResponse {
59+
message: req.into_inner().message,
3960
});
61+
let mut stream = Box::pin(tokio_stream::iter(repeat).throttle(Duration::from_millis(200)));
4062

41-
struct ClientDisconnect(oneshot::Sender<()>);
42-
43-
impl Stream for ClientDisconnect {
44-
type Item = Result<EchoResponse, Status>;
45-
46-
fn poll_next(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Option<Self::Item>> {
47-
// A stream that never resolves to anything....
48-
Poll::Pending
63+
// spawn and channel are required if you want handle "disconnect" functionality
64+
// the `out_stream` will not be polled after client disconnect
65+
let (tx, rx) = mpsc::channel(128);
66+
tokio::spawn(async move {
67+
while let Some(item) = stream.next().await {
68+
match tx.send(Result::<_, Status>::Ok(item)).await {
69+
Ok(_) => {
70+
// item (server response) was queued to be send to client
71+
}
72+
Err(_item) => {
73+
// output_stream was build from rx and both are dropped
74+
break;
75+
}
76+
}
4977
}
50-
}
78+
println!("\tclient disconnected");
79+
});
5180

81+
let output_stream = ReceiverStream::new(rx);
5282
Ok(Response::new(
53-
Box::pin(ClientDisconnect(tx)) as Self::ServerStreamingEchoStream
83+
Box::pin(output_stream) as Self::ServerStreamingEchoStream
5484
))
5585
}
5686

@@ -65,9 +95,50 @@ impl pb::echo_server::Echo for EchoServer {
6595

6696
async fn bidirectional_streaming_echo(
6797
&self,
68-
_: Request<Streaming<EchoRequest>>,
98+
req: Request<Streaming<EchoRequest>>,
6999
) -> EchoResult<Self::BidirectionalStreamingEchoStream> {
70-
Err(Status::unimplemented("not implemented"))
100+
println!("EchoServer::bidirectional_streaming_echo");
101+
102+
let mut in_stream = req.into_inner();
103+
let (tx, rx) = mpsc::channel(128);
104+
105+
// this spawn here is required if you want to handle connection error.
106+
// If we just map `in_stream` and write it back as `out_stream` the `out_stream`
107+
// will be drooped when connection error occurs and error will never be propagated
108+
// to mapped version of `in_stream`.
109+
tokio::spawn(async move {
110+
while let Some(result) = in_stream.next().await {
111+
match result {
112+
Ok(v) => tx
113+
.send(Ok(EchoResponse { message: v.message }))
114+
.await
115+
.expect("working rx"),
116+
Err(err) => {
117+
if let Some(io_err) = match_for_io_error(&err) {
118+
if io_err.kind() == ErrorKind::BrokenPipe {
119+
// here you can handle special case when client
120+
// disconnected in unexpected way
121+
eprintln!("\tclient disconnected: broken pipe");
122+
break;
123+
}
124+
}
125+
126+
match tx.send(Err(err)).await {
127+
Ok(_) => (),
128+
Err(_err) => break, // response was droped
129+
}
130+
}
131+
}
132+
}
133+
println!("\tstream ended");
134+
});
135+
136+
// echo just write the same data that was received
137+
let out_stream = ReceiverStream::new(rx);
138+
139+
Ok(Response::new(
140+
Box::pin(out_stream) as Self::BidirectionalStreamingEchoStream
141+
))
71142
}
72143
}
73144

0 commit comments

Comments
 (0)