| 
														
															@@ -197,7 +197,7 @@ class TextDecoder(nn.Module): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         """ 
														 | 
														
														 | 
														
															         """ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         x : torch.LongTensor, shape = (batch_size, <= n_ctx) 
														 | 
														
														 | 
														
															         x : torch.LongTensor, shape = (batch_size, <= n_ctx) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															             the text tokens 
														 | 
														
														 | 
														
															             the text tokens 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        xa : torch.Tensor, shape = (batch_size, n_mels, n_audio_ctx) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															             the encoded audio features to be attended on 
														 | 
														
														 | 
														
															             the encoded audio features to be attended on 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         """ 
														 | 
														
														 | 
														
															         """ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 
														 | 
														
														 | 
														
															         offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 
														 |