@@ -41,25 +41,19 @@ def test_get_header(self):
4141 )
4242 self .assertEqual (self .backend .get_header (request ), self .fake_header )
4343
44- # Should work with the x_access_token
45- with override_api_settings (AUTH_HEADER_NAME = "HTTP_X_ACCESS_TOKEN" ):
46- # Should pull correct header off request when using X_ACCESS_TOKEN
47- request = self .factory .get (
48- "/test-url/" , HTTP_X_ACCESS_TOKEN = self .fake_header
49- )
50- self .assertEqual (self .backend .get_header (request ), self .fake_header )
51-
52- # Should work for unicode headers when using
53- request = self .factory .get (
54- "/test-url/" , HTTP_X_ACCESS_TOKEN = self .fake_header .decode ("utf-8" )
55- )
56- self .assertEqual (self .backend .get_header (request ), self .fake_header )
44+ @override_api_settings (AUTH_HEADER_NAME = "HTTP_X_ACCESS_TOKEN" )
45+ def test_get_header_x_access_token (self ):
46+ # Should pull correct header off request when using X_ACCESS_TOKEN
47+ request = self .factory .get ("/test-url/" , HTTP_X_ACCESS_TOKEN = self .fake_header )
48+ self .assertEqual (self .backend .get_header (request ), self .fake_header )
49+
50+ # Should work for unicode headers when using
51+ request = self .factory .get (
52+ "/test-url/" , HTTP_X_ACCESS_TOKEN = self .fake_header .decode ("utf-8" )
53+ )
54+ self .assertEqual (self .backend .get_header (request ), self .fake_header )
5755
5856 def test_get_raw_token (self ):
59- # Should return None if header lacks correct type keyword
60- with override_api_settings (AUTH_HEADER_TYPES = "JWT" ):
61- reload (authentication )
62- self .assertIsNone (self .backend .get_raw_token (self .fake_header ))
6357 reload (authentication )
6458
6559 # Should return None if an empty AUTHORIZATION header is sent
@@ -75,14 +69,21 @@ def test_get_raw_token(self):
7569 # Otherwise, should return unvalidated token in header
7670 self .assertEqual (self .backend .get_raw_token (self .fake_header ), self .fake_token )
7771
72+ @override_api_settings (AUTH_HEADER_TYPES = "JWT" )
73+ def test_get_raw_token_incorrect_header_keyword (self ):
74+ # Should return None if header lacks correct type keyword
75+ # AUTH_HEADER_TYPES is "JWT", but header is "Bearer"
76+ reload (authentication )
77+ self .assertIsNone (self .backend .get_raw_token (self .fake_header ))
78+
79+ @override_api_settings (AUTH_HEADER_TYPES = ("JWT" , "Bearer" ))
80+ def test_get_raw_token_multi_header_keyword (self ):
7881 # Should return token if header has one of many valid token types
79- with override_api_settings (AUTH_HEADER_TYPES = ("JWT" , "Bearer" )):
80- reload (authentication )
81- self .assertEqual (
82- self .backend .get_raw_token (self .fake_header ),
83- self .fake_token ,
84- )
8582 reload (authentication )
83+ self .assertEqual (
84+ self .backend .get_raw_token (self .fake_header ),
85+ self .fake_token ,
86+ )
8687
8788 def test_get_validated_token (self ):
8889 # Should raise InvalidToken if token not valid
@@ -97,36 +98,39 @@ def test_get_validated_token(self):
9798 self .backend .get_validated_token (str (token )).payload , token .payload
9899 )
99100
101+ @override_api_settings (
102+ AUTH_TOKEN_CLASSES = ("rest_framework_simplejwt.tokens.AccessToken" ,),
103+ )
104+ def test_get_validated_token_reject_unknown_token (self ):
100105 # Should not accept tokens not included in AUTH_TOKEN_CLASSES
101106 sliding_token = SlidingToken ()
102- with override_api_settings (
103- AUTH_TOKEN_CLASSES = ("rest_framework_simplejwt.tokens.AccessToken" ,)
104- ):
105- with self .assertRaises (InvalidToken ) as e :
106- self .backend .get_validated_token (str (sliding_token ))
107-
108- messages = e .exception .detail ["messages" ]
109- self .assertEqual (1 , len (messages ))
110- self .assertEqual (
111- {
112- "token_class" : "AccessToken" ,
113- "token_type" : "access" ,
114- "message" : "Token has wrong type" ,
115- },
116- messages [0 ],
117- )
107+ with self .assertRaises (InvalidToken ) as e :
108+ self .backend .get_validated_token (str (sliding_token ))
109+
110+ messages = e .exception .detail ["messages" ]
111+ self .assertEqual (1 , len (messages ))
112+ self .assertEqual (
113+ {
114+ "token_class" : "AccessToken" ,
115+ "token_type" : "access" ,
116+ "message" : "Token has wrong type" ,
117+ },
118+ messages [0 ],
119+ )
118120
121+ @override_api_settings (
122+ AUTH_TOKEN_CLASSES = (
123+ "rest_framework_simplejwt.tokens.AccessToken" ,
124+ "rest_framework_simplejwt.tokens.SlidingToken" ,
125+ ),
126+ )
127+ def test_get_validated_token_accept_known_token (self ):
119128 # Should accept tokens included in AUTH_TOKEN_CLASSES
120129 access_token = AccessToken ()
121130 sliding_token = SlidingToken ()
122- with override_api_settings (
123- AUTH_TOKEN_CLASSES = (
124- "rest_framework_simplejwt.tokens.AccessToken" ,
125- "rest_framework_simplejwt.tokens.SlidingToken" ,
126- )
127- ):
128- self .backend .get_validated_token (str (access_token ))
129- self .backend .get_validated_token (str (sliding_token ))
131+
132+ self .backend .get_validated_token (str (access_token ))
133+ self .backend .get_validated_token (str (sliding_token ))
130134
131135 def test_get_user (self ):
132136 payload = {"some_other_id" : "foo" }
0 commit comments