diff --git a/conformance/baseline.yml b/conformance/baseline.yml index f0e15132..ae1f9c63 100644 --- a/conformance/baseline.yml +++ b/conformance/baseline.yml @@ -1,5 +1,4 @@ -server: -- dns-rebinding-protection +server: [] # All tests pass! client: - auth/basic-cimd - auth/metadata-default diff --git a/internal/util/net.go b/internal/util/net.go new file mode 100644 index 00000000..6858614e --- /dev/null +++ b/internal/util/net.go @@ -0,0 +1,26 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by the license +// that can be found in the LICENSE file. +package util + +import ( + "net" + "net/netip" + "strings" +) + +func IsLoopback(addr string) bool { + host, _, err := net.SplitHostPort(addr) + if err != nil { + // If SplitHostPort fails, it might be just a host without a port. + host = strings.Trim(addr, "[]") + } + if host == "localhost" { + return true + } + ip, err := netip.ParseAddr(host) + if err != nil { + return false + } + return ip.IsLoopback() +} diff --git a/internal/util/net_test.go b/internal/util/net_test.go new file mode 100644 index 00000000..fd3187ba --- /dev/null +++ b/internal/util/net_test.go @@ -0,0 +1,35 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by the license +// that can be found in the LICENSE file. +package util + +import "testing" + +// TestIsLoopback tests the IsLoopback helper function. +func TestIsLoopback(t *testing.T) { + tests := []struct { + addr string + want bool + }{ + {"localhost", true}, + {"localhost:3000", true}, + {"127.0.0.1", true}, + {"127.0.0.1:3000", true}, + {"[::1]", true}, + {"[::1]:3000", true}, + {"::1", true}, + {"", false}, + {"evil.com", false}, + {"evil.com:80", false}, + {"localhost.evil.com", false}, + {"127.0.0.1.evil.com", false}, + } + + for _, tt := range tests { + t.Run(tt.addr, func(t *testing.T) { + if got := IsLoopback(tt.addr); got != tt.want { + t.Errorf("IsLoopback(%q) = %v, want %v", tt.addr, got, tt.want) + } + }) + } +} diff --git a/mcp/streamable.go b/mcp/streamable.go index c6b96bb0..c06877f5 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -20,6 +20,7 @@ import ( "maps" "math" "math/rand/v2" + "net" "net/http" "slices" "strconv" @@ -30,6 +31,8 @@ import ( "github.com/modelcontextprotocol/go-sdk/auth" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + "github.com/modelcontextprotocol/go-sdk/internal/mcpgodebug" + "github.com/modelcontextprotocol/go-sdk/internal/util" "github.com/modelcontextprotocol/go-sdk/internal/xcontext" "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) @@ -161,6 +164,16 @@ type StreamableHTTPOptions struct { // // If SessionTimeout is the zero value, idle sessions are never closed. SessionTimeout time.Duration + + // DisableLocalhostProtection disables automatic DNS rebinding protection. + // By default, requests arriving via a localhost address (127.0.0.1, [::1]) + // that have a non-localhost Host header are rejected with 403 Forbidden. + // This protects against DNS rebinding attacks regardless of whether the + // server is listening on localhost specifically or on 0.0.0.0. + // + // Only disable this if you understand the security implications. + // See: https://modelcontextprotocol.io/specification/2025-11-25/basic/security_best_practices#local-mcp-server-compromise + DisableLocalhostProtection bool } // NewStreamableHTTPHandler returns a new [StreamableHTTPHandler]. @@ -207,7 +220,24 @@ func (h *StreamableHTTPHandler) closeAll() { } } +// disablelocalhostprotection is a compatibility parameter that allows to disable +// DNS rebinding protection, which was added in the 1.4.0 version of the SDK. +// See the documentation for the mcpgodebug package for instructions how to enable it. +// The option will be removed in the 1.6.0 version of the SDK. +var disablelocalhostprotection = mcpgodebug.Value("disablelocalhostprotection") + func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { + // DNS rebinding protection: auto-enabled for localhost servers. + // See: https://modelcontextprotocol.io/specification/2025-11-25/basic/security_best_practices#local-mcp-server-compromise + if !h.opts.DisableLocalhostProtection && disablelocalhostprotection != "1" { + if localAddr, ok := req.Context().Value(http.LocalAddrContextKey).(net.Addr); ok && localAddr != nil { + if util.IsLoopback(localAddr.String()) && !util.IsLoopback(req.Host) { + http.Error(w, fmt.Sprintf("Forbidden: invalid Host header %q", req.Host), http.StatusForbidden) + return + } + } + } + // Allow multiple 'Accept' headers. // https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Accept#syntax accept := strings.Split(strings.Join(req.Header.Values("Accept"), ","), ",") diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index f1f6200f..2cbe4002 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -2374,3 +2374,114 @@ func Test_ExportErrSessionMissing(t *testing.T) { t.Errorf("expected error to wrap ErrSessionMissing, got: %v", err) } } + +// TestStreamableLocalhostProtection verifies that DNS rebinding protection +// is automatically enabled for localhost servers. +func TestStreamableLocalhostProtection(t *testing.T) { + server := NewServer(testImpl, nil) + + tests := []struct { + name string + listenAddr string // Address to listen on + hostHeader string // Host header in request + disableProtection bool // DisableLocalhostProtection setting + wantStatus int + }{ + // Auto-enabled for localhost listeners (127.0.0.1). + { + name: "127.0.0.1 accepts 127.0.0.1", + listenAddr: "127.0.0.1:0", + hostHeader: "127.0.0.1:1234", + disableProtection: false, + wantStatus: http.StatusOK, + }, + { + name: "127.0.0.1 accepts localhost", + listenAddr: "127.0.0.1:0", + hostHeader: "localhost:1234", + disableProtection: false, + wantStatus: http.StatusOK, + }, + { + name: "127.0.0.1 rejects evil.com", + listenAddr: "127.0.0.1:0", + hostHeader: "evil.com", + disableProtection: false, + wantStatus: http.StatusForbidden, + }, + { + name: "127.0.0.1 rejects evil.com:80", + listenAddr: "127.0.0.1:0", + hostHeader: "evil.com:80", + disableProtection: false, + wantStatus: http.StatusForbidden, + }, + { + name: "127.0.0.1 rejects localhost.evil.com", + listenAddr: "127.0.0.1:0", + hostHeader: "localhost.evil.com", + disableProtection: false, + wantStatus: http.StatusForbidden, + }, + + // When listening on 0.0.0.0, requests arriving via localhost are still protected + // because LocalAddrContextKey returns the actual connection's local address. + // This is actually more secure - DNS rebinding attacks target localhost regardless + // of the listener configuration. + { + name: "0.0.0.0 via localhost rejects evil.com", + listenAddr: "0.0.0.0:0", + hostHeader: "evil.com", + disableProtection: false, + wantStatus: http.StatusForbidden, + }, + + // Explicit disable + { + name: "disabled accepts evil.com", + listenAddr: "127.0.0.1:0", + hostHeader: "evil.com", + disableProtection: true, + wantStatus: http.StatusOK, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + opts := &StreamableHTTPOptions{ + Stateless: true, // Simpler for testing + DisableLocalhostProtection: tt.disableProtection, + } + handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, opts) + + listener, err := net.Listen("tcp", tt.listenAddr) + if err != nil { + t.Fatalf("Failed to listen on %s: %v", tt.listenAddr, err) + } + defer listener.Close() + + srv := &http.Server{Handler: handler} + go srv.Serve(listener) + defer srv.Close() + + reqReader := strings.NewReader(`{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}}`) + req, err := http.NewRequest("POST", fmt.Sprintf("http://%s", listener.Addr().String()), reqReader) + if err != nil { + t.Fatal(err) + } + req.Host = tt.hostHeader + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json, text/event-stream") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + if got := resp.StatusCode; got != tt.wantStatus { + t.Errorf("Status code: got %d, want %d", got, tt.wantStatus) + } + }) + } +}