Skip to content

Commit 6326315

Browse files
committed
Add missing go file
1 parent 584e932 commit 6326315

1 file changed

Lines changed: 187 additions & 0 deletions

File tree

src/mcps/lib/security.go

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
package mcplib
2+
3+
import (
4+
"net"
5+
"net/http"
6+
"strings"
7+
"sync"
8+
"time"
9+
)
10+
11+
// HTTPSecurity configures HTTP-layer security for MCP servers.
12+
// The zero value is secure: CORS rejects cross-origin requests,
13+
// DNS rebinding is blocked on loopback, and Content-Type is enforced.
14+
type HTTPSecurity struct {
15+
// AllowedOrigins lists accepted Origin header values for CORS.
16+
// Empty (default) rejects all cross-origin requests.
17+
// Use ["*"] to allow any origin (not recommended for public ports).
18+
AllowedOrigins []string
19+
20+
// AllowDNSRebinding disables Host header validation on loopback listeners.
21+
AllowDNSRebinding bool
22+
23+
// SessionTimeout is the idle duration before SSE sessions expire.
24+
// Zero means no timeout.
25+
SessionTimeout time.Duration
26+
27+
// MaxSessions limits concurrent SSE sessions. Zero means unlimited.
28+
MaxSessions int
29+
30+
// RateLimit is the max requests per second per source IP. Zero means unlimited.
31+
RateLimit float64
32+
33+
// RateBurst is the token bucket burst size. Defaults to max(1, int(RateLimit)).
34+
RateBurst int
35+
}
36+
37+
// SetHTTPSecurity configures HTTP security options.
38+
// The zero value is secure: CORS restricted, DNS rebinding blocked, Content-Type enforced.
39+
func (s *MCPServer) SetHTTPSecurity(sec HTTPSecurity) {
40+
s.httpSecurity = sec
41+
if sec.RateLimit > 0 {
42+
s.limiter = newRateLimiter(sec.RateLimit, sec.RateBurst)
43+
} else {
44+
s.limiter = nil
45+
}
46+
}
47+
48+
// httpSecurityCheck runs all security checks on an incoming HTTP request.
49+
// Returns true if the request was rejected (response already written).
50+
func (s *MCPServer) httpSecurityCheck(w http.ResponseWriter, r *http.Request) bool {
51+
// DNS rebinding protection on loopback listeners
52+
if !s.httpSecurity.AllowDNSRebinding {
53+
if laddr, ok := r.Context().Value(http.LocalAddrContextKey).(net.Addr); ok {
54+
if isLoopback(laddr.String()) && !isLoopback(r.Host) {
55+
http.Error(w, "Forbidden", http.StatusForbidden)
56+
return true
57+
}
58+
}
59+
}
60+
61+
// CORS: validate Origin header
62+
origin := r.Header.Get("Origin")
63+
if origin != "" {
64+
if !s.originAllowed(origin) {
65+
http.Error(w, "Forbidden", http.StatusForbidden)
66+
return true
67+
}
68+
w.Header().Set("Access-Control-Allow-Origin", origin)
69+
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type, Cache-Control, X-SSE-Session-ID")
70+
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
71+
w.Header().Set("Vary", "Origin")
72+
}
73+
74+
// CORS preflight
75+
if r.Method == http.MethodOptions {
76+
if origin != "" {
77+
w.Header().Set("Access-Control-Max-Age", "86400")
78+
}
79+
w.WriteHeader(http.StatusNoContent)
80+
return true
81+
}
82+
83+
// Content-Type enforcement for POST
84+
if r.Method == http.MethodPost {
85+
ct := strings.ToLower(r.Header.Get("Content-Type"))
86+
if !strings.HasPrefix(ct, "application/json") {
87+
http.Error(w, "Unsupported Media Type", http.StatusUnsupportedMediaType)
88+
return true
89+
}
90+
}
91+
92+
// Rate limiting
93+
if s.limiter != nil && !s.limiter.allow(clientIP(r)) {
94+
http.Error(w, "Too Many Requests", http.StatusTooManyRequests)
95+
return true
96+
}
97+
98+
return false
99+
}
100+
101+
func (s *MCPServer) originAllowed(origin string) bool {
102+
for _, o := range s.httpSecurity.AllowedOrigins {
103+
if o == "*" || strings.EqualFold(o, origin) {
104+
return true
105+
}
106+
}
107+
return false
108+
}
109+
110+
// isLoopback reports whether addr refers to a loopback address.
111+
func isLoopback(addr string) bool {
112+
host, _, err := net.SplitHostPort(addr)
113+
if err != nil {
114+
host = addr
115+
}
116+
if strings.EqualFold(host, "localhost") {
117+
return true
118+
}
119+
ip := net.ParseIP(strings.Trim(host, "[]"))
120+
return ip != nil && ip.IsLoopback()
121+
}
122+
123+
// clientIP extracts the remote IP from the request.
124+
func clientIP(r *http.Request) string {
125+
host, _, err := net.SplitHostPort(r.RemoteAddr)
126+
if err != nil {
127+
return r.RemoteAddr
128+
}
129+
return host
130+
}
131+
132+
// rateLimiter implements a per-IP token bucket rate limiter.
133+
type rateLimiter struct {
134+
mu sync.Mutex
135+
rate float64
136+
burst int
137+
clients map[string]*rateBucket
138+
}
139+
140+
type rateBucket struct {
141+
tokens float64
142+
last time.Time
143+
}
144+
145+
func newRateLimiter(rate float64, burst int) *rateLimiter {
146+
if burst < 1 {
147+
burst = int(rate)
148+
if burst < 1 {
149+
burst = 1
150+
}
151+
}
152+
return &rateLimiter{
153+
rate: rate,
154+
burst: burst,
155+
clients: make(map[string]*rateBucket),
156+
}
157+
}
158+
159+
func (rl *rateLimiter) allow(ip string) bool {
160+
rl.mu.Lock()
161+
defer rl.mu.Unlock()
162+
now := time.Now()
163+
// Purge stale entries when map grows large
164+
if len(rl.clients) > 1000 {
165+
cutoff := now.Add(-5 * time.Minute)
166+
for k, v := range rl.clients {
167+
if v.last.Before(cutoff) {
168+
delete(rl.clients, k)
169+
}
170+
}
171+
}
172+
b, ok := rl.clients[ip]
173+
if !ok {
174+
rl.clients[ip] = &rateBucket{tokens: float64(rl.burst) - 1, last: now}
175+
return true
176+
}
177+
b.tokens += now.Sub(b.last).Seconds() * rl.rate
178+
b.last = now
179+
if b.tokens > float64(rl.burst) {
180+
b.tokens = float64(rl.burst)
181+
}
182+
if b.tokens < 1 {
183+
return false
184+
}
185+
b.tokens--
186+
return true
187+
}

0 commit comments

Comments
 (0)