4040
4141from __future__ import annotations
4242
43+ import ipaddress
44+ from collections .abc import Iterable , Sequence
4345from http import HTTPStatus
4446from typing import Any
4547from urllib .parse import urlsplit
4648
4749from starlette .types import ASGIApp , Receive , Scope , Send
4850
51+ IPNetwork = ipaddress .IPv4Network | ipaddress .IPv6Network
52+
4953LOOPBACK_HOSTNAMES : frozenset [str ] = frozenset ({"127.0.0.1" , "localhost" , "[::1]" })
5054LOOPBACK_ORIGIN_SCHEMES : frozenset [str ] = frozenset ({"http" , "https" , "ws" , "wss" })
5155
@@ -85,16 +89,87 @@ def _normalise_host(host: str) -> str:
8589 return normalised
8690
8791
88- def is_allowed_host (host_header : str | None ) -> bool :
89- """Whether ``host_header`` resolves to a loopback name.
92+ def parse_allow_hosts (values : Iterable [str ]) -> list [IPNetwork ]:
93+ """Parse ``--allow-host`` CLI values into IP networks (issue #421).
94+
95+ Each value may be a CIDR (``192.168.1.0/24``), a bare IP
96+ (``192.168.1.50`` → a /32 or /128), or a comma-separated list of
97+ either. ``host_bits`` set on a CIDR are tolerated (``strict=False``)
98+ so ``192.168.1.5/24`` is accepted as ``192.168.1.0/24``. Raises
99+ ``ValueError`` (with the offending token) on anything unparseable so
100+ a typo fails loudly at startup instead of silently exposing nothing.
101+ """
102+ networks : list [IPNetwork ] = []
103+ for raw in values :
104+ for token in str (raw ).split ("," ):
105+ token = token .strip ()
106+ if not token :
107+ continue
108+ try :
109+ networks .append (ipaddress .ip_network (token , strict = False ))
110+ except ValueError as exc :
111+ raise ValueError (f"invalid --allow-host value { token !r} : { exc } " ) from exc
112+ return networks
113+
114+
115+ def bind_host_for_networks (networks : Sequence [IPNetwork ] | None ) -> str | None :
116+ """Bind address that exposes the transports to ``networks`` (issue #421).
117+
118+ Returns ``None`` when no networks are named so callers keep their
119+ loopback default (the byte-for-byte unchanged path). Otherwise returns
120+ ``"::"`` when any requested network is IPv6 — on a dual-stack host that
121+ also accepts IPv4 — and ``"0.0.0.0"`` for an IPv4-only allowlist, so an
122+ IPv6 ``--allow-host`` actually listens on IPv6 instead of silently
123+ binding IPv4-only. Centralized so the HTTP bind, the WebSocket bind, and
124+ the reload runner can't disagree about where to listen.
125+ """
126+ if not networks :
127+ return None
128+ if any (isinstance (net , ipaddress .IPv6Network ) for net in networks ):
129+ return "::" # noqa: S104 — opt-in, and the guard still gates every request
130+ return "0.0.0.0" # noqa: S104 — same
131+
132+
133+ def _host_ip_in_networks (host_header : str , networks : Sequence [IPNetwork ] | None ) -> bool :
134+ """Whether the Host header's IP literal falls inside one of ``networks``.
135+
136+ Only IP literals match — a DNS name (the shape a rebinding attack
137+ presents) never parses to an address, so it can't slip into an
138+ allowed network. Bracketed IPv6 (``[192.168..]`` form) is unwrapped
139+ by ``_normalise_host`` first.
140+ """
141+ if not networks :
142+ return False
143+ candidate = _normalise_host (host_header .strip ())
144+ if candidate .startswith ("[" ) and candidate .endswith ("]" ):
145+ candidate = candidate [1 :- 1 ]
146+ try :
147+ ip = ipaddress .ip_address (candidate )
148+ except ValueError :
149+ return False
150+ return any (ip in net for net in networks )
151+
152+
153+ def is_allowed_host (
154+ host_header : str | None ,
155+ allowed_networks : Sequence [IPNetwork ] | None = None ,
156+ ) -> bool :
157+ """Whether ``host_header`` resolves to a loopback name (or an allowed LAN IP).
90158
91159 Empty or missing returns False — a properly formed HTTP/1.1 request
92160 always carries a Host header, and refusing the request is safer than
93161 guessing. The WebSocket guard mirrors this.
162+
163+ When ``allowed_networks`` is supplied (the ``--allow-host`` opt-in,
164+ #421), a Host header whose IP literal falls inside one of those
165+ networks is also accepted. ``allowed_networks=None`` (the default)
166+ is byte-for-byte the original loopback-only behavior.
94167 """
95168 if not host_header :
96169 return False
97- return _normalise_host (host_header .strip ()) in LOOPBACK_HOSTNAMES
170+ if _normalise_host (host_header .strip ()) in LOOPBACK_HOSTNAMES :
171+ return True
172+ return _host_ip_in_networks (host_header , allowed_networks )
98173
99174
100175def is_allowed_origin (origin_header : str | None ) -> bool :
@@ -146,6 +221,7 @@ def evaluate_loopback(
146221 hosts : list [str ],
147222 origins : list [str ],
148223 sec_fetch_sites : list [str ] | None = None ,
224+ allowed_networks : Sequence [IPNetwork ] | None = None ,
149225) -> bool :
150226 """Return True iff the request's headers pass the allowlist.
151227
@@ -155,6 +231,13 @@ def evaluate_loopback(
155231 the Sec-Fetch-Site cross-origin reject rule are evaluated identically.
156232 A divergence between the two transports would be a security
157233 regression — this helper exists to prevent it.
234+
235+ ``allowed_networks`` (the ``--allow-host`` opt-in, #421) only widens
236+ the *Host* allowlist to named LAN CIDRs. The Origin and Sec-Fetch-Site
237+ rules are deliberately left untouched: a browser on the LAN sends a
238+ non-loopback Origin (rejected) and a foreign Sec-Fetch-Site (rejected),
239+ so DNS-rebinding defense survives the opt-in. A native remote agent
240+ sends neither header, so it passes once its Host IP is allowed.
158241 """
159242 if len (hosts ) > 1 or len (origins ) > 1 :
160243 return False
@@ -164,7 +247,7 @@ def evaluate_loopback(
164247 origin = origins [0 ] if origins else None
165248 sec_fetch_site = sec_fetch_sites [0 ] if sec_fetch_sites else None
166249 return (
167- is_allowed_host (host )
250+ is_allowed_host (host , allowed_networks )
168251 and is_allowed_origin (origin )
169252 and is_allowed_sec_fetch_site (sec_fetch_site )
170253 )
@@ -178,8 +261,14 @@ class LocalhostOnlyHTTPMiddleware:
178261 before any inner middleware. Non-HTTP scopes (lifespan) pass through.
179262 """
180263
181- def __init__ (self , app : ASGIApp ) -> None :
264+ def __init__ (
265+ self ,
266+ app : ASGIApp ,
267+ allowed_networks : Sequence [IPNetwork ] | None = None ,
268+ ) -> None :
182269 self .app = app
270+ # #421: empty/None keeps the loopback-only behavior byte-for-byte.
271+ self .allowed_networks = list (allowed_networks ) if allowed_networks else None
183272
184273 def __getattr__ (self , name : str ) -> Any :
185274 # Mirror StaleMcpSessionDiagnosticMiddleware: FastMCP / uvicorn
@@ -203,7 +292,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
203292 elif key == b"sec-fetch-site" :
204293 sec_fetch_sites .append (raw_value .decode ("latin-1" ))
205294
206- if evaluate_loopback (hosts , origins , sec_fetch_sites ):
295+ if evaluate_loopback (hosts , origins , sec_fetch_sites , self . allowed_networks ):
207296 await self .app (scope , receive , send )
208297 return
209298 await _send_forbidden (send )
@@ -223,15 +312,20 @@ async def _send_forbidden(send: Send) -> None:
223312 await send ({"type" : "http.response.body" , "body" : FORBIDDEN_BODY , "more_body" : False })
224313
225314
226- def make_websocket_request_guard ():
315+ def make_websocket_request_guard (allowed_networks : Sequence [ IPNetwork ] | None = None ):
227316 """Return a ``process_request`` hook for ``websockets.asyncio.server.serve``.
228317
229318 The hook fires before the WebSocket upgrade. When Host or Origin
230319 fails the loopback allowlist the hook synthesizes an HTTP 403 via
231320 ``connection.respond(...)``; returning that response from
232321 ``process_request`` aborts the upgrade without ever creating a
233322 Session.
323+
324+ ``allowed_networks`` (the ``--allow-host`` opt-in, #421) widens the
325+ Host allowlist identically to the HTTP middleware so the two
326+ transports never diverge.
234327 """
328+ networks = list (allowed_networks ) if allowed_networks else None
235329
236330 async def guard (connection , request ):
237331 ## Use ``get_all`` so a smuggled duplicate (two ``Host:`` lines)
@@ -240,7 +334,7 @@ async def guard(connection, request):
240334 hosts = list (request .headers .get_all ("Host" ))
241335 origins = list (request .headers .get_all ("Origin" ))
242336 sec_fetch_sites = list (request .headers .get_all ("Sec-Fetch-Site" ))
243- if evaluate_loopback (hosts , origins , sec_fetch_sites ):
337+ if evaluate_loopback (hosts , origins , sec_fetch_sites , networks ):
244338 return None
245339 return connection .respond (HTTPStatus .FORBIDDEN , FORBIDDEN_BODY_TEXT )
246340
0 commit comments