diff --git a/docs/api.rst b/docs/api.rst index 29d153dd..40c7fbc6 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -187,6 +187,11 @@ Registration API See `Serialization of natural keys `_ + ``object_id_field=None`` + The name of the model field to use as the version's ``object_id``. Defaults to the model's primary key field. Use this when you want versions to be keyed on a different unique field, such as a slug. + + The field must exist on the model (validated at registration time). ``"pk"`` is not a valid value — use the concrete field name instead (e.g. ``"id"``). + .. Hint:: By default, django-reversion will not register any parent classes of a model that uses multi-table inheritance. If you wish to also add parent models to your revision, you must explicitly add their ``parent_ptr`` fields to the ``follow`` parameter when you register the model. diff --git a/reversion/admin.py b/reversion/admin.py index 8b9f17d2..a6c3b7bb 100644 --- a/reversion/admin.py +++ b/reversion/admin.py @@ -18,7 +18,7 @@ from reversion.errors import RevertError from reversion.models import Version -from reversion.revisions import is_active, register, is_registered, set_comment, create_revision, set_user +from reversion.revisions import is_active, register, is_registered, set_comment, create_revision, set_user, _get_options from reversion.utils import mute_signals @@ -200,7 +200,18 @@ def _reversion_revisionform_view(self, request, version, template_name, extra_co version.revision.revert(delete=True) # Run the normal changeform view. with self.create_revision(request): - response = self.changeform_view(request, quote(version.object_id), request.path, extra_context) + opts = _get_options(version._model) + if opts.object_id_field == version._model._meta.pk.attname: + obj_pk = version.object_id + else: + obj = get_object_or_404( + version._model._default_manager.using(version.db), + **{opts.object_id_field: version.object_id}, + ) + obj_pk = str(obj.pk) + response = self.changeform_view( + request, quote(obj_pk), request.path, extra_context + ) # Decide on whether the keep the changes. if request.method == "POST" and response.status_code == 302: set_comment(_("Reverted to previous version, saved on %(datetime)s") % { @@ -305,6 +316,12 @@ def history_view(self, request, object_id, extra_context=None): if not self.has_change_permission(request): raise PermissionDenied + version_opts = _get_options(self.model) + if version_opts.object_id_field == self.model._meta.pk.attname: + reversion_object_id = unquote(object_id) + else: + obj = get_object_or_404(self.model, pk=unquote(object_id)) + reversion_object_id = str(getattr(obj, version_opts.object_id_field)) opts = self.model._meta action_list = [ { @@ -317,7 +334,7 @@ def history_view(self, request, object_id, extra_context=None): for version in self._reversion_order_version_queryset(request, Version.objects.get_for_object_reference( self.model, - unquote(object_id), # Underscores in primary key get quoted to "_5F" + reversion_object_id, ).select_related("revision", "revision__user")) ] # Compile the context. diff --git a/reversion/models.py b/reversion/models.py index c259a686..40e4c02a 100644 --- a/reversion/models.py +++ b/reversion/models.py @@ -19,8 +19,8 @@ from django.utils.translation import gettext_lazy as _ from reversion.errors import RevertError -from reversion.revisions import (_follow_relations_recursive, - _get_content_type, _get_options) +from reversion.revisions import (_follow_relations_recursive, _get_content_type, + _get_options) logger = logging.getLogger(__name__) @@ -91,7 +91,10 @@ def revert(self, delete=False): model = version._model try: # Load the model instance from the same DB as it was saved under. - old_revision.add(model._default_manager.using(version.db).get(pk=version.object_id)) + id_field = _get_options(model).object_id_field + old_revision.add( + model._default_manager.using(version.db).get(**{id_field: version.object_id}) + ) except model.DoesNotExist: pass # Calculate the set of all objects that are in the revision now. @@ -135,14 +138,17 @@ def get_for_object_reference(self, model, object_id, model_db=None): ) def get_for_object(self, obj, model_db=None): - return self.get_for_object_reference(obj.__class__, obj.pk, model_db=model_db) + opts = _get_options(obj.__class__) + return self.get_for_object_reference( + obj.__class__, getattr(obj, opts.object_id_field), model_db=model_db + ) def get_deleted(self, model, model_db=None): model_db = model_db or router.db_for_write(model) connection = connections[self.db] + object_id_field_name = _get_options(model).object_id_field if self.db == model_db and connection.vendor in ("sqlite", "postgresql", "oracle"): - pk_field_name = model._meta.pk.name - object_id_cast_target = model._meta.get_field(pk_field_name) + object_id_cast_target = model._meta.get_field(object_id_field_name) if django.VERSION >= (2, 1): # django 2.0 contains a critical bug that doesn't allow the code below to work, # fallback to casting primary keys then @@ -158,14 +164,14 @@ def get_deleted(self, model, model_db=None): model_qs = ( model._default_manager .using(model_db) - .filter(**{pk_field_name: casted_object_id}) + .filter(**{object_id_field_name: casted_object_id}) ) else: model_qs = ( model._default_manager .using(model_db) - .annotate(_pk_to_object_id=Cast("pk", Version._meta.get_field("object_id"))) - .filter(_pk_to_object_id=models.OuterRef("object_id")) + .annotate(_field_to_object_id=Cast(object_id_field_name, Version._meta.get_field("object_id"))) + .filter(_field_to_object_id=models.OuterRef("object_id")) ) # conditional expressions are being supported since django 3.0 # DISTINCT ON works only for Postgres DB @@ -190,7 +196,9 @@ def get_deleted(self, model, model_db=None): # We have to use a slow subquery. subquery = self.get_for_model(model, model_db=model_db).exclude( object_id__in=list( - model._default_manager.using(model_db).values_list("pk", flat=True).order_by().iterator() + model._default_manager.using(model_db).values_list( + object_id_field_name, flat=True + ).order_by().iterator() ), ).values_list("object_id").annotate( latest_pk=models.Max("pk") diff --git a/reversion/revisions.py b/reversion/revisions.py index 5e2ed4f6..1420d26d 100644 --- a/reversion/revisions.py +++ b/reversion/revisions.py @@ -21,6 +21,7 @@ "for_concrete_model", "ignore_duplicates", "use_natural_foreign_keys", + "object_id_field" )) @@ -168,7 +169,7 @@ def _add_to_revision(obj, using, model_db, explicit): return version_options = _get_options(obj.__class__) content_type = _get_content_type(obj.__class__, using) - object_id = force_str(obj.pk) + object_id = force_str(getattr(obj, version_options.object_id_field)) version_key = (content_type, object_id) # If the obj is already in the revision, stop now. db_versions = _current_frame().db_versions @@ -191,7 +192,9 @@ def _add_to_revision(obj, using, model_db, explicit): ) # If the version is a duplicate, stop now. if version_options.ignore_duplicates and explicit: - previous_version = Version.objects.using(using).get_for_object(obj, model_db=model_db).first() + previous_version = Version.objects.using(using).get_for_object_reference( + obj.__class__, object_id, model_db=model_db + ).first() if previous_version and previous_version._local_field_dict == version._local_field_dict: return # Store the version. @@ -221,7 +224,9 @@ def _save_revision(versions, user=None, comment="", meta=(), date_created=None, model: { db: frozenset(map( force_str, - model._base_manager.using(db).filter(pk__in=pks).values_list("pk", flat=True), + model._base_manager.using(db).filter( + **{f"{_get_options(model).object_id_field}__in": pks} + ).values_list(_get_options(model).object_id_field, flat=True), )) for db, pks in db_pks.items() } @@ -376,7 +381,7 @@ def _get_senders_and_signals(model): def register(model=None, fields=None, exclude=(), follow=(), format="json", - for_concrete_model=True, ignore_duplicates=False, use_natural_foreign_keys=False): + for_concrete_model=True, ignore_duplicates=False, use_natural_foreign_keys=False, object_id_field=None): def register(model): # Prevent multiple registration. if is_registered(model): @@ -385,6 +390,12 @@ def register(model): )) # Parse fields. opts = model._meta.concrete_model._meta + if object_id_field is None: + id_field = model._meta.pk.attname + else: + model._meta.get_field(object_id_field) + id_field = object_id_field + version_options = _VersionOptions( fields=tuple( field_name @@ -401,6 +412,7 @@ def register(model): for_concrete_model=for_concrete_model, ignore_duplicates=ignore_duplicates, use_natural_foreign_keys=use_natural_foreign_keys, + object_id_field=id_field, ) # Register the model. _registered_models[_get_registration_key(model)] = version_options diff --git a/tests/test_app/admin.py b/tests/test_app/admin.py index c6641231..00874c43 100644 --- a/tests/test_app/admin.py +++ b/tests/test_app/admin.py @@ -9,6 +9,4 @@ class TestModelAdmin(VersionAdmin): admin.site.register(TestModel, TestModelAdmin) - - admin.site.register(TestModelRelated, admin.ModelAdmin) diff --git a/tests/test_app/migrations/0003_testmodelcustomobjectid.py b/tests/test_app/migrations/0003_testmodelcustomobjectid.py new file mode 100644 index 00000000..37ea99ef --- /dev/null +++ b/tests/test_app/migrations/0003_testmodelcustomobjectid.py @@ -0,0 +1,21 @@ +# Generated by Django 6.0.4 on 2026-05-02 12:08 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('test_app', '0002_alter_testmodel_related_and_more'), + ] + + operations = [ + migrations.CreateModel( + name='TestModelCustomObjectId', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('slug', models.CharField(max_length=191, unique=True)), + ('name', models.CharField(default='v1', max_length=191)), + ], + ), + ] diff --git a/tests/test_app/models.py b/tests/test_app/models.py index bfb3a16c..b404b08d 100644 --- a/tests/test_app/models.py +++ b/tests/test_app/models.py @@ -150,3 +150,8 @@ class TestModelWithUniqueConstraint(models.Model): max_length=191, unique=True, ) + + +class TestModelCustomObjectId(models.Model): + slug = models.CharField(max_length=191, unique=True) + name = models.CharField(max_length=191, default="v1") diff --git a/tests/test_app/tests/test_admin.py b/tests/test_app/tests/test_admin.py index 8d992bf8..0d89e8ba 100644 --- a/tests/test_app/tests/test_admin.py +++ b/tests/test_app/tests/test_admin.py @@ -9,7 +9,9 @@ import reversion from reversion.admin import VersionAdmin from reversion.models import Version -from test_app.models import TestModel, TestModelParent, TestModelInline, TestModelGenericInline, TestModelEscapePK +from test_app.models import ( + TestModel, TestModelParent, TestModelInline, TestModelGenericInline, TestModelEscapePK, TestModelCustomObjectId +) from test_app.tests.base import TestBase, LoginMixin @@ -569,3 +571,70 @@ def testAutoRegisterInline(self): def testAutoRegisterGenericInline(self): self.assertTrue(reversion.is_registered(TestModelGenericInline)) + + +class TestModelCustomObjectIdAdmin(VersionAdmin): + + def reversion_register(self, model, **kwargs): + kwargs["object_id_field"] = "slug" + super().reversion_register(model, **kwargs) + + +class CustomObjectIdAdminMixin(TestBase): + + def setUp(self): + super().setUp() + admin.site.register(TestModelCustomObjectId, TestModelCustomObjectIdAdmin) + self.reloadUrls() + + def tearDown(self): + super().tearDown() + admin.site.unregister(TestModelCustomObjectId) + self.reloadUrls() + + +class CustomObjectIdAdminHistoryViewTest(LoginMixin, CustomObjectIdAdminMixin, TestBase): + + def testHistoryView(self): + with reversion.create_revision(): + obj = TestModelCustomObjectId.objects.create(slug="test", name="v1") + with reversion.create_revision(): + obj.name = "v2" + obj.save() + response = self.client.get(resolve_url("admin:test_app_testmodelcustomobjectid_history", obj.pk)) + self.assertEqual(response.status_code, 200) + self.assertEqual(len(response.context["action_list"]), 2) + + +class CustomObjectIdAdminRevisionViewTest(LoginMixin, CustomObjectIdAdminMixin, TestBase): + + def setUp(self): + super().setUp() + with reversion.create_revision(): + self.obj = TestModelCustomObjectId.objects.create(slug="test", name="v1") + with reversion.create_revision(): + self.obj.name = "v2" + self.obj.save() + + def testRevisionViewGet(self): + version = Version.objects.get_for_object_reference(TestModelCustomObjectId, "test")[1] + response = self.client.get(resolve_url( + "admin:test_app_testmodelcustomobjectid_revision", + self.obj.slug, + version.pk, + )) + self.assertEqual(response.status_code, 200) + self.assertContains(response, 'value="v1"') + # Verify the revert was rolled back (GET should not persist changes). + self.obj.refresh_from_db() + self.assertEqual(self.obj.name, "v2") + + def testRevisionViewRevert(self): + version = Version.objects.get_for_object_reference(TestModelCustomObjectId, "test")[1] + self.client.post(resolve_url( + "admin:test_app_testmodelcustomobjectid_revision", + self.obj.slug, + version.pk, + ), {"slug": "test", "name": "v1"}) + self.obj.refresh_from_db() + self.assertEqual(self.obj.name, "v1") diff --git a/tests/test_app/tests/test_models.py b/tests/test_app/tests/test_models.py index 2d2ca0a9..7169751a 100644 --- a/tests/test_app/tests/test_models.py +++ b/tests/test_app/tests/test_models.py @@ -5,6 +5,7 @@ TestModelNestedInline, TestModelInlineByNaturalKey, TestModelWithNaturalKey, TestModelWithUniqueConstraint, + TestModelCustomObjectId ) from test_app.tests.base import TestBase, TestModelMixin, TestModelParentMixin import json @@ -458,3 +459,74 @@ def testTransactionInRollbackState(self): TestModelWithUniqueConstraint.objects.create(name='A') except Exception: pass + + +class CustomObjectIdTest(TestBase): + + def setUp(self): + super().setUp() + reversion.register(TestModelCustomObjectId, object_id_field='slug') + + def testObjectIdStoredAsSlug(self): + with reversion.create_revision(): + TestModelCustomObjectId.objects.create(slug='custom-id', name='v1') + version = Version.objects.get_for_object_reference(TestModelCustomObjectId, 'custom-id').get() + self.assertEqual(version.object_id, 'custom-id') + + def testGetForObject(self): + with reversion.create_revision(): + obj = TestModelCustomObjectId.objects.create(slug='custom-id', name='v1') + self.assertEqual(Version.objects.get_for_object(obj).count(), 1) + + def testFieldDict(self): + with reversion.create_revision(): + obj = TestModelCustomObjectId.objects.create(slug='custom-id', name='v1') + version = Version.objects.get_for_object_reference(TestModelCustomObjectId, 'custom-id').get() + self.assertEqual(version.field_dict, {'id': obj.pk, 'slug': 'custom-id', 'name': 'v1'}) + + def testMultipleVersions(self): + with reversion.create_revision(): + obj = TestModelCustomObjectId.objects.create(slug='custom-id', name='v1') + with reversion.create_revision(): + obj.name = 'v2' + obj.save() + versions = Version.objects.get_for_object_reference(TestModelCustomObjectId, 'custom-id') + self.assertEqual(versions.count(), 2) + self.assertEqual(versions[0].field_dict['name'], 'v2') + self.assertEqual(versions[1].field_dict['name'], 'v1') + + def testGetDeleted(self): + with reversion.create_revision(): + obj = TestModelCustomObjectId.objects.create(slug='custom-id', name='v1') + obj.delete() + deleted = Version.objects.get_deleted(TestModelCustomObjectId) + self.assertEqual(deleted.count(), 1) + self.assertEqual(deleted[0].object_id, 'custom-id') + + def testGetDeletedNotIncludingExisting(self): + with reversion.create_revision(): + TestModelCustomObjectId.objects.create(slug='custom-id', name='v1') + self.assertEqual(Version.objects.get_deleted(TestModelCustomObjectId).count(), 0) + + +class CustomObjectIdIgnoreDuplicatesTest(TestBase): + + def setUp(self): + reversion.register(TestModelCustomObjectId, object_id_field='slug', ignore_duplicates=True) + + def testIgnoreDuplicates(self): + with reversion.create_revision(): + obj = TestModelCustomObjectId.objects.create(slug='custom-id', name='v1') + with reversion.create_revision(): + obj.save() + versions = Version.objects.get_for_object_reference(TestModelCustomObjectId, 'custom-id') + self.assertEqual(versions.count(), 1) + + def testIgnoreDuplicatesNewData(self): + with reversion.create_revision(): + obj = TestModelCustomObjectId.objects.create(slug='custom-id', name='v1') + with reversion.create_revision(): + obj.name = 'v2' + obj.save() + versions = Version.objects.get_for_object_reference(TestModelCustomObjectId, 'custom-id') + self.assertEqual(versions.count(), 2)