model.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. from dataclasses import dataclass
  2. from typing import Dict
  3. from typing import Iterable, Optional
  4. import numpy as np
  5. import torch
  6. import torch.nn.functional as F
  7. from torch import Tensor
  8. from torch import nn
  9. from .transcribe import transcribe as transcribe_function
  10. from .decoding import detect_language as detect_language_function, decode as decode_function
  11. @dataclass
  12. class ModelDimensions:
  13. n_mels: int
  14. n_audio_ctx: int
  15. n_audio_state: int
  16. n_audio_head: int
  17. n_audio_layer: int
  18. n_vocab: int
  19. n_text_ctx: int
  20. n_text_state: int
  21. n_text_head: int
  22. n_text_layer: int
  23. class LayerNorm(nn.LayerNorm):
  24. def forward(self, x: Tensor) -> Tensor:
  25. return super().forward(x.float()).type(x.dtype)
  26. class Linear(nn.Linear):
  27. def forward(self, x: Tensor) -> Tensor:
  28. return F.linear(
  29. x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype)
  30. )
  31. class Conv1d(nn.Conv1d):
  32. def _conv_forward(self, x: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor:
  33. return super()._conv_forward(
  34. x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
  35. )
  36. def sinusoids(length, channels, max_timescale=10000):
  37. """Returns sinusoids for positional embedding"""
  38. assert channels % 2 == 0
  39. log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
  40. inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
  41. scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
  42. return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
  43. class MultiHeadAttention(nn.Module):
  44. def __init__(self, n_state: int, n_head: int):
  45. super().__init__()
  46. self.n_head = n_head
  47. self.query = Linear(n_state, n_state)
  48. self.key = Linear(n_state, n_state, bias=False)
  49. self.value = Linear(n_state, n_state)
  50. self.out = Linear(n_state, n_state)
  51. def forward(
  52. self,
  53. x: Tensor,
  54. xa: Optional[Tensor] = None,
  55. mask: Optional[Tensor] = None,
  56. kv_cache: Optional[dict] = None,
  57. ):
  58. q = self.query(x)
  59. if kv_cache is None or xa is None:
  60. # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
  61. # otherwise, perform key/value projections for self- or cross-attention as usual.
  62. k = self.key(x if xa is None else xa)
  63. v = self.value(x if xa is None else xa)
  64. else:
  65. # for cross-attention, calculate keys and values once and reuse in subsequent calls.
  66. k = kv_cache.get(self.key, self.key(xa))
  67. v = kv_cache.get(self.value, self.value(xa))
  68. wv = self.qkv_attention(q, k, v, mask)
  69. return self.out(wv)
  70. def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None):
  71. n_batch, n_ctx, n_state = q.shape
  72. scale = (n_state // self.n_head) ** -0.25
  73. q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
  74. k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
  75. v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
  76. qk = q @ k
  77. if mask is not None:
  78. qk = qk + mask[:n_ctx, :n_ctx]
  79. w = F.softmax(qk.float(), dim=-1).to(q.dtype)
  80. return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
  81. class ResidualAttentionBlock(nn.Module):
  82. def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
  83. super().__init__()
  84. self.attn = MultiHeadAttention(n_state, n_head)
  85. self.attn_ln = LayerNorm(n_state)
  86. self.cross_attn = MultiHeadAttention(n_state, n_head) if cross_attention else None
  87. self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
  88. n_mlp = n_state * 4
  89. self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state))
  90. self.mlp_ln = LayerNorm(n_state)
  91. def forward(
  92. self,
  93. x: Tensor,
  94. xa: Optional[Tensor] = None,
  95. mask: Optional[Tensor] = None,
  96. kv_cache: Optional[dict] = None,
  97. ):
  98. x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)
  99. if self.cross_attn:
  100. x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)
  101. x = x + self.mlp(self.mlp_ln(x))
  102. return x
  103. class AudioEncoder(nn.Module):
  104. def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
  105. super().__init__()
  106. self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
  107. self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
  108. self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
  109. self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
  110. [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
  111. )
  112. self.ln_post = LayerNorm(n_state)
  113. def forward(self, x: Tensor):
  114. """
  115. x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
  116. the mel spectrogram of the audio
  117. """
  118. x = F.gelu(self.conv1(x))
  119. x = F.gelu(self.conv2(x))
  120. x = x.permute(0, 2, 1)
  121. assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
  122. x = (x + self.positional_embedding).to(x.dtype)
  123. for block in self.blocks:
  124. x = block(x)
  125. x = self.ln_post(x)
  126. return x
  127. class TextDecoder(nn.Module):
  128. def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
  129. super().__init__()
  130. self.token_embedding = nn.Embedding(n_vocab, n_state)
  131. self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
  132. self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
  133. [ResidualAttentionBlock(n_state, n_head, cross_attention=True) for _ in range(n_layer)]
  134. )
  135. self.ln = LayerNorm(n_state)
  136. mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
  137. self.register_buffer("mask", mask, persistent=False)
  138. def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
  139. """
  140. x : torch.LongTensor, shape = (batch_size, <= n_ctx)
  141. the text tokens
  142. xa : torch.Tensor, shape = (batch_size, n_mels, n_audio_ctx)
  143. the encoded audio features to be attended on
  144. """
  145. offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
  146. x = self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]]
  147. x = x.to(xa.dtype)
  148. for block in self.blocks:
  149. x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
  150. x = self.ln(x)
  151. logits = (x @ self.token_embedding.weight.to(x.dtype).T).float()
  152. return logits
  153. class Whisper(nn.Module):
  154. def __init__(self, dims: ModelDimensions):
  155. super().__init__()
  156. self.dims = dims
  157. self.encoder = AudioEncoder(
  158. self.dims.n_mels,
  159. self.dims.n_audio_ctx,
  160. self.dims.n_audio_state,
  161. self.dims.n_audio_head,
  162. self.dims.n_audio_layer,
  163. )
  164. self.decoder = TextDecoder(
  165. self.dims.n_vocab,
  166. self.dims.n_text_ctx,
  167. self.dims.n_text_state,
  168. self.dims.n_text_head,
  169. self.dims.n_text_layer,
  170. )
  171. def embed_audio(self, mel: torch.Tensor):
  172. return self.encoder.forward(mel)
  173. def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
  174. return self.decoder.forward(tokens, audio_features)
  175. def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> Dict[str, torch.Tensor]:
  176. return self.decoder(tokens, self.encoder(mel))
  177. @property
  178. def device(self):
  179. return next(self.parameters()).device
  180. @property
  181. def is_multilingual(self):
  182. return self.dims.n_vocab == 51865
  183. def install_kv_cache_hooks(self, cache: Optional[dict] = None):
  184. """
  185. The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
  186. tensors calculated for the previous positions. This method returns a dictionary that stores
  187. all caches, and the necessary hooks for the key and value projection modules that save the
  188. intermediate tensors to be reused during later calculations.
  189. Returns
  190. -------
  191. cache : Dict[nn.Module, torch.Tensor]
  192. A dictionary object mapping the key/value projection modules to its cache
  193. hooks : List[RemovableHandle]
  194. List of PyTorch RemovableHandle objects to stop the hooks to be called
  195. """
  196. cache = {**cache} if cache is not None else {}
  197. hooks = []
  198. def save_to_cache(module, _, output):
  199. if module not in cache or output.shape[1] > self.decoder.positional_embedding.shape[0]:
  200. cache[module] = output # save as-is, for the first token or cross attention
  201. else:
  202. cache[module] = torch.cat([cache[module], output], dim=1).detach()
  203. return cache[module]
  204. def install_hooks(layer: nn.Module):
  205. if isinstance(layer, MultiHeadAttention):
  206. hooks.append(layer.key.register_forward_hook(save_to_cache))
  207. hooks.append(layer.value.register_forward_hook(save_to_cache))
  208. self.decoder.apply(install_hooks)
  209. return cache, hooks
  210. detect_language = detect_language_function
  211. transcribe = transcribe_function
  212. decode = decode_function