diff --git a/candle-examples/examples/marian-mt/README.md b/candle-examples/examples/marian-mt/README.md
index eecaee32..8ebd7f34 100644
--- a/candle-examples/examples/marian-mt/README.md
+++ b/candle-examples/examples/marian-mt/README.md
@@ -18,21 +18,19 @@ I know you are waiting for me. I will go through the forest, I will go through t
mountain. I cannot stay far from you any longer.
```
+### Changing model and language pairs
+
+```bash
+$ cargo run --example marian-mt --release -- --text "hello, how are you." --which base --language-pair en-zh
+
+你好,你好吗?
+```
+
## Generating the tokenizer.json files
-You can use the following script to generate the `tokenizer.json` config files
-from the hf-hub repos. This requires the `tokenizers` and `sentencepiece`
-packages to be install and use the `convert_slow_tokenizer.py` script from this
-directory.
-
-```python
-from convert_slow_tokenizer import MarianConverter
-from transformers import AutoTokenizer
-
-
-tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-fr-en", use_fast=False)
-fast_tokenizer = MarianConverter(tokenizer, index=0).converted()
-fast_tokenizer.save(f"tokenizer-marian-base-fr.json")
-fast_tokenizer = MarianConverter(tokenizer, index=1).converted()
-fast_tokenizer.save(f"tokenizer-marian-base-en.json")
-```
+The tokenizer for each `marian-mt` model was trained independently,
+meaning each new model needs unique tokenizer encoders and decoders.
+You can use the `./python/convert_slow_tokenizer.py` script in this directory to generate
+the `tokenizer.json` config files from the hf-hub repos.
+The script requires all the packages in `./python/requirements.txt` or `./python/uv.lock`
+to be installed, and has only been tested for `python 3.12.7`.
diff --git a/candle-examples/examples/marian-mt/convert_slow_tokenizer.py b/candle-examples/examples/marian-mt/convert_slow_tokenizer.py
deleted file mode 100644
index 33a887b6..00000000
--- a/candle-examples/examples/marian-mt/convert_slow_tokenizer.py
+++ /dev/null
@@ -1,1397 +0,0 @@
-# coding=utf-8
-# Copyright 2018 The HuggingFace Inc. team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""
-Utilities to convert slow tokenizers in their fast tokenizers counterparts.
-
-All the conversions are grouped here to gather SentencePiece dependencies outside of the fast tokenizers files and
-allow to make our dependency on SentencePiece optional.
-"""
-
-import warnings
-from typing import Dict, List, Tuple
-
-from packaging import version
-from pathlib import Path
-from tokenizers import AddedToken, Regex, Tokenizer, decoders, normalizers, pre_tokenizers, processors
-from tokenizers.models import BPE, Unigram, WordPiece
-
-from transformers.utils import is_protobuf_available, requires_backends
-from transformers.utils.import_utils import PROTOBUF_IMPORT_ERROR
-
-
-def import_protobuf(error_message=""):
- if is_protobuf_available():
- import google.protobuf
-
- if version.parse(google.protobuf.__version__) < version.parse("4.0.0"):
- from transformers.utils import sentencepiece_model_pb2
- else:
- from transformers.utils import sentencepiece_model_pb2_new as sentencepiece_model_pb2
- return sentencepiece_model_pb2
- else:
- raise ImportError(PROTOBUF_IMPORT_ERROR.format(error_message))
-
-def _get_prepend_scheme(add_prefix_space: bool, original_tokenizer) -> str:
- if add_prefix_space:
- prepend_scheme = "always"
- if hasattr(original_tokenizer, "legacy") and not original_tokenizer.legacy:
- prepend_scheme = "first"
- else:
- prepend_scheme = "never"
- return prepend_scheme
-
-class SentencePieceExtractor:
- """
- Extractor implementation for SentencePiece trained models. https://github.com/google/sentencepiece
- """
-
- def __init__(self, model: str):
- requires_backends(self, "sentencepiece")
- from sentencepiece import SentencePieceProcessor
-
- self.sp = SentencePieceProcessor()
- self.sp.Load(model)
-
- def extract(self, vocab_scores=None) -> Tuple[Dict[str, int], List[Tuple]]:
- """
- By default will return vocab and merges with respect to their order, by sending `vocab_scores` we're going to
- order the merges with respect to the piece scores instead.
- """
- sp = self.sp
- vocab = {sp.id_to_piece(index): index for index in range(sp.GetPieceSize())}
- if vocab_scores is not None:
- vocab_scores, reverse = dict(vocab_scores), True
- else:
- vocab_scores, reverse = vocab, False
-
- # Merges
- merges = []
- for merge, piece_score in vocab_scores.items():
- local = []
- for index in range(1, len(merge)):
- piece_l, piece_r = merge[:index], merge[index:]
- if piece_l in vocab and piece_r in vocab:
- local.append((piece_l, piece_r, piece_score))
- local = sorted(local, key=lambda x: (vocab[x[0]], vocab[x[1]]))
- merges.extend(local)
-
- merges = sorted(merges, key=lambda val: val[2], reverse=reverse)
- merges = [(val[0], val[1]) for val in merges]
- return vocab, merges
-
-
-def check_number_comma(piece: str) -> bool:
- return len(piece) < 2 or piece[-1] != "," or not piece[-2].isdigit()
-
-
-class Converter:
- def __init__(self, original_tokenizer):
- self.original_tokenizer = original_tokenizer
-
- def converted(self) -> Tokenizer:
- raise NotImplementedError()
-
-
-class BertConverter(Converter):
- def converted(self) -> Tokenizer:
- vocab = self.original_tokenizer.vocab
- tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
-
- tokenize_chinese_chars = False
- strip_accents = False
- do_lower_case = False
- if hasattr(self.original_tokenizer, "basic_tokenizer"):
- tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
- strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
- do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
-
- tokenizer.normalizer = normalizers.BertNormalizer(
- clean_text=True,
- handle_chinese_chars=tokenize_chinese_chars,
- strip_accents=strip_accents,
- lowercase=do_lower_case,
- )
- tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
-
- cls = str(self.original_tokenizer.cls_token)
- sep = str(self.original_tokenizer.sep_token)
- cls_token_id = self.original_tokenizer.cls_token_id
- sep_token_id = self.original_tokenizer.sep_token_id
-
- tokenizer.post_processor = processors.TemplateProcessing(
- single=f"{cls}:0 $A:0 {sep}:0",
- pair=f"{cls}:0 $A:0 {sep}:0 $B:1 {sep}:1",
- special_tokens=[
- (cls, cls_token_id),
- (sep, sep_token_id),
- ],
- )
- tokenizer.decoder = decoders.WordPiece(prefix="##")
-
- return tokenizer
-
-
-class SplinterConverter(Converter):
- def converted(self) -> Tokenizer:
- vocab = self.original_tokenizer.vocab
- tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
-
- tokenize_chinese_chars = False
- strip_accents = False
- do_lower_case = False
- if hasattr(self.original_tokenizer, "basic_tokenizer"):
- tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
- strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
- do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
-
- tokenizer.normalizer = normalizers.BertNormalizer(
- clean_text=True,
- handle_chinese_chars=tokenize_chinese_chars,
- strip_accents=strip_accents,
- lowercase=do_lower_case,
- )
- tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
-
- cls = str(self.original_tokenizer.cls_token)
- sep = str(self.original_tokenizer.sep_token)
- question = str(self.original_tokenizer.question_token)
- dot = "."
- cls_token_id = self.original_tokenizer.cls_token_id
- sep_token_id = self.original_tokenizer.sep_token_id
- question_token_id = self.original_tokenizer.question_token_id
- dot_token_id = self.original_tokenizer.convert_tokens_to_ids(".")
-
- if self.original_tokenizer.padding_side == "right":
- pair = f"{cls}:0 $A:0 {question} {dot} {sep}:0 $B:1 {sep}:1"
- else:
- pair = f"{cls}:0 $A:0 {sep}:0 $B:1 {question} {dot} {sep}:1"
-
- tokenizer.post_processor = processors.TemplateProcessing(
- single=f"{cls}:0 $A:0 {sep}:0",
- pair=pair,
- special_tokens=[
- (cls, cls_token_id),
- (sep, sep_token_id),
- (question, question_token_id),
- (dot, dot_token_id),
- ],
- )
- tokenizer.decoder = decoders.WordPiece(prefix="##")
-
- return tokenizer
-
-
-class FunnelConverter(Converter):
- def converted(self) -> Tokenizer:
- vocab = self.original_tokenizer.vocab
- tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
-
- tokenize_chinese_chars = False
- strip_accents = False
- do_lower_case = False
- if hasattr(self.original_tokenizer, "basic_tokenizer"):
- tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
- strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
- do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
-
- tokenizer.normalizer = normalizers.BertNormalizer(
- clean_text=True,
- handle_chinese_chars=tokenize_chinese_chars,
- strip_accents=strip_accents,
- lowercase=do_lower_case,
- )
- tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
-
- cls = str(self.original_tokenizer.cls_token)
- sep = str(self.original_tokenizer.sep_token)
- cls_token_id = self.original_tokenizer.cls_token_id
- sep_token_id = self.original_tokenizer.sep_token_id
-
- tokenizer.post_processor = processors.TemplateProcessing(
- single=f"{cls}:2 $A:0 {sep}:0", # token_type_id is 2 for Funnel transformer
- pair=f"{cls}:2 $A:0 {sep}:0 $B:1 {sep}:1",
- special_tokens=[
- (cls, cls_token_id),
- (sep, sep_token_id),
- ],
- )
- tokenizer.decoder = decoders.WordPiece(prefix="##")
-
- return tokenizer
-
-
-class MPNetConverter(Converter):
- def converted(self) -> Tokenizer:
- vocab = self.original_tokenizer.vocab
- tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
-
- tokenize_chinese_chars = False
- strip_accents = False
- do_lower_case = False
- if hasattr(self.original_tokenizer, "basic_tokenizer"):
- tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
- strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
- do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
-
- tokenizer.normalizer = normalizers.BertNormalizer(
- clean_text=True,
- handle_chinese_chars=tokenize_chinese_chars,
- strip_accents=strip_accents,
- lowercase=do_lower_case,
- )
- tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
-
- cls = str(self.original_tokenizer.cls_token)
- sep = str(self.original_tokenizer.sep_token)
- cls_token_id = self.original_tokenizer.cls_token_id
- sep_token_id = self.original_tokenizer.sep_token_id
-
- tokenizer.post_processor = processors.TemplateProcessing(
- single=f"{cls}:0 $A:0 {sep}:0",
- pair=f"{cls}:0 $A:0 {sep}:0 {sep}:0 $B:1 {sep}:1", # MPNet uses two [SEP] tokens
- special_tokens=[
- (cls, cls_token_id),
- (sep, sep_token_id),
- ],
- )
- tokenizer.decoder = decoders.WordPiece(prefix="##")
-
- return tokenizer
-
-
-class OpenAIGPTConverter(Converter):
- def converted(self) -> Tokenizer:
- vocab = self.original_tokenizer.encoder
- merges = list(self.original_tokenizer.bpe_ranks.keys())
- unk_token = self.original_tokenizer.unk_token
-
- tokenizer = Tokenizer(
- BPE(
- vocab=vocab,
- merges=merges,
- dropout=None,
- unk_token=str(unk_token),
- end_of_word_suffix="",
- fuse_unk=False,
- )
- )
-
- if tokenizer.token_to_id(str(unk_token)) is not None:
- tokenizer.add_special_tokens([str(unk_token)])
-
- tokenizer.normalizer = normalizers.BertNormalizer(lowercase=True)
- tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
- tokenizer.decoder = decoders.BPEDecoder(suffix="")
-
- return tokenizer
-
-
-class GPT2Converter(Converter):
- def converted(self) -> Tokenizer:
- vocab = self.original_tokenizer.encoder
- merges = list(self.original_tokenizer.bpe_ranks.keys())
-
- tokenizer = Tokenizer(
- BPE(
- vocab=vocab,
- merges=merges,
- dropout=None,
- continuing_subword_prefix="",
- end_of_word_suffix="",
- fuse_unk=False,
- )
- )
-
- tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=self.original_tokenizer.add_prefix_space)
- tokenizer.decoder = decoders.ByteLevel()
- if self.original_tokenizer.add_bos_token:
- bos = self.original_tokenizer.bos_token
- bos_token_id = self.original_tokenizer.bos_token_id
- tokenizer.post_processor = processors.TemplateProcessing(
- single=f"{bos}:0 $A:0",
- pair=f"{bos}:0 $A:0 $B:1",
- special_tokens=[
- (bos, bos_token_id),
- ],
- )
- else:
- # XXX trim_offsets=False actually means this post_processor doesn't
- # really do anything.
- tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
- return tokenizer
-
-
-class HerbertConverter(Converter):
- def converted(self) -> Tokenizer:
- tokenizer_info_str = "#version:"
- token_suffix = ""
-
- vocab = self.original_tokenizer.encoder
- merges = list(self.original_tokenizer.bpe_ranks.keys())
- if tokenizer_info_str in merges[0][0]:
- merges = merges[1:]
-
- tokenizer = Tokenizer(
- BPE(
- vocab,
- merges,
- dropout=None,
- unk_token=self.original_tokenizer.unk_token,
- end_of_word_suffix=token_suffix,
- )
- )
-
- tokenizer.normalizer = normalizers.BertNormalizer(lowercase=False, strip_accents=False)
- tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
- tokenizer.decoder = decoders.BPEDecoder(suffix=token_suffix)
- tokenizer.post_processor = processors.BertProcessing(
- sep=(self.original_tokenizer.sep_token, self.original_tokenizer.sep_token_id),
- cls=(self.original_tokenizer.cls_token, self.original_tokenizer.cls_token_id),
- )
-
- return tokenizer
-
-
-class RobertaConverter(Converter):
- def converted(self) -> Tokenizer:
- ot = self.original_tokenizer
- vocab = ot.encoder
- merges = list(ot.bpe_ranks.keys())
-
- tokenizer = Tokenizer(
- BPE(
- vocab=vocab,
- merges=merges,
- dropout=None,
- continuing_subword_prefix="",
- end_of_word_suffix="",
- fuse_unk=False,
- )
- )
-
- tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space)
- tokenizer.decoder = decoders.ByteLevel()
- tokenizer.post_processor = processors.RobertaProcessing(
- sep=(ot.sep_token, ot.sep_token_id),
- cls=(ot.cls_token, ot.cls_token_id),
- add_prefix_space=ot.add_prefix_space,
- trim_offsets=True, # True by default on Roberta (historical)
- )
-
- return tokenizer
-
-
-class RoFormerConverter(Converter):
- def converted(self) -> Tokenizer:
- from .models.roformer.tokenization_utils import JiebaPreTokenizer
-
- vocab = self.original_tokenizer.vocab
- tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
-
- strip_accents = False
- do_lower_case = False
- if hasattr(self.original_tokenizer, "basic_tokenizer"):
- strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
- do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
-
- tokenizer.normalizer = normalizers.BertNormalizer(
- clean_text=True,
- handle_chinese_chars=False,
- strip_accents=strip_accents,
- lowercase=do_lower_case,
- )
- tokenizer.pre_tokenizer = pre_tokenizers.PreTokenizer.custom(JiebaPreTokenizer(vocab))
-
- cls = str(self.original_tokenizer.cls_token)
- sep = str(self.original_tokenizer.sep_token)
- cls_token_id = self.original_tokenizer.cls_token_id
- sep_token_id = self.original_tokenizer.sep_token_id
-
- tokenizer.post_processor = processors.TemplateProcessing(
- single=f"{cls}:0 $A:0 {sep}:0",
- pair=f"{cls}:0 $A:0 {sep}:0 $B:1 {sep}:1",
- special_tokens=[
- (cls, cls_token_id),
- (sep, sep_token_id),
- ],
- )
- tokenizer.decoder = decoders.WordPiece(prefix="##")
-
- return tokenizer
-
-
-class DebertaConverter(Converter):
- def converted(self) -> Tokenizer:
- ot = self.original_tokenizer
- vocab = ot.encoder
- merges = list(ot.bpe_ranks.keys())
-
- tokenizer = Tokenizer(
- BPE(
- vocab=vocab,
- merges=merges,
- dropout=None,
- continuing_subword_prefix="",
- end_of_word_suffix="",
- fuse_unk=False,
- )
- )
-
- tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space)
- tokenizer.decoder = decoders.ByteLevel()
- tokenizer.post_processor = processors.TemplateProcessing(
- single="[CLS]:0 $A:0 [SEP]:0",
- pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
- special_tokens=[
- ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
- ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
- ],
- )
-
- return tokenizer
-
-
-class SpmConverter(Converter):
- def __init__(self, *args):
- requires_backends(self, "protobuf")
-
- super().__init__(*args)
-
- # from .utils import sentencepiece_model_pb2 as model_pb2
- model_pb2 = import_protobuf()
-
- m = model_pb2.ModelProto()
- with open(self.original_tokenizer.vocab_file, "rb") as f:
- m.ParseFromString(f.read())
- self.proto = m
-
- if self.proto.trainer_spec.byte_fallback:
- if not getattr(self, "handle_byte_fallback", None):
- warnings.warn(
- "The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option"
- " which is not implemented in the fast tokenizers. In practice this means that the fast version of the"
- " tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these "
- "unknown tokens into a sequence of byte tokens matching the original piece of text."
- )
-
- def vocab(self, proto):
- return [(piece.piece, piece.score) for piece in proto.pieces]
-
- def unk_id(self, proto):
- return proto.trainer_spec.unk_id
-
- def tokenizer(self, proto):
- model_type = proto.trainer_spec.model_type
- vocab_scores = self.vocab(proto)
- unk_id = self.unk_id(proto)
-
- if model_type == 1:
- tokenizer = Tokenizer(Unigram(vocab_scores, unk_id))
- elif model_type == 2:
- _, merges = SentencePieceExtractor(self.original_tokenizer.vocab_file).extract()
- bpe_vocab = {word: i for i, (word, score) in enumerate(vocab_scores)}
- tokenizer = Tokenizer(
- BPE(
- bpe_vocab,
- merges,
- unk_token=proto.trainer_spec.unk_piece,
- fuse_unk=True,
- )
- )
- else:
- raise Exception(
- "You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
- )
-
- return tokenizer
-
- def normalizer(self, proto):
- precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
- if not precompiled_charsmap:
- return normalizers.Sequence([normalizers.Replace(Regex(" {2,}"), " ")])
- else:
- return normalizers.Sequence(
- [normalizers.Precompiled(precompiled_charsmap), normalizers.Replace(Regex(" {2,}"), " ")]
- )
-
- def pre_tokenizer(self, replacement, add_prefix_space):
- prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
- return pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme)
-
- def post_processor(self):
- return None
-
- def decoder(self, replacement, add_prefix_space):
- prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
- return decoders.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme)
-
- def converted(self) -> Tokenizer:
- tokenizer = self.tokenizer(self.proto)
-
- # Tokenizer assemble
- normalizer = self.normalizer(self.proto)
- if normalizer is not None:
- tokenizer.normalizer = normalizer
-
- replacement = "▁"
- add_prefix_space = True
- pre_tokenizer = self.pre_tokenizer(replacement, add_prefix_space)
- if pre_tokenizer is not None:
- tokenizer.pre_tokenizer = pre_tokenizer
-
- tokenizer.decoder = self.decoder(replacement, add_prefix_space)
- post_processor = self.post_processor()
- if post_processor:
- tokenizer.post_processor = post_processor
-
- return tokenizer
-
-
-class AlbertConverter(SpmConverter):
- def vocab(self, proto):
- return [
- (piece.piece, piece.score) if check_number_comma(piece.piece) else (piece.piece, piece.score - 100)
- for piece in proto.pieces
- ]
-
- def normalizer(self, proto):
- list_normalizers = [
- normalizers.Replace("``", '"'),
- normalizers.Replace("''", '"'),
- ]
- if not self.original_tokenizer.keep_accents:
- list_normalizers.append(normalizers.NFKD())
- list_normalizers.append(normalizers.StripAccents())
- if self.original_tokenizer.do_lower_case:
- list_normalizers.append(normalizers.Lowercase())
-
- precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
-
- if precompiled_charsmap:
- list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))
-
- list_normalizers.append(normalizers.Replace(Regex(" {2,}"), " "))
- return normalizers.Sequence(list_normalizers)
-
- def post_processor(self):
- return processors.TemplateProcessing(
- single="[CLS]:0 $A:0 [SEP]:0",
- pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
- special_tokens=[
- ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
- ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
- ],
- )
-
-
-class BarthezConverter(SpmConverter):
- def unk_id(self, proto):
- unk_id = 3
- return unk_id
-
- def post_processor(self):
- return processors.TemplateProcessing(
- single=" $A ",
- pair=" $A $B ",
- special_tokens=[
- ("", self.original_tokenizer.convert_tokens_to_ids("")),
- ("", self.original_tokenizer.convert_tokens_to_ids("")),
- ],
- )
-
-
-class CamembertConverter(SpmConverter):
- def vocab(self, proto):
- vocab = [
- ("NOTUSED", 0.0),
- ("", 0.0),
- ("NOTUSED", 0.0),
- ("", 0.0),
- ("NOTUSED", -100),
- ]
- # We down-grade the original SentencePiece by -100 to avoid using it and use our added token instead
- vocab += [(piece.piece, piece.score) for piece in proto.pieces[1:]]
- vocab += [("", 0.0)]
- return vocab
-
- def unk_id(self, proto):
- # See vocab unk position
- return 3
-
- def post_processor(self):
- return processors.TemplateProcessing(
- single=" $A ",
- pair=" $A $B ",
- special_tokens=[
- ("", self.original_tokenizer.convert_tokens_to_ids("")),
- ("", self.original_tokenizer.convert_tokens_to_ids("")),
- ],
- )
-
-
-class DebertaV2Converter(SpmConverter):
- def pre_tokenizer(self, replacement, add_prefix_space):
- list_pretokenizers = []
- if self.original_tokenizer.split_by_punct:
- list_pretokenizers.append(pre_tokenizers.Punctuation(behavior="isolated"))
- prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
- list_pretokenizers.append(pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme))
- return pre_tokenizers.Sequence(list_pretokenizers)
-
- def normalizer(self, proto):
- list_normalizers = []
- if self.original_tokenizer.do_lower_case:
- list_normalizers.append(normalizers.Lowercase())
- list_normalizers.append(normalizers.Strip())
-
- precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
- if precompiled_charsmap:
- list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))
- list_normalizers.append(normalizers.Replace(Regex(" {2,}"), " "))
-
- return normalizers.Sequence(list_normalizers)
-
- def post_processor(self):
- return processors.TemplateProcessing(
- single="[CLS]:0 $A:0 [SEP]:0",
- pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
- special_tokens=[
- ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
- ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
- ],
- )
-
-
-class MBartConverter(SpmConverter):
- def vocab(self, proto):
- vocab = [
- ("", 0.0),
- ("", 0.0),
- ("", 0.0),
- ("", 0.0),
- ]
- vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
- vocab += [
- ("ar_AR", 0.0),
- ("cs_CZ", 0.0),
- ("de_DE", 0.0),
- ("en_XX", 0.0),
- ("es_XX", 0.0),
- ("et_EE", 0.0),
- ("fi_FI", 0.0),
- ("fr_XX", 0.0),
- ("gu_IN", 0.0),
- ("hi_IN", 0.0),
- ("it_IT", 0.0),
- ("ja_XX", 0.0),
- ("kk_KZ", 0.0),
- ("ko_KR", 0.0),
- ("lt_LT", 0.0),
- ("lv_LV", 0.0),
- ("my_MM", 0.0),
- ("ne_NP", 0.0),
- ("nl_XX", 0.0),
- ("ro_RO", 0.0),
- ("ru_RU", 0.0),
- ("si_LK", 0.0),
- ("tr_TR", 0.0),
- ("vi_VN", 0.0),
- ("zh_CN", 0.0),
- ]
- vocab += [("", 0.0)]
- return vocab
-
- def unk_id(self, proto):
- return 3
-
- def post_processor(self):
- return processors.TemplateProcessing(
- single="$A en_XX",
- pair="$A $B en_XX",
- special_tokens=[
- ("en_XX", self.original_tokenizer.convert_tokens_to_ids("en_XX")),
- ("", self.original_tokenizer.convert_tokens_to_ids("")),
- ],
- )
-
-
-class MBart50Converter(SpmConverter):
- def vocab(self, proto):
- vocab = [
- ("", 0.0),
- ("", 0.0),
- ("", 0.0),
- ("", 0.0),
- ]
- vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
- # fmt: off
- vocab += [("ar_AR", 0.0), ("cs_CZ", 0.0), ("de_DE", 0.0), ("en_XX", 0.0), ("es_XX", 0.0), ("et_EE", 0.0), ("fi_FI", 0.0), ("fr_XX", 0.0), ("gu_IN", 0.0), ("hi_IN", 0.0), ("it_IT", 0.0), ("ja_XX", 0.0), ("kk_KZ", 0.0), ("ko_KR", 0.0), ("lt_LT", 0.0), ("lv_LV", 0.0), ("my_MM", 0.0), ("ne_NP", 0.0), ("nl_XX", 0.0), ("ro_RO", 0.0), ("ru_RU", 0.0), ("si_LK", 0.0), ("tr_TR", 0.0), ("vi_VN", 0.0), ("zh_CN", 0.0), ("af_ZA", 0.0), ("az_AZ", 0.0), ("bn_IN", 0.0), ("fa_IR", 0.0), ("he_IL", 0.0), ("hr_HR", 0.0), ("id_ID", 0.0), ("ka_GE", 0.0), ("km_KH", 0.0), ("mk_MK", 0.0), ("ml_IN", 0.0), ("mn_MN", 0.0), ("mr_IN", 0.0), ("pl_PL", 0.0), ("ps_AF", 0.0), ("pt_XX", 0.0), ("sv_SE", 0.0), ("sw_KE", 0.0), ("ta_IN", 0.0), ("te_IN", 0.0), ("th_TH", 0.0), ("tl_XX", 0.0), ("uk_UA", 0.0), ("ur_PK", 0.0), ("xh_ZA", 0.0), ("gl_ES", 0.0), ("sl_SI", 0.0)]
- # fmt: on
- vocab += [("", 0.0)]
- return vocab
-
- def unk_id(self, proto):
- return 3
-
- def post_processor(self):
- return processors.TemplateProcessing(
- single="en_XX $A ",
- pair="en_XX $A $B ",
- special_tokens=[
- ("en_XX", self.original_tokenizer.convert_tokens_to_ids("en_XX")),
- ("", self.original_tokenizer.convert_tokens_to_ids("")),
- ],
- )
-
-
-class NllbConverter(SpmConverter):
- def vocab(self, proto):
- vocab = [
- ("", 0.0),
- ("", 0.0),
- ("", 0.0),
- ("", 0.0),
- ]
- vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
- vocab += [
- # fmt: off
- ('ace_Arab', 0.0), ('ace_Latn', 0.0), ('acm_Arab', 0.0), ('acq_Arab', 0.0), ('aeb_Arab', 0.0), ('afr_Latn', 0.0), ('ajp_Arab', 0.0), ('aka_Latn', 0.0), ('amh_Ethi', 0.0), ('apc_Arab', 0.0), ('arb_Arab', 0.0), ('ars_Arab', 0.0), ('ary_Arab', 0.0), ('arz_Arab', 0.0), ('asm_Beng', 0.0), ('ast_Latn', 0.0), ('awa_Deva', 0.0), ('ayr_Latn', 0.0), ('azb_Arab', 0.0), ('azj_Latn', 0.0), ('bak_Cyrl', 0.0), ('bam_Latn', 0.0), ('ban_Latn', 0.0), ('bel_Cyrl', 0.0), ('bem_Latn', 0.0), ('ben_Beng', 0.0), ('bho_Deva', 0.0), ('bjn_Arab', 0.0), ('bjn_Latn', 0.0), ('bod_Tibt', 0.0), ('bos_Latn', 0.0), ('bug_Latn', 0.0), ('bul_Cyrl', 0.0), ('cat_Latn', 0.0), ('ceb_Latn', 0.0), ('ces_Latn', 0.0), ('cjk_Latn', 0.0), ('ckb_Arab', 0.0), ('crh_Latn', 0.0), ('cym_Latn', 0.0), ('dan_Latn', 0.0), ('deu_Latn', 0.0), ('dik_Latn', 0.0), ('dyu_Latn', 0.0), ('dzo_Tibt', 0.0), ('ell_Grek', 0.0), ('eng_Latn', 0.0), ('epo_Latn', 0.0), ('est_Latn', 0.0), ('eus_Latn', 0.0), ('ewe_Latn', 0.0), ('fao_Latn', 0.0), ('pes_Arab', 0.0), ('fij_Latn', 0.0), ('fin_Latn', 0.0), ('fon_Latn', 0.0), ('fra_Latn', 0.0), ('fur_Latn', 0.0), ('fuv_Latn', 0.0), ('gla_Latn', 0.0), ('gle_Latn', 0.0), ('glg_Latn', 0.0), ('grn_Latn', 0.0), ('guj_Gujr', 0.0), ('hat_Latn', 0.0), ('hau_Latn', 0.0), ('heb_Hebr', 0.0), ('hin_Deva', 0.0), ('hne_Deva', 0.0), ('hrv_Latn', 0.0), ('hun_Latn', 0.0), ('hye_Armn', 0.0), ('ibo_Latn', 0.0), ('ilo_Latn', 0.0), ('ind_Latn', 0.0), ('isl_Latn', 0.0), ('ita_Latn', 0.0), ('jav_Latn', 0.0), ('jpn_Jpan', 0.0), ('kab_Latn', 0.0), ('kac_Latn', 0.0), ('kam_Latn', 0.0), ('kan_Knda', 0.0), ('kas_Arab', 0.0), ('kas_Deva', 0.0), ('kat_Geor', 0.0), ('knc_Arab', 0.0), ('knc_Latn', 0.0), ('kaz_Cyrl', 0.0), ('kbp_Latn', 0.0), ('kea_Latn', 0.0), ('khm_Khmr', 0.0), ('kik_Latn', 0.0), ('kin_Latn', 0.0), ('kir_Cyrl', 0.0), ('kmb_Latn', 0.0), ('kon_Latn', 0.0), ('kor_Hang', 0.0), ('kmr_Latn', 0.0), ('lao_Laoo', 0.0), ('lvs_Latn', 0.0), ('lij_Latn', 0.0), ('lim_Latn', 0.0), ('lin_Latn', 0.0), ('lit_Latn', 0.0), ('lmo_Latn', 0.0), ('ltg_Latn', 0.0), ('ltz_Latn', 0.0), ('lua_Latn', 0.0), ('lug_Latn', 0.0), ('luo_Latn', 0.0), ('lus_Latn', 0.0), ('mag_Deva', 0.0), ('mai_Deva', 0.0), ('mal_Mlym', 0.0), ('mar_Deva', 0.0), ('min_Latn', 0.0), ('mkd_Cyrl', 0.0), ('plt_Latn', 0.0), ('mlt_Latn', 0.0), ('mni_Beng', 0.0), ('khk_Cyrl', 0.0), ('mos_Latn', 0.0), ('mri_Latn', 0.0), ('zsm_Latn', 0.0), ('mya_Mymr', 0.0), ('nld_Latn', 0.0), ('nno_Latn', 0.0), ('nob_Latn', 0.0), ('npi_Deva', 0.0), ('nso_Latn', 0.0), ('nus_Latn', 0.0), ('nya_Latn', 0.0), ('oci_Latn', 0.0), ('gaz_Latn', 0.0), ('ory_Orya', 0.0), ('pag_Latn', 0.0), ('pan_Guru', 0.0), ('pap_Latn', 0.0), ('pol_Latn', 0.0), ('por_Latn', 0.0), ('prs_Arab', 0.0), ('pbt_Arab', 0.0), ('quy_Latn', 0.0), ('ron_Latn', 0.0), ('run_Latn', 0.0), ('rus_Cyrl', 0.0), ('sag_Latn', 0.0), ('san_Deva', 0.0), ('sat_Beng', 0.0), ('scn_Latn', 0.0), ('shn_Mymr', 0.0), ('sin_Sinh', 0.0), ('slk_Latn', 0.0), ('slv_Latn', 0.0), ('smo_Latn', 0.0), ('sna_Latn', 0.0), ('snd_Arab', 0.0), ('som_Latn', 0.0), ('sot_Latn', 0.0), ('spa_Latn', 0.0), ('als_Latn', 0.0), ('srd_Latn', 0.0), ('srp_Cyrl', 0.0), ('ssw_Latn', 0.0), ('sun_Latn', 0.0), ('swe_Latn', 0.0), ('swh_Latn', 0.0), ('szl_Latn', 0.0), ('tam_Taml', 0.0), ('tat_Cyrl', 0.0), ('tel_Telu', 0.0), ('tgk_Cyrl', 0.0), ('tgl_Latn', 0.0), ('tha_Thai', 0.0), ('tir_Ethi', 0.0), ('taq_Latn', 0.0), ('taq_Tfng', 0.0), ('tpi_Latn', 0.0), ('tsn_Latn', 0.0), ('tso_Latn', 0.0), ('tuk_Latn', 0.0), ('tum_Latn', 0.0), ('tur_Latn', 0.0), ('twi_Latn', 0.0), ('tzm_Tfng', 0.0), ('uig_Arab', 0.0), ('ukr_Cyrl', 0.0), ('umb_Latn', 0.0), ('urd_Arab', 0.0), ('uzn_Latn', 0.0), ('vec_Latn', 0.0), ('vie_Latn', 0.0), ('war_Latn', 0.0), ('wol_Latn', 0.0), ('xho_Latn', 0.0), ('ydd_Hebr', 0.0), ('yor_Latn', 0.0), ('yue_Hant', 0.0), ('zho_Hans', 0.0), ('zho_Hant', 0.0), ('zul_Latn', 0.0)
- # fmt: on
- ]
- vocab += [("", 0.0)]
- return vocab
-
- def unk_id(self, proto):
- return 3
-
- def post_processor(self):
- return processors.TemplateProcessing(
- single="eng_Latn $A ",
- pair="eng_Latn $A $B ",
- special_tokens=[
- ("eng_Latn", self.original_tokenizer.convert_tokens_to_ids("eng_Latn")),
- ("", self.original_tokenizer.convert_tokens_to_ids("")),
- ],
- )
-
-
-class SeamlessM4TConverter(SpmConverter):
- def vocab(self, proto):
- vocab = [
- ("", 0.0),
- ("", 0.0),
- ("", 0.0),
- ("", 0.0),
- ]
- vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
- return vocab
-
- def unk_id(self, proto):
- return self.original_tokenizer.unk_token_id
-
- def post_processor(self):
- return processors.TemplateProcessing(
- single="__eng__ $A ",
- pair="__eng__ $A $B ",
- special_tokens=[
- ("__eng__", self.original_tokenizer.convert_tokens_to_ids("__eng__")),
- ("", self.original_tokenizer.convert_tokens_to_ids("")),
- ],
- )
-
-
-class XLMRobertaConverter(SpmConverter):
- def vocab(self, proto):
- vocab = [
- ("", 0.0),
- ("", 0.0),
- ("", 0.0),
- ("", 0.0),
- ]
- vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
- vocab += [("", 0.0)]
- return vocab
-
- def unk_id(self, proto):
- unk_id = 3
- return unk_id
-
- def post_processor(self):
- return processors.TemplateProcessing(
- single=" $A ",
- pair=" $A $B ",
- special_tokens=[
- ("", self.original_tokenizer.convert_tokens_to_ids("")),
- ("", self.original_tokenizer.convert_tokens_to_ids("")),
- ],
- )
-
-
-class XLNetConverter(SpmConverter):
- def vocab(self, proto):
- return [
- (piece.piece, piece.score) if check_number_comma(piece.piece) else (piece.piece, piece.score - 100)
- for piece in proto.pieces
- ]
-
- def normalizer(self, proto):
- list_normalizers = [
- normalizers.Replace("``", '"'),
- normalizers.Replace("''", '"'),
- ]
- if not self.original_tokenizer.keep_accents:
- list_normalizers.append(normalizers.NFKD())
- list_normalizers.append(normalizers.StripAccents())
- if self.original_tokenizer.do_lower_case:
- list_normalizers.append(normalizers.Lowercase())
-
- precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
-
- if precompiled_charsmap:
- list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))
-
- list_normalizers.append(normalizers.Replace(Regex(" {2,}"), " "))
- return normalizers.Sequence(list_normalizers)
-
- def post_processor(self):
- return processors.TemplateProcessing(
- single="$A:0 :0 :2",
- pair="$A:0 :0 $B:1 :1 :2",
- special_tokens=[
- ("", self.original_tokenizer.convert_tokens_to_ids("")),
- ("", self.original_tokenizer.convert_tokens_to_ids("")),
- ],
- )
-
-
-class ReformerConverter(SpmConverter):
- pass
-
-
-class RemBertConverter(SpmConverter):
- # Inspired from AlbertConverter
- def normalizer(self, proto):
- list_normalizers = [
- normalizers.Replace("``", '"'),
- normalizers.Replace("''", '"'),
- normalizers.Replace(Regex(" {2,}"), " "),
- ]
- if not self.original_tokenizer.keep_accents:
- list_normalizers.append(normalizers.NFKD())
- list_normalizers.append(normalizers.StripAccents())
- if self.original_tokenizer.do_lower_case:
- list_normalizers.append(normalizers.Lowercase())
-
- precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
-
- if precompiled_charsmap:
- list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))
-
- return normalizers.Sequence(list_normalizers)
-
- def post_processor(self):
- return processors.TemplateProcessing(
- single="[CLS]:0 $A:0 [SEP]:0",
- pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
- special_tokens=[
- ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
- ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
- ],
- )
-
-
-class BertGenerationConverter(SpmConverter):
- pass
-
-
-class PegasusConverter(SpmConverter):
- def vocab(self, proto):
- vocab = [
- (self.original_tokenizer.pad_token, 0.0),
- (self.original_tokenizer.eos_token, 0.0),
- ]
-
- if self.original_tokenizer.mask_token_sent is not None:
- vocab += [(self.original_tokenizer.mask_token_sent, 0.0)]
-
- if (
- self.original_tokenizer.mask_token is not None
- and self.original_tokenizer.mask_token_id < self.original_tokenizer.offset
- ):
- vocab += [(self.original_tokenizer.mask_token, 0.0)]
-
- vocab += [(f"", -100.0) for i in range(2, self.original_tokenizer.offset)]
- vocab += [(piece.piece, piece.score) for piece in proto.pieces[2:]]
- return vocab
-
- def unk_id(self, proto):
- return proto.trainer_spec.unk_id + self.original_tokenizer.offset
-
- def pre_tokenizer(self, replacement, add_prefix_space):
- prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
- return pre_tokenizers.Sequence(
- [
- pre_tokenizers.WhitespaceSplit(),
- pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme),
- ]
- )
-
- def post_processor(self):
- eos = self.original_tokenizer.eos_token
- special_tokens = [
- (eos, self.original_tokenizer.eos_token_id),
- ]
- return processors.TemplateProcessing(single=["$A", eos], pair=["$A", "$B", eos], special_tokens=special_tokens)
-
-
-class T5Converter(SpmConverter):
- def vocab(self, proto):
- num_extra_ids = self.original_tokenizer._extra_ids
- vocab = [(piece.piece, piece.score) for piece in proto.pieces]
- vocab += [(f"", 0.0) for i in range(num_extra_ids - 1, -1, -1)]
- return vocab
-
- def post_processor(self):
- return processors.TemplateProcessing(
- single=["$A", ""],
- pair=["$A", "", "$B", ""],
- special_tokens=[
- ("", self.original_tokenizer.convert_tokens_to_ids("")),
- ],
- )
-
-
-class WhisperConverter(Converter):
- def converted(self) -> Tokenizer:
- vocab = self.original_tokenizer.encoder
- merges = list(self.original_tokenizer.bpe_ranks.keys())
-
- tokenizer = Tokenizer(
- BPE(
- vocab=vocab,
- merges=merges,
- dropout=None,
- continuing_subword_prefix="",
- end_of_word_suffix="",
- fuse_unk=False,
- )
- )
-
- tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=self.original_tokenizer.add_prefix_space)
- tokenizer.decoder = decoders.ByteLevel()
-
- prefix_token_ids = self.original_tokenizer.prefix_tokens
- prefixes = self.original_tokenizer.convert_ids_to_tokens(prefix_token_ids)
- eos = self.original_tokenizer.eos_token
- eos_token_id = self.original_tokenizer.eos_token_id
- prefix_template = " ".join([f"{token}:0" for token in prefixes])
- tokenizer.post_processor = processors.TemplateProcessing(
- single=f"{prefix_template} $A:0 {eos}:0",
- pair=f"{prefix_template} $A:0 $B:1 {eos}:1",
- special_tokens=[
- (eos, eos_token_id),
- *zip(prefixes, prefix_token_ids),
- ],
- )
-
- return tokenizer
-
-
-class BigBirdConverter(SpmConverter):
- def post_processor(self):
- return processors.TemplateProcessing(
- single="[CLS]:0 $A:0 [SEP]:0",
- pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
- special_tokens=[
- ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
- ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
- ],
- )
-
-
-class CLIPConverter(Converter):
- def converted(self) -> Tokenizer:
- vocab = self.original_tokenizer.encoder
- merges = list(self.original_tokenizer.bpe_ranks.keys())
- unk_token = self.original_tokenizer.unk_token
-
- tokenizer = Tokenizer(
- BPE(
- vocab=vocab,
- merges=merges,
- dropout=None,
- continuing_subword_prefix="",
- end_of_word_suffix="",
- fuse_unk=False,
- unk_token=str(unk_token),
- )
- )
-
- tokenizer.normalizer = normalizers.Sequence(
- [normalizers.NFC(), normalizers.Replace(Regex(r"\s+"), " "), normalizers.Lowercase()]
- )
- tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
- [
- pre_tokenizers.Split(
- Regex(r"""'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+"""),
- behavior="removed",
- invert=True,
- ),
- pre_tokenizers.ByteLevel(add_prefix_space=False),
- ]
- )
- tokenizer.decoder = decoders.ByteLevel()
-
- # Hack to have a ByteLevel and TemplaceProcessor
- tokenizer.post_processor = processors.RobertaProcessing(
- sep=(self.original_tokenizer.eos_token, self.original_tokenizer.eos_token_id),
- cls=(self.original_tokenizer.bos_token, self.original_tokenizer.bos_token_id),
- add_prefix_space=False,
- trim_offsets=False,
- )
- return tokenizer
-
-
-class LayoutLMv2Converter(Converter):
- def converted(self) -> Tokenizer:
- vocab = self.original_tokenizer.vocab
- tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
-
- tokenize_chinese_chars = False
- strip_accents = False
- do_lower_case = True
- if hasattr(self.original_tokenizer, "basic_tokenizer"):
- tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
- strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
- do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
-
- tokenizer.normalizer = normalizers.BertNormalizer(
- clean_text=True,
- handle_chinese_chars=tokenize_chinese_chars,
- strip_accents=strip_accents,
- lowercase=do_lower_case,
- )
- tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
-
- cls = str(self.original_tokenizer.cls_token)
- sep = str(self.original_tokenizer.sep_token)
- cls_token_id = self.original_tokenizer.cls_token_id
- sep_token_id = self.original_tokenizer.sep_token_id
-
- tokenizer.post_processor = processors.TemplateProcessing(
- single=f"{cls}:0 $A:0 {sep}:0",
- pair=f"{cls}:0 $A:0 {sep}:0 $B:1 {sep}:1",
- special_tokens=[
- (cls, cls_token_id),
- (sep, sep_token_id),
- ],
- )
- tokenizer.decoder = decoders.WordPiece(prefix="##")
-
- return tokenizer
-
-
-class BlenderbotConverter(Converter):
- def converted(self) -> Tokenizer:
- ot = self.original_tokenizer
- vocab = ot.encoder
- merges = list(ot.bpe_ranks.keys())
-
- tokenizer = Tokenizer(
- BPE(
- vocab=vocab,
- merges=merges,
- dropout=None,
- continuing_subword_prefix="",
- end_of_word_suffix="",
- fuse_unk=False,
- )
- )
-
- tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space)
- tokenizer.decoder = decoders.ByteLevel()
- tokenizer.post_processor = processors.TemplateProcessing(
- single=f"$A:0 {ot.eos_token}:0",
- special_tokens=[
- (ot.eos_token, ot.eos_token_id),
- ],
- )
-
- return tokenizer
-
-
-class XGLMConverter(SpmConverter):
- def vocab(self, proto):
- vocab = [
- ("", 0.0),
- ("", 0.0),
- ("", 0.0),
- ("", 0.0),
- ]
- vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
- # fmt: off
- vocab += [("", 0.0), ("", 0.0), ("", 0.0), ("", 0.0), ("", 0.0), ("", 0.0), ("", 0.0)]
- # fmt: on
- return vocab
-
- def unk_id(self, proto):
- unk_id = 3
- return unk_id
-
- def post_processor(self):
- return processors.TemplateProcessing(
- single=" $A",
- pair=" $A $B",
- special_tokens=[
- ("", self.original_tokenizer.convert_tokens_to_ids("")),
- ("", self.original_tokenizer.convert_tokens_to_ids("")),
- ],
- )
-
-
-class LlamaConverter(SpmConverter):
- handle_byte_fallback = True
-
- def vocab(self, proto):
- vocab = [
- ("", 0.0),
- ("", 0.0),
- ("", 0.0),
- ]
- vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
- return vocab
-
- def unk_id(self, proto):
- unk_id = 0
- return unk_id
-
- def decoder(self, replacement, add_prefix_space):
- return decoders.Sequence(
- [
- decoders.Replace("▁", " "),
- decoders.ByteFallback(),
- decoders.Fuse(),
- decoders.Strip(content=" ", left=1),
- ]
- )
-
- def tokenizer(self, proto):
- model_type = proto.trainer_spec.model_type
- vocab_scores = self.vocab(proto)
- if model_type == 1:
- import tokenizers
-
- if version.parse(tokenizers.__version__) < version.parse("0.14.0"):
- tokenizer = Tokenizer(Unigram(vocab_scores, 0))
- else:
- tokenizer = Tokenizer(Unigram(vocab_scores, 0, byte_fallback=True))
-
- elif model_type == 2:
- _, merges = SentencePieceExtractor(self.original_tokenizer.vocab_file).extract(vocab_scores)
- bpe_vocab = {word: i for i, (word, _score) in enumerate(vocab_scores)}
- tokenizer = Tokenizer(
- BPE(bpe_vocab, merges, unk_token=proto.trainer_spec.unk_piece, fuse_unk=True, byte_fallback=True)
- )
- tokenizer.add_special_tokens(
- [
- AddedToken("", normalized=False, special=True),
- AddedToken("", normalized=False, special=True),
- AddedToken("", normalized=False, special=True),
- ]
- )
- else:
- raise Exception(
- "You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
- )
-
- return tokenizer
-
- def normalizer(self, proto):
- return normalizers.Sequence(
- [
- normalizers.Prepend(prepend="▁"),
- normalizers.Replace(pattern=" ", content="▁"),
- ]
- )
-
- def pre_tokenizer(self, replacement, add_prefix_space):
- return None
-
- def post_processor(self):
- # the processor is defined in the LlamaTokenizerFast class.
- return None
-
-
-class MarkupLMConverter(Converter):
- def converted(self) -> Tokenizer:
- ot = self.original_tokenizer
- vocab = ot.encoder
- merges = list(ot.bpe_ranks.keys())
-
- tokenizer = Tokenizer(
- BPE(
- vocab=vocab,
- merges=merges,
- dropout=None,
- continuing_subword_prefix="",
- end_of_word_suffix="",
- fuse_unk=False,
- unk_token=self.original_tokenizer.unk_token,
- )
- )
-
- tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space)
- tokenizer.decoder = decoders.ByteLevel()
-
- cls = str(self.original_tokenizer.cls_token)
- sep = str(self.original_tokenizer.sep_token)
- cls_token_id = self.original_tokenizer.cls_token_id
- sep_token_id = self.original_tokenizer.sep_token_id
-
- tokenizer.post_processor = processors.TemplateProcessing(
- single=f"{cls} $A {sep}",
- pair=f"{cls} $A {sep} $B {sep}",
- special_tokens=[
- (cls, cls_token_id),
- (sep, sep_token_id),
- ],
- )
-
- return tokenizer
-
-class MarianConverter(SpmConverter):
- def __init__(self, *args, index: int = 0):
- requires_backends(self, "protobuf")
-
- super(SpmConverter, self).__init__(*args)
-
- # from .utils import sentencepiece_model_pb2 as model_pb2
- model_pb2 = import_protobuf()
-
- m = model_pb2.ModelProto()
- print(self.original_tokenizer.spm_files)
- with open(self.original_tokenizer.spm_files[index], "rb") as f:
- m.ParseFromString(f.read())
- self.proto = m
- print(self.original_tokenizer)
- #with open(self.original_tokenizer.vocab_path, "r") as f:
- dir_path = Path(self.original_tokenizer.spm_files[0]).parents[0]
- with open(dir_path / "vocab.json", "r") as f:
- import json
- self._vocab = json.load(f)
-
- if self.proto.trainer_spec.byte_fallback:
- if not getattr(self, "handle_byte_fallback", None):
- warnings.warn(
- "The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option"
- " which is not implemented in the fast tokenizers. In practice this means that the fast version of the"
- " tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these "
- "unknown tokens into a sequence of byte tokens matching the original piece of text."
- )
-
- def vocab(self, proto):
- vocab_size = max(self._vocab.values()) + 1
- vocab = [("", -100) for _ in range(vocab_size)]
- for piece in proto.pieces:
- try:
- index = self._vocab[piece.piece]
- except Exception:
- print(f"Ignored missing piece {piece.piece}")
- vocab[index] = (piece.piece, piece.score)
- return vocab
-
-SLOW_TO_FAST_CONVERTERS = {
- "AlbertTokenizer": AlbertConverter,
- "BartTokenizer": RobertaConverter,
- "BarthezTokenizer": BarthezConverter,
- "BertTokenizer": BertConverter,
- "BigBirdTokenizer": BigBirdConverter,
- "BlenderbotTokenizer": BlenderbotConverter,
- "CamembertTokenizer": CamembertConverter,
- "CLIPTokenizer": CLIPConverter,
- "CodeGenTokenizer": GPT2Converter,
- "ConvBertTokenizer": BertConverter,
- "DebertaTokenizer": DebertaConverter,
- "DebertaV2Tokenizer": DebertaV2Converter,
- "DistilBertTokenizer": BertConverter,
- "DPRReaderTokenizer": BertConverter,
- "DPRQuestionEncoderTokenizer": BertConverter,
- "DPRContextEncoderTokenizer": BertConverter,
- "ElectraTokenizer": BertConverter,
- "FNetTokenizer": AlbertConverter,
- "FunnelTokenizer": FunnelConverter,
- "GPT2Tokenizer": GPT2Converter,
- "HerbertTokenizer": HerbertConverter,
- "LayoutLMTokenizer": BertConverter,
- "LayoutLMv2Tokenizer": BertConverter,
- "LayoutLMv3Tokenizer": RobertaConverter,
- "LayoutXLMTokenizer": XLMRobertaConverter,
- "LongformerTokenizer": RobertaConverter,
- "LEDTokenizer": RobertaConverter,
- "LxmertTokenizer": BertConverter,
- "MarkupLMTokenizer": MarkupLMConverter,
- "MBartTokenizer": MBartConverter,
- "MBart50Tokenizer": MBart50Converter,
- "MPNetTokenizer": MPNetConverter,
- "MobileBertTokenizer": BertConverter,
- "MvpTokenizer": RobertaConverter,
- "NllbTokenizer": NllbConverter,
- "OpenAIGPTTokenizer": OpenAIGPTConverter,
- "PegasusTokenizer": PegasusConverter,
- "RealmTokenizer": BertConverter,
- "ReformerTokenizer": ReformerConverter,
- "RemBertTokenizer": RemBertConverter,
- "RetriBertTokenizer": BertConverter,
- "RobertaTokenizer": RobertaConverter,
- "RoFormerTokenizer": RoFormerConverter,
- "SeamlessM4TTokenizer": SeamlessM4TConverter,
- "SqueezeBertTokenizer": BertConverter,
- "T5Tokenizer": T5Converter,
- "WhisperTokenizer": WhisperConverter,
- "XLMRobertaTokenizer": XLMRobertaConverter,
- "XLNetTokenizer": XLNetConverter,
- "SplinterTokenizer": SplinterConverter,
- "XGLMTokenizer": XGLMConverter,
- "LlamaTokenizer": LlamaConverter,
- "CodeLlamaTokenizer": LlamaConverter,
-}
-
-
-def convert_slow_tokenizer(transformer_tokenizer) -> Tokenizer:
- """
- Utilities to convert a slow tokenizer instance in a fast tokenizer instance.
-
- Args:
- transformer_tokenizer ([`~tokenization_utils_base.PreTrainedTokenizer`]):
- Instance of a slow tokenizer to convert in the backend tokenizer for
- [`~tokenization_utils_base.PreTrainedTokenizerFast`].
-
- Return:
- A instance of [`~tokenizers.Tokenizer`] to be used as the backend tokenizer of a
- [`~tokenization_utils_base.PreTrainedTokenizerFast`]
- """
-
- tokenizer_class_name = transformer_tokenizer.__class__.__name__
-
- if tokenizer_class_name not in SLOW_TO_FAST_CONVERTERS:
- raise ValueError(
- f"An instance of tokenizer class {tokenizer_class_name} cannot be converted in a Fast tokenizer instance."
- " No converter was found. Currently available slow->fast convertors:"
- f" {list(SLOW_TO_FAST_CONVERTERS.keys())}"
- )
-
- converter_class = SLOW_TO_FAST_CONVERTERS[tokenizer_class_name]
-
- return converter_class(transformer_tokenizer).converted()
diff --git a/candle-examples/examples/marian-mt/main.rs b/candle-examples/examples/marian-mt/main.rs
index 89b3a9a3..76445bdb 100644
--- a/candle-examples/examples/marian-mt/main.rs
+++ b/candle-examples/examples/marian-mt/main.rs
@@ -20,6 +20,22 @@ enum Which {
Big,
}
+#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
+enum LanguagePair {
+ #[value(name = "fr-en")]
+ FrEn,
+ #[value(name = "en-zh")]
+ EnZh,
+ #[value(name = "en-hi")]
+ EnHi,
+ #[value(name = "en-es")]
+ EnEs,
+ #[value(name = "en-fr")]
+ EnFr,
+ #[value(name = "en-ru")]
+ EnRu,
+}
+
// TODO: Maybe add support for the conditional prompt.
#[derive(Parser)]
struct Args {
@@ -36,6 +52,10 @@ struct Args {
#[arg(long, default_value = "big")]
which: Which,
+ // Choose which language pair to use
+ #[arg(long, default_value = "fr-en")]
+ language_pair: LanguagePair,
+
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
@@ -53,21 +73,43 @@ pub fn main() -> anyhow::Result<()> {
use hf_hub::api::sync::Api;
let args = Args::parse();
- let config = match args.which {
- Which::Base => marian::Config::opus_mt_fr_en(),
- Which::Big => marian::Config::opus_mt_tc_big_fr_en(),
+ let config = match (args.which, args.language_pair) {
+ (Which::Base, LanguagePair::FrEn) => marian::Config::opus_mt_fr_en(),
+ (Which::Big, LanguagePair::FrEn) => marian::Config::opus_mt_tc_big_fr_en(),
+ (Which::Base, LanguagePair::EnZh) => marian::Config::opus_mt_en_zh(),
+ (Which::Base, LanguagePair::EnHi) => marian::Config::opus_mt_en_hi(),
+ (Which::Base, LanguagePair::EnEs) => marian::Config::opus_mt_en_es(),
+ (Which::Base, LanguagePair::EnFr) => marian::Config::opus_mt_fr_en(),
+ (Which::Base, LanguagePair::EnRu) => marian::Config::opus_mt_en_ru(),
+ (Which::Big, lp) => anyhow::bail!("big is not supported for language pair {lp:?}"),
+ };
+ let tokenizer_default_repo = match args.language_pair {
+ LanguagePair::FrEn => "lmz/candle-marian",
+ LanguagePair::EnZh
+ | LanguagePair::EnHi
+ | LanguagePair::EnEs
+ | LanguagePair::EnFr
+ | LanguagePair::EnRu => "KeighBee/candle-marian",
};
let tokenizer = {
let tokenizer = match args.tokenizer {
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
None => {
- let name = match args.which {
- Which::Base => "tokenizer-marian-base-fr.json",
- Which::Big => "tokenizer-marian-fr.json",
+ let filename = match (args.which, args.language_pair) {
+ (Which::Base, LanguagePair::FrEn) => "tokenizer-marian-base-fr.json",
+ (Which::Big, LanguagePair::FrEn) => "tokenizer-marian-fr.json",
+ (Which::Base, LanguagePair::EnZh) => "tokenizer-marian-base-en-zh-en.json",
+ (Which::Base, LanguagePair::EnHi) => "tokenizer-marian-base-en-hi-en.json",
+ (Which::Base, LanguagePair::EnEs) => "tokenizer-marian-base-en-es-en.json",
+ (Which::Base, LanguagePair::EnFr) => "tokenizer-marian-base-en-fr-en.json",
+ (Which::Base, LanguagePair::EnRu) => "tokenizer-marian-base-en-ru-en.json",
+ (Which::Big, lp) => {
+ anyhow::bail!("big is not supported for language pair {lp:?}")
+ }
};
Api::new()?
- .model("lmz/candle-marian".to_string())
- .get(name)?
+ .model(tokenizer_default_repo.to_string())
+ .get(filename)?
}
};
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
@@ -77,13 +119,21 @@ pub fn main() -> anyhow::Result<()> {
let tokenizer = match args.tokenizer_dec {
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
None => {
- let name = match args.which {
- Which::Base => "tokenizer-marian-base-en.json",
- Which::Big => "tokenizer-marian-en.json",
+ let filename = match (args.which, args.language_pair) {
+ (Which::Base, LanguagePair::FrEn) => "tokenizer-marian-base-en.json",
+ (Which::Big, LanguagePair::FrEn) => "tokenizer-marian-en.json",
+ (Which::Base, LanguagePair::EnZh) => "tokenizer-marian-base-en-zh-zh.json",
+ (Which::Base, LanguagePair::EnHi) => "tokenizer-marian-base-en-hi-hi.json",
+ (Which::Base, LanguagePair::EnEs) => "tokenizer-marian-base-en-es-es.json",
+ (Which::Base, LanguagePair::EnFr) => "tokenizer-marian-base-en-fr-fr.json",
+ (Which::Base, LanguagePair::EnRu) => "tokenizer-marian-base-en-ru-ru.json",
+ (Which::Big, lp) => {
+ anyhow::bail!("big is not supported for language pair {lp:?}")
+ }
};
Api::new()?
- .model("lmz/candle-marian".to_string())
- .get(name)?
+ .model(tokenizer_default_repo.to_string())
+ .get(filename)?
}
};
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
@@ -94,18 +144,48 @@ pub fn main() -> anyhow::Result<()> {
let vb = {
let model = match args.model {
Some(model) => std::path::PathBuf::from(model),
- None => match args.which {
- Which::Base => Api::new()?
- .repo(hf_hub::Repo::with_revision(
+ None => {
+ let api = Api::new()?;
+ let api = match (args.which, args.language_pair) {
+ (Which::Base, LanguagePair::FrEn) => api.repo(hf_hub::Repo::with_revision(
"Helsinki-NLP/opus-mt-fr-en".to_string(),
hf_hub::RepoType::Model,
"refs/pr/4".to_string(),
- ))
- .get("model.safetensors")?,
- Which::Big => Api::new()?
- .model("Helsinki-NLP/opus-mt-tc-big-fr-en".to_string())
- .get("model.safetensors")?,
- },
+ )),
+ (Which::Big, LanguagePair::FrEn) => {
+ api.model("Helsinki-NLP/opus-mt-tc-big-fr-en".to_string())
+ }
+ (Which::Base, LanguagePair::EnZh) => api.repo(hf_hub::Repo::with_revision(
+ "Helsinki-NLP/opus-mt-en-zh".to_string(),
+ hf_hub::RepoType::Model,
+ "refs/pr/13".to_string(),
+ )),
+ (Which::Base, LanguagePair::EnHi) => api.repo(hf_hub::Repo::with_revision(
+ "Helsinki-NLP/opus-mt-en-hi".to_string(),
+ hf_hub::RepoType::Model,
+ "refs/pr/3".to_string(),
+ )),
+ (Which::Base, LanguagePair::EnEs) => api.repo(hf_hub::Repo::with_revision(
+ "Helsinki-NLP/opus-mt-en-es".to_string(),
+ hf_hub::RepoType::Model,
+ "refs/pr/4".to_string(),
+ )),
+ (Which::Base, LanguagePair::EnFr) => api.repo(hf_hub::Repo::with_revision(
+ "Helsinki-NLP/opus-mt-en-fr".to_string(),
+ hf_hub::RepoType::Model,
+ "refs/pr/9".to_string(),
+ )),
+ (Which::Base, LanguagePair::EnRu) => api.repo(hf_hub::Repo::with_revision(
+ "Helsinki-NLP/opus-mt-en-ru".to_string(),
+ hf_hub::RepoType::Model,
+ "refs/pr/7".to_string(),
+ )),
+ (Which::Big, lp) => {
+ anyhow::bail!("big is not supported for language pair {lp:?}")
+ }
+ };
+ api.get("model.safetensors")?
+ }
};
unsafe { VarBuilder::from_mmaped_safetensors(&[&model], DType::F32, &device)? }
};
diff --git a/candle-examples/examples/marian-mt/python/convert_slow_tokenizer.py b/candle-examples/examples/marian-mt/python/convert_slow_tokenizer.py
new file mode 100644
index 00000000..7d2f3efb
--- /dev/null
+++ b/candle-examples/examples/marian-mt/python/convert_slow_tokenizer.py
@@ -0,0 +1,53 @@
+from pathlib import Path
+import warnings
+
+from transformers import AutoTokenizer
+from transformers.convert_slow_tokenizer import SpmConverter, requires_backends, import_protobuf
+
+class MarianConverter(SpmConverter):
+ def __init__(self, *args, index: int = 0):
+ requires_backends(self, "protobuf")
+
+ super(SpmConverter, self).__init__(*args)
+
+ # from .utils import sentencepiece_model_pb2 as model_pb2
+ model_pb2 = import_protobuf()
+
+ m = model_pb2.ModelProto()
+ print(self.original_tokenizer.spm_files)
+ with open(self.original_tokenizer.spm_files[index], "rb") as f:
+ m.ParseFromString(f.read())
+ self.proto = m
+ print(self.original_tokenizer)
+ #with open(self.original_tokenizer.vocab_path, "r") as f:
+ dir_path = Path(self.original_tokenizer.spm_files[0]).parents[0]
+ with open(dir_path / "vocab.json", "r") as f:
+ import json
+ self._vocab = json.load(f)
+
+ if self.proto.trainer_spec.byte_fallback:
+ if not getattr(self, "handle_byte_fallback", None):
+ warnings.warn(
+ "The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option"
+ " which is not implemented in the fast tokenizers. In practice this means that the fast version of the"
+ " tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these "
+ "unknown tokens into a sequence of byte tokens matching the original piece of text."
+ )
+
+ def vocab(self, proto):
+ vocab_size = max(self._vocab.values()) + 1
+ vocab = [("", -100) for _ in range(vocab_size)]
+ for piece in proto.pieces:
+ try:
+ index = self._vocab[piece.piece]
+ except Exception:
+ print(f"Ignored missing piece {piece.piece}")
+ vocab[index] = (piece.piece, piece.score)
+ return vocab
+
+
+tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-fr-en", use_fast=False)
+fast_tokenizer = MarianConverter(tokenizer, index=0).converted()
+fast_tokenizer.save("tokenizer-marian-base-fr.json")
+fast_tokenizer = MarianConverter(tokenizer, index=1).converted()
+fast_tokenizer.save("tokenizer-marian-base-en.json")
\ No newline at end of file
diff --git a/candle-examples/examples/marian-mt/python/requirements.txt b/candle-examples/examples/marian-mt/python/requirements.txt
new file mode 100644
index 00000000..2eabc6d2
--- /dev/null
+++ b/candle-examples/examples/marian-mt/python/requirements.txt
@@ -0,0 +1,22 @@
+certifi==2025.1.31
+charset-normalizer==3.4.1
+click==8.1.8
+filelock==3.18.0
+fsspec==2025.3.2
+huggingface-hub==0.30.1
+idna==3.10
+joblib==1.4.2
+numpy==2.2.4
+packaging==24.2
+protobuf==6.30.2
+pyyaml==6.0.2
+regex==2024.11.6
+requests==2.32.3
+sacremoses==0.1.1
+safetensors==0.5.3
+sentencepiece==0.2.0
+tokenizers==0.21.1
+tqdm==4.67.1
+transformers==4.50.3
+typing-extensions==4.13.0
+urllib3==2.3.0
\ No newline at end of file
diff --git a/candle-transformers/src/models/marian.rs b/candle-transformers/src/models/marian.rs
index c4ba0a15..313b48ed 100644
--- a/candle-transformers/src/models/marian.rs
+++ b/candle-transformers/src/models/marian.rs
@@ -81,6 +81,126 @@ impl Config {
vocab_size: 59514,
}
}
+
+ pub fn opus_mt_en_zh() -> Self {
+ Self {
+ activation_function: candle_nn::Activation::Swish,
+ d_model: 512,
+ decoder_attention_heads: 8,
+ decoder_ffn_dim: 2048,
+ decoder_layers: 6,
+ decoder_start_token_id: 65000,
+ decoder_vocab_size: Some(65001),
+ encoder_attention_heads: 8,
+ encoder_ffn_dim: 2048,
+ encoder_layers: 6,
+ eos_token_id: 0,
+ forced_eos_token_id: 0,
+ is_encoder_decoder: true,
+ max_position_embeddings: 512,
+ pad_token_id: 65000,
+ scale_embedding: true,
+ share_encoder_decoder_embeddings: true,
+ use_cache: true,
+ vocab_size: 65001,
+ }
+ }
+
+ pub fn opus_mt_en_hi() -> Self {
+ Self {
+ activation_function: candle_nn::Activation::Swish,
+ d_model: 512,
+ decoder_attention_heads: 8,
+ decoder_ffn_dim: 2048,
+ decoder_layers: 6,
+ decoder_start_token_id: 61949,
+ decoder_vocab_size: Some(61950),
+ encoder_attention_heads: 8,
+ encoder_ffn_dim: 2048,
+ encoder_layers: 6,
+ eos_token_id: 0,
+ forced_eos_token_id: 0,
+ is_encoder_decoder: true,
+ max_position_embeddings: 512,
+ pad_token_id: 61949,
+ scale_embedding: true,
+ share_encoder_decoder_embeddings: true,
+ use_cache: true,
+ vocab_size: 61950,
+ }
+ }
+
+ pub fn opus_mt_en_es() -> Self {
+ Self {
+ activation_function: candle_nn::Activation::Swish,
+ d_model: 512,
+ decoder_attention_heads: 8,
+ decoder_ffn_dim: 2048,
+ decoder_layers: 6,
+ decoder_start_token_id: 65000,
+ decoder_vocab_size: Some(65001),
+ encoder_attention_heads: 8,
+ encoder_ffn_dim: 2048,
+ encoder_layers: 6,
+ eos_token_id: 0,
+ forced_eos_token_id: 0,
+ is_encoder_decoder: true,
+ max_position_embeddings: 512,
+ pad_token_id: 65000,
+ scale_embedding: true,
+ share_encoder_decoder_embeddings: true,
+ use_cache: true,
+ vocab_size: 65001,
+ }
+ }
+
+ pub fn opus_mt_en_fr() -> Self {
+ Self {
+ activation_function: candle_nn::Activation::Swish,
+ d_model: 512,
+ decoder_attention_heads: 8,
+ decoder_ffn_dim: 2048,
+ decoder_layers: 6,
+ decoder_start_token_id: 59513,
+ decoder_vocab_size: Some(59514),
+ encoder_attention_heads: 8,
+ encoder_ffn_dim: 2048,
+ encoder_layers: 6,
+ eos_token_id: 0,
+ forced_eos_token_id: 0,
+ is_encoder_decoder: true,
+ max_position_embeddings: 512,
+ pad_token_id: 59513,
+ scale_embedding: true,
+ share_encoder_decoder_embeddings: true,
+ use_cache: true,
+ vocab_size: 59514,
+ }
+ }
+
+ pub fn opus_mt_en_ru() -> Self {
+ Self {
+ activation_function: candle_nn::Activation::Swish,
+ d_model: 512,
+ decoder_attention_heads: 8,
+ decoder_ffn_dim: 2048,
+ decoder_layers: 6,
+ decoder_start_token_id: 62517,
+ decoder_vocab_size: Some(62518),
+ encoder_attention_heads: 8,
+ encoder_ffn_dim: 2048,
+ encoder_layers: 6,
+ eos_token_id: 0,
+ forced_eos_token_id: 0,
+ is_encoder_decoder: true,
+ max_position_embeddings: 512,
+ pad_token_id: 62517,
+ scale_embedding: true,
+ share_encoder_decoder_embeddings: true,
+ use_cache: true,
+ vocab_size: 62518,
+ }
+ }
}
#[derive(Debug, Clone)]