|
1 | 1 | """API ViewSet for proxying external images.""" |
2 | 2 |
|
3 | | -import ipaddress |
4 | 3 | import logging |
5 | | -import socket |
6 | | -from urllib.parse import ParseResult, unquote, urlparse, urlunparse |
| 4 | +from urllib.parse import unquote |
7 | 5 |
|
8 | 6 | from django.conf import settings |
9 | 7 | from django.http import HttpResponse |
10 | 8 |
|
11 | 9 | import magic |
12 | 10 | import requests |
13 | 11 | from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema |
14 | | -from requests.adapters import HTTPAdapter |
15 | 12 | from rest_framework import status as http_status |
16 | 13 | from rest_framework.response import Response |
17 | 14 | from rest_framework.viewsets import ViewSet |
18 | 15 |
|
19 | 16 | from core import enums, models |
20 | 17 | from core.api import permissions |
| 18 | +from core.services.ssrf import SSRFSafeSession, SSRFValidationError |
21 | 19 |
|
22 | 20 | logger = logging.getLogger(__name__) |
23 | 21 |
|
24 | 22 |
|
25 | | -class SSRFValidationError(Exception): |
26 | | - """Exception raised when URL validation fails due to SSRF protection.""" |
27 | | - |
28 | | - |
29 | | -class SSRFProtectedAdapter(HTTPAdapter): |
30 | | - """ |
31 | | - HTTPAdapter that connects to a pre-validated IP address while maintaining |
32 | | - proper TLS certificate verification against the original hostname. |
33 | | -
|
34 | | - This prevents TOCTOU DNS rebinding attacks by: |
35 | | - 1. Connecting to the IP address that was validated (not re-resolving DNS) |
36 | | - 2. Verifying TLS certificates against the original hostname (for HTTPS) |
37 | | - 3. Setting the Host header correctly for virtual hosting |
38 | | - """ |
39 | | - |
40 | | - def __init__( |
41 | | - self, |
42 | | - dest_ip: str, |
43 | | - dest_port: int, |
44 | | - original_hostname: str, |
45 | | - original_scheme: str, |
46 | | - **kwargs, |
47 | | - ): |
48 | | - self.dest_ip = dest_ip |
49 | | - self.dest_port = dest_port |
50 | | - self.original_hostname = original_hostname |
51 | | - self.original_scheme = original_scheme |
52 | | - super().__init__(**kwargs) |
53 | | - |
54 | | - def init_poolmanager(self, connections, maxsize, block=False, **pool_kwargs): |
55 | | - """Initialize pool manager with TLS hostname verification settings.""" |
56 | | - if self.original_scheme == "https": |
57 | | - # Ensure TLS certificate is verified against the original hostname |
58 | | - # even though we're connecting to an IP address |
59 | | - pool_kwargs["assert_hostname"] = self.original_hostname |
60 | | - pool_kwargs["server_hostname"] = self.original_hostname |
61 | | - super().init_poolmanager(connections, maxsize, block, **pool_kwargs) |
62 | | - |
63 | | - def send( |
64 | | - self, request, stream=False, timeout=None, verify=True, cert=None, proxies=None |
65 | | - ): |
66 | | - """Send request, rewriting URL to connect to the validated IP address.""" |
67 | | - parsed = urlparse(request.url) |
68 | | - |
69 | | - # Build URL with validated IP instead of hostname |
70 | | - # IPv6 addresses need brackets in URLs |
71 | | - if ":" in self.dest_ip: |
72 | | - ip_netloc = f"[{self.dest_ip}]:{self.dest_port}" |
73 | | - else: |
74 | | - ip_netloc = f"{self.dest_ip}:{self.dest_port}" |
75 | | - |
76 | | - # Reconstruct URL with IP address |
77 | | - request.url = urlunparse( |
78 | | - ( |
79 | | - parsed.scheme, |
80 | | - ip_netloc, |
81 | | - parsed.path, |
82 | | - parsed.params, |
83 | | - parsed.query, |
84 | | - parsed.fragment, |
85 | | - ) |
86 | | - ) |
87 | | - |
88 | | - # Set Host header to original hostname for virtual hosting |
89 | | - # Include port only if non-standard |
90 | | - if parsed.port and parsed.port not in (80, 443): |
91 | | - request.headers["Host"] = f"{self.original_hostname}:{parsed.port}" |
92 | | - else: |
93 | | - request.headers["Host"] = self.original_hostname |
94 | | - |
95 | | - return super().send( |
96 | | - request, |
97 | | - stream=stream, |
98 | | - timeout=timeout, |
99 | | - verify=verify, |
100 | | - cert=cert, |
101 | | - proxies=proxies, |
102 | | - ) |
103 | | - |
104 | | - |
105 | | -class SSRFSafeSession: |
106 | | - """ |
107 | | - HTTP Session with built-in SSRF protection. |
108 | | -
|
109 | | - This class provides a safe way to make HTTP requests by: |
110 | | - 1. Validating URL scheme (only http/https allowed) |
111 | | - 2. Blocking direct IP addresses (legitimate services use domain names) |
112 | | - 3. Resolving hostnames and blocking private/internal IPs |
113 | | - 4. Pinning resolved IPs to prevent DNS rebinding attacks (TOCTOU) |
114 | | -
|
115 | | - Usage: |
116 | | - try: |
117 | | - response = SSRFSafeSession().get("https://example.com/image.png", timeout=10) |
118 | | - except SSRFValidationError: |
119 | | - # URL was blocked for security reasons |
120 | | - pass |
121 | | - """ |
122 | | - |
123 | | - def _validate_url(self, parsed_url: ParseResult) -> list[str]: |
124 | | - """ |
125 | | - Validate that a URL is safe to fetch (SSRF protection). |
126 | | -
|
127 | | - This function prevents Server-Side Request Forgery (SSRF) attacks by |
128 | | - validating URLs before making HTTP requests. It implements a defense-in-depth |
129 | | - approach: |
130 | | -
|
131 | | - 1. Only allows http/https schemes |
132 | | - 2. Blocks all IP addresses (legitimate emails use domain names) |
133 | | - 3. Resolves hostnames and blocks if they resolve to private/internal IPs |
134 | | - (prevents DNS rebinding attacks where attacker-controlled DNS returns |
135 | | - 127.0.0.1 or internal IPs) |
136 | | -
|
137 | | - Blocked addresses include: |
138 | | - - Any direct IP address (e.g., http://192.168.1.1/) |
139 | | - - Private IP ranges (RFC1918: 10.x.x.x, 172.16-31.x.x, 192.168.x.x) |
140 | | - - Loopback addresses (127.x.x.x, ::1) |
141 | | - - Link-local addresses (169.254.x.x, fe80::/10) |
142 | | - - Multicast and reserved addresses |
143 | | - - Cloud provider metadata endpoints (169.254.169.254, fd00:ec2::254) |
144 | | -
|
145 | | - Args: |
146 | | - parsed_url: The parsed URL to validate |
147 | | -
|
148 | | - Returns: |
149 | | - List of validated IP addresses that the hostname resolves to |
150 | | -
|
151 | | - Raises: |
152 | | - SSRFValidationError: If the URL is unsafe |
153 | | - """ |
154 | | - # Only allow http and https schemes |
155 | | - if parsed_url.scheme not in {"http", "https"}: |
156 | | - raise SSRFValidationError("Invalid URL scheme (only http/https allowed)") |
157 | | - |
158 | | - # Require a hostname |
159 | | - if not parsed_url.hostname: |
160 | | - raise SSRFValidationError("Invalid URL (missing hostname)") |
161 | | - |
162 | | - # Block all IP addresses (legitimate services use domain names) |
163 | | - try: |
164 | | - ipaddress.ip_address(parsed_url.hostname) |
165 | | - raise SSRFValidationError( |
166 | | - "IP addresses are not allowed (domain name required)" |
167 | | - ) |
168 | | - except ValueError: |
169 | | - # Not an IP address, continue validation |
170 | | - pass |
171 | | - |
172 | | - # Resolve hostname to IP addresses |
173 | | - try: |
174 | | - addr_info = socket.getaddrinfo( |
175 | | - parsed_url.hostname, None, socket.AF_UNSPEC, socket.SOCK_STREAM |
176 | | - ) |
177 | | - except socket.gaierror as exc: |
178 | | - raise SSRFValidationError("Unable to resolve hostname") from exc |
179 | | - |
180 | | - # Check all resolved IP addresses |
181 | | - valid_ips = [] |
182 | | - for _, _, _, _, sockaddr in addr_info: |
183 | | - ip_str = sockaddr[0] |
184 | | - try: |
185 | | - ip_addr = ipaddress.ip_address(ip_str) |
186 | | - |
187 | | - if ip_addr.is_private: |
188 | | - raise SSRFValidationError("Domain resolves to private IP address") |
189 | | - |
190 | | - if ip_addr.is_loopback: |
191 | | - raise SSRFValidationError("Domain resolves to loopback address") |
192 | | - |
193 | | - if ip_addr.is_link_local: |
194 | | - raise SSRFValidationError("Domain resolves to link-local address") |
195 | | - |
196 | | - if ip_addr.is_multicast: |
197 | | - raise SSRFValidationError("Domain resolves to multicast address") |
198 | | - |
199 | | - if ip_addr.is_reserved: |
200 | | - raise SSRFValidationError("Domain resolves to reserved address") |
201 | | - |
202 | | - # Block known cloud metadata IPs |
203 | | - if ip_str in ("169.254.169.254", "fd00:ec2::254"): |
204 | | - raise SSRFValidationError( |
205 | | - "Domain resolves to cloud metadata endpoint" |
206 | | - ) |
207 | | - |
208 | | - valid_ips.append(ip_str) |
209 | | - |
210 | | - except ValueError as exc: |
211 | | - raise SSRFValidationError("Invalid IP address in DNS response") from exc |
212 | | - |
213 | | - if not valid_ips: |
214 | | - raise SSRFValidationError("No valid IP addresses found") |
215 | | - |
216 | | - return valid_ips |
217 | | - |
218 | | - def get(self, url: str, timeout: int, **kwargs) -> requests.Response: |
219 | | - """ |
220 | | - Perform a safe HTTP GET request with SSRF protection and IP pinning. |
221 | | -
|
222 | | - This method: |
223 | | - 1. Parses and validates the URL |
224 | | - 2. Resolves DNS and validates all returned IPs |
225 | | - 3. Creates a requests Session with a custom HTTPAdapter that: |
226 | | - - Connects directly to the validated IP (preventing DNS rebinding) |
227 | | - - Maintains proper TLS certificate verification against the hostname |
228 | | - - Sets the Host header correctly for virtual hosting |
229 | | -
|
230 | | - Args: |
231 | | - url: The URL to fetch |
232 | | - timeout: Request timeout in seconds |
233 | | - **kwargs: Additional arguments passed to requests.Session.get() |
234 | | -
|
235 | | - Returns: |
236 | | - requests.Response object |
237 | | -
|
238 | | - Raises: |
239 | | - SSRFValidationError: If the URL fails security validation |
240 | | - requests.RequestException: If the HTTP request fails |
241 | | - """ |
242 | | - parsed_url = urlparse(url) |
243 | | - valid_ips = self._validate_url(parsed_url) |
244 | | - |
245 | | - # Determine the port (explicit or default based on scheme) |
246 | | - if parsed_url.port: |
247 | | - port = parsed_url.port |
248 | | - elif parsed_url.scheme == "http": |
249 | | - port = 80 |
250 | | - else: |
251 | | - port = 443 |
252 | | - |
253 | | - # Create a session with our SSRF-protected adapter that pins to the validated IP |
254 | | - session = requests.Session() |
255 | | - adapter = SSRFProtectedAdapter( |
256 | | - dest_ip=valid_ips[0], |
257 | | - dest_port=port, |
258 | | - original_hostname=parsed_url.hostname, |
259 | | - original_scheme=parsed_url.scheme, |
260 | | - ) |
261 | | - |
262 | | - # Mount the adapter for both http and https schemes |
263 | | - session.mount("http://", adapter) |
264 | | - session.mount("https://", adapter) |
265 | | - |
266 | | - return session.get(url, timeout=timeout, **kwargs) |
267 | | - |
268 | | - |
269 | 23 | class ImageProxySuspiciousResponse(HttpResponse): |
270 | 24 | """ |
271 | 25 | Response for suspicious content that has been blocked by our image proxy. |
@@ -356,7 +110,6 @@ def list(self, request, mailbox_id=None): |
356 | 110 | timeout=10, |
357 | 111 | stream=True, |
358 | 112 | headers={"User-Agent": "Messages-ImageProxy/1.0"}, |
359 | | - allow_redirects=False, |
360 | 113 | ) |
361 | 114 | response.raise_for_status() |
362 | 115 |
|
|
0 commit comments