|
4 | 4 | import json |
5 | 5 | import subprocess |
6 | 6 | import tempfile |
| 7 | +from unittest.mock import AsyncMock, MagicMock, patch |
7 | 8 |
|
8 | 9 | import pytest |
9 | 10 |
|
10 | 11 | from vllm.assets.audio import AudioAsset |
11 | | -from vllm.entrypoints.openai.run_batch import BatchRequestOutput |
| 12 | +from vllm.entrypoints.openai.run_batch import ( |
| 13 | + BatchRequestOutput, |
| 14 | + download_bytes_from_url, |
| 15 | +) |
12 | 16 |
|
13 | 17 | CHAT_MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM" |
14 | 18 | EMBEDDING_MODEL_NAME = "intfloat/multilingual-e5-small" |
@@ -746,3 +750,131 @@ def test_tool_calling(): |
746 | 750 | assert "arguments" in tool_call["function"] |
747 | 751 | # Verify the tool name matches our tool definition |
748 | 752 | assert tool_call["function"]["name"] == "get_current_weather" |
| 753 | + |
| 754 | + |
| 755 | +# --------------------------------------------------------------------------- |
| 756 | +# Unit tests for download_bytes_from_url SSRF protection |
| 757 | +# --------------------------------------------------------------------------- |
| 758 | + |
| 759 | + |
| 760 | +def _make_aiohttp_mocks(response_data: bytes = b"fake-data", status: int = 200): |
| 761 | + """Create mock objects that simulate aiohttp.ClientSession context managers.""" |
| 762 | + mock_resp = MagicMock() |
| 763 | + mock_resp.status = status |
| 764 | + mock_resp.read = AsyncMock(return_value=response_data) |
| 765 | + mock_resp.__aenter__ = AsyncMock(return_value=mock_resp) |
| 766 | + mock_resp.__aexit__ = AsyncMock(return_value=False) |
| 767 | + |
| 768 | + mock_session = MagicMock() |
| 769 | + mock_session.get = MagicMock(return_value=mock_resp) |
| 770 | + mock_session.__aenter__ = AsyncMock(return_value=mock_session) |
| 771 | + mock_session.__aexit__ = AsyncMock(return_value=False) |
| 772 | + return mock_session |
| 773 | + |
| 774 | + |
| 775 | +@pytest.mark.asyncio |
| 776 | +async def test_download_bytes_data_url_bypasses_domain_check(): |
| 777 | + """data: URLs must work regardless of the domain allowlist.""" |
| 778 | + data_url = f"data:audio/wav;base64,{MINIMAL_WAV_BASE64}" |
| 779 | + result = await download_bytes_from_url( |
| 780 | + data_url, allowed_media_domains=["example.com"] |
| 781 | + ) |
| 782 | + assert isinstance(result, bytes) |
| 783 | + assert len(result) > 0 |
| 784 | + |
| 785 | + |
| 786 | +@pytest.mark.asyncio |
| 787 | +async def test_download_bytes_rejects_disallowed_domain(): |
| 788 | + """HTTP URLs whose hostname is not in the allowlist must be rejected.""" |
| 789 | + url = "https://evil.internal/secret" |
| 790 | + with pytest.raises(ValueError, match="allowed domains"): |
| 791 | + await download_bytes_from_url(url, allowed_media_domains=["example.com"]) |
| 792 | + |
| 793 | + |
| 794 | +@pytest.mark.asyncio |
| 795 | +async def test_download_bytes_rejects_cloud_metadata_ip(): |
| 796 | + """Cloud metadata endpoints must be blocked when an allowlist is set.""" |
| 797 | + url = "http://169.254.169.254/latest/meta-data/" |
| 798 | + with pytest.raises(ValueError, match="allowed domains"): |
| 799 | + await download_bytes_from_url(url, allowed_media_domains=["example.com"]) |
| 800 | + |
| 801 | + |
| 802 | +@pytest.mark.asyncio |
| 803 | +async def test_download_bytes_rejects_internal_ip(): |
| 804 | + """Private-range IPs must be blocked when an allowlist is set.""" |
| 805 | + for internal_url in [ |
| 806 | + "http://10.0.0.1/secret", |
| 807 | + "http://192.168.1.1/admin", |
| 808 | + "http://127.0.0.1:8080/internal", |
| 809 | + ]: |
| 810 | + with pytest.raises(ValueError, match="allowed domains"): |
| 811 | + await download_bytes_from_url( |
| 812 | + internal_url, allowed_media_domains=["example.com"] |
| 813 | + ) |
| 814 | + |
| 815 | + |
| 816 | +@pytest.mark.asyncio |
| 817 | +async def test_download_bytes_allows_permitted_domain(): |
| 818 | + """HTTP URLs whose hostname IS in the allowlist must be fetched.""" |
| 819 | + url = "https://example.com/audio.wav" |
| 820 | + expected = b"audio-bytes" |
| 821 | + mock_session = _make_aiohttp_mocks(expected) |
| 822 | + |
| 823 | + with patch( |
| 824 | + "vllm.entrypoints.openai.run_batch.aiohttp.ClientSession", |
| 825 | + return_value=mock_session, |
| 826 | + ): |
| 827 | + result = await download_bytes_from_url( |
| 828 | + url, allowed_media_domains=["example.com"] |
| 829 | + ) |
| 830 | + assert result == expected |
| 831 | + |
| 832 | + |
| 833 | +@pytest.mark.asyncio |
| 834 | +async def test_download_bytes_no_allowlist_permits_any_domain(): |
| 835 | + """Without an allowlist all HTTP URLs must be attempted (backward compat).""" |
| 836 | + url = "https://any-domain.example.org/file.wav" |
| 837 | + expected = b"some-data" |
| 838 | + mock_session = _make_aiohttp_mocks(expected) |
| 839 | + |
| 840 | + with patch( |
| 841 | + "vllm.entrypoints.openai.run_batch.aiohttp.ClientSession", |
| 842 | + return_value=mock_session, |
| 843 | + ): |
| 844 | + result = await download_bytes_from_url(url, allowed_media_domains=None) |
| 845 | + assert result == expected |
| 846 | + |
| 847 | + |
| 848 | +@pytest.mark.asyncio |
| 849 | +async def test_download_bytes_empty_allowlist_denies_all(): |
| 850 | + """An empty allowlist must deny all HTTP URLs (least privilege).""" |
| 851 | + url = "https://any-domain.example.org/file.wav" |
| 852 | + with pytest.raises(ValueError, match="allowed domains"): |
| 853 | + await download_bytes_from_url(url, allowed_media_domains=[]) |
| 854 | + |
| 855 | + |
| 856 | +@pytest.mark.asyncio |
| 857 | +async def test_download_bytes_unsupported_scheme(): |
| 858 | + """Unsupported URL schemes must be rejected regardless of allowlist.""" |
| 859 | + with pytest.raises(ValueError, match="Unsupported URL scheme"): |
| 860 | + await download_bytes_from_url("ftp://example.com/file.wav") |
| 861 | + |
| 862 | + with pytest.raises(ValueError, match="Unsupported URL scheme"): |
| 863 | + await download_bytes_from_url( |
| 864 | + "ftp://example.com/file.wav", |
| 865 | + allowed_media_domains=["example.com"], |
| 866 | + ) |
| 867 | + |
| 868 | + |
| 869 | +@pytest.mark.asyncio |
| 870 | +async def test_download_bytes_backslash_bypass(): |
| 871 | + """Backslash-@ URL confusion must not bypass the allowlist. |
| 872 | +
|
| 873 | + urllib3.parse_url() and aiohttp/yarl disagree on backslash-before-@. |
| 874 | + The fix normalizes through urllib3 before handing to aiohttp. |
| 875 | + """ |
| 876 | + bypass_url = "http://allowed.example.com\\@evil.internal/secret" |
| 877 | + with pytest.raises(ValueError, match="allowed domains"): |
| 878 | + await download_bytes_from_url( |
| 879 | + bypass_url, allowed_media_domains=["evil.internal"] |
| 880 | + ) |
0 commit comments