| 
					
				 | 
			
			
				@@ -189,7 +189,7 @@ class TextDecoder(nn.Module): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             x = block(x, xa, mask=self.mask, kv_cache=kv_cache) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         x = self.ln(x) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        logits = (x @ self.token_embedding.weight.to(x.dtype).T).float() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        logits = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         return logits 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 |