Skip to content

Commit 714264b

Browse files
committed
init
1 parent e9ff12f commit 714264b

2 files changed

Lines changed: 714 additions & 0 deletions

File tree

Lines changed: 268 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,268 @@
1+
"""Assign WhisperX diarization speakers to participant identities.
2+
3+
Uses per-stream VAD events to match generic SPEAKER_XX labels provided
4+
by diarization to real user id's by computing time overlap between
5+
diarization segments and VAD intervals.
6+
7+
Multiple speakers can map to the same participant (e.g. two people sharing
8+
one microphone). A participant with no matching speaker gets no assignment.
9+
"""
10+
11+
from __future__ import annotations
12+
13+
import logging
14+
from dataclasses import dataclass, field
15+
from datetime import datetime
16+
from typing import Any
17+
18+
logger = logging.getLogger(__name__)
19+
20+
# Minimum fraction of a speaker's total duration that must overlap with a
21+
# participant's VAD to accept the assignment.
22+
DEFAULT_OVERLAP_THRESHOLD = 0.5
23+
24+
25+
@dataclass
26+
class Interval:
27+
"""A time interval in seconds relative to recording start."""
28+
29+
start: float
30+
end: float
31+
32+
33+
@dataclass
34+
class SpeakerAssignment:
35+
"""Maps a diarization speaker label to a participant."""
36+
37+
speaker_label: str
38+
participant_id: str
39+
participant_name: str
40+
score: float
41+
42+
43+
@dataclass
44+
class AssignmentResult:
45+
"""Result of speaker-to-participant assignment."""
46+
47+
assignments: list[SpeakerAssignment] = field(default_factory=list)
48+
unassigned_speakers: list[str] = field(default_factory=list)
49+
50+
def apply(self, diarization: dict[str, Any]) -> dict[str, Any]:
51+
"""Return a copy of diarization with speaker labels replaced by names.
52+
53+
Replaces `"speaker"` fields in segments and word_segments with the
54+
assigned participant name. Unassigned speakers are left as-is.
55+
56+
Args:
57+
diarization: WhisperX dict with `segments` and optionally
58+
`word_segments`.
59+
60+
Returns:
61+
New dict with speaker labels replaced.
62+
"""
63+
speaker_to_name = {
64+
a.speaker_label: a.participant_name for a in self.assignments
65+
}
66+
67+
def _replace_speaker(item: dict[str, Any]) -> dict[str, Any]:
68+
if "speaker" in item and item["speaker"] in speaker_to_name:
69+
return {**item, "speaker": speaker_to_name[item["speaker"]]}
70+
return item
71+
72+
result: dict[str, Any] = {}
73+
for key, value in diarization.items():
74+
if key in ("segments", "word_segments"):
75+
new_items = []
76+
for item in value:
77+
new_item = _replace_speaker(item)
78+
if key == "segments" and "words" in item:
79+
new_item["words"] = [_replace_speaker(w) for w in item["words"]]
80+
new_items.append(new_item)
81+
result[key] = new_items
82+
else:
83+
result[key] = value
84+
return result
85+
86+
87+
def _merge_intervals(intervals: list[Interval]) -> list[Interval]:
88+
"""Return a list of non-overlapping intervals sorted by start time."""
89+
if not intervals:
90+
return []
91+
sorted_intervals = sorted(intervals, key=lambda interval: interval.start)
92+
merged: list[Interval] = [
93+
Interval(sorted_intervals[0].start, sorted_intervals[0].end)
94+
]
95+
for interval in sorted_intervals[1:]:
96+
if interval.start <= merged[-1].end:
97+
merged[-1].end = max(merged[-1].end, interval.end)
98+
else:
99+
merged.append(Interval(interval.start, interval.end))
100+
return merged
101+
102+
103+
def _total_duration(intervals: list[Interval]) -> float:
104+
"""Return the sum of all interval durations."""
105+
return sum(interval.end - interval.start for interval in intervals)
106+
107+
108+
def _overlap_duration(
109+
a_intervals: list[Interval],
110+
b_intervals: list[Interval],
111+
) -> float:
112+
"""Compute total overlap between two sorted, merged interval lists.
113+
114+
Uses a sweep-line approach in O(n + m).
115+
"""
116+
overlap = 0.0
117+
i = j = 0
118+
while i < len(a_intervals) and j < len(b_intervals):
119+
a = a_intervals[i]
120+
b = b_intervals[j]
121+
lo = max(a.start, b.start)
122+
hi = min(a.end, b.end)
123+
if lo < hi:
124+
overlap += hi - lo
125+
if a.end <= b.end:
126+
i += 1
127+
else:
128+
j += 1
129+
return overlap
130+
131+
132+
def _parse_iso(ts: str) -> datetime:
133+
"""Parse an ISO-formatted timestamp string."""
134+
return datetime.fromisoformat(ts)
135+
136+
137+
def _build_participant_timelines(
138+
metadata: dict[str, Any],
139+
recording_start: str,
140+
) -> tuple[dict[str, list[Interval]], dict[str, str]]:
141+
"""Build VAD interval timelines for each participant.
142+
143+
Args:
144+
metadata: Dict with `events` and `participants` keys.
145+
recording_start: ISO timestamp used as t=0 reference.
146+
147+
Returns:
148+
Tuple of (participant_id → intervals, participant_id → name).
149+
Intervals are in seconds relative to recording_start.
150+
Events before recording start are clamped to 0.
151+
"""
152+
events = metadata.get("events", [])
153+
participants_info = {
154+
p["participantId"]: p.get("name", p["participantId"])
155+
for p in metadata.get("participants", [])
156+
}
157+
158+
ref_epoch = _parse_iso(recording_start).timestamp()
159+
160+
open_starts: dict[str, float] = {}
161+
intervals: dict[str, list[Interval]] = {}
162+
163+
for event in events:
164+
pid = event["participant_id"]
165+
ts = _parse_iso(event["timestamp"]).timestamp() - ref_epoch
166+
etype = event["type"]
167+
168+
if etype == "speech_start":
169+
open_starts[pid] = max(ts, 0.0)
170+
elif etype == "speech_end":
171+
start = open_starts.pop(pid, None)
172+
if start is not None:
173+
end = max(ts, 0.0)
174+
if end > start:
175+
intervals.setdefault(pid, []).append(Interval(start, end))
176+
177+
for pid, pid_intervals in intervals.items():
178+
intervals[pid] = _merge_intervals(pid_intervals)
179+
180+
return intervals, participants_info
181+
182+
183+
def _build_speaker_timelines(
184+
diarization: dict[str, Any],
185+
) -> dict[str, list[Interval]]:
186+
"""Build interval timelines from WhisperX diarization segments."""
187+
intervals: dict[str, list[Interval]] = {}
188+
189+
for segment in diarization.get("segments", []):
190+
speaker = segment.get("speaker")
191+
if speaker is None:
192+
continue
193+
intervals.setdefault(speaker, []).append(
194+
Interval(segment["start"], segment["end"])
195+
)
196+
197+
for speaker, speaker_intervals in intervals.items():
198+
intervals[speaker] = _merge_intervals(speaker_intervals)
199+
200+
return intervals
201+
202+
203+
def assign_speakers(
204+
metadata: dict[str, Any],
205+
diarization: dict[str, Any],
206+
recording_start_timestamp: str,
207+
overlap_threshold: float = DEFAULT_OVERLAP_THRESHOLD,
208+
) -> AssignmentResult:
209+
"""Match WhisperX speaker labels to participants.
210+
211+
Args:
212+
metadata: User metadata with `events` and `participants`.
213+
diarization: WhisperX JSON output containing `segments`.
214+
recording_start_timestamp: ISO timestamp for t=0 reference.
215+
overlap_threshold: Minimum overlap/speaker_duration to accept.
216+
217+
Returns:
218+
AssignmentResult with per-speaker assignments and unassigned
219+
speakers.
220+
"""
221+
participant_timelines, participant_names = _build_participant_timelines(
222+
metadata, recording_start_timestamp
223+
)
224+
speaker_timelines = _build_speaker_timelines(diarization)
225+
226+
result = AssignmentResult()
227+
228+
for speaker, speaker_intervals in speaker_timelines.items():
229+
speaker_duration = _total_duration(speaker_intervals)
230+
if speaker_duration == 0:
231+
result.unassigned_speakers.append(speaker)
232+
continue
233+
234+
best_pid: str | None = None
235+
best_score: float = 0.0
236+
237+
for pid, part_intervals in participant_timelines.items():
238+
overlap = _overlap_duration(speaker_intervals, part_intervals)
239+
score = overlap / speaker_duration
240+
if score > best_score:
241+
best_score = score
242+
best_pid = pid
243+
244+
if best_pid is not None and best_score >= overlap_threshold:
245+
result.assignments.append(
246+
SpeakerAssignment(
247+
speaker_label=speaker,
248+
participant_id=best_pid,
249+
participant_name=participant_names.get(best_pid, best_pid),
250+
score=best_score,
251+
)
252+
)
253+
logger.info(
254+
"Assigned %s -> %s (score=%.3f)",
255+
speaker,
256+
participant_names.get(best_pid, best_pid),
257+
best_score,
258+
)
259+
else:
260+
result.unassigned_speakers.append(speaker)
261+
logger.info(
262+
"Speaker %s unassigned (best=%.3f, threshold=%.3f)",
263+
speaker,
264+
best_score,
265+
overlap_threshold,
266+
)
267+
268+
return result

0 commit comments

Comments
 (0)