Skip to content

Commit 359ad12

Browse files
author
trial
committed
Merge claude/issue-421-allow-host-lan into beta-clean for testing
2 parents 7cc6ce2 + b992148 commit 359ad12

8 files changed

Lines changed: 453 additions & 29 deletions

File tree

src/godot_ai/__init__.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,21 @@ def main(argv: Sequence[str] | None = None) -> None:
8585
"managed servers (CI, manual --reload)."
8686
),
8787
)
88+
parser.add_argument(
89+
"--allow-host",
90+
action="append",
91+
metavar="CIDR",
92+
default=None,
93+
help=(
94+
"Expose the server to a named LAN range for remote agents (issue "
95+
"#421). Takes a CIDR or bare IP (repeatable, or comma-separated), "
96+
"e.g. --allow-host 192.168.1.0/24. When set, both transports bind "
97+
"off loopback and the rebinding guard widens its Host allowlist to "
98+
"these networks ONLY — browser Origin / Sec-Fetch-Site checks stay "
99+
"on. Omit for the default loopback-only behavior. Prefer an SSH "
100+
"tunnel / Tailscale on untrusted networks; only name ranges you trust."
101+
),
102+
)
88103
parser.add_argument(
89104
"--exclude-domains",
90105
default="",
@@ -105,6 +120,23 @@ def main(argv: Sequence[str] | None = None) -> None:
105120
except ValueError as exc:
106121
parser.error(str(exc))
107122

123+
## #421: parse --allow-host CIDRs. A typo here fails loudly at startup
124+
## rather than silently binding loopback-only (or worse, wide open).
125+
from godot_ai.transport.origin_guard import bind_host_for_networks, parse_allow_hosts
126+
127+
try:
128+
allow_host_networks = parse_allow_hosts(args.allow_host or [])
129+
except ValueError as exc:
130+
parser.error(str(exc))
131+
132+
## Widen the HTTP bind off loopback only when an allowlist is named. The
133+
## DNS-rebinding guard still gates every request by the CIDR(s); binding
134+
## off loopback without the guard would be the footgun this flag avoids.
135+
if allow_host_networks and args.transport in ("sse", "streamable-http"):
136+
import fastmcp
137+
138+
fastmcp.settings.host = bind_host_for_networks(allow_host_networks)
139+
108140
from godot_ai.runtime_info import install_pid_file
109141

110142
install_pid_file(args.pid_file)
@@ -129,6 +161,7 @@ def main(argv: Sequence[str] | None = None) -> None:
129161
port=args.port,
130162
ws_port=args.ws_port,
131163
exclude_domains=exclude_domains,
164+
allow_host_networks=allow_host_networks,
132165
)
133166
return
134167

@@ -138,6 +171,7 @@ def main(argv: Sequence[str] | None = None) -> None:
138171
ws_port=args.ws_port,
139172
exclude_domains=exclude_domains,
140173
owner_pid=owner_pid,
174+
allow_host_networks=allow_host_networks,
141175
)
142176

143177
transport_kwargs = {}

src/godot_ai/asgi.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515
DEV_TRANSPORT_ENV = "GODOT_AI_DEV_TRANSPORT"
1616
DEV_WS_PORT_ENV = "GODOT_AI_DEV_WS_PORT"
1717
DEV_EXCLUDE_DOMAINS_ENV = "GODOT_AI_DEV_EXCLUDE_DOMAINS"
18+
## #421: reload runs the app in a uvicorn-supervised subprocess via the
19+
## ``create_app`` factory, so --allow-host CIDRs ride through as an env var
20+
## (the comma-joined CIDR strings) rather than a function argument.
21+
DEV_ALLOW_HOST_ENV = "GODOT_AI_DEV_ALLOW_HOST"
1822
RELOADABLE_TRANSPORTS = {"sse", "streamable-http"}
1923

2024
STALE_MCP_SESSION_MESSAGE = (
@@ -157,9 +161,17 @@ def create_app():
157161
"""Create the FastMCP ASGI app for uvicorn's reload supervisor."""
158162
from godot_ai.server import create_server
159163
from godot_ai.tools.domains import parse_exclude_list
164+
from godot_ai.transport.origin_guard import parse_allow_hosts
160165

161166
exclude_domains = parse_exclude_list(os.environ.get(DEV_EXCLUDE_DOMAINS_ENV, ""))
162-
server = create_server(ws_port=_get_dev_ws_port(), exclude_domains=exclude_domains)
167+
allow_host_networks = parse_allow_hosts(
168+
[v for v in os.environ.get(DEV_ALLOW_HOST_ENV, "").split(",") if v]
169+
)
170+
server = create_server(
171+
ws_port=_get_dev_ws_port(),
172+
exclude_domains=exclude_domains,
173+
allow_host_networks=allow_host_networks,
174+
)
163175
return server.http_app(transport=_get_dev_transport())
164176

165177

@@ -169,6 +181,7 @@ def run_with_reload(
169181
port: int,
170182
ws_port: int,
171183
exclude_domains: set[str] | None = None,
184+
allow_host_networks: list | None = None,
172185
) -> None:
173186
"""Run the HTTP transport through uvicorn's supported reload path."""
174187
if transport not in RELOADABLE_TRANSPORTS:
@@ -177,12 +190,20 @@ def run_with_reload(
177190
os.environ[DEV_TRANSPORT_ENV] = transport
178191
os.environ[DEV_WS_PORT_ENV] = str(ws_port)
179192
os.environ[DEV_EXCLUDE_DOMAINS_ENV] = ",".join(sorted(exclude_domains or set()))
193+
## #421: pass the CIDRs to the factory subprocess as their string forms.
194+
os.environ[DEV_ALLOW_HOST_ENV] = ",".join(str(net) for net in (allow_host_networks or []))
195+
196+
## Bind off loopback only when an allowlist is named; the guard (rebuilt
197+
## inside create_app from the same env) still gates every request.
198+
from godot_ai.transport.origin_guard import bind_host_for_networks
199+
200+
bind_host = bind_host_for_networks(allow_host_networks) or fastmcp.settings.host
180201

181202
src_dir = str(Path(__file__).resolve().parent.parent)
182203
uvicorn.run(
183204
"godot_ai.asgi:create_app",
184205
factory=True,
185-
host=fastmcp.settings.host,
206+
host=bind_host,
186207
port=port,
187208
log_level=fastmcp.settings.log_level.lower(),
188209
timeout_graceful_shutdown=2,

src/godot_ai/server.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import asyncio
66
import logging
77
import time
8-
from collections.abc import AsyncIterator, Iterable
8+
from collections.abc import AsyncIterator, Iterable, Sequence
99
from contextlib import asynccontextmanager
1010
from dataclasses import dataclass
1111
from pathlib import Path
@@ -64,7 +64,11 @@
6464
from godot_ai.tools.testing import register_testing_tools
6565
from godot_ai.tools.theme import register_theme_tools
6666
from godot_ai.tools.ui import register_ui_tools
67-
from godot_ai.transport.origin_guard import LocalhostOnlyHTTPMiddleware
67+
from godot_ai.transport.origin_guard import (
68+
IPNetwork,
69+
LocalhostOnlyHTTPMiddleware,
70+
bind_host_for_networks,
71+
)
6872
from godot_ai.transport.websocket import GodotWebSocketServer
6973

7074
logger = logging.getLogger(__name__)
@@ -96,23 +100,35 @@ def http_app(self, *args: Any, **kwargs: Any):
96100
## Outermost wrap: refuse non-loopback Host/Origin (DNS-rebinding
97101
## guard, audit-v2 finding #1). Applied to every HTTP transport
98102
## including ``sse`` so ``/godot-ai/status`` and the FastMCP
99-
## endpoints are guarded uniformly.
100-
return LocalhostOnlyHTTPMiddleware(app)
103+
## endpoints are guarded uniformly. ``--allow-host`` (#421) widens
104+
## only the Host allowlist to named LAN CIDRs; None = loopback-only.
105+
return LocalhostOnlyHTTPMiddleware(app, getattr(self, "_allow_host_networks", None))
101106

102107

103108
def create_server(
104109
ws_port: int = 9500,
105110
*,
106111
exclude_domains: Iterable[str] | None = None,
107112
owner_pid: int | None = None,
113+
allow_host_networks: Sequence[IPNetwork] | None = None,
108114
) -> FastMCP:
109115
logging.basicConfig(level=logging.INFO, format="%(name)s | %(message)s")
110116

117+
## #421: --allow-host opt-in. When set, expose both transports to the
118+
## named LAN CIDR(s) — bind the WS server off loopback and hand the
119+
## networks to its rebinding guard. None/empty = unchanged loopback-only.
120+
ws_bind_host = bind_host_for_networks(allow_host_networks) or "127.0.0.1"
121+
111122
# Capture ws_port in the lifespan closure
112123
@asynccontextmanager
113124
async def _lifespan(server: FastMCP) -> AsyncIterator[AppContext]:
114125
registry = SessionRegistry()
115-
ws_server = GodotWebSocketServer(registry, port=ws_port)
126+
ws_server = GodotWebSocketServer(
127+
registry,
128+
port=ws_port,
129+
host=ws_bind_host,
130+
allowed_networks=allow_host_networks,
131+
)
116132
client = GodotClient(ws_server, registry)
117133

118134
ws_task = asyncio.create_task(ws_server.start())
@@ -252,6 +268,10 @@ def _emit_startup() -> None:
252268
lifespan=_lifespan,
253269
)
254270

271+
## #421: stash the --allow-host CIDRs where http_app() reads them when it
272+
## installs the rebinding guard middleware. None = loopback-only (default).
273+
mcp._allow_host_networks = list(allow_host_networks) if allow_host_networks else None
274+
255275
## Middleware registration order is load-bearing — do not reorder
256276
## without reading the rationale below. Locked by
257277
## ``tests/unit/test_server_middleware_order.py``.

src/godot_ai/transport/origin_guard.py

Lines changed: 102 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,16 @@
4040

4141
from __future__ import annotations
4242

43+
import ipaddress
44+
from collections.abc import Iterable, Sequence
4345
from http import HTTPStatus
4446
from typing import Any
4547
from urllib.parse import urlsplit
4648

4749
from starlette.types import ASGIApp, Receive, Scope, Send
4850

51+
IPNetwork = ipaddress.IPv4Network | ipaddress.IPv6Network
52+
4953
LOOPBACK_HOSTNAMES: frozenset[str] = frozenset({"127.0.0.1", "localhost", "[::1]"})
5054
LOOPBACK_ORIGIN_SCHEMES: frozenset[str] = frozenset({"http", "https", "ws", "wss"})
5155

@@ -85,16 +89,87 @@ def _normalise_host(host: str) -> str:
8589
return normalised
8690

8791

88-
def is_allowed_host(host_header: str | None) -> bool:
89-
"""Whether ``host_header`` resolves to a loopback name.
92+
def parse_allow_hosts(values: Iterable[str]) -> list[IPNetwork]:
93+
"""Parse ``--allow-host`` CLI values into IP networks (issue #421).
94+
95+
Each value may be a CIDR (``192.168.1.0/24``), a bare IP
96+
(``192.168.1.50`` → a /32 or /128), or a comma-separated list of
97+
either. ``host_bits`` set on a CIDR are tolerated (``strict=False``)
98+
so ``192.168.1.5/24`` is accepted as ``192.168.1.0/24``. Raises
99+
``ValueError`` (with the offending token) on anything unparseable so
100+
a typo fails loudly at startup instead of silently exposing nothing.
101+
"""
102+
networks: list[IPNetwork] = []
103+
for raw in values:
104+
for token in str(raw).split(","):
105+
token = token.strip()
106+
if not token:
107+
continue
108+
try:
109+
networks.append(ipaddress.ip_network(token, strict=False))
110+
except ValueError as exc:
111+
raise ValueError(f"invalid --allow-host value {token!r}: {exc}") from exc
112+
return networks
113+
114+
115+
def bind_host_for_networks(networks: Sequence[IPNetwork] | None) -> str | None:
116+
"""Bind address that exposes the transports to ``networks`` (issue #421).
117+
118+
Returns ``None`` when no networks are named so callers keep their
119+
loopback default (the byte-for-byte unchanged path). Otherwise returns
120+
``"::"`` when any requested network is IPv6 — on a dual-stack host that
121+
also accepts IPv4 — and ``"0.0.0.0"`` for an IPv4-only allowlist, so an
122+
IPv6 ``--allow-host`` actually listens on IPv6 instead of silently
123+
binding IPv4-only. Centralized so the HTTP bind, the WebSocket bind, and
124+
the reload runner can't disagree about where to listen.
125+
"""
126+
if not networks:
127+
return None
128+
if any(isinstance(net, ipaddress.IPv6Network) for net in networks):
129+
return "::" # noqa: S104 — opt-in, and the guard still gates every request
130+
return "0.0.0.0" # noqa: S104 — same
131+
132+
133+
def _host_ip_in_networks(host_header: str, networks: Sequence[IPNetwork] | None) -> bool:
134+
"""Whether the Host header's IP literal falls inside one of ``networks``.
135+
136+
Only IP literals match — a DNS name (the shape a rebinding attack
137+
presents) never parses to an address, so it can't slip into an
138+
allowed network. Bracketed IPv6 (``[192.168..]`` form) is unwrapped
139+
by ``_normalise_host`` first.
140+
"""
141+
if not networks:
142+
return False
143+
candidate = _normalise_host(host_header.strip())
144+
if candidate.startswith("[") and candidate.endswith("]"):
145+
candidate = candidate[1:-1]
146+
try:
147+
ip = ipaddress.ip_address(candidate)
148+
except ValueError:
149+
return False
150+
return any(ip in net for net in networks)
151+
152+
153+
def is_allowed_host(
154+
host_header: str | None,
155+
allowed_networks: Sequence[IPNetwork] | None = None,
156+
) -> bool:
157+
"""Whether ``host_header`` resolves to a loopback name (or an allowed LAN IP).
90158
91159
Empty or missing returns False — a properly formed HTTP/1.1 request
92160
always carries a Host header, and refusing the request is safer than
93161
guessing. The WebSocket guard mirrors this.
162+
163+
When ``allowed_networks`` is supplied (the ``--allow-host`` opt-in,
164+
#421), a Host header whose IP literal falls inside one of those
165+
networks is also accepted. ``allowed_networks=None`` (the default)
166+
is byte-for-byte the original loopback-only behavior.
94167
"""
95168
if not host_header:
96169
return False
97-
return _normalise_host(host_header.strip()) in LOOPBACK_HOSTNAMES
170+
if _normalise_host(host_header.strip()) in LOOPBACK_HOSTNAMES:
171+
return True
172+
return _host_ip_in_networks(host_header, allowed_networks)
98173

99174

100175
def is_allowed_origin(origin_header: str | None) -> bool:
@@ -146,6 +221,7 @@ def evaluate_loopback(
146221
hosts: list[str],
147222
origins: list[str],
148223
sec_fetch_sites: list[str] | None = None,
224+
allowed_networks: Sequence[IPNetwork] | None = None,
149225
) -> bool:
150226
"""Return True iff the request's headers pass the allowlist.
151227
@@ -155,6 +231,13 @@ def evaluate_loopback(
155231
the Sec-Fetch-Site cross-origin reject rule are evaluated identically.
156232
A divergence between the two transports would be a security
157233
regression — this helper exists to prevent it.
234+
235+
``allowed_networks`` (the ``--allow-host`` opt-in, #421) only widens
236+
the *Host* allowlist to named LAN CIDRs. The Origin and Sec-Fetch-Site
237+
rules are deliberately left untouched: a browser on the LAN sends a
238+
non-loopback Origin (rejected) and a foreign Sec-Fetch-Site (rejected),
239+
so DNS-rebinding defense survives the opt-in. A native remote agent
240+
sends neither header, so it passes once its Host IP is allowed.
158241
"""
159242
if len(hosts) > 1 or len(origins) > 1:
160243
return False
@@ -164,7 +247,7 @@ def evaluate_loopback(
164247
origin = origins[0] if origins else None
165248
sec_fetch_site = sec_fetch_sites[0] if sec_fetch_sites else None
166249
return (
167-
is_allowed_host(host)
250+
is_allowed_host(host, allowed_networks)
168251
and is_allowed_origin(origin)
169252
and is_allowed_sec_fetch_site(sec_fetch_site)
170253
)
@@ -178,8 +261,14 @@ class LocalhostOnlyHTTPMiddleware:
178261
before any inner middleware. Non-HTTP scopes (lifespan) pass through.
179262
"""
180263

181-
def __init__(self, app: ASGIApp) -> None:
264+
def __init__(
265+
self,
266+
app: ASGIApp,
267+
allowed_networks: Sequence[IPNetwork] | None = None,
268+
) -> None:
182269
self.app = app
270+
# #421: empty/None keeps the loopback-only behavior byte-for-byte.
271+
self.allowed_networks = list(allowed_networks) if allowed_networks else None
183272

184273
def __getattr__(self, name: str) -> Any:
185274
# Mirror StaleMcpSessionDiagnosticMiddleware: FastMCP / uvicorn
@@ -203,7 +292,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
203292
elif key == b"sec-fetch-site":
204293
sec_fetch_sites.append(raw_value.decode("latin-1"))
205294

206-
if evaluate_loopback(hosts, origins, sec_fetch_sites):
295+
if evaluate_loopback(hosts, origins, sec_fetch_sites, self.allowed_networks):
207296
await self.app(scope, receive, send)
208297
return
209298
await _send_forbidden(send)
@@ -223,15 +312,20 @@ async def _send_forbidden(send: Send) -> None:
223312
await send({"type": "http.response.body", "body": FORBIDDEN_BODY, "more_body": False})
224313

225314

226-
def make_websocket_request_guard():
315+
def make_websocket_request_guard(allowed_networks: Sequence[IPNetwork] | None = None):
227316
"""Return a ``process_request`` hook for ``websockets.asyncio.server.serve``.
228317
229318
The hook fires before the WebSocket upgrade. When Host or Origin
230319
fails the loopback allowlist the hook synthesizes an HTTP 403 via
231320
``connection.respond(...)``; returning that response from
232321
``process_request`` aborts the upgrade without ever creating a
233322
Session.
323+
324+
``allowed_networks`` (the ``--allow-host`` opt-in, #421) widens the
325+
Host allowlist identically to the HTTP middleware so the two
326+
transports never diverge.
234327
"""
328+
networks = list(allowed_networks) if allowed_networks else None
235329

236330
async def guard(connection, request):
237331
## Use ``get_all`` so a smuggled duplicate (two ``Host:`` lines)
@@ -240,7 +334,7 @@ async def guard(connection, request):
240334
hosts = list(request.headers.get_all("Host"))
241335
origins = list(request.headers.get_all("Origin"))
242336
sec_fetch_sites = list(request.headers.get_all("Sec-Fetch-Site"))
243-
if evaluate_loopback(hosts, origins, sec_fetch_sites):
337+
if evaluate_loopback(hosts, origins, sec_fetch_sites, networks):
244338
return None
245339
return connection.respond(HTTPStatus.FORBIDDEN, FORBIDDEN_BODY_TEXT)
246340

0 commit comments

Comments
 (0)