|
@@ -197,7 +197,7 @@ class TextDecoder(nn.Module):
|
|
|
"""
|
|
|
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
|
|
|
the text tokens
|
|
|
- xa : torch.Tensor, shape = (batch_size, n_mels, n_audio_ctx)
|
|
|
+ xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
|
|
|
the encoded audio features to be attended on
|
|
|
"""
|
|
|
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|