44import random
55import time
66from dataclasses import dataclass , field
7- from typing import TYPE_CHECKING , Any
7+ from typing import TYPE_CHECKING , Any , Protocol , runtime_checkable
88
99if TYPE_CHECKING :
1010 from .protocol import TaskEnvelope , WorkerState
@@ -103,6 +103,113 @@ def as_dict(self) -> dict[str, Any]:
103103 }
104104
105105
106+ # ── routing policy abstraction ────────────────────────────────────────────────
107+
108+
109+ @dataclass (frozen = True )
110+ class WorkerView :
111+ """Immutable worker snapshot passed to :class:`RoutingPolicy` implementations."""
112+
113+ worker_id : str
114+ inflight : int
115+ capacity : int
116+
117+ @property
118+ def load (self ) -> float :
119+ """Fractional load: 0.0 idle → 1.0 at capacity."""
120+ return self .inflight / max (self .capacity , 1 )
121+
122+
123+ @runtime_checkable
124+ class RoutingPolicy (Protocol ):
125+ """Strategy interface for selecting a dispatch target from available workers."""
126+
127+ def choose (self , workers : list [WorkerView ]) -> WorkerView | None :
128+ """Return the chosen worker, or None to hold off dispatch."""
129+ ...
130+
131+
132+ class LeastLoaded :
133+ """Pick the worker with the lowest inflight / capacity ratio."""
134+
135+ def choose (self , workers : list [WorkerView ]) -> WorkerView | None :
136+ """Return the least-loaded worker."""
137+ if not workers :
138+ return None
139+ return min (workers , key = lambda w : w .load )
140+
141+
142+ class PowerOfTwoChoices :
143+ """
144+ Power-of-two-choices routing.
145+
146+ Samples two workers uniformly at random and returns the less loaded one.
147+ Reduces hot-spot probability under high concurrency compared to pure random.
148+ """
149+
150+ def choose (self , workers : list [WorkerView ]) -> WorkerView | None :
151+ """Return the less loaded of two randomly sampled workers."""
152+ if not workers :
153+ return None
154+ if len (workers ) == 1 :
155+ return workers [0 ]
156+ a , b = random .sample (workers , k = 2 ) # noqa: S311
157+ return a if a .load <= b .load else b
158+
159+
160+ class RoundRobin :
161+ """
162+ Round-robin routing — cycle through workers in alphabetical ID order.
163+
164+ Ignores load; useful when tasks are homogeneous and worker capacity is equal.
165+ The counter is per-instance, so each :class:`NNGHub` maintains its own cycle.
166+ """
167+
168+ def __init__ (self ) -> None :
169+ """Initialise the cycle counter."""
170+ self ._idx : int = 0
171+
172+ def choose (self , workers : list [WorkerView ]) -> WorkerView | None :
173+ """Return the next worker in the cycle."""
174+ if not workers :
175+ return None
176+ w = workers [self ._idx % len (workers )]
177+ self ._idx += 1
178+ return w
179+
180+
181+ # Singletons for stateless built-ins; RoundRobin singleton is fine for single-hub
182+ # processes. Users needing isolated counters should pass their own instance.
183+ _BUILTIN_POLICIES : dict [str , RoutingPolicy ] = {
184+ "least_loaded" : LeastLoaded (),
185+ "p2c" : PowerOfTwoChoices (),
186+ "round_robin" : RoundRobin (),
187+ }
188+
189+
190+ def make_routing_policy (policy : "RoutingPolicy | str" ) -> RoutingPolicy :
191+ """
192+ Resolve a routing policy name or pass through an instance.
193+
194+ :param policy: ``'least_loaded'``, ``'p2c'``, ``'round_robin'``, or a
195+ :class:`RoutingPolicy` instance.
196+ :return: concrete routing policy.
197+ :raises ValueError: for unknown string names.
198+ """
199+ if isinstance (policy , str ):
200+ resolved = _BUILTIN_POLICIES .get (policy )
201+ if resolved is None :
202+ raise ValueError (
203+ f"Unknown routing policy { policy !r} ; "
204+ f"available: { sorted (_BUILTIN_POLICIES )} "
205+ )
206+ return resolved
207+ return policy
208+
209+
210+ # ── store ─────────────────────────────────────────────────────────────────────
211+
212+
106213class InMemoryStore :
107214 """
108215 Pure in-memory task store for the NNG hub.
@@ -384,17 +491,17 @@ def mark_draining(self, worker_id: str) -> None:
384491
385492 def choose_worker (
386493 self ,
387- routing_policy : str = "least_loaded" ,
494+ policy : "RoutingPolicy | str" = "least_loaded" ,
388495 * ,
389496 heartbeat_timeout : float = 15.0 ,
390497 ) -> dict [str , Any ] | None :
391498 """
392- Select the best available worker according to ``routing_policy`` .
499+ Select the best available worker using a routing policy .
393500
394- ``'least_loaded'`` picks the worker with the lowest inflight/capacity
395- ratio. ``'p2c'`` samples two workers and picks the less loaded one .
501+ Accepts a :class:`RoutingPolicy` instance or a string name
502+ (``'least_loaded'``, ``'p2c'``, ``'round_robin'``) .
396503
397- :param routing_policy: ``'least_loaded'`` or ``'p2c'`` .
504+ :param policy: routing policy or name .
398505 :param heartbeat_timeout: seconds before a worker is considered stale.
399506 :return: chosen worker dict, or None if no worker has capacity.
400507 """
@@ -408,14 +515,17 @@ def choose_worker(
408515 ]
409516 if not available :
410517 return None
411- if routing_policy == "p2c" and len (available ) >= 2 :
412- a , b = random .sample (available , k = 2 ) # noqa: S311
413- load_a = a .inflight / max (a .capacity , 1 )
414- load_b = b .inflight / max (b .capacity , 1 )
415- chosen = a if load_a <= load_b else b
416- else :
417- chosen = min (available , key = lambda w : w .inflight / max (w .capacity , 1 ))
418- return chosen .as_dict ()
518+ # Stable sort so RoundRobin cycles in a predictable, deterministic order.
519+ views = sorted (
520+ [WorkerView (w .worker_id , w .inflight , w .capacity ) for w in available ],
521+ key = lambda v : v .worker_id ,
522+ )
523+ routing = make_routing_policy (policy )
524+ chosen = routing .choose (views )
525+ if chosen is None :
526+ return None
527+ worker = self ._workers .get (chosen .worker_id )
528+ return worker .as_dict () if worker is not None else None
419529
420530 # ── observability ─────────────────────────────────────────────────────────
421531
0 commit comments