Skip to content

Commit 309c58e

Browse files
authored
Merge pull request #207 from hashicorp/sebasslash/handle-go-away
Sets request's GetBody field on create wrapper
2 parents 571a88b + f95735f commit 309c58e

3 files changed

Lines changed: 82 additions & 7 deletions

File tree

.github/workflows/go-retryablehttp.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ jobs:
1010
- name: Setup go
1111
uses: actions/setup-go@4d34df0c2316fe8122ab82dc22947d607c0c91f9 # v4.0.0
1212
with:
13-
go-version: 1.14.2
13+
go-version: 1.18
1414
- uses: actions/checkout@8f4b7f84864484a7bf31766abe9204da3cbe65b3 # v3.5.0
1515
- run: mkdir -p "$TEST_RESULTS"/go-retryablyhttp
1616
- name: restore_cache
@@ -20,6 +20,7 @@ jobs:
2020
restore-keys: go-mod-v1-{{ checksum "go.sum" }}
2121
path: "/go/pkg/mod"
2222
- run: go mod download
23+
- run: go mod tidy
2324
- name: Run go format
2425
run: |-
2526
files=$(go fmt ./...)
@@ -29,7 +30,7 @@ jobs:
2930
exit 1
3031
fi
3132
- name: Install gotestsum
32-
run: go get gotest.tools/gotestsum
33+
run: go install gotest.tools/gotestsum@latest
3334
- name: Run unit tests
3435
run: |-
3536
PACKAGE_NAMES=$(go list ./...)

client.go

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,20 @@ func (r *Request) SetBody(rawBody interface{}) error {
160160
}
161161
r.body = bodyReader
162162
r.ContentLength = contentLength
163+
if bodyReader != nil {
164+
r.GetBody = func() (io.ReadCloser, error) {
165+
body, err := bodyReader()
166+
if err != nil {
167+
return nil, err
168+
}
169+
if rc, ok := body.(io.ReadCloser); ok {
170+
return rc, nil
171+
}
172+
return io.NopCloser(body), nil
173+
}
174+
} else {
175+
r.GetBody = func() (io.ReadCloser, error) { return http.NoBody, nil }
176+
}
163177
return nil
164178
}
165179

@@ -302,18 +316,19 @@ func NewRequest(method, url string, rawBody interface{}) (*Request, error) {
302316
// The context controls the entire lifetime of a request and its response:
303317
// obtaining a connection, sending the request, and reading the response headers and body.
304318
func NewRequestWithContext(ctx context.Context, method, url string, rawBody interface{}) (*Request, error) {
305-
bodyReader, contentLength, err := getBodyReaderAndContentLength(rawBody)
319+
httpReq, err := http.NewRequestWithContext(ctx, method, url, nil)
306320
if err != nil {
307321
return nil, err
308322
}
309323

310-
httpReq, err := http.NewRequestWithContext(ctx, method, url, nil)
311-
if err != nil {
324+
req := &Request{
325+
Request: httpReq,
326+
}
327+
if err := req.SetBody(rawBody); err != nil {
312328
return nil, err
313329
}
314-
httpReq.ContentLength = contentLength
315330

316-
return &Request{body: bodyReader, Request: httpReq}, nil
331+
return req, nil
317332
}
318333

319334
// Logger interface allows to use other loggers than

client_test.go

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -978,3 +978,62 @@ func TestClient_StandardClient(t *testing.T) {
978978
t.Fatalf("expected %v, got %v", client, v)
979979
}
980980
}
981+
982+
func TestClient_RedirectWithBody(t *testing.T) {
983+
var redirects int32
984+
// Mock server which always responds 200.
985+
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
986+
switch r.RequestURI {
987+
case "/redirect":
988+
w.Header().Set("Location", "/target")
989+
w.WriteHeader(http.StatusTemporaryRedirect)
990+
case "/target":
991+
atomic.AddInt32(&redirects, 1)
992+
w.WriteHeader(http.StatusCreated)
993+
default:
994+
t.Fatalf("bad uri: %s", r.RequestURI)
995+
}
996+
}))
997+
defer ts.Close()
998+
999+
client := NewClient()
1000+
client.RequestLogHook = func(logger Logger, req *http.Request, retryNumber int) {
1001+
if _, err := req.GetBody(); err != nil {
1002+
t.Fatalf("unexpected error with GetBody: %v", err)
1003+
}
1004+
}
1005+
// create a request with a body
1006+
req, err := NewRequest(http.MethodPost, ts.URL+"/redirect", strings.NewReader(`{"foo":"bar"}`))
1007+
if err != nil {
1008+
t.Fatalf("err: %v", err)
1009+
}
1010+
1011+
resp, err := client.Do(req)
1012+
if err != nil {
1013+
t.Fatalf("err: %v", err)
1014+
}
1015+
resp.Body.Close()
1016+
1017+
if resp.StatusCode != http.StatusCreated {
1018+
t.Fatalf("expected status code 201, got: %d", resp.StatusCode)
1019+
}
1020+
1021+
// now one without a body
1022+
if err := req.SetBody(nil); err != nil {
1023+
t.Fatalf("err: %v", err)
1024+
}
1025+
1026+
resp, err = client.Do(req)
1027+
if err != nil {
1028+
t.Fatalf("err: %v", err)
1029+
}
1030+
resp.Body.Close()
1031+
1032+
if resp.StatusCode != http.StatusCreated {
1033+
t.Fatalf("expected status code 201, got: %d", resp.StatusCode)
1034+
}
1035+
1036+
if atomic.LoadInt32(&redirects) != 2 {
1037+
t.Fatalf("Expected the client to be redirected 2 times, got: %d", atomic.LoadInt32(&redirects))
1038+
}
1039+
}

0 commit comments

Comments
 (0)