diff --git a/candle-examples/examples/marian-mt/README.md b/candle-examples/examples/marian-mt/README.md
index e41bf007..eecaee32 100644
--- a/candle-examples/examples/marian-mt/README.md
+++ b/candle-examples/examples/marian-mt/README.md
@@ -17,3 +17,22 @@ cargo run --example marian-mt --release -- \
I know you are waiting for me. I will go through the forest, I will go through the
mountain. I cannot stay far from you any longer.
```
+
+## Generating the tokenizer.json files
+
+You can use the following script to generate the `tokenizer.json` config files
+from the hf-hub repos. This requires the `tokenizers` and `sentencepiece`
+packages to be install and use the `convert_slow_tokenizer.py` script from this
+directory.
+
+```python
+from convert_slow_tokenizer import MarianConverter
+from transformers import AutoTokenizer
+
+
+tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-fr-en", use_fast=False)
+fast_tokenizer = MarianConverter(tokenizer, index=0).converted()
+fast_tokenizer.save(f"tokenizer-marian-base-fr.json")
+fast_tokenizer = MarianConverter(tokenizer, index=1).converted()
+fast_tokenizer.save(f"tokenizer-marian-base-en.json")
+```
diff --git a/candle-examples/examples/marian-mt/convert_slow_tokenizer.py b/candle-examples/examples/marian-mt/convert_slow_tokenizer.py
new file mode 100644
index 00000000..8ae32f79
--- /dev/null
+++ b/candle-examples/examples/marian-mt/convert_slow_tokenizer.py
@@ -0,0 +1,1385 @@
+# coding=utf-8
+# Copyright 2018 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Utilities to convert slow tokenizers in their fast tokenizers counterparts.
+
+All the conversions are grouped here to gather SentencePiece dependencies outside of the fast tokenizers files and
+allow to make our dependency on SentencePiece optional.
+"""
+
+import warnings
+from typing import Dict, List, Tuple
+
+from packaging import version
+from pathlib import Path
+from tokenizers import AddedToken, Regex, Tokenizer, decoders, normalizers, pre_tokenizers, processors
+from tokenizers.models import BPE, Unigram, WordPiece
+
+from transformers.utils import is_protobuf_available, requires_backends
+from transformers.utils.import_utils import PROTOBUF_IMPORT_ERROR
+
+
+def import_protobuf(error_message=""):
+ if is_protobuf_available():
+ import google.protobuf
+
+ if version.parse(google.protobuf.__version__) < version.parse("4.0.0"):
+ from transformers.utils import sentencepiece_model_pb2
+ else:
+ from transformers.utils import sentencepiece_model_pb2_new as sentencepiece_model_pb2
+ return sentencepiece_model_pb2
+ else:
+ raise ImportError(PROTOBUF_IMPORT_ERROR.format(error_message))
+
+
+class SentencePieceExtractor:
+ """
+ Extractor implementation for SentencePiece trained models. https://github.com/google/sentencepiece
+ """
+
+ def __init__(self, model: str):
+ requires_backends(self, "sentencepiece")
+ from sentencepiece import SentencePieceProcessor
+
+ self.sp = SentencePieceProcessor()
+ self.sp.Load(model)
+
+ def extract(self, vocab_scores=None) -> Tuple[Dict[str, int], List[Tuple]]:
+ """
+ By default will return vocab and merges with respect to their order, by sending `vocab_scores` we're going to
+ order the merges with respect to the piece scores instead.
+ """
+ sp = self.sp
+ vocab = {sp.id_to_piece(index): index for index in range(sp.GetPieceSize())}
+ if vocab_scores is not None:
+ vocab_scores, reverse = dict(vocab_scores), True
+ else:
+ vocab_scores, reverse = vocab, False
+
+ # Merges
+ merges = []
+ for merge, piece_score in vocab_scores.items():
+ local = []
+ for index in range(1, len(merge)):
+ piece_l, piece_r = merge[:index], merge[index:]
+ if piece_l in vocab and piece_r in vocab:
+ local.append((piece_l, piece_r, piece_score))
+ local = sorted(local, key=lambda x: (vocab[x[0]], vocab[x[1]]))
+ merges.extend(local)
+
+ merges = sorted(merges, key=lambda val: val[2], reverse=reverse)
+ merges = [(val[0], val[1]) for val in merges]
+ return vocab, merges
+
+
+def check_number_comma(piece: str) -> bool:
+ return len(piece) < 2 or piece[-1] != "," or not piece[-2].isdigit()
+
+
+class Converter:
+ def __init__(self, original_tokenizer):
+ self.original_tokenizer = original_tokenizer
+
+ def converted(self) -> Tokenizer:
+ raise NotImplementedError()
+
+
+class BertConverter(Converter):
+ def converted(self) -> Tokenizer:
+ vocab = self.original_tokenizer.vocab
+ tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
+
+ tokenize_chinese_chars = False
+ strip_accents = False
+ do_lower_case = False
+ if hasattr(self.original_tokenizer, "basic_tokenizer"):
+ tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
+ strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
+ do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
+
+ tokenizer.normalizer = normalizers.BertNormalizer(
+ clean_text=True,
+ handle_chinese_chars=tokenize_chinese_chars,
+ strip_accents=strip_accents,
+ lowercase=do_lower_case,
+ )
+ tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
+
+ cls = str(self.original_tokenizer.cls_token)
+ sep = str(self.original_tokenizer.sep_token)
+ cls_token_id = self.original_tokenizer.cls_token_id
+ sep_token_id = self.original_tokenizer.sep_token_id
+
+ tokenizer.post_processor = processors.TemplateProcessing(
+ single=f"{cls}:0 $A:0 {sep}:0",
+ pair=f"{cls}:0 $A:0 {sep}:0 $B:1 {sep}:1",
+ special_tokens=[
+ (cls, cls_token_id),
+ (sep, sep_token_id),
+ ],
+ )
+ tokenizer.decoder = decoders.WordPiece(prefix="##")
+
+ return tokenizer
+
+
+class SplinterConverter(Converter):
+ def converted(self) -> Tokenizer:
+ vocab = self.original_tokenizer.vocab
+ tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
+
+ tokenize_chinese_chars = False
+ strip_accents = False
+ do_lower_case = False
+ if hasattr(self.original_tokenizer, "basic_tokenizer"):
+ tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
+ strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
+ do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
+
+ tokenizer.normalizer = normalizers.BertNormalizer(
+ clean_text=True,
+ handle_chinese_chars=tokenize_chinese_chars,
+ strip_accents=strip_accents,
+ lowercase=do_lower_case,
+ )
+ tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
+
+ cls = str(self.original_tokenizer.cls_token)
+ sep = str(self.original_tokenizer.sep_token)
+ question = str(self.original_tokenizer.question_token)
+ dot = "."
+ cls_token_id = self.original_tokenizer.cls_token_id
+ sep_token_id = self.original_tokenizer.sep_token_id
+ question_token_id = self.original_tokenizer.question_token_id
+ dot_token_id = self.original_tokenizer.convert_tokens_to_ids(".")
+
+ if self.original_tokenizer.padding_side == "right":
+ pair = f"{cls}:0 $A:0 {question} {dot} {sep}:0 $B:1 {sep}:1"
+ else:
+ pair = f"{cls}:0 $A:0 {sep}:0 $B:1 {question} {dot} {sep}:1"
+
+ tokenizer.post_processor = processors.TemplateProcessing(
+ single=f"{cls}:0 $A:0 {sep}:0",
+ pair=pair,
+ special_tokens=[
+ (cls, cls_token_id),
+ (sep, sep_token_id),
+ (question, question_token_id),
+ (dot, dot_token_id),
+ ],
+ )
+ tokenizer.decoder = decoders.WordPiece(prefix="##")
+
+ return tokenizer
+
+
+class FunnelConverter(Converter):
+ def converted(self) -> Tokenizer:
+ vocab = self.original_tokenizer.vocab
+ tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
+
+ tokenize_chinese_chars = False
+ strip_accents = False
+ do_lower_case = False
+ if hasattr(self.original_tokenizer, "basic_tokenizer"):
+ tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
+ strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
+ do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
+
+ tokenizer.normalizer = normalizers.BertNormalizer(
+ clean_text=True,
+ handle_chinese_chars=tokenize_chinese_chars,
+ strip_accents=strip_accents,
+ lowercase=do_lower_case,
+ )
+ tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
+
+ cls = str(self.original_tokenizer.cls_token)
+ sep = str(self.original_tokenizer.sep_token)
+ cls_token_id = self.original_tokenizer.cls_token_id
+ sep_token_id = self.original_tokenizer.sep_token_id
+
+ tokenizer.post_processor = processors.TemplateProcessing(
+ single=f"{cls}:2 $A:0 {sep}:0", # token_type_id is 2 for Funnel transformer
+ pair=f"{cls}:2 $A:0 {sep}:0 $B:1 {sep}:1",
+ special_tokens=[
+ (cls, cls_token_id),
+ (sep, sep_token_id),
+ ],
+ )
+ tokenizer.decoder = decoders.WordPiece(prefix="##")
+
+ return tokenizer
+
+
+class MPNetConverter(Converter):
+ def converted(self) -> Tokenizer:
+ vocab = self.original_tokenizer.vocab
+ tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
+
+ tokenize_chinese_chars = False
+ strip_accents = False
+ do_lower_case = False
+ if hasattr(self.original_tokenizer, "basic_tokenizer"):
+ tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
+ strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
+ do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
+
+ tokenizer.normalizer = normalizers.BertNormalizer(
+ clean_text=True,
+ handle_chinese_chars=tokenize_chinese_chars,
+ strip_accents=strip_accents,
+ lowercase=do_lower_case,
+ )
+ tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
+
+ cls = str(self.original_tokenizer.cls_token)
+ sep = str(self.original_tokenizer.sep_token)
+ cls_token_id = self.original_tokenizer.cls_token_id
+ sep_token_id = self.original_tokenizer.sep_token_id
+
+ tokenizer.post_processor = processors.TemplateProcessing(
+ single=f"{cls}:0 $A:0 {sep}:0",
+ pair=f"{cls}:0 $A:0 {sep}:0 {sep}:0 $B:1 {sep}:1", # MPNet uses two [SEP] tokens
+ special_tokens=[
+ (cls, cls_token_id),
+ (sep, sep_token_id),
+ ],
+ )
+ tokenizer.decoder = decoders.WordPiece(prefix="##")
+
+ return tokenizer
+
+
+class OpenAIGPTConverter(Converter):
+ def converted(self) -> Tokenizer:
+ vocab = self.original_tokenizer.encoder
+ merges = list(self.original_tokenizer.bpe_ranks.keys())
+ unk_token = self.original_tokenizer.unk_token
+
+ tokenizer = Tokenizer(
+ BPE(
+ vocab=vocab,
+ merges=merges,
+ dropout=None,
+ unk_token=str(unk_token),
+ end_of_word_suffix="",
+ fuse_unk=False,
+ )
+ )
+
+ if tokenizer.token_to_id(str(unk_token)) is not None:
+ tokenizer.add_special_tokens([str(unk_token)])
+
+ tokenizer.normalizer = normalizers.BertNormalizer(lowercase=True)
+ tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
+ tokenizer.decoder = decoders.BPEDecoder(suffix="")
+
+ return tokenizer
+
+
+class GPT2Converter(Converter):
+ def converted(self) -> Tokenizer:
+ vocab = self.original_tokenizer.encoder
+ merges = list(self.original_tokenizer.bpe_ranks.keys())
+
+ tokenizer = Tokenizer(
+ BPE(
+ vocab=vocab,
+ merges=merges,
+ dropout=None,
+ continuing_subword_prefix="",
+ end_of_word_suffix="",
+ fuse_unk=False,
+ )
+ )
+
+ tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=self.original_tokenizer.add_prefix_space)
+ tokenizer.decoder = decoders.ByteLevel()
+ if self.original_tokenizer.add_bos_token:
+ bos = self.original_tokenizer.bos_token
+ bos_token_id = self.original_tokenizer.bos_token_id
+ tokenizer.post_processor = processors.TemplateProcessing(
+ single=f"{bos}:0 $A:0",
+ pair=f"{bos}:0 $A:0 $B:1",
+ special_tokens=[
+ (bos, bos_token_id),
+ ],
+ )
+ else:
+ # XXX trim_offsets=False actually means this post_processor doesn't
+ # really do anything.
+ tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
+ return tokenizer
+
+
+class HerbertConverter(Converter):
+ def converted(self) -> Tokenizer:
+ tokenizer_info_str = "#version:"
+ token_suffix = ""
+
+ vocab = self.original_tokenizer.encoder
+ merges = list(self.original_tokenizer.bpe_ranks.keys())
+ if tokenizer_info_str in merges[0][0]:
+ merges = merges[1:]
+
+ tokenizer = Tokenizer(
+ BPE(
+ vocab,
+ merges,
+ dropout=None,
+ unk_token=self.original_tokenizer.unk_token,
+ end_of_word_suffix=token_suffix,
+ )
+ )
+
+ tokenizer.normalizer = normalizers.BertNormalizer(lowercase=False, strip_accents=False)
+ tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
+ tokenizer.decoder = decoders.BPEDecoder(suffix=token_suffix)
+ tokenizer.post_processor = processors.BertProcessing(
+ sep=(self.original_tokenizer.sep_token, self.original_tokenizer.sep_token_id),
+ cls=(self.original_tokenizer.cls_token, self.original_tokenizer.cls_token_id),
+ )
+
+ return tokenizer
+
+
+class RobertaConverter(Converter):
+ def converted(self) -> Tokenizer:
+ ot = self.original_tokenizer
+ vocab = ot.encoder
+ merges = list(ot.bpe_ranks.keys())
+
+ tokenizer = Tokenizer(
+ BPE(
+ vocab=vocab,
+ merges=merges,
+ dropout=None,
+ continuing_subword_prefix="",
+ end_of_word_suffix="",
+ fuse_unk=False,
+ )
+ )
+
+ tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space)
+ tokenizer.decoder = decoders.ByteLevel()
+ tokenizer.post_processor = processors.RobertaProcessing(
+ sep=(ot.sep_token, ot.sep_token_id),
+ cls=(ot.cls_token, ot.cls_token_id),
+ add_prefix_space=ot.add_prefix_space,
+ trim_offsets=True, # True by default on Roberta (historical)
+ )
+
+ return tokenizer
+
+
+class RoFormerConverter(Converter):
+ def converted(self) -> Tokenizer:
+ from .models.roformer.tokenization_utils import JiebaPreTokenizer
+
+ vocab = self.original_tokenizer.vocab
+ tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
+
+ strip_accents = False
+ do_lower_case = False
+ if hasattr(self.original_tokenizer, "basic_tokenizer"):
+ strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
+ do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
+
+ tokenizer.normalizer = normalizers.BertNormalizer(
+ clean_text=True,
+ handle_chinese_chars=False,
+ strip_accents=strip_accents,
+ lowercase=do_lower_case,
+ )
+ tokenizer.pre_tokenizer = pre_tokenizers.PreTokenizer.custom(JiebaPreTokenizer(vocab))
+
+ cls = str(self.original_tokenizer.cls_token)
+ sep = str(self.original_tokenizer.sep_token)
+ cls_token_id = self.original_tokenizer.cls_token_id
+ sep_token_id = self.original_tokenizer.sep_token_id
+
+ tokenizer.post_processor = processors.TemplateProcessing(
+ single=f"{cls}:0 $A:0 {sep}:0",
+ pair=f"{cls}:0 $A:0 {sep}:0 $B:1 {sep}:1",
+ special_tokens=[
+ (cls, cls_token_id),
+ (sep, sep_token_id),
+ ],
+ )
+ tokenizer.decoder = decoders.WordPiece(prefix="##")
+
+ return tokenizer
+
+
+class DebertaConverter(Converter):
+ def converted(self) -> Tokenizer:
+ ot = self.original_tokenizer
+ vocab = ot.encoder
+ merges = list(ot.bpe_ranks.keys())
+
+ tokenizer = Tokenizer(
+ BPE(
+ vocab=vocab,
+ merges=merges,
+ dropout=None,
+ continuing_subword_prefix="",
+ end_of_word_suffix="",
+ fuse_unk=False,
+ )
+ )
+
+ tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space)
+ tokenizer.decoder = decoders.ByteLevel()
+ tokenizer.post_processor = processors.TemplateProcessing(
+ single="[CLS]:0 $A:0 [SEP]:0",
+ pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
+ special_tokens=[
+ ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
+ ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
+ ],
+ )
+
+ return tokenizer
+
+
+class SpmConverter(Converter):
+ def __init__(self, *args):
+ requires_backends(self, "protobuf")
+
+ super().__init__(*args)
+
+ # from .utils import sentencepiece_model_pb2 as model_pb2
+ model_pb2 = import_protobuf()
+
+ m = model_pb2.ModelProto()
+ with open(self.original_tokenizer.vocab_file, "rb") as f:
+ m.ParseFromString(f.read())
+ self.proto = m
+
+ if self.proto.trainer_spec.byte_fallback:
+ if not getattr(self, "handle_byte_fallback", None):
+ warnings.warn(
+ "The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option"
+ " which is not implemented in the fast tokenizers. In practice this means that the fast version of the"
+ " tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these "
+ "unknown tokens into a sequence of byte tokens matching the original piece of text."
+ )
+
+ def vocab(self, proto):
+ return [(piece.piece, piece.score) for piece in proto.pieces]
+
+ def unk_id(self, proto):
+ return proto.trainer_spec.unk_id
+
+ def tokenizer(self, proto):
+ model_type = proto.trainer_spec.model_type
+ vocab_scores = self.vocab(proto)
+ unk_id = self.unk_id(proto)
+
+ if model_type == 1:
+ tokenizer = Tokenizer(Unigram(vocab_scores, unk_id))
+ elif model_type == 2:
+ _, merges = SentencePieceExtractor(self.original_tokenizer.vocab_file).extract()
+ bpe_vocab = {word: i for i, (word, score) in enumerate(vocab_scores)}
+ tokenizer = Tokenizer(
+ BPE(
+ bpe_vocab,
+ merges,
+ unk_token=proto.trainer_spec.unk_piece,
+ fuse_unk=True,
+ )
+ )
+ else:
+ raise Exception(
+ "You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
+ )
+
+ return tokenizer
+
+ def normalizer(self, proto):
+ precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
+ if not precompiled_charsmap:
+ return normalizers.Sequence([normalizers.Replace(Regex(" {2,}"), " ")])
+ else:
+ return normalizers.Sequence(
+ [normalizers.Precompiled(precompiled_charsmap), normalizers.Replace(Regex(" {2,}"), " ")]
+ )
+
+ def pre_tokenizer(self, replacement, add_prefix_space):
+ return pre_tokenizers.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space)
+
+ def post_processor(self):
+ return None
+
+ def decoder(self, replacement, add_prefix_space):
+ return decoders.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space)
+
+ def converted(self) -> Tokenizer:
+ tokenizer = self.tokenizer(self.proto)
+
+ # Tokenizer assemble
+ normalizer = self.normalizer(self.proto)
+ if normalizer is not None:
+ tokenizer.normalizer = normalizer
+
+ replacement = "▁"
+ add_prefix_space = True
+ pre_tokenizer = self.pre_tokenizer(replacement, add_prefix_space)
+ if pre_tokenizer is not None:
+ tokenizer.pre_tokenizer = pre_tokenizer
+
+ tokenizer.decoder = self.decoder(replacement, add_prefix_space)
+ post_processor = self.post_processor()
+ if post_processor:
+ tokenizer.post_processor = post_processor
+
+ return tokenizer
+
+
+class AlbertConverter(SpmConverter):
+ def vocab(self, proto):
+ return [
+ (piece.piece, piece.score) if check_number_comma(piece.piece) else (piece.piece, piece.score - 100)
+ for piece in proto.pieces
+ ]
+
+ def normalizer(self, proto):
+ list_normalizers = [
+ normalizers.Replace("``", '"'),
+ normalizers.Replace("''", '"'),
+ ]
+ if not self.original_tokenizer.keep_accents:
+ list_normalizers.append(normalizers.NFKD())
+ list_normalizers.append(normalizers.StripAccents())
+ if self.original_tokenizer.do_lower_case:
+ list_normalizers.append(normalizers.Lowercase())
+
+ precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
+
+ if precompiled_charsmap:
+ list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))
+
+ list_normalizers.append(normalizers.Replace(Regex(" {2,}"), " "))
+ return normalizers.Sequence(list_normalizers)
+
+ def post_processor(self):
+ return processors.TemplateProcessing(
+ single="[CLS]:0 $A:0 [SEP]:0",
+ pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
+ special_tokens=[
+ ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
+ ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
+ ],
+ )
+
+
+class BarthezConverter(SpmConverter):
+ def unk_id(self, proto):
+ unk_id = 3
+ return unk_id
+
+ def post_processor(self):
+ return processors.TemplateProcessing(
+ single=" $A ",
+ pair=" $A $B ",
+ special_tokens=[
+ ("", self.original_tokenizer.convert_tokens_to_ids("")),
+ ("", self.original_tokenizer.convert_tokens_to_ids("")),
+ ],
+ )
+
+
+class CamembertConverter(SpmConverter):
+ def vocab(self, proto):
+ vocab = [
+ ("NOTUSED", 0.0),
+ ("", 0.0),
+ ("NOTUSED", 0.0),
+ ("", 0.0),
+ ("NOTUSED", -100),
+ ]
+ # We down-grade the original SentencePiece by -100 to avoid using it and use our added token instead
+ vocab += [(piece.piece, piece.score) for piece in proto.pieces[1:]]
+ vocab += [("", 0.0)]
+ return vocab
+
+ def unk_id(self, proto):
+ # See vocab unk position
+ return 3
+
+ def post_processor(self):
+ return processors.TemplateProcessing(
+ single=" $A ",
+ pair=" $A $B ",
+ special_tokens=[
+ ("", self.original_tokenizer.convert_tokens_to_ids("")),
+ ("", self.original_tokenizer.convert_tokens_to_ids("")),
+ ],
+ )
+
+
+class DebertaV2Converter(SpmConverter):
+ def pre_tokenizer(self, replacement, add_prefix_space):
+ list_pretokenizers = []
+ if self.original_tokenizer.split_by_punct:
+ list_pretokenizers.append(pre_tokenizers.Punctuation(behavior="isolated"))
+ list_pretokenizers.append(pre_tokenizers.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space))
+ return pre_tokenizers.Sequence(list_pretokenizers)
+
+ def normalizer(self, proto):
+ list_normalizers = []
+ if self.original_tokenizer.do_lower_case:
+ list_normalizers.append(normalizers.Lowercase())
+ list_normalizers.append(normalizers.Strip())
+
+ precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
+ if precompiled_charsmap:
+ list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))
+ list_normalizers.append(normalizers.Replace(Regex(" {2,}"), " "))
+
+ return normalizers.Sequence(list_normalizers)
+
+ def post_processor(self):
+ return processors.TemplateProcessing(
+ single="[CLS]:0 $A:0 [SEP]:0",
+ pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
+ special_tokens=[
+ ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
+ ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
+ ],
+ )
+
+
+class MBartConverter(SpmConverter):
+ def vocab(self, proto):
+ vocab = [
+ ("", 0.0),
+ ("", 0.0),
+ ("", 0.0),
+ ("", 0.0),
+ ]
+ vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
+ vocab += [
+ ("ar_AR", 0.0),
+ ("cs_CZ", 0.0),
+ ("de_DE", 0.0),
+ ("en_XX", 0.0),
+ ("es_XX", 0.0),
+ ("et_EE", 0.0),
+ ("fi_FI", 0.0),
+ ("fr_XX", 0.0),
+ ("gu_IN", 0.0),
+ ("hi_IN", 0.0),
+ ("it_IT", 0.0),
+ ("ja_XX", 0.0),
+ ("kk_KZ", 0.0),
+ ("ko_KR", 0.0),
+ ("lt_LT", 0.0),
+ ("lv_LV", 0.0),
+ ("my_MM", 0.0),
+ ("ne_NP", 0.0),
+ ("nl_XX", 0.0),
+ ("ro_RO", 0.0),
+ ("ru_RU", 0.0),
+ ("si_LK", 0.0),
+ ("tr_TR", 0.0),
+ ("vi_VN", 0.0),
+ ("zh_CN", 0.0),
+ ]
+ vocab += [("", 0.0)]
+ return vocab
+
+ def unk_id(self, proto):
+ return 3
+
+ def post_processor(self):
+ return processors.TemplateProcessing(
+ single="$A en_XX",
+ pair="$A $B en_XX",
+ special_tokens=[
+ ("en_XX", self.original_tokenizer.convert_tokens_to_ids("en_XX")),
+ ("", self.original_tokenizer.convert_tokens_to_ids("")),
+ ],
+ )
+
+
+class MBart50Converter(SpmConverter):
+ def vocab(self, proto):
+ vocab = [
+ ("", 0.0),
+ ("", 0.0),
+ ("", 0.0),
+ ("", 0.0),
+ ]
+ vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
+ # fmt: off
+ vocab += [("ar_AR", 0.0), ("cs_CZ", 0.0), ("de_DE", 0.0), ("en_XX", 0.0), ("es_XX", 0.0), ("et_EE", 0.0), ("fi_FI", 0.0), ("fr_XX", 0.0), ("gu_IN", 0.0), ("hi_IN", 0.0), ("it_IT", 0.0), ("ja_XX", 0.0), ("kk_KZ", 0.0), ("ko_KR", 0.0), ("lt_LT", 0.0), ("lv_LV", 0.0), ("my_MM", 0.0), ("ne_NP", 0.0), ("nl_XX", 0.0), ("ro_RO", 0.0), ("ru_RU", 0.0), ("si_LK", 0.0), ("tr_TR", 0.0), ("vi_VN", 0.0), ("zh_CN", 0.0), ("af_ZA", 0.0), ("az_AZ", 0.0), ("bn_IN", 0.0), ("fa_IR", 0.0), ("he_IL", 0.0), ("hr_HR", 0.0), ("id_ID", 0.0), ("ka_GE", 0.0), ("km_KH", 0.0), ("mk_MK", 0.0), ("ml_IN", 0.0), ("mn_MN", 0.0), ("mr_IN", 0.0), ("pl_PL", 0.0), ("ps_AF", 0.0), ("pt_XX", 0.0), ("sv_SE", 0.0), ("sw_KE", 0.0), ("ta_IN", 0.0), ("te_IN", 0.0), ("th_TH", 0.0), ("tl_XX", 0.0), ("uk_UA", 0.0), ("ur_PK", 0.0), ("xh_ZA", 0.0), ("gl_ES", 0.0), ("sl_SI", 0.0)]
+ # fmt: on
+ vocab += [("", 0.0)]
+ return vocab
+
+ def unk_id(self, proto):
+ return 3
+
+ def post_processor(self):
+ return processors.TemplateProcessing(
+ single="en_XX $A ",
+ pair="en_XX $A $B ",
+ special_tokens=[
+ ("en_XX", self.original_tokenizer.convert_tokens_to_ids("en_XX")),
+ ("", self.original_tokenizer.convert_tokens_to_ids("")),
+ ],
+ )
+
+
+class NllbConverter(SpmConverter):
+ def vocab(self, proto):
+ vocab = [
+ ("", 0.0),
+ ("", 0.0),
+ ("", 0.0),
+ ("", 0.0),
+ ]
+ vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
+ vocab += [
+ # fmt: off
+ ('ace_Arab', 0.0), ('ace_Latn', 0.0), ('acm_Arab', 0.0), ('acq_Arab', 0.0), ('aeb_Arab', 0.0), ('afr_Latn', 0.0), ('ajp_Arab', 0.0), ('aka_Latn', 0.0), ('amh_Ethi', 0.0), ('apc_Arab', 0.0), ('arb_Arab', 0.0), ('ars_Arab', 0.0), ('ary_Arab', 0.0), ('arz_Arab', 0.0), ('asm_Beng', 0.0), ('ast_Latn', 0.0), ('awa_Deva', 0.0), ('ayr_Latn', 0.0), ('azb_Arab', 0.0), ('azj_Latn', 0.0), ('bak_Cyrl', 0.0), ('bam_Latn', 0.0), ('ban_Latn', 0.0), ('bel_Cyrl', 0.0), ('bem_Latn', 0.0), ('ben_Beng', 0.0), ('bho_Deva', 0.0), ('bjn_Arab', 0.0), ('bjn_Latn', 0.0), ('bod_Tibt', 0.0), ('bos_Latn', 0.0), ('bug_Latn', 0.0), ('bul_Cyrl', 0.0), ('cat_Latn', 0.0), ('ceb_Latn', 0.0), ('ces_Latn', 0.0), ('cjk_Latn', 0.0), ('ckb_Arab', 0.0), ('crh_Latn', 0.0), ('cym_Latn', 0.0), ('dan_Latn', 0.0), ('deu_Latn', 0.0), ('dik_Latn', 0.0), ('dyu_Latn', 0.0), ('dzo_Tibt', 0.0), ('ell_Grek', 0.0), ('eng_Latn', 0.0), ('epo_Latn', 0.0), ('est_Latn', 0.0), ('eus_Latn', 0.0), ('ewe_Latn', 0.0), ('fao_Latn', 0.0), ('pes_Arab', 0.0), ('fij_Latn', 0.0), ('fin_Latn', 0.0), ('fon_Latn', 0.0), ('fra_Latn', 0.0), ('fur_Latn', 0.0), ('fuv_Latn', 0.0), ('gla_Latn', 0.0), ('gle_Latn', 0.0), ('glg_Latn', 0.0), ('grn_Latn', 0.0), ('guj_Gujr', 0.0), ('hat_Latn', 0.0), ('hau_Latn', 0.0), ('heb_Hebr', 0.0), ('hin_Deva', 0.0), ('hne_Deva', 0.0), ('hrv_Latn', 0.0), ('hun_Latn', 0.0), ('hye_Armn', 0.0), ('ibo_Latn', 0.0), ('ilo_Latn', 0.0), ('ind_Latn', 0.0), ('isl_Latn', 0.0), ('ita_Latn', 0.0), ('jav_Latn', 0.0), ('jpn_Jpan', 0.0), ('kab_Latn', 0.0), ('kac_Latn', 0.0), ('kam_Latn', 0.0), ('kan_Knda', 0.0), ('kas_Arab', 0.0), ('kas_Deva', 0.0), ('kat_Geor', 0.0), ('knc_Arab', 0.0), ('knc_Latn', 0.0), ('kaz_Cyrl', 0.0), ('kbp_Latn', 0.0), ('kea_Latn', 0.0), ('khm_Khmr', 0.0), ('kik_Latn', 0.0), ('kin_Latn', 0.0), ('kir_Cyrl', 0.0), ('kmb_Latn', 0.0), ('kon_Latn', 0.0), ('kor_Hang', 0.0), ('kmr_Latn', 0.0), ('lao_Laoo', 0.0), ('lvs_Latn', 0.0), ('lij_Latn', 0.0), ('lim_Latn', 0.0), ('lin_Latn', 0.0), ('lit_Latn', 0.0), ('lmo_Latn', 0.0), ('ltg_Latn', 0.0), ('ltz_Latn', 0.0), ('lua_Latn', 0.0), ('lug_Latn', 0.0), ('luo_Latn', 0.0), ('lus_Latn', 0.0), ('mag_Deva', 0.0), ('mai_Deva', 0.0), ('mal_Mlym', 0.0), ('mar_Deva', 0.0), ('min_Latn', 0.0), ('mkd_Cyrl', 0.0), ('plt_Latn', 0.0), ('mlt_Latn', 0.0), ('mni_Beng', 0.0), ('khk_Cyrl', 0.0), ('mos_Latn', 0.0), ('mri_Latn', 0.0), ('zsm_Latn', 0.0), ('mya_Mymr', 0.0), ('nld_Latn', 0.0), ('nno_Latn', 0.0), ('nob_Latn', 0.0), ('npi_Deva', 0.0), ('nso_Latn', 0.0), ('nus_Latn', 0.0), ('nya_Latn', 0.0), ('oci_Latn', 0.0), ('gaz_Latn', 0.0), ('ory_Orya', 0.0), ('pag_Latn', 0.0), ('pan_Guru', 0.0), ('pap_Latn', 0.0), ('pol_Latn', 0.0), ('por_Latn', 0.0), ('prs_Arab', 0.0), ('pbt_Arab', 0.0), ('quy_Latn', 0.0), ('ron_Latn', 0.0), ('run_Latn', 0.0), ('rus_Cyrl', 0.0), ('sag_Latn', 0.0), ('san_Deva', 0.0), ('sat_Beng', 0.0), ('scn_Latn', 0.0), ('shn_Mymr', 0.0), ('sin_Sinh', 0.0), ('slk_Latn', 0.0), ('slv_Latn', 0.0), ('smo_Latn', 0.0), ('sna_Latn', 0.0), ('snd_Arab', 0.0), ('som_Latn', 0.0), ('sot_Latn', 0.0), ('spa_Latn', 0.0), ('als_Latn', 0.0), ('srd_Latn', 0.0), ('srp_Cyrl', 0.0), ('ssw_Latn', 0.0), ('sun_Latn', 0.0), ('swe_Latn', 0.0), ('swh_Latn', 0.0), ('szl_Latn', 0.0), ('tam_Taml', 0.0), ('tat_Cyrl', 0.0), ('tel_Telu', 0.0), ('tgk_Cyrl', 0.0), ('tgl_Latn', 0.0), ('tha_Thai', 0.0), ('tir_Ethi', 0.0), ('taq_Latn', 0.0), ('taq_Tfng', 0.0), ('tpi_Latn', 0.0), ('tsn_Latn', 0.0), ('tso_Latn', 0.0), ('tuk_Latn', 0.0), ('tum_Latn', 0.0), ('tur_Latn', 0.0), ('twi_Latn', 0.0), ('tzm_Tfng', 0.0), ('uig_Arab', 0.0), ('ukr_Cyrl', 0.0), ('umb_Latn', 0.0), ('urd_Arab', 0.0), ('uzn_Latn', 0.0), ('vec_Latn', 0.0), ('vie_Latn', 0.0), ('war_Latn', 0.0), ('wol_Latn', 0.0), ('xho_Latn', 0.0), ('ydd_Hebr', 0.0), ('yor_Latn', 0.0), ('yue_Hant', 0.0), ('zho_Hans', 0.0), ('zho_Hant', 0.0), ('zul_Latn', 0.0)
+ # fmt: on
+ ]
+ vocab += [("", 0.0)]
+ return vocab
+
+ def unk_id(self, proto):
+ return 3
+
+ def post_processor(self):
+ return processors.TemplateProcessing(
+ single="eng_Latn $A ",
+ pair="eng_Latn $A $B ",
+ special_tokens=[
+ ("eng_Latn", self.original_tokenizer.convert_tokens_to_ids("eng_Latn")),
+ ("", self.original_tokenizer.convert_tokens_to_ids("")),
+ ],
+ )
+
+
+class SeamlessM4TConverter(SpmConverter):
+ def vocab(self, proto):
+ vocab = [
+ ("", 0.0),
+ ("", 0.0),
+ ("", 0.0),
+ ("", 0.0),
+ ]
+ vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
+ return vocab
+
+ def unk_id(self, proto):
+ return self.original_tokenizer.unk_token_id
+
+ def post_processor(self):
+ return processors.TemplateProcessing(
+ single="__eng__ $A ",
+ pair="__eng__ $A $B ",
+ special_tokens=[
+ ("__eng__", self.original_tokenizer.convert_tokens_to_ids("__eng__")),
+ ("", self.original_tokenizer.convert_tokens_to_ids("")),
+ ],
+ )
+
+
+class XLMRobertaConverter(SpmConverter):
+ def vocab(self, proto):
+ vocab = [
+ ("", 0.0),
+ ("", 0.0),
+ ("", 0.0),
+ ("", 0.0),
+ ]
+ vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
+ vocab += [("", 0.0)]
+ return vocab
+
+ def unk_id(self, proto):
+ unk_id = 3
+ return unk_id
+
+ def post_processor(self):
+ return processors.TemplateProcessing(
+ single=" $A ",
+ pair=" $A $B ",
+ special_tokens=[
+ ("", self.original_tokenizer.convert_tokens_to_ids("")),
+ ("", self.original_tokenizer.convert_tokens_to_ids("")),
+ ],
+ )
+
+
+class XLNetConverter(SpmConverter):
+ def vocab(self, proto):
+ return [
+ (piece.piece, piece.score) if check_number_comma(piece.piece) else (piece.piece, piece.score - 100)
+ for piece in proto.pieces
+ ]
+
+ def normalizer(self, proto):
+ list_normalizers = [
+ normalizers.Replace("``", '"'),
+ normalizers.Replace("''", '"'),
+ ]
+ if not self.original_tokenizer.keep_accents:
+ list_normalizers.append(normalizers.NFKD())
+ list_normalizers.append(normalizers.StripAccents())
+ if self.original_tokenizer.do_lower_case:
+ list_normalizers.append(normalizers.Lowercase())
+
+ precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
+
+ if precompiled_charsmap:
+ list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))
+
+ list_normalizers.append(normalizers.Replace(Regex(" {2,}"), " "))
+ return normalizers.Sequence(list_normalizers)
+
+ def post_processor(self):
+ return processors.TemplateProcessing(
+ single="$A:0 :0 :2",
+ pair="$A:0 :0 $B:1 :1 :2",
+ special_tokens=[
+ ("", self.original_tokenizer.convert_tokens_to_ids("")),
+ ("", self.original_tokenizer.convert_tokens_to_ids("")),
+ ],
+ )
+
+
+class ReformerConverter(SpmConverter):
+ pass
+
+
+class RemBertConverter(SpmConverter):
+ # Inspired from AlbertConverter
+ def normalizer(self, proto):
+ list_normalizers = [
+ normalizers.Replace("``", '"'),
+ normalizers.Replace("''", '"'),
+ normalizers.Replace(Regex(" {2,}"), " "),
+ ]
+ if not self.original_tokenizer.keep_accents:
+ list_normalizers.append(normalizers.NFKD())
+ list_normalizers.append(normalizers.StripAccents())
+ if self.original_tokenizer.do_lower_case:
+ list_normalizers.append(normalizers.Lowercase())
+
+ precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
+
+ if precompiled_charsmap:
+ list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))
+
+ return normalizers.Sequence(list_normalizers)
+
+ def post_processor(self):
+ return processors.TemplateProcessing(
+ single="[CLS]:0 $A:0 [SEP]:0",
+ pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
+ special_tokens=[
+ ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
+ ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
+ ],
+ )
+
+
+class BertGenerationConverter(SpmConverter):
+ pass
+
+
+class PegasusConverter(SpmConverter):
+ def vocab(self, proto):
+ vocab = [
+ (self.original_tokenizer.pad_token, 0.0),
+ (self.original_tokenizer.eos_token, 0.0),
+ ]
+
+ if self.original_tokenizer.mask_token_sent is not None:
+ vocab += [(self.original_tokenizer.mask_token_sent, 0.0)]
+
+ if (
+ self.original_tokenizer.mask_token is not None
+ and self.original_tokenizer.mask_token_id < self.original_tokenizer.offset
+ ):
+ vocab += [(self.original_tokenizer.mask_token, 0.0)]
+
+ vocab += [(f"", -100.0) for i in range(2, self.original_tokenizer.offset)]
+ vocab += [(piece.piece, piece.score) for piece in proto.pieces[2:]]
+ return vocab
+
+ def unk_id(self, proto):
+ return proto.trainer_spec.unk_id + self.original_tokenizer.offset
+
+ def pre_tokenizer(self, replacement, add_prefix_space):
+ return pre_tokenizers.Sequence(
+ [
+ pre_tokenizers.WhitespaceSplit(),
+ pre_tokenizers.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space),
+ ]
+ )
+
+ def post_processor(self):
+ eos = self.original_tokenizer.eos_token
+ special_tokens = [
+ (eos, self.original_tokenizer.eos_token_id),
+ ]
+ return processors.TemplateProcessing(single=["$A", eos], pair=["$A", "$B", eos], special_tokens=special_tokens)
+
+
+class T5Converter(SpmConverter):
+ def vocab(self, proto):
+ num_extra_ids = self.original_tokenizer._extra_ids
+ vocab = [(piece.piece, piece.score) for piece in proto.pieces]
+ vocab += [(f"", 0.0) for i in range(num_extra_ids - 1, -1, -1)]
+ return vocab
+
+ def post_processor(self):
+ return processors.TemplateProcessing(
+ single=["$A", ""],
+ pair=["$A", "", "$B", ""],
+ special_tokens=[
+ ("", self.original_tokenizer.convert_tokens_to_ids("")),
+ ],
+ )
+
+
+class WhisperConverter(Converter):
+ def converted(self) -> Tokenizer:
+ vocab = self.original_tokenizer.encoder
+ merges = list(self.original_tokenizer.bpe_ranks.keys())
+
+ tokenizer = Tokenizer(
+ BPE(
+ vocab=vocab,
+ merges=merges,
+ dropout=None,
+ continuing_subword_prefix="",
+ end_of_word_suffix="",
+ fuse_unk=False,
+ )
+ )
+
+ tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=self.original_tokenizer.add_prefix_space)
+ tokenizer.decoder = decoders.ByteLevel()
+
+ prefix_token_ids = self.original_tokenizer.prefix_tokens
+ prefixes = self.original_tokenizer.convert_ids_to_tokens(prefix_token_ids)
+ eos = self.original_tokenizer.eos_token
+ eos_token_id = self.original_tokenizer.eos_token_id
+ prefix_template = " ".join([f"{token}:0" for token in prefixes])
+ tokenizer.post_processor = processors.TemplateProcessing(
+ single=f"{prefix_template} $A:0 {eos}:0",
+ pair=f"{prefix_template} $A:0 $B:1 {eos}:1",
+ special_tokens=[
+ (eos, eos_token_id),
+ *zip(prefixes, prefix_token_ids),
+ ],
+ )
+
+ return tokenizer
+
+
+class BigBirdConverter(SpmConverter):
+ def post_processor(self):
+ return processors.TemplateProcessing(
+ single="[CLS]:0 $A:0 [SEP]:0",
+ pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
+ special_tokens=[
+ ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
+ ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
+ ],
+ )
+
+
+class CLIPConverter(Converter):
+ def converted(self) -> Tokenizer:
+ vocab = self.original_tokenizer.encoder
+ merges = list(self.original_tokenizer.bpe_ranks.keys())
+ unk_token = self.original_tokenizer.unk_token
+
+ tokenizer = Tokenizer(
+ BPE(
+ vocab=vocab,
+ merges=merges,
+ dropout=None,
+ continuing_subword_prefix="",
+ end_of_word_suffix="",
+ fuse_unk=False,
+ unk_token=str(unk_token),
+ )
+ )
+
+ tokenizer.normalizer = normalizers.Sequence(
+ [normalizers.NFC(), normalizers.Replace(Regex(r"\s+"), " "), normalizers.Lowercase()]
+ )
+ tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
+ [
+ pre_tokenizers.Split(
+ Regex(r"""'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+"""),
+ behavior="removed",
+ invert=True,
+ ),
+ pre_tokenizers.ByteLevel(add_prefix_space=False),
+ ]
+ )
+ tokenizer.decoder = decoders.ByteLevel()
+
+ # Hack to have a ByteLevel and TemplaceProcessor
+ tokenizer.post_processor = processors.RobertaProcessing(
+ sep=(self.original_tokenizer.eos_token, self.original_tokenizer.eos_token_id),
+ cls=(self.original_tokenizer.bos_token, self.original_tokenizer.bos_token_id),
+ add_prefix_space=False,
+ trim_offsets=False,
+ )
+ return tokenizer
+
+
+class LayoutLMv2Converter(Converter):
+ def converted(self) -> Tokenizer:
+ vocab = self.original_tokenizer.vocab
+ tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
+
+ tokenize_chinese_chars = False
+ strip_accents = False
+ do_lower_case = True
+ if hasattr(self.original_tokenizer, "basic_tokenizer"):
+ tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
+ strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
+ do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
+
+ tokenizer.normalizer = normalizers.BertNormalizer(
+ clean_text=True,
+ handle_chinese_chars=tokenize_chinese_chars,
+ strip_accents=strip_accents,
+ lowercase=do_lower_case,
+ )
+ tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
+
+ cls = str(self.original_tokenizer.cls_token)
+ sep = str(self.original_tokenizer.sep_token)
+ cls_token_id = self.original_tokenizer.cls_token_id
+ sep_token_id = self.original_tokenizer.sep_token_id
+
+ tokenizer.post_processor = processors.TemplateProcessing(
+ single=f"{cls}:0 $A:0 {sep}:0",
+ pair=f"{cls}:0 $A:0 {sep}:0 $B:1 {sep}:1",
+ special_tokens=[
+ (cls, cls_token_id),
+ (sep, sep_token_id),
+ ],
+ )
+ tokenizer.decoder = decoders.WordPiece(prefix="##")
+
+ return tokenizer
+
+
+class BlenderbotConverter(Converter):
+ def converted(self) -> Tokenizer:
+ ot = self.original_tokenizer
+ vocab = ot.encoder
+ merges = list(ot.bpe_ranks.keys())
+
+ tokenizer = Tokenizer(
+ BPE(
+ vocab=vocab,
+ merges=merges,
+ dropout=None,
+ continuing_subword_prefix="",
+ end_of_word_suffix="",
+ fuse_unk=False,
+ )
+ )
+
+ tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space)
+ tokenizer.decoder = decoders.ByteLevel()
+ tokenizer.post_processor = processors.TemplateProcessing(
+ single=f"$A:0 {ot.eos_token}:0",
+ special_tokens=[
+ (ot.eos_token, ot.eos_token_id),
+ ],
+ )
+
+ return tokenizer
+
+
+class XGLMConverter(SpmConverter):
+ def vocab(self, proto):
+ vocab = [
+ ("", 0.0),
+ ("", 0.0),
+ ("", 0.0),
+ ("", 0.0),
+ ]
+ vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
+ # fmt: off
+ vocab += [("", 0.0), ("", 0.0), ("", 0.0), ("", 0.0), ("", 0.0), ("", 0.0), ("", 0.0)]
+ # fmt: on
+ return vocab
+
+ def unk_id(self, proto):
+ unk_id = 3
+ return unk_id
+
+ def post_processor(self):
+ return processors.TemplateProcessing(
+ single=" $A",
+ pair=" $A $B",
+ special_tokens=[
+ ("", self.original_tokenizer.convert_tokens_to_ids("")),
+ ("", self.original_tokenizer.convert_tokens_to_ids("")),
+ ],
+ )
+
+
+class LlamaConverter(SpmConverter):
+ handle_byte_fallback = True
+
+ def vocab(self, proto):
+ vocab = [
+ ("", 0.0),
+ ("", 0.0),
+ ("", 0.0),
+ ]
+ vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
+ return vocab
+
+ def unk_id(self, proto):
+ unk_id = 0
+ return unk_id
+
+ def decoder(self, replacement, add_prefix_space):
+ return decoders.Sequence(
+ [
+ decoders.Replace("▁", " "),
+ decoders.ByteFallback(),
+ decoders.Fuse(),
+ decoders.Strip(content=" ", left=1),
+ ]
+ )
+
+ def tokenizer(self, proto):
+ model_type = proto.trainer_spec.model_type
+ vocab_scores = self.vocab(proto)
+ if model_type == 1:
+ import tokenizers
+
+ if version.parse(tokenizers.__version__) < version.parse("0.14.0"):
+ tokenizer = Tokenizer(Unigram(vocab_scores, 0))
+ else:
+ tokenizer = Tokenizer(Unigram(vocab_scores, 0, byte_fallback=True))
+
+ elif model_type == 2:
+ _, merges = SentencePieceExtractor(self.original_tokenizer.vocab_file).extract(vocab_scores)
+ bpe_vocab = {word: i for i, (word, _score) in enumerate(vocab_scores)}
+ tokenizer = Tokenizer(
+ BPE(bpe_vocab, merges, unk_token=proto.trainer_spec.unk_piece, fuse_unk=True, byte_fallback=True)
+ )
+ tokenizer.add_special_tokens(
+ [
+ AddedToken("", normalized=False, special=True),
+ AddedToken("", normalized=False, special=True),
+ AddedToken("", normalized=False, special=True),
+ ]
+ )
+ else:
+ raise Exception(
+ "You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
+ )
+
+ return tokenizer
+
+ def normalizer(self, proto):
+ return normalizers.Sequence(
+ [
+ normalizers.Prepend(prepend="▁"),
+ normalizers.Replace(pattern=" ", content="▁"),
+ ]
+ )
+
+ def pre_tokenizer(self, replacement, add_prefix_space):
+ return None
+
+ def post_processor(self):
+ # the processor is defined in the LlamaTokenizerFast class.
+ return None
+
+
+class MarkupLMConverter(Converter):
+ def converted(self) -> Tokenizer:
+ ot = self.original_tokenizer
+ vocab = ot.encoder
+ merges = list(ot.bpe_ranks.keys())
+
+ tokenizer = Tokenizer(
+ BPE(
+ vocab=vocab,
+ merges=merges,
+ dropout=None,
+ continuing_subword_prefix="",
+ end_of_word_suffix="",
+ fuse_unk=False,
+ unk_token=self.original_tokenizer.unk_token,
+ )
+ )
+
+ tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space)
+ tokenizer.decoder = decoders.ByteLevel()
+
+ cls = str(self.original_tokenizer.cls_token)
+ sep = str(self.original_tokenizer.sep_token)
+ cls_token_id = self.original_tokenizer.cls_token_id
+ sep_token_id = self.original_tokenizer.sep_token_id
+
+ tokenizer.post_processor = processors.TemplateProcessing(
+ single=f"{cls} $A {sep}",
+ pair=f"{cls} $A {sep} $B {sep}",
+ special_tokens=[
+ (cls, cls_token_id),
+ (sep, sep_token_id),
+ ],
+ )
+
+ return tokenizer
+
+class MarianConverter(SpmConverter):
+ def __init__(self, *args, index: int = 0):
+ requires_backends(self, "protobuf")
+
+ super(SpmConverter, self).__init__(*args)
+
+ # from .utils import sentencepiece_model_pb2 as model_pb2
+ model_pb2 = import_protobuf()
+
+ m = model_pb2.ModelProto()
+ print(self.original_tokenizer.spm_files)
+ with open(self.original_tokenizer.spm_files[index], "rb") as f:
+ m.ParseFromString(f.read())
+ self.proto = m
+ print(self.original_tokenizer)
+ #with open(self.original_tokenizer.vocab_path, "r") as f:
+ dir_path = Path(self.original_tokenizer.spm_files[0]).parents[0]
+ with open(dir_path / "vocab.json", "r") as f:
+ import json
+ self._vocab = json.load(f)
+
+ if self.proto.trainer_spec.byte_fallback:
+ if not getattr(self, "handle_byte_fallback", None):
+ warnings.warn(
+ "The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option"
+ " which is not implemented in the fast tokenizers. In practice this means that the fast version of the"
+ " tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these "
+ "unknown tokens into a sequence of byte tokens matching the original piece of text."
+ )
+
+ def vocab(self, proto):
+ vocab_size = max(self._vocab.values()) + 1
+ vocab = [("", -100) for _ in range(vocab_size)]
+ for piece in proto.pieces:
+ try:
+ index = self._vocab[piece.piece]
+ except Exception:
+ print(f"Ignored missing piece {piece.piece}")
+ vocab[index] = (piece.piece, piece.score)
+ return vocab
+
+SLOW_TO_FAST_CONVERTERS = {
+ "AlbertTokenizer": AlbertConverter,
+ "BartTokenizer": RobertaConverter,
+ "BarthezTokenizer": BarthezConverter,
+ "BertTokenizer": BertConverter,
+ "BigBirdTokenizer": BigBirdConverter,
+ "BlenderbotTokenizer": BlenderbotConverter,
+ "CamembertTokenizer": CamembertConverter,
+ "CLIPTokenizer": CLIPConverter,
+ "CodeGenTokenizer": GPT2Converter,
+ "ConvBertTokenizer": BertConverter,
+ "DebertaTokenizer": DebertaConverter,
+ "DebertaV2Tokenizer": DebertaV2Converter,
+ "DistilBertTokenizer": BertConverter,
+ "DPRReaderTokenizer": BertConverter,
+ "DPRQuestionEncoderTokenizer": BertConverter,
+ "DPRContextEncoderTokenizer": BertConverter,
+ "ElectraTokenizer": BertConverter,
+ "FNetTokenizer": AlbertConverter,
+ "FunnelTokenizer": FunnelConverter,
+ "GPT2Tokenizer": GPT2Converter,
+ "HerbertTokenizer": HerbertConverter,
+ "LayoutLMTokenizer": BertConverter,
+ "LayoutLMv2Tokenizer": BertConverter,
+ "LayoutLMv3Tokenizer": RobertaConverter,
+ "LayoutXLMTokenizer": XLMRobertaConverter,
+ "LongformerTokenizer": RobertaConverter,
+ "LEDTokenizer": RobertaConverter,
+ "LxmertTokenizer": BertConverter,
+ "MarkupLMTokenizer": MarkupLMConverter,
+ "MBartTokenizer": MBartConverter,
+ "MBart50Tokenizer": MBart50Converter,
+ "MPNetTokenizer": MPNetConverter,
+ "MobileBertTokenizer": BertConverter,
+ "MvpTokenizer": RobertaConverter,
+ "NllbTokenizer": NllbConverter,
+ "OpenAIGPTTokenizer": OpenAIGPTConverter,
+ "PegasusTokenizer": PegasusConverter,
+ "RealmTokenizer": BertConverter,
+ "ReformerTokenizer": ReformerConverter,
+ "RemBertTokenizer": RemBertConverter,
+ "RetriBertTokenizer": BertConverter,
+ "RobertaTokenizer": RobertaConverter,
+ "RoFormerTokenizer": RoFormerConverter,
+ "SeamlessM4TTokenizer": SeamlessM4TConverter,
+ "SqueezeBertTokenizer": BertConverter,
+ "T5Tokenizer": T5Converter,
+ "WhisperTokenizer": WhisperConverter,
+ "XLMRobertaTokenizer": XLMRobertaConverter,
+ "XLNetTokenizer": XLNetConverter,
+ "SplinterTokenizer": SplinterConverter,
+ "XGLMTokenizer": XGLMConverter,
+ "LlamaTokenizer": LlamaConverter,
+ "CodeLlamaTokenizer": LlamaConverter,
+}
+
+
+def convert_slow_tokenizer(transformer_tokenizer) -> Tokenizer:
+ """
+ Utilities to convert a slow tokenizer instance in a fast tokenizer instance.
+
+ Args:
+ transformer_tokenizer ([`~tokenization_utils_base.PreTrainedTokenizer`]):
+ Instance of a slow tokenizer to convert in the backend tokenizer for
+ [`~tokenization_utils_base.PreTrainedTokenizerFast`].
+
+ Return:
+ A instance of [`~tokenizers.Tokenizer`] to be used as the backend tokenizer of a
+ [`~tokenization_utils_base.PreTrainedTokenizerFast`]
+ """
+
+ tokenizer_class_name = transformer_tokenizer.__class__.__name__
+
+ if tokenizer_class_name not in SLOW_TO_FAST_CONVERTERS:
+ raise ValueError(
+ f"An instance of tokenizer class {tokenizer_class_name} cannot be converted in a Fast tokenizer instance."
+ " No converter was found. Currently available slow->fast convertors:"
+ f" {list(SLOW_TO_FAST_CONVERTERS.keys())}"
+ )
+
+ converter_class = SLOW_TO_FAST_CONVERTERS[tokenizer_class_name]
+
+ return converter_class(transformer_tokenizer).converted()