You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
101 lines
4.4 KiB
101 lines
4.4 KiB
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
|
# This file was automatically generated from src/transformers/models/lasr/modular_lasr.py.
|
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
|
# the file from the modular. If any change should be done, please apply the change to the
|
|
# modular_lasr.py file directly. One of our CI enforces this.
|
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
|
# Copyright 2025 The HuggingFace Inc. team and Google LLC. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from ...audio_utils import AudioInput, make_list_of_audio
|
|
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
|
|
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
|
from ...utils import auto_docstring, logging
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
class LasrProcessorKwargs(ProcessingKwargs, total=False):
|
|
_defaults = {
|
|
"audio_kwargs": {
|
|
"sampling_rate": 16000,
|
|
"padding": "longest",
|
|
"return_attention_mask": True,
|
|
},
|
|
"text_kwargs": {
|
|
"padding": True,
|
|
"padding_side": "right",
|
|
"add_special_tokens": False,
|
|
},
|
|
"common_kwargs": {"return_tensors": "pt"},
|
|
}
|
|
|
|
|
|
@auto_docstring
|
|
class LasrProcessor(ProcessorMixin):
|
|
def __init__(self, feature_extractor, tokenizer):
|
|
super().__init__(feature_extractor, tokenizer)
|
|
|
|
@auto_docstring
|
|
def __call__(
|
|
self,
|
|
audio: AudioInput,
|
|
text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] | None = None,
|
|
sampling_rate: int | None = None,
|
|
**kwargs: Unpack[LasrProcessorKwargs],
|
|
):
|
|
r"""
|
|
sampling_rate (`int`, *optional*):
|
|
The sampling rate of the input audio in Hz. This should match the sampling rate expected by the feature
|
|
extractor (defaults to 16000 Hz). If provided, it will be validated against the processor's expected
|
|
sampling rate, and an error will be raised if they don't match. If not provided, a warning will be
|
|
issued and the default sampling rate will be assumed.
|
|
"""
|
|
audio = make_list_of_audio(audio)
|
|
|
|
output_kwargs = self._merge_kwargs(
|
|
LasrProcessorKwargs,
|
|
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
|
**kwargs,
|
|
)
|
|
|
|
if sampling_rate is None:
|
|
logger.warning_once(
|
|
f"You've provided audio without specifying the sampling rate. It will be assumed to be {output_kwargs['audio_kwargs']['sampling_rate']}, which can result in silent errors."
|
|
)
|
|
elif sampling_rate != output_kwargs["audio_kwargs"]["sampling_rate"]:
|
|
raise ValueError(
|
|
f"The sampling rate of the audio ({sampling_rate}) does not match the sampling rate of the processor ({output_kwargs['audio_kwargs']['sampling_rate']}). Please provide resampled the audio to the expected sampling rate."
|
|
)
|
|
|
|
if audio is not None:
|
|
inputs = self.feature_extractor(audio, **output_kwargs["audio_kwargs"])
|
|
if text is not None:
|
|
encodings = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
|
|
|
if text is None:
|
|
return inputs
|
|
else:
|
|
inputs["labels"] = encodings["input_ids"]
|
|
return inputs
|
|
|
|
@property
|
|
def model_input_names(self):
|
|
feature_extractor_input_names = self.feature_extractor.model_input_names
|
|
return feature_extractor_input_names + ["labels"]
|
|
|
|
|
|
__all__ = ["LasrProcessor"]
|