@@ -174,26 +174,60 @@ def compute_group_distances(
174174 return result
175175
176176
177+ def build_dimension_weights (weights : dict [str , float ]) -> np .ndarray :
178+ """
179+ Expand per-group weights into a per-dimension weight vector.
180+
181+ Callers in a hot loop (e.g. MMR re-ranking) can build this once and reuse it
182+ across many compute_weighted_distance_vec calls instead of rebuilding it per call.
183+
184+ :param weights: Per-group weight overrides keyed by FEATURE_GROUPS name. Groups
185+ absent from the dict default to weight 1.0.
186+ :returns: Float64 weight array of length VECTOR_DIMENSIONS.
187+ """
188+ dim_weights = np .ones (VECTOR_DIMENSIONS , dtype = np .float64 )
189+ for group , (start , end ) in FEATURE_GROUPS .items ():
190+ if group in weights :
191+ dim_weights [start :end ] = weights [group ]
192+ return dim_weights
193+
194+
177195def compute_weighted_distance (
178196 sig_a : list [float ] | np .ndarray ,
179197 sig_b : list [float ] | np .ndarray ,
180198 weights : dict [str , float ],
181199) -> float :
182- """Compute per-group weighted Euclidean distance between two feature vectors.
200+ """
201+ Compute per-group weighted Euclidean distance between two feature vectors.
183202
184203 :param sig_a: First feature vector (list or numpy array).
185204 :param sig_b: Second feature vector (list or numpy array).
186205 :param weights: Per-group weight overrides keyed by FEATURE_GROUPS name.
187206 :returns: Weighted normalized distance as a float.
188207 """
208+ return compute_weighted_distance_vec (sig_a , sig_b , build_dimension_weights (weights ))
209+
210+
211+ def compute_weighted_distance_vec (
212+ sig_a : list [float ] | np .ndarray ,
213+ sig_b : list [float ] | np .ndarray ,
214+ dim_weights : np .ndarray ,
215+ ) -> float :
216+ """
217+ Compute weighted Euclidean distance from a precomputed per-dimension weight vector.
218+
219+ :param sig_a: First feature vector (list or numpy array).
220+ :param sig_b: Second feature vector (list or numpy array).
221+ :param dim_weights: Per-dimension weights as built by build_dimension_weights.
222+ :returns: Weighted normalized distance as a float.
223+ """
224+ total_weighted_dims = float (dim_weights .sum ())
225+ if total_weighted_dims == 0.0 :
226+ return 0.0
189227 # np.asarray is a no-op when the caller already holds a float64 array (the
190228 # MMR hot path), avoiding the list round-trip the previous version forced.
191229 a = np .asarray (sig_a , dtype = np .float64 )
192230 b = np .asarray (sig_b , dtype = np .float64 )
193- dim_weights = _expand_group_weights (weights )
194- total_weighted_dims = float (dim_weights .sum ())
195- if total_weighted_dims == 0.0 :
196- return 0.0
197231 diff = a - b
198232 weighted_sq_sum = float (np .dot (dim_weights , diff * diff ))
199233 return math .sqrt (weighted_sq_sum / total_weighted_dims )
@@ -222,18 +256,3 @@ def build_debug_breakdown(
222256 for k , v in compute_group_distances (seed_normalized , cand_normalized ).items ()
223257 },
224258 }
225-
226-
227- def _expand_group_weights (weights : dict [str , float ]) -> np .ndarray :
228- """
229- Expand per-group weights into a per-dimension weight vector.
230-
231- :param weights: Per-group weight overrides keyed by FEATURE_GROUPS name. Groups
232- absent from the dict default to weight 1.0.
233- :returns: Float64 weight array of length VECTOR_DIMENSIONS.
234- """
235- dim_weights = np .ones (VECTOR_DIMENSIONS , dtype = np .float64 )
236- for group , (start , end ) in FEATURE_GROUPS .items ():
237- if group in weights :
238- dim_weights [start :end ] = weights [group ]
239- return dim_weights
0 commit comments