Skip to content

Commit 19d5404

Browse files
committed
Add fix for double encoding
Source: fullonic/brotli-asgi#24
1 parent 55f2422 commit 19d5404

2 files changed

Lines changed: 43 additions & 2 deletions

File tree

tests.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
1-
"""Main test for zstd middleware.
1+
"""Main tests for zstd middleware.
22
3-
This tests are the same as the ones from starlette.tests.middleware.test_gzip
3+
Some of these tests are the same as the ones from starlette.tests.middleware.test_gzip
44
but using zstd instead.
55
"""
66
import functools
7+
import gzip
8+
import io
79

810
import pytest
911

1012
from starlette.applications import Starlette
1113
from starlette.responses import (
1214
JSONResponse,
1315
PlainTextResponse,
16+
Response,
1417
StreamingResponse,
1518
)
1619
from starlette.testclient import TestClient
@@ -174,3 +177,33 @@ def homepage(request):
174177
assert response.text == "x" * 4000
175178
assert "Content-Encoding" not in response.headers
176179
assert int(response.headers["Content-Length"]) == 4000
180+
181+
182+
def test_zstd_avoids_double_encoding():
183+
# See https://github.com/encode/starlette/pull/1901
184+
185+
app = Starlette()
186+
187+
app.add_middleware(ZstdMiddleware, minimum_size=1)
188+
189+
@app.route("/")
190+
def homepage(request):
191+
gzip_buffer = io.BytesIO()
192+
gzip_file = gzip.GzipFile(mode="wb", fileobj=gzip_buffer)
193+
gzip_file.write(b"hello world" * 200)
194+
gzip_file.close()
195+
body = gzip_buffer.getvalue()
196+
return Response(
197+
body,
198+
headers={
199+
"content-encoding": "gzip",
200+
"x-gzipped-content-length": str(len(body))
201+
}
202+
)
203+
204+
client = TestClient(app)
205+
response = client.get("/", headers={"accept-encoding": "zstd"})
206+
assert response.status_code == 200
207+
assert response.text == "hello world" * 200
208+
assert response.headers["Content-Encoding"] == "gzip"
209+
assert response.headers["Content-Length"] == response.headers["x-gzipped-content-length"]

zstd_asgi/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ def __init__(
103103
self.send = unattached_send # type: Send
104104
self.initial_message = {} # type: Message
105105
self.started = False
106+
self.content_encoding_set = False
106107
self.zstd_buffer = io.BytesIO()
107108
self.zstd_file = zstandard.ZstdCompressor(
108109
level=level,
@@ -124,6 +125,13 @@ async def send_with_zstd(self, message: Message) -> None:
124125
# Don't send the initial message until we've determined how to
125126
# modify the outgoing headers correctly.
126127
self.initial_message = message
128+
headers = Headers(raw=self.initial_message["headers"])
129+
self.content_encoding_set = "content-encoding" in headers
130+
elif message_type == "http.response.body" and self.content_encoding_set:
131+
if not self.started:
132+
self.started = True
133+
await self.send(self.initial_message)
134+
await self.send(message)
127135
elif message_type == "http.response.body" and not self.started:
128136
self.started = True
129137
body = message.get("body", b"")

0 commit comments

Comments
 (0)