33
44Run as a standalone process::
55
6- taskiq-nng-hub --control-addr ipc:///tmp/taskiq-nng.ipc \\
7- --task-db /var/lib/taskiq/tasks.db
6+ taskiq-nng-hub --control-addr ipc:///tmp/taskiq-nng.ipc
87
98Or embed it in an application for testing::
109
11- hub = NNGHub(HubConfig(control_addr="ipc:///tmp/h.ipc", task_db=":memory:" ))
10+ hub = NNGHub(HubConfig(control_addr="ipc:///tmp/h.ipc"))
1211 await hub.start()
1312 ...
1413 await hub.stop()
1817import argparse
1918import asyncio
2019import base64
21- import json
2220import logging
2321import os
2422import signal
2523import time
2624import uuid
27- from concurrent .futures import ThreadPoolExecutor
2825from contextlib import suppress
2926from dataclasses import dataclass , field
3027from typing import Any
3431except ImportError :
3532 pynng = None # type: ignore[assignment]
3633
37- from protocol import (
34+ from . protocol import (
3835 ControlMessage ,
3936 ControlResponse ,
4037 TaskEnvelope ,
4138 WorkerState ,
4239)
43- from storage import QueueFullError , SQLiteJournal , StoreConfig
40+ from . storage import InMemoryStore , QueueFullError , StoreConfig
4441
4542logger = logging .getLogger (__name__ )
4643
@@ -50,7 +47,7 @@ class HubConfig:
5047 """Configuration for :class:`NNGHub`."""
5148
5249 control_addr : str
53- task_db : str
50+ task_db : str = "" # kept for API compat; ignored by in-memory store
5451 max_pending : int = 10_000
5552 heartbeat_timeout : float = 15.0
5653 lease_timeout : float = 20.0
@@ -77,20 +74,18 @@ class NNGHub:
7774 independent ``nng_ctx`` contexts running concurrently. Each context
7875 handles one request-reply at a time, so N workers can
7976 register/heartbeat/ack simultaneously without queuing behind each other.
80- This is the key fix over the single-context (serial) Rep0 in v2.
8177
8278 **Data plane** — One ``Push0`` socket per registered worker, dialed to
8379 the worker's own ``Pull0`` listen address. The hub explicitly targets
84- the least-loaded worker instead of relying on NNG round-robin, giving
85- us load-aware routing.
80+ the least-loaded worker instead of relying on NNG round-robin.
8681
87- **Persistence** — :class:`~taskiq.brokers.nng_storage.SQLiteJournal` in
88- WAL mode. All storage calls are executed on a single-threaded
89- ``ThreadPoolExecutor`` so the asyncio event loop is never blocked and
90- SQLite write serialisation is guaranteed.
82+ **State** — :class:`~taskiq.brokers.nng.storage.InMemoryStore`. All
83+ store operations are synchronous and execute directly on the asyncio event
84+ loop without blocking (no I/O, no syscalls).
9185
92- **Recovery** — On startup, tasks leased to workers that died during the
93- previous hub session are automatically requeued.
86+ **Recovery** — On startup, any tasks that were leased before the hub last
87+ stopped (within the same process lifetime) are automatically requeued by
88+ :meth:`~InMemoryStore.recover_dead_workers`.
9489 """
9590
9691 def __init__ (self , config : HubConfig ) -> None :
@@ -105,9 +100,8 @@ def __init__(self, config: HubConfig) -> None:
105100 "Install it with: pip install taskiq[nng]"
106101 )
107102 self .config = config
108- self .store = SQLiteJournal (
103+ self .store = InMemoryStore (
109104 StoreConfig (
110- path = config .task_db ,
111105 max_pending = config .max_pending ,
112106 lease_timeout = config .lease_timeout ,
113107 backoff_cap = config .backoff_cap ,
@@ -117,16 +111,12 @@ def __init__(self, config: HubConfig) -> None:
117111 self ._ctrl_sock : Any = None # pynng.Rep0
118112 self ._worker_push : dict [str , Any ] = {} # worker_id -> pynng.Push0
119113 self ._tasks : list [asyncio .Task [None ]] = []
120- # Single-threaded executor: serialises all SQLite calls on one OS thread.
121- self ._db_exec = ThreadPoolExecutor (
122- max_workers = 1 , thread_name_prefix = "nng-db"
123- )
124114
125115 # ── lifecycle ─────────────────────────────────────────────────────────────
126116
127117 async def start (self ) -> None :
128118 """Start the hub: recover orphaned tasks, open sockets, spawn loops."""
129- await self ._db ( self . store .recover_dead_workers , self .config .heartbeat_timeout )
119+ self .store .recover_dead_workers ( self .config .heartbeat_timeout )
130120
131121 self ._ctrl_sock = pynng .Rep0 (listen = self .config .control_addr )
132122 self ._ctrl_sock .recv_timeout = self .config .recv_timeout_ms
@@ -142,11 +132,7 @@ async def start(self) -> None:
142132 self ._control_handler (ctx ), name = f"hub-ctrl-{ i } "
143133 ),
144134 )
145- logger .info (
146- "NNG hub started on %s (db=%s)" ,
147- self .config .control_addr ,
148- self .config .task_db ,
149- )
135+ logger .info ("NNG hub started on %s" , self .config .control_addr )
150136
151137 async def stop (self ) -> None :
152138 """Gracefully stop all hub loops and close sockets."""
@@ -163,17 +149,8 @@ async def stop(self) -> None:
163149 if self ._ctrl_sock is not None :
164150 with suppress (Exception ):
165151 self ._ctrl_sock .close ()
166- self ._db_exec .shutdown (wait = True )
167152 logger .info ("NNG hub stopped" )
168153
169- # ── DB helper ─────────────────────────────────────────────────────────────
170-
171- async def _db (self , fn : Any , * args : Any , ** kwargs : Any ) -> Any :
172- loop = asyncio .get_running_loop ()
173- return await loop .run_in_executor (
174- self ._db_exec , lambda : fn (* args , ** kwargs )
175- )
176-
177154 # ── control plane ─────────────────────────────────────────────────────────
178155
179156 async def _control_handler (self , ctx : Any ) -> None :
@@ -216,28 +193,26 @@ async def _handle(self, raw: bytes) -> ControlResponse: # noqa: PLR0911, C901
216193 return await self ._handle_register (msg .payload )
217194
218195 if msg .kind == "heartbeat" :
219- await self ._db ( self . store .heartbeat , msg .payload ["worker_id" ])
196+ self .store .heartbeat ( msg .payload ["worker_id" ])
220197 return ControlResponse (ok = True , payload = {"ok" : True })
221198
222199 if msg .kind == "unregister" :
223200 return await self ._handle_unregister (msg .payload ["worker_id" ])
224201
225202 if msg .kind == "drain" :
226- await self ._db ( self . store .mark_draining , msg .payload ["worker_id" ])
203+ self .store .mark_draining ( msg .payload ["worker_id" ])
227204 return ControlResponse (ok = True , payload = {"draining" : True })
228205
229206 if msg .kind == "ack" :
230- ok = await self ._db (
231- self .store .ack ,
207+ ok = self .store .ack (
232208 msg .payload ["task_id" ],
233209 msg .payload ["worker_id" ],
234210 msg .payload ["lease_id" ],
235211 )
236212 return ControlResponse (ok = ok , payload = {"acked" : ok })
237213
238214 if msg .kind == "nack" :
239- ok = await self ._db (
240- self .store .nack ,
215+ ok = self .store .nack (
241216 msg .payload ["task_id" ],
242217 msg .payload ["worker_id" ],
243218 msg .payload ["lease_id" ],
@@ -246,26 +221,25 @@ async def _handle(self, raw: bytes) -> ControlResponse: # noqa: PLR0911, C901
246221 return ControlResponse (ok = ok , payload = {"nacked" : ok })
247222
248223 if msg .kind == "status" :
249- task = await self ._db ( self . store .get_task , msg .payload ["task_id" ])
250- return ControlResponse (ok = bool (task ), payload = dict ( task ) if task else {})
224+ task = self .store .get_task ( msg .payload ["task_id" ])
225+ return ControlResponse (ok = bool (task ), payload = task or {})
251226
252227 if msg .kind == "stats" :
253- s = await self ._db (self .store .stats )
254- return ControlResponse (ok = True , payload = s )
228+ return ControlResponse (ok = True , payload = self .store .stats ())
255229
256230 return ControlResponse (ok = False , error = f"unknown kind: { msg .kind !r} " )
257231
258232 async def _handle_submit (self , payload : dict [str , Any ]) -> ControlResponse :
259233 envelope = TaskEnvelope (** payload )
260234 try :
261- await self ._db ( self . store .submit , envelope )
235+ self .store .submit ( envelope )
262236 return ControlResponse (ok = True , payload = {"task_id" : envelope .task_id })
263237 except QueueFullError :
264238 return ControlResponse (ok = False , error = "queue full" )
265239
266240 async def _handle_register (self , payload : dict [str , Any ]) -> ControlResponse :
267241 worker = WorkerState (** payload )
268- await self ._db ( self . store .register_worker , worker )
242+ self .store .register_worker ( worker )
269243 if worker .worker_id not in self ._worker_push :
270244 try :
271245 sock = pynng .Push0 (dial = worker .task_addr )
@@ -279,7 +253,7 @@ async def _handle_register(self, payload: dict[str, Any]) -> ControlResponse:
279253 return ControlResponse (ok = True , payload = {"registered" : True })
280254
281255 async def _handle_unregister (self , worker_id : str ) -> ControlResponse :
282- await self ._db ( self . store .unregister_worker , worker_id )
256+ self .store .unregister_worker ( worker_id )
283257 sock = self ._worker_push .pop (worker_id , None )
284258 if sock is not None :
285259 with suppress (Exception ):
@@ -302,13 +276,12 @@ async def _dispatch_loop(self) -> None:
302276
303277 async def _dispatch_once (self ) -> bool :
304278 """Dispatch up to ``dispatch_batch`` due tasks to available workers."""
305- due = await self ._db ( self . store .due_tasks , self .config .dispatch_batch )
279+ due = self .store .due_tasks ( self .config .dispatch_batch )
306280 if not due :
307281 return False
308282 sent_any = False
309283 for row in due :
310- worker = await self ._db (
311- self .store .choose_worker ,
284+ worker = self .store .choose_worker (
312285 self .config .routing_policy ,
313286 heartbeat_timeout = self .config .heartbeat_timeout ,
314287 )
@@ -319,8 +292,7 @@ async def _dispatch_once(self) -> bool:
319292 lease_id = uuid .uuid4 ().hex
320293 lease_until = time .time () + self .config .lease_timeout
321294
322- if not await self ._db (
323- self .store .mark_leased ,
295+ if not self .store .mark_leased (
324296 row ["task_id" ], worker_id , lease_id , lease_until ,
325297 ):
326298 continue # concurrent dispatch race; task already taken
@@ -331,19 +303,14 @@ async def _dispatch_once(self) -> bool:
331303 "No push socket for worker %s, requeueing %s" ,
332304 worker_id , row ["task_id" ],
333305 )
334- await self ._db (
335- self .store .nack ,
336- row ["task_id" ], worker_id , lease_id , "no socket" ,
337- )
306+ self .store .nack (row ["task_id" ], worker_id , lease_id , "no socket" )
338307 continue
339308
340- # Include the hub-generated lease_id so the worker can ack with
341- # the exact token. Omitting it was the core correctness bug in v2.
342309 envelope = TaskEnvelope (
343310 task_id = row ["task_id" ],
344311 task_name = row ["task_name" ],
345312 payload_b64 = base64 .b64encode (row ["payload" ]).decode ("ascii" ),
346- labels = json . loads ( row ["labels_json" ]) ,
313+ labels = row ["labels" ] ,
347314 lease_id = lease_id ,
348315 attempts = int (row ["attempts" ]) + 1 ,
349316 max_retries = int (row ["max_retries" ]),
@@ -360,8 +327,7 @@ async def _dispatch_once(self) -> bool:
360327 "Failed to deliver %s to worker %s: %s" ,
361328 row ["task_id" ], worker_id , exc ,
362329 )
363- await self ._db (
364- self .store .nack ,
330+ self .store .nack (
365331 row ["task_id" ], worker_id , lease_id ,
366332 f"dispatch send failed: { exc } " ,
367333 )
@@ -373,11 +339,10 @@ async def _reaper_loop(self) -> None:
373339 while not self ._stop .is_set ():
374340 try :
375341 await asyncio .sleep (self .config .reaper_interval )
376- reaped = await self ._db ( self . store .reap_expired_leases )
342+ reaped = self .store .reap_expired_leases ( )
377343 if reaped :
378344 logger .debug ("Reaped %d expired leases" , reaped )
379- recovered = await self ._db (
380- self .store .recover_dead_workers ,
345+ recovered = self .store .recover_dead_workers (
381346 self .config .heartbeat_timeout ,
382347 )
383348 if recovered :
@@ -400,11 +365,6 @@ def _build_config() -> HubConfig:
400365 default = os .getenv ("NNG_CONTROL_ADDR" , "ipc:///tmp/taskiq-nng.ipc" ),
401366 help = "NNG address the hub listens on. Env: NNG_CONTROL_ADDR" ,
402367 )
403- p .add_argument (
404- "--task-db" ,
405- default = os .getenv ("NNG_TASK_DB" , "/tmp/taskiq-nng-tasks.db" ), # noqa: S108
406- help = "Path to the SQLite WAL task journal. Env: NNG_TASK_DB" ,
407- )
408368 p .add_argument (
409369 "--max-pending" ,
410370 type = int ,
@@ -445,7 +405,6 @@ def _build_config() -> HubConfig:
445405 )
446406 return HubConfig (
447407 control_addr = args .control_addr ,
448- task_db = args .task_db ,
449408 max_pending = args .max_pending ,
450409 heartbeat_timeout = args .heartbeat_timeout ,
451410 lease_timeout = args .lease_timeout ,
0 commit comments