Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/7010.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Change device list streams to have one row per ID.
10 changes: 7 additions & 3 deletions synapse/app/generic_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,8 +676,9 @@ async def process_and_notify(self, stream_name, token, rows):
elif stream_name == "device_lists":
all_room_ids = set()
for row in rows:
room_ids = await self.store.get_rooms_for_user(row.user_id)
all_room_ids.update(room_ids)
if row.entity.startswith("@"):
room_ids = await self.store.get_rooms_for_user(row.entity)
all_room_ids.update(room_ids)
self.notifier.on_new_event("device_list_key", token, rooms=all_room_ids)
elif stream_name == "presence":
await self.presence_handler.process_replication_rows(token, rows)
Expand Down Expand Up @@ -774,7 +775,10 @@ def process_replication_rows(self, stream_name, token, rows):

# ... as well as device updates and messages
elif stream_name == DeviceListsStream.NAME:
hosts = {row.destination for row in rows}
# The entities are either user IDs (starting with '@') whose devices
# have changed, or remote servers that we need to tell about
# changes.
hosts = {row.entity for row in rows if not row.entity.startswith("@")}
Comment thread
erikjohnston marked this conversation as resolved.
for host in hosts:
self.federation_sender.send_device_messages(host)

Expand Down
36 changes: 23 additions & 13 deletions synapse/replication/slave/storage/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,13 @@ def __init__(self, database: Database, db_conn, hs):
self.hs = hs

self._device_list_id_gen = SlavedIdTracker(
db_conn, "device_lists_stream", "stream_id"
db_conn,
"device_lists_stream",
"stream_id",
extra_tables=[
("user_signature_stream", "stream_id"),
("device_lists_outbound_pokes", "stream_id"),
],
)
device_list_max = self._device_list_id_gen.get_current_token()
self._device_list_stream_cache = StreamChangeCache(
Expand All @@ -55,23 +61,27 @@ def stream_positions(self):
def process_replication_rows(self, stream_name, token, rows):
if stream_name == DeviceListsStream.NAME:
self._device_list_id_gen.advance(token)
for row in rows:
self._invalidate_caches_for_devices(token, row.user_id, row.destination)
self._invalidate_caches_for_devices(token, rows)
elif stream_name == UserSignatureStream.NAME:
self._device_list_id_gen.advance(token)
for row in rows:
self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
return super(SlavedDeviceStore, self).process_replication_rows(
stream_name, token, rows
)

def _invalidate_caches_for_devices(self, token, user_id, destination):
self._device_list_stream_cache.entity_has_changed(user_id, token)

if destination:
self._device_list_federation_stream_cache.entity_has_changed(
destination, token
)
def _invalidate_caches_for_devices(self, token, rows):
for row in rows:
# The entities are either user IDs (starting with '@') whose devices
# have changed, or remote servers that we need to tell about
# changes.
if row.entity.startswith("@"):
Comment thread
erikjohnston marked this conversation as resolved.
self._device_list_stream_cache.entity_has_changed(row.entity, token)
self.get_cached_devices_for_user.invalidate((row.entity,))
self._get_cached_user_device.invalidate_many((row.entity,))
self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,))

self.get_cached_devices_for_user.invalidate((user_id,))
self._get_cached_user_device.invalidate_many((user_id,))
self.get_device_list_last_stream_id_for_remote.invalidate((user_id,))
else:
self._device_list_federation_stream_cache.entity_has_changed(
row.entity, token
)
13 changes: 9 additions & 4 deletions synapse/replication/tcp/streams/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,13 @@ class CachesStreamRow:
"network_id", # str, optional
),
)
DeviceListsStreamRow = namedtuple(
"DeviceListsStreamRow", ("user_id", "destination") # str # str
)


@attr.s
class DeviceListsStreamRow:
entity = attr.ib(type=str)


ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",)) # str
TagAccountDataStreamRow = namedtuple(
"TagAccountDataStreamRow", ("user_id", "room_id", "data") # str # str # dict
Expand Down Expand Up @@ -363,7 +367,8 @@ def __init__(self, hs):


class DeviceListsStream(Stream):
"""Someone added/changed/removed a device
"""Either a user has updated their devices or a remote server needs to be
Comment thread
erikjohnston marked this conversation as resolved.
told about a device update.
"""

NAME = "device_lists"
Expand Down
5 changes: 4 additions & 1 deletion synapse/storage/data_stores/main/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,10 @@ def __init__(self, database: Database, db_conn, hs):
db_conn,
"device_lists_stream",
"stream_id",
extra_tables=[("user_signature_stream", "stream_id")],
extra_tables=[
("user_signature_stream", "stream_id"),
("device_lists_outbound_pokes", "stream_id"),
],
)
self._cross_signing_id_gen = StreamIdGenerator(
db_conn, "e2e_cross_signing_keys", "stream_id"
Expand Down
132 changes: 68 additions & 64 deletions synapse/storage/data_stores/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import List, Tuple

from six import iteritems

Expand All @@ -31,7 +32,7 @@
)
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import Database
from synapse.storage.database import Database, LoggingTransaction
from synapse.types import Collection, get_verify_key_from_cross_signing_key
from synapse.util.caches.descriptors import (
Cache,
Expand Down Expand Up @@ -112,23 +113,13 @@ def get_device_updates_by_remote(self, destination, from_stream_id, limit):
if not has_changed:
return now_stream_id, []

# We retrieve n+1 devices from the list of outbound pokes where n is
# our outbound device update limit. We then check if the very last
# device has the same stream_id as the second-to-last device. If so,
# then we ignore all devices with that stream_id and only send the
# devices with a lower stream_id.
#
# If when culling the list we end up with no devices afterwards, we
# consider the device update to be too large, and simply skip the
# stream_id; the rationale being that such a large device list update
# is likely an error.
updates = yield self.db.runInteraction(
"get_device_updates_by_remote",
self._get_device_updates_by_remote_txn,
destination,
from_stream_id,
now_stream_id,
limit + 1,
limit,
)

# Return an empty list if there are no updates
Expand Down Expand Up @@ -166,14 +157,6 @@ def get_device_updates_by_remote(self, destination, from_stream_id, limit):
"device_id": verify_key.version,
}

# if we have exceeded the limit, we need to exclude any results with the
# same stream_id as the last row.
if len(updates) > limit:
stream_id_cutoff = updates[-1][2]
now_stream_id = stream_id_cutoff - 1
else:
stream_id_cutoff = None

# Perform the equivalent of a GROUP BY
#
# Iterate through the updates list and copy non-duplicate
Expand All @@ -192,10 +175,6 @@ def get_device_updates_by_remote(self, destination, from_stream_id, limit):
query_map = {}
cross_signing_keys_by_user = {}
for user_id, device_id, update_stream_id, update_context in updates:
if stream_id_cutoff is not None and update_stream_id >= stream_id_cutoff:
# Stop processing updates
break

if (
user_id in master_key_by_user
and device_id == master_key_by_user[user_id]["device_id"]
Expand All @@ -218,17 +197,6 @@ def get_device_updates_by_remote(self, destination, from_stream_id, limit):
if update_stream_id > previous_update_stream_id:
query_map[key] = (update_stream_id, update_context)

# If we didn't find any updates with a stream_id lower than the cutoff, it
# means that there are more than limit updates all of which have the same
# steam_id.

# That should only happen if a client is spamming the server with new
# devices, in which case E2E isn't going to work well anyway. We'll just
# skip that stream_id and return an empty list, and continue with the next
# stream_id next time.
if not query_map and not cross_signing_keys_by_user:
return stream_id_cutoff, []

results = yield self._get_device_update_edus_by_remote(
destination, from_stream_id, query_map
)
Expand Down Expand Up @@ -607,21 +575,26 @@ def get_users_whose_signatures_changed(self, user_id, from_key):
else:
return set()

def get_all_device_list_changes_for_remotes(self, from_key, to_key):
"""Return a list of `(stream_id, user_id, destination)` which is the
combined list of changes to devices, and which destinations need to be
poked. `destination` may be None if no destinations need to be poked.
async def get_all_device_list_changes_for_remotes(
self, from_key: int, to_key: int
) -> List[Tuple[int, str]]:
"""Return a list of `(stream_id, entity)` which is the combined list of
changes to devices and which destinations need to be poked. Entity is
either a user ID (starting with '@') or a remote destination.
"""
# We do a group by here as there can be a large number of duplicate
# entries, since we throw away device IDs.

# This query Does The Right Thing where it'll correctly apply the
# bounds to the inner queries.
sql = """
SELECT MAX(stream_id) AS stream_id, user_id, destination
FROM device_lists_stream
LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id)
SELECT stream_id, entity FROM (
SELECT stream_id, user_id AS entity FROM device_lists_stream
UNION ALL
SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
) AS e
WHERE ? < stream_id AND stream_id <= ?
GROUP BY user_id, destination
"""
return self.db.execute(

return await self.db.execute(
"get_all_device_list_changes_for_remotes", None, sql, from_key, to_key
)

Expand Down Expand Up @@ -1017,29 +990,49 @@ def add_device_change_to_streams(self, user_id, device_ids, hosts):
"""Persist that a user's devices have been updated, and which hosts
(if any) should be poked.
"""
with self._device_list_id_gen.get_next() as stream_id:
if not device_ids:
return

with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids:
yield self.db.runInteraction(
"add_device_change_to_streams",
self._add_device_change_txn,
"add_device_change_to_stream",
self._add_device_change_to_stream_txn,
user_id,
device_ids,
stream_ids,
)

if not hosts:
return stream_ids[-1]

context = get_active_span_text_map()
with self._device_list_id_gen.get_next_mult(
len(hosts) * len(device_ids)
) as stream_ids:
yield self.db.runInteraction(
"add_device_outbound_poke_to_stream",
self._add_device_outbound_poke_to_stream_txn,
user_id,
device_ids,
hosts,
stream_id,
stream_ids,
context,
)
return stream_id

def _add_device_change_txn(self, txn, user_id, device_ids, hosts, stream_id):
now = self._clock.time_msec()
return stream_ids[-1]

def _add_device_change_to_stream_txn(
self,
txn: LoggingTransaction,
user_id: str,
device_ids: Collection[str],
stream_ids: List[str],
):
txn.call_after(
self._device_list_stream_cache.entity_has_changed, user_id, stream_id
self._device_list_stream_cache.entity_has_changed, user_id, stream_ids[-1],
)
for host in hosts:
txn.call_after(
self._device_list_federation_stream_cache.entity_has_changed,
host,
stream_id,
)

min_stream_id = stream_ids[0]

# Delete older entries in the table, as we really only care about
# when the latest change happened.
Expand All @@ -1048,27 +1041,38 @@ def _add_device_change_txn(self, txn, user_id, device_ids, hosts, stream_id):
DELETE FROM device_lists_stream
WHERE user_id = ? AND device_id = ? AND stream_id < ?
""",
[(user_id, device_id, stream_id) for device_id in device_ids],
[(user_id, device_id, min_stream_id) for device_id in device_ids],
)

self.db.simple_insert_many_txn(
txn,
table="device_lists_stream",
values=[
{"stream_id": stream_id, "user_id": user_id, "device_id": device_id}
for device_id in device_ids
for stream_id, device_id in zip(stream_ids, device_ids)
],
)

context = get_active_span_text_map()
def _add_device_outbound_poke_to_stream_txn(
self, txn, user_id, device_ids, hosts, stream_ids, context,
):
for host in hosts:
txn.call_after(
self._device_list_federation_stream_cache.entity_has_changed,
host,
stream_ids[-1],
)

now = self._clock.time_msec()
next_stream_id = iter(stream_ids)

self.db.simple_insert_many_txn(
txn,
table="device_lists_outbound_pokes",
values=[
{
"destination": destination,
"stream_id": stream_id,
"stream_id": next(next_stream_id),
"user_id": user_id,
"device_id": device_id,
"sent": False,
Expand Down
Loading