You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
209 lines
8.3 KiB
209 lines
8.3 KiB
# Copyright 2020 The Facebook AI Research Team Authors and The HuggingFace Inc. team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
|
|
from tokenizers import Tokenizer, decoders, pre_tokenizers, processors
|
|
from tokenizers.models import Unigram
|
|
|
|
from ...tokenization_python import AddedToken
|
|
from ...tokenization_utils_tokenizers import TokenizersBackend
|
|
from ...utils import logging
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model", "tokenizer_file": "tokenizer.json"}
|
|
|
|
|
|
FAIRSEQ_LANGUAGE_CODES = ["ar_AR", "cs_CZ", "de_DE", "en_XX", "es_XX", "et_EE", "fi_FI", "fr_XX", "gu_IN", "hi_IN", "it_IT", "ja_XX", "kk_KZ", "ko_KR", "lt_LT", "lv_LV", "my_MM", "ne_NP", "nl_XX", "ro_RO", "ru_RU", "si_LK", "tr_TR", "vi_VN", "zh_CN"] # fmt: skip
|
|
|
|
|
|
class MBartTokenizer(TokenizersBackend):
|
|
"""
|
|
Construct an MBART tokenizer (backed by HuggingFace's *tokenizers* library). Based on
|
|
[Unigram](https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=unigram#models).
|
|
|
|
This tokenizer inherits from [`TokenizersBackend`] which contains most of the main methods. Users should
|
|
refer to this superclass for more information regarding those methods.
|
|
|
|
The tokenization method is `<tokens> <eos> <language code>` for source language documents, and `<language code>
|
|
<tokens> <eos>` for target language documents.
|
|
|
|
Examples:
|
|
|
|
```python
|
|
>>> from transformers import MBartTokenizer
|
|
|
|
>>> tokenizer = MBartTokenizer.from_pretrained(
|
|
... "facebook/mbart-large-en-ro", src_lang="en_XX", tgt_lang="ro_RO"
|
|
... )
|
|
>>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria"
|
|
>>> expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria"
|
|
>>> inputs = tokenizer(example_english_phrase, text_target=expected_translation_romanian, return_tensors="pt")
|
|
```"""
|
|
|
|
vocab_files_names = VOCAB_FILES_NAMES
|
|
model_input_names = ["input_ids", "attention_mask"]
|
|
model = Unigram
|
|
|
|
prefix_tokens: list[int] = []
|
|
suffix_tokens: list[int] = []
|
|
|
|
def __init__(
|
|
self,
|
|
vocab: str | dict | list | None = None,
|
|
bos_token="<s>",
|
|
eos_token="</s>",
|
|
sep_token="</s>",
|
|
cls_token="<s>",
|
|
unk_token="<unk>",
|
|
pad_token="<pad>",
|
|
mask_token="<mask>",
|
|
src_lang=None,
|
|
tgt_lang=None,
|
|
additional_special_tokens=None,
|
|
**kwargs,
|
|
):
|
|
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
|
|
|
|
_additional_special_tokens = FAIRSEQ_LANGUAGE_CODES.copy()
|
|
if additional_special_tokens is not None:
|
|
_additional_special_tokens.extend(
|
|
[t for t in additional_special_tokens if t not in _additional_special_tokens]
|
|
)
|
|
|
|
if vocab is None:
|
|
vocab = [
|
|
(str(bos_token), 0.0),
|
|
(str(pad_token), 0.0),
|
|
(str(eos_token), 0.0),
|
|
(str(unk_token), 0.0),
|
|
]
|
|
vocab += [("▁", -2.0)]
|
|
for lang_code in FAIRSEQ_LANGUAGE_CODES:
|
|
vocab.append((lang_code, 0.0))
|
|
vocab.append((str(mask_token), 0.0))
|
|
|
|
self._vocab = vocab
|
|
self._tokenizer = Tokenizer(Unigram(self._vocab, unk_id=3, byte_fallback=False))
|
|
|
|
self._tokenizer.normalizer = None
|
|
|
|
self._tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
|
|
[
|
|
pre_tokenizers.WhitespaceSplit(),
|
|
pre_tokenizers.Metaspace(replacement="▁", prepend_scheme="always", split=True),
|
|
]
|
|
)
|
|
|
|
self._tokenizer.decoder = decoders.Metaspace(replacement="▁", prepend_scheme="always", split=True)
|
|
|
|
super().__init__(
|
|
bos_token=bos_token,
|
|
eos_token=eos_token,
|
|
sep_token=sep_token,
|
|
cls_token=cls_token,
|
|
unk_token=unk_token,
|
|
pad_token=pad_token,
|
|
mask_token=mask_token,
|
|
src_lang=src_lang,
|
|
tgt_lang=tgt_lang,
|
|
additional_special_tokens=_additional_special_tokens,
|
|
**kwargs,
|
|
)
|
|
|
|
self.lang_code_to_id = {
|
|
lang_code: self.convert_tokens_to_ids(lang_code) for lang_code in FAIRSEQ_LANGUAGE_CODES
|
|
}
|
|
self.fairseq_offset = 1
|
|
|
|
# Build fairseq token mappings for backward compatibility
|
|
self.fairseq_tokens_to_ids = {
|
|
"<s>": 0,
|
|
"<pad>": 1,
|
|
"</s>": 2,
|
|
"<unk>": 3,
|
|
}
|
|
self.fairseq_tokens_to_ids.update(self.lang_code_to_id)
|
|
self.fairseq_tokens_to_ids["<mask>"] = self.convert_tokens_to_ids(str(mask_token))
|
|
self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
|
|
|
|
self._src_lang = src_lang if src_lang is not None else "en_XX"
|
|
self.cur_lang_code = self.convert_tokens_to_ids(self._src_lang)
|
|
self.tgt_lang = tgt_lang
|
|
self.set_src_lang_special_tokens(self._src_lang)
|
|
|
|
@property
|
|
def src_lang(self) -> str:
|
|
return self._src_lang
|
|
|
|
@src_lang.setter
|
|
def src_lang(self, new_src_lang: str) -> None:
|
|
self._src_lang = new_src_lang
|
|
self.set_src_lang_special_tokens(self._src_lang)
|
|
|
|
def _build_translation_inputs(
|
|
self, raw_inputs, return_tensors: str, src_lang: str | None, tgt_lang: str | None, **extra_kwargs
|
|
):
|
|
"""Used by translation pipeline, to prepare inputs for the generate function"""
|
|
if src_lang is None or tgt_lang is None:
|
|
raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
|
|
self.src_lang = src_lang
|
|
inputs = self(raw_inputs, add_special_tokens=True, return_tensors=return_tensors, **extra_kwargs)
|
|
tgt_lang_id = self.convert_tokens_to_ids(tgt_lang)
|
|
inputs["forced_bos_token_id"] = tgt_lang_id
|
|
return inputs
|
|
|
|
def _switch_to_input_mode(self):
|
|
return self.set_src_lang_special_tokens(self.src_lang)
|
|
|
|
def _switch_to_target_mode(self):
|
|
if self.tgt_lang is None:
|
|
self.tgt_lang = self._src_lang
|
|
return self.set_tgt_lang_special_tokens(self.tgt_lang)
|
|
|
|
def set_src_lang_special_tokens(self, src_lang) -> None:
|
|
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code]."""
|
|
self.cur_lang_code = self.convert_tokens_to_ids(src_lang)
|
|
self.prefix_tokens = []
|
|
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
|
|
|
|
prefix_tokens_str = self.convert_ids_to_tokens(self.prefix_tokens)
|
|
suffix_tokens_str = self.convert_ids_to_tokens(self.suffix_tokens)
|
|
|
|
self._tokenizer.post_processor = processors.TemplateProcessing(
|
|
single=prefix_tokens_str + ["$A"] + suffix_tokens_str,
|
|
pair=prefix_tokens_str + ["$A", "$B"] + suffix_tokens_str,
|
|
special_tokens=list(zip(prefix_tokens_str + suffix_tokens_str, self.prefix_tokens + self.suffix_tokens)),
|
|
)
|
|
|
|
def set_tgt_lang_special_tokens(self, lang: str) -> None:
|
|
"""Reset the special tokens to the target language setting. No prefix and suffix=[eos, tgt_lang_code]."""
|
|
self.cur_lang_code = self.convert_tokens_to_ids(lang)
|
|
self.prefix_tokens = []
|
|
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
|
|
|
|
prefix_tokens_str = self.convert_ids_to_tokens(self.prefix_tokens)
|
|
suffix_tokens_str = self.convert_ids_to_tokens(self.suffix_tokens)
|
|
|
|
self._tokenizer.post_processor = processors.TemplateProcessing(
|
|
single=prefix_tokens_str + ["$A"] + suffix_tokens_str,
|
|
pair=prefix_tokens_str + ["$A", "$B"] + suffix_tokens_str,
|
|
special_tokens=list(zip(prefix_tokens_str + suffix_tokens_str, self.prefix_tokens + self.suffix_tokens)),
|
|
)
|
|
|
|
|
|
__all__ = ["MBartTokenizer"]
|