2121
2222from typing import Dict , List , Set , Tuple , cast
2323
24+ from parameterized import parameterized
25+
2426from twisted .test .proto_helpers import MemoryReactor
2527from twisted .trial import unittest
2628
@@ -45,14 +47,16 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
4547 self .store = hs .get_datastores ().main
4648 self ._next_stream_ordering = 1
4749
48- def test_simple (self ) -> None :
50+ @parameterized .expand ([(False ,), (True ,)])
51+ def test_simple (self , batched : bool ) -> None :
4952 """Test that the example in `docs/auth_chain_difference_algorithm.md`
5053 works.
5154 """
5255
5356 event_factory = self .hs .get_event_builder_factory ()
5457 bob = "@creator:test"
5558 alice = "@alice:test"
59+ charlie = "@charlie:test"
5660 room_id = "!room:test"
5761
5862 # Ensure that we have a rooms entry so that we generate the chain index.
@@ -191,6 +195,26 @@ def test_simple(self) -> None:
191195 )
192196 )
193197
198+ charlie_invite = self .get_success (
199+ event_factory .for_room_version (
200+ RoomVersions .V6 ,
201+ {
202+ "type" : EventTypes .Member ,
203+ "state_key" : charlie ,
204+ "sender" : alice ,
205+ "room_id" : room_id ,
206+ "content" : {"tag" : "charlie_invite" },
207+ },
208+ ).build (
209+ prev_event_ids = [],
210+ auth_event_ids = [
211+ create .event_id ,
212+ alice_join2 .event_id ,
213+ power_2 .event_id ,
214+ ],
215+ )
216+ )
217+
194218 events = [
195219 create ,
196220 bob_join ,
@@ -200,33 +224,41 @@ def test_simple(self) -> None:
200224 bob_join_2 ,
201225 power_2 ,
202226 alice_join2 ,
227+ charlie_invite ,
203228 ]
204229
205230 expected_links = [
206231 (bob_join , create ),
207- (power , create ),
208232 (power , bob_join ),
209- (alice_invite , create ),
210233 (alice_invite , power ),
211- (alice_invite , bob_join ),
212234 (bob_join_2 , power ),
213235 (alice_join2 , power_2 ),
236+ (charlie_invite , alice_join2 ),
214237 ]
215238
216- self .persist (events )
239+ # We either persist as a batch or one-by-one depending on test
240+ # parameter.
241+ if batched :
242+ self .persist (events )
243+ else :
244+ for event in events :
245+ self .persist ([event ])
246+
217247 chain_map , link_map = self .fetch_chains (events )
218248
219249 # Check that the expected links and only the expected links have been
220250 # added.
221- self .assertEqual (len (expected_links ), len (list (link_map .get_additions ())))
222-
223- for start , end in expected_links :
224- start_id , start_seq = chain_map [start .event_id ]
225- end_id , end_seq = chain_map [end .event_id ]
251+ event_map = {e .event_id : e for e in events }
252+ reverse_chain_map = {v : event_map [k ] for k , v in chain_map .items ()}
226253
227- self .assertIn (
228- (start_seq , end_seq ), list (link_map .get_links_between (start_id , end_id ))
229- )
254+ self .maxDiff = None
255+ self .assertCountEqual (
256+ expected_links ,
257+ [
258+ (reverse_chain_map [(s1 , s2 )], reverse_chain_map [(t1 , t2 )])
259+ for s1 , s2 , t1 , t2 in link_map .get_additions ()
260+ ],
261+ )
230262
231263 # Test that everything can reach the create event, but the create event
232264 # can't reach anything.
@@ -368,24 +400,23 @@ def test_out_of_order_events(self) -> None:
368400
369401 expected_links = [
370402 (bob_join , create ),
371- (power , create ),
372403 (power , bob_join ),
373- (alice_invite , create ),
374404 (alice_invite , power ),
375- (alice_invite , bob_join ),
376405 ]
377406
378407 # Check that the expected links and only the expected links have been
379408 # added.
380- self .assertEqual (len (expected_links ), len (list (link_map .get_additions ())))
409+ event_map = {e .event_id : e for e in events }
410+ reverse_chain_map = {v : event_map [k ] for k , v in chain_map .items ()}
381411
382- for start , end in expected_links :
383- start_id , start_seq = chain_map [start .event_id ]
384- end_id , end_seq = chain_map [end .event_id ]
385-
386- self .assertIn (
387- (start_seq , end_seq ), list (link_map .get_links_between (start_id , end_id ))
388- )
412+ self .maxDiff = None
413+ self .assertCountEqual (
414+ expected_links ,
415+ [
416+ (reverse_chain_map [(s1 , s2 )], reverse_chain_map [(t1 , t2 )])
417+ for s1 , s2 , t1 , t2 in link_map .get_additions ()
418+ ],
419+ )
389420
390421 def persist (
391422 self ,
@@ -489,8 +520,6 @@ def test_simple(self) -> None:
489520 link_map = _LinkMap ()
490521
491522 link_map .add_link ((1 , 1 ), (2 , 1 ), new = False )
492- self .assertCountEqual (link_map .get_links_between (1 , 2 ), [(1 , 1 )])
493- self .assertCountEqual (link_map .get_links_from ((1 , 1 )), [(2 , 1 )])
494523 self .assertCountEqual (link_map .get_additions (), [])
495524 self .assertTrue (link_map .exists_path_from ((1 , 5 ), (2 , 1 )))
496525 self .assertFalse (link_map .exists_path_from ((1 , 5 ), (2 , 2 )))
@@ -499,18 +528,31 @@ def test_simple(self) -> None:
499528
500529 # Attempting to add a redundant link is ignored.
501530 self .assertFalse (link_map .add_link ((1 , 4 ), (2 , 1 )))
502- self .assertCountEqual (link_map .get_links_between ( 1 , 2 ), [( 1 , 1 ) ])
531+ self .assertCountEqual (link_map .get_additions ( ), [])
503532
504533 # Adding new non-redundant links works
505534 self .assertTrue (link_map .add_link ((1 , 3 ), (2 , 3 )))
506- self .assertCountEqual (link_map .get_links_between ( 1 , 2 ), [(1 , 1 ), ( 3 , 3 )])
535+ self .assertCountEqual (link_map .get_additions ( ), [(1 , 3 , 2 , 3 )])
507536
508537 self .assertTrue (link_map .add_link ((2 , 5 ), (1 , 3 )))
509- self .assertCountEqual (link_map .get_links_between (2 , 1 ), [(5 , 3 )])
510- self .assertCountEqual (link_map .get_links_between (1 , 2 ), [(1 , 1 ), (3 , 3 )])
511-
512538 self .assertCountEqual (link_map .get_additions (), [(1 , 3 , 2 , 3 ), (2 , 5 , 1 , 3 )])
513539
540+ def test_exists_path_from (self ) -> None :
541+ "Check that `exists_path_from` can handle non-direct links"
542+ link_map = _LinkMap ()
543+
544+ link_map .add_link ((1 , 1 ), (2 , 1 ), new = False )
545+ link_map .add_link ((2 , 1 ), (3 , 1 ), new = False )
546+
547+ self .assertTrue (link_map .exists_path_from ((1 , 4 ), (3 , 1 )))
548+ self .assertFalse (link_map .exists_path_from ((1 , 4 ), (3 , 2 )))
549+
550+ link_map .add_link ((1 , 5 ), (2 , 3 ), new = False )
551+ link_map .add_link ((2 , 2 ), (3 , 3 ), new = False )
552+
553+ self .assertTrue (link_map .exists_path_from ((1 , 6 ), (3 , 2 )))
554+ self .assertFalse (link_map .exists_path_from ((1 , 4 ), (3 , 2 )))
555+
514556
515557class EventChainBackgroundUpdateTestCase (HomeserverTestCase ):
516558 servlets = [
0 commit comments