Skip to content

Commit 57861ae

Browse files
(security) Fix SSRF in batch runner download_bytes_from_url (#38482)
Signed-off-by: jperezde <jperezde@redhat.com>
1 parent ac30a83 commit 57861ae

3 files changed

Lines changed: 183 additions & 8 deletions

File tree

docs/usage/security.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@ Restrict domains that vLLM can access for media URLs by setting
6666
`--allowed-media-domains` to prevent Server-Side Request Forgery (SSRF) attacks.
6767
(e.g. `--allowed-media-domains upload.wikimedia.org github.com www.bogotobogo.com`)
6868

69+
This protection applies to both the online serving API (multimodal inputs) and
70+
the **batch runner** (`vllm run-batch`), where `file_url` values in batch
71+
transcription/translation requests are validated against the same allowlist.
72+
6973
Without domain restrictions, a malicious user could supply URLs that:
7074

7175
- **Target internal services**: Access internal network endpoints, cloud metadata

tests/entrypoints/openai/test_run_batch.py

Lines changed: 133 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,15 @@
44
import json
55
import subprocess
66
import tempfile
7+
from unittest.mock import AsyncMock, MagicMock, patch
78

89
import pytest
910

1011
from vllm.assets.audio import AudioAsset
11-
from vllm.entrypoints.openai.run_batch import BatchRequestOutput
12+
from vllm.entrypoints.openai.run_batch import (
13+
BatchRequestOutput,
14+
download_bytes_from_url,
15+
)
1216

1317
CHAT_MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM"
1418
EMBEDDING_MODEL_NAME = "intfloat/multilingual-e5-small"
@@ -746,3 +750,131 @@ def test_tool_calling():
746750
assert "arguments" in tool_call["function"]
747751
# Verify the tool name matches our tool definition
748752
assert tool_call["function"]["name"] == "get_current_weather"
753+
754+
755+
# ---------------------------------------------------------------------------
756+
# Unit tests for download_bytes_from_url SSRF protection
757+
# ---------------------------------------------------------------------------
758+
759+
760+
def _make_aiohttp_mocks(response_data: bytes = b"fake-data", status: int = 200):
761+
"""Create mock objects that simulate aiohttp.ClientSession context managers."""
762+
mock_resp = MagicMock()
763+
mock_resp.status = status
764+
mock_resp.read = AsyncMock(return_value=response_data)
765+
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
766+
mock_resp.__aexit__ = AsyncMock(return_value=False)
767+
768+
mock_session = MagicMock()
769+
mock_session.get = MagicMock(return_value=mock_resp)
770+
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
771+
mock_session.__aexit__ = AsyncMock(return_value=False)
772+
return mock_session
773+
774+
775+
@pytest.mark.asyncio
776+
async def test_download_bytes_data_url_bypasses_domain_check():
777+
"""data: URLs must work regardless of the domain allowlist."""
778+
data_url = f"data:audio/wav;base64,{MINIMAL_WAV_BASE64}"
779+
result = await download_bytes_from_url(
780+
data_url, allowed_media_domains=["example.com"]
781+
)
782+
assert isinstance(result, bytes)
783+
assert len(result) > 0
784+
785+
786+
@pytest.mark.asyncio
787+
async def test_download_bytes_rejects_disallowed_domain():
788+
"""HTTP URLs whose hostname is not in the allowlist must be rejected."""
789+
url = "https://evil.internal/secret"
790+
with pytest.raises(ValueError, match="allowed domains"):
791+
await download_bytes_from_url(url, allowed_media_domains=["example.com"])
792+
793+
794+
@pytest.mark.asyncio
795+
async def test_download_bytes_rejects_cloud_metadata_ip():
796+
"""Cloud metadata endpoints must be blocked when an allowlist is set."""
797+
url = "http://169.254.169.254/latest/meta-data/"
798+
with pytest.raises(ValueError, match="allowed domains"):
799+
await download_bytes_from_url(url, allowed_media_domains=["example.com"])
800+
801+
802+
@pytest.mark.asyncio
803+
async def test_download_bytes_rejects_internal_ip():
804+
"""Private-range IPs must be blocked when an allowlist is set."""
805+
for internal_url in [
806+
"http://10.0.0.1/secret",
807+
"http://192.168.1.1/admin",
808+
"http://127.0.0.1:8080/internal",
809+
]:
810+
with pytest.raises(ValueError, match="allowed domains"):
811+
await download_bytes_from_url(
812+
internal_url, allowed_media_domains=["example.com"]
813+
)
814+
815+
816+
@pytest.mark.asyncio
817+
async def test_download_bytes_allows_permitted_domain():
818+
"""HTTP URLs whose hostname IS in the allowlist must be fetched."""
819+
url = "https://example.com/audio.wav"
820+
expected = b"audio-bytes"
821+
mock_session = _make_aiohttp_mocks(expected)
822+
823+
with patch(
824+
"vllm.entrypoints.openai.run_batch.aiohttp.ClientSession",
825+
return_value=mock_session,
826+
):
827+
result = await download_bytes_from_url(
828+
url, allowed_media_domains=["example.com"]
829+
)
830+
assert result == expected
831+
832+
833+
@pytest.mark.asyncio
834+
async def test_download_bytes_no_allowlist_permits_any_domain():
835+
"""Without an allowlist all HTTP URLs must be attempted (backward compat)."""
836+
url = "https://any-domain.example.org/file.wav"
837+
expected = b"some-data"
838+
mock_session = _make_aiohttp_mocks(expected)
839+
840+
with patch(
841+
"vllm.entrypoints.openai.run_batch.aiohttp.ClientSession",
842+
return_value=mock_session,
843+
):
844+
result = await download_bytes_from_url(url, allowed_media_domains=None)
845+
assert result == expected
846+
847+
848+
@pytest.mark.asyncio
849+
async def test_download_bytes_empty_allowlist_denies_all():
850+
"""An empty allowlist must deny all HTTP URLs (least privilege)."""
851+
url = "https://any-domain.example.org/file.wav"
852+
with pytest.raises(ValueError, match="allowed domains"):
853+
await download_bytes_from_url(url, allowed_media_domains=[])
854+
855+
856+
@pytest.mark.asyncio
857+
async def test_download_bytes_unsupported_scheme():
858+
"""Unsupported URL schemes must be rejected regardless of allowlist."""
859+
with pytest.raises(ValueError, match="Unsupported URL scheme"):
860+
await download_bytes_from_url("ftp://example.com/file.wav")
861+
862+
with pytest.raises(ValueError, match="Unsupported URL scheme"):
863+
await download_bytes_from_url(
864+
"ftp://example.com/file.wav",
865+
allowed_media_domains=["example.com"],
866+
)
867+
868+
869+
@pytest.mark.asyncio
870+
async def test_download_bytes_backslash_bypass():
871+
"""Backslash-@ URL confusion must not bypass the allowlist.
872+
873+
urllib3.parse_url() and aiohttp/yarl disagree on backslash-before-@.
874+
The fix normalizes through urllib3 before handing to aiohttp.
875+
"""
876+
bypass_url = "http://allowed.example.com\\@evil.internal/secret"
877+
with pytest.raises(ValueError, match="allowed domains"):
878+
await download_bytes_from_url(
879+
bypass_url, allowed_media_domains=["evil.internal"]
880+
)

vllm/entrypoints/openai/run_batch.py

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
from pydantic_core.core_schema import ValidationInfo
2121
from starlette.datastructures import State
2222
from tqdm import tqdm
23+
from urllib3.util import parse_url
2324

25+
import vllm.envs as envs
2426
from vllm.config import config
2527
from vllm.engine.arg_utils import AsyncEngineArgs
2628
from vllm.engine.protocol import EngineClient
@@ -439,19 +441,25 @@ async def write_file(
439441
await write_local_file(path_or_url, batch_outputs)
440442

441443

442-
async def download_bytes_from_url(url: str) -> bytes:
444+
async def download_bytes_from_url(
445+
url: str,
446+
allowed_media_domains: list[str] | None = None,
447+
) -> bytes:
443448
"""
444449
Download data from a URL or decode from a data URL.
445450
446451
Args:
447452
url: Either an HTTP/HTTPS URL or a data URL (data:...;base64,...)
453+
allowed_media_domains: If set, only HTTP/HTTPS URLs whose hostname
454+
is in this list are permitted. data: URLs are not subject to
455+
this restriction.
448456
449457
Returns:
450458
Data as bytes
451459
"""
452460
parsed = urlparse(url)
453461

454-
# Handle data URLs (base64 encoded)
462+
# Handle data URLs (base64 encoded) - not subject to domain restrictions
455463
if parsed.scheme == "data":
456464
# Format: data:...;base64,<base64_data>
457465
if "," in url:
@@ -465,9 +473,24 @@ async def download_bytes_from_url(url: str) -> bytes:
465473

466474
# Handle HTTP/HTTPS URLs
467475
elif parsed.scheme in ("http", "https"):
476+
if allowed_media_domains is not None:
477+
url_spec = parse_url(url)
478+
if url_spec.hostname not in allowed_media_domains:
479+
raise ValueError(
480+
f"The URL must be from one of the allowed domains: "
481+
f"{allowed_media_domains}. Input URL domain: "
482+
f"{url_spec.hostname}"
483+
)
484+
# Use the normalized URL to prevent parsing discrepancies
485+
# between urllib3 and aiohttp (e.g. backslash-@ attacks).
486+
url = url_spec.url
487+
468488
async with (
469489
aiohttp.ClientSession() as session,
470-
session.get(url) as resp,
490+
session.get(
491+
url,
492+
allow_redirects=envs.VLLM_MEDIA_URL_ALLOW_REDIRECTS,
493+
) as resp,
471494
):
472495
if resp.status != 200:
473496
raise Exception(
@@ -593,7 +616,10 @@ def handle_endpoint_request(
593616
return run_request(handler_fn, request, tracker)
594617

595618

596-
def make_transcription_wrapper(is_translation: bool) -> WrapperFn:
619+
def make_transcription_wrapper(
620+
is_translation: bool,
621+
allowed_media_domains: list[str] | None = None,
622+
) -> WrapperFn:
597623
"""
598624
Factory function to create a wrapper for transcription/translation handlers.
599625
The wrapper converts BatchTranscriptionRequest or BatchTranslationRequest
@@ -602,6 +628,8 @@ def make_transcription_wrapper(is_translation: bool) -> WrapperFn:
602628
Args:
603629
is_translation: If True, process as translation; otherwise process
604630
as transcription
631+
allowed_media_domains: If set, only URLs from these domains are
632+
permitted for HTTP/HTTPS fetches.
605633
606634
Returns:
607635
A function that takes a handler and returns a wrapped handler
@@ -619,7 +647,10 @@ async def transcription_wrapper(
619647
):
620648
try:
621649
# Download data from URL
622-
audio_data = await download_bytes_from_url(batch_request_body.file_url)
650+
audio_data = await download_bytes_from_url(
651+
batch_request_body.file_url,
652+
allowed_media_domains=allowed_media_domains,
653+
)
623654

624655
# Create a mock file from the downloaded audio data
625656
mock_file = UploadFile(
@@ -691,6 +722,8 @@ async def build_endpoint_registry(
691722
serving_embedding = getattr(state, "serving_embedding", None)
692723
serving_scores = getattr(state, "serving_scores", None)
693724

725+
allowed_media_domains = getattr(args, "allowed_media_domains", None)
726+
694727
# Registry of endpoint configurations
695728
endpoint_registry: dict[str, dict[str, Any]] = {
696729
"completions": {
@@ -730,7 +763,10 @@ async def build_endpoint_registry(
730763
if openai_serving_transcription is not None
731764
else None
732765
),
733-
"wrapper_fn": make_transcription_wrapper(is_translation=False),
766+
"wrapper_fn": make_transcription_wrapper(
767+
is_translation=False,
768+
allowed_media_domains=allowed_media_domains,
769+
),
734770
},
735771
"translations": {
736772
"url_matcher": lambda url: url == "/v1/audio/translations",
@@ -739,7 +775,10 @@ async def build_endpoint_registry(
739775
if openai_serving_translation is not None
740776
else None
741777
),
742-
"wrapper_fn": make_transcription_wrapper(is_translation=True),
778+
"wrapper_fn": make_transcription_wrapper(
779+
is_translation=True,
780+
allowed_media_domains=allowed_media_domains,
781+
),
743782
},
744783
}
745784

0 commit comments

Comments
 (0)