model.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. import base64
  2. import gzip
  3. from dataclasses import dataclass
  4. from typing import Dict, Iterable, Optional
  5. import numpy as np
  6. import torch
  7. import torch.nn.functional as F
  8. from torch import Tensor, nn
  9. from .decoding import decode as decode_function
  10. from .decoding import detect_language as detect_language_function
  11. from .transcribe import transcribe as transcribe_function
  12. @dataclass
  13. class ModelDimensions:
  14. n_mels: int
  15. n_audio_ctx: int
  16. n_audio_state: int
  17. n_audio_head: int
  18. n_audio_layer: int
  19. n_vocab: int
  20. n_text_ctx: int
  21. n_text_state: int
  22. n_text_head: int
  23. n_text_layer: int
  24. class LayerNorm(nn.LayerNorm):
  25. def forward(self, x: Tensor) -> Tensor:
  26. return super().forward(x.float()).type(x.dtype)
  27. class Linear(nn.Linear):
  28. def forward(self, x: Tensor) -> Tensor:
  29. return F.linear(
  30. x,
  31. self.weight.to(x.dtype),
  32. None if self.bias is None else self.bias.to(x.dtype),
  33. )
  34. class Conv1d(nn.Conv1d):
  35. def _conv_forward(
  36. self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
  37. ) -> Tensor:
  38. return super()._conv_forward(
  39. x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
  40. )
  41. def sinusoids(length, channels, max_timescale=10000):
  42. """Returns sinusoids for positional embedding"""
  43. assert channels % 2 == 0
  44. log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
  45. inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
  46. scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
  47. return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
  48. class MultiHeadAttention(nn.Module):
  49. def __init__(self, n_state: int, n_head: int):
  50. super().__init__()
  51. self.n_head = n_head
  52. self.query = Linear(n_state, n_state)
  53. self.key = Linear(n_state, n_state, bias=False)
  54. self.value = Linear(n_state, n_state)
  55. self.out = Linear(n_state, n_state)
  56. def forward(
  57. self,
  58. x: Tensor,
  59. xa: Optional[Tensor] = None,
  60. mask: Optional[Tensor] = None,
  61. kv_cache: Optional[dict] = None,
  62. ):
  63. q = self.query(x)
  64. if kv_cache is None or xa is None or self.key not in kv_cache:
  65. # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
  66. # otherwise, perform key/value projections for self- or cross-attention as usual.
  67. k = self.key(x if xa is None else xa)
  68. v = self.value(x if xa is None else xa)
  69. else:
  70. # for cross-attention, calculate keys and values once and reuse in subsequent calls.
  71. k = kv_cache[self.key]
  72. v = kv_cache[self.value]
  73. wv, qk = self.qkv_attention(q, k, v, mask)
  74. return self.out(wv), qk
  75. def qkv_attention(
  76. self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
  77. ):
  78. n_batch, n_ctx, n_state = q.shape
  79. scale = (n_state // self.n_head) ** -0.25
  80. q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
  81. k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
  82. v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
  83. qk = q @ k
  84. if mask is not None:
  85. qk = qk + mask[:n_ctx, :n_ctx]
  86. qk = qk.float()
  87. w = F.softmax(qk, dim=-1).to(q.dtype)
  88. return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
  89. class ResidualAttentionBlock(nn.Module):
  90. def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
  91. super().__init__()
  92. self.attn = MultiHeadAttention(n_state, n_head)
  93. self.attn_ln = LayerNorm(n_state)
  94. self.cross_attn = (
  95. MultiHeadAttention(n_state, n_head) if cross_attention else None
  96. )
  97. self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
  98. n_mlp = n_state * 4
  99. self.mlp = nn.Sequential(
  100. Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
  101. )
  102. self.mlp_ln = LayerNorm(n_state)
  103. def forward(
  104. self,
  105. x: Tensor,
  106. xa: Optional[Tensor] = None,
  107. mask: Optional[Tensor] = None,
  108. kv_cache: Optional[dict] = None,
  109. ):
  110. x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
  111. if self.cross_attn:
  112. x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
  113. x = x + self.mlp(self.mlp_ln(x))
  114. return x
  115. class AudioEncoder(nn.Module):
  116. def __init__(
  117. self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
  118. ):
  119. super().__init__()
  120. self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
  121. self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
  122. self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
  123. self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
  124. [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
  125. )
  126. self.ln_post = LayerNorm(n_state)
  127. def forward(self, x: Tensor):
  128. """
  129. x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
  130. the mel spectrogram of the audio
  131. """
  132. x = F.gelu(self.conv1(x))
  133. x = F.gelu(self.conv2(x))
  134. x = x.permute(0, 2, 1)
  135. assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
  136. x = (x + self.positional_embedding).to(x.dtype)
  137. for block in self.blocks:
  138. x = block(x)
  139. x = self.ln_post(x)
  140. return x
  141. class TextDecoder(nn.Module):
  142. def __init__(
  143. self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
  144. ):
  145. super().__init__()
  146. self.token_embedding = nn.Embedding(n_vocab, n_state)
  147. self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
  148. self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
  149. [
  150. ResidualAttentionBlock(n_state, n_head, cross_attention=True)
  151. for _ in range(n_layer)
  152. ]
  153. )
  154. self.ln = LayerNorm(n_state)
  155. mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
  156. self.register_buffer("mask", mask, persistent=False)
  157. def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
  158. """
  159. x : torch.LongTensor, shape = (batch_size, <= n_ctx)
  160. the text tokens
  161. xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
  162. the encoded audio features to be attended on
  163. """
  164. offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
  165. x = (
  166. self.token_embedding(x)
  167. + self.positional_embedding[offset : offset + x.shape[-1]]
  168. )
  169. x = x.to(xa.dtype)
  170. for block in self.blocks:
  171. x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
  172. x = self.ln(x)
  173. logits = (
  174. x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
  175. ).float()
  176. return logits
  177. class Whisper(nn.Module):
  178. def __init__(self, dims: ModelDimensions):
  179. super().__init__()
  180. self.dims = dims
  181. self.encoder = AudioEncoder(
  182. self.dims.n_mels,
  183. self.dims.n_audio_ctx,
  184. self.dims.n_audio_state,
  185. self.dims.n_audio_head,
  186. self.dims.n_audio_layer,
  187. )
  188. self.decoder = TextDecoder(
  189. self.dims.n_vocab,
  190. self.dims.n_text_ctx,
  191. self.dims.n_text_state,
  192. self.dims.n_text_head,
  193. self.dims.n_text_layer,
  194. )
  195. # use the last half among the decoder layers for time alignment by default;
  196. # to use a specific set of heads, see `set_alignment_heads()` below.
  197. all_heads = torch.zeros(
  198. self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
  199. )
  200. all_heads[self.dims.n_text_layer // 2 :] = True
  201. self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
  202. def set_alignment_heads(self, dump: bytes):
  203. array = np.frombuffer(
  204. gzip.decompress(base64.b85decode(dump)), dtype=bool
  205. ).copy()
  206. mask = torch.from_numpy(array).reshape(
  207. self.dims.n_text_layer, self.dims.n_text_head
  208. )
  209. self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
  210. def embed_audio(self, mel: torch.Tensor):
  211. return self.encoder(mel)
  212. def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
  213. return self.decoder(tokens, audio_features)
  214. def forward(
  215. self, mel: torch.Tensor, tokens: torch.Tensor
  216. ) -> Dict[str, torch.Tensor]:
  217. return self.decoder(tokens, self.encoder(mel))
  218. @property
  219. def device(self):
  220. return next(self.parameters()).device
  221. @property
  222. def is_multilingual(self):
  223. return self.dims.n_vocab >= 51865
  224. @property
  225. def num_languages(self):
  226. return self.dims.n_vocab - 51765 - int(self.is_multilingual)
  227. def install_kv_cache_hooks(self, cache: Optional[dict] = None):
  228. """
  229. The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
  230. tensors calculated for the previous positions. This method returns a dictionary that stores
  231. all caches, and the necessary hooks for the key and value projection modules that save the
  232. intermediate tensors to be reused during later calculations.
  233. Returns
  234. -------
  235. cache : Dict[nn.Module, torch.Tensor]
  236. A dictionary object mapping the key/value projection modules to its cache
  237. hooks : List[RemovableHandle]
  238. List of PyTorch RemovableHandle objects to stop the hooks to be called
  239. """
  240. cache = {**cache} if cache is not None else {}
  241. hooks = []
  242. def save_to_cache(module, _, output):
  243. if module not in cache or output.shape[1] > self.dims.n_text_ctx:
  244. # save as-is, for the first token or cross attention
  245. cache[module] = output
  246. else:
  247. cache[module] = torch.cat([cache[module], output], dim=1).detach()
  248. return cache[module]
  249. def install_hooks(layer: nn.Module):
  250. if isinstance(layer, MultiHeadAttention):
  251. hooks.append(layer.key.register_forward_hook(save_to_cache))
  252. hooks.append(layer.value.register_forward_hook(save_to_cache))
  253. self.decoder.apply(install_hooks)
  254. return cache, hooks
  255. detect_language = detect_language_function
  256. transcribe = transcribe_function
  257. decode = decode_function