timing.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. import subprocess
  2. import warnings
  3. from dataclasses import dataclass
  4. from typing import TYPE_CHECKING, List
  5. import numba
  6. import numpy as np
  7. import torch
  8. import torch.nn.functional as F
  9. from .audio import HOP_LENGTH, SAMPLE_RATE, TOKENS_PER_SECOND
  10. from .tokenizer import Tokenizer
  11. if TYPE_CHECKING:
  12. from .model import Whisper
  13. def median_filter(x: torch.Tensor, filter_width: int):
  14. """Apply a median filter of width `filter_width` along the last dimension of `x`"""
  15. pad_width = filter_width // 2
  16. if x.shape[-1] <= pad_width:
  17. # F.pad requires the padding width to be smaller than the input dimension
  18. return x
  19. if (ndim := x.ndim) <= 2:
  20. # `F.pad` does not support 1D or 2D inputs for reflect padding but supports 3D and 4D
  21. x = x[None, None, :]
  22. assert (
  23. filter_width > 0 and filter_width % 2 == 1
  24. ), "`filter_width` should be an odd number"
  25. result = None
  26. x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect")
  27. if x.is_cuda:
  28. try:
  29. from .triton_ops import median_filter_cuda
  30. result = median_filter_cuda(x, filter_width)
  31. except (RuntimeError, subprocess.CalledProcessError):
  32. warnings.warn(
  33. "Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
  34. "falling back to a slower median kernel implementation..."
  35. )
  36. if result is None:
  37. # sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450)
  38. result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2]
  39. if ndim <= 2:
  40. result = result[0, 0]
  41. return result
  42. @numba.jit
  43. def backtrace(trace: np.ndarray):
  44. i = trace.shape[0] - 1
  45. j = trace.shape[1] - 1
  46. trace[0, :] = 2
  47. trace[:, 0] = 1
  48. result = []
  49. while i > 0 or j > 0:
  50. result.append((i - 1, j - 1))
  51. if trace[i, j] == 0:
  52. i -= 1
  53. j -= 1
  54. elif trace[i, j] == 1:
  55. i -= 1
  56. elif trace[i, j] == 2:
  57. j -= 1
  58. else:
  59. raise ValueError("Unexpected trace[i, j]")
  60. result = np.array(result)
  61. return result[::-1, :].T
  62. @numba.jit(nopython=True, parallel=True)
  63. def dtw_cpu(x: np.ndarray):
  64. N, M = x.shape
  65. cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf
  66. trace = -np.ones((N + 1, M + 1), dtype=np.float32)
  67. cost[0, 0] = 0
  68. for j in range(1, M + 1):
  69. for i in range(1, N + 1):
  70. c0 = cost[i - 1, j - 1]
  71. c1 = cost[i - 1, j]
  72. c2 = cost[i, j - 1]
  73. if c0 < c1 and c0 < c2:
  74. c, t = c0, 0
  75. elif c1 < c0 and c1 < c2:
  76. c, t = c1, 1
  77. else:
  78. c, t = c2, 2
  79. cost[i, j] = x[i - 1, j - 1] + c
  80. trace[i, j] = t
  81. return backtrace(trace)
  82. def dtw_cuda(x, BLOCK_SIZE=1024):
  83. from .triton_ops import dtw_kernel
  84. M, N = x.shape
  85. assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}"
  86. x_skew = (
  87. F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M)
  88. )
  89. x_skew = x_skew.T.contiguous()
  90. cost = torch.ones(N + M + 2, M + 2) * np.inf
  91. cost[0, 0] = 0
  92. cost = cost.cuda()
  93. trace = torch.zeros_like(cost, dtype=torch.int32)
  94. dtw_kernel[(1,)](
  95. cost,
  96. trace,
  97. x_skew,
  98. x_skew.stride(0),
  99. cost.stride(0),
  100. trace.stride(0),
  101. N,
  102. M,
  103. BLOCK_SIZE=BLOCK_SIZE,
  104. )
  105. trace = trace.T.flatten()[: (M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[
  106. :, : N + 1
  107. ]
  108. return backtrace(trace.cpu().numpy())
  109. def dtw(x: torch.Tensor) -> np.ndarray:
  110. if x.is_cuda:
  111. try:
  112. return dtw_cuda(x)
  113. except (RuntimeError, subprocess.CalledProcessError):
  114. warnings.warn(
  115. "Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
  116. "falling back to a slower DTW implementation..."
  117. )
  118. return dtw_cpu(x.double().cpu().numpy())
  119. @dataclass
  120. class WordTiming:
  121. word: str
  122. tokens: List[int]
  123. start: float
  124. end: float
  125. probability: float
  126. def find_alignment(
  127. model: "Whisper",
  128. tokenizer: Tokenizer,
  129. text_tokens: List[int],
  130. mel: torch.Tensor,
  131. num_frames: int,
  132. *,
  133. medfilt_width: int = 7,
  134. qk_scale: float = 1.0,
  135. ) -> List[WordTiming]:
  136. tokens = torch.tensor(
  137. [
  138. *tokenizer.sot_sequence,
  139. tokenizer.no_timestamps,
  140. *text_tokens,
  141. tokenizer.eot,
  142. ]
  143. ).to(model.device)
  144. # install hooks on the cross attention layers to retrieve the attention weights
  145. QKs = [None] * model.dims.n_text_layer
  146. hooks = [
  147. block.cross_attn.register_forward_hook(
  148. lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1][0])
  149. )
  150. for i, block in enumerate(model.decoder.blocks)
  151. ]
  152. with torch.no_grad():
  153. logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
  154. sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
  155. token_probs = sampled_logits.softmax(dim=-1)
  156. text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens]
  157. text_token_probs = text_token_probs.tolist()
  158. for hook in hooks:
  159. hook.remove()
  160. # heads * tokens * frames
  161. weights = torch.stack([QKs[l][h] for l, h in model.alignment_heads.indices().T])
  162. weights = weights[:, :, : num_frames // 2]
  163. weights = (weights * qk_scale).softmax(dim=-1)
  164. std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
  165. weights = (weights - mean) / std
  166. weights = median_filter(weights, medfilt_width)
  167. matrix = weights.mean(axis=0)
  168. matrix = matrix[len(tokenizer.sot_sequence) : -1]
  169. text_indices, time_indices = dtw(-matrix)
  170. words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot])
  171. word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
  172. jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
  173. jump_times = time_indices[jumps] / TOKENS_PER_SECOND
  174. start_times = jump_times[word_boundaries[:-1]]
  175. end_times = jump_times[word_boundaries[1:]]
  176. word_probabilities = [
  177. np.mean(text_token_probs[i:j])
  178. for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
  179. ]
  180. # hack: ensure the first and second word is not longer than twice the median word duration.
  181. # a better segmentation algorithm based on VAD should be able to replace this.
  182. word_durations = end_times - start_times
  183. word_durations = word_durations[word_durations.nonzero()]
  184. if len(word_durations) > 0:
  185. median_duration = np.median(word_durations)
  186. max_duration = median_duration * 2
  187. if len(word_durations) >= 2 and word_durations[1] > max_duration:
  188. boundary = max(end_times[2] / 2, end_times[2] - max_duration)
  189. end_times[0] = start_times[1] = boundary
  190. if len(word_durations) >= 1 and end_times[0] - start_times[0] > max_duration:
  191. start_times[0] = max(0, end_times[0] - max_duration)
  192. return [
  193. WordTiming(word, tokens, start, end, probability)
  194. for word, tokens, start, end, probability in zip(
  195. words, word_tokens, start_times, end_times, word_probabilities
  196. )
  197. ]
  198. def merge_punctuations(alignment: List[WordTiming], prepended: str, appended: str):
  199. # merge prepended punctuations
  200. i = len(alignment) - 2
  201. j = len(alignment) - 1
  202. while i >= 0:
  203. previous = alignment[i]
  204. following = alignment[j]
  205. if previous.word.startswith(" ") and previous.word.strip() in prepended:
  206. # prepend it to the following word
  207. following.word = previous.word + following.word
  208. following.tokens = previous.tokens + following.tokens
  209. previous.word = ""
  210. previous.tokens = []
  211. else:
  212. j = i
  213. i -= 1
  214. # merge appended punctuations
  215. i = 0
  216. j = 1
  217. while j < len(alignment):
  218. previous = alignment[i]
  219. following = alignment[j]
  220. if not previous.word.endswith(" ") and following.word in appended:
  221. # append it to the previous word
  222. previous.word = previous.word + following.word
  223. previous.tokens = previous.tokens + following.tokens
  224. following.word = ""
  225. following.tokens = []
  226. else:
  227. i = j
  228. j += 1
  229. def add_word_timestamps(
  230. *,
  231. segments: List[dict],
  232. model: "Whisper",
  233. tokenizer: Tokenizer,
  234. mel: torch.Tensor,
  235. num_frames: int,
  236. prepend_punctuations: str = "\"'“¿([{-",
  237. append_punctuations: str = "\"'.。,,!!??::”)]}、",
  238. **kwargs,
  239. ):
  240. if len(segments) == 0:
  241. return
  242. text_tokens = [t for segment in segments for t in segment["tokens"]]
  243. alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs)
  244. merge_punctuations(alignment, prepend_punctuations, append_punctuations)
  245. time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE
  246. segment_lengths = [len(s["tokens"]) for s in segments]
  247. token_sources = np.repeat(np.arange(len(segments)), segment_lengths)
  248. for segment in segments:
  249. segment["words"] = []
  250. word_boundaries = np.pad(np.cumsum([len(w.tokens) for w in alignment]), (1, 0))
  251. for i, timing in enumerate(alignment):
  252. if timing.word:
  253. segment = segments[token_sources[word_boundaries[i]]]
  254. start = round(time_offset + timing.start, 2)
  255. end = round(time_offset + timing.end, 2)
  256. segment["words"].append(
  257. dict(
  258. word=timing.word,
  259. start=start,
  260. end=end,
  261. probability=timing.probability,
  262. )
  263. )
  264. for segment in segments:
  265. if len(words := segment["words"]) > 0:
  266. # adjust the segment-level timestamps based on the word-level timestamps
  267. segment["start"] = words[0]["start"]
  268. segment["end"] = words[-1]["end"]