Skip to content

Commit 10a37ce

Browse files
authored
Merge branch 'main' into feat/event-driven-messaging
2 parents c5920b3 + 2a86bef commit 10a37ce

7 files changed

Lines changed: 270 additions & 13 deletions

File tree

README.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,13 @@ For ready-to-use examples, see [`examples/cross-provider/`](examples/cross-provi
408408

409409
## Security
410410

411-
See [SECURITY.md](SECURITY.md) for vulnerability reporting, security scanning, and best practices.
411+
### DNS Rebinding Protection
412+
413+
The CAO server validates HTTP `Host` headers to prevent [DNS rebinding attacks](https://owasp.org/www-community/attacks/DNS_Rebinding). Only `localhost` and `127.0.0.1` are accepted by default — requests with other hostnames are rejected with `400 Bad Request`.
414+
415+
**Note:** If you need to expose the server on a network (not recommended for development use), be aware that the Host header validation will reject requests unless the hostname matches the allowed list.
416+
417+
For more details, see [SECURITY.md](SECURITY.md) for vulnerability reporting, security scanning, and best practices.
412418

413419
## Contributing
414420

src/cli_agent_orchestrator/api/main.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Annotated, Dict, List, Optional
77

88
from fastapi import FastAPI, HTTPException, Path, Query, status
9+
from fastapi.middleware.trustedhost import TrustedHostMiddleware
910
from pydantic import BaseModel, Field, field_validator
1011

1112
from cli_agent_orchestrator.clients.database import (
@@ -14,6 +15,8 @@
1415
init_db,
1516
)
1617
from cli_agent_orchestrator.constants import (
18+
ALLOWED_HOSTS,
19+
INBOX_POLLING_INTERVAL,
1720
SERVER_HOST,
1821
SERVER_PORT,
1922
SERVER_VERSION,
@@ -126,6 +129,14 @@ async def lifespan(app: FastAPI):
126129
lifespan=lifespan,
127130
)
128131

132+
# Security: DNS Rebinding Protection
133+
# Validate Host header to prevent DNS rebinding attacks (CVE mitigation)
134+
# Only allow requests with localhost Host headers
135+
app.add_middleware(
136+
TrustedHostMiddleware,
137+
allowed_hosts=ALLOWED_HOSTS,
138+
)
139+
129140

130141
@app.get("/health")
131142
async def health_check():

src/cli_agent_orchestrator/constants.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,3 +101,14 @@
101101

102102
# CORS allowed origins for web-based clients
103103
CORS_ORIGINS = ["http://localhost:3000", "http://127.0.0.1:3000"]
104+
105+
# Allowed Host headers for DNS rebinding protection (CVE mitigation)
106+
# Only localhost connections permitted - CAO is a local-only service
107+
# These hosts are validated by TrustedHostMiddleware to prevent DNS rebinding attacks
108+
# Note: IPv6 (::1) is not included as CAO is accessed via IPv4 localhost in practice
109+
# Future extension point: To allow additional hosts, add --allowed-hosts CLI flag
110+
# or CAO_ALLOWED_HOSTS env var (comma-separated) that modifies this list
111+
ALLOWED_HOSTS = [
112+
"localhost",
113+
"127.0.0.1",
114+
]

test/api/conftest.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
"""Shared fixtures for API tests."""
2+
3+
import pytest
4+
from fastapi.testclient import TestClient
5+
6+
from cli_agent_orchestrator.api.main import app
7+
8+
9+
class TestClientWithHost(TestClient):
10+
"""TestClient that always sends correct Host header for TrustedHostMiddleware."""
11+
12+
def request(self, method, url, **kwargs):
13+
# Ensure Host header is always set to localhost
14+
if "headers" not in kwargs or kwargs["headers"] is None:
15+
kwargs["headers"] = {}
16+
17+
# Check if Host header is already present (case-insensitive)
18+
headers_dict = kwargs["headers"]
19+
has_host = any(k.lower() == "host" for k in headers_dict.keys())
20+
21+
if not has_host:
22+
headers_dict["Host"] = "localhost"
23+
24+
return super().request(method, url, **kwargs)
25+
26+
27+
@pytest.fixture
28+
def client():
29+
"""Test client with proper Host header for security middleware."""
30+
return TestClientWithHost(app)

test/api/test_inbox_messages.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,6 @@
1111
from cli_agent_orchestrator.models.inbox import InboxMessage, MessageStatus
1212

1313

14-
@pytest.fixture
15-
def client():
16-
"""Create a test client."""
17-
return TestClient(app)
18-
19-
2014
@pytest.fixture
2115
def sample_inbox_messages():
2216
"""Create sample inbox messages for testing."""

test/api/test_security.py

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
"""Security tests for DNS rebinding protection and host validation."""
2+
3+
import pytest
4+
from fastapi.testclient import TestClient
5+
6+
from cli_agent_orchestrator.api.main import app
7+
8+
client = TestClient(app)
9+
10+
11+
class TestDNSRebindingProtection:
12+
"""Test suite for DNS rebinding attack prevention via TrustedHostMiddleware."""
13+
14+
def test_localhost_hostname_allowed(self):
15+
"""Legitimate requests with 'localhost' Host header should be accepted."""
16+
response = client.get("/health", headers={"Host": "localhost"})
17+
assert response.status_code == 200
18+
assert response.json()["status"] == "ok"
19+
20+
def test_localhost_hostname_with_port_allowed(self):
21+
"""Requests with 'localhost:9889' Host header should be accepted."""
22+
response = client.get("/health", headers={"Host": "localhost:9889"})
23+
assert response.status_code == 200
24+
25+
def test_ipv4_loopback_allowed(self):
26+
"""IPv4 loopback address '127.0.0.1' should be allowed."""
27+
response = client.get("/health", headers={"Host": "127.0.0.1"})
28+
assert response.status_code == 200
29+
30+
def test_ipv4_loopback_with_port_allowed(self):
31+
"""IPv4 loopback with port '127.0.0.1:9889' should be allowed."""
32+
response = client.get("/health", headers={"Host": "127.0.0.1:9889"})
33+
assert response.status_code == 200
34+
35+
def test_ipv6_loopback_with_brackets_blocked(self):
36+
"""IPv6 loopback '[::1]' should be blocked (not in ALLOWED_HOSTS)."""
37+
response = client.get("/health", headers={"Host": "[::1]"})
38+
assert response.status_code == 400
39+
40+
def test_ipv6_loopback_without_brackets_blocked(self):
41+
"""IPv6 loopback '::1' should be blocked (not in ALLOWED_HOSTS)."""
42+
response = client.get("/health", headers={"Host": "::1"})
43+
assert response.status_code == 400
44+
45+
def test_arbitrary_domain_rejected(self):
46+
"""Requests with arbitrary domain Host header should be blocked."""
47+
response = client.get("/health", headers={"Host": "attack.poc"})
48+
assert response.status_code == 400
49+
50+
def test_external_domain_rejected(self):
51+
"""External domains like 'example.com' should be rejected."""
52+
response = client.get("/health", headers={"Host": "example.com"})
53+
assert response.status_code == 400
54+
55+
def test_malicious_domain_rejected(self):
56+
"""Malicious domains should be rejected."""
57+
response = client.get("/health", headers={"Host": "malicious-site.com"})
58+
assert response.status_code == 400
59+
60+
def test_dns_rebinding_attack_simulation(self):
61+
"""Simulate DNS rebinding attack - attacker's domain after rebind should be blocked."""
62+
# After DNS rebinding, attacker's domain points to 127.0.0.1
63+
# But Host header still says "attack.poc"
64+
response = client.post(
65+
"/sessions",
66+
headers={"Host": "attack.poc"},
67+
params={"provider": "kiro_cli", "agent_profile": "developer"},
68+
)
69+
# Should be blocked before reaching the endpoint
70+
assert response.status_code == 400
71+
72+
def test_subdomain_of_localhost_rejected(self):
73+
"""Subdomains of localhost should be rejected (e.g., 'evil.localhost')."""
74+
response = client.get("/health", headers={"Host": "evil.localhost"})
75+
assert response.status_code == 400
76+
77+
def test_localhost_lookalike_rejected(self):
78+
"""Domains that look like localhost should be rejected."""
79+
response = client.get("/health", headers={"Host": "localhost.attacker.com"})
80+
assert response.status_code == 400
81+
82+
def test_ip_lookalike_rejected(self):
83+
"""Domains that look like IP addresses should be rejected."""
84+
response = client.get("/health", headers={"Host": "127.0.0.2"})
85+
assert response.status_code == 400
86+
87+
def test_missing_host_header_rejected(self):
88+
"""Requests without Host header should be rejected."""
89+
# Note: TestClient automatically adds Host header, so we test with empty string
90+
response = client.get("/health", headers={"Host": ""})
91+
assert response.status_code == 400
92+
93+
94+
class TestCriticalEndpointProtection:
95+
"""Test that critical endpoints are protected from DNS rebinding."""
96+
97+
def test_create_session_protected(self):
98+
"""POST /sessions endpoint should reject malicious Host headers."""
99+
response = client.post(
100+
"/sessions",
101+
headers={"Host": "malicious.com"},
102+
params={"provider": "kiro_cli", "agent_profile": "developer"},
103+
)
104+
assert response.status_code == 400
105+
106+
def test_send_terminal_input_protected(self):
107+
"""POST /terminals/{id}/input should reject malicious Host headers."""
108+
response = client.post(
109+
"/terminals/fake-id/input",
110+
headers={"Host": "attacker.poc"},
111+
params={"message": "malicious command"},
112+
)
113+
assert response.status_code == 400
114+
115+
def test_get_terminal_output_protected(self):
116+
"""GET /terminals/{id}/output should reject malicious Host headers."""
117+
response = client.get(
118+
"/terminals/fake-id/output",
119+
headers={"Host": "evil.example.com"},
120+
params={"mode": "full"},
121+
)
122+
assert response.status_code == 400
123+
124+
def test_delete_session_protected(self):
125+
"""DELETE /sessions/{name} should reject malicious Host headers."""
126+
response = client.delete("/sessions/fake-session", headers={"Host": "attacker.com"})
127+
assert response.status_code == 400
128+
129+
130+
class TestRealWorldAttackScenarios:
131+
"""Test scenarios from the actual CVE report."""
132+
133+
def test_cao_terminal_injection_poc_blocked(self):
134+
"""
135+
Simulate the exact attack from the security report PoC.
136+
137+
The attacker's JavaScript tries to:
138+
1. Enumerate sessions: GET /sessions with Host: attack.poc
139+
2. List terminals: GET /sessions/{name}/terminals with Host: attack.poc
140+
3. Inject prompt: POST /terminals/{id}/input with Host: attack.poc
141+
4. Read output: GET /terminals/{id}/output with Host: attack.poc
142+
143+
All should be blocked by TrustedHostMiddleware.
144+
"""
145+
# Step 1: Enumerate sessions (should be blocked)
146+
response = client.get("/sessions", headers={"Host": "attack.poc"})
147+
assert response.status_code == 400
148+
149+
# Step 2: List terminals (should be blocked)
150+
response = client.get(
151+
"/sessions/cao-fake-session/terminals", headers={"Host": "attack.poc"}
152+
)
153+
assert response.status_code == 400
154+
155+
# Step 3: Inject malicious prompt (should be blocked)
156+
response = client.post(
157+
"/terminals/fake-terminal-id/input",
158+
headers={"Host": "attack.poc"},
159+
params={"message": "launch the calculator"}, # From actual PoC
160+
)
161+
assert response.status_code == 400
162+
163+
# Step 4: Read terminal output (should be blocked)
164+
response = client.get(
165+
"/terminals/fake-terminal-id/output",
166+
headers={"Host": "attack.poc"},
167+
params={"mode": "full"},
168+
)
169+
assert response.status_code == 400
170+
171+
def test_singularity_dns_rebinding_blocked(self):
172+
"""
173+
Test against Singularity DNS rebinding tool configuration.
174+
175+
From the PoC, attacker uses:
176+
- attackHostDomain: "attack.poc"
177+
- targetHostIPAddress: "127.0.0.1"
178+
179+
After rebinding, attack.poc points to 127.0.0.1, but Host header
180+
still says "attack.poc" - this should be blocked.
181+
"""
182+
response = client.get("/health", headers={"Host": "attack.poc"})
183+
assert response.status_code == 400
184+
185+
# Even with port
186+
response = client.get("/health", headers={"Host": "attack.poc:9889"})
187+
assert response.status_code == 400
188+
189+
190+
class TestLegitimateUseCases:
191+
"""Ensure legitimate CAO usage patterns still work."""
192+
193+
def test_cao_cli_can_connect(self):
194+
"""CAO CLI connecting to localhost should work."""
195+
response = client.get("/health", headers={"Host": "localhost:9889"})
196+
assert response.status_code == 200
197+
198+
def test_mcp_server_can_connect(self):
199+
"""MCP server connecting to localhost should work."""
200+
response = client.get("/health", headers={"Host": "127.0.0.1:9889"})
201+
assert response.status_code == 200
202+
203+
def test_browser_localhost_access(self):
204+
"""Browser accessing http://localhost:9889 should work."""
205+
response = client.get("/health", headers={"Host": "localhost"})
206+
assert response.status_code == 200
207+
208+
def test_curl_localhost_access(self):
209+
"""curl http://127.0.0.1:9889/health should work."""
210+
response = client.get("/health", headers={"Host": "127.0.0.1"})
211+
assert response.status_code == 200

test/api/test_terminals.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,6 @@
99
from cli_agent_orchestrator.models.terminal import Terminal
1010

1111

12-
@pytest.fixture
13-
def client():
14-
"""Create a test client."""
15-
return TestClient(app)
16-
17-
1812
class TestWorkingDirectoryEndpoint:
1913
"""Test GET /terminals/{terminal_id}/working-directory endpoint."""
2014

0 commit comments

Comments
 (0)