You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

143 lines
6.1 KiB

# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Processor class for Nougat.
"""
from typing import Optional, Union
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput, TruncationStrategy
from ...processing_utils import ProcessorMixin
from ...utils import PaddingStrategy, TensorType, auto_docstring
@auto_docstring
class NougatProcessor(ProcessorMixin):
def __init__(self, image_processor, tokenizer):
super().__init__(image_processor, tokenizer)
@auto_docstring
def __call__(
self,
images=None,
text=None,
do_crop_margin: bool | None = None,
do_resize: bool | None = None,
size: dict[str, int] | None = None,
resample: "PILImageResampling" = None, # noqa: F821
do_thumbnail: bool | None = None,
do_align_long_axis: bool | None = None,
do_pad: bool | None = None,
do_rescale: bool | None = None,
rescale_factor: int | float | None = None,
do_normalize: bool | None = None,
image_mean: float | list[float] | None = None,
image_std: float | list[float] | None = None,
data_format: Optional["ChannelDimension"] = "channels_first", # noqa: F821
input_data_format: Union[str, "ChannelDimension"] | None = None, # noqa: F821
text_pair: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] | None = None,
text_target: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] | None = None,
text_pair_target: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] | None = None,
add_special_tokens: bool = True,
padding: bool | str | PaddingStrategy = False,
truncation: bool | str | TruncationStrategy | None = None,
max_length: int | None = None,
stride: int = 0,
is_split_into_words: bool = False,
pad_to_multiple_of: int | None = None,
return_tensors: str | TensorType | None = None,
return_token_type_ids: bool | None = None,
return_attention_mask: bool | None = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_length: bool = False,
verbose: bool = True,
):
r"""
do_crop_margin (`bool`, *optional*):
Whether to automatically crop white margins from document images. When enabled, the processor detects
and removes white space around the edges of document pages, which is useful for processing scanned
documents or PDFs with large margins.
do_thumbnail (`bool`, *optional*):
Whether to create a thumbnail version of the image. When enabled, a smaller version of the image is
generated alongside the main processed image, which can be useful for preview or faster processing.
do_align_long_axis (`bool`, *optional*):
Whether to automatically align images so that the longer axis is horizontal. When enabled, portrait
images are rotated to landscape orientation, which is typically better for document processing tasks.
"""
if images is None and text is None:
raise ValueError("You need to specify either an `images` or `text` input to process.")
if images is not None:
inputs = self.image_processor(
images,
do_crop_margin=do_crop_margin,
do_resize=do_resize,
size=size,
resample=resample,
do_thumbnail=do_thumbnail,
do_align_long_axis=do_align_long_axis,
do_pad=do_pad,
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
return_tensors=return_tensors,
data_format=data_format,
input_data_format=input_data_format,
)
if text is not None:
encodings = self.tokenizer(
text,
text_pair=text_pair,
text_target=text_target,
text_pair_target=text_pair_target,
add_special_tokens=add_special_tokens,
padding=padding,
truncation=truncation,
max_length=max_length,
stride=stride,
is_split_into_words=is_split_into_words,
pad_to_multiple_of=pad_to_multiple_of,
return_tensors=return_tensors,
return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_length=return_length,
verbose=verbose,
)
if text is None:
return inputs
elif images is None:
return encodings
else:
inputs["labels"] = encodings["input_ids"]
return inputs
def post_process_generation(self, *args, **kwargs):
"""
This method forwards all its arguments to NougatTokenizer's [`~PreTrainedTokenizer.post_process_generation`].
Please refer to the docstring of this method for more information.
"""
return self.tokenizer.post_process_generation(*args, **kwargs)
__all__ = ["NougatProcessor"]