| 
					
				 | 
			
			
				@@ -197,7 +197,7 @@ class TextDecoder(nn.Module): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         x : torch.LongTensor, shape = (batch_size, <= n_ctx) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             the text tokens 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        xa : torch.Tensor, shape = (batch_size, n_mels, n_audio_ctx) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             the encoded audio features to be attended on 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 
			 |