Browse Source

patience definition to match the paper

Jong Wook Kim 2 years ago
parent
commit
62fe7f1009
2 changed files with 9 additions and 7 deletions
  1. 8 6
      whisper/decoding.py
  2. 1 1
      whisper/transcribe.py

+ 8 - 6
whisper/decoding.py

@@ -78,7 +78,7 @@ class DecodingOptions:
     sample_len: Optional[int] = None  # maximum number of tokens to sample
     best_of: Optional[int] = None     # number of independent samples to collect, when t > 0
     beam_size: Optional[int] = None   # number of beams in beam search, when t == 0
-    patience: float = 0.0             # patience in beam search (https://arxiv.org/abs/2204.05424)
+    patience: Optional[float] = None  # patience in beam search (https://arxiv.org/abs/2204.05424)
 
     # options for ranking generations (either beams or best-of-N samples)
     length_penalty: Optional[float] = None   # "alpha" in Google NMT, None defaults to length norm
@@ -275,14 +275,16 @@ class GreedyDecoder(TokenDecoder):
 
 
 class BeamSearchDecoder(TokenDecoder):
-    def __init__(self, beam_size: int, eot: int, inference: Inference, patience: float = 0.0):
+    def __init__(self, beam_size: int, eot: int, inference: Inference, patience: Optional[float] = None):
         self.beam_size = beam_size
         self.eot = eot
         self.inference = inference
-        self.patience = patience
-        self.max_candidates: int = round(beam_size * (1.0 + patience))
+        self.patience = patience or 1.0
+        self.max_candidates: int = round(beam_size * self.patience)
         self.finished_sequences = None
 
+        assert self.max_candidates > 0, f"Invalid beam size ({beam_size}) or patience ({patience})"
+
     def reset(self):
         self.finished_sequences = None
 
@@ -496,8 +498,8 @@ class DecodingTask:
         if options.temperature == 0:
             if options.best_of is not None:
                 raise ValueError("best_of with greedy sampling (T=0) is not compatible")
-        if options.patience != 0.0 and options.beam_size is None:
-            raise ValueError("nonzero patience requires beam_size to be given")
+        if options.patience is not None and options.beam_size is None:
+            raise ValueError("patience requires beam_size to be given")
         if options.length_penalty is not None and not (0 <= options.length_penalty <= 1):
             raise ValueError("length_penalty (alpha) should be a value between 0 and 1")
 

+ 1 - 1
whisper/transcribe.py

@@ -263,7 +263,7 @@ def cli():
     parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
     parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
     parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
-    parser.add_argument("--patience", type=float, default=0.0, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (0.0) is equivalent to not using patience")
+    parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
     parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple lengt normalization by default")
 
     parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")