Skip to content

Commit dee2ab5

Browse files
authored
feat(transport): add unix socket support in server (#861)
1 parent d6c0fc1 commit dee2ab5

5 files changed

Lines changed: 120 additions & 79 deletions

File tree

examples/src/uds/server.rs

Lines changed: 8 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
#![cfg_attr(not(unix), allow(unused_imports))]
22

3-
use futures::TryFutureExt;
43
use std::path::Path;
54
#[cfg(unix)]
65
use tokio::net::UnixListener;
6+
#[cfg(unix)]
7+
use tokio_stream::wrappers::UnixListenerStream;
8+
#[cfg(unix)]
9+
use tonic::transport::server::UdsConnectInfo;
710
use tonic::{transport::Server, Request, Response, Status};
811

912
pub mod hello_world {
@@ -26,7 +29,7 @@ impl Greeter for MyGreeter {
2629
) -> Result<Response<HelloReply>, Status> {
2730
#[cfg(unix)]
2831
{
29-
let conn_info = request.extensions().get::<unix::UdsConnectInfo>().unwrap();
32+
let conn_info = request.extensions().get::<UdsConnectInfo>().unwrap();
3033
println!("Got a request {:?} with info {:?}", request, conn_info);
3134
}
3235

@@ -46,89 +49,17 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
4649

4750
let greeter = MyGreeter::default();
4851

49-
let incoming = {
50-
let uds = UnixListener::bind(path)?;
51-
52-
async_stream::stream! {
53-
loop {
54-
let item = uds.accept().map_ok(|(st, _)| unix::UnixStream(st)).await;
55-
56-
yield item;
57-
}
58-
}
59-
};
52+
let uds = UnixListener::bind(path)?;
53+
let uds_stream = UnixListenerStream::new(uds);
6054

6155
Server::builder()
6256
.add_service(GreeterServer::new(greeter))
63-
.serve_with_incoming(incoming)
57+
.serve_with_incoming(uds_stream)
6458
.await?;
6559

6660
Ok(())
6761
}
6862

69-
#[cfg(unix)]
70-
mod unix {
71-
use std::{
72-
pin::Pin,
73-
sync::Arc,
74-
task::{Context, Poll},
75-
};
76-
77-
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
78-
use tonic::transport::server::Connected;
79-
80-
#[derive(Debug)]
81-
pub struct UnixStream(pub tokio::net::UnixStream);
82-
83-
impl Connected for UnixStream {
84-
type ConnectInfo = UdsConnectInfo;
85-
86-
fn connect_info(&self) -> Self::ConnectInfo {
87-
UdsConnectInfo {
88-
peer_addr: self.0.peer_addr().ok().map(Arc::new),
89-
peer_cred: self.0.peer_cred().ok(),
90-
}
91-
}
92-
}
93-
94-
#[derive(Clone, Debug)]
95-
pub struct UdsConnectInfo {
96-
pub peer_addr: Option<Arc<tokio::net::unix::SocketAddr>>,
97-
pub peer_cred: Option<tokio::net::unix::UCred>,
98-
}
99-
100-
impl AsyncRead for UnixStream {
101-
fn poll_read(
102-
mut self: Pin<&mut Self>,
103-
cx: &mut Context<'_>,
104-
buf: &mut ReadBuf<'_>,
105-
) -> Poll<std::io::Result<()>> {
106-
Pin::new(&mut self.0).poll_read(cx, buf)
107-
}
108-
}
109-
110-
impl AsyncWrite for UnixStream {
111-
fn poll_write(
112-
mut self: Pin<&mut Self>,
113-
cx: &mut Context<'_>,
114-
buf: &[u8],
115-
) -> Poll<std::io::Result<usize>> {
116-
Pin::new(&mut self.0).poll_write(cx, buf)
117-
}
118-
119-
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
120-
Pin::new(&mut self.0).poll_flush(cx)
121-
}
122-
123-
fn poll_shutdown(
124-
mut self: Pin<&mut Self>,
125-
cx: &mut Context<'_>,
126-
) -> Poll<std::io::Result<()>> {
127-
Pin::new(&mut self.0).poll_shutdown(cx)
128-
}
129-
}
130-
}
131-
13263
#[cfg(not(unix))]
13364
fn main() {
13465
panic!("The `uds` example only works on unix systems!");

tests/integration_tests/tests/connect_info.rs

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,77 @@ async fn getting_connect_info() {
4848

4949
jh.await.unwrap();
5050
}
51+
52+
#[cfg(unix)]
53+
pub mod unix {
54+
use std::convert::TryFrom as _;
55+
56+
use futures_util::FutureExt;
57+
use tokio::{
58+
net::{UnixListener, UnixStream},
59+
sync::oneshot,
60+
};
61+
use tokio_stream::wrappers::UnixListenerStream;
62+
use tonic::{
63+
transport::{server::UdsConnectInfo, Endpoint, Server, Uri},
64+
Request, Response, Status,
65+
};
66+
use tower::service_fn;
67+
68+
use integration_tests::pb::{test_client, test_server, Input, Output};
69+
70+
struct Svc {}
71+
72+
#[tonic::async_trait]
73+
impl test_server::Test for Svc {
74+
async fn unary_call(&self, req: Request<Input>) -> Result<Response<Output>, Status> {
75+
let conn_info = req.extensions().get::<UdsConnectInfo>().unwrap();
76+
77+
// Client-side unix sockets are unnamed.
78+
assert!(req.remote_addr().is_none());
79+
assert!(conn_info.peer_addr.as_ref().unwrap().is_unnamed());
80+
// This should contain process credentials for the client socket.
81+
assert!(conn_info.peer_cred.as_ref().is_some());
82+
83+
Ok(Response::new(Output {}))
84+
}
85+
}
86+
87+
#[tokio::test]
88+
async fn getting_connect_info() {
89+
let mut unix_socket_path = std::env::temp_dir();
90+
unix_socket_path.push("uds-integration-test");
91+
92+
let uds = UnixListener::bind(&unix_socket_path).unwrap();
93+
let uds_stream = UnixListenerStream::new(uds);
94+
95+
let service = test_server::TestServer::new(Svc {});
96+
let (tx, rx) = oneshot::channel::<()>();
97+
98+
let jh = tokio::spawn(async move {
99+
Server::builder()
100+
.add_service(service)
101+
.serve_with_incoming_shutdown(uds_stream, rx.map(drop))
102+
.await
103+
.unwrap();
104+
});
105+
106+
// Take a copy before moving into the `service_fn` closure so that the closure
107+
// can implement `FnMut`.
108+
let path = unix_socket_path.clone();
109+
let channel = Endpoint::try_from("http://[::]:50051")
110+
.unwrap()
111+
.connect_with_connector(service_fn(move |_: Uri| UnixStream::connect(path.clone())))
112+
.await
113+
.unwrap();
114+
115+
let mut client = test_client::TestClient::new(channel);
116+
117+
client.unary_call(Input {}).await.unwrap();
118+
119+
tx.send(()).unwrap();
120+
jh.await.unwrap();
121+
122+
std::fs::remove_file(unix_socket_path).unwrap();
123+
}
124+
}

tonic/src/request.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,8 +202,8 @@ impl<T> Request<T> {
202202
/// Get the remote address of this connection.
203203
///
204204
/// This will return `None` if the `IO` type used
205-
/// does not implement `Connected`. This currently,
206-
/// only works on the server side.
205+
/// does not implement `Connected` or when using a unix domain socket.
206+
/// This currently only works on the server side.
207207
pub fn remote_addr(&self) -> Option<SocketAddr> {
208208
#[cfg(feature = "transport")]
209209
{

tonic/src/transport/server/mod.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ mod recover_error;
66
#[cfg(feature = "tls")]
77
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
88
mod tls;
9+
#[cfg(unix)]
10+
mod unix;
911

1012
pub use conn::{Connected, TcpConnectInfo};
1113
#[cfg(feature = "tls")]
@@ -17,6 +19,9 @@ pub use conn::TlsConnectInfo;
1719
#[cfg(feature = "tls")]
1820
use super::service::TlsAcceptor;
1921

22+
#[cfg(unix)]
23+
pub use unix::UdsConnectInfo;
24+
2025
use incoming::TcpIncoming;
2126

2227
#[cfg(feature = "tls")]

tonic/src/transport/server/unix.rs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
use super::Connected;
2+
use std::sync::Arc;
3+
4+
/// Connection info for Unix domain socket streams.
5+
///
6+
/// This type will be accessible through [request extensions][ext] if you're using
7+
/// a unix stream.
8+
///
9+
/// See [Connected] for more details.
10+
///
11+
/// [ext]: crate::Request::extensions
12+
/// [Connected]: crate::transport::server::Connected
13+
#[cfg_attr(docsrs, doc(cfg(unix)))]
14+
#[derive(Clone, Debug)]
15+
pub struct UdsConnectInfo {
16+
/// Peer address. This will be "unnamed" for client unix sockets.
17+
pub peer_addr: Option<Arc<tokio::net::unix::SocketAddr>>,
18+
/// Process credentials for the unix socket.
19+
pub peer_cred: Option<tokio::net::unix::UCred>,
20+
}
21+
22+
impl Connected for tokio::net::UnixStream {
23+
type ConnectInfo = UdsConnectInfo;
24+
25+
fn connect_info(&self) -> Self::ConnectInfo {
26+
UdsConnectInfo {
27+
peer_addr: self.peer_addr().ok().map(Arc::new),
28+
peer_cred: self.peer_cred().ok(),
29+
}
30+
}
31+
}

0 commit comments

Comments
 (0)