@@ -328,6 +328,61 @@ def nonce() -> str:
328328 self .assertEqual (400 , channel .code , msg = channel .json_body )
329329 self .assertEqual ("Invalid user type" , channel .json_body ["error" ])
330330
331+ @override_config (
332+ {
333+ "user_types" : {
334+ "extra_user_types" : ["extra1" , "extra2" ],
335+ }
336+ }
337+ )
338+ def test_extra_user_type (self ) -> None :
339+ """
340+ Check that the extra user type can be used when registering a user.
341+ """
342+
343+ def nonce_mac (user_type : str ) -> tuple [str , str ]:
344+ """
345+ Get a nonce and the expected HMAC for that nonce.
346+ """
347+ channel = self .make_request ("GET" , self .url )
348+ nonce = channel .json_body ["nonce" ]
349+
350+ want_mac = hmac .new (key = b"shared" , digestmod = hashlib .sha1 )
351+ want_mac .update (
352+ nonce .encode ("ascii" )
353+ + b"\x00 alice\x00 abc123\x00 notadmin\x00 "
354+ + user_type .encode ("ascii" )
355+ )
356+ want_mac_str = want_mac .hexdigest ()
357+
358+ return nonce , want_mac_str
359+
360+ nonce , mac = nonce_mac ("extra1" )
361+ # Valid user_type
362+ body = {
363+ "nonce" : nonce ,
364+ "username" : "alice" ,
365+ "password" : "abc123" ,
366+ "user_type" : "extra1" ,
367+ "mac" : mac ,
368+ }
369+ channel = self .make_request ("POST" , self .url , body )
370+ self .assertEqual (200 , channel .code , msg = channel .json_body )
371+
372+ nonce , mac = nonce_mac ("extra3" )
373+ # Invalid user_type
374+ body = {
375+ "nonce" : nonce ,
376+ "username" : "alice" ,
377+ "password" : "abc123" ,
378+ "user_type" : "extra3" ,
379+ "mac" : mac ,
380+ }
381+ channel = self .make_request ("POST" , self .url , body )
382+
383+ self .assertEqual (400 , channel .code , msg = channel .json_body )
384+ self .assertEqual ("Invalid user type" , channel .json_body ["error" ])
385+
331386 def test_displayname (self ) -> None :
332387 """
333388 Test that displayname of new user is set
@@ -1186,6 +1241,80 @@ def test_user_type(
11861241 not_user_types = ["custom" ],
11871242 )
11881243
1244+ @override_config (
1245+ {
1246+ "user_types" : {
1247+ "extra_user_types" : ["extra1" , "extra2" ],
1248+ }
1249+ }
1250+ )
1251+ def test_filter_not_user_types_with_extra (self ) -> None :
1252+ """Tests that the endpoint handles the not_user_types param when extra_user_types are configured"""
1253+
1254+ regular_user_id = self .register_user ("normalo" , "secret" )
1255+
1256+ extra1_user_id = self .register_user ("extra1" , "secret" )
1257+ self .make_request (
1258+ "PUT" ,
1259+ "/_synapse/admin/v2/users/" + urllib .parse .quote (extra1_user_id ),
1260+ {"user_type" : "extra1" },
1261+ access_token = self .admin_user_tok ,
1262+ )
1263+
1264+ def test_user_type (
1265+ expected_user_ids : List [str ], not_user_types : Optional [List [str ]] = None
1266+ ) -> None :
1267+ """Runs a test for the not_user_types param
1268+ Args:
1269+ expected_user_ids: Ids of the users that are expected to be returned
1270+ not_user_types: List of values for the not_user_types param
1271+ """
1272+
1273+ user_type_query = ""
1274+
1275+ if not_user_types is not None :
1276+ user_type_query = "&" .join (
1277+ [f"not_user_type={ u } " for u in not_user_types ]
1278+ )
1279+
1280+ test_url = f"{ self .url } ?{ user_type_query } "
1281+ channel = self .make_request (
1282+ "GET" ,
1283+ test_url ,
1284+ access_token = self .admin_user_tok ,
1285+ )
1286+
1287+ self .assertEqual (200 , channel .code )
1288+ self .assertEqual (channel .json_body ["total" ], len (expected_user_ids ))
1289+ self .assertEqual (
1290+ expected_user_ids ,
1291+ [u ["name" ] for u in channel .json_body ["users" ]],
1292+ )
1293+
1294+ # Request without user_types → all users expected
1295+ test_user_type ([self .admin_user , extra1_user_id , regular_user_id ])
1296+
1297+ # Request and exclude extra1 user type
1298+ test_user_type (
1299+ [self .admin_user , regular_user_id ],
1300+ not_user_types = ["extra1" ],
1301+ )
1302+
1303+ # Request and exclude extra1 and extra2 user types
1304+ test_user_type (
1305+ [self .admin_user , regular_user_id ],
1306+ not_user_types = ["extra1" , "extra2" ],
1307+ )
1308+
1309+ # Request and exclude empty user types → only expected the extra1 user
1310+ test_user_type ([extra1_user_id ], not_user_types = ["" ])
1311+
1312+ # Request and exclude an unregistered type → expect all users
1313+ test_user_type (
1314+ [self .admin_user , extra1_user_id , regular_user_id ],
1315+ not_user_types = ["extra3" ],
1316+ )
1317+
11891318 def test_erasure_status (self ) -> None :
11901319 # Create a new user.
11911320 user_id = self .register_user ("eraseme" , "eraseme" )
@@ -2977,22 +3106,18 @@ def test_set_user_as_admin(self) -> None:
29773106 self .assertEqual ("@user:test" , channel .json_body ["name" ])
29783107 self .assertTrue (channel .json_body ["admin" ])
29793108
2980- def test_set_user_type (self ) -> None :
2981- """
2982- Test changing user type.
2983- """
2984-
2985- # Set to support type
3109+ def set_user_type (self , user_type : Optional [str ]) -> None :
3110+ # Set to user_type
29863111 channel = self .make_request (
29873112 "PUT" ,
29883113 self .url_other_user ,
29893114 access_token = self .admin_user_tok ,
2990- content = {"user_type" : UserTypes . SUPPORT },
3115+ content = {"user_type" : user_type },
29913116 )
29923117
29933118 self .assertEqual (200 , channel .code , msg = channel .json_body )
29943119 self .assertEqual ("@user:test" , channel .json_body ["name" ])
2995- self .assertEqual (UserTypes . SUPPORT , channel .json_body ["user_type" ])
3120+ self .assertEqual (user_type , channel .json_body ["user_type" ])
29963121
29973122 # Get user
29983123 channel = self .make_request (
@@ -3003,30 +3128,44 @@ def test_set_user_type(self) -> None:
30033128
30043129 self .assertEqual (200 , channel .code , msg = channel .json_body )
30053130 self .assertEqual ("@user:test" , channel .json_body ["name" ])
3006- self .assertEqual (UserTypes .SUPPORT , channel .json_body ["user_type" ])
3131+ self .assertEqual (user_type , channel .json_body ["user_type" ])
3132+
3133+ def test_set_user_type (self ) -> None :
3134+ """
3135+ Test changing user type.
3136+ """
3137+
3138+ # Set to support type
3139+ self .set_user_type (UserTypes .SUPPORT )
30073140
30083141 # Change back to a regular user
3009- channel = self .make_request (
3010- "PUT" ,
3011- self .url_other_user ,
3012- access_token = self .admin_user_tok ,
3013- content = {"user_type" : None },
3014- )
3142+ self .set_user_type (None )
30153143
3016- self .assertEqual (200 , channel .code , msg = channel .json_body )
3017- self .assertEqual ("@user:test" , channel .json_body ["name" ])
3018- self .assertIsNone (channel .json_body ["user_type" ])
3144+ @override_config ({"user_types" : {"extra_user_types" : ["extra1" , "extra2" ]}})
3145+ def test_set_user_type_with_extras (self ) -> None :
3146+ """
3147+ Test changing user type with extra_user_types configured.
3148+ """
30193149
3020- # Get user
3150+ # Check that we can still set to support type
3151+ self .set_user_type (UserTypes .SUPPORT )
3152+
3153+ # Check that we can set to an extra user type
3154+ self .set_user_type ("extra2" )
3155+
3156+ # Change back to a regular user
3157+ self .set_user_type (None )
3158+
3159+ # Try setting to invalid type
30213160 channel = self .make_request (
3022- "GET " ,
3161+ "PUT " ,
30233162 self .url_other_user ,
30243163 access_token = self .admin_user_tok ,
3164+ content = {"user_type" : "extra3" },
30253165 )
30263166
3027- self .assertEqual (200 , channel .code , msg = channel .json_body )
3028- self .assertEqual ("@user:test" , channel .json_body ["name" ])
3029- self .assertIsNone (channel .json_body ["user_type" ])
3167+ self .assertEqual (400 , channel .code , msg = channel .json_body )
3168+ self .assertEqual ("Invalid user type" , channel .json_body ["error" ])
30303169
30313170 def test_accidental_deactivation_prevention (self ) -> None :
30323171 """
0 commit comments