Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ max_supported_python = "3.14"
keep_full_version = true

[tool.pytest.ini_options]
addopts = "--tb=short --strict-markers -ra"
addopts = "--tb=short --strict-markers -ra --no-migrations"
testpaths = [ "tests" ]
markers = [
"requires_postgres: marks tests as requiring a PostgreSQL database backend",
Expand Down
21 changes: 20 additions & 1 deletion rest_framework/test.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
# Note that we import as `DjangoRequestFactory` and `DjangoClient` in order
# to make it harder for the user to import the wrong thing without realizing.
import io
from contextlib import contextmanager
from importlib import import_module

from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
from django.core.handlers.wsgi import WSGIHandler
from django.core.signals import request_finished, request_started
from django.db import close_old_connections
from django.test import override_settings, testcases
from django.test.client import Client as DjangoClient
from django.test.client import ClientHandler
Expand All @@ -22,6 +25,21 @@ def force_authenticate(request, user=None, token=None):
request._force_auth_token = token


@contextmanager
def _keep_connections_open():
"""
Prevent Django from closing the database connection while a request
is dispatched, matching the behavior of Django's ClientHandler.
"""
request_started.disconnect(close_old_connections)
request_finished.disconnect(close_old_connections)
Comment on lines +34 to +35
Copy link

Copilot AI May 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_keep_connections_open() disconnects close_old_connections and then unconditionally reconnects it in finally. This can mutate global signal state if the receiver was not connected to begin with (or if the context is nested), and it is not thread-safe because signals are process-global. Track whether each disconnect actually removed a receiver (and only reconnect when needed), and consider making the context re-entrant (eg, a counter) to avoid inner contexts re-enabling connection-closing while an outer context is still active.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Member Author

@browniebroke browniebroke May 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand when that would be a problem in practice... As far as I understand, this is mainly used when running tests, and Django --parrallel test option runs each worker in separate processes, not threads (source).

Django itself does this disconnect/connect in ClientHandler

try:
yield
finally:
request_started.connect(close_old_connections)
request_finished.connect(close_old_connections)


if requests is not None:
class HeaderDict(requests.packages.urllib3._collections.HTTPHeaderDict):
def get_all(self, key, default):
Expand Down Expand Up @@ -90,7 +108,8 @@ def start_response(wsgi_status, wsgi_headers, exc_info=None):

# Make the outgoing request via WSGI.
environ = self.get_environ(request)
wsgi_response = self.app(environ, start_response)
with _keep_connections_open():
wsgi_response = self.app(environ, start_response)

# Build the underlying urllib3.HTTPResponse
raw_kwargs['body'] = io.BytesIO(b''.join(wsgi_response))
Expand Down
24 changes: 24 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,31 @@
import dj_database_url
import django
import pytest
from django.apps import apps
from django.core import management
from django.core.management.color import no_style
from django.db import connection


@pytest.fixture
def reset_sequences():
"""
Reset all database sequences so PKs start from 1.

PostgreSQL sequences are non-transactional and persist across
TestCase's transaction rollbacks. Apply this fixture to test
classes that rely on hardcoded PKs to keep them predictable
regardless of execution order. No-op on SQLite.
"""
if connection.vendor != 'postgresql':
return
table_names = set(connection.introspection.table_names())
models = [m for m in apps.get_models() if m._meta.db_table in table_names]
sql_list = connection.ops.sequence_reset_sql(no_style(), models)
if sql_list:
with connection.cursor() as cursor:
for sql in sql_list:
cursor.execute(sql)
Comment thread
browniebroke marked this conversation as resolved.


def pytest_addoption(parser):
Expand Down
5 changes: 4 additions & 1 deletion tests/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_filter_queryset_raises_error(self):


class SearchFilterModel(models.Model):
title = models.CharField(max_length=20)
title = models.CharField(max_length=25)
text = models.CharField(max_length=100)


Expand Down Expand Up @@ -459,6 +459,7 @@ class Meta:
fields = '__all__'


@pytest.mark.usefixtures("reset_sequences")
class SearchFilterM2MTests(TestCase):
def setUp(self):
# Sequence of title/text/attributes is:
Expand Down Expand Up @@ -657,6 +658,7 @@ class Meta:
fields = '__all__'


@pytest.mark.usefixtures("reset_sequences")
class OrderingFilterTests(TestCase):
def setUp(self):
# Sequence of title/text is:
Expand Down Expand Up @@ -974,6 +976,7 @@ class Meta:
fields = ('id', 'user')


@pytest.mark.usefixtures("reset_sequences")
class SensitiveOrderingFilterTests(TestCase):
def setUp(self):
for idx in range(3):
Expand Down
4 changes: 4 additions & 0 deletions tests/test_generics.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class SlugBasedInstanceView(InstanceView):


# Tests
@pytest.mark.usefixtures("reset_sequences")
class TestRootView(TestCase):
def setUp(self):
"""
Expand Down Expand Up @@ -171,6 +172,7 @@ def test_post_error_root_view(self):
EXPECTED_QUERIES_FOR_PUT = 2


@pytest.mark.usefixtures("reset_sequences")
class TestInstanceView(TestCase):
def setUp(self):
"""
Expand Down Expand Up @@ -334,6 +336,7 @@ def setUp(self):
self.view = FKInstanceView.as_view()


@pytest.mark.usefixtures("reset_sequences")
class TestOverriddenGetObject(TestCase):
"""
Test cases for a RetrieveUpdateDestroyAPIView that does NOT use the
Expand Down Expand Up @@ -477,6 +480,7 @@ class Meta:
return DynamicSerializer


@pytest.mark.usefixtures("reset_sequences")
class TestFilterBackendAppliedToViews(TestCase):
def setUp(self):
"""
Expand Down
1 change: 1 addition & 0 deletions tests/test_model_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,7 @@ class DisplayValueModel(models.Model):
color = models.ForeignKey(DisplayValueTargetModel, on_delete=models.CASCADE)


@pytest.mark.usefixtures("reset_sequences")
class TestRelationalFieldDisplayValue(TestCase):
def setUp(self):
DisplayValueTargetModel.objects.bulk_create([
Expand Down
1 change: 1 addition & 0 deletions tests/test_pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,6 +1045,7 @@ class CursorPaginationModel(models.Model):
created = models.IntegerField()


@pytest.mark.usefixtures("reset_sequences")
class TestCursorPaginationWithValueQueryset(CursorPaginationTestsMixin, TestCase):
"""
Unit tests for `pagination.CursorPagination` for value querysets.
Expand Down
4 changes: 4 additions & 0 deletions tests/test_permissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import unittest
from unittest import mock

import pytest
from django.conf import settings
from django.contrib.auth.models import AnonymousUser, Group, Permission, User
from django.db import models
Expand Down Expand Up @@ -73,6 +74,7 @@ def basic_auth_header(username, password):
return 'Basic %s' % base64_credentials


@pytest.mark.usefixtures("reset_sequences")
class ModelPermissionsIntegrationTests(TestCase):
def setUp(self):
User.objects.create_user('disallowed', 'disallowed@example.com', 'password')
Expand Down Expand Up @@ -325,6 +327,7 @@ def get_queryset(self):
get_queryset_object_permissions_view = GetQuerysetObjectPermissionInstanceView.as_view()


@pytest.mark.usefixtures("reset_sequences")
@unittest.skipUnless('guardian' in settings.INSTALLED_APPS, 'django-guardian not installed')
class ObjectPermissionsIntegrationTests(TestCase):
"""
Expand Down Expand Up @@ -504,6 +507,7 @@ class DeniedObjectViewWithDetail(PermissionInstanceView):
denied_object_view_with_detail = DeniedObjectViewWithDetail.as_view()


@pytest.mark.usefixtures("reset_sequences")
class CustomPermissionsTests(TestCase):
def setUp(self):
BasicModel(text='foo').save()
Expand Down
2 changes: 2 additions & 0 deletions tests/test_prefetch_related.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
from django.contrib.auth.models import Group, User
from django.test import TestCase

Expand All @@ -18,6 +19,7 @@ class UserUpdate(generics.UpdateAPIView):
serializer_class = UserSerializer


@pytest.mark.usefixtures("reset_sequences")
class TestPrefetchRelatedUpdates(TestCase):
def setUp(self):
self.user = User.objects.create(username='tom', email='tom@example.com')
Expand Down
Loading
Loading