Skip to content

Commit 67bd3f2

Browse files
mcp: add automatic DNS rebinding protection for localhost servers (#760)
## Summary Add DNS rebinding protection that is automatically enabled when requests arrive via localhost (127.0.0.1, [::1]). This protects against malicious websites using DNS rebinding to interact with local MCP servers. (note: this was a claude assisted PR, mostly wanted to see how difficult passing this test would be for this SDK) ## Design Goal: Secure by Default The primary goal is to make it difficult to run a localhost server without these protections by mistake. There are other approaches that could provide secure defaults (e.g., a helper for `ListenAndServe` that does the localhost check), but using `http.LocalAddrContextKey` for runtime detection seemed like the most backwards compatible approach and least likely to be disabled by accident. With this implementation: - **No code changes required** - existing servers get protection automatically - **No opt-in needed** - protection activates based on the connection's local address - **Explicit opt-out** - users must deliberately set `DisableLocalhostProtection: true` ## Changes - Add `DisableLocalhostProtection` option to `StreamableHTTPOptions` - Add `isLocalhostAddr` and `isLocalhostHost` helper functions - Validate Host header at start of `ServeHTTP`, rejecting non-localhost Host headers with 403 Forbidden ## How it works The protection uses `http.LocalAddrContextKey` to detect the connection's local address at runtime. When a request arrives via localhost (127.0.0.1 or [::1]), the handler validates that the Host header also matches a localhost value. If not, the request is rejected with 403 Forbidden. This approach means: - Protection is enabled for requests arriving via localhost, regardless of whether the server listens on `127.0.0.1` or `0.0.0.0` - Requests arriving via non-localhost IPs (e.g., external network requests) are **not** affected ### Edge case: Reverse proxies If a reverse proxy (e.g., Envoy, nginx) runs on the same host and forwards requests to the MCP server via localhost while preserving the original Host header, those requests would be rejected. In this case, users should either: 1. Set `DisableLocalhostProtection: true` 2. Configure the proxy to rewrite the Host header to localhost ## Testing - Added unit tests for `isLocalhostAddr` and `isLocalhostHost` helper functions - Added integration tests for the full protection flow - Verified against the MCP conformance test suite: - `localhost-host-rebinding-rejected`: PASS - `localhost-host-valid-accepted`: PASS ## Related - Spec: https://modelcontextprotocol.io/specification/2025-11-25/basic/security_best_practices#local-mcp-server-compromise - TypeScript SDK implementation: `localhostHostValidation()` middleware - Conformance test: `dns-rebinding-protection` scenario --------- Co-authored-by: Maciek Kisiel <[email protected]>
1 parent c952ab0 commit 67bd3f2

File tree

5 files changed

+203
-2
lines changed

5 files changed

+203
-2
lines changed

conformance/baseline.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
server:
2-
- dns-rebinding-protection
1+
server: [] # All tests pass!
32
client:
43
- auth/basic-cimd
54
- auth/metadata-default

internal/util/net.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// Copyright 2025 The Go MCP SDK Authors. All rights reserved.
2+
// Use of this source code is governed by the license
3+
// that can be found in the LICENSE file.
4+
package util
5+
6+
import (
7+
"net"
8+
"net/netip"
9+
"strings"
10+
)
11+
12+
func IsLoopback(addr string) bool {
13+
host, _, err := net.SplitHostPort(addr)
14+
if err != nil {
15+
// If SplitHostPort fails, it might be just a host without a port.
16+
host = strings.Trim(addr, "[]")
17+
}
18+
if host == "localhost" {
19+
return true
20+
}
21+
ip, err := netip.ParseAddr(host)
22+
if err != nil {
23+
return false
24+
}
25+
return ip.IsLoopback()
26+
}

internal/util/net_test.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
// Copyright 2025 The Go MCP SDK Authors. All rights reserved.
2+
// Use of this source code is governed by the license
3+
// that can be found in the LICENSE file.
4+
package util
5+
6+
import "testing"
7+
8+
// TestIsLoopback tests the IsLoopback helper function.
9+
func TestIsLoopback(t *testing.T) {
10+
tests := []struct {
11+
addr string
12+
want bool
13+
}{
14+
{"localhost", true},
15+
{"localhost:3000", true},
16+
{"127.0.0.1", true},
17+
{"127.0.0.1:3000", true},
18+
{"[::1]", true},
19+
{"[::1]:3000", true},
20+
{"::1", true},
21+
{"", false},
22+
{"evil.com", false},
23+
{"evil.com:80", false},
24+
{"localhost.evil.com", false},
25+
{"127.0.0.1.evil.com", false},
26+
}
27+
28+
for _, tt := range tests {
29+
t.Run(tt.addr, func(t *testing.T) {
30+
if got := IsLoopback(tt.addr); got != tt.want {
31+
t.Errorf("IsLoopback(%q) = %v, want %v", tt.addr, got, tt.want)
32+
}
33+
})
34+
}
35+
}

mcp/streamable.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"maps"
2121
"math"
2222
"math/rand/v2"
23+
"net"
2324
"net/http"
2425
"slices"
2526
"strconv"
@@ -30,6 +31,8 @@ import (
3031

3132
"github.com/modelcontextprotocol/go-sdk/auth"
3233
"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
34+
"github.com/modelcontextprotocol/go-sdk/internal/mcpgodebug"
35+
"github.com/modelcontextprotocol/go-sdk/internal/util"
3336
"github.com/modelcontextprotocol/go-sdk/internal/xcontext"
3437
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
3538
)
@@ -161,6 +164,16 @@ type StreamableHTTPOptions struct {
161164
//
162165
// If SessionTimeout is the zero value, idle sessions are never closed.
163166
SessionTimeout time.Duration
167+
168+
// DisableLocalhostProtection disables automatic DNS rebinding protection.
169+
// By default, requests arriving via a localhost address (127.0.0.1, [::1])
170+
// that have a non-localhost Host header are rejected with 403 Forbidden.
171+
// This protects against DNS rebinding attacks regardless of whether the
172+
// server is listening on localhost specifically or on 0.0.0.0.
173+
//
174+
// Only disable this if you understand the security implications.
175+
// See: https://modelcontextprotocol.io/specification/2025-11-25/basic/security_best_practices#local-mcp-server-compromise
176+
DisableLocalhostProtection bool
164177
}
165178

166179
// NewStreamableHTTPHandler returns a new [StreamableHTTPHandler].
@@ -207,7 +220,24 @@ func (h *StreamableHTTPHandler) closeAll() {
207220
}
208221
}
209222

223+
// disablelocalhostprotection is a compatibility parameter that allows to disable
224+
// DNS rebinding protection, which was added in the 1.4.0 version of the SDK.
225+
// See the documentation for the mcpgodebug package for instructions how to enable it.
226+
// The option will be removed in the 1.6.0 version of the SDK.
227+
var disablelocalhostprotection = mcpgodebug.Value("disablelocalhostprotection")
228+
210229
func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
230+
// DNS rebinding protection: auto-enabled for localhost servers.
231+
// See: https://modelcontextprotocol.io/specification/2025-11-25/basic/security_best_practices#local-mcp-server-compromise
232+
if !h.opts.DisableLocalhostProtection && disablelocalhostprotection != "1" {
233+
if localAddr, ok := req.Context().Value(http.LocalAddrContextKey).(net.Addr); ok && localAddr != nil {
234+
if util.IsLoopback(localAddr.String()) && !util.IsLoopback(req.Host) {
235+
http.Error(w, fmt.Sprintf("Forbidden: invalid Host header %q", req.Host), http.StatusForbidden)
236+
return
237+
}
238+
}
239+
}
240+
211241
// Allow multiple 'Accept' headers.
212242
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Accept#syntax
213243
accept := strings.Split(strings.Join(req.Header.Values("Accept"), ","), ",")

mcp/streamable_test.go

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2374,3 +2374,114 @@ func Test_ExportErrSessionMissing(t *testing.T) {
23742374
t.Errorf("expected error to wrap ErrSessionMissing, got: %v", err)
23752375
}
23762376
}
2377+
2378+
// TestStreamableLocalhostProtection verifies that DNS rebinding protection
2379+
// is automatically enabled for localhost servers.
2380+
func TestStreamableLocalhostProtection(t *testing.T) {
2381+
server := NewServer(testImpl, nil)
2382+
2383+
tests := []struct {
2384+
name string
2385+
listenAddr string // Address to listen on
2386+
hostHeader string // Host header in request
2387+
disableProtection bool // DisableLocalhostProtection setting
2388+
wantStatus int
2389+
}{
2390+
// Auto-enabled for localhost listeners (127.0.0.1).
2391+
{
2392+
name: "127.0.0.1 accepts 127.0.0.1",
2393+
listenAddr: "127.0.0.1:0",
2394+
hostHeader: "127.0.0.1:1234",
2395+
disableProtection: false,
2396+
wantStatus: http.StatusOK,
2397+
},
2398+
{
2399+
name: "127.0.0.1 accepts localhost",
2400+
listenAddr: "127.0.0.1:0",
2401+
hostHeader: "localhost:1234",
2402+
disableProtection: false,
2403+
wantStatus: http.StatusOK,
2404+
},
2405+
{
2406+
name: "127.0.0.1 rejects evil.com",
2407+
listenAddr: "127.0.0.1:0",
2408+
hostHeader: "evil.com",
2409+
disableProtection: false,
2410+
wantStatus: http.StatusForbidden,
2411+
},
2412+
{
2413+
name: "127.0.0.1 rejects evil.com:80",
2414+
listenAddr: "127.0.0.1:0",
2415+
hostHeader: "evil.com:80",
2416+
disableProtection: false,
2417+
wantStatus: http.StatusForbidden,
2418+
},
2419+
{
2420+
name: "127.0.0.1 rejects localhost.evil.com",
2421+
listenAddr: "127.0.0.1:0",
2422+
hostHeader: "localhost.evil.com",
2423+
disableProtection: false,
2424+
wantStatus: http.StatusForbidden,
2425+
},
2426+
2427+
// When listening on 0.0.0.0, requests arriving via localhost are still protected
2428+
// because LocalAddrContextKey returns the actual connection's local address.
2429+
// This is actually more secure - DNS rebinding attacks target localhost regardless
2430+
// of the listener configuration.
2431+
{
2432+
name: "0.0.0.0 via localhost rejects evil.com",
2433+
listenAddr: "0.0.0.0:0",
2434+
hostHeader: "evil.com",
2435+
disableProtection: false,
2436+
wantStatus: http.StatusForbidden,
2437+
},
2438+
2439+
// Explicit disable
2440+
{
2441+
name: "disabled accepts evil.com",
2442+
listenAddr: "127.0.0.1:0",
2443+
hostHeader: "evil.com",
2444+
disableProtection: true,
2445+
wantStatus: http.StatusOK,
2446+
},
2447+
}
2448+
2449+
for _, tt := range tests {
2450+
t.Run(tt.name, func(t *testing.T) {
2451+
opts := &StreamableHTTPOptions{
2452+
Stateless: true, // Simpler for testing
2453+
DisableLocalhostProtection: tt.disableProtection,
2454+
}
2455+
handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, opts)
2456+
2457+
listener, err := net.Listen("tcp", tt.listenAddr)
2458+
if err != nil {
2459+
t.Fatalf("Failed to listen on %s: %v", tt.listenAddr, err)
2460+
}
2461+
defer listener.Close()
2462+
2463+
srv := &http.Server{Handler: handler}
2464+
go srv.Serve(listener)
2465+
defer srv.Close()
2466+
2467+
reqReader := strings.NewReader(`{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}}`)
2468+
req, err := http.NewRequest("POST", fmt.Sprintf("http://%s", listener.Addr().String()), reqReader)
2469+
if err != nil {
2470+
t.Fatal(err)
2471+
}
2472+
req.Host = tt.hostHeader
2473+
req.Header.Set("Content-Type", "application/json")
2474+
req.Header.Set("Accept", "application/json, text/event-stream")
2475+
2476+
resp, err := http.DefaultClient.Do(req)
2477+
if err != nil {
2478+
t.Fatal(err)
2479+
}
2480+
defer resp.Body.Close()
2481+
2482+
if got := resp.StatusCode; got != tt.wantStatus {
2483+
t.Errorf("Status code: got %d, want %d", got, tt.wantStatus)
2484+
}
2485+
})
2486+
}
2487+
}

0 commit comments

Comments
 (0)