|
1 | 1 | import os |
2 | 2 | import pathlib |
3 | 3 | import stat |
| 4 | +import tempfile |
4 | 5 | import time |
5 | 6 |
|
6 | 7 | import anyio |
@@ -448,3 +449,70 @@ def mock_timeout(*args, **kwargs): |
448 | 449 | response = client.get("/example.txt") |
449 | 450 | assert response.status_code == 500 |
450 | 451 | assert response.text == "Internal Server Error" |
| 452 | + |
| 453 | + |
| 454 | +def test_staticfiles_follows_symlinks(tmpdir, test_client_factory): |
| 455 | + statics_path = os.path.join(tmpdir, "statics") |
| 456 | + os.mkdir(statics_path) |
| 457 | + |
| 458 | + source_path = tempfile.mkdtemp() |
| 459 | + source_file_path = os.path.join(source_path, "page.html") |
| 460 | + with open(source_file_path, "w") as file: |
| 461 | + file.write("<h1>Hello</h1>") |
| 462 | + |
| 463 | + statics_file_path = os.path.join(statics_path, "index.html") |
| 464 | + os.symlink(source_file_path, statics_file_path) |
| 465 | + |
| 466 | + app = StaticFiles(directory=statics_path, follow_symlink=True) |
| 467 | + client = test_client_factory(app) |
| 468 | + |
| 469 | + response = client.get("/index.html") |
| 470 | + assert response.url == "http://testserver/index.html" |
| 471 | + assert response.status_code == 200 |
| 472 | + assert response.text == "<h1>Hello</h1>" |
| 473 | + |
| 474 | + |
| 475 | +def test_staticfiles_follows_symlink_directories(tmpdir, test_client_factory): |
| 476 | + statics_path = os.path.join(tmpdir, "statics") |
| 477 | + statics_html_path = os.path.join(statics_path, "html") |
| 478 | + os.mkdir(statics_path) |
| 479 | + |
| 480 | + source_path = tempfile.mkdtemp() |
| 481 | + source_file_path = os.path.join(source_path, "page.html") |
| 482 | + with open(source_file_path, "w") as file: |
| 483 | + file.write("<h1>Hello</h1>") |
| 484 | + |
| 485 | + os.symlink(source_path, statics_html_path) |
| 486 | + |
| 487 | + app = StaticFiles(directory=statics_path, follow_symlink=True) |
| 488 | + client = test_client_factory(app) |
| 489 | + |
| 490 | + response = client.get("/html/page.html") |
| 491 | + assert response.url == "http://testserver/html/page.html" |
| 492 | + assert response.status_code == 200 |
| 493 | + assert response.text == "<h1>Hello</h1>" |
| 494 | + |
| 495 | + |
| 496 | +def test_staticfiles_disallows_path_traversal_with_symlinks(tmpdir): |
| 497 | + statics_path = os.path.join(tmpdir, "statics") |
| 498 | + |
| 499 | + root_source_path = tempfile.mkdtemp() |
| 500 | + source_path = os.path.join(root_source_path, "statics") |
| 501 | + os.mkdir(source_path) |
| 502 | + |
| 503 | + source_file_path = os.path.join(root_source_path, "index.html") |
| 504 | + with open(source_file_path, "w") as file: |
| 505 | + file.write("<h1>Hello</h1>") |
| 506 | + |
| 507 | + os.symlink(source_path, statics_path) |
| 508 | + |
| 509 | + app = StaticFiles(directory=statics_path, follow_symlink=True) |
| 510 | + # We can't test this with 'httpx', so we test the app directly here. |
| 511 | + path = app.get_path({"path": "/../index.html"}) |
| 512 | + scope = {"method": "GET"} |
| 513 | + |
| 514 | + with pytest.raises(HTTPException) as exc_info: |
| 515 | + anyio.run(app.get_response, path, scope) |
| 516 | + |
| 517 | + assert exc_info.value.status_code == 404 |
| 518 | + assert exc_info.value.detail == "Not Found" |
0 commit comments