Skip to content

Commit 4388d82

Browse files
author
Quentin Perez
authored
tonic: Introduce a new method on Endpoint to override the origin (#1013)
1 parent 8287988 commit 4388d82

File tree

4 files changed

+151
-15
lines changed

4 files changed

+151
-15
lines changed
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
use futures::future::BoxFuture;
2+
use futures_util::FutureExt;
3+
use integration_tests::pb::test_client;
4+
use integration_tests::pb::{test_server, Input, Output};
5+
use std::task::Context;
6+
use std::task::Poll;
7+
use std::time::Duration;
8+
use tokio::sync::oneshot;
9+
use tonic::codegen::http::Request;
10+
use tonic::{
11+
transport::{Endpoint, Server},
12+
Response, Status,
13+
};
14+
use tower::Layer;
15+
use tower::Service;
16+
17+
#[tokio::test]
18+
async fn writes_origin_header() {
19+
struct Svc;
20+
21+
#[tonic::async_trait]
22+
impl test_server::Test for Svc {
23+
async fn unary_call(
24+
&self,
25+
_req: tonic::Request<Input>,
26+
) -> Result<Response<Output>, Status> {
27+
Ok(Response::new(Output {}))
28+
}
29+
}
30+
31+
let svc = test_server::TestServer::new(Svc);
32+
33+
let (tx, rx) = oneshot::channel::<()>();
34+
35+
let jh = tokio::spawn(async move {
36+
Server::builder()
37+
.layer(OriginLayer {})
38+
.add_service(svc)
39+
.serve_with_shutdown("127.0.0.1:1442".parse().unwrap(), rx.map(drop))
40+
.await
41+
.unwrap();
42+
});
43+
44+
tokio::time::sleep(Duration::from_millis(100)).await;
45+
46+
let channel = Endpoint::from_static("http://127.0.0.1:1442")
47+
.origin("https://docs.rs".parse().expect("valid uri"))
48+
.connect()
49+
.await
50+
.unwrap();
51+
52+
let mut client = test_client::TestClient::new(channel);
53+
54+
match client.unary_call(Input {}).await {
55+
Ok(_) => {}
56+
Err(status) => panic!("{}", status.message()),
57+
}
58+
59+
tx.send(()).unwrap();
60+
61+
jh.await.unwrap();
62+
}
63+
64+
#[derive(Clone)]
65+
struct OriginLayer {}
66+
67+
impl<S> Layer<S> for OriginLayer {
68+
type Service = OriginService<S>;
69+
70+
fn layer(&self, inner: S) -> Self::Service {
71+
OriginService { inner }
72+
}
73+
}
74+
75+
#[derive(Clone)]
76+
struct OriginService<S> {
77+
inner: S,
78+
}
79+
80+
impl<T> Service<Request<tonic::transport::Body>> for OriginService<T>
81+
where
82+
T: Service<Request<tonic::transport::Body>>,
83+
T::Future: Send + 'static,
84+
T::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
85+
{
86+
type Response = T::Response;
87+
type Error = Box<dyn std::error::Error + Send + Sync>;
88+
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
89+
90+
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
91+
self.inner.poll_ready(cx).map_err(Into::into)
92+
}
93+
94+
fn call(&mut self, req: Request<tonic::transport::Body>) -> Self::Future {
95+
assert_eq!(req.uri().host(), Some("docs.rs"));
96+
let fut = self.inner.call(req);
97+
98+
Box::pin(async move { fut.await.map_err(Into::into) })
99+
}
100+
}

tonic/src/transport/channel/endpoint.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ use tower::make::MakeConnection;
2424
#[derive(Clone)]
2525
pub struct Endpoint {
2626
pub(crate) uri: Uri,
27+
pub(crate) origin: Option<Uri>,
2728
pub(crate) user_agent: Option<HeaderValue>,
2829
pub(crate) timeout: Option<Duration>,
2930
pub(crate) concurrency_limit: Option<usize>,
@@ -106,6 +107,25 @@ impl Endpoint {
106107
.map_err(|_| Error::new_invalid_user_agent())
107108
}
108109

110+
/// Set a custom origin.
111+
///
112+
/// Override the `origin`, mainly useful when you are reaching a Server/LoadBalancer
113+
/// which serves multiple services at the same time.
114+
/// It will play the role of SNI (Server Name Indication).
115+
///
116+
/// ```
117+
/// # use tonic::transport::Endpoint;
118+
/// # let mut builder = Endpoint::from_static("https://proxy.com");
119+
/// builder.origin("https://example.com".parse().expect("http://example.com must be a valid URI"));
120+
/// // origin: "https://example.com"
121+
/// ```
122+
pub fn origin(self, origin: Uri) -> Self {
123+
Endpoint {
124+
origin: Some(origin),
125+
..self
126+
}
127+
}
128+
109129
/// Apply a timeout to each request.
110130
///
111131
/// ```
@@ -395,6 +415,7 @@ impl From<Uri> for Endpoint {
395415
fn from(uri: Uri) -> Self {
396416
Self {
397417
uri,
418+
origin: None,
398419
user_agent: None,
399420
concurrency_limit: None,
400421
rate_limit: None,

tonic/src/transport/service/add_origin.rs

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,28 @@
11
use futures_core::future::BoxFuture;
2+
use http::uri::Authority;
3+
use http::uri::Scheme;
24
use http::{Request, Uri};
35
use std::task::{Context, Poll};
46
use tower_service::Service;
57

68
#[derive(Debug)]
79
pub(crate) struct AddOrigin<T> {
810
inner: T,
9-
origin: Uri,
11+
scheme: Option<Scheme>,
12+
authority: Option<Authority>,
1013
}
1114

1215
impl<T> AddOrigin<T> {
1316
pub(crate) fn new(inner: T, origin: Uri) -> Self {
14-
Self { inner, origin }
17+
let http::uri::Parts {
18+
scheme, authority, ..
19+
} = origin.into_parts();
20+
21+
Self {
22+
inner,
23+
scheme,
24+
authority,
25+
}
1526
}
1627
}
1728

@@ -30,24 +41,24 @@ where
3041
}
3142

3243
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
33-
// Split the request into the head and the body.
34-
let (mut head, body) = req.into_parts();
35-
36-
// Split the request URI into parts.
37-
let mut uri: http::uri::Parts = head.uri.into();
38-
let set_uri = self.origin.clone().into_parts();
39-
40-
if set_uri.scheme.is_none() || set_uri.authority.is_none() {
44+
if self.scheme.is_none() || self.authority.is_none() {
4145
let err = crate::transport::Error::new_invalid_uri();
4246
return Box::pin(async move { Err::<Self::Response, _>(err.into()) });
4347
}
4448

45-
// Update the URI parts, setting hte scheme and authority
46-
uri.scheme = Some(set_uri.scheme.expect("expected scheme"));
47-
uri.authority = Some(set_uri.authority.expect("expected authority"));
49+
// Split the request into the head and the body.
50+
let (mut head, body) = req.into_parts();
4851

4952
// Update the the request URI
50-
head.uri = http::Uri::from_parts(uri).expect("valid uri");
53+
head.uri = {
54+
// Split the request URI into parts.
55+
let mut uri: http::uri::Parts = head.uri.into();
56+
// Update the URI parts, setting hte scheme and authority
57+
uri.scheme = self.scheme.clone();
58+
uri.authority = self.authority.clone();
59+
60+
http::Uri::from_parts(uri).expect("valid uri")
61+
};
5162

5263
let request = Request::from_parts(head, body);
5364

tonic/src/transport/service/connection.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,11 @@ impl Connection {
5555
}
5656

5757
let stack = ServiceBuilder::new()
58-
.layer_fn(|s| AddOrigin::new(s, endpoint.uri.clone()))
58+
.layer_fn(|s| {
59+
let origin = endpoint.origin.as_ref().unwrap_or(&endpoint.uri).clone();
60+
61+
AddOrigin::new(s, origin)
62+
})
5963
.layer_fn(|s| UserAgent::new(s, endpoint.user_agent.clone()))
6064
.layer_fn(|s| GrpcTimeout::new(s, endpoint.timeout))
6165
.option_layer(endpoint.concurrency_limit.map(ConcurrencyLimitLayer::new))

0 commit comments

Comments
 (0)