-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsse.py
More file actions
254 lines (210 loc) · 9.08 KB
/
Copy pathsse.py
File metadata and controls
254 lines (210 loc) · 9.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
"""
sse.py — Server-Sent Events stream management.
Provides:
- Per-request stream registries (logs, recommendations, status queues).
- QueueHandler: routes Python log records into the SSE logs queue.
- Helper functions for connection tracking and stream cleanup.
"""
import contextvars
import json
import os
import time
import queue
import logging
from threading import Lock
try:
from flask import g
except ImportError: # outside Flask context (tests that don't need SSE)
g = None # type: ignore
logger = logging.getLogger("letterboxd-recommender")
# request_id for code running outside a Flask request context (background
# pipeline jobs, executor worker threads). Checked before flask.g.
REQUEST_ID_CTX: contextvars.ContextVar = contextvars.ContextVar('request_id', default=None)
# ---------------------------------------------------------------------------
# Stream registry
# ---------------------------------------------------------------------------
REQUEST_STREAMS: dict = {}
REQUEST_STREAMS_LOCK = Lock()
STREAM_MAX_AGE_S = int(os.getenv('STREAM_MAX_AGE_S', '3600')) # evict orphaned entries after 1h
SSE_QUEUE_MAXSIZE = int(os.getenv('SSE_QUEUE_MAXSIZE', '1000'))
_STREAM_EVICTION_INTERVAL_S = 30.0 # scan for stale entries at most this often
_last_stream_eviction_s = 0.0 # guarded by REQUEST_STREAMS_LOCK
class BoundedDropQueue(queue.Queue):
"""Queue that drops the oldest item instead of blocking when full.
An abandoned SSE client must never block producers nor accumulate
messages without bound; `dropped` counts messages lost to overflow.
"""
def __init__(self, maxsize: int = SSE_QUEUE_MAXSIZE):
super().__init__(maxsize)
self.dropped = 0
self._drop_lock = Lock()
def put(self, item, block=True, timeout=None):
while True:
try:
super().put(item, block=False)
return
except queue.Full:
try:
self.get_nowait()
with self._drop_lock:
self.dropped += 1
except queue.Empty:
pass
def _get_or_create_streams(request_id: str) -> dict:
"""Return (or lazily create) the stream-queue bundle for *request_id*.
Also runs a throttled eviction sweep to reclaim memory from orphaned
entries (e.g. clients that disconnected without consuming their streams).
"""
global _last_stream_eviction_s
with REQUEST_STREAMS_LOCK:
now = time.time()
# Throttle the full-scan eviction so it doesn't run on every log
# message. With dozens of threads logging concurrently this was the
# dominant lock-contention source.
if now - _last_stream_eviction_s >= _STREAM_EVICTION_INTERVAL_S:
stale = [
rid for rid, s in REQUEST_STREAMS.items()
if now - s.get('_created', now) > STREAM_MAX_AGE_S
]
if stale:
for rid in stale:
REQUEST_STREAMS.pop(rid, None)
logger.debug(f"[stream-cleanup] evicted {len(stale)} stale stream entries")
_last_stream_eviction_s = now
streams = REQUEST_STREAMS.get(request_id)
if streams is None:
streams = {
'logs': BoundedDropQueue(),
'recommendations': BoundedDropQueue(),
'status': BoundedDropQueue(),
'logs_connected': 0,
'recommendations_connected': 0,
'status_connected': 0,
'recommendations_done': False,
'status_done': False,
'_created': now,
}
REQUEST_STREAMS[request_id] = streams
return streams
def _cleanup_request_streams(request_id: str) -> None:
with REQUEST_STREAMS_LOCK:
REQUEST_STREAMS.pop(request_id, None)
def _mark_recommendations_done(request_id: str) -> None:
with REQUEST_STREAMS_LOCK:
streams = REQUEST_STREAMS.get(request_id)
if streams:
streams['recommendations_done'] = True
def _mark_status_done(request_id: str) -> None:
with REQUEST_STREAMS_LOCK:
streams = REQUEST_STREAMS.get(request_id)
if streams:
streams['status_done'] = True
def _track_stream_connection(request_id: str, stream_name: str, connected: bool) -> None:
with REQUEST_STREAMS_LOCK:
streams = REQUEST_STREAMS.get(request_id)
if not streams:
return
counter_key = f"{stream_name}_connected"
current = streams.get(counter_key, 0)
streams[counter_key] = max(0, current + (1 if connected else -1))
if (
streams.get('recommendations_done')
and streams.get('status_done')
and streams.get('logs_connected', 0) == 0
and streams.get('recommendations_connected', 0) == 0
and streams.get('status_connected', 0) == 0
):
REQUEST_STREAMS.pop(request_id, None)
# ---------------------------------------------------------------------------
# Pub/sub delivery — Redis when REDIS_URL is set, in-memory queues otherwise
# ---------------------------------------------------------------------------
# With 2+ gunicorn workers a client's SSE connection may land on a different
# worker than the one running the pipeline; Redis pub/sub (one channel per
# request_id per stream) carries messages across workers. The in-memory
# queues remain the single-worker fallback. Note: pub/sub does not buffer —
# subscribers must connect before messages are published (the frontend opens
# its streams before POSTing /api/recommend).
_redis_client = None
_redis_attempted = False
_REDIS_INIT_LOCK = Lock()
def _get_redis():
global _redis_client, _redis_attempted
if _redis_attempted:
return _redis_client
with _REDIS_INIT_LOCK:
if _redis_attempted:
return _redis_client
_redis_attempted = True
url = os.getenv('REDIS_URL')
if not url:
return None
try:
import redis as _redis_lib
client = _redis_lib.from_url(url, decode_responses=True)
client.ping()
_redis_client = client
logger.info("SSE delivery using Redis pub/sub backend")
except Exception as exc:
logger.warning("SSE Redis unavailable; using in-memory queues: %s", exc)
return _redis_client
def _channel(request_id: str, stream_name: str) -> str:
return f"sse:{request_id}:{stream_name}"
def publish(request_id: str, stream_name: str, message) -> None:
"""Send *message* to the stream's subscribers (cross-worker via Redis)."""
if not request_id:
return
r = _get_redis()
if r is not None:
try:
r.publish(_channel(request_id, stream_name), json.dumps(message))
return
except Exception as exc:
logger.debug("SSE publish via Redis failed, using memory: %s", exc)
_get_or_create_streams(request_id)[stream_name].put(message)
class Subscription:
"""Consumer handle: get(timeout) returns a message or raises queue.Empty."""
def __init__(self, request_id: str, stream_name: str):
r = _get_redis()
if r is not None:
self._pubsub = r.pubsub()
self._pubsub.subscribe(_channel(request_id, stream_name))
self._queue = None
else:
self._pubsub = None
self._queue = _get_or_create_streams(request_id)[stream_name]
def get(self, timeout: float = 1):
if self._pubsub is not None:
msg = self._pubsub.get_message(ignore_subscribe_messages=True, timeout=timeout)
if not msg or msg.get('type') != 'message':
raise queue.Empty
return json.loads(msg['data'])
return self._queue.get(timeout=timeout)
def close(self) -> None:
if self._pubsub is not None:
try:
self._pubsub.unsubscribe()
self._pubsub.close()
except Exception:
pass
def subscribe(request_id: str, stream_name: str) -> Subscription:
return Subscription(request_id, stream_name)
# ---------------------------------------------------------------------------
# QueueHandler — routes log records into the per-request SSE logs queue
# ---------------------------------------------------------------------------
class QueueHandler(logging.Handler):
"""Custom logging handler that forwards records to the SSE logs stream."""
def emit(self, record: logging.LogRecord) -> None:
try:
msg = self.format(record)
request_id = getattr(record, 'request_id', None)
if not request_id:
request_id = REQUEST_ID_CTX.get()
if not request_id and g is not None:
try:
request_id = getattr(g, 'request_id', None)
except RuntimeError:
request_id = None
if request_id:
publish(request_id, 'logs', msg)
except Exception:
self.handleError(record)