You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

312 lines
14 KiB

4 days ago
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Processor class for Mllama."""
import numpy as np
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput, make_nested_list_of_images
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import auto_docstring
class MllamaProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"image_kwargs": {
"max_image_tiles": 4,
},
}
def get_cross_attention_token_mask(input_ids: list[int], image_token_id: int) -> list[list[int]]:
"""
Generate a cross-attention token mask for image tokens in the input sequence.
This function identifies the positions of image tokens in the input sequence and creates
a mask that defines which subsequent tokens each image token should attend to.
Args:
input_ids (list[int]): A list of token ids representing the input sequence.
image_token_id (int): The id of the token used to represent images in the sequence.
Returns:
list[list[int]]: A list of [start, end] pairs, where each pair represents the range
of tokens an image token should attend to.
Notes:
- If no image tokens are present, an empty list is returned.
- For a single image token, it attends to all subsequent tokens until the end of the sequence.
- For multiple image tokens, each attends to tokens up to the next image token or the end of the sequence.
- Consecutive image tokens are treated as a group and attend to all subsequent tokens together.
"""
image_token_locations = [i for i, token in enumerate(input_ids) if token == image_token_id]
if len(image_token_locations) == 0:
return []
# only one image present, unmask until end of sequence
if len(image_token_locations) == 1:
return [[image_token_locations[0], -1]]
vision_masks = [[loc1, loc2] for loc1, loc2 in zip(image_token_locations[:-1], image_token_locations[1:])]
# last image will attend to all subsequent text
vision_masks.append([image_token_locations[-1], len(input_ids)])
# if there are two or more consecutive vision tokens,
# they should all attend to all subsequent
# text present
last_mask_end = vision_masks[-1][1]
for vision_mask in vision_masks[::-1]:
if vision_mask[0] == vision_mask[1] - 1:
vision_mask[1] = last_mask_end
last_mask_end = vision_mask[1]
return vision_masks
def convert_sparse_cross_attention_mask_to_dense(
cross_attention_token_mask: list[list[list[int]]],
num_tiles: list[list[int]],
max_num_tiles: int,
length: int,
) -> np.ndarray:
"""
Convert the cross attention mask indices to a cross attention mask 4D array.
This function takes a sparse representation of cross attention masks and converts it to a dense 4D numpy array.
The sparse representation is a nested list structure that defines attention ranges for each image in each batch item.
Args:
cross_attention_token_mask (list[list[list[int]]]): A nested list structure where:
- The outer list represents the batch dimension.
- The middle list represents different images within each batch item.
- The inner list contains pairs of integers [start, end] representing token ranges for each image.
num_tiles (list[list[int]]): A nested list structure specifying the number of tiles for each image in each batch item.
max_num_tiles (int): The maximum possible number of tiles.
length (int): The total sequence length of the input.
Returns:
np.ndarray: A 4D numpy array of shape (batch_size, length, max_num_images, max_num_tiles)
The array contains `1` where attention is allowed and `0` where it is not.
Note:
- Special handling is done for cases where the end token is -1, which is interpreted as attending to the end of the sequence.
"""
batch_size = len(cross_attention_token_mask)
max_num_images = max(len(masks) for masks in cross_attention_token_mask)
cross_attention_mask = np.zeros(
shape=(batch_size, length, max_num_images, max_num_tiles),
dtype=np.int64,
)
for sample_idx, (sample_masks, sample_num_tiles) in enumerate(zip(cross_attention_token_mask, num_tiles)):
for mask_idx, (locations, mask_num_tiles) in enumerate(zip(sample_masks, sample_num_tiles)):
if len(locations) == 2:
start, end = locations
end = min(end, length)
if end == -1:
end = length
cross_attention_mask[sample_idx, start:end, mask_idx, :mask_num_tiles] = 1
return cross_attention_mask
def build_string_from_input(prompt: str, bos_token: str, image_token: str) -> str:
"""
Builds a string from the input prompt by adding `bos_token` if not already present.
Args:
prompt (`str`):
The input prompt string.
bos_token (`str`):
The beginning of sentence token to be added.
image_token (`str`):
The image token used to identify the start of an image sequence.
Returns:
str: The modified prompt string with the `bos_token` added if necessary.
Examples:
>>> build_string_from_input("Hello world", "<begin_of_text>", "<|image|>")
'<begin_of_text>Hello world'
>>> build_string_from_input("<|image|>Hello world", "<begin_of_text>", "<|image|>")
'<|image|><begin_of_text>Hello world'
>>> build_string_from_input("<begin_of_text>Hello world", "<begin_of_text>", "<|image|>")
'<begin_of_text>Hello world'
"""
if bos_token in prompt:
return prompt
num_image_tokens_on_start = 0
while prompt.startswith(image_token):
prompt = prompt[len(image_token) :]
num_image_tokens_on_start += 1
return f"{image_token * num_image_tokens_on_start}{bos_token}{prompt}"
@auto_docstring
class MllamaProcessor(ProcessorMixin):
def __init__(self, image_processor, tokenizer, chat_template=None):
if not hasattr(tokenizer, "image_token"):
self.image_token = "<|image|>"
self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
else:
self.image_token = tokenizer.image_token
self.image_token_id = tokenizer.image_token_id
self.python_token = "<|python_tag|>"
self.python_token_id = tokenizer.convert_tokens_to_ids(self.python_token)
self.bos_token = tokenizer.bos_token
super().__init__(image_processor, tokenizer, chat_template=chat_template)
@auto_docstring
def __call__(
self,
images: ImageInput | None = None,
text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] | None = None,
**kwargs: Unpack[MllamaProcessorKwargs],
) -> BatchFeature:
r"""
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
TODO: add aspect_ratio_ids and aspect_ratio_mask and cross_attention_mask
"""
if text is None and images is None:
raise ValueError("You must specify either text or images.")
output_kwargs = self._merge_kwargs(
MllamaProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
data = {}
if text is not None:
if isinstance(text, str):
text = [text]
elif not (isinstance(text, (list, tuple)) and all(isinstance(t, str) for t in text)):
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
n_images_in_text = [t.count(self.image_token) for t in text]
text = [build_string_from_input(text_item, self.bos_token, self.image_token) for text_item in text]
encoding = self.tokenizer(text, **output_kwargs["text_kwargs"])
self._check_special_mm_tokens(text, encoding, modalities=["image"])
n_images_in_ids = [token_ids.count(self.image_token_id) for token_ids in encoding["input_ids"]]
data.update(encoding)
n_images_in_images = [0]
if images is not None:
images = self.image_processor.fetch_images(images)
images = make_nested_list_of_images(images)
n_images_in_images = [len(sample) for sample in images]
if text is not None:
if any(batch_img == 0 for batch_img in n_images_in_text) and not all(
batch_img == 0 for batch_img in n_images_in_text
):
raise ValueError(
"If a batch of text is provided, there should be either no images or at least one image per sample"
)
if sum(n_images_in_text) > 0 and (
n_images_in_images != n_images_in_text or n_images_in_ids != n_images_in_images
):
if images is None:
raise ValueError("No image were provided, but there are image tokens in the prompt")
else:
add_message = ""
if sum(n_images_in_images) == sum(n_images_in_text) and n_images_in_images != n_images_in_text:
add_message = "Make sure to pass your images as a nested list, where each sub-list holds images per batch"
elif n_images_in_ids != n_images_in_images:
add_message = "If you activated truncation with `max_length`, increase the `max_length` so image tokens aren't cropped."
raise ValueError(
f"The number of image tokens in each text ({n_images_in_text}) should be the same as the "
f"number of provided images per batch ({n_images_in_images}). {add_message}"
)
if images is not None:
image_features = self.image_processor(images, **output_kwargs["images_kwargs"])
num_tiles = image_features.pop("num_tiles")
data.update(image_features)
# Create cross attention mask
if images is not None and text is not None:
cross_attention_token_mask = [
get_cross_attention_token_mask(token_ids, self.image_token_id) for token_ids in encoding["input_ids"]
]
cross_attention_mask = convert_sparse_cross_attention_mask_to_dense(
cross_attention_token_mask,
num_tiles=num_tiles,
max_num_tiles=self.image_processor.max_image_tiles,
length=max(len(input_ids) for input_ids in encoding["input_ids"]),
)
data["cross_attention_mask"] = cross_attention_mask
return BatchFeature(data=data, tensor_type=return_tensors)
def post_process_image_text_to_text(
self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs
):
"""
Post-process the output of the model to decode the text.
Args:
generated_outputs (`torch.Tensor` or `np.ndarray`):
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
or `(sequence_length,)`.
skip_special_tokens (`bool`, *optional*, defaults to `True`):
Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method.
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method.
**kwargs:
Additional arguments to be passed to the tokenizer's `batch_decode method`.
Returns:
`list[str]`: The decoded text.
"""
return self.tokenizer.batch_decode(
generated_outputs,
skip_special_tokens=skip_special_tokens,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
**kwargs,
)
@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
# Remove `num_tiles`, it is popped and used only when processing. Make a copy of list when removing
# otherwise `self.image_processor.model_input_names` is also modified
image_processor_input_names = [name for name in image_processor_input_names if name != "num_tiles"]
return list(tokenizer_input_names + image_processor_input_names + ["cross_attention_mask"])
__all__ = ["MllamaProcessor"]