Skip to content

Commit fdd85f6

Browse files
committed
account for UniqueTogetherValidator having a code attribute starting drf 3.17
1 parent b5c953e commit fdd85f6

2 files changed

Lines changed: 53 additions & 1 deletion

File tree

drf_standardized_errors/openapi_utils.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from dataclasses import dataclass, field as dataclass_field
22
from typing import Any, Dict, List, Optional, Set, Type, Union
33

4+
import django
5+
import rest_framework
46
from django import forms
57
from django.core.validators import (
68
DecimalValidator,
@@ -230,12 +232,18 @@ def add_unique_together_error_codes(
230232
sfields_with_error_codes: "List[InputDataField]",
231233
) -> None:
232234
for sfield in sfields_with_unique_together_validators:
233-
sfield.error_codes.add("unique")
234235
unique_together_validators = [
235236
validator
236237
for validator in sfield.field.validators
237238
if isinstance(validator, UniqueTogetherValidator)
238239
]
240+
if _drf_version() >= (3, 17) and django.VERSION >= (5, 0):
241+
# drf 3.17 passes the `custom_violation_error` added in django 5.0
242+
# to `drf.UniqueTogetherValidator`. Before that, the error code was
243+
# hardcoded as `"unique"`
244+
sfield.error_codes.update(v.code for v in unique_together_validators)
245+
else:
246+
sfield.error_codes.add("unique")
239247
# fields involved in a unique together constraint have an implied
240248
# "required" state, so we're adding the "required" error code to them
241249
implicitly_required_fields = set()
@@ -501,3 +509,9 @@ def get_example_from_exception(exc: exceptions.APIException) -> OpenApiExample:
501509
response_only=True,
502510
status_codes=[str(exc.status_code)],
503511
)
512+
513+
514+
def _drf_version():
515+
# we just care about major and minor drf versions
516+
parts = rest_framework.VERSION.split(".")
517+
return int(parts[0]), int(parts[1])

tests/test_openapi_utils.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from drf_standardized_errors.openapi_utils import (
1818
InputDataField,
19+
_drf_version,
1920
get_django_filter_backends,
2021
get_error_serializer,
2122
get_filter_forms,
@@ -372,6 +373,43 @@ def test_unique_together_error_codes(unique_together):
372373
assert "required" in model.error_codes
373374

374375

376+
@pytest.fixture
377+
def unique_together_with_violation_code():
378+
from django.db import models
379+
380+
class SomeModel(models.Model):
381+
app_label = models.CharField(max_length=100)
382+
model = models.CharField(max_length=100)
383+
384+
class Meta:
385+
constraints = [
386+
models.UniqueConstraint(
387+
fields=["app_label", "model"],
388+
name="unique_model",
389+
violation_error_code="custom_violation_code",
390+
)
391+
]
392+
393+
class SomeSerializer(serializers.ModelSerializer):
394+
class Meta:
395+
model = SomeModel
396+
fields = ["app_label", "model"]
397+
398+
return get_flat_serializer_fields(SomeSerializer())
399+
400+
401+
@pytest.mark.skipif(
402+
_drf_version() < (3, 17) or django.VERSION < (5, 0),
403+
reason="django added violation_error_code in v5 and drf supported it in v3.17",
404+
)
405+
def test_unique_together_new(unique_together_with_violation_code):
406+
non_field_errors, _, __ = get_serializer_fields_with_error_codes(
407+
unique_together_with_violation_code
408+
)
409+
410+
assert "custom_violation_code" in non_field_errors.error_codes
411+
412+
375413
class PostSerializer(serializers.ModelSerializer):
376414
"""
377415
Intentional required=False to test that the 'required' error code is added

0 commit comments

Comments
 (0)