Skip to content
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions reversion/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from reversion.errors import RevertError
from reversion.models import Version
from reversion.revisions import is_active, register, is_registered, set_comment, create_revision, set_user
from reversion.revisions import is_active, register, is_registered, set_comment, create_revision, set_user, _get_options
from reversion.utils import mute_signals


Expand Down Expand Up @@ -200,7 +200,13 @@ def _reversion_revisionform_view(self, request, version, template_name, extra_co
version.revision.revert(delete=True)
# Run the normal changeform view.
with self.create_revision(request):
response = self.changeform_view(request, quote(version.object_id), request.path, extra_context)
obj = get_object_or_404(
version._model._default_manager.using(version.db),
**{_get_options(version._model).object_id_field: version.object_id},
)
response = self.changeform_view(
request, quote(str(obj.pk)), request.path, extra_context
)
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels worth optimizing for the case of object_id_field being the primary key. It'll save one query for 99.9% of django-reversion users.

So a check like ._meta.pk.name == object_id_field

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great yes thanks for catching this

# Decide on whether the keep the changes.
if request.method == "POST" and response.status_code == 302:
set_comment(_("Reverted to previous version, saved on %(datetime)s") % {
Expand Down Expand Up @@ -305,6 +311,8 @@ def history_view(self, request, object_id, extra_context=None):
if not self.has_change_permission(request):
raise PermissionDenied

obj = get_object_or_404(self.model, pk=unquote(object_id))
reversion_object_id = str(getattr(obj, _get_options(self.model).object_id_field))
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As above, this is worth optimizing for the common case of object_id_field being the primary key.

opts = self.model._meta
action_list = [
{
Expand All @@ -317,7 +325,7 @@ def history_view(self, request, object_id, extra_context=None):
for version
in self._reversion_order_version_queryset(request, Version.objects.get_for_object_reference(
self.model,
unquote(object_id), # Underscores in primary key get quoted to "_5F"
reversion_object_id,
).select_related("revision", "revision__user"))
]
# Compile the context.
Expand Down
26 changes: 17 additions & 9 deletions reversion/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from reversion.errors import RevertError
from reversion.revisions import (_follow_relations_recursive,
_get_content_type, _get_options)
_get_content_type, _get_options, _get_object_id_field)


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -91,7 +91,10 @@ def revert(self, delete=False):
model = version._model
try:
# Load the model instance from the same DB as it was saved under.
old_revision.add(model._default_manager.using(version.db).get(pk=version.object_id))
id_field = _get_object_id_field(model)
old_revision.add(
model._default_manager.using(version.db).get(**{id_field: version.object_id})
)
except model.DoesNotExist:
pass
# Calculate the set of all objects that are in the revision now.
Expand Down Expand Up @@ -135,14 +138,17 @@ def get_for_object_reference(self, model, object_id, model_db=None):
)

def get_for_object(self, obj, model_db=None):
return self.get_for_object_reference(obj.__class__, obj.pk, model_db=model_db)
opts = _get_options(obj.__class__)
return self.get_for_object_reference(
obj.__class__, getattr(obj, opts.object_id_field), model_db=model_db
)

def get_deleted(self, model, model_db=None):
model_db = model_db or router.db_for_write(model)
connection = connections[self.db]
object_id_field_name = _get_object_id_field(model)
if self.db == model_db and connection.vendor in ("sqlite", "postgresql", "oracle"):
pk_field_name = model._meta.pk.name
object_id_cast_target = model._meta.get_field(pk_field_name)
object_id_cast_target = model._meta.get_field(object_id_field_name)
if django.VERSION >= (2, 1):
# django 2.0 contains a critical bug that doesn't allow the code below to work,
# fallback to casting primary keys then
Expand All @@ -158,14 +164,14 @@ def get_deleted(self, model, model_db=None):
model_qs = (
model._default_manager
.using(model_db)
.filter(**{pk_field_name: casted_object_id})
.filter(**{object_id_field_name: casted_object_id})
)
else:
model_qs = (
model._default_manager
.using(model_db)
.annotate(_pk_to_object_id=Cast("pk", Version._meta.get_field("object_id")))
.filter(_pk_to_object_id=models.OuterRef("object_id"))
.annotate(_field_to_object_id=Cast(object_id_field_name, Version._meta.get_field("object_id")))
.filter(_field_to_object_id=models.OuterRef("object_id"))
)
# conditional expressions are being supported since django 3.0
# DISTINCT ON works only for Postgres DB
Expand All @@ -190,7 +196,9 @@ def get_deleted(self, model, model_db=None):
# We have to use a slow subquery.
subquery = self.get_for_model(model, model_db=model_db).exclude(
object_id__in=list(
model._default_manager.using(model_db).values_list("pk", flat=True).order_by().iterator()
model._default_manager.using(model_db).values_list(
object_id_field_name, flat=True
).order_by().iterator()
),
).values_list("object_id").annotate(
latest_pk=models.Max("pk")
Expand Down
19 changes: 15 additions & 4 deletions reversion/revisions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"for_concrete_model",
"ignore_duplicates",
"use_natural_foreign_keys",
"object_id_field"
))


Expand Down Expand Up @@ -168,7 +169,7 @@ def _add_to_revision(obj, using, model_db, explicit):
return
version_options = _get_options(obj.__class__)
content_type = _get_content_type(obj.__class__, using)
object_id = force_str(obj.pk)
object_id = force_str(getattr(obj, version_options.object_id_field))
version_key = (content_type, object_id)
# If the obj is already in the revision, stop now.
db_versions = _current_frame().db_versions
Expand All @@ -191,7 +192,9 @@ def _add_to_revision(obj, using, model_db, explicit):
)
# If the version is a duplicate, stop now.
if version_options.ignore_duplicates and explicit:
previous_version = Version.objects.using(using).get_for_object(obj, model_db=model_db).first()
previous_version = Version.objects.using(using).get_for_object_reference(
obj.__class__, object_id, model_db=model_db
).first()
if previous_version and previous_version._local_field_dict == version._local_field_dict:
return
# Store the version.
Expand All @@ -209,6 +212,11 @@ def add_to_revision(obj, model_db=None):
_add_to_revision(obj, db, model_db, True)


def _get_object_id_field(model):
Comment thread
etianen marked this conversation as resolved.
Outdated
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some parts of the codebase using VersionOptions.object_id_field directly, and others using _get_object_id_field().

What if... we resolved object_id_field to a concrete field name in register, stored that as VersionOptions.object_id_field, and then used VersionOptions.object_id_field consistly everywhere, dropping this helper?

So make object_id_field default to None. If it's None, we resolve it to a concrete primary key field using ._meta.pk.name. If it's a string, we validate it using .get_field during register.

This means explicitly setting object_id_field="pk" becomes an error.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolving object_id_field here using ._meta.pk.name resulted in errors, for models with multi-table inheritence, so instead i used meta.pk.attname.

if object_id_field is None:
            id_field = model._meta.pk.attname
        else:
            model._meta.get_field(object_id_field)
            id_field = object_id_field

field = _get_options(model).object_id_field
return model._meta.pk.name if field == "pk" else field


def _save_revision(versions, user=None, comment="", meta=(), date_created=None, using=None):
from reversion.models import Revision
from reversion.models import Version
Expand All @@ -221,7 +229,9 @@ def _save_revision(versions, user=None, comment="", meta=(), date_created=None,
model: {
db: frozenset(map(
force_str,
model._base_manager.using(db).filter(pk__in=pks).values_list("pk", flat=True),
model._base_manager.using(db).filter(
**{f"{_get_object_id_field(model)}__in": pks}
).values_list(_get_object_id_field(model), flat=True),
))
for db, pks in db_pks.items()
}
Expand Down Expand Up @@ -376,7 +386,7 @@ def _get_senders_and_signals(model):


def register(model=None, fields=None, exclude=(), follow=(), format="json",
for_concrete_model=True, ignore_duplicates=False, use_natural_foreign_keys=False):
for_concrete_model=True, ignore_duplicates=False, use_natural_foreign_keys=False, object_id_field="pk"):
def register(model):
# Prevent multiple registration.
if is_registered(model):
Expand All @@ -401,6 +411,7 @@ def register(model):
for_concrete_model=for_concrete_model,
ignore_duplicates=ignore_duplicates,
use_natural_foreign_keys=use_natural_foreign_keys,
object_id_field=object_id_field,
)
# Register the model.
_registered_models[_get_registration_key(model)] = version_options
Expand Down
2 changes: 0 additions & 2 deletions tests/test_app/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,4 @@ class TestModelAdmin(VersionAdmin):


admin.site.register(TestModel, TestModelAdmin)


admin.site.register(TestModelRelated, admin.ModelAdmin)
21 changes: 21 additions & 0 deletions tests/test_app/migrations/0003_testmodelcustomobjectid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Generated by Django 6.0.4 on 2026-05-02 12:08

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('test_app', '0002_alter_testmodel_related_and_more'),
]

operations = [
migrations.CreateModel(
name='TestModelCustomObjectId',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('slug', models.CharField(max_length=191, unique=True)),
('name', models.CharField(default='v1', max_length=191)),
],
),
]
5 changes: 5 additions & 0 deletions tests/test_app/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,8 @@ class TestModelWithUniqueConstraint(models.Model):
max_length=191,
unique=True,
)


class TestModelCustomObjectId(models.Model):
slug = models.CharField(max_length=191, unique=True)
name = models.CharField(max_length=191, default="v1")
71 changes: 70 additions & 1 deletion tests/test_app/tests/test_admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
import reversion
from reversion.admin import VersionAdmin
from reversion.models import Version
from test_app.models import TestModel, TestModelParent, TestModelInline, TestModelGenericInline, TestModelEscapePK
from test_app.models import (
TestModel, TestModelParent, TestModelInline, TestModelGenericInline, TestModelEscapePK, TestModelCustomObjectId
)
from test_app.tests.base import TestBase, LoginMixin


Expand Down Expand Up @@ -569,3 +571,70 @@ def testAutoRegisterInline(self):

def testAutoRegisterGenericInline(self):
self.assertTrue(reversion.is_registered(TestModelGenericInline))


class TestModelCustomObjectIdAdmin(VersionAdmin):

def reversion_register(self, model, **kwargs):
kwargs["object_id_field"] = "slug"
super().reversion_register(model, **kwargs)


class CustomObjectIdAdminMixin(TestBase):

def setUp(self):
super().setUp()
admin.site.register(TestModelCustomObjectId, TestModelCustomObjectIdAdmin)
self.reloadUrls()

def tearDown(self):
super().tearDown()
admin.site.unregister(TestModelCustomObjectId)
self.reloadUrls()


class CustomObjectIdAdminHistoryViewTest(LoginMixin, CustomObjectIdAdminMixin, TestBase):

def testHistoryView(self):
with reversion.create_revision():
obj = TestModelCustomObjectId.objects.create(slug="test", name="v1")
with reversion.create_revision():
obj.name = "v2"
obj.save()
response = self.client.get(resolve_url("admin:test_app_testmodelcustomobjectid_history", obj.pk))
self.assertEqual(response.status_code, 200)
self.assertEqual(len(response.context["action_list"]), 2)


class CustomObjectIdAdminRevisionViewTest(LoginMixin, CustomObjectIdAdminMixin, TestBase):

def setUp(self):
super().setUp()
with reversion.create_revision():
self.obj = TestModelCustomObjectId.objects.create(slug="test", name="v1")
with reversion.create_revision():
self.obj.name = "v2"
self.obj.save()

def testRevisionViewGet(self):
version = Version.objects.get_for_object_reference(TestModelCustomObjectId, "test")[1]
response = self.client.get(resolve_url(
"admin:test_app_testmodelcustomobjectid_revision",
self.obj.slug,
version.pk,
))
self.assertEqual(response.status_code, 200)
self.assertContains(response, 'value="v1"')
# Verify the revert was rolled back (GET should not persist changes).
self.obj.refresh_from_db()
self.assertEqual(self.obj.name, "v2")

def testRevisionViewRevert(self):
version = Version.objects.get_for_object_reference(TestModelCustomObjectId, "test")[1]
self.client.post(resolve_url(
"admin:test_app_testmodelcustomobjectid_revision",
self.obj.slug,
version.pk,
), {"slug": "test", "name": "v1"})
self.obj.refresh_from_db()
self.assertEqual(self.obj.name, "v1")
72 changes: 72 additions & 0 deletions tests/test_app/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
TestModelNestedInline,
TestModelInlineByNaturalKey, TestModelWithNaturalKey,
TestModelWithUniqueConstraint,
TestModelCustomObjectId
)
from test_app.tests.base import TestBase, TestModelMixin, TestModelParentMixin
import json
Expand Down Expand Up @@ -458,3 +459,74 @@ def testTransactionInRollbackState(self):
TestModelWithUniqueConstraint.objects.create(name='A')
except Exception:
pass


class CustomObjectIdTest(TestBase):

def setUp(self):
super().setUp()
reversion.register(TestModelCustomObjectId, object_id_field='slug')

def testObjectIdStoredAsSlug(self):
with reversion.create_revision():
TestModelCustomObjectId.objects.create(slug='custom-id', name='v1')
version = Version.objects.get_for_object_reference(TestModelCustomObjectId, 'custom-id').get()
self.assertEqual(version.object_id, 'custom-id')

def testGetForObject(self):
with reversion.create_revision():
obj = TestModelCustomObjectId.objects.create(slug='custom-id', name='v1')
self.assertEqual(Version.objects.get_for_object(obj).count(), 1)

def testFieldDict(self):
with reversion.create_revision():
obj = TestModelCustomObjectId.objects.create(slug='custom-id', name='v1')
version = Version.objects.get_for_object_reference(TestModelCustomObjectId, 'custom-id').get()
self.assertEqual(version.field_dict, {'id': obj.pk, 'slug': 'custom-id', 'name': 'v1'})

def testMultipleVersions(self):
with reversion.create_revision():
obj = TestModelCustomObjectId.objects.create(slug='custom-id', name='v1')
with reversion.create_revision():
obj.name = 'v2'
obj.save()
versions = Version.objects.get_for_object_reference(TestModelCustomObjectId, 'custom-id')
self.assertEqual(versions.count(), 2)
self.assertEqual(versions[0].field_dict['name'], 'v2')
self.assertEqual(versions[1].field_dict['name'], 'v1')

def testGetDeleted(self):
with reversion.create_revision():
obj = TestModelCustomObjectId.objects.create(slug='custom-id', name='v1')
obj.delete()
deleted = Version.objects.get_deleted(TestModelCustomObjectId)
self.assertEqual(deleted.count(), 1)
self.assertEqual(deleted[0].object_id, 'custom-id')

def testGetDeletedNotIncludingExisting(self):
with reversion.create_revision():
TestModelCustomObjectId.objects.create(slug='custom-id', name='v1')
self.assertEqual(Version.objects.get_deleted(TestModelCustomObjectId).count(), 0)


class CustomObjectIdIgnoreDuplicatesTest(TestBase):

def setUp(self):
reversion.register(TestModelCustomObjectId, object_id_field='slug', ignore_duplicates=True)

def testIgnoreDuplicates(self):
with reversion.create_revision():
obj = TestModelCustomObjectId.objects.create(slug='custom-id', name='v1')
with reversion.create_revision():
obj.save()
versions = Version.objects.get_for_object_reference(TestModelCustomObjectId, 'custom-id')
self.assertEqual(versions.count(), 1)

def testIgnoreDuplicatesNewData(self):
with reversion.create_revision():
obj = TestModelCustomObjectId.objects.create(slug='custom-id', name='v1')
with reversion.create_revision():
obj.name = 'v2'
obj.save()
versions = Version.objects.get_for_object_reference(TestModelCustomObjectId, 'custom-id')
self.assertEqual(versions.count(), 2)
Loading