Skip to content

Commit 7fdb8eb

Browse files
committed
Refactoring the NNG support solution v1.5: Simplify store
1 parent f74cce9 commit 7fdb8eb

5 files changed

Lines changed: 349 additions & 685 deletions

File tree

taskiq/brokers/nng/__init__.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,25 @@
1-
from hub import HubConfig, NNGHub
2-
from protocol import (
1+
"""NNG broker package for taskiq."""
2+
from .hub import HubConfig, NNGHub
3+
from .protocol import (
34
ControlMessage,
45
ControlResponse,
56
MessageKind,
67
TaskEnvelope,
78
WorkerState,
89
WorkerStatus,
910
)
10-
from storage import QueueFullError, SQLiteJournal, StoreConfig
11+
from .storage import InMemoryStore, QueueFullError, StoreConfig
1112

1213
__all__ = [
13-
'HubConfig',
14-
'NNGHub',
15-
'ControlMessage',
16-
'ControlResponse',
17-
'MessageKind',
18-
'TaskEnvelope',
19-
'WorkerState',
20-
'WorkerStatus',
21-
'QueueFullError',
22-
'SQLiteJournal',
23-
'StoreConfig',
14+
"HubConfig",
15+
"NNGHub",
16+
"ControlMessage",
17+
"ControlResponse",
18+
"MessageKind",
19+
"TaskEnvelope",
20+
"WorkerState",
21+
"WorkerStatus",
22+
"QueueFullError",
23+
"InMemoryStore",
24+
"StoreConfig",
2425
]

taskiq/brokers/nng/broker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from taskiq.acks import AckableMessage
1818
from taskiq.message import BrokerMessage
1919

20-
from protocol import (
20+
from .protocol import (
2121
ControlMessage,
2222
ControlResponse,
2323
TaskEnvelope,

taskiq/brokers/nng/hub.py

Lines changed: 33 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,11 @@
33
44
Run as a standalone process::
55
6-
taskiq-nng-hub --control-addr ipc:///tmp/taskiq-nng.ipc \\
7-
--task-db /var/lib/taskiq/tasks.db
6+
taskiq-nng-hub --control-addr ipc:///tmp/taskiq-nng.ipc
87
98
Or embed it in an application for testing::
109
11-
hub = NNGHub(HubConfig(control_addr="ipc:///tmp/h.ipc", task_db=":memory:"))
10+
hub = NNGHub(HubConfig(control_addr="ipc:///tmp/h.ipc"))
1211
await hub.start()
1312
...
1413
await hub.stop()
@@ -18,13 +17,11 @@
1817
import argparse
1918
import asyncio
2019
import base64
21-
import json
2220
import logging
2321
import os
2422
import signal
2523
import time
2624
import uuid
27-
from concurrent.futures import ThreadPoolExecutor
2825
from contextlib import suppress
2926
from dataclasses import dataclass, field
3027
from typing import Any
@@ -34,13 +31,13 @@
3431
except ImportError:
3532
pynng = None # type: ignore[assignment]
3633

37-
from protocol import (
34+
from .protocol import (
3835
ControlMessage,
3936
ControlResponse,
4037
TaskEnvelope,
4138
WorkerState,
4239
)
43-
from storage import QueueFullError, SQLiteJournal, StoreConfig
40+
from .storage import InMemoryStore, QueueFullError, StoreConfig
4441

4542
logger = logging.getLogger(__name__)
4643

@@ -50,7 +47,7 @@ class HubConfig:
5047
"""Configuration for :class:`NNGHub`."""
5148

5249
control_addr: str
53-
task_db: str
50+
task_db: str = "" # kept for API compat; ignored by in-memory store
5451
max_pending: int = 10_000
5552
heartbeat_timeout: float = 15.0
5653
lease_timeout: float = 20.0
@@ -77,20 +74,18 @@ class NNGHub:
7774
independent ``nng_ctx`` contexts running concurrently. Each context
7875
handles one request-reply at a time, so N workers can
7976
register/heartbeat/ack simultaneously without queuing behind each other.
80-
This is the key fix over the single-context (serial) Rep0 in v2.
8177
8278
**Data plane** — One ``Push0`` socket per registered worker, dialed to
8379
the worker's own ``Pull0`` listen address. The hub explicitly targets
84-
the least-loaded worker instead of relying on NNG round-robin, giving
85-
us load-aware routing.
80+
the least-loaded worker instead of relying on NNG round-robin.
8681
87-
**Persistence** — :class:`~taskiq.brokers.nng_storage.SQLiteJournal` in
88-
WAL mode. All storage calls are executed on a single-threaded
89-
``ThreadPoolExecutor`` so the asyncio event loop is never blocked and
90-
SQLite write serialisation is guaranteed.
82+
**State** — :class:`~taskiq.brokers.nng.storage.InMemoryStore`. All
83+
store operations are synchronous and execute directly on the asyncio event
84+
loop without blocking (no I/O, no syscalls).
9185
92-
**Recovery** — On startup, tasks leased to workers that died during the
93-
previous hub session are automatically requeued.
86+
**Recovery** — On startup, any tasks that were leased before the hub last
87+
stopped (within the same process lifetime) are automatically requeued by
88+
:meth:`~InMemoryStore.recover_dead_workers`.
9489
"""
9590

9691
def __init__(self, config: HubConfig) -> None:
@@ -105,9 +100,8 @@ def __init__(self, config: HubConfig) -> None:
105100
"Install it with: pip install taskiq[nng]"
106101
)
107102
self.config = config
108-
self.store = SQLiteJournal(
103+
self.store = InMemoryStore(
109104
StoreConfig(
110-
path=config.task_db,
111105
max_pending=config.max_pending,
112106
lease_timeout=config.lease_timeout,
113107
backoff_cap=config.backoff_cap,
@@ -117,16 +111,12 @@ def __init__(self, config: HubConfig) -> None:
117111
self._ctrl_sock: Any = None # pynng.Rep0
118112
self._worker_push: dict[str, Any] = {} # worker_id -> pynng.Push0
119113
self._tasks: list[asyncio.Task[None]] = []
120-
# Single-threaded executor: serialises all SQLite calls on one OS thread.
121-
self._db_exec = ThreadPoolExecutor(
122-
max_workers=1, thread_name_prefix="nng-db"
123-
)
124114

125115
# ── lifecycle ─────────────────────────────────────────────────────────────
126116

127117
async def start(self) -> None:
128118
"""Start the hub: recover orphaned tasks, open sockets, spawn loops."""
129-
await self._db(self.store.recover_dead_workers, self.config.heartbeat_timeout)
119+
self.store.recover_dead_workers(self.config.heartbeat_timeout)
130120

131121
self._ctrl_sock = pynng.Rep0(listen=self.config.control_addr)
132122
self._ctrl_sock.recv_timeout = self.config.recv_timeout_ms
@@ -142,11 +132,7 @@ async def start(self) -> None:
142132
self._control_handler(ctx), name=f"hub-ctrl-{i}"
143133
),
144134
)
145-
logger.info(
146-
"NNG hub started on %s (db=%s)",
147-
self.config.control_addr,
148-
self.config.task_db,
149-
)
135+
logger.info("NNG hub started on %s", self.config.control_addr)
150136

151137
async def stop(self) -> None:
152138
"""Gracefully stop all hub loops and close sockets."""
@@ -163,17 +149,8 @@ async def stop(self) -> None:
163149
if self._ctrl_sock is not None:
164150
with suppress(Exception):
165151
self._ctrl_sock.close()
166-
self._db_exec.shutdown(wait=True)
167152
logger.info("NNG hub stopped")
168153

169-
# ── DB helper ─────────────────────────────────────────────────────────────
170-
171-
async def _db(self, fn: Any, *args: Any, **kwargs: Any) -> Any:
172-
loop = asyncio.get_running_loop()
173-
return await loop.run_in_executor(
174-
self._db_exec, lambda: fn(*args, **kwargs)
175-
)
176-
177154
# ── control plane ─────────────────────────────────────────────────────────
178155

179156
async def _control_handler(self, ctx: Any) -> None:
@@ -216,28 +193,26 @@ async def _handle(self, raw: bytes) -> ControlResponse: # noqa: PLR0911, C901
216193
return await self._handle_register(msg.payload)
217194

218195
if msg.kind == "heartbeat":
219-
await self._db(self.store.heartbeat, msg.payload["worker_id"])
196+
self.store.heartbeat(msg.payload["worker_id"])
220197
return ControlResponse(ok=True, payload={"ok": True})
221198

222199
if msg.kind == "unregister":
223200
return await self._handle_unregister(msg.payload["worker_id"])
224201

225202
if msg.kind == "drain":
226-
await self._db(self.store.mark_draining, msg.payload["worker_id"])
203+
self.store.mark_draining(msg.payload["worker_id"])
227204
return ControlResponse(ok=True, payload={"draining": True})
228205

229206
if msg.kind == "ack":
230-
ok = await self._db(
231-
self.store.ack,
207+
ok = self.store.ack(
232208
msg.payload["task_id"],
233209
msg.payload["worker_id"],
234210
msg.payload["lease_id"],
235211
)
236212
return ControlResponse(ok=ok, payload={"acked": ok})
237213

238214
if msg.kind == "nack":
239-
ok = await self._db(
240-
self.store.nack,
215+
ok = self.store.nack(
241216
msg.payload["task_id"],
242217
msg.payload["worker_id"],
243218
msg.payload["lease_id"],
@@ -246,26 +221,25 @@ async def _handle(self, raw: bytes) -> ControlResponse: # noqa: PLR0911, C901
246221
return ControlResponse(ok=ok, payload={"nacked": ok})
247222

248223
if msg.kind == "status":
249-
task = await self._db(self.store.get_task, msg.payload["task_id"])
250-
return ControlResponse(ok=bool(task), payload=dict(task) if task else {})
224+
task = self.store.get_task(msg.payload["task_id"])
225+
return ControlResponse(ok=bool(task), payload=task or {})
251226

252227
if msg.kind == "stats":
253-
s = await self._db(self.store.stats)
254-
return ControlResponse(ok=True, payload=s)
228+
return ControlResponse(ok=True, payload=self.store.stats())
255229

256230
return ControlResponse(ok=False, error=f"unknown kind: {msg.kind!r}")
257231

258232
async def _handle_submit(self, payload: dict[str, Any]) -> ControlResponse:
259233
envelope = TaskEnvelope(**payload)
260234
try:
261-
await self._db(self.store.submit, envelope)
235+
self.store.submit(envelope)
262236
return ControlResponse(ok=True, payload={"task_id": envelope.task_id})
263237
except QueueFullError:
264238
return ControlResponse(ok=False, error="queue full")
265239

266240
async def _handle_register(self, payload: dict[str, Any]) -> ControlResponse:
267241
worker = WorkerState(**payload)
268-
await self._db(self.store.register_worker, worker)
242+
self.store.register_worker(worker)
269243
if worker.worker_id not in self._worker_push:
270244
try:
271245
sock = pynng.Push0(dial=worker.task_addr)
@@ -279,7 +253,7 @@ async def _handle_register(self, payload: dict[str, Any]) -> ControlResponse:
279253
return ControlResponse(ok=True, payload={"registered": True})
280254

281255
async def _handle_unregister(self, worker_id: str) -> ControlResponse:
282-
await self._db(self.store.unregister_worker, worker_id)
256+
self.store.unregister_worker(worker_id)
283257
sock = self._worker_push.pop(worker_id, None)
284258
if sock is not None:
285259
with suppress(Exception):
@@ -302,13 +276,12 @@ async def _dispatch_loop(self) -> None:
302276

303277
async def _dispatch_once(self) -> bool:
304278
"""Dispatch up to ``dispatch_batch`` due tasks to available workers."""
305-
due = await self._db(self.store.due_tasks, self.config.dispatch_batch)
279+
due = self.store.due_tasks(self.config.dispatch_batch)
306280
if not due:
307281
return False
308282
sent_any = False
309283
for row in due:
310-
worker = await self._db(
311-
self.store.choose_worker,
284+
worker = self.store.choose_worker(
312285
self.config.routing_policy,
313286
heartbeat_timeout=self.config.heartbeat_timeout,
314287
)
@@ -319,8 +292,7 @@ async def _dispatch_once(self) -> bool:
319292
lease_id = uuid.uuid4().hex
320293
lease_until = time.time() + self.config.lease_timeout
321294

322-
if not await self._db(
323-
self.store.mark_leased,
295+
if not self.store.mark_leased(
324296
row["task_id"], worker_id, lease_id, lease_until,
325297
):
326298
continue # concurrent dispatch race; task already taken
@@ -331,19 +303,14 @@ async def _dispatch_once(self) -> bool:
331303
"No push socket for worker %s, requeueing %s",
332304
worker_id, row["task_id"],
333305
)
334-
await self._db(
335-
self.store.nack,
336-
row["task_id"], worker_id, lease_id, "no socket",
337-
)
306+
self.store.nack(row["task_id"], worker_id, lease_id, "no socket")
338307
continue
339308

340-
# Include the hub-generated lease_id so the worker can ack with
341-
# the exact token. Omitting it was the core correctness bug in v2.
342309
envelope = TaskEnvelope(
343310
task_id=row["task_id"],
344311
task_name=row["task_name"],
345312
payload_b64=base64.b64encode(row["payload"]).decode("ascii"),
346-
labels=json.loads(row["labels_json"]),
313+
labels=row["labels"],
347314
lease_id=lease_id,
348315
attempts=int(row["attempts"]) + 1,
349316
max_retries=int(row["max_retries"]),
@@ -360,8 +327,7 @@ async def _dispatch_once(self) -> bool:
360327
"Failed to deliver %s to worker %s: %s",
361328
row["task_id"], worker_id, exc,
362329
)
363-
await self._db(
364-
self.store.nack,
330+
self.store.nack(
365331
row["task_id"], worker_id, lease_id,
366332
f"dispatch send failed: {exc}",
367333
)
@@ -373,11 +339,10 @@ async def _reaper_loop(self) -> None:
373339
while not self._stop.is_set():
374340
try:
375341
await asyncio.sleep(self.config.reaper_interval)
376-
reaped = await self._db(self.store.reap_expired_leases)
342+
reaped = self.store.reap_expired_leases()
377343
if reaped:
378344
logger.debug("Reaped %d expired leases", reaped)
379-
recovered = await self._db(
380-
self.store.recover_dead_workers,
345+
recovered = self.store.recover_dead_workers(
381346
self.config.heartbeat_timeout,
382347
)
383348
if recovered:
@@ -400,11 +365,6 @@ def _build_config() -> HubConfig:
400365
default=os.getenv("NNG_CONTROL_ADDR", "ipc:///tmp/taskiq-nng.ipc"),
401366
help="NNG address the hub listens on. Env: NNG_CONTROL_ADDR",
402367
)
403-
p.add_argument(
404-
"--task-db",
405-
default=os.getenv("NNG_TASK_DB", "/tmp/taskiq-nng-tasks.db"), # noqa: S108
406-
help="Path to the SQLite WAL task journal. Env: NNG_TASK_DB",
407-
)
408368
p.add_argument(
409369
"--max-pending",
410370
type=int,
@@ -445,7 +405,6 @@ def _build_config() -> HubConfig:
445405
)
446406
return HubConfig(
447407
control_addr=args.control_addr,
448-
task_db=args.task_db,
449408
max_pending=args.max_pending,
450409
heartbeat_timeout=args.heartbeat_timeout,
451410
lease_timeout=args.lease_timeout,

0 commit comments

Comments
 (0)