Skip to content

Commit 1b2e20e

Browse files
authored
Simplify using custom token classes in serializers (#517)
For most cases this could be done by overriding get_token, which is simple enough. The exception was TokenRefreshSerializer.validate where the entire method needed to be copy-pasted to allow using a custom replacement for RefreshToken. The other cases are changed the same way mainly for consistency.
1 parent 92124cf commit 1b2e20e

1 file changed

Lines changed: 10 additions & 12 deletions

File tree

rest_framework_simplejwt/serializers.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def __init__(self, *args, **kwargs):
2424

2525
class TokenObtainSerializer(serializers.Serializer):
2626
username_field = get_user_model().USERNAME_FIELD
27+
token_class = None
2728

2829
default_error_messages = {
2930
"no_active_account": _("No active account found with the given credentials")
@@ -57,15 +58,11 @@ def validate(self, attrs):
5758

5859
@classmethod
5960
def get_token(cls, user):
60-
raise NotImplementedError(
61-
"Must implement `get_token` method for `TokenObtainSerializer` subclasses"
62-
)
61+
return cls.token_class.for_user(user)
6362

6463

6564
class TokenObtainPairSerializer(TokenObtainSerializer):
66-
@classmethod
67-
def get_token(cls, user):
68-
return RefreshToken.for_user(user)
65+
token_class = RefreshToken
6966

7067
def validate(self, attrs):
7168
data = super().validate(attrs)
@@ -82,9 +79,7 @@ def validate(self, attrs):
8279

8380

8481
class TokenObtainSlidingSerializer(TokenObtainSerializer):
85-
@classmethod
86-
def get_token(cls, user):
87-
return SlidingToken.for_user(user)
82+
token_class = SlidingToken
8883

8984
def validate(self, attrs):
9085
data = super().validate(attrs)
@@ -102,9 +97,10 @@ def validate(self, attrs):
10297
class TokenRefreshSerializer(serializers.Serializer):
10398
refresh = serializers.CharField()
10499
access = serializers.CharField(read_only=True)
100+
token_class = RefreshToken
105101

106102
def validate(self, attrs):
107-
refresh = RefreshToken(attrs["refresh"])
103+
refresh = self.token_class(attrs["refresh"])
108104

109105
data = {"access": str(refresh.access_token)}
110106

@@ -129,9 +125,10 @@ def validate(self, attrs):
129125

130126
class TokenRefreshSlidingSerializer(serializers.Serializer):
131127
token = serializers.CharField()
128+
token_class = SlidingToken
132129

133130
def validate(self, attrs):
134-
token = SlidingToken(attrs["token"])
131+
token = self.token_class(attrs["token"])
135132

136133
# Check that the timestamp in the "refresh_exp" claim has not
137134
# passed
@@ -163,9 +160,10 @@ def validate(self, attrs):
163160

164161
class TokenBlacklistSerializer(serializers.Serializer):
165162
refresh = serializers.CharField()
163+
token_class = RefreshToken
166164

167165
def validate(self, attrs):
168-
refresh = RefreshToken(attrs["refresh"])
166+
refresh = self.token_class(attrs["refresh"])
169167
try:
170168
refresh.blacklist()
171169
except AttributeError:

0 commit comments

Comments
 (0)