1616
1717import logging
1818from collections import namedtuple
19- from typing import Any , Awaitable , Callable , Iterable , List , Optional , Tuple
19+ from typing import Any , Awaitable , Callable , List , Optional , Tuple
2020
2121import attr
2222
5353#
5454# The arguments are:
5555#
56+ # * instance_name: the writer of the stream
5657# * from_token: the previous stream token: the starting point for fetching the
5758# updates
5859# * to_token: the new stream token: the point to get updates up to
6263# If there are more updates available, it should set `limited` in the result, and
6364# it will be called again to get the next batch.
6465#
65- UpdateFunction = Callable [[Token , Token , int ], Awaitable [StreamUpdateResult ]]
66+ UpdateFunction = Callable [[str , Token , Token , int ], Awaitable [StreamUpdateResult ]]
6667
6768
6869class Stream (object ):
@@ -93,6 +94,7 @@ def parse_row(cls, row: StreamRow):
9394
9495 def __init__ (
9596 self ,
97+ local_instance_name : str ,
9698 current_token_function : Callable [[], Token ],
9799 update_function : UpdateFunction ,
98100 ):
@@ -108,9 +110,11 @@ def __init__(
108110 stream tokens. See the UpdateFunction type definition for more info.
109111
110112 Args:
113+ local_instance_name: The instance name of the current process
111114 current_token_function: callback to get the current token, as above
112115 update_function: callback go get stream updates, as above
113116 """
117+ self .local_instance_name = local_instance_name
114118 self .current_token = current_token_function
115119 self .update_function = update_function
116120
@@ -135,14 +139,14 @@ async def get_updates(self) -> StreamUpdateResult:
135139 """
136140 current_token = self .current_token ()
137141 updates , current_token , limited = await self .get_updates_since (
138- self .last_token , current_token
142+ self .local_instance_name , self . last_token , current_token
139143 )
140144 self .last_token = current_token
141145
142146 return updates , current_token , limited
143147
144148 async def get_updates_since (
145- self , from_token : Token , upto_token : Token
149+ self , instance_name : str , from_token : Token , upto_token : Token
146150 ) -> StreamUpdateResult :
147151 """Like get_updates except allows specifying from when we should
148152 stream updates
@@ -160,19 +164,19 @@ async def get_updates_since(
160164 return [], upto_token , False
161165
162166 updates , upto_token , limited = await self .update_function (
163- from_token , upto_token , _STREAM_UPDATE_TARGET_ROW_COUNT ,
167+ instance_name , from_token , upto_token , _STREAM_UPDATE_TARGET_ROW_COUNT ,
164168 )
165169 return updates , upto_token , limited
166170
167171
168172def db_query_to_update_function (
169- query_function : Callable [[Token , Token , int ], Awaitable [Iterable [tuple ]]]
173+ query_function : Callable [[Token , Token , int ], Awaitable [List [tuple ]]]
170174) -> UpdateFunction :
171175 """Wraps a db query function which returns a list of rows to make it
172176 suitable for use as an `update_function` for the Stream class
173177 """
174178
175- async def update_function (from_token , upto_token , limit ):
179+ async def update_function (instance_name , from_token , upto_token , limit ):
176180 rows = await query_function (from_token , upto_token , limit )
177181 updates = [(row [0 ], row [1 :]) for row in rows ]
178182 limited = False
@@ -193,10 +197,13 @@ def make_http_update_function(hs, stream_name: str) -> UpdateFunction:
193197 client = ReplicationGetStreamUpdates .make_client (hs )
194198
195199 async def update_function (
196- from_token : int , upto_token : int , limit : int
200+ instance_name : str , from_token : int , upto_token : int , limit : int
197201 ) -> StreamUpdateResult :
198202 result = await client (
199- stream_name = stream_name , from_token = from_token , upto_token = upto_token ,
203+ instance_name = instance_name ,
204+ stream_name = stream_name ,
205+ from_token = from_token ,
206+ upto_token = upto_token ,
200207 )
201208 return result ["updates" ], result ["upto_token" ], result ["limited" ]
202209
@@ -226,6 +233,7 @@ class BackfillStream(Stream):
226233 def __init__ (self , hs ):
227234 store = hs .get_datastore ()
228235 super ().__init__ (
236+ hs .get_instance_name (),
229237 store .get_current_backfill_token ,
230238 db_query_to_update_function (store .get_all_new_backfill_event_rows ),
231239 )
@@ -261,7 +269,9 @@ def __init__(self, hs):
261269 # Query master process
262270 update_function = make_http_update_function (hs , self .NAME )
263271
264- super ().__init__ (store .get_current_presence_token , update_function )
272+ super ().__init__ (
273+ hs .get_instance_name (), store .get_current_presence_token , update_function
274+ )
265275
266276
267277class TypingStream (Stream ):
@@ -284,7 +294,9 @@ def __init__(self, hs):
284294 # Query master process
285295 update_function = make_http_update_function (hs , self .NAME )
286296
287- super ().__init__ (typing_handler .get_current_token , update_function )
297+ super ().__init__ (
298+ hs .get_instance_name (), typing_handler .get_current_token , update_function
299+ )
288300
289301
290302class ReceiptsStream (Stream ):
@@ -305,6 +317,7 @@ class ReceiptsStream(Stream):
305317 def __init__ (self , hs ):
306318 store = hs .get_datastore ()
307319 super ().__init__ (
320+ hs .get_instance_name (),
308321 store .get_max_receipt_stream_id ,
309322 db_query_to_update_function (store .get_all_updated_receipts ),
310323 )
@@ -322,14 +335,16 @@ class PushRulesStream(Stream):
322335 def __init__ (self , hs ):
323336 self .store = hs .get_datastore ()
324337 super (PushRulesStream , self ).__init__ (
325- self ._current_token , self ._update_function
338+ hs . get_instance_name (), self ._current_token , self ._update_function
326339 )
327340
328341 def _current_token (self ) -> int :
329342 push_rules_token , _ = self .store .get_push_rules_stream_token ()
330343 return push_rules_token
331344
332- async def _update_function (self , from_token : Token , to_token : Token , limit : int ):
345+ async def _update_function (
346+ self , instance_name : str , from_token : Token , to_token : Token , limit : int
347+ ):
333348 rows = await self .store .get_all_push_rule_updates (from_token , to_token , limit )
334349
335350 limited = False
@@ -356,6 +371,7 @@ def __init__(self, hs):
356371 store = hs .get_datastore ()
357372
358373 super ().__init__ (
374+ hs .get_instance_name (),
359375 store .get_pushers_stream_token ,
360376 db_query_to_update_function (store .get_all_updated_pushers_rows ),
361377 )
@@ -387,6 +403,7 @@ class CachesStreamRow:
387403 def __init__ (self , hs ):
388404 store = hs .get_datastore ()
389405 super ().__init__ (
406+ hs .get_instance_name (),
390407 store .get_cache_stream_token ,
391408 db_query_to_update_function (store .get_all_updated_caches ),
392409 )
@@ -412,6 +429,7 @@ class PublicRoomsStream(Stream):
412429 def __init__ (self , hs ):
413430 store = hs .get_datastore ()
414431 super ().__init__ (
432+ hs .get_instance_name (),
415433 store .get_current_public_room_stream_id ,
416434 db_query_to_update_function (store .get_all_new_public_rooms ),
417435 )
@@ -432,6 +450,7 @@ class DeviceListsStreamRow:
432450 def __init__ (self , hs ):
433451 store = hs .get_datastore ()
434452 super ().__init__ (
453+ hs .get_instance_name (),
435454 store .get_device_stream_token ,
436455 db_query_to_update_function (store .get_all_device_list_changes_for_remotes ),
437456 )
@@ -449,6 +468,7 @@ class ToDeviceStream(Stream):
449468 def __init__ (self , hs ):
450469 store = hs .get_datastore ()
451470 super ().__init__ (
471+ hs .get_instance_name (),
452472 store .get_to_device_stream_token ,
453473 db_query_to_update_function (store .get_all_new_device_messages ),
454474 )
@@ -468,6 +488,7 @@ class TagAccountDataStream(Stream):
468488 def __init__ (self , hs ):
469489 store = hs .get_datastore ()
470490 super ().__init__ (
491+ hs .get_instance_name (),
471492 store .get_max_account_data_stream_id ,
472493 db_query_to_update_function (store .get_all_updated_tags ),
473494 )
@@ -487,6 +508,7 @@ class AccountDataStream(Stream):
487508 def __init__ (self , hs ):
488509 self .store = hs .get_datastore ()
489510 super ().__init__ (
511+ hs .get_instance_name (),
490512 self .store .get_max_account_data_stream_id ,
491513 db_query_to_update_function (self ._update_function ),
492514 )
@@ -517,6 +539,7 @@ class GroupServerStream(Stream):
517539 def __init__ (self , hs ):
518540 store = hs .get_datastore ()
519541 super ().__init__ (
542+ hs .get_instance_name (),
520543 store .get_group_stream_token ,
521544 db_query_to_update_function (store .get_all_groups_changes ),
522545 )
@@ -534,6 +557,7 @@ class UserSignatureStream(Stream):
534557 def __init__ (self , hs ):
535558 store = hs .get_datastore ()
536559 super ().__init__ (
560+ hs .get_instance_name (),
537561 store .get_device_stream_token ,
538562 db_query_to_update_function (
539563 store .get_all_user_signature_changes_for_remotes
0 commit comments