# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # This file was automatically generated from src/transformers/models/glmasr/modular_glmasr.py. # Do NOT edit this file manually as any edits will be overwritten by the generation of # the file from the modular. If any change should be done, please apply the change to the # modular_glmasr.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # Copyright 2025 the HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import re import numpy as np from ...audio_utils import AudioInput, make_list_of_audio from ...feature_extraction_utils import BatchFeature from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import TextInput from ...utils import is_torch_available, logging if is_torch_available(): import torch logger = logging.get_logger(__name__) class GlmAsrProcessorKwargs(ProcessingKwargs, total=False): _defaults = { "text_kwargs": { "padding": True, }, "audio_kwargs": { "sampling_rate": 16000, "chunk_length": 30.0, "return_attention_mask": True, "padding": "max_length", }, "common_kwargs": { "return_tensors": "pt", "padding_side": "left", }, } class GlmAsrProcessor(ProcessorMixin): r""" Constructs an GlmAsr processor which wraps an GlmAsr feature extractor and an GlmAsr tokenizer into a single processor. [`GlmAsrProcessor`] offers all the functionalities of [`WhisperFeatureExtractor`] and [`Qwen2TokenizerFast`]. See the [`~GlmAsrProcessor.__call__`] for more information. Args: feature_extractor ([`WhisperFeatureExtractor`]): The feature extractor is a required input. tokenizer ([`Qwen2TokenizerFast`]): The tokenizer is a required input. chat_template (`Optional[str]`, *optional*): The Jinja template to use for formatting the conversation. If not provided, the tokenizer's default chat template will be used. audio_token (`Optional[str]`, *optional*, defaults to `"<|pad|>`"): Special token used to represent audio inputs in the chat template. default_transcription_prompt (`str`, *optional*, defaults to `"Please transcribe this audio into text"`): Default prompt to use for transcription tasks when applying transcription requests. max_audio_len (`int`, *optional*, defaults to 655): Maximum length of audio sequences in seconds. Audio longer than this will be truncated. 655 gives approximately 8192 tokens, corresponding to the maximum sequence length of the text model. """ def __init__( self, feature_extractor, tokenizer, chat_template=None, audio_token="<|pad|>", default_transcription_prompt="Please transcribe this audio into text", max_audio_len=655, ): self.audio_token = audio_token self.audio_token_id = tokenizer.convert_tokens_to_ids(audio_token) self.default_transcription_prompt = default_transcription_prompt self.max_audio_len = max_audio_len super().__init__(feature_extractor, tokenizer, chat_template=chat_template) def _get_audio_token_length(self, audio_lengths: "torch.Tensor") -> "torch.Tensor": merge_factor = 4 for padding, kernel_size, stride in [(1, 3, 1), (1, 3, 2)]: audio_lengths = (audio_lengths + 2 * padding - (kernel_size - 1) - 1) // stride + 1 num_tokens = (audio_lengths - merge_factor) // merge_factor + 1 return num_tokens def __call__( self, text: TextInput | list[TextInput], audio: AudioInput | None = None, output_labels: bool | None = False, **kwargs: Unpack[GlmAsrProcessorKwargs], ) -> BatchFeature: r""" Main method to prepare one or several text sequence(s) and audio waveform(s) for the model. This method expands `` placeholders in the text based on the post-pool frame counts of the audio windows, then tokenizes the provided strings as-is, and extracts log-mel features with [`WhisperFeatureExtractor`]. If `audio` is `None`, no audio processing is performed and the text is tokenized as-is (LM-only behavior). Args: text (`str` or `list[str]`): Input sequence or batch of sequences. audio (`np.ndarray` or `list[np.ndarray]`): Input audio or batch of audios as NumPy arrays. If provided, there must be as many `text` inputs as `audio` inputs. output_labels (bool, *optional*, default=False): Whether to return labels for training. Returns: [`BatchFeature`]: A dictionary with tokenized text (`input_ids`, `attention_mask`) and audio features (`input_features`, `input_features_mask`). """ # Merge defaults with user kwargs call_kwargs = self._merge_kwargs( GlmAsrProcessorKwargs, tokenizer_init_kwargs=self.tokenizer.init_kwargs, **kwargs, ) text_kwargs = call_kwargs["text_kwargs"] audio_kwargs = call_kwargs["audio_kwargs"] return_tensors = text_kwargs.get("return_tensors") if return_tensors != "pt": raise ValueError(f"{self.__class__.__name__} only supports `return_tensors='pt'`.") if isinstance(text, str): text = [text] elif not (isinstance(text, (list, tuple)) and all(isinstance(t, str) for t in text)): raise ValueError("Invalid input text. Please provide a string, or a list of strings") audio_inputs = {} if audio is not None: audio = make_list_of_audio(audio) if len(text) != len(audio): raise ValueError(f"Got {len(text)} text but {len(audio)} audios; they must match 1:1.") # Determine number of chunks per sample, and flatten window_size = int(audio_kwargs["sampling_rate"] * audio_kwargs["chunk_length"]) max_windows = int(self.max_audio_len // audio_kwargs["chunk_length"]) per_sample_windows: list[int] = [] flat_chunks: list[np.ndarray] = [] for audio_el in audio: n_samples = int(audio_el.shape[0]) n_win = max(1, (n_samples + window_size - 1) // window_size) if n_win > max_windows: logger.warning( f"Audio duration ({n_samples / audio_kwargs['sampling_rate']:.1f}s) exceeds {self.max_audio_len}s; truncating to first {self.max_audio_len}s." ) n_win = max_windows per_sample_windows.append(n_win) time_cap = min(n_samples, n_win * window_size) for i in range(n_win): start = i * window_size end = min((i + 1) * window_size, time_cap) flat_chunks.append(audio_el[start:end]) # Feature extraction audio_inputs = self.feature_extractor(flat_chunks, **audio_kwargs) padding_mask = audio_inputs.pop("attention_mask") audio_inputs["input_features_mask"] = padding_mask # Compute sequence lengths token counting audio_lengths = torch.stack([s.sum() for s in torch.split(padding_mask.sum(-1), per_sample_windows)]) audio_tokens_lengths = self._get_audio_token_length(audio_lengths) # expand audio tokens in text for i, audio_length in enumerate(audio_tokens_lengths): expanded = re.sub(re.escape(self.audio_token), self.audio_token * audio_length, text[i]) text[i] = expanded # Tokenize text_inputs = self.tokenizer(text, **text_kwargs) data = {**text_inputs, **audio_inputs} if output_labels: labels = data["input_ids"].clone() labels[labels == self.audio_token_id] = -100 labels[labels == self.tokenizer.pad_token_id] = -100 data["labels"] = labels return BatchFeature(data=data, tensor_type=return_tensors) @property def model_input_names(self) -> list[str]: tok_names = self.tokenizer.model_input_names fea_names = self.feature_extractor.model_input_names return list(dict.fromkeys(tok_names + fea_names + ["input_features_mask"])) def apply_transcription_request( self, audio: str | list[str] | AudioInput, prompt: str | list[str] | None = None, **kwargs: Unpack[GlmAsrProcessorKwargs], ) -> BatchFeature: """ Prepare inputs for automatic speech recognition without manually writing the default transcription prompt. Args: audio (`str`, `list[str]`, `np.ndarray`, `torch.Tensor`, `list[np.ndarray]`, `list[torch.Tensor]`): Audio to transcribe. Strings are interpreted as local paths or URLs and will be loaded automatically by the chat template loader; NumPy arrays and PyTorch tensors are forwarded directly. prompt (`str` or `list[str]`, *optional*): Custom prompt(s) to include in the user turn. A list must be the same length as the batch. When `None`, each sample uses `"Transcribe the input speech."`. **kwargs: Additional keyword arguments forwarded to [`~AudioFlamingo3Processor.apply_chat_template`] (for example `text_kwargs`, `audio_kwargs`, ...). Returns: [`BatchFeature`]: Processor outputs ready to be passed to [`AudioFlamingo3ForConditionalGeneration.generate`]. """ if isinstance(audio, str): audio_items: list[str | np.ndarray] = [audio] elif isinstance(audio, (list, tuple)) and audio and all(isinstance(el, str) for el in audio): audio_items = list(audio) else: audio_items = list(make_list_of_audio(audio)) if is_torch_available(): audio_items = [el.detach().cpu().numpy() if isinstance(el, torch.Tensor) else el for el in audio_items] batch_size = len(audio_items) if batch_size == 0: raise ValueError("`audio` must contain at least one sample.") if prompt is None: prompts = [self.default_transcription_prompt] * batch_size elif isinstance(prompt, str): prompts = [prompt] * batch_size elif isinstance(prompt, (list, tuple)): if len(prompt) != batch_size: raise ValueError( f"Received {len(prompt)} prompt(s) for {batch_size} audio sample(s); counts must match." ) prompts = [] for item in prompt: if item is None: prompts.append(self.default_transcription_prompt) elif isinstance(item, str): prompts.append(item) else: raise TypeError("Each prompt must be a string or `None`.") else: raise TypeError("`prompt` must be a string, a sequence of strings, or `None`.") conversations = [ [ { "role": "user", "content": [ {"type": "audio", "path": audio_item} if isinstance(audio_item, str) else {"type": "audio", "audio": audio_item}, {"type": "text", "text": prompt_text}, ], } ] for prompt_text, audio_item in zip(prompts, audio_items) ] return self.apply_chat_template( conversations, tokenize=True, add_generation_prompt=True, return_dict=True, **kwargs, ) def batch_decode(self, *args, strip_prefix=False, **kwargs): """ Forward arguments to [`~PreTrainedTokenizer.batch_decode`] and optionally remove the assistant framing the model was trained to produce. AF3 transcription requests respond with sentences such as `"The spoken content of the audio is \"...\"."`. Setting `strip_prefix=True` trims the fixed prefix for just the transcription text. """ decoded = self.tokenizer.batch_decode(*args, **kwargs) if strip_prefix: decoded = [self._strip_assistant_prefix_and_quotes(text) for text in decoded] return decoded def _strip_assistant_prefix_and_quotes(self, text: str) -> str: """ Remove the assistant prefix and surrounding quotes from a decoded transcription string. """ stripped = text.strip() for prefix in ( "The spoken content of the audio is", "The transcription of the audio is", ): if stripped.startswith(prefix): stripped = stripped[len(prefix) :].strip() break if stripped.endswith("."): stripped = stripped[:-1].strip() if len(stripped) >= 2 and stripped[0] == stripped[-1] and stripped[0] in {"'", '"'}: stripped = stripped[1:-1].strip() return stripped __all__ = ["GlmAsrProcessor"]