|
@@ -78,7 +78,7 @@ class DecodingOptions:
|
|
|
sample_len: Optional[int] = None # maximum number of tokens to sample
|
|
|
best_of: Optional[int] = None # number of independent samples to collect, when t > 0
|
|
|
beam_size: Optional[int] = None # number of beams in beam search, when t == 0
|
|
|
- patience: float = 0.0 # patience in beam search (https://arxiv.org/abs/2204.05424)
|
|
|
+ patience: Optional[float] = None # patience in beam search (https://arxiv.org/abs/2204.05424)
|
|
|
|
|
|
# options for ranking generations (either beams or best-of-N samples)
|
|
|
length_penalty: Optional[float] = None # "alpha" in Google NMT, None defaults to length norm
|
|
@@ -275,14 +275,16 @@ class GreedyDecoder(TokenDecoder):
|
|
|
|
|
|
|
|
|
class BeamSearchDecoder(TokenDecoder):
|
|
|
- def __init__(self, beam_size: int, eot: int, inference: Inference, patience: float = 0.0):
|
|
|
+ def __init__(self, beam_size: int, eot: int, inference: Inference, patience: Optional[float] = None):
|
|
|
self.beam_size = beam_size
|
|
|
self.eot = eot
|
|
|
self.inference = inference
|
|
|
- self.patience = patience
|
|
|
- self.max_candidates: int = round(beam_size * (1.0 + patience))
|
|
|
+ self.patience = patience or 1.0
|
|
|
+ self.max_candidates: int = round(beam_size * self.patience)
|
|
|
self.finished_sequences = None
|
|
|
|
|
|
+ assert self.max_candidates > 0, f"Invalid beam size ({beam_size}) or patience ({patience})"
|
|
|
+
|
|
|
def reset(self):
|
|
|
self.finished_sequences = None
|
|
|
|
|
@@ -496,8 +498,8 @@ class DecodingTask:
|
|
|
if options.temperature == 0:
|
|
|
if options.best_of is not None:
|
|
|
raise ValueError("best_of with greedy sampling (T=0) is not compatible")
|
|
|
- if options.patience != 0.0 and options.beam_size is None:
|
|
|
- raise ValueError("nonzero patience requires beam_size to be given")
|
|
|
+ if options.patience is not None and options.beam_size is None:
|
|
|
+ raise ValueError("patience requires beam_size to be given")
|
|
|
if options.length_penalty is not None and not (0 <= options.length_penalty <= 1):
|
|
|
raise ValueError("length_penalty (alpha) should be a value between 0 and 1")
|
|
|
|