|
| 1 | +"""Assign WhisperX diarization speakers to participant identities. |
| 2 | +
|
| 3 | +Uses per-stream VAD events to match generic SPEAKER_XX labels provided |
| 4 | +by diarization to real user id's by computing time overlap between |
| 5 | +diarization segments and VAD intervals. |
| 6 | +
|
| 7 | +Multiple speakers can map to the same participant (e.g. two people sharing |
| 8 | +one microphone). A participant with no matching speaker gets no assignment. |
| 9 | +""" |
| 10 | + |
| 11 | +from __future__ import annotations |
| 12 | + |
| 13 | +import logging |
| 14 | +from dataclasses import dataclass, field |
| 15 | +from datetime import datetime |
| 16 | +from typing import Any |
| 17 | + |
| 18 | +logger = logging.getLogger(__name__) |
| 19 | + |
| 20 | +# Minimum fraction of a speaker's total duration that must overlap with a |
| 21 | +# participant's VAD to accept the assignment. |
| 22 | +DEFAULT_OVERLAP_THRESHOLD = 0.5 |
| 23 | + |
| 24 | + |
| 25 | +@dataclass |
| 26 | +class Interval: |
| 27 | + """A time interval in seconds relative to recording start.""" |
| 28 | + |
| 29 | + start: float |
| 30 | + end: float |
| 31 | + |
| 32 | + |
| 33 | +@dataclass |
| 34 | +class SpeakerAssignment: |
| 35 | + """Maps a diarization speaker label to a participant.""" |
| 36 | + |
| 37 | + speaker_label: str |
| 38 | + participant_id: str |
| 39 | + participant_name: str |
| 40 | + score: float |
| 41 | + |
| 42 | + |
| 43 | +@dataclass |
| 44 | +class AssignmentResult: |
| 45 | + """Result of speaker-to-participant assignment.""" |
| 46 | + |
| 47 | + assignments: list[SpeakerAssignment] = field(default_factory=list) |
| 48 | + unassigned_speakers: list[str] = field(default_factory=list) |
| 49 | + |
| 50 | + def apply(self, diarization: dict[str, Any]) -> dict[str, Any]: |
| 51 | + """Return a copy of diarization with speaker labels replaced by names. |
| 52 | +
|
| 53 | + Replaces `"speaker"` fields in segments and word_segments with the |
| 54 | + assigned participant name. Unassigned speakers are left as-is. |
| 55 | +
|
| 56 | + Args: |
| 57 | + diarization: WhisperX dict with `segments` and optionally |
| 58 | + `word_segments`. |
| 59 | +
|
| 60 | + Returns: |
| 61 | + New dict with speaker labels replaced. |
| 62 | + """ |
| 63 | + speaker_to_name = { |
| 64 | + a.speaker_label: a.participant_name for a in self.assignments |
| 65 | + } |
| 66 | + |
| 67 | + def _replace_speaker(item: dict[str, Any]) -> dict[str, Any]: |
| 68 | + if "speaker" in item and item["speaker"] in speaker_to_name: |
| 69 | + return {**item, "speaker": speaker_to_name[item["speaker"]]} |
| 70 | + return item |
| 71 | + |
| 72 | + result: dict[str, Any] = {} |
| 73 | + for key, value in diarization.items(): |
| 74 | + if key in ("segments", "word_segments"): |
| 75 | + new_items = [] |
| 76 | + for item in value: |
| 77 | + new_item = _replace_speaker(item) |
| 78 | + if key == "segments" and "words" in item: |
| 79 | + new_item["words"] = [_replace_speaker(w) for w in item["words"]] |
| 80 | + new_items.append(new_item) |
| 81 | + result[key] = new_items |
| 82 | + else: |
| 83 | + result[key] = value |
| 84 | + return result |
| 85 | + |
| 86 | + |
| 87 | +def _merge_intervals(intervals: list[Interval]) -> list[Interval]: |
| 88 | + """Return a list of non-overlapping intervals sorted by start time.""" |
| 89 | + if not intervals: |
| 90 | + return [] |
| 91 | + sorted_intervals = sorted(intervals, key=lambda interval: interval.start) |
| 92 | + merged: list[Interval] = [ |
| 93 | + Interval(sorted_intervals[0].start, sorted_intervals[0].end) |
| 94 | + ] |
| 95 | + for interval in sorted_intervals[1:]: |
| 96 | + if interval.start <= merged[-1].end: |
| 97 | + merged[-1].end = max(merged[-1].end, interval.end) |
| 98 | + else: |
| 99 | + merged.append(Interval(interval.start, interval.end)) |
| 100 | + return merged |
| 101 | + |
| 102 | + |
| 103 | +def _total_duration(intervals: list[Interval]) -> float: |
| 104 | + """Return the sum of all interval durations.""" |
| 105 | + return sum(interval.end - interval.start for interval in intervals) |
| 106 | + |
| 107 | + |
| 108 | +def _overlap_duration( |
| 109 | + a_intervals: list[Interval], |
| 110 | + b_intervals: list[Interval], |
| 111 | +) -> float: |
| 112 | + """Compute total overlap between two sorted, merged interval lists. |
| 113 | +
|
| 114 | + Uses a sweep-line approach in O(n + m). |
| 115 | + """ |
| 116 | + overlap = 0.0 |
| 117 | + i = j = 0 |
| 118 | + while i < len(a_intervals) and j < len(b_intervals): |
| 119 | + a = a_intervals[i] |
| 120 | + b = b_intervals[j] |
| 121 | + lo = max(a.start, b.start) |
| 122 | + hi = min(a.end, b.end) |
| 123 | + if lo < hi: |
| 124 | + overlap += hi - lo |
| 125 | + if a.end <= b.end: |
| 126 | + i += 1 |
| 127 | + else: |
| 128 | + j += 1 |
| 129 | + return overlap |
| 130 | + |
| 131 | + |
| 132 | +def _parse_iso(ts: str) -> datetime: |
| 133 | + """Parse an ISO-formatted timestamp string.""" |
| 134 | + return datetime.fromisoformat(ts) |
| 135 | + |
| 136 | + |
| 137 | +def _build_participant_timelines( |
| 138 | + metadata: dict[str, Any], |
| 139 | + recording_start: str, |
| 140 | +) -> tuple[dict[str, list[Interval]], dict[str, str]]: |
| 141 | + """Build VAD interval timelines for each participant. |
| 142 | +
|
| 143 | + Args: |
| 144 | + metadata: Dict with `events` and `participants` keys. |
| 145 | + recording_start: ISO timestamp used as t=0 reference. |
| 146 | +
|
| 147 | + Returns: |
| 148 | + Tuple of (participant_id → intervals, participant_id → name). |
| 149 | + Intervals are in seconds relative to recording_start. |
| 150 | + Events before recording start are clamped to 0. |
| 151 | + """ |
| 152 | + events = metadata.get("events", []) |
| 153 | + participants_info = { |
| 154 | + p["participantId"]: p.get("name", p["participantId"]) |
| 155 | + for p in metadata.get("participants", []) |
| 156 | + } |
| 157 | + |
| 158 | + ref_epoch = _parse_iso(recording_start).timestamp() |
| 159 | + |
| 160 | + open_starts: dict[str, float] = {} |
| 161 | + intervals: dict[str, list[Interval]] = {} |
| 162 | + |
| 163 | + for event in events: |
| 164 | + pid = event["participant_id"] |
| 165 | + ts = _parse_iso(event["timestamp"]).timestamp() - ref_epoch |
| 166 | + etype = event["type"] |
| 167 | + |
| 168 | + if etype == "speech_start": |
| 169 | + open_starts[pid] = max(ts, 0.0) |
| 170 | + elif etype == "speech_end": |
| 171 | + start = open_starts.pop(pid, None) |
| 172 | + if start is not None: |
| 173 | + end = max(ts, 0.0) |
| 174 | + if end > start: |
| 175 | + intervals.setdefault(pid, []).append(Interval(start, end)) |
| 176 | + |
| 177 | + for pid, pid_intervals in intervals.items(): |
| 178 | + intervals[pid] = _merge_intervals(pid_intervals) |
| 179 | + |
| 180 | + return intervals, participants_info |
| 181 | + |
| 182 | + |
| 183 | +def _build_speaker_timelines( |
| 184 | + diarization: dict[str, Any], |
| 185 | +) -> dict[str, list[Interval]]: |
| 186 | + """Build interval timelines from WhisperX diarization segments.""" |
| 187 | + intervals: dict[str, list[Interval]] = {} |
| 188 | + |
| 189 | + for segment in diarization.get("segments", []): |
| 190 | + speaker = segment.get("speaker") |
| 191 | + if speaker is None: |
| 192 | + continue |
| 193 | + intervals.setdefault(speaker, []).append( |
| 194 | + Interval(segment["start"], segment["end"]) |
| 195 | + ) |
| 196 | + |
| 197 | + for speaker, speaker_intervals in intervals.items(): |
| 198 | + intervals[speaker] = _merge_intervals(speaker_intervals) |
| 199 | + |
| 200 | + return intervals |
| 201 | + |
| 202 | + |
| 203 | +def assign_speakers( |
| 204 | + metadata: dict[str, Any], |
| 205 | + diarization: dict[str, Any], |
| 206 | + recording_start_timestamp: str, |
| 207 | + overlap_threshold: float = DEFAULT_OVERLAP_THRESHOLD, |
| 208 | +) -> AssignmentResult: |
| 209 | + """Match WhisperX speaker labels to participants. |
| 210 | +
|
| 211 | + Args: |
| 212 | + metadata: User metadata with `events` and `participants`. |
| 213 | + diarization: WhisperX JSON output containing `segments`. |
| 214 | + recording_start_timestamp: ISO timestamp for t=0 reference. |
| 215 | + overlap_threshold: Minimum overlap/speaker_duration to accept. |
| 216 | +
|
| 217 | + Returns: |
| 218 | + AssignmentResult with per-speaker assignments and unassigned |
| 219 | + speakers. |
| 220 | + """ |
| 221 | + participant_timelines, participant_names = _build_participant_timelines( |
| 222 | + metadata, recording_start_timestamp |
| 223 | + ) |
| 224 | + speaker_timelines = _build_speaker_timelines(diarization) |
| 225 | + |
| 226 | + result = AssignmentResult() |
| 227 | + |
| 228 | + for speaker, speaker_intervals in speaker_timelines.items(): |
| 229 | + speaker_duration = _total_duration(speaker_intervals) |
| 230 | + if speaker_duration == 0: |
| 231 | + result.unassigned_speakers.append(speaker) |
| 232 | + continue |
| 233 | + |
| 234 | + best_pid: str | None = None |
| 235 | + best_score: float = 0.0 |
| 236 | + |
| 237 | + for pid, part_intervals in participant_timelines.items(): |
| 238 | + overlap = _overlap_duration(speaker_intervals, part_intervals) |
| 239 | + score = overlap / speaker_duration |
| 240 | + if score > best_score: |
| 241 | + best_score = score |
| 242 | + best_pid = pid |
| 243 | + |
| 244 | + if best_pid is not None and best_score >= overlap_threshold: |
| 245 | + result.assignments.append( |
| 246 | + SpeakerAssignment( |
| 247 | + speaker_label=speaker, |
| 248 | + participant_id=best_pid, |
| 249 | + participant_name=participant_names.get(best_pid, best_pid), |
| 250 | + score=best_score, |
| 251 | + ) |
| 252 | + ) |
| 253 | + logger.info( |
| 254 | + "Assigned %s -> %s (score=%.3f)", |
| 255 | + speaker, |
| 256 | + participant_names.get(best_pid, best_pid), |
| 257 | + best_score, |
| 258 | + ) |
| 259 | + else: |
| 260 | + result.unassigned_speakers.append(speaker) |
| 261 | + logger.info( |
| 262 | + "Speaker %s unassigned (best=%.3f, threshold=%.3f)", |
| 263 | + speaker, |
| 264 | + best_score, |
| 265 | + overlap_threshold, |
| 266 | + ) |
| 267 | + |
| 268 | + return result |
0 commit comments