|
12 | 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | 13 | # See the License for the specific language governing permissions and |
14 | 14 | # limitations under the License. |
| 15 | +from typing import Optional |
15 | 16 |
|
16 | 17 | from mock import Mock |
17 | 18 |
|
| 19 | +from signedjson import key, sign |
| 20 | +from signedjson.types import BaseKey, SigningKey |
| 21 | + |
18 | 22 | from twisted.internet import defer |
19 | 23 |
|
20 | | -from synapse.types import ReadReceipt |
| 24 | +from synapse.rest import admin |
| 25 | +from synapse.rest.client.v1 import login |
| 26 | +from synapse.types import JsonDict, ReadReceipt |
21 | 27 |
|
22 | 28 | from tests.unittest import HomeserverTestCase, override_config |
23 | 29 |
|
24 | 30 |
|
25 | | -class FederationSenderTestCases(HomeserverTestCase): |
| 31 | +class FederationSenderReceiptsTestCases(HomeserverTestCase): |
26 | 32 | def make_homeserver(self, reactor, clock): |
27 | | - return super(FederationSenderTestCases, self).setup_test_homeserver( |
| 33 | + return self.setup_test_homeserver( |
28 | 34 | state_handler=Mock(spec=["get_current_hosts_in_room"]), |
29 | 35 | federation_transport_client=Mock(spec=["send_transaction"]), |
30 | 36 | ) |
@@ -147,3 +153,294 @@ def test_send_receipts_with_backoff(self): |
147 | 153 | } |
148 | 154 | ], |
149 | 155 | ) |
| 156 | + |
| 157 | + |
| 158 | +class FederationSenderDevicesTestCases(HomeserverTestCase): |
| 159 | + servlets = [ |
| 160 | + admin.register_servlets, |
| 161 | + login.register_servlets, |
| 162 | + ] |
| 163 | + |
| 164 | + def make_homeserver(self, reactor, clock): |
| 165 | + return self.setup_test_homeserver( |
| 166 | + state_handler=Mock(spec=["get_current_hosts_in_room"]), |
| 167 | + federation_transport_client=Mock(spec=["send_transaction"]), |
| 168 | + ) |
| 169 | + |
| 170 | + def default_config(self): |
| 171 | + c = super().default_config() |
| 172 | + c["send_federation"] = True |
| 173 | + return c |
| 174 | + |
| 175 | + def prepare(self, reactor, clock, hs): |
| 176 | + # stub out get_current_hosts_in_room |
| 177 | + mock_state_handler = hs.get_state_handler() |
| 178 | + mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"] |
| 179 | + |
| 180 | + # stub out get_users_who_share_room_with_user so that it claims that |
| 181 | + # `@user2:host2` is in the room |
| 182 | + def get_users_who_share_room_with_user(user_id): |
| 183 | + return defer.succeed({"@user2:host2"}) |
| 184 | + |
| 185 | + hs.get_datastore().get_users_who_share_room_with_user = ( |
| 186 | + get_users_who_share_room_with_user |
| 187 | + ) |
| 188 | + |
| 189 | + # whenever send_transaction is called, record the edu data |
| 190 | + self.edus = [] |
| 191 | + self.hs.get_federation_transport_client().send_transaction.side_effect = ( |
| 192 | + self.record_transaction |
| 193 | + ) |
| 194 | + |
| 195 | + def record_transaction(self, txn, json_cb): |
| 196 | + data = json_cb() |
| 197 | + self.edus.extend(data["edus"]) |
| 198 | + return defer.succeed({}) |
| 199 | + |
| 200 | + def test_send_device_updates(self): |
| 201 | + """Basic case: each device update should result in an EDU""" |
| 202 | + # create a device |
| 203 | + u1 = self.register_user("user", "pass") |
| 204 | + self.login(u1, "pass", device_id="D1") |
| 205 | + |
| 206 | + # expect one edu |
| 207 | + self.assertEqual(len(self.edus), 1) |
| 208 | + stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", None) |
| 209 | + |
| 210 | + # a second call should produce no new device EDUs |
| 211 | + self.hs.get_federation_sender().send_device_messages("host2") |
| 212 | + self.pump() |
| 213 | + self.assertEqual(self.edus, []) |
| 214 | + |
| 215 | + # a second device |
| 216 | + self.login("user", "pass", device_id="D2") |
| 217 | + |
| 218 | + self.assertEqual(len(self.edus), 1) |
| 219 | + self.check_device_update_edu(self.edus.pop(0), u1, "D2", stream_id) |
| 220 | + |
| 221 | + def test_upload_signatures(self): |
| 222 | + """Uploading signatures on some devices should produce updates for that user""" |
| 223 | + |
| 224 | + e2e_handler = self.hs.get_e2e_keys_handler() |
| 225 | + |
| 226 | + # register two devices |
| 227 | + u1 = self.register_user("user", "pass") |
| 228 | + self.login(u1, "pass", device_id="D1") |
| 229 | + self.login(u1, "pass", device_id="D2") |
| 230 | + |
| 231 | + # expect two edus |
| 232 | + self.assertEqual(len(self.edus), 2) |
| 233 | + stream_id = None |
| 234 | + stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", stream_id) |
| 235 | + stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D2", stream_id) |
| 236 | + |
| 237 | + # upload signing keys for each device |
| 238 | + device1_signing_key = self.generate_and_upload_device_signing_key(u1, "D1") |
| 239 | + device2_signing_key = self.generate_and_upload_device_signing_key(u1, "D2") |
| 240 | + |
| 241 | + # expect two more edus |
| 242 | + self.assertEqual(len(self.edus), 2) |
| 243 | + stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", stream_id) |
| 244 | + stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D2", stream_id) |
| 245 | + |
| 246 | + # upload master key and self-signing key |
| 247 | + master_signing_key = generate_self_id_key() |
| 248 | + master_key = { |
| 249 | + "user_id": u1, |
| 250 | + "usage": ["master"], |
| 251 | + "keys": {key_id(master_signing_key): encode_pubkey(master_signing_key)}, |
| 252 | + } |
| 253 | + |
| 254 | + # private key: HvQBbU+hc2Zr+JP1sE0XwBe1pfZZEYtJNPJLZJtS+F8 |
| 255 | + selfsigning_signing_key = generate_self_id_key() |
| 256 | + selfsigning_key = { |
| 257 | + "user_id": u1, |
| 258 | + "usage": ["self_signing"], |
| 259 | + "keys": { |
| 260 | + key_id(selfsigning_signing_key): encode_pubkey(selfsigning_signing_key) |
| 261 | + }, |
| 262 | + } |
| 263 | + sign.sign_json(selfsigning_key, u1, master_signing_key) |
| 264 | + |
| 265 | + cross_signing_keys = { |
| 266 | + "master_key": master_key, |
| 267 | + "self_signing_key": selfsigning_key, |
| 268 | + } |
| 269 | + |
| 270 | + self.get_success( |
| 271 | + e2e_handler.upload_signing_keys_for_user(u1, cross_signing_keys) |
| 272 | + ) |
| 273 | + |
| 274 | + # expect signing key update edu |
| 275 | + self.assertEqual(len(self.edus), 1) |
| 276 | + self.assertEqual(self.edus.pop(0)["edu_type"], "org.matrix.signing_key_update") |
| 277 | + |
| 278 | + # sign the devices |
| 279 | + d1_json = build_device_dict(u1, "D1", device1_signing_key) |
| 280 | + sign.sign_json(d1_json, u1, selfsigning_signing_key) |
| 281 | + d2_json = build_device_dict(u1, "D2", device2_signing_key) |
| 282 | + sign.sign_json(d2_json, u1, selfsigning_signing_key) |
| 283 | + |
| 284 | + ret = self.get_success( |
| 285 | + e2e_handler.upload_signatures_for_device_keys( |
| 286 | + u1, {u1: {"D1": d1_json, "D2": d2_json}}, |
| 287 | + ) |
| 288 | + ) |
| 289 | + self.assertEqual(ret["failures"], {}) |
| 290 | + |
| 291 | + # expect two edus, in one or two transactions. We don't know what order the |
| 292 | + # devices will be updated. |
| 293 | + self.assertEqual(len(self.edus), 2) |
| 294 | + stream_id = None # FIXME: there is a discontinuity in the stream IDs: see #7142 |
| 295 | + for edu in self.edus: |
| 296 | + self.assertEqual(edu["edu_type"], "m.device_list_update") |
| 297 | + c = edu["content"] |
| 298 | + if stream_id is not None: |
| 299 | + self.assertEqual(c["prev_id"], [stream_id]) |
| 300 | + stream_id = c["stream_id"] |
| 301 | + devices = {edu["content"]["device_id"] for edu in self.edus} |
| 302 | + self.assertEqual({"D1", "D2"}, devices) |
| 303 | + |
| 304 | + def test_delete_devices(self): |
| 305 | + """If devices are deleted, that should result in EDUs too""" |
| 306 | + |
| 307 | + # create devices |
| 308 | + u1 = self.register_user("user", "pass") |
| 309 | + self.login("user", "pass", device_id="D1") |
| 310 | + self.login("user", "pass", device_id="D2") |
| 311 | + self.login("user", "pass", device_id="D3") |
| 312 | + |
| 313 | + # expect three edus |
| 314 | + self.assertEqual(len(self.edus), 3) |
| 315 | + stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", None) |
| 316 | + stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D2", stream_id) |
| 317 | + stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D3", stream_id) |
| 318 | + |
| 319 | + # delete them again |
| 320 | + self.get_success( |
| 321 | + self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"]) |
| 322 | + ) |
| 323 | + |
| 324 | + # expect three edus, in an unknown order |
| 325 | + self.assertEqual(len(self.edus), 3) |
| 326 | + for edu in self.edus: |
| 327 | + self.assertEqual(edu["edu_type"], "m.device_list_update") |
| 328 | + c = edu["content"] |
| 329 | + self.assertGreaterEqual( |
| 330 | + c.items(), |
| 331 | + {"user_id": u1, "prev_id": [stream_id], "deleted": True}.items(), |
| 332 | + ) |
| 333 | + stream_id = c["stream_id"] |
| 334 | + devices = {edu["content"]["device_id"] for edu in self.edus} |
| 335 | + self.assertEqual({"D1", "D2", "D3"}, devices) |
| 336 | + |
| 337 | + def test_unreachable_server(self): |
| 338 | + """If the destination server is unreachable, all the updates should get sent on |
| 339 | + recovery |
| 340 | + """ |
| 341 | + mock_send_txn = self.hs.get_federation_transport_client().send_transaction |
| 342 | + mock_send_txn.side_effect = lambda t, cb: defer.fail("fail") |
| 343 | + |
| 344 | + # create devices |
| 345 | + u1 = self.register_user("user", "pass") |
| 346 | + self.login("user", "pass", device_id="D1") |
| 347 | + self.login("user", "pass", device_id="D2") |
| 348 | + self.login("user", "pass", device_id="D3") |
| 349 | + |
| 350 | + # delete them again |
| 351 | + self.get_success( |
| 352 | + self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"]) |
| 353 | + ) |
| 354 | + |
| 355 | + self.assertGreaterEqual(mock_send_txn.call_count, 4) |
| 356 | + |
| 357 | + # recover the server |
| 358 | + mock_send_txn.side_effect = self.record_transaction |
| 359 | + self.hs.get_federation_sender().send_device_messages("host2") |
| 360 | + self.pump() |
| 361 | + |
| 362 | + # for each device, there should be a single update |
| 363 | + self.assertEqual(len(self.edus), 3) |
| 364 | + stream_id = None |
| 365 | + for edu in self.edus: |
| 366 | + self.assertEqual(edu["edu_type"], "m.device_list_update") |
| 367 | + c = edu["content"] |
| 368 | + self.assertEqual(c["prev_id"], [stream_id] if stream_id is not None else []) |
| 369 | + stream_id = c["stream_id"] |
| 370 | + devices = {edu["content"]["device_id"] for edu in self.edus} |
| 371 | + self.assertEqual({"D1", "D2", "D3"}, devices) |
| 372 | + |
| 373 | + def check_device_update_edu( |
| 374 | + self, |
| 375 | + edu: JsonDict, |
| 376 | + user_id: str, |
| 377 | + device_id: str, |
| 378 | + prev_stream_id: Optional[int], |
| 379 | + ) -> int: |
| 380 | + """Check that the given EDU is an update for the given device |
| 381 | + Returns the stream_id. |
| 382 | + """ |
| 383 | + self.assertEqual(edu["edu_type"], "m.device_list_update") |
| 384 | + content = edu["content"] |
| 385 | + |
| 386 | + expected = { |
| 387 | + "user_id": user_id, |
| 388 | + "device_id": device_id, |
| 389 | + "prev_id": [prev_stream_id] if prev_stream_id is not None else [], |
| 390 | + } |
| 391 | + |
| 392 | + self.assertLessEqual(expected.items(), content.items()) |
| 393 | + return content["stream_id"] |
| 394 | + |
| 395 | + def check_signing_key_update_txn(self, txn: JsonDict,) -> None: |
| 396 | + """Check that the txn has an EDU with a signing key update. |
| 397 | + """ |
| 398 | + edus = txn["edus"] |
| 399 | + self.assertEqual(len(edus), 1) |
| 400 | + |
| 401 | + def generate_and_upload_device_signing_key( |
| 402 | + self, user_id: str, device_id: str |
| 403 | + ) -> SigningKey: |
| 404 | + """Generate a signing keypair for the given device, and upload it""" |
| 405 | + sk = key.generate_signing_key(device_id) |
| 406 | + |
| 407 | + device_dict = build_device_dict(user_id, device_id, sk) |
| 408 | + |
| 409 | + self.get_success( |
| 410 | + self.hs.get_e2e_keys_handler().upload_keys_for_user( |
| 411 | + user_id, device_id, {"device_keys": device_dict}, |
| 412 | + ) |
| 413 | + ) |
| 414 | + return sk |
| 415 | + |
| 416 | + |
| 417 | +def generate_self_id_key() -> SigningKey: |
| 418 | + """generate a signing key whose version is its public key |
| 419 | +
|
| 420 | + ... as used by the cross-signing-keys. |
| 421 | + """ |
| 422 | + k = key.generate_signing_key("x") |
| 423 | + k.version = encode_pubkey(k) |
| 424 | + return k |
| 425 | + |
| 426 | + |
| 427 | +def key_id(k: BaseKey) -> str: |
| 428 | + return "%s:%s" % (k.alg, k.version) |
| 429 | + |
| 430 | + |
| 431 | +def encode_pubkey(sk: SigningKey) -> str: |
| 432 | + """Encode the public key corresponding to the given signing key as base64""" |
| 433 | + return key.encode_verify_key_base64(key.get_verify_key(sk)) |
| 434 | + |
| 435 | + |
| 436 | +def build_device_dict(user_id: str, device_id: str, sk: SigningKey): |
| 437 | + """Build a dict representing the given device""" |
| 438 | + return { |
| 439 | + "user_id": user_id, |
| 440 | + "device_id": device_id, |
| 441 | + "algorithms": ["m.olm.curve25519-aes-sha256", "m.megolm.v1.aes-sha"], |
| 442 | + "keys": { |
| 443 | + "curve25519:" + device_id: "curve25519+key", |
| 444 | + key_id(sk): encode_pubkey(sk), |
| 445 | + }, |
| 446 | + } |
0 commit comments