| 
					
				 | 
			
			
				@@ -78,7 +78,7 @@ class DecodingOptions: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     sample_len: Optional[int] = None  # maximum number of tokens to sample 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     best_of: Optional[int] = None     # number of independent samples to collect, when t > 0 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     beam_size: Optional[int] = None   # number of beams in beam search, when t == 0 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    patience: float = 0.0             # patience in beam search (https://arxiv.org/abs/2204.05424) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    patience: Optional[float] = None  # patience in beam search (https://arxiv.org/abs/2204.05424) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     # options for ranking generations (either beams or best-of-N samples) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     length_penalty: Optional[float] = None   # "alpha" in Google NMT, None defaults to length norm 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -275,14 +275,16 @@ class GreedyDecoder(TokenDecoder): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 class BeamSearchDecoder(TokenDecoder): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    def __init__(self, beam_size: int, eot: int, inference: Inference, patience: float = 0.0): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def __init__(self, beam_size: int, eot: int, inference: Inference, patience: Optional[float] = None): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         self.beam_size = beam_size 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         self.eot = eot 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         self.inference = inference 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.patience = patience 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.max_candidates: int = round(beam_size * (1.0 + patience)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.patience = patience or 1.0 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.max_candidates: int = round(beam_size * self.patience) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         self.finished_sequences = None 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        assert self.max_candidates > 0, f"Invalid beam size ({beam_size}) or patience ({patience})" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def reset(self): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         self.finished_sequences = None 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -496,8 +498,8 @@ class DecodingTask: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if options.temperature == 0: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             if options.best_of is not None: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 raise ValueError("best_of with greedy sampling (T=0) is not compatible") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        if options.patience != 0.0 and options.beam_size is None: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            raise ValueError("nonzero patience requires beam_size to be given") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if options.patience is not None and options.beam_size is None: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            raise ValueError("patience requires beam_size to be given") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if options.length_penalty is not None and not (0 <= options.length_penalty <= 1): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             raise ValueError("length_penalty (alpha) should be a value between 0 and 1") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 |