Skip to content

Commit e92b3e8

Browse files
committed
add support for defining RedirectPolicy for a Client
1 parent 6ef73ae commit e92b3e8

File tree

5 files changed

+250
-24
lines changed

5 files changed

+250
-24
lines changed

src/client.rs

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::fmt;
22
use std::io::{self, Read};
3-
use std::sync::Arc;
3+
use std::sync::{Arc, Mutex};
44

55
use hyper::client::IntoUrl;
66
use hyper::header::{Headers, ContentType, Location, Referer, UserAgent};
@@ -14,6 +14,7 @@ use serde_json;
1414
use serde_urlencoded;
1515

1616
use ::body::{self, Body};
17+
use ::redirect::{RedirectPolicy, check_redirect};
1718

1819
static DEFAULT_USER_AGENT: &'static str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"));
1920

@@ -24,8 +25,9 @@ static DEFAULT_USER_AGENT: &'static str = concat!(env!("CARGO_PKG_NAME"), "/", e
2425
///
2526
/// The `Client` holds a connection pool internally, so it is advised that
2627
/// you create one and reuse it.
28+
#[derive(Clone)]
2729
pub struct Client {
28-
inner: ClientRef, //::hyper::Client,
30+
inner: Arc<ClientRef>, //::hyper::Client,
2931
}
3032

3133
impl Client {
@@ -34,12 +36,18 @@ impl Client {
3436
let mut client = try!(new_hyper_client());
3537
client.set_redirect_policy(::hyper::client::RedirectPolicy::FollowNone);
3638
Ok(Client {
37-
inner: ClientRef {
38-
hyper: Arc::new(client),
39-
}
39+
inner: Arc::new(ClientRef {
40+
hyper: client,
41+
redirect_policy: Mutex::new(RedirectPolicy::default()),
42+
}),
4043
})
4144
}
4245

46+
/// Set a `RedirectPolicy` for this client.
47+
pub fn redirect(&mut self, policy: RedirectPolicy) {
48+
*self.inner.redirect_policy.lock().unwrap() = policy;
49+
}
50+
4351
/// Convenience method to make a `GET` request to a URL.
4452
pub fn get<U: IntoUrl>(&self, url: U) -> RequestBuilder {
4553
self.request(Method::Get, url)
@@ -75,13 +83,15 @@ impl Client {
7583

7684
impl fmt::Debug for Client {
7785
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
78-
f.pad("Client")
86+
f.debug_struct("Client")
87+
.field("redirect_policy", &self.inner.redirect_policy)
88+
.finish()
7989
}
8090
}
8191

82-
#[derive(Clone)]
8392
struct ClientRef {
84-
hyper: Arc<::hyper::Client>,
93+
hyper: ::hyper::Client,
94+
redirect_policy: Mutex<RedirectPolicy>,
8595
}
8696

8797
fn new_hyper_client() -> ::Result<::hyper::Client> {
@@ -97,7 +107,7 @@ fn new_hyper_client() -> ::Result<::hyper::Client> {
97107

98108
/// A builder to construct the properties of a `Request`.
99109
pub struct RequestBuilder {
100-
client: ClientRef,
110+
client: Arc<ClientRef>,
101111

102112
method: Method,
103113
url: Result<Url, ::UrlError>,
@@ -196,7 +206,7 @@ impl RequestBuilder {
196206
None => None,
197207
};
198208

199-
let mut redirect_count = 0;
209+
let mut urls = Vec::new();
200210

201211
loop {
202212
let res = {
@@ -237,14 +247,6 @@ impl RequestBuilder {
237247
};
238248

239249
if should_redirect {
240-
//TODO: turn this into self.redirect_policy.check()
241-
if redirect_count > 10 {
242-
return Err(::Error::TooManyRedirects);
243-
}
244-
redirect_count += 1;
245-
246-
headers.set(Referer(url.to_string()));
247-
248250
let loc = {
249251
let loc = res.headers.get::<Location>().map(|loc| url.join(loc));
250252
if let Some(loc) = loc {
@@ -257,7 +259,18 @@ impl RequestBuilder {
257259
};
258260

259261
url = match loc {
260-
Ok(u) => u,
262+
Ok(loc) => {
263+
headers.set(Referer(url.to_string()));
264+
urls.push(url);
265+
if check_redirect(&client.redirect_policy.lock().unwrap(), &loc, &urls)? {
266+
loc
267+
} else {
268+
debug!("redirect_policy disallowed redirection to '{}'", loc);
269+
return Ok(Response {
270+
inner: res
271+
})
272+
}
273+
},
261274
Err(e) => {
262275
debug!("Location header had invalid URI: {:?}", e);
263276
return Ok(Response {

src/error.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ pub enum Error {
1313
Serialize(Box<StdError + Send + Sync>),
1414
/// A request tried to redirect too many times.
1515
TooManyRedirects,
16+
/// An infinite redirect loop was detected.
17+
RedirectLoop,
1618
#[doc(hidden)]
1719
__DontMatchMe,
1820
}
@@ -22,9 +24,8 @@ impl fmt::Display for Error {
2224
match *self {
2325
Error::Http(ref e) => fmt::Display::fmt(e, f),
2426
Error::Serialize(ref e) => fmt::Display::fmt(e, f),
25-
Error::TooManyRedirects => {
26-
f.pad("Too many redirects")
27-
},
27+
Error::TooManyRedirects => f.pad("Too many redirects"),
28+
Error::RedirectLoop => f.pad("Infinite redirect loop"),
2829
Error::__DontMatchMe => unreachable!()
2930
}
3031
}
@@ -36,6 +37,7 @@ impl StdError for Error {
3637
Error::Http(ref e) => e.description(),
3738
Error::Serialize(ref e) => e.description(),
3839
Error::TooManyRedirects => "Too many redirects",
40+
Error::RedirectLoop => "Infinite redirect loop",
3941
Error::__DontMatchMe => unreachable!()
4042
}
4143
}
@@ -45,6 +47,7 @@ impl StdError for Error {
4547
Error::Http(ref e) => Some(e),
4648
Error::Serialize(ref e) => Some(&**e),
4749
Error::TooManyRedirects => None,
50+
Error::RedirectLoop => None,
4851
Error::__DontMatchMe => unreachable!()
4952
}
5053
}

src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,10 +108,12 @@ pub use url::ParseError as UrlError;
108108
pub use self::client::{Client, Response, RequestBuilder};
109109
pub use self::error::{Error, Result};
110110
pub use self::body::Body;
111+
pub use self::redirect::RedirectPolicy;
111112

112113
mod body;
113114
mod client;
114115
mod error;
116+
mod redirect;
115117
mod tls;
116118

117119

src/redirect.rs

Lines changed: 157 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,163 @@
1+
use std::fmt;
2+
3+
use ::Url;
4+
5+
/// A type that controls the policy on how to handle the following of redirects.
6+
///
7+
/// The default value will catch redirect loops, and has a maximum of 10
8+
/// redirects it will follow in a chain before returning an error.
19
#[derive(Debug)]
210
pub struct RedirectPolicy {
3-
inner: ()
11+
inner: Policy,
412
}
513

614
impl RedirectPolicy {
7-
15+
/// Create a RedirectPolicy with a maximum number of redirects.
16+
///
17+
/// A `Error::TooManyRedirects` will be returned if the max is reached.
18+
pub fn limited(max: usize) -> RedirectPolicy {
19+
RedirectPolicy {
20+
inner: Policy::Limit(max),
21+
}
22+
}
23+
24+
/// Create a RedirectPolicy that does not follow any redirect.
25+
pub fn none() -> RedirectPolicy {
26+
RedirectPolicy {
27+
inner: Policy::None,
28+
}
29+
}
30+
31+
/// Create a custom RedirectPolicy using the passed function.
32+
///
33+
/// # Note
34+
///
35+
/// The default RedirectPolicy handles redirect loops and a maximum loop
36+
/// chain, but the custom variant does not do that for you automatically.
37+
/// The custom policy should hanve some way of handling those.
38+
///
39+
/// There are variants on `::Error` for both cases that can be used as
40+
/// return values.
41+
///
42+
/// # Example
43+
///
44+
/// ```no_run
45+
/// # use reqwest::RedirectPolicy;
46+
/// # let mut client = reqwest::Client::new().unwrap();
47+
/// client.redirect(RedirectPolicy::custom(|next, previous| {
48+
/// if previous.len() > 5 {
49+
/// Err(reqwest::Error::TooManyRedirects)
50+
/// } else if next.host_str() == Some("example.domain") {
51+
/// // prevent redirects to 'example.domain'
52+
/// Ok(false)
53+
/// } else {
54+
/// Ok(true)
55+
/// }
56+
/// }));
57+
/// ```
58+
pub fn custom<T>(policy: T) -> RedirectPolicy
59+
where T: Fn(&Url, &[Url]) -> ::Result<bool> + Send + Sync + 'static {
60+
RedirectPolicy {
61+
inner: Policy::Custom(Box::new(policy)),
62+
}
63+
}
64+
65+
fn redirect(&self, next: &Url, previous: &[Url]) -> ::Result<bool> {
66+
match self.inner {
67+
Policy::Custom(ref custom) => custom(next, previous),
68+
Policy::Limit(max) => {
69+
if previous.len() == max {
70+
Err(::Error::TooManyRedirects)
71+
} else if previous.contains(next) {
72+
Err(::Error::RedirectLoop)
73+
} else {
74+
Ok(true)
75+
}
76+
},
77+
Policy::None => Ok(false),
78+
}
79+
}
80+
}
81+
82+
impl Default for RedirectPolicy {
83+
fn default() -> RedirectPolicy {
84+
RedirectPolicy::limited(10)
85+
}
86+
}
87+
88+
enum Policy {
89+
Custom(Box<Fn(&Url, &[Url]) -> ::Result<bool> + Send + Sync + 'static>),
90+
Limit(usize),
91+
None,
92+
}
93+
94+
impl fmt::Debug for Policy {
95+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
96+
match *self {
97+
Policy::Custom(..) => f.pad("Custom"),
98+
Policy::Limit(max) => f.debug_tuple("Limit").field(&max).finish(),
99+
Policy::None => f.pad("None"),
100+
}
101+
}
102+
}
103+
104+
pub fn check_redirect(policy: &RedirectPolicy, next: &Url, previous: &[Url]) -> ::Result<bool> {
105+
policy.redirect(next, previous)
106+
}
107+
108+
/*
109+
This was the desired way of doing it, but ran in to inference issues when
110+
using closures, since the arguments received are references (&Url and &[Url]),
111+
and the compiler could not infer the lifetimes of those references. That means
112+
people would need to annotate the closure's argument types, which is garbase.
113+
114+
pub trait Redirect {
115+
fn redirect(&self, next: &Url, previous: &[Url]) -> ::Result<bool>;
116+
}
117+
118+
impl<F> Redirect for F
119+
where F: Fn(&Url, &[Url]) -> ::Result<bool> {
120+
fn redirect(&self, next: &Url, previous: &[Url]) -> ::Result<bool> {
121+
self(next, previous)
122+
}
123+
}
124+
*/
125+
126+
#[test]
127+
fn test_redirect_policy_limit() {
128+
let policy = RedirectPolicy::default();
129+
let next = Url::parse("http://x.y/z").unwrap();
130+
let mut previous = (0..9)
131+
.map(|i| Url::parse(&format!("http://a.b/c/{}", i)).unwrap())
132+
.collect::<Vec<_>>();
133+
134+
135+
match policy.redirect(&next, &previous) {
136+
Ok(true) => {},
137+
other => panic!("expected Ok(true), got: {:?}", other)
138+
}
139+
140+
previous.push(Url::parse("http://a.b.d/e/33").unwrap());
141+
142+
match policy.redirect(&next, &previous) {
143+
Err(::Error::TooManyRedirects) => {},
144+
other => panic!("expected TooManyRedirects, got: {:?}", other)
145+
}
146+
}
147+
148+
#[test]
149+
fn test_redirect_policy_custom() {
150+
let policy = RedirectPolicy::custom(|next, _previous| {
151+
if next.host_str() == Some("foo") {
152+
Ok(false)
153+
} else {
154+
Ok(true)
155+
}
156+
});
157+
158+
let next = Url::parse("http://bar/baz").unwrap();
159+
assert_eq!(policy.redirect(&next, &[]).unwrap(), true);
160+
161+
let next = Url::parse("http://foo/baz").unwrap();
162+
assert_eq!(policy.redirect(&next, &[]).unwrap(), false);
8163
}

tests/client.rs

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,3 +160,56 @@ fn test_redirect_307_does_not_try_if_reader_cannot_reset() {
160160
assert_eq!(res.status(), &reqwest::StatusCode::from_u16(code));
161161
}
162162
}
163+
164+
#[test]
165+
fn test_redirect_policy_can_return_errors() {
166+
let server = server! {
167+
request: b"\
168+
GET /loop HTTP/1.1\r\n\
169+
Host: $HOST\r\n\
170+
User-Agent: $USERAGENT\r\n\
171+
\r\n\
172+
",
173+
response: b"\
174+
HTTP/1.1 302 Found\r\n\
175+
Server: test\r\n\
176+
Location: /loop
177+
Content-Length: 0\r\n\
178+
\r\n\
179+
"
180+
};
181+
182+
let err = reqwest::get(&format!("http://{}/loop", server.addr())).unwrap_err();
183+
match err {
184+
reqwest::Error::RedirectLoop => (),
185+
e => panic!("wrong error received: {:?}", e),
186+
}
187+
}
188+
189+
#[test]
190+
fn test_redirect_policy_can_stop_redirects_without_an_error() {
191+
let server = server! {
192+
request: b"\
193+
GET /no-redirect HTTP/1.1\r\n\
194+
Host: $HOST\r\n\
195+
User-Agent: $USERAGENT\r\n\
196+
\r\n\
197+
",
198+
response: b"\
199+
HTTP/1.1 302 Found\r\n\
200+
Server: test-dont\r\n\
201+
Location: /dont
202+
Content-Length: 0\r\n\
203+
\r\n\
204+
"
205+
};
206+
let mut client = reqwest::Client::new().unwrap();
207+
client.redirect(reqwest::RedirectPolicy::none());
208+
209+
let res = client.get(&format!("http://{}/no-redirect", server.addr()))
210+
.send()
211+
.unwrap();
212+
213+
assert_eq!(res.status(), &reqwest::StatusCode::Found);
214+
assert_eq!(res.headers().get(), Some(&reqwest::header::Server("test-dont".to_string())));
215+
}

0 commit comments

Comments
 (0)