Skip to content

Commit c0cde1a

Browse files
committed
Route history_update SSE events for anonymous sessions
HistoryAuditMonitor only dispatched history_update events via push_to_user, so anonymous-owned histories (user_id IS NULL) never produced events. With enable_sse_history_updates on, the client disables polling and waits for SSE events — leaving anonymous history panels frozen, which broke seven UI tests across the Playwright suites. Extend the pipeline with a parallel galaxy_session.id-keyed route: - SSEConnectionManager tracks a _session_connections map alongside _connections and exposes push_to_session; connect/disconnect/stream accept an optional galaxy_session_id. - EventsService.open_stream forwards trans.galaxy_session.id so anonymous sessions register under their session key. - SSEEventDispatcher.history_update and HistoryUpdatePayload gain an optional session_updates dict; the queue_worker handler fans out session-keyed events via push_to_session. - HistoryAuditMonitor caches (user_id, session_ids) per history and performs one extra indexed lookup against GalaxySessionToHistoryAssociation only for anon histories. galaxy_session.id never leaves the server — it's used only as an in-memory/AMQP dispatch key; the browser-visible event payload still contains just encoded history_ids.
1 parent fee5e5f commit c0cde1a

5 files changed

Lines changed: 119 additions & 28 deletions

File tree

lib/galaxy/managers/history_audit_monitor.py

Lines changed: 48 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from galaxy.config import GalaxyAppConfiguration
3232
from galaxy.managers.sse_dispatch import SSEEventDispatcher
3333
from galaxy.model import (
34+
GalaxySessionToHistoryAssociation,
3435
History,
3536
HistoryAudit,
3637
)
@@ -119,8 +120,11 @@ def __init__(
119120
self._exit = threading.Event()
120121
self._thread: Optional[threading.Thread] = None
121122
self._active = False
122-
# Bounded LRU cache: history_id -> user_id, refreshed on miss.
123-
self._history_owner_cache: OrderedDict[int, int] = OrderedDict()
123+
# Bounded LRU cache: history_id -> (user_id, session_ids), refreshed on miss.
124+
# For registered-owned histories: (user_id, ()); for anonymous histories:
125+
# (None, (session_id, ...)) — a history can be associated with multiple
126+
# sessions via GalaxySessionToHistoryAssociation.
127+
self._history_owner_cache: OrderedDict[int, tuple[Optional[int], tuple[int, ...]]] = OrderedDict()
124128

125129
def start(self) -> None:
126130
if self._active:
@@ -227,7 +231,7 @@ def _poll_audit_table(self) -> None:
227231
# --- Common dispatch logic ---
228232

229233
def _dispatch_history_updates(self, history_ids: set[int]) -> None:
230-
"""Map history_ids to user_ids and send Kombu control task.
234+
"""Map history_ids to user_ids / session_ids and send Kombu control task.
231235
232236
Raw integer history IDs are sent across the control queue; encoding is
233237
deferred to the ``history_update`` task handler on the receiving side,
@@ -239,24 +243,58 @@ def _dispatch_history_updates(self, history_ids: set[int]) -> None:
239243
self._refresh_owner_cache(unknown)
240244

241245
user_updates: dict[str, list[int]] = defaultdict(list)
246+
session_updates: dict[str, list[int]] = defaultdict(list)
242247
for history_id in history_ids:
243-
user_id = self._history_owner_cache.get(history_id)
248+
entry = self._history_owner_cache.get(history_id)
249+
if entry is None:
250+
continue
251+
user_id, session_ids = entry
244252
if user_id is not None:
245253
user_updates[str(user_id)].append(history_id)
254+
else:
255+
for session_id in session_ids:
256+
session_updates[str(session_id)].append(history_id)
246257

247-
if not user_updates:
258+
if not user_updates and not session_updates:
248259
return
249260

250-
self._dispatcher.history_update(user_updates=dict(user_updates))
261+
self._dispatcher.history_update(
262+
user_updates=dict(user_updates),
263+
session_updates=dict(session_updates) if session_updates else None,
264+
)
251265

252266
def _refresh_owner_cache(self, history_ids: set[int]) -> None:
253-
"""Look up user_id for given history_ids and update the bounded cache."""
267+
"""Look up ownership for given history_ids and update the bounded cache.
268+
269+
Registered-owned histories resolve with just ``History.user_id``. For
270+
histories where ``user_id IS NULL`` we additionally fetch associated
271+
``galaxy_session.id`` values from ``GalaxySessionToHistoryAssociation``
272+
so the anonymous SSE dispatch path can target the right browser.
273+
"""
254274
try:
255-
stmt = sa_select(History.id, History.user_id).where(History.id.in_(history_ids))
256275
with self._model.new_session() as session:
276+
stmt = sa_select(History.id, History.user_id).where(History.id.in_(history_ids))
277+
anon_history_ids: set[int] = set()
257278
for row in session.execute(stmt):
258-
self._history_owner_cache[row[0]] = row[1]
259-
self._history_owner_cache.move_to_end(row[0])
279+
hid, uid = row[0], row[1]
280+
self._history_owner_cache[hid] = (uid, ())
281+
self._history_owner_cache.move_to_end(hid)
282+
if uid is None:
283+
anon_history_ids.add(hid)
284+
285+
if anon_history_ids:
286+
assoc_stmt = sa_select(
287+
GalaxySessionToHistoryAssociation.history_id,
288+
GalaxySessionToHistoryAssociation.session_id,
289+
).where(GalaxySessionToHistoryAssociation.history_id.in_(anon_history_ids))
290+
sessions_by_history: dict[int, list[int]] = defaultdict(list)
291+
for row in session.execute(assoc_stmt):
292+
hid, sid = row[0], row[1]
293+
if sid is not None:
294+
sessions_by_history[hid].append(sid)
295+
for hid, sids in sessions_by_history.items():
296+
self._history_owner_cache[hid] = (None, tuple(sids))
297+
260298
while len(self._history_owner_cache) > OWNER_CACHE_MAX:
261299
self._history_owner_cache.popitem(last=False)
262300
except Exception:

lib/galaxy/managers/sse.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ class SSEConnectionManager:
8181

8282
def __init__(self, statsd_client: Optional[VanillaGalaxyStatsdClient] = None) -> None:
8383
self._connections: dict[int, set[asyncio.Queue]] = defaultdict(set)
84+
self._session_connections: dict[int, set[asyncio.Queue]] = defaultdict(set)
8485
self._broadcast_connections: set[asyncio.Queue] = set()
8586
self._loop: Optional[asyncio.AbstractEventLoop] = None
8687
self._statsd_client = statsd_client
@@ -92,27 +93,39 @@ def _ensure_loop(self) -> None:
9293

9394
# -- Called from ASYNC context (uvicorn event loop thread) --
9495

95-
def connect(self, user_id: Optional[int]) -> asyncio.Queue:
96+
def connect(self, user_id: Optional[int], galaxy_session_id: Optional[int] = None) -> asyncio.Queue:
9697
"""Register a new SSE connection. Returns a queue to await events from.
9798
9899
Called from the SSE endpoint handler (async context). A ``ready`` event is
99100
enqueued immediately so that clients (and tests) can synchronize on the
100101
server-side subscription rather than the underlying socket open event.
102+
103+
``galaxy_session_id`` is the dispatch key for events that target a
104+
specific browser session (e.g. history updates for anonymous users,
105+
whose ``user_id`` is ``None``).
101106
"""
102107
self._ensure_loop()
103108
queue: asyncio.Queue = asyncio.Queue(maxsize=64)
104109
if user_id is not None:
105110
self._connections[user_id].add(queue)
111+
if galaxy_session_id is not None:
112+
self._session_connections[galaxy_session_id].add(queue)
106113
self._broadcast_connections.add(queue)
107114
queue.put_nowait(SSEEvent(event="ready", data=""))
108115
log.debug(
109-
"SSE connection opened for user_id=%s (total=%d)",
116+
"SSE connection opened for user_id=%s session_id=%s (total=%d)",
110117
user_id,
118+
galaxy_session_id,
111119
len(self._broadcast_connections),
112120
)
113121
return queue
114122

115-
def disconnect(self, user_id: Optional[int], queue: asyncio.Queue) -> None:
123+
def disconnect(
124+
self,
125+
user_id: Optional[int],
126+
queue: asyncio.Queue,
127+
galaxy_session_id: Optional[int] = None,
128+
) -> None:
116129
"""Unregister an SSE connection.
117130
118131
Called from the SSE endpoint's ``finally`` block (async context).
@@ -121,10 +134,15 @@ def disconnect(self, user_id: Optional[int], queue: asyncio.Queue) -> None:
121134
self._connections[user_id].discard(queue)
122135
if not self._connections[user_id]:
123136
del self._connections[user_id]
137+
if galaxy_session_id is not None:
138+
self._session_connections[galaxy_session_id].discard(queue)
139+
if not self._session_connections[galaxy_session_id]:
140+
del self._session_connections[galaxy_session_id]
124141
self._broadcast_connections.discard(queue)
125142
log.debug(
126-
"SSE connection closed for user_id=%s (total=%d)",
143+
"SSE connection closed for user_id=%s session_id=%s (total=%d)",
127144
user_id,
145+
galaxy_session_id,
128146
len(self._broadcast_connections),
129147
)
130148

@@ -135,6 +153,15 @@ def push_to_user(self, user_id: int, event: SSEEvent) -> None:
135153
for queue in list(self._connections.get(user_id, [])):
136154
self._safe_put(queue, event)
137155

156+
def push_to_session(self, galaxy_session_id: int, event: SSEEvent) -> None:
157+
"""Thread-safe. Push an event to all SSE connections for a specific galaxy_session.
158+
159+
Used to route per-browser events (e.g. history updates for anonymous
160+
histories) when there is no registered ``user_id`` to key on.
161+
"""
162+
for queue in list(self._session_connections.get(galaxy_session_id, [])):
163+
self._safe_put(queue, event)
164+
138165
def push_broadcast(self, event: SSEEvent) -> None:
139166
"""Thread-safe. Push an event to ALL connected SSE clients."""
140167
for queue in list(self._broadcast_connections):
@@ -189,6 +216,7 @@ async def stream(
189216
user_id: Optional[int],
190217
catch_up: Optional[SSEEvent] = None,
191218
keepalive: float = 30.0,
219+
galaxy_session_id: Optional[int] = None,
192220
) -> AsyncIterator[str]:
193221
"""Yield SSE-framed strings for one connected client.
194222
@@ -198,7 +226,7 @@ async def stream(
198226
what the service passes in (typically ``request.is_disconnected`` from
199227
starlette) so the manager stays framework-agnostic.
200228
"""
201-
queue = self.connect(user_id)
229+
queue = self.connect(user_id, galaxy_session_id)
202230
if catch_up is not None:
203231
await queue.put(catch_up)
204232
try:
@@ -211,4 +239,4 @@ async def stream(
211239
except asyncio.TimeoutError:
212240
yield ": keepalive\n\n"
213241
finally:
214-
self.disconnect(user_id, queue)
242+
self.disconnect(user_id, queue, galaxy_session_id)

lib/galaxy/managers/sse_dispatch.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -123,14 +123,21 @@ def notify_broadcast(self, payload: str, event_id: Optional[str] = None) -> None
123123
},
124124
)
125125

126-
def history_update(self, user_updates: dict[str, list[int]], event_id: Optional[str] = None) -> None:
127-
self._send(
128-
"history_update",
129-
{
130-
"user_updates": user_updates,
131-
"event_id": event_id or make_event_id(),
132-
},
133-
)
126+
def history_update(
127+
self,
128+
user_updates: dict[str, list[int]],
129+
event_id: Optional[str] = None,
130+
session_updates: Optional[dict[str, list[int]]] = None,
131+
) -> None:
132+
kwargs: dict[str, Any] = {
133+
"user_updates": user_updates,
134+
"event_id": event_id or make_event_id(),
135+
}
136+
if session_updates:
137+
# Only include when non-empty: anonymous histories are uncommon on
138+
# most deployments, and an empty dict is wasted wire payload.
139+
kwargs["session_updates"] = session_updates
140+
self._send("history_update", kwargs)
134141

135142
def entry_point_update(self, user_id: int, event_id: Optional[str] = None) -> None:
136143
"""Fan out a wake-up ``entry_point_update`` event for one user.

lib/galaxy/queue_worker/__init__.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,14 @@ class HistoryUpdatePayload(TypedDict, total=False):
6969
"""Wire contract for the ``history_update`` control-task kwargs.
7070
7171
``user_updates`` maps stringified user IDs to lists of (unencoded) history IDs.
72-
Stringified because AMQP JSON serialization coerces dict keys to strings.
72+
``session_updates`` is the parallel route for anonymous-owned histories,
73+
keyed by stringified ``galaxy_session.id`` (the dispatch key never leaves
74+
the server — browsers never see it). Stringified because AMQP JSON
75+
serialization coerces dict keys to strings.
7376
"""
7477

7578
user_updates: dict[str, list[int]]
79+
session_updates: dict[str, list[int]]
7680
event_id: Optional[str]
7781

7882

@@ -423,7 +427,9 @@ def history_update(app: "MinimalManagerApp", **kwargs) -> None:
423427
"""Push SSE history update events to connected users on this worker process.
424428
425429
Encodes integer history IDs here (not in the monitor) so the manager layer
426-
stays free of presentation/security concerns.
430+
stays free of presentation/security concerns. Handles both user-keyed
431+
routing (registered users) and galaxy_session-keyed routing (anonymous
432+
histories, which have ``user_id IS NULL``).
427433
"""
428434
payload = cast(HistoryUpdatePayload, kwargs)
429435
sse_manager = app[SSEConnectionManager]
@@ -435,6 +441,12 @@ def history_update(app: "MinimalManagerApp", **kwargs) -> None:
435441
data = json.dumps({"history_ids": encoded_ids})
436442
event = SSEEvent(event="history_update", data=data, id=event_id)
437443
sse_manager.push_to_user(user_id, event)
444+
for session_id_str, history_ids in payload.get("session_updates", {}).items():
445+
session_id = int(session_id_str)
446+
encoded_ids = [encode(hid) for hid in history_ids]
447+
data = json.dumps({"history_ids": encoded_ids})
448+
event = SSEEvent(event="history_update", data=data, id=event_id)
449+
sse_manager.push_to_session(session_id, event)
438450

439451

440452
def entry_point_update(app: "MinimalManagerApp", **kwargs) -> None:

lib/galaxy/webapps/galaxy/services/events.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,13 @@ def open_stream(
3232
last_event_id: Optional[str],
3333
is_disconnected: IsDisconnected,
3434
) -> AsyncIterator[str]:
35-
"""Open an SSE events stream; anonymous users receive only broadcasts."""
35+
"""Open an SSE events stream.
36+
37+
Anonymous users still register under their ``galaxy_session.id`` so the
38+
server can route per-session events (e.g. ``history_update`` for
39+
anonymous-owned histories) even when ``user_id`` is ``None``.
40+
"""
3641
user_id = user_context.user.id if not user_context.anonymous else None
42+
session_id = user_context.galaxy_session.id if user_context.galaxy_session else None
3743
catch_up = self.notifications.build_status_catchup(user_context, last_event_id)
38-
return self.sse_manager.stream(is_disconnected, user_id, catch_up=catch_up)
44+
return self.sse_manager.stream(is_disconnected, user_id, catch_up=catch_up, galaxy_session_id=session_id)

0 commit comments

Comments
 (0)