You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
164 lines
7.0 KiB
164 lines
7.0 KiB
|
4 days ago
|
# Copyright 2024 Meta Inc. and The HuggingFace Inc. team. All rights reserved.
|
||
|
|
#
|
||
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
|
# you may not use this file except in compliance with the License.
|
||
|
|
# You may obtain a copy of the License at
|
||
|
|
#
|
||
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
|
#
|
||
|
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
|
# See the License for the specific language governing permissions and
|
||
|
|
# limitations under the License.
|
||
|
|
"""
|
||
|
|
Processor class for Chameleon.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import numpy as np
|
||
|
|
|
||
|
|
from ...feature_extraction_utils import BatchFeature
|
||
|
|
from ...image_utils import ImageInput
|
||
|
|
from ...processing_utils import (
|
||
|
|
MultiModalData,
|
||
|
|
ProcessingKwargs,
|
||
|
|
ProcessorMixin,
|
||
|
|
TextKwargs,
|
||
|
|
Unpack,
|
||
|
|
)
|
||
|
|
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
||
|
|
from ...utils import auto_docstring
|
||
|
|
|
||
|
|
|
||
|
|
class ChameleonTextKwargs(TextKwargs, total=False):
|
||
|
|
"""
|
||
|
|
return_for_text_completion (`bool`, *optional*, defaults to `False`):
|
||
|
|
Whether the processed text is intended for text completion tasks. When `True`, the processor does not
|
||
|
|
append the separator token (`sep_token`) to the end of the prompt, which is typically used for chat
|
||
|
|
mode. When `False`, the separator token is appended for proper chat formatting.
|
||
|
|
"""
|
||
|
|
|
||
|
|
return_for_text_completion: bool
|
||
|
|
|
||
|
|
|
||
|
|
class ChameleonProcessorKwargs(ProcessingKwargs, total=False):
|
||
|
|
text_kwargs: ChameleonTextKwargs
|
||
|
|
_defaults = {
|
||
|
|
"text_kwargs": {
|
||
|
|
"padding": False,
|
||
|
|
"return_for_text_completion": False,
|
||
|
|
"return_mm_token_type_ids": False,
|
||
|
|
},
|
||
|
|
"common_kwargs": {
|
||
|
|
"return_tensors": "pt",
|
||
|
|
},
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
@auto_docstring
|
||
|
|
class ChameleonProcessor(ProcessorMixin):
|
||
|
|
def __init__(self, image_processor, tokenizer, image_seq_length: int = 1024, image_token: str = "<image>"):
|
||
|
|
r"""
|
||
|
|
image_seq_length (`int`, *optional*, defaults to 1024):
|
||
|
|
Sequence length of one image embedding.
|
||
|
|
image_token (`str`, *optional*, defaults to `"<image>"`):
|
||
|
|
The special token used to indicate image in the text.
|
||
|
|
"""
|
||
|
|
self.image_seq_length = image_seq_length
|
||
|
|
self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token
|
||
|
|
self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
|
||
|
|
self.image_start_token = (
|
||
|
|
tokenizer.boi_token if hasattr(tokenizer, "boi_token") else "<racm3:break>"
|
||
|
|
) # fixed tokens for start and end, so can hardcode
|
||
|
|
self.image_end_token = tokenizer.eoi_token if hasattr(tokenizer, "eoi_token") else "<eoss>"
|
||
|
|
self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
|
||
|
|
self.image_start_token_id = tokenizer.convert_tokens_to_ids(self.image_start_token)
|
||
|
|
self.image_end_token_id = tokenizer.convert_tokens_to_ids(self.image_end_token)
|
||
|
|
self.image_ids = [self.image_token_id, self.image_start_token_id, self.image_end_token_id]
|
||
|
|
|
||
|
|
super().__init__(image_processor, tokenizer)
|
||
|
|
|
||
|
|
@auto_docstring
|
||
|
|
def __call__(
|
||
|
|
self,
|
||
|
|
images: ImageInput | None = None,
|
||
|
|
text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] | None = None,
|
||
|
|
**kwargs: Unpack[ChameleonProcessorKwargs],
|
||
|
|
) -> BatchFeature:
|
||
|
|
r"""
|
||
|
|
Returns:
|
||
|
|
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
||
|
|
|
||
|
|
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
||
|
|
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
||
|
|
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
||
|
|
`None`).
|
||
|
|
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
||
|
|
"""
|
||
|
|
|
||
|
|
if isinstance(text, str):
|
||
|
|
text = [text]
|
||
|
|
elif not isinstance(text, list) and not isinstance(text[0], str):
|
||
|
|
raise TypeError("Invalid input text. Please provide a string, or a list of strings")
|
||
|
|
if text is None and images is None:
|
||
|
|
raise ValueError("You must provide either text or images")
|
||
|
|
|
||
|
|
output_kwargs = self._merge_kwargs(
|
||
|
|
ChameleonProcessorKwargs,
|
||
|
|
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||
|
|
**kwargs,
|
||
|
|
)
|
||
|
|
return_for_text_completion = output_kwargs["text_kwargs"].pop("return_for_text_completion", False)
|
||
|
|
|
||
|
|
# Replace the image token with the expanded image token sequence
|
||
|
|
prompt_strings = []
|
||
|
|
one_img_tokens = self.image_start_token + (self.image_token * self.image_seq_length) + self.image_end_token
|
||
|
|
for sample in text:
|
||
|
|
sample = sample.replace(self.image_token, one_img_tokens)
|
||
|
|
if not return_for_text_completion:
|
||
|
|
sample += self.tokenizer.sep_token # special Chameleon treatment to add sep for chat mode
|
||
|
|
prompt_strings.append(sample)
|
||
|
|
|
||
|
|
image_inputs = {}
|
||
|
|
if images is not None:
|
||
|
|
image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
|
||
|
|
|
||
|
|
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
||
|
|
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
|
||
|
|
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"], return_tensors=None)
|
||
|
|
self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image"])
|
||
|
|
|
||
|
|
if return_mm_token_type_ids:
|
||
|
|
array_ids = np.array(text_inputs["input_ids"])
|
||
|
|
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
|
||
|
|
mm_token_type_ids[np.isin(array_ids, self.image_ids)] = 1
|
||
|
|
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
|
||
|
|
|
||
|
|
return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
|
||
|
|
|
||
|
|
def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
|
||
|
|
"""
|
||
|
|
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
image_sizes (`list[list[int]]`, *optional*):
|
||
|
|
The input sizes formatted as (height, width) per each image.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
`MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
|
||
|
|
input modalities, along with other useful data.
|
||
|
|
"""
|
||
|
|
|
||
|
|
vision_data = {}
|
||
|
|
if image_sizes is not None:
|
||
|
|
# add 2 for BOI and EOI tokens
|
||
|
|
num_image_tokens = [self.image_seq_length + 2] * len(image_sizes)
|
||
|
|
num_image_patches = [1] * len(image_sizes)
|
||
|
|
|
||
|
|
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
|
||
|
|
|
||
|
|
return MultiModalData(**vision_data)
|
||
|
|
|
||
|
|
|
||
|
|
__all__ = ["ChameleonProcessor"]
|