1414# limitations under the License.
1515import logging
1616
17- from six import iteritems , itervalues
17+ from six import iteritems
1818
1919from canonicaljson import json
2020
@@ -72,67 +72,146 @@ def get_devices_by_user(self, user_id):
7272
7373 defer .returnValue ({d ["device_id" ]: d for d in devices })
7474
75- def get_devices_by_remote (self , destination , from_stream_id ):
75+ @defer .inlineCallbacks
76+ def get_devices_by_remote (self , destination , from_stream_id , limit ):
7677 """Get stream of updates to send to remote servers
7778
7879 Returns:
79- (int, list[dict]): current stream id and list of updates
80+ Deferred[tuple[int, list[dict]]]:
81+ current stream id (ie, the stream id of the last update included in the
82+ response), and the list of updates
8083 """
8184 now_stream_id = self ._device_list_id_gen .get_current_token ()
8285
8386 has_changed = self ._device_list_federation_stream_cache .has_entity_changed (
8487 destination , int (from_stream_id )
8588 )
8689 if not has_changed :
87- return (now_stream_id , [])
88-
89- return self .runInteraction (
90+ defer .returnValue ((now_stream_id , []))
91+
92+ # We retrieve n+1 devices from the list of outbound pokes where n is
93+ # our outbound device update limit. We then check if the very last
94+ # device has the same stream_id as the second-to-last device. If so,
95+ # then we ignore all devices with that stream_id and only send the
96+ # devices with a lower stream_id.
97+ #
98+ # If when culling the list we end up with no devices afterwards, we
99+ # consider the device update to be too large, and simply skip the
100+ # stream_id; the rationale being that such a large device list update
101+ # is likely an error.
102+ updates = yield self .runInteraction (
90103 "get_devices_by_remote" ,
91104 self ._get_devices_by_remote_txn ,
92105 destination ,
93106 from_stream_id ,
94107 now_stream_id ,
108+ limit + 1 ,
95109 )
96110
111+ # Return an empty list if there are no updates
112+ if not updates :
113+ defer .returnValue ((now_stream_id , []))
114+
115+ # if we have exceeded the limit, we need to exclude any results with the
116+ # same stream_id as the last row.
117+ if len (updates ) > limit :
118+ stream_id_cutoff = updates [- 1 ][2 ]
119+ now_stream_id = stream_id_cutoff - 1
120+ else :
121+ stream_id_cutoff = None
122+
123+ # Perform the equivalent of a GROUP BY
124+ #
125+ # Iterate through the updates list and copy non-duplicate
126+ # (user_id, device_id) entries into a map, with the value being
127+ # the max stream_id across each set of duplicate entries
128+ #
129+ # maps (user_id, device_id) -> stream_id
130+ # as long as their stream_id does not match that of the last row
131+ query_map = {}
132+ for update in updates :
133+ if stream_id_cutoff is not None and update [2 ] >= stream_id_cutoff :
134+ # Stop processing updates
135+ break
136+
137+ key = (update [0 ], update [1 ])
138+ query_map [key ] = max (query_map .get (key , 0 ), update [2 ])
139+
140+ # If we didn't find any updates with a stream_id lower than the cutoff, it
141+ # means that there are more than limit updates all of which have the same
142+ # steam_id.
143+
144+ # That should only happen if a client is spamming the server with new
145+ # devices, in which case E2E isn't going to work well anyway. We'll just
146+ # skip that stream_id and return an empty list, and continue with the next
147+ # stream_id next time.
148+ if not query_map :
149+ defer .returnValue ((stream_id_cutoff , []))
150+
151+ results = yield self ._get_device_update_edus_by_remote (
152+ destination ,
153+ from_stream_id ,
154+ query_map ,
155+ )
156+
157+ defer .returnValue ((now_stream_id , results ))
158+
97159 def _get_devices_by_remote_txn (
98- self , txn , destination , from_stream_id , now_stream_id
160+ self , txn , destination , from_stream_id , now_stream_id , limit
99161 ):
162+ """Return device update information for a given remote destination
163+
164+ Args:
165+ txn (LoggingTransaction): The transaction to execute
166+ destination (str): The host the device updates are intended for
167+ from_stream_id (int): The minimum stream_id to filter updates by, exclusive
168+ now_stream_id (int): The maximum stream_id to filter updates by, inclusive
169+ limit (int): Maximum number of device updates to return
170+
171+ Returns:
172+ List: List of device updates
173+ """
100174 sql = """
101- SELECT user_id, device_id, max( stream_id) FROM device_lists_outbound_pokes
175+ SELECT user_id, device_id, stream_id FROM device_lists_outbound_pokes
102176 WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
103- GROUP BY user_id, device_id
104- LIMIT 20
177+ ORDER BY stream_id
178+ LIMIT ?
105179 """
106- txn .execute (sql , (destination , from_stream_id , now_stream_id , False ))
180+ txn .execute (sql , (destination , from_stream_id , now_stream_id , False , limit ))
107181
108- # maps (user_id, device_id) -> stream_id
109- query_map = {(r [0 ], r [1 ]): r [2 ] for r in txn }
110- if not query_map :
111- return (now_stream_id , [])
182+ return list (txn )
112183
113- if len (query_map ) >= 20 :
114- now_stream_id = max (stream_id for stream_id in itervalues (query_map ))
184+ @defer .inlineCallbacks
185+ def _get_device_update_edus_by_remote (
186+ self , destination , from_stream_id , query_map ,
187+ ):
188+ """Returns a list of device update EDUs as well as E2EE keys
115189
116- devices = self ._get_e2e_device_keys_txn (
117- txn ,
190+ Args:
191+ destination (str): The host the device updates are intended for
192+ from_stream_id (int): The minimum stream_id to filter updates by, exclusive
193+ query_map (Dict[(str, str): int]): Dictionary mapping
194+ user_id/device_id to update stream_id
195+
196+ Returns:
197+ List[Dict]: List of objects representing an device update EDU
198+
199+ """
200+ devices = yield self .runInteraction (
201+ "_get_e2e_device_keys_txn" ,
202+ self ._get_e2e_device_keys_txn ,
118203 query_map .keys (),
119204 include_all_devices = True ,
120205 include_deleted_devices = True ,
121206 )
122207
123- prev_sent_id_sql = """
124- SELECT coalesce(max(stream_id), 0) as stream_id
125- FROM device_lists_outbound_last_success
126- WHERE destination = ? AND user_id = ? AND stream_id <= ?
127- """
128-
129208 results = []
130209 for user_id , user_devices in iteritems (devices ):
131210 # The prev_id for the first row is always the last row before
132211 # `from_stream_id`
133- txn . execute ( prev_sent_id_sql , ( destination , user_id , from_stream_id ))
134- rows = txn . fetchall ()
135- prev_id = rows [ 0 ][ 0 ]
212+ prev_id = yield self . _get_last_device_update_for_remote_user (
213+ destination , user_id , from_stream_id ,
214+ )
136215 for device_id , device in iteritems (user_devices ):
137216 stream_id = query_map [(user_id , device_id )]
138217 result = {
@@ -156,7 +235,22 @@ def _get_devices_by_remote_txn(
156235
157236 results .append (result )
158237
159- return (now_stream_id , results )
238+ defer .returnValue (results )
239+
240+ def _get_last_device_update_for_remote_user (
241+ self , destination , user_id , from_stream_id ,
242+ ):
243+ def f (txn ):
244+ prev_sent_id_sql = """
245+ SELECT coalesce(max(stream_id), 0) as stream_id
246+ FROM device_lists_outbound_last_success
247+ WHERE destination = ? AND user_id = ? AND stream_id <= ?
248+ """
249+ txn .execute (prev_sent_id_sql , (destination , user_id , from_stream_id ))
250+ rows = txn .fetchall ()
251+ return rows [0 ][0 ]
252+
253+ return self .runInteraction ("get_last_device_update_for_remote_user" , f )
160254
161255 def mark_as_sent_devices_by_remote (self , destination , stream_id ):
162256 """Mark that updates have successfully been sent to the destination.
0 commit comments