Browse Source

Fix truncated words list when the replacement character is decoded (#1089)

Guillaume Klein 2 years ago
parent
commit
5f9ac653b7
2 changed files with 21 additions and 1 deletions
  1. 10 0
      tests/test_tokenizer.py
  2. 11 1
      whisper/tokenizer.py

+ 10 - 0
tests/test_tokenizer.py

@@ -12,3 +12,13 @@ def test_tokenizer():
     assert gpt2_tokenizer.decode(gpt2_tokens) == text
     assert multilingual_tokenizer.decode(multilingual_tokens) == text
     assert len(gpt2_tokens) > len(multilingual_tokens)
+
+
+def test_split_on_unicode():
+    multilingual_tokenizer = get_tokenizer(multilingual=True)
+
+    tokens = [8404, 871, 287, 6, 246, 526, 3210, 20378]
+    words, word_tokens = multilingual_tokenizer.split_tokens_on_unicode(tokens)
+
+    assert words == [" elle", " est", " l", "'", "�", "é", "rit", "oire"]
+    assert word_tokens == [[8404], [871], [287], [6], [246], [526], [3210], [20378]]

+ 11 - 1
whisper/tokenizer.py

@@ -279,17 +279,27 @@ class Tokenizer:
         return self.split_tokens_on_spaces(tokens)
 
     def split_tokens_on_unicode(self, tokens: List[int]):
+        decoded_full = self.decode_with_timestamps(tokens)
+        replacement_char = "\ufffd"
+
         words = []
         word_tokens = []
         current_tokens = []
+        unicode_offset = 0
 
         for token in tokens:
             current_tokens.append(token)
             decoded = self.decode_with_timestamps(current_tokens)
-            if "\ufffd" not in decoded:
+
+            if (
+                replacement_char not in decoded
+                or decoded_full[unicode_offset + decoded.index(replacement_char)]
+                == replacement_char
+            ):
                 words.append(decoded)
                 word_tokens.append(current_tokens)
                 current_tokens = []
+                unicode_offset += len(decoded)
 
         return words, word_tokens