|
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | 15 | import re |
| 16 | +from http import HTTPStatus |
| 17 | +from typing import Tuple |
16 | 18 |
|
17 | 19 | from twisted.internet.defer import Deferred |
18 | 20 | from twisted.web.resource import Resource |
19 | 21 |
|
20 | 22 | from synapse.api.errors import Codes, RedirectException, SynapseError |
21 | 23 | from synapse.config.server import parse_listener_def |
22 | | -from synapse.http.server import DirectServeHtmlResource, JsonResource, OptionsResource |
23 | | -from synapse.http.site import SynapseSite |
| 24 | +from synapse.http.server import ( |
| 25 | + DirectServeHtmlResource, |
| 26 | + DirectServeJsonResource, |
| 27 | + JsonResource, |
| 28 | + OptionsResource, |
| 29 | + cancellable, |
| 30 | +) |
| 31 | +from synapse.http.site import SynapseRequest, SynapseSite |
24 | 32 | from synapse.logging.context import make_deferred_yieldable |
| 33 | +from synapse.types import JsonDict |
25 | 34 | from synapse.util import Clock |
26 | 35 |
|
27 | 36 | from tests import unittest |
| 37 | +from tests.http.server._base import EndpointCancellationTestHelperMixin |
28 | 38 | from tests.server import ( |
29 | 39 | FakeSite, |
30 | 40 | ThreadedMemoryReactorClock, |
@@ -363,3 +373,100 @@ async def callback(request): |
363 | 373 |
|
364 | 374 | self.assertEqual(channel.result["code"], b"200") |
365 | 375 | self.assertNotIn("body", channel.result) |
| 376 | + |
| 377 | + |
| 378 | +class CancellableDirectServeJsonResource(DirectServeJsonResource): |
| 379 | + def __init__(self, clock: Clock): |
| 380 | + super().__init__() |
| 381 | + self.clock = clock |
| 382 | + |
| 383 | + @cancellable |
| 384 | + async def _async_render_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: |
| 385 | + await self.clock.sleep(1.0) |
| 386 | + return HTTPStatus.OK, {"result": True} |
| 387 | + |
| 388 | + async def _async_render_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: |
| 389 | + await self.clock.sleep(1.0) |
| 390 | + return HTTPStatus.OK, {"result": True} |
| 391 | + |
| 392 | + |
| 393 | +class CancellableDirectServeHtmlResource(DirectServeHtmlResource): |
| 394 | + ERROR_TEMPLATE = "{code} {msg}" |
| 395 | + |
| 396 | + def __init__(self, clock: Clock): |
| 397 | + super().__init__() |
| 398 | + self.clock = clock |
| 399 | + |
| 400 | + @cancellable |
| 401 | + async def _async_render_GET(self, request: SynapseRequest) -> Tuple[int, bytes]: |
| 402 | + await self.clock.sleep(1.0) |
| 403 | + return HTTPStatus.OK, b"ok" |
| 404 | + |
| 405 | + async def _async_render_POST(self, request: SynapseRequest) -> Tuple[int, bytes]: |
| 406 | + await self.clock.sleep(1.0) |
| 407 | + return HTTPStatus.OK, b"ok" |
| 408 | + |
| 409 | + |
| 410 | +class DirectServeJsonResourceCancellationTests(EndpointCancellationTestHelperMixin): |
| 411 | + """Tests for `DirectServeJsonResource` cancellation.""" |
| 412 | + |
| 413 | + def setUp(self): |
| 414 | + self.reactor = ThreadedMemoryReactorClock() |
| 415 | + self.clock = Clock(self.reactor) |
| 416 | + self.resource = CancellableDirectServeJsonResource(self.clock) |
| 417 | + self.site = FakeSite(self.resource, self.reactor) |
| 418 | + |
| 419 | + def test_cancellable_disconnect(self) -> None: |
| 420 | + """Test that handlers with the `@cancellable` flag can be cancelled.""" |
| 421 | + channel = make_request( |
| 422 | + self.reactor, self.site, "GET", "/sleep", await_result=False |
| 423 | + ) |
| 424 | + self._test_disconnect( |
| 425 | + self.reactor, |
| 426 | + channel, |
| 427 | + expect_cancellation=True, |
| 428 | + expected_body={"error": "Request cancelled", "errcode": Codes.UNKNOWN}, |
| 429 | + ) |
| 430 | + |
| 431 | + def test_uncancellable_disconnect(self) -> None: |
| 432 | + """Test that handlers without the `@cancellable` flag cannot be cancelled.""" |
| 433 | + channel = make_request( |
| 434 | + self.reactor, self.site, "POST", "/sleep", await_result=False |
| 435 | + ) |
| 436 | + self._test_disconnect( |
| 437 | + self.reactor, |
| 438 | + channel, |
| 439 | + expect_cancellation=False, |
| 440 | + expected_body={"result": True}, |
| 441 | + ) |
| 442 | + |
| 443 | + |
| 444 | +class DirectServeHtmlResourceCancellationTests(EndpointCancellationTestHelperMixin): |
| 445 | + """Tests for `DirectServeHtmlResource` cancellation.""" |
| 446 | + |
| 447 | + def setUp(self): |
| 448 | + self.reactor = ThreadedMemoryReactorClock() |
| 449 | + self.clock = Clock(self.reactor) |
| 450 | + self.resource = CancellableDirectServeHtmlResource(self.clock) |
| 451 | + self.site = FakeSite(self.resource, self.reactor) |
| 452 | + |
| 453 | + def test_cancellable_disconnect(self) -> None: |
| 454 | + """Test that handlers with the `@cancellable` flag can be cancelled.""" |
| 455 | + channel = make_request( |
| 456 | + self.reactor, self.site, "GET", "/sleep", await_result=False |
| 457 | + ) |
| 458 | + self._test_disconnect( |
| 459 | + self.reactor, |
| 460 | + channel, |
| 461 | + expect_cancellation=True, |
| 462 | + expected_body=b"499 Request cancelled", |
| 463 | + ) |
| 464 | + |
| 465 | + def test_uncancellable_disconnect(self) -> None: |
| 466 | + """Test that handlers without the `@cancellable` flag cannot be cancelled.""" |
| 467 | + channel = make_request( |
| 468 | + self.reactor, self.site, "POST", "/sleep", await_result=False |
| 469 | + ) |
| 470 | + self._test_disconnect( |
| 471 | + self.reactor, channel, expect_cancellation=False, expected_body=b"ok" |
| 472 | + ) |
0 commit comments