Skip to content

Commit 2c6e322

Browse files
kirawarepre-commit-ci[bot]
authored andcommitted
Improve testing (jazzband#688)
* Support `override_api_settings` as decorator * Update test_authentication * black formatting test_authentication * Use drf status instead of literal status * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update test_integration * Update test_serializers * Update test_integration * Update test_token_blacklist * Update test_tokens * Update test_views * add `setUpTestData` to `TestToken` * fix typo `self` should be `cls` --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent bb0b928 commit 2c6e322

7 files changed

Lines changed: 154 additions & 155 deletions

File tree

tests/test_authentication.py

Lines changed: 51 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -41,25 +41,19 @@ def test_get_header(self):
4141
)
4242
self.assertEqual(self.backend.get_header(request), self.fake_header)
4343

44-
# Should work with the x_access_token
45-
with override_api_settings(AUTH_HEADER_NAME="HTTP_X_ACCESS_TOKEN"):
46-
# Should pull correct header off request when using X_ACCESS_TOKEN
47-
request = self.factory.get(
48-
"/test-url/", HTTP_X_ACCESS_TOKEN=self.fake_header
49-
)
50-
self.assertEqual(self.backend.get_header(request), self.fake_header)
51-
52-
# Should work for unicode headers when using
53-
request = self.factory.get(
54-
"/test-url/", HTTP_X_ACCESS_TOKEN=self.fake_header.decode("utf-8")
55-
)
56-
self.assertEqual(self.backend.get_header(request), self.fake_header)
44+
@override_api_settings(AUTH_HEADER_NAME="HTTP_X_ACCESS_TOKEN")
45+
def test_get_header_x_access_token(self):
46+
# Should pull correct header off request when using X_ACCESS_TOKEN
47+
request = self.factory.get("/test-url/", HTTP_X_ACCESS_TOKEN=self.fake_header)
48+
self.assertEqual(self.backend.get_header(request), self.fake_header)
49+
50+
# Should work for unicode headers when using
51+
request = self.factory.get(
52+
"/test-url/", HTTP_X_ACCESS_TOKEN=self.fake_header.decode("utf-8")
53+
)
54+
self.assertEqual(self.backend.get_header(request), self.fake_header)
5755

5856
def test_get_raw_token(self):
59-
# Should return None if header lacks correct type keyword
60-
with override_api_settings(AUTH_HEADER_TYPES="JWT"):
61-
reload(authentication)
62-
self.assertIsNone(self.backend.get_raw_token(self.fake_header))
6357
reload(authentication)
6458

6559
# Should return None if an empty AUTHORIZATION header is sent
@@ -75,14 +69,21 @@ def test_get_raw_token(self):
7569
# Otherwise, should return unvalidated token in header
7670
self.assertEqual(self.backend.get_raw_token(self.fake_header), self.fake_token)
7771

72+
@override_api_settings(AUTH_HEADER_TYPES="JWT")
73+
def test_get_raw_token_incorrect_header_keyword(self):
74+
# Should return None if header lacks correct type keyword
75+
# AUTH_HEADER_TYPES is "JWT", but header is "Bearer"
76+
reload(authentication)
77+
self.assertIsNone(self.backend.get_raw_token(self.fake_header))
78+
79+
@override_api_settings(AUTH_HEADER_TYPES=("JWT", "Bearer"))
80+
def test_get_raw_token_multi_header_keyword(self):
7881
# Should return token if header has one of many valid token types
79-
with override_api_settings(AUTH_HEADER_TYPES=("JWT", "Bearer")):
80-
reload(authentication)
81-
self.assertEqual(
82-
self.backend.get_raw_token(self.fake_header),
83-
self.fake_token,
84-
)
8582
reload(authentication)
83+
self.assertEqual(
84+
self.backend.get_raw_token(self.fake_header),
85+
self.fake_token,
86+
)
8687

8788
def test_get_validated_token(self):
8889
# Should raise InvalidToken if token not valid
@@ -97,36 +98,39 @@ def test_get_validated_token(self):
9798
self.backend.get_validated_token(str(token)).payload, token.payload
9899
)
99100

101+
@override_api_settings(
102+
AUTH_TOKEN_CLASSES=("rest_framework_simplejwt.tokens.AccessToken",),
103+
)
104+
def test_get_validated_token_reject_unknown_token(self):
100105
# Should not accept tokens not included in AUTH_TOKEN_CLASSES
101106
sliding_token = SlidingToken()
102-
with override_api_settings(
103-
AUTH_TOKEN_CLASSES=("rest_framework_simplejwt.tokens.AccessToken",)
104-
):
105-
with self.assertRaises(InvalidToken) as e:
106-
self.backend.get_validated_token(str(sliding_token))
107-
108-
messages = e.exception.detail["messages"]
109-
self.assertEqual(1, len(messages))
110-
self.assertEqual(
111-
{
112-
"token_class": "AccessToken",
113-
"token_type": "access",
114-
"message": "Token has wrong type",
115-
},
116-
messages[0],
117-
)
107+
with self.assertRaises(InvalidToken) as e:
108+
self.backend.get_validated_token(str(sliding_token))
109+
110+
messages = e.exception.detail["messages"]
111+
self.assertEqual(1, len(messages))
112+
self.assertEqual(
113+
{
114+
"token_class": "AccessToken",
115+
"token_type": "access",
116+
"message": "Token has wrong type",
117+
},
118+
messages[0],
119+
)
118120

121+
@override_api_settings(
122+
AUTH_TOKEN_CLASSES=(
123+
"rest_framework_simplejwt.tokens.AccessToken",
124+
"rest_framework_simplejwt.tokens.SlidingToken",
125+
),
126+
)
127+
def test_get_validated_token_accept_known_token(self):
119128
# Should accept tokens included in AUTH_TOKEN_CLASSES
120129
access_token = AccessToken()
121130
sliding_token = SlidingToken()
122-
with override_api_settings(
123-
AUTH_TOKEN_CLASSES=(
124-
"rest_framework_simplejwt.tokens.AccessToken",
125-
"rest_framework_simplejwt.tokens.SlidingToken",
126-
)
127-
):
128-
self.backend.get_validated_token(str(access_token))
129-
self.backend.get_validated_token(str(sliding_token))
131+
132+
self.backend.get_validated_token(str(access_token))
133+
self.backend.get_validated_token(str(sliding_token))
130134

131135
def test_get_user(self):
132136
payload = {"some_other_id": "foo"}

tests/test_integration.py

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from django.contrib.auth import get_user_model
44
from django.urls import reverse
5+
from rest_framework.status import HTTP_200_OK, HTTP_401_UNAUTHORIZED
56

67
from rest_framework_simplejwt.settings import api_settings
78
from rest_framework_simplejwt.tokens import AccessToken
@@ -26,7 +27,7 @@ def setUp(self):
2627
def test_no_authorization(self):
2728
res = self.view_get()
2829

29-
self.assertEqual(res.status_code, 401)
30+
self.assertEqual(res.status_code, HTTP_401_UNAUTHORIZED)
3031
self.assertIn("credentials were not provided", res.data["detail"])
3132

3233
def test_wrong_auth_type(self):
@@ -43,9 +44,12 @@ def test_wrong_auth_type(self):
4344

4445
res = self.view_get()
4546

46-
self.assertEqual(res.status_code, 401)
47+
self.assertEqual(res.status_code, HTTP_401_UNAUTHORIZED)
4748
self.assertIn("credentials were not provided", res.data["detail"])
4849

50+
@override_api_settings(
51+
AUTH_TOKEN_CLASSES=("rest_framework_simplejwt.tokens.AccessToken",),
52+
)
4953
def test_expired_token(self):
5054
old_lifetime = AccessToken.lifetime
5155
AccessToken.lifetime = timedelta(seconds=0)
@@ -63,14 +67,14 @@ def test_expired_token(self):
6367
access = res.data["access"]
6468
self.authenticate_with_token(api_settings.AUTH_HEADER_TYPES[0], access)
6569

66-
with override_api_settings(
67-
AUTH_TOKEN_CLASSES=("rest_framework_simplejwt.tokens.AccessToken",)
68-
):
69-
res = self.view_get()
70+
res = self.view_get()
7071

71-
self.assertEqual(res.status_code, 401)
72+
self.assertEqual(res.status_code, HTTP_401_UNAUTHORIZED)
7273
self.assertEqual("token_not_valid", res.data["code"])
7374

75+
@override_api_settings(
76+
AUTH_TOKEN_CLASSES=("rest_framework_simplejwt.tokens.SlidingToken",),
77+
)
7478
def test_user_can_get_sliding_token_and_use_it(self):
7579
res = self.client.post(
7680
reverse("token_obtain_sliding"),
@@ -83,14 +87,14 @@ def test_user_can_get_sliding_token_and_use_it(self):
8387
token = res.data["token"]
8488
self.authenticate_with_token(api_settings.AUTH_HEADER_TYPES[0], token)
8589

86-
with override_api_settings(
87-
AUTH_TOKEN_CLASSES=("rest_framework_simplejwt.tokens.SlidingToken",)
88-
):
89-
res = self.view_get()
90+
res = self.view_get()
9091

91-
self.assertEqual(res.status_code, 200)
92+
self.assertEqual(res.status_code, HTTP_200_OK)
9293
self.assertEqual(res.data["foo"], "bar")
9394

95+
@override_api_settings(
96+
AUTH_TOKEN_CLASSES=("rest_framework_simplejwt.tokens.AccessToken",),
97+
)
9498
def test_user_can_get_access_and_refresh_tokens_and_use_them(self):
9599
res = self.client.post(
96100
reverse("token_obtain_pair"),
@@ -105,12 +109,9 @@ def test_user_can_get_access_and_refresh_tokens_and_use_them(self):
105109

106110
self.authenticate_with_token(api_settings.AUTH_HEADER_TYPES[0], access)
107111

108-
with override_api_settings(
109-
AUTH_TOKEN_CLASSES=("rest_framework_simplejwt.tokens.AccessToken",)
110-
):
111-
res = self.view_get()
112+
res = self.view_get()
112113

113-
self.assertEqual(res.status_code, 200)
114+
self.assertEqual(res.status_code, HTTP_200_OK)
114115
self.assertEqual(res.data["foo"], "bar")
115116

116117
res = self.client.post(
@@ -122,10 +123,7 @@ def test_user_can_get_access_and_refresh_tokens_and_use_them(self):
122123

123124
self.authenticate_with_token(api_settings.AUTH_HEADER_TYPES[0], access)
124125

125-
with override_api_settings(
126-
AUTH_TOKEN_CLASSES=("rest_framework_simplejwt.tokens.AccessToken",)
127-
):
128-
res = self.view_get()
126+
res = self.view_get()
129127

130-
self.assertEqual(res.status_code, 200)
128+
self.assertEqual(res.status_code, HTTP_200_OK)
131129
self.assertEqual(res.data["foo"], "bar")

tests/test_serializers.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,10 @@ def test_it_should_return_access_token_if_everything_ok(self):
285285
access["exp"], datetime_to_epoch(now + api_settings.ACCESS_TOKEN_LIFETIME)
286286
)
287287

288+
@override_api_settings(
289+
ROTATE_REFRESH_TOKENS=True,
290+
BLACKLIST_AFTER_ROTATION=False,
291+
)
288292
def test_it_should_return_refresh_token_if_tokens_should_be_rotated(self):
289293
refresh = RefreshToken()
290294

@@ -298,14 +302,9 @@ def test_it_should_return_refresh_token_if_tokens_should_be_rotated(self):
298302

299303
now = aware_utcnow() - api_settings.ACCESS_TOKEN_LIFETIME / 2
300304

301-
with override_api_settings(
302-
ROTATE_REFRESH_TOKENS=True, BLACKLIST_AFTER_ROTATION=False
303-
):
304-
with patch(
305-
"rest_framework_simplejwt.tokens.aware_utcnow"
306-
) as fake_aware_utcnow:
307-
fake_aware_utcnow.return_value = now
308-
self.assertTrue(ser.is_valid())
305+
with patch("rest_framework_simplejwt.tokens.aware_utcnow") as fake_aware_utcnow:
306+
fake_aware_utcnow.return_value = now
307+
self.assertTrue(ser.is_valid())
309308

310309
access = AccessToken(ser.validated_data["access"])
311310
new_refresh = RefreshToken(ser.validated_data["refresh"])
@@ -324,6 +323,10 @@ def test_it_should_return_refresh_token_if_tokens_should_be_rotated(self):
324323
datetime_to_epoch(now + api_settings.REFRESH_TOKEN_LIFETIME),
325324
)
326325

326+
@override_api_settings(
327+
ROTATE_REFRESH_TOKENS=True,
328+
BLACKLIST_AFTER_ROTATION=True,
329+
)
327330
def test_it_should_blacklist_refresh_token_if_tokens_should_be_rotated_and_blacklisted(
328331
self,
329332
):
@@ -342,14 +345,9 @@ def test_it_should_blacklist_refresh_token_if_tokens_should_be_rotated_and_black
342345

343346
now = aware_utcnow() - api_settings.ACCESS_TOKEN_LIFETIME / 2
344347

345-
with override_api_settings(
346-
ROTATE_REFRESH_TOKENS=True, BLACKLIST_AFTER_ROTATION=True
347-
):
348-
with patch(
349-
"rest_framework_simplejwt.tokens.aware_utcnow"
350-
) as fake_aware_utcnow:
351-
fake_aware_utcnow.return_value = now
352-
self.assertTrue(ser.is_valid())
348+
with patch("rest_framework_simplejwt.tokens.aware_utcnow") as fake_aware_utcnow:
349+
fake_aware_utcnow.return_value = now
350+
self.assertTrue(ser.is_valid())
353351

354352
access = AccessToken(ser.validated_data["access"])
355353
new_refresh = RefreshToken(ser.validated_data["refresh"])

tests/test_token_blacklist.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -237,25 +237,25 @@ def setUp(self):
237237

238238
super().setUp()
239239

240+
@override_api_settings(BLACKLIST_AFTER_ROTATION=True)
240241
def test_token_verify_serializer_should_honour_blacklist_if_blacklisting_enabled(
241242
self,
242243
):
243-
with override_api_settings(BLACKLIST_AFTER_ROTATION=True):
244-
refresh_token = RefreshToken.for_user(self.user)
245-
refresh_token.blacklist()
244+
refresh_token = RefreshToken.for_user(self.user)
245+
refresh_token.blacklist()
246246

247-
serializer = TokenVerifySerializer(data={"token": str(refresh_token)})
248-
self.assertFalse(serializer.is_valid())
247+
serializer = TokenVerifySerializer(data={"token": str(refresh_token)})
248+
self.assertFalse(serializer.is_valid())
249249

250+
@override_api_settings(BLACKLIST_AFTER_ROTATION=False)
250251
def test_token_verify_serializer_should_not_honour_blacklist_if_blacklisting_not_enabled(
251252
self,
252253
):
253-
with override_api_settings(BLACKLIST_AFTER_ROTATION=False):
254-
refresh_token = RefreshToken.for_user(self.user)
255-
refresh_token.blacklist()
254+
refresh_token = RefreshToken.for_user(self.user)
255+
refresh_token.blacklist()
256256

257-
serializer = TokenVerifySerializer(data={"token": str(refresh_token)})
258-
self.assertTrue(serializer.is_valid())
257+
serializer = TokenVerifySerializer(data={"token": str(refresh_token)})
258+
self.assertTrue(serializer.is_valid())
259259

260260

261261
class TestBigAutoFieldIDMigration(MigrationTestCase):

tests/test_tokens.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,14 @@ class TestToken(TestCase):
3131
def setUp(self):
3232
self.token = MyToken()
3333

34+
@classmethod
35+
def setUpTestData(cls):
36+
cls.username = "test_user"
37+
cls.user = User.objects.create_user(
38+
username=cls.username,
39+
password="test_password",
40+
)
41+
3442
def test_init_no_token_type_or_lifetime(self):
3543
class MyTestToken(Token):
3644
pass
@@ -225,14 +233,14 @@ def test_set_jti(self):
225233
self.assertIn("jti", token)
226234
self.assertNotEqual(old_jti, token["jti"])
227235

236+
@override_api_settings(JTI_CLAIM=None)
228237
def test_optional_jti(self):
229-
with override_api_settings(JTI_CLAIM=None):
230-
token = MyToken()
238+
token = MyToken()
231239
self.assertNotIn("jti", token)
232240

241+
@override_api_settings(TOKEN_TYPE_CLAIM=None)
233242
def test_optional_type_token(self):
234-
with override_api_settings(TOKEN_TYPE_CLAIM=None):
235-
token = MyToken()
243+
token = MyToken()
236244
self.assertNotIn("type", token)
237245

238246
def test_set_exp(self):
@@ -355,25 +363,19 @@ def test_check_token_if_wrong_type_leeway(self):
355363
token.token_backend.leeway = 0
356364

357365
def test_for_user(self):
358-
username = "test_user"
359-
user = User.objects.create_user(
360-
username=username,
361-
password="test_password",
362-
)
366+
token = MyToken.for_user(self.user)
363367

364-
token = MyToken.for_user(user)
365-
366-
user_id = getattr(user, api_settings.USER_ID_FIELD)
368+
user_id = getattr(self.user, api_settings.USER_ID_FIELD)
367369
if not isinstance(user_id, int):
368370
user_id = str(user_id)
369371

370372
self.assertEqual(token[api_settings.USER_ID_CLAIM], user_id)
371373

374+
@override_api_settings(USER_ID_FIELD="username")
375+
def test_for_user_with_username(self):
372376
# Test with non-int user id
373-
with override_api_settings(USER_ID_FIELD="username"):
374-
token = MyToken.for_user(user)
375-
376-
self.assertEqual(token[api_settings.USER_ID_CLAIM], username)
377+
token = MyToken.for_user(self.user)
378+
self.assertEqual(token[api_settings.USER_ID_CLAIM], self.username)
377379

378380
def test_get_token_backend(self):
379381
token = MyToken()

0 commit comments

Comments
 (0)