|
1 | | -"""Main test for zstd middleware. |
| 1 | +"""Main tests for zstd middleware. |
2 | 2 |
|
3 | | -This tests are the same as the ones from starlette.tests.middleware.test_gzip |
| 3 | +Some of these tests are the same as the ones from starlette.tests.middleware.test_gzip |
4 | 4 | but using zstd instead. |
5 | 5 | """ |
6 | 6 | import functools |
| 7 | +import gzip |
| 8 | +import io |
7 | 9 |
|
8 | 10 | import pytest |
9 | 11 |
|
10 | 12 | from starlette.applications import Starlette |
11 | 13 | from starlette.responses import ( |
12 | 14 | JSONResponse, |
13 | 15 | PlainTextResponse, |
| 16 | + Response, |
14 | 17 | StreamingResponse, |
15 | 18 | ) |
16 | 19 | from starlette.testclient import TestClient |
@@ -174,3 +177,33 @@ def homepage(request): |
174 | 177 | assert response.text == "x" * 4000 |
175 | 178 | assert "Content-Encoding" not in response.headers |
176 | 179 | assert int(response.headers["Content-Length"]) == 4000 |
| 180 | + |
| 181 | + |
| 182 | +def test_zstd_avoids_double_encoding(): |
| 183 | + # See https://github.com/encode/starlette/pull/1901 |
| 184 | + |
| 185 | + app = Starlette() |
| 186 | + |
| 187 | + app.add_middleware(ZstdMiddleware, minimum_size=1) |
| 188 | + |
| 189 | + @app.route("/") |
| 190 | + def homepage(request): |
| 191 | + gzip_buffer = io.BytesIO() |
| 192 | + gzip_file = gzip.GzipFile(mode="wb", fileobj=gzip_buffer) |
| 193 | + gzip_file.write(b"hello world" * 200) |
| 194 | + gzip_file.close() |
| 195 | + body = gzip_buffer.getvalue() |
| 196 | + return Response( |
| 197 | + body, |
| 198 | + headers={ |
| 199 | + "content-encoding": "gzip", |
| 200 | + "x-gzipped-content-length": str(len(body)) |
| 201 | + } |
| 202 | + ) |
| 203 | + |
| 204 | + client = TestClient(app) |
| 205 | + response = client.get("/", headers={"accept-encoding": "zstd"}) |
| 206 | + assert response.status_code == 200 |
| 207 | + assert response.text == "hello world" * 200 |
| 208 | + assert response.headers["Content-Encoding"] == "gzip" |
| 209 | + assert response.headers["Content-Length"] == response.headers["x-gzipped-content-length"] |
0 commit comments