mazesmazes commited on
Commit
eda17fb
·
verified ·
1 Parent(s): 14dee2f

Update custom model files, README, and requirements

Browse files
Files changed (3) hide show
  1. alignment.py +292 -0
  2. asr_modeling.py +2 -0
  3. asr_pipeline.py +3 -296
alignment.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Forced alignment for word-level timestamps using Wav2Vec2."""
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+
7
+ def _get_device() -> str:
8
+ """Get best available device for non-transformers models."""
9
+ if torch.cuda.is_available():
10
+ return "cuda"
11
+ if torch.backends.mps.is_available():
12
+ return "mps"
13
+ return "cpu"
14
+
15
+
16
+ class ForcedAligner:
17
+ """Lazy-loaded forced aligner for word-level timestamps using torchaudio wav2vec2.
18
+
19
+ Uses Viterbi trellis algorithm for optimal alignment path finding.
20
+ """
21
+
22
+ _bundle = None
23
+ _model = None
24
+ _labels = None
25
+ _dictionary = None
26
+
27
+ @classmethod
28
+ def get_instance(cls, device: str = "cuda"):
29
+ """Get or create the forced alignment model (singleton).
30
+
31
+ Args:
32
+ device: Device to run model on ("cuda" or "cpu")
33
+
34
+ Returns:
35
+ Tuple of (model, labels, dictionary)
36
+ """
37
+ if cls._model is None:
38
+ import torchaudio
39
+
40
+ cls._bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
41
+ cls._model = cls._bundle.get_model().to(device)
42
+ cls._model.eval()
43
+ cls._labels = cls._bundle.get_labels()
44
+ cls._dictionary = {c: i for i, c in enumerate(cls._labels)}
45
+ return cls._model, cls._labels, cls._dictionary
46
+
47
+ @staticmethod
48
+ def _get_trellis(emission: torch.Tensor, tokens: list[int], blank_id: int = 0) -> torch.Tensor:
49
+ """Build trellis for forced alignment using forward algorithm.
50
+
51
+ The trellis[t, j] represents the log probability of the best path that
52
+ aligns the first j tokens to the first t frames.
53
+
54
+ Args:
55
+ emission: Log-softmax emission matrix of shape (num_frames, num_classes)
56
+ tokens: List of target token indices
57
+ blank_id: Index of the blank/CTC token (default 0)
58
+
59
+ Returns:
60
+ Trellis matrix of shape (num_frames + 1, num_tokens + 1)
61
+ """
62
+ num_frames = emission.size(0)
63
+ num_tokens = len(tokens)
64
+
65
+ trellis = torch.full((num_frames + 1, num_tokens + 1), -float("inf"))
66
+ trellis[0, 0] = 0
67
+
68
+ for t in range(num_frames):
69
+ for j in range(num_tokens + 1):
70
+ # Stay: emit blank and stay at j tokens
71
+ stay = trellis[t, j] + emission[t, blank_id]
72
+
73
+ # Move: emit token j and advance to j+1 tokens
74
+ move = trellis[t, j - 1] + emission[t, tokens[j - 1]] if j > 0 else -float("inf")
75
+
76
+ trellis[t + 1, j] = max(stay, move) # Viterbi: take best path
77
+
78
+ return trellis
79
+
80
+ @staticmethod
81
+ def _backtrack(
82
+ trellis: torch.Tensor, emission: torch.Tensor, tokens: list[int], blank_id: int = 0
83
+ ) -> list[tuple[int, float, float]]:
84
+ """Backtrack through trellis to find optimal forced monotonic alignment.
85
+
86
+ Guarantees:
87
+ - All tokens are emitted exactly once
88
+ - Strictly monotonic: each token's frames come after previous token's
89
+ - No frame skipping or token teleporting
90
+
91
+ Returns list of (token_id, start_frame, end_frame) for each token.
92
+ """
93
+ num_frames = emission.size(0)
94
+ num_tokens = len(tokens)
95
+
96
+ if num_tokens == 0:
97
+ return []
98
+
99
+ # Find the best ending point (should be at num_tokens)
100
+ # But verify trellis reached a valid state
101
+ if trellis[num_frames, num_tokens] == -float("inf"):
102
+ # Alignment failed - fall back to uniform distribution
103
+ frames_per_token = num_frames / num_tokens
104
+ return [
105
+ (tokens[i], i * frames_per_token, (i + 1) * frames_per_token)
106
+ for i in range(num_tokens)
107
+ ]
108
+
109
+ # Backtrack: find where each token transition occurred
110
+ # path[i] = frame where token i was first emitted
111
+ token_frames: list[list[int]] = [[] for _ in range(num_tokens)]
112
+
113
+ t = num_frames
114
+ j = num_tokens
115
+
116
+ while t > 0 and j > 0:
117
+ # Check: did we transition from j-1 to j at frame t-1?
118
+ stay_score = trellis[t - 1, j] + emission[t - 1, blank_id]
119
+ move_score = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]
120
+
121
+ if move_score >= stay_score:
122
+ # Token j-1 was emitted at frame t-1
123
+ token_frames[j - 1].insert(0, t - 1)
124
+ j -= 1
125
+ # Always decrement time (monotonic)
126
+ t -= 1
127
+
128
+ # Handle any remaining tokens at the start (edge case)
129
+ while j > 0:
130
+ token_frames[j - 1].insert(0, 0)
131
+ j -= 1
132
+
133
+ # Convert to spans with sub-frame interpolation
134
+ token_spans: list[tuple[int, float, float]] = []
135
+ for token_idx, frames in enumerate(token_frames):
136
+ if not frames:
137
+ # Token never emitted - assign minimal span after previous
138
+ if token_spans:
139
+ prev_end = token_spans[-1][2]
140
+ frames = [int(prev_end)]
141
+ else:
142
+ frames = [0]
143
+
144
+ token_id = tokens[token_idx]
145
+ frame_probs = emission[frames, token_id]
146
+ peak_idx = int(torch.argmax(frame_probs).item())
147
+ peak_frame = frames[peak_idx]
148
+
149
+ # Sub-frame interpolation using quadratic fit around peak
150
+ if len(frames) >= 3 and 0 < peak_idx < len(frames) - 1:
151
+ y0 = frame_probs[peak_idx - 1].item()
152
+ y1 = frame_probs[peak_idx].item()
153
+ y2 = frame_probs[peak_idx + 1].item()
154
+
155
+ denom = y0 - 2 * y1 + y2
156
+ if abs(denom) > 1e-10:
157
+ offset = 0.5 * (y0 - y2) / denom
158
+ offset = max(-0.5, min(0.5, offset))
159
+ else:
160
+ offset = 0.0
161
+ refined_frame = peak_frame + offset
162
+ else:
163
+ refined_frame = float(peak_frame)
164
+
165
+ token_spans.append((token_id, refined_frame, refined_frame + 1.0))
166
+
167
+ return token_spans
168
+
169
+ # Offset compensation for Wav2Vec2-BASE systematic bias (in seconds)
170
+ # Calibrated on librispeech-alignments dataset
171
+ START_OFFSET = 0.06 # Subtract from start times (shift earlier)
172
+ END_OFFSET = -0.03 # Add to end times (shift later)
173
+
174
+ @classmethod
175
+ def align(
176
+ cls,
177
+ audio: np.ndarray,
178
+ text: str,
179
+ sample_rate: int = 16000,
180
+ _language: str = "eng",
181
+ _batch_size: int = 16,
182
+ ) -> list[dict]:
183
+ """Align transcript to audio and return word-level timestamps.
184
+
185
+ Uses Viterbi trellis algorithm for optimal forced alignment.
186
+
187
+ Args:
188
+ audio: Audio waveform as numpy array
189
+ text: Transcript text to align
190
+ sample_rate: Audio sample rate (default 16000)
191
+ _language: ISO-639-3 language code (default "eng" for English, unused)
192
+ _batch_size: Batch size for alignment model (unused)
193
+
194
+ Returns:
195
+ List of dicts with 'word', 'start', 'end' keys
196
+ """
197
+ import torchaudio
198
+
199
+ device = _get_device()
200
+ model, labels, dictionary = cls.get_instance(device)
201
+
202
+ # Convert audio to tensor (copy to ensure array is writable)
203
+ if isinstance(audio, np.ndarray):
204
+ waveform = torch.from_numpy(audio.copy()).float()
205
+ else:
206
+ waveform = audio.clone().float()
207
+
208
+ # Ensure 2D (channels, time)
209
+ if waveform.dim() == 1:
210
+ waveform = waveform.unsqueeze(0)
211
+
212
+ # Resample if needed (wav2vec2 expects 16kHz)
213
+ if sample_rate != cls._bundle.sample_rate:
214
+ waveform = torchaudio.functional.resample(
215
+ waveform, sample_rate, cls._bundle.sample_rate
216
+ )
217
+
218
+ waveform = waveform.to(device)
219
+
220
+ # Get emissions from model
221
+ with torch.inference_mode():
222
+ emissions, _ = model(waveform)
223
+ emissions = torch.log_softmax(emissions, dim=-1)
224
+
225
+ emission = emissions[0].cpu()
226
+
227
+ # Normalize text: uppercase, keep only valid characters
228
+ transcript = text.upper()
229
+
230
+ # Build tokens from transcript (including word separators)
231
+ tokens = []
232
+ for char in transcript:
233
+ if char in dictionary:
234
+ tokens.append(dictionary[char])
235
+ elif char == " ":
236
+ tokens.append(dictionary.get("|", dictionary.get(" ", 0)))
237
+
238
+ if not tokens:
239
+ return []
240
+
241
+ # Build Viterbi trellis and backtrack for optimal path
242
+ trellis = cls._get_trellis(emission, tokens, blank_id=0)
243
+ alignment_path = cls._backtrack(trellis, emission, tokens, blank_id=0)
244
+
245
+ # Convert frame indices to time (model stride is 320 samples at 16kHz = 20ms)
246
+ frame_duration = 320 / cls._bundle.sample_rate
247
+
248
+ # Apply separate offset compensation for start/end (Wav2Vec2 systematic bias)
249
+ start_offset = cls.START_OFFSET
250
+ end_offset = cls.END_OFFSET
251
+
252
+ # Group aligned tokens into words based on pipe separator
253
+ words = text.split()
254
+ word_timestamps = []
255
+ current_word_start = None
256
+ current_word_end = None
257
+ word_idx = 0
258
+ separator_id = dictionary.get("|", dictionary.get(" ", 0))
259
+
260
+ for token_id, start_frame, end_frame in alignment_path:
261
+ if token_id == separator_id: # Word separator
262
+ if current_word_start is not None and word_idx < len(words):
263
+ start_time = max(0.0, current_word_start * frame_duration - start_offset)
264
+ end_time = max(0.0, current_word_end * frame_duration - end_offset)
265
+ word_timestamps.append(
266
+ {
267
+ "word": words[word_idx],
268
+ "start": start_time,
269
+ "end": end_time,
270
+ }
271
+ )
272
+ word_idx += 1
273
+ current_word_start = None
274
+ current_word_end = None
275
+ else:
276
+ if current_word_start is None:
277
+ current_word_start = start_frame
278
+ current_word_end = end_frame
279
+
280
+ # Don't forget the last word
281
+ if current_word_start is not None and word_idx < len(words):
282
+ start_time = max(0.0, current_word_start * frame_duration - start_offset)
283
+ end_time = max(0.0, current_word_end * frame_duration - end_offset)
284
+ word_timestamps.append(
285
+ {
286
+ "word": words[word_idx],
287
+ "start": start_time,
288
+ "end": end_time,
289
+ }
290
+ )
291
+
292
+ return word_timestamps
asr_modeling.py CHANGED
@@ -869,6 +869,8 @@ class ASRModel(PreTrainedModel, GenerationMixin):
869
  shutil.copy(asr_file, save_dir / asr_file.name)
870
  # Copy projectors module
871
  shutil.copy(src_dir / "projectors.py", save_dir / "projectors.py")
 
 
872
  # Copy diarization module
873
  shutil.copy(src_dir / "diarization.py", save_dir / "diarization.py")
874
 
 
869
  shutil.copy(asr_file, save_dir / asr_file.name)
870
  # Copy projectors module
871
  shutil.copy(src_dir / "projectors.py", save_dir / "projectors.py")
872
+ # Copy alignment module
873
+ shutil.copy(src_dir / "alignment.py", save_dir / "alignment.py")
874
  # Copy diarization module
875
  shutil.copy(src_dir / "diarization.py", save_dir / "diarization.py")
876
 
asr_pipeline.py CHANGED
@@ -9,305 +9,12 @@ import torch
9
  import transformers
10
 
11
  try:
 
12
  from .asr_modeling import ASRModel
13
- except ImportError:
14
- from asr_modeling import ASRModel # type: ignore[no-redef]
15
-
16
-
17
- def _get_device() -> str:
18
- """Get best available device for non-transformers models."""
19
- if torch.cuda.is_available():
20
- return "cuda"
21
- if torch.backends.mps.is_available():
22
- return "mps"
23
- return "cpu"
24
-
25
-
26
- class ForcedAligner:
27
- """Lazy-loaded forced aligner for word-level timestamps using torchaudio wav2vec2.
28
-
29
- Uses Viterbi trellis algorithm for optimal alignment path finding.
30
- """
31
-
32
- _bundle = None
33
- _model = None
34
- _labels = None
35
- _dictionary = None
36
-
37
- @classmethod
38
- def get_instance(cls, device: str = "cuda"):
39
- """Get or create the forced alignment model (singleton).
40
-
41
- Args:
42
- device: Device to run model on ("cuda" or "cpu")
43
-
44
- Returns:
45
- Tuple of (model, labels, dictionary)
46
- """
47
- if cls._model is None:
48
- import torchaudio
49
-
50
- cls._bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
51
- cls._model = cls._bundle.get_model().to(device)
52
- cls._model.eval()
53
- cls._labels = cls._bundle.get_labels()
54
- cls._dictionary = {c: i for i, c in enumerate(cls._labels)}
55
- return cls._model, cls._labels, cls._dictionary
56
-
57
- @staticmethod
58
- def _get_trellis(emission: torch.Tensor, tokens: list[int], blank_id: int = 0) -> torch.Tensor:
59
- """Build trellis for forced alignment using forward algorithm.
60
-
61
- The trellis[t, j] represents the log probability of the best path that
62
- aligns the first j tokens to the first t frames.
63
-
64
- Args:
65
- emission: Log-softmax emission matrix of shape (num_frames, num_classes)
66
- tokens: List of target token indices
67
- blank_id: Index of the blank/CTC token (default 0)
68
-
69
- Returns:
70
- Trellis matrix of shape (num_frames + 1, num_tokens + 1)
71
- """
72
- num_frames = emission.size(0)
73
- num_tokens = len(tokens)
74
-
75
- trellis = torch.full((num_frames + 1, num_tokens + 1), -float("inf"))
76
- trellis[0, 0] = 0
77
-
78
- for t in range(num_frames):
79
- for j in range(num_tokens + 1):
80
- # Stay: emit blank and stay at j tokens
81
- stay = trellis[t, j] + emission[t, blank_id]
82
-
83
- # Move: emit token j and advance to j+1 tokens
84
- if j > 0:
85
- move = trellis[t, j - 1] + emission[t, tokens[j - 1]]
86
- else:
87
- move = -float("inf")
88
-
89
- trellis[t + 1, j] = max(stay, move) # Viterbi: take best path
90
-
91
- return trellis
92
-
93
- @staticmethod
94
- def _backtrack(
95
- trellis: torch.Tensor, emission: torch.Tensor, tokens: list[int], blank_id: int = 0
96
- ) -> list[tuple[int, int, float]]:
97
- """Backtrack through trellis to find optimal forced monotonic alignment.
98
-
99
- Guarantees:
100
- - All tokens are emitted exactly once
101
- - Strictly monotonic: each token's frames come after previous token's
102
- - No frame skipping or token teleporting
103
-
104
- Returns list of (token_id, start_frame, end_frame) for each token.
105
- """
106
- num_frames = emission.size(0)
107
- num_tokens = len(tokens)
108
-
109
- if num_tokens == 0:
110
- return []
111
-
112
- # Find the best ending point (should be at num_tokens)
113
- # But verify trellis reached a valid state
114
- if trellis[num_frames, num_tokens] == -float("inf"):
115
- # Alignment failed - fall back to uniform distribution
116
- frames_per_token = num_frames / num_tokens
117
- return [
118
- (tokens[i], i * frames_per_token, (i + 1) * frames_per_token)
119
- for i in range(num_tokens)
120
- ]
121
-
122
- # Backtrack: find where each token transition occurred
123
- # path[i] = frame where token i was first emitted
124
- token_frames: list[list[int]] = [[] for _ in range(num_tokens)]
125
-
126
- t = num_frames
127
- j = num_tokens
128
-
129
- while t > 0 and j > 0:
130
- # Check: did we transition from j-1 to j at frame t-1?
131
- stay_score = trellis[t - 1, j] + emission[t - 1, blank_id]
132
- move_score = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]
133
-
134
- if move_score >= stay_score:
135
- # Token j-1 was emitted at frame t-1
136
- token_frames[j - 1].insert(0, t - 1)
137
- j -= 1
138
- # Always decrement time (monotonic)
139
- t -= 1
140
-
141
- # Handle any remaining tokens at the start (edge case)
142
- while j > 0:
143
- token_frames[j - 1].insert(0, 0)
144
- j -= 1
145
-
146
- # Convert to spans with sub-frame interpolation
147
- token_spans = []
148
- for token_idx, frames in enumerate(token_frames):
149
- if not frames:
150
- # Token never emitted - assign minimal span after previous
151
- if token_spans:
152
- prev_end = token_spans[-1][2]
153
- frames = [int(prev_end)]
154
- else:
155
- frames = [0]
156
-
157
- token_id = tokens[token_idx]
158
- frame_probs = emission[frames, token_id]
159
- peak_idx = int(torch.argmax(frame_probs).item())
160
- peak_frame = frames[peak_idx]
161
-
162
- # Sub-frame interpolation using quadratic fit around peak
163
- if len(frames) >= 3 and 0 < peak_idx < len(frames) - 1:
164
- y0 = frame_probs[peak_idx - 1].item()
165
- y1 = frame_probs[peak_idx].item()
166
- y2 = frame_probs[peak_idx + 1].item()
167
-
168
- denom = y0 - 2 * y1 + y2
169
- if abs(denom) > 1e-10:
170
- offset = 0.5 * (y0 - y2) / denom
171
- offset = max(-0.5, min(0.5, offset))
172
- else:
173
- offset = 0.0
174
- refined_frame = peak_frame + offset
175
- else:
176
- refined_frame = float(peak_frame)
177
-
178
- token_spans.append((token_id, refined_frame, refined_frame + 1.0))
179
-
180
- return token_spans
181
-
182
- # Offset compensation for Wav2Vec2-BASE systematic bias (in seconds)
183
- # Calibrated on librispeech-alignments dataset
184
- START_OFFSET = 0.06 # Subtract from start times (shift earlier)
185
- END_OFFSET = -0.03 # Add to end times (shift later)
186
-
187
- @classmethod
188
- def align(
189
- cls,
190
- audio: np.ndarray,
191
- text: str,
192
- sample_rate: int = 16000,
193
- _language: str = "eng",
194
- _batch_size: int = 16,
195
- ) -> list[dict]:
196
- """Align transcript to audio and return word-level timestamps.
197
-
198
- Uses Viterbi trellis algorithm for optimal forced alignment.
199
-
200
- Args:
201
- audio: Audio waveform as numpy array
202
- text: Transcript text to align
203
- sample_rate: Audio sample rate (default 16000)
204
- _language: ISO-639-3 language code (default "eng" for English, unused)
205
- _batch_size: Batch size for alignment model (unused)
206
-
207
- Returns:
208
- List of dicts with 'word', 'start', 'end' keys
209
- """
210
- import torchaudio
211
-
212
- device = _get_device()
213
- model, labels, dictionary = cls.get_instance(device)
214
-
215
- # Convert audio to tensor (copy to ensure array is writable)
216
- if isinstance(audio, np.ndarray):
217
- waveform = torch.from_numpy(audio.copy()).float()
218
- else:
219
- waveform = audio.clone().float()
220
-
221
- # Ensure 2D (channels, time)
222
- if waveform.dim() == 1:
223
- waveform = waveform.unsqueeze(0)
224
-
225
- # Resample if needed (wav2vec2 expects 16kHz)
226
- if sample_rate != cls._bundle.sample_rate:
227
- waveform = torchaudio.functional.resample(
228
- waveform, sample_rate, cls._bundle.sample_rate
229
- )
230
-
231
- waveform = waveform.to(device)
232
-
233
- # Get emissions from model
234
- with torch.inference_mode():
235
- emissions, _ = model(waveform)
236
- emissions = torch.log_softmax(emissions, dim=-1)
237
-
238
- emission = emissions[0].cpu()
239
-
240
- # Normalize text: uppercase, keep only valid characters
241
- transcript = text.upper()
242
-
243
- # Build tokens from transcript (including word separators)
244
- tokens = []
245
- for char in transcript:
246
- if char in dictionary:
247
- tokens.append(dictionary[char])
248
- elif char == " ":
249
- tokens.append(dictionary.get("|", dictionary.get(" ", 0)))
250
-
251
- if not tokens:
252
- return []
253
-
254
- # Build Viterbi trellis and backtrack for optimal path
255
- trellis = cls._get_trellis(emission, tokens, blank_id=0)
256
- alignment_path = cls._backtrack(trellis, emission, tokens, blank_id=0)
257
-
258
- # Convert frame indices to time (model stride is 320 samples at 16kHz = 20ms)
259
- frame_duration = 320 / cls._bundle.sample_rate
260
-
261
- # Apply separate offset compensation for start/end (Wav2Vec2 systematic bias)
262
- start_offset = cls.START_OFFSET
263
- end_offset = cls.END_OFFSET
264
-
265
- # Group aligned tokens into words based on pipe separator
266
- words = text.split()
267
- word_timestamps = []
268
- current_word_start = None
269
- current_word_end = None
270
- word_idx = 0
271
- separator_id = dictionary.get("|", dictionary.get(" ", 0))
272
-
273
- for token_id, start_frame, end_frame in alignment_path:
274
- if token_id == separator_id: # Word separator
275
- if current_word_start is not None and word_idx < len(words):
276
- start_time = max(0.0, current_word_start * frame_duration - start_offset)
277
- end_time = max(0.0, current_word_end * frame_duration - end_offset)
278
- word_timestamps.append(
279
- {
280
- "word": words[word_idx],
281
- "start": start_time,
282
- "end": end_time,
283
- }
284
- )
285
- word_idx += 1
286
- current_word_start = None
287
- current_word_end = None
288
- else:
289
- if current_word_start is None:
290
- current_word_start = start_frame
291
- current_word_end = end_frame
292
-
293
- # Don't forget the last word
294
- if current_word_start is not None and word_idx < len(words):
295
- start_time = max(0.0, current_word_start * frame_duration - start_offset)
296
- end_time = max(0.0, current_word_end * frame_duration - end_offset)
297
- word_timestamps.append(
298
- {
299
- "word": words[word_idx],
300
- "start": start_time,
301
- "end": end_time,
302
- }
303
- )
304
-
305
- return word_timestamps
306
-
307
-
308
- try:
309
  from .diarization import SpeakerDiarizer
310
  except ImportError:
 
 
311
  from diarization import SpeakerDiarizer # type: ignore[no-redef]
312
 
313
  # Re-export for backwards compatibility
 
9
  import transformers
10
 
11
  try:
12
+ from .alignment import ForcedAligner
13
  from .asr_modeling import ASRModel
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  from .diarization import SpeakerDiarizer
15
  except ImportError:
16
+ from alignment import ForcedAligner # type: ignore[no-redef]
17
+ from asr_modeling import ASRModel # type: ignore[no-redef]
18
  from diarization import SpeakerDiarizer # type: ignore[no-redef]
19
 
20
  # Re-export for backwards compatibility