-
Notifications
You must be signed in to change notification settings - Fork 71
Expand file tree
/
Copy pathnegotiator.go
More file actions
331 lines (303 loc) · 9.65 KB
/
negotiator.go
File metadata and controls
331 lines (303 loc) · 9.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
package ntlmssp
import (
"bytes"
"encoding/base64"
"io"
"net/http"
"strings"
)
// negotiatorBody wraps an io.ReadSeeker to allow waiting for its closure
// before rewinding and reusing it.
type negotiatorBody struct {
body io.ReadSeeker
closed chan struct{}
startPos int64
}
// newNegotiatorBody creates a negotiatorBody from the provided io.Reader.
// If the body is nil, it returns nil.
// If the body is already an io.ReadSeeker, it uses it directly.
// Otherwise, it reads the entire body into memory to allow rewinding.
func newNegotiatorBody(body io.Reader) (*negotiatorBody, error) {
if body == nil {
return nil, nil
}
// Check if body is already seekable to avoid buffering large bodies
if seeker, ok := body.(io.ReadSeeker); ok {
// Remember the current position
startPos, err := seeker.Seek(0, io.SeekCurrent)
if err == nil {
// Seeking succeeded, use the seekable body directly
return &negotiatorBody{
body: seeker,
closed: make(chan struct{}, 1),
startPos: startPos,
}, nil
}
// Seeking failed (e.g., pipes), fallback to buffering
}
// For non-seekable bodies, buffer in memory as required
data, err := io.ReadAll(body)
if err != nil {
return nil, err
}
return &negotiatorBody{
body: bytes.NewReader(data),
closed: make(chan struct{}, 1),
}, nil
}
func (b *negotiatorBody) Read(p []byte) (n int, err error) {
if b == nil {
return 0, io.EOF
}
return b.body.Read(p)
}
// Close signals that the body is no longer needed for the current request.
// It allows the negotiator to rewind the body for potential reuse.
// The underlying body is not closed here; use close() for that.
func (b *negotiatorBody) Close() error {
if b == nil {
return nil
}
select {
case b.closed <- struct{}{}:
default:
// Already signaled
}
return nil
}
// close closes the underlying body if it implements io.Closer.
func (b *negotiatorBody) close() {
if b == nil {
return
}
if closer, ok := b.body.(io.Closer); ok {
_ = closer.Close()
}
}
// rewind rewinds the body to the start position for reuse.
func (b *negotiatorBody) rewind() error {
if b == nil {
return nil
}
// Wait for the body to be closed before rewinding
<-b.closed
_, err := b.body.Seek(b.startPos, io.SeekStart)
return err
}
// GetDomain extracts the user domain from the username if present.
//
// Deprecated: Pass the username directly to [ProcessChallenge], it will handle domain extraction.
// Don't pass the resulting domain to [NewNegotiateMessage], that function expects the client
// machine domain, not the user domain.
func GetDomain(username string) (user string, domain string, domainNeeded bool) {
if strings.Contains(username, "\\") {
ucomponents := strings.SplitN(username, "\\", 2)
domain = ucomponents[0]
user = ucomponents[1]
domainNeeded = true
} else if strings.Contains(username, "@") {
user = username
domainNeeded = false
} else {
user = username
domainNeeded = true
}
return user, domain, domainNeeded
}
// Negotiator is a [net/http.RoundTripper] decorator that automatically
// converts basic authentication to NTLM/Negotiate authentication when appropriate.
//
// The credentials must be set using [net/http.Request.SetBasicAuth] on a per-request basis.
//
// By default, no credentials will be sent to the server unless it requests
// Basic authentication and [Negotiator.AllowBasicAuth] is set to true.
type Negotiator struct {
// RoundTripper is the underlying round tripper to use.
// If nil, http.DefaultTransport is used.
http.RoundTripper
// AllowBasicAuth controls whether to send Basic authentication credentials
// if the server requests it.
//
// If false (default), Basic authentication requests are ignored
// and only NTLM/Negotiate authentication is performed.
// If true, Basic authentication requests are honored.
//
// Only set this to true if you trust the server you are connecting to.
// Basic authentication sends the credentials in clear text and may be
// vulnerable to man-in-the-middle attacks and compromised servers.
AllowBasicAuth bool
// WorkstationDomain is the domain of the client machine.
// It is normally not needed to set this field.
// It is passed to the negotiate message.
WorkstationDomain string
// WorkstationName is the workstation name of the client machine.
// It is passed to the negotiate and authenticate messages.
// Useful for auditing purposes on the server side.
WorkstationName string
}
// RoundTrip sends the request to the server, handling any authentication
// re-sends as needed.
func (l Negotiator) RoundTrip(req *http.Request) (*http.Response, error) {
// Use default round tripper if not provided
rt := l.RoundTripper
if rt == nil {
rt = http.DefaultTransport
}
// If it is not basic auth, just round trip the request as usual
username, password, ok := req.BasicAuth()
if !ok {
return rt.RoundTrip(req)
}
id := identity{
username: username,
password: password,
}
req = req.Clone(req.Context()) // Clone the request to avoid modifying the original
// We need to buffer or seek the request body to handle authentication challenges
// that require resending the body multiple times during the NTLM handshake.
body, err := newNegotiatorBody(req.Body)
if err != nil {
if req.Body != nil {
_ = req.Body.Close()
}
return nil, err
}
defer body.close()
// First try anonymous, in case the server still finds us authenticated from previous traffic
req.Body = body
req.Header.Del("Authorization")
resp, err := rt.RoundTrip(req)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusUnauthorized {
// No authentication required, return the response as is
return resp, nil
}
// Note that from here on, the response returned in case of error or unsuccessful
// negotiation is the one we just got from the server. This is to allow the caller
// to do its own handling in case we can't do it in this roundtrip.
originalResp := resp
resauth := newAuthHeader(resp.Header)
if l.AllowBasicAuth && resauth.isBasic() {
// Basic auth requested instead of NTLM/Negotiate.
//
// Rewind the body, we will resend it.
if body.rewind() != nil {
return originalResp, nil
}
req.SetBasicAuth(id.username, id.password)
resp, err := rt.RoundTrip(req)
if err != nil {
return originalResp, nil
}
if resp.StatusCode != http.StatusUnauthorized {
// Basic auth succeeded, return the new response
drainResponse(originalResp)
return resp, nil
}
resauth = newAuthHeader(resp.Header)
if !resauth.isNTLM() {
// No NTLM/Negotiate requested, return the response as is
return resp, nil
}
// Server upgraded from Basic to NTLM/Negotiate (rare but possible)
drainResponse(resp)
// After Basic-to-NTLM upgrade, update originalResp to the NTLM-triggering response
originalResp = resp
} else if !resauth.isNTLM() {
// No NTLM/Negotiate requested, return the response as is
return originalResp, nil
}
// Server requested Negotiate/NTLM, start handshake
// First step: send negotiate message
resp = clientHandshake(rt, req, resauth.schema, l.WorkstationDomain, l.WorkstationName)
if resp == nil {
return originalResp, nil
}
if resp.StatusCode != http.StatusUnauthorized {
// We are expecting a 401 with challenge, but the server responded differently,
// maybe it even accepted our negotiate message without further challenge, which is
// valid per the spec (RFC 4559 Section 5).
// Return the response as is, negotiation is over.
drainResponse(originalResp)
return resp, nil
}
resauth = newAuthHeader(resp.Header)
drainResponse(resp)
// Second step: process challenge and resend the original body with the authenticate message
resp = completeHandshake(rt, resauth, req, id, l.WorkstationName)
if resp == nil {
return originalResp, nil
}
// We could return the original response in case of 401 again, but at this point
// it's better to return the latest response from the server, as it might be the case
// that we are really not authorized.
drainResponse(originalResp) // Done with the original response
return resp, nil
}
type identity struct {
username string
password string
}
func drainResponse(res *http.Response) {
// Drain body and close it to allow reusing the connection
_, _ = io.Copy(io.Discard, res.Body)
_ = res.Body.Close()
}
func rewindBody(req *http.Request) error {
if req.Body == nil {
return nil
}
if nb, ok := req.Body.(*negotiatorBody); ok {
return nb.rewind()
}
return nil
}
func clientHandshake(rt http.RoundTripper, req *http.Request, schema string, domain, workstation string) *http.Response {
if rewindBody(req) != nil {
return nil
}
auth, err := NewNegotiateMessage(domain, workstation)
if err != nil {
return nil
}
req.Header.Set("Authorization", schema+" "+base64.StdEncoding.EncodeToString(auth))
res, err := rt.RoundTrip(req)
if err != nil {
return nil
}
return res
}
func completeHandshake(rt http.RoundTripper, resauth authheader, req *http.Request, id identity, workstation string) *http.Response {
if rewindBody(req) != nil {
return nil
}
challenge, err := resauth.token()
if err != nil {
return nil
}
if !resauth.isNTLM() || len(challenge) == 0 {
// The only expected schema here is NTLM/Negotiate with a challenge token,
// otherwise the negotiation is over.
return nil
}
var opts *AuthenticateMessageOptions
if workstation != "" {
opts = &AuthenticateMessageOptions{
WorkstationName: workstation,
}
}
auth, err := NewAuthenticateMessage(challenge, id.username, id.password, opts)
if err != nil {
return nil
}
req.Header.Set("Authorization", resauth.schema+" "+base64.StdEncoding.EncodeToString(auth))
resp, err := rt.RoundTrip(req)
if err != nil {
return nil
}
return resp
}