Skip to content

Commit 2b4b1dc

Browse files
authored
fix(openai): sanitize urls when counting tokens in images (#35143)
1 parent 0493b27 commit 2b4b1dc

4 files changed

Lines changed: 720 additions & 5 deletions

File tree

libs/core/langchain_core/_security/__init__.py

Whitespace-only changes.
Lines changed: 361 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,361 @@
1+
"""SSRF Protection for validating URLs against Server-Side Request Forgery attacks.
2+
3+
This module provides utilities to validate user-provided URLs and prevent SSRF attacks
4+
by blocking requests to:
5+
- Private IP ranges (RFC 1918, loopback, link-local)
6+
- Cloud metadata endpoints (AWS, GCP, Azure, etc.)
7+
- Localhost addresses
8+
- Invalid URL schemes
9+
10+
Usage:
11+
from lc_security.ssrf_protection import validate_safe_url, is_safe_url
12+
13+
# Validate a URL (raises ValueError if unsafe)
14+
safe_url = validate_safe_url("https://example.com/webhook")
15+
16+
# Check if URL is safe (returns bool)
17+
if is_safe_url("http://192.168.1.1"):
18+
# URL is safe
19+
pass
20+
21+
# Allow private IPs for development/testing (still blocks cloud metadata)
22+
safe_url = validate_safe_url("http://localhost:8080", allow_private=True)
23+
"""
24+
25+
import ipaddress
26+
import os
27+
import socket
28+
from typing import Annotated, Any
29+
from urllib.parse import urlparse
30+
31+
from pydantic import (
32+
AnyHttpUrl,
33+
BeforeValidator,
34+
HttpUrl,
35+
)
36+
37+
# Private IP ranges (RFC 1918, RFC 4193, RFC 3927, loopback)
38+
PRIVATE_IP_RANGES = [
39+
ipaddress.ip_network("10.0.0.0/8"), # Private Class A
40+
ipaddress.ip_network("172.16.0.0/12"), # Private Class B
41+
ipaddress.ip_network("192.168.0.0/16"), # Private Class C
42+
ipaddress.ip_network("127.0.0.0/8"), # Loopback
43+
ipaddress.ip_network("169.254.0.0/16"), # Link-local (includes cloud metadata)
44+
ipaddress.ip_network("0.0.0.0/8"), # Current network
45+
ipaddress.ip_network("::1/128"), # IPv6 loopback
46+
ipaddress.ip_network("fc00::/7"), # IPv6 unique local
47+
ipaddress.ip_network("fe80::/10"), # IPv6 link-local
48+
ipaddress.ip_network("ff00::/8"), # IPv6 multicast
49+
]
50+
51+
# Cloud provider metadata endpoints
52+
CLOUD_METADATA_IPS = [
53+
"169.254.169.254", # AWS, GCP, Azure, DigitalOcean, Oracle Cloud
54+
"169.254.170.2", # AWS ECS task metadata
55+
"100.100.100.200", # Alibaba Cloud metadata
56+
]
57+
58+
CLOUD_METADATA_HOSTNAMES = [
59+
"metadata.google.internal", # GCP
60+
"metadata", # Generic
61+
"instance-data", # AWS EC2
62+
]
63+
64+
# Localhost variations
65+
LOCALHOST_NAMES = [
66+
"localhost",
67+
"localhost.localdomain",
68+
]
69+
70+
71+
def is_private_ip(ip_str: str) -> bool:
72+
"""Check if an IP address is in a private range.
73+
74+
Args:
75+
ip_str: IP address as a string (e.g., "192.168.1.1")
76+
77+
Returns:
78+
True if IP is in a private range, False otherwise
79+
"""
80+
try:
81+
ip = ipaddress.ip_address(ip_str)
82+
return any(ip in range_ for range_ in PRIVATE_IP_RANGES)
83+
except ValueError:
84+
return False
85+
86+
87+
def is_cloud_metadata(hostname: str, ip_str: str | None = None) -> bool:
88+
"""Check if hostname or IP is a cloud metadata endpoint.
89+
90+
Args:
91+
hostname: Hostname to check
92+
ip_str: Optional IP address to check
93+
94+
Returns:
95+
True if hostname or IP is a known cloud metadata endpoint
96+
"""
97+
# Check hostname
98+
if hostname.lower() in CLOUD_METADATA_HOSTNAMES:
99+
return True
100+
101+
# Check IP
102+
if ip_str and ip_str in CLOUD_METADATA_IPS: # noqa: SIM103
103+
return True
104+
105+
return False
106+
107+
108+
def is_localhost(hostname: str, ip_str: str | None = None) -> bool:
109+
"""Check if hostname or IP is localhost.
110+
111+
Args:
112+
hostname: Hostname to check
113+
ip_str: Optional IP address to check
114+
115+
Returns:
116+
True if hostname or IP is localhost
117+
"""
118+
# Check hostname
119+
if hostname.lower() in LOCALHOST_NAMES:
120+
return True
121+
122+
# Check IP
123+
if ip_str:
124+
try:
125+
ip = ipaddress.ip_address(ip_str)
126+
# Check if loopback
127+
if ip.is_loopback:
128+
return True
129+
# Also check common localhost IPs
130+
if ip_str in ("127.0.0.1", "::1", "0.0.0.0"): # noqa: S104
131+
return True
132+
except ValueError:
133+
pass
134+
135+
return False
136+
137+
138+
def validate_safe_url(
139+
url: str | AnyHttpUrl,
140+
*,
141+
allow_private: bool = False,
142+
allow_http: bool = True,
143+
) -> str:
144+
"""Validate a URL for SSRF protection.
145+
146+
This function validates URLs to prevent Server-Side Request Forgery (SSRF) attacks
147+
by blocking requests to private networks and cloud metadata endpoints.
148+
149+
Args:
150+
url: The URL to validate (string or Pydantic HttpUrl)
151+
allow_private: If True, allows private IPs and localhost (for development).
152+
Cloud metadata endpoints are ALWAYS blocked.
153+
allow_http: If True, allows both HTTP and HTTPS. If False, only HTTPS.
154+
155+
Returns:
156+
The validated URL as a string
157+
158+
Raises:
159+
ValueError: If URL is invalid or potentially dangerous
160+
161+
Examples:
162+
>>> validate_safe_url("https://hooks.slack.com/services/xxx")
163+
'https://hooks.slack.com/services/xxx'
164+
165+
>>> validate_safe_url("http://127.0.0.1:8080")
166+
ValueError: Localhost URLs are not allowed
167+
168+
>>> validate_safe_url("http://192.168.1.1")
169+
ValueError: URL resolves to private IP: 192.168.1.1
170+
171+
>>> validate_safe_url("http://169.254.169.254/latest/meta-data/")
172+
ValueError: URL resolves to cloud metadata IP: 169.254.169.254
173+
174+
>>> validate_safe_url("http://localhost:8080", allow_private=True)
175+
'http://localhost:8080'
176+
"""
177+
url_str = str(url)
178+
parsed = urlparse(url_str)
179+
180+
# Validate URL scheme
181+
if not allow_http and parsed.scheme != "https":
182+
msg = "Only HTTPS URLs are allowed"
183+
raise ValueError(msg)
184+
185+
if parsed.scheme not in ("http", "https"):
186+
msg = f"Only HTTP/HTTPS URLs are allowed, got scheme: {parsed.scheme}"
187+
raise ValueError(msg)
188+
189+
# Extract hostname
190+
hostname = parsed.hostname
191+
if not hostname:
192+
msg = "URL must have a valid hostname"
193+
raise ValueError(msg)
194+
195+
# Special handling for test environments - allow test server hostnames
196+
# testserver is used by FastAPI/Starlette test clients and doesn't resolve via DNS
197+
# Only enabled when LANGCHAIN_ENV=local_test (set in conftest.py)
198+
if (
199+
os.environ.get("LANGCHAIN_ENV") == "local_test"
200+
and hostname.startswith("test")
201+
and "server" in hostname
202+
):
203+
return url_str
204+
205+
# ALWAYS block cloud metadata endpoints (even with allow_private=True)
206+
if is_cloud_metadata(hostname):
207+
msg = f"Cloud metadata endpoints are not allowed: {hostname}"
208+
raise ValueError(msg)
209+
210+
# Check for localhost
211+
if is_localhost(hostname) and not allow_private:
212+
msg = f"Localhost URLs are not allowed: {hostname}"
213+
raise ValueError(msg)
214+
215+
# Resolve hostname to IP addresses and validate each one.
216+
# Note: DNS resolution results are cached by the OS, so repeated calls are fast.
217+
try:
218+
# Get all IP addresses for this hostname
219+
addr_info = socket.getaddrinfo(
220+
hostname,
221+
parsed.port or (443 if parsed.scheme == "https" else 80),
222+
socket.AF_UNSPEC, # Allow both IPv4 and IPv6
223+
socket.SOCK_STREAM,
224+
)
225+
226+
for result in addr_info:
227+
ip_str: str = result[4][0] # type: ignore[assignment]
228+
229+
# ALWAYS block cloud metadata IPs
230+
if is_cloud_metadata(hostname, ip_str):
231+
msg = f"URL resolves to cloud metadata IP: {ip_str}"
232+
raise ValueError(msg)
233+
234+
# Check for localhost IPs
235+
if is_localhost(hostname, ip_str) and not allow_private:
236+
msg = f"URL resolves to localhost IP: {ip_str}"
237+
raise ValueError(msg)
238+
239+
# Check for private IPs
240+
if not allow_private and is_private_ip(ip_str):
241+
msg = f"URL resolves to private IP address: {ip_str}"
242+
raise ValueError(msg)
243+
244+
except socket.gaierror as e:
245+
# DNS resolution failed - fail closed for security
246+
msg = f"Failed to resolve hostname '{hostname}': {e}"
247+
raise ValueError(msg) from e
248+
except OSError as e:
249+
# Other network errors - fail closed
250+
msg = f"Network error while validating URL: {e}"
251+
raise ValueError(msg) from e
252+
253+
return url_str
254+
255+
256+
def is_safe_url(
257+
url: str | AnyHttpUrl,
258+
*,
259+
allow_private: bool = False,
260+
allow_http: bool = True,
261+
) -> bool:
262+
"""Check if a URL is safe (non-throwing version of validate_safe_url).
263+
264+
Args:
265+
url: The URL to check
266+
allow_private: If True, allows private IPs and localhost
267+
allow_http: If True, allows both HTTP and HTTPS
268+
269+
Returns:
270+
True if URL is safe, False otherwise
271+
272+
Examples:
273+
>>> is_safe_url("https://example.com")
274+
True
275+
276+
>>> is_safe_url("http://127.0.0.1:8080")
277+
False
278+
279+
>>> is_safe_url("http://localhost:8080", allow_private=True)
280+
True
281+
"""
282+
try:
283+
validate_safe_url(url, allow_private=allow_private, allow_http=allow_http)
284+
except ValueError:
285+
return False
286+
else:
287+
return True
288+
289+
290+
def _validate_url_ssrf_strict(v: Any) -> Any:
291+
"""Validate URL for SSRF protection (strict mode)."""
292+
if isinstance(v, str):
293+
validate_safe_url(v, allow_private=False, allow_http=True)
294+
return v
295+
296+
297+
def _validate_url_ssrf_https_only(v: Any) -> Any:
298+
"""Validate URL for SSRF protection (HTTPS only, strict mode)."""
299+
if isinstance(v, str):
300+
validate_safe_url(v, allow_private=False, allow_http=False)
301+
return v
302+
303+
304+
def _validate_url_ssrf_relaxed(v: Any) -> Any:
305+
"""Validate URL for SSRF protection (relaxed mode - allows private IPs)."""
306+
if isinstance(v, str):
307+
validate_safe_url(v, allow_private=True, allow_http=True)
308+
return v
309+
310+
311+
# Annotated types with SSRF protection
312+
SSRFProtectedUrl = Annotated[HttpUrl, BeforeValidator(_validate_url_ssrf_strict)]
313+
"""A Pydantic HttpUrl type with built-in SSRF protection.
314+
315+
This blocks private IPs, localhost, and cloud metadata endpoints.
316+
317+
Example:
318+
class WebhookSchema(BaseModel):
319+
url: SSRFProtectedUrl # Automatically validated for SSRF
320+
headers: dict[str, str] | None = None
321+
"""
322+
323+
SSRFProtectedUrlRelaxed = Annotated[
324+
HttpUrl, BeforeValidator(_validate_url_ssrf_relaxed)
325+
]
326+
"""A Pydantic HttpUrl with relaxed SSRF protection (allows private IPs).
327+
328+
Use this for development/testing webhooks where localhost/private IPs are needed.
329+
Cloud metadata endpoints are still blocked.
330+
331+
Example:
332+
class DevWebhookSchema(BaseModel):
333+
url: SSRFProtectedUrlRelaxed # Allows localhost, blocks cloud metadata
334+
"""
335+
336+
SSRFProtectedHttpsUrl = Annotated[
337+
HttpUrl, BeforeValidator(_validate_url_ssrf_https_only)
338+
]
339+
"""A Pydantic HttpUrl with SSRF protection that only allows HTTPS.
340+
341+
This blocks private IPs, localhost, cloud metadata endpoints, and HTTP URLs.
342+
343+
Example:
344+
class SecureWebhookSchema(BaseModel):
345+
url: SSRFProtectedHttpsUrl # Only HTTPS, blocks private IPs
346+
"""
347+
348+
SSRFProtectedHttpsUrlStr = Annotated[
349+
str, BeforeValidator(_validate_url_ssrf_https_only)
350+
]
351+
"""A string type with SSRF protection that only allows HTTPS URLs.
352+
353+
Same as SSRFProtectedHttpsUrl but returns a string instead of HttpUrl.
354+
Useful for FastAPI query parameters where you need a string URL.
355+
356+
Example:
357+
@router.get("/proxy")
358+
async def proxy_get(url: SSRFProtectedHttpsUrlStr):
359+
async with httpx.AsyncClient() as client:
360+
resp = await client.get(url)
361+
"""

0 commit comments

Comments
 (0)