Skip to content

Commit b8c7212

Browse files
committed
Refactoring the NNG support solution v2: Update routing policy
1 parent 7fdb8eb commit b8c7212

4 files changed

Lines changed: 295 additions & 22 deletions

File tree

taskiq/brokers/nng/__init__.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,37 @@
88
WorkerState,
99
WorkerStatus,
1010
)
11-
from .storage import InMemoryStore, QueueFullError, StoreConfig
11+
from .storage import (
12+
InMemoryStore,
13+
LeastLoaded,
14+
PowerOfTwoChoices,
15+
QueueFullError,
16+
RoutingPolicy,
17+
RoundRobin,
18+
StoreConfig,
19+
WorkerView,
20+
make_routing_policy,
21+
)
1222

1323
__all__ = [
1424
"HubConfig",
1525
"NNGHub",
26+
# protocol
1627
"ControlMessage",
1728
"ControlResponse",
1829
"MessageKind",
1930
"TaskEnvelope",
2031
"WorkerState",
2132
"WorkerStatus",
33+
# store
2234
"QueueFullError",
2335
"InMemoryStore",
2436
"StoreConfig",
37+
# routing
38+
"WorkerView",
39+
"RoutingPolicy",
40+
"LeastLoaded",
41+
"PowerOfTwoChoices",
42+
"RoundRobin",
43+
"make_routing_policy",
2544
]

taskiq/brokers/nng/hub.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,13 @@
3737
TaskEnvelope,
3838
WorkerState,
3939
)
40-
from .storage import InMemoryStore, QueueFullError, StoreConfig
40+
from .storage import (
41+
InMemoryStore,
42+
QueueFullError,
43+
RoutingPolicy,
44+
StoreConfig,
45+
make_routing_policy,
46+
)
4147

4248
logger = logging.getLogger(__name__)
4349

@@ -53,7 +59,7 @@ class HubConfig:
5359
lease_timeout: float = 20.0
5460
dispatch_interval: float = 0.05
5561
reaper_interval: float = 0.5
56-
routing_policy: str = "least_loaded"
62+
routing_policy: RoutingPolicy | str = "least_loaded"
5763
backoff_cap: float = 60.0
5864
# Number of concurrent Rep0 contexts. Each context handles one req/rep
5965
# pair independently; N contexts ≈ N simultaneous control-plane clients.
@@ -107,6 +113,9 @@ def __init__(self, config: HubConfig) -> None:
107113
backoff_cap=config.backoff_cap,
108114
),
109115
)
116+
# Resolve once at construction so RoundRobin and similar stateful
117+
# policies maintain their counter across dispatch calls.
118+
self._routing: RoutingPolicy = make_routing_policy(config.routing_policy)
110119
self._stop = asyncio.Event()
111120
self._ctrl_sock: Any = None # pynng.Rep0
112121
self._worker_push: dict[str, Any] = {} # worker_id -> pynng.Push0
@@ -282,7 +291,7 @@ async def _dispatch_once(self) -> bool:
282291
sent_any = False
283292
for row in due:
284293
worker = self.store.choose_worker(
285-
self.config.routing_policy,
294+
self._routing,
286295
heartbeat_timeout=self.config.heartbeat_timeout,
287296
)
288297
if worker is None:
@@ -384,7 +393,7 @@ def _build_config() -> HubConfig:
384393
)
385394
p.add_argument(
386395
"--routing-policy",
387-
choices=["least_loaded", "p2c"],
396+
choices=["least_loaded", "p2c", "round_robin"],
388397
default=os.getenv("NNG_ROUTING_POLICY", "least_loaded"),
389398
)
390399
p.add_argument(

taskiq/brokers/nng/storage.py

Lines changed: 124 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import random
55
import time
66
from dataclasses import dataclass, field
7-
from typing import TYPE_CHECKING, Any
7+
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
88

99
if TYPE_CHECKING:
1010
from .protocol import TaskEnvelope, WorkerState
@@ -103,6 +103,113 @@ def as_dict(self) -> dict[str, Any]:
103103
}
104104

105105

106+
# ── routing policy abstraction ────────────────────────────────────────────────
107+
108+
109+
@dataclass(frozen=True)
110+
class WorkerView:
111+
"""Immutable worker snapshot passed to :class:`RoutingPolicy` implementations."""
112+
113+
worker_id: str
114+
inflight: int
115+
capacity: int
116+
117+
@property
118+
def load(self) -> float:
119+
"""Fractional load: 0.0 idle → 1.0 at capacity."""
120+
return self.inflight / max(self.capacity, 1)
121+
122+
123+
@runtime_checkable
124+
class RoutingPolicy(Protocol):
125+
"""Strategy interface for selecting a dispatch target from available workers."""
126+
127+
def choose(self, workers: list[WorkerView]) -> WorkerView | None:
128+
"""Return the chosen worker, or None to hold off dispatch."""
129+
...
130+
131+
132+
class LeastLoaded:
133+
"""Pick the worker with the lowest inflight / capacity ratio."""
134+
135+
def choose(self, workers: list[WorkerView]) -> WorkerView | None:
136+
"""Return the least-loaded worker."""
137+
if not workers:
138+
return None
139+
return min(workers, key=lambda w: w.load)
140+
141+
142+
class PowerOfTwoChoices:
143+
"""
144+
Power-of-two-choices routing.
145+
146+
Samples two workers uniformly at random and returns the less loaded one.
147+
Reduces hot-spot probability under high concurrency compared to pure random.
148+
"""
149+
150+
def choose(self, workers: list[WorkerView]) -> WorkerView | None:
151+
"""Return the less loaded of two randomly sampled workers."""
152+
if not workers:
153+
return None
154+
if len(workers) == 1:
155+
return workers[0]
156+
a, b = random.sample(workers, k=2) # noqa: S311
157+
return a if a.load <= b.load else b
158+
159+
160+
class RoundRobin:
161+
"""
162+
Round-robin routing — cycle through workers in alphabetical ID order.
163+
164+
Ignores load; useful when tasks are homogeneous and worker capacity is equal.
165+
The counter is per-instance, so each :class:`NNGHub` maintains its own cycle.
166+
"""
167+
168+
def __init__(self) -> None:
169+
"""Initialise the cycle counter."""
170+
self._idx: int = 0
171+
172+
def choose(self, workers: list[WorkerView]) -> WorkerView | None:
173+
"""Return the next worker in the cycle."""
174+
if not workers:
175+
return None
176+
w = workers[self._idx % len(workers)]
177+
self._idx += 1
178+
return w
179+
180+
181+
# Singletons for stateless built-ins; RoundRobin singleton is fine for single-hub
182+
# processes. Users needing isolated counters should pass their own instance.
183+
_BUILTIN_POLICIES: dict[str, RoutingPolicy] = {
184+
"least_loaded": LeastLoaded(),
185+
"p2c": PowerOfTwoChoices(),
186+
"round_robin": RoundRobin(),
187+
}
188+
189+
190+
def make_routing_policy(policy: "RoutingPolicy | str") -> RoutingPolicy:
191+
"""
192+
Resolve a routing policy name or pass through an instance.
193+
194+
:param policy: ``'least_loaded'``, ``'p2c'``, ``'round_robin'``, or a
195+
:class:`RoutingPolicy` instance.
196+
:return: concrete routing policy.
197+
:raises ValueError: for unknown string names.
198+
"""
199+
if isinstance(policy, str):
200+
resolved = _BUILTIN_POLICIES.get(policy)
201+
if resolved is None:
202+
raise ValueError(
203+
f"Unknown routing policy {policy!r}; "
204+
f"available: {sorted(_BUILTIN_POLICIES)}"
205+
)
206+
return resolved
207+
return policy
208+
209+
210+
# ── store ─────────────────────────────────────────────────────────────────────
211+
212+
106213
class InMemoryStore:
107214
"""
108215
Pure in-memory task store for the NNG hub.
@@ -384,17 +491,17 @@ def mark_draining(self, worker_id: str) -> None:
384491

385492
def choose_worker(
386493
self,
387-
routing_policy: str = "least_loaded",
494+
policy: "RoutingPolicy | str" = "least_loaded",
388495
*,
389496
heartbeat_timeout: float = 15.0,
390497
) -> dict[str, Any] | None:
391498
"""
392-
Select the best available worker according to ``routing_policy``.
499+
Select the best available worker using a routing policy.
393500
394-
``'least_loaded'`` picks the worker with the lowest inflight/capacity
395-
ratio. ``'p2c'`` samples two workers and picks the less loaded one.
501+
Accepts a :class:`RoutingPolicy` instance or a string name
502+
(``'least_loaded'``, ``'p2c'``, ``'round_robin'``).
396503
397-
:param routing_policy: ``'least_loaded'`` or ``'p2c'``.
504+
:param policy: routing policy or name.
398505
:param heartbeat_timeout: seconds before a worker is considered stale.
399506
:return: chosen worker dict, or None if no worker has capacity.
400507
"""
@@ -408,14 +515,17 @@ def choose_worker(
408515
]
409516
if not available:
410517
return None
411-
if routing_policy == "p2c" and len(available) >= 2:
412-
a, b = random.sample(available, k=2) # noqa: S311
413-
load_a = a.inflight / max(a.capacity, 1)
414-
load_b = b.inflight / max(b.capacity, 1)
415-
chosen = a if load_a <= load_b else b
416-
else:
417-
chosen = min(available, key=lambda w: w.inflight / max(w.capacity, 1))
418-
return chosen.as_dict()
518+
# Stable sort so RoundRobin cycles in a predictable, deterministic order.
519+
views = sorted(
520+
[WorkerView(w.worker_id, w.inflight, w.capacity) for w in available],
521+
key=lambda v: v.worker_id,
522+
)
523+
routing = make_routing_policy(policy)
524+
chosen = routing.choose(views)
525+
if chosen is None:
526+
return None
527+
worker = self._workers.get(chosen.worker_id)
528+
return worker.as_dict() if worker is not None else None
419529

420530
# ── observability ─────────────────────────────────────────────────────────
421531

0 commit comments

Comments
 (0)