You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
862 lines
40 KiB
862 lines
40 KiB
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
|
# This file was automatically generated from src/transformers/models/glm46v/modular_glm46v.py.
|
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
|
# the file from the modular. If any change should be done, please apply the change to the
|
|
# modular_glm46v.py file directly. One of our CI enforces this.
|
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
|
# Copyright 2025 the HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
|
|
import itertools
|
|
from dataclasses import dataclass
|
|
from typing import Any
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from ...cache_utils import Cache
|
|
from ...generation import GenerationMixin
|
|
from ...modeling_outputs import BaseModelOutputWithPooling, ModelOutput
|
|
from ...modeling_utils import PreTrainedModel
|
|
from ...processing_utils import Unpack
|
|
from ...utils import (
|
|
TransformersKwargs,
|
|
auto_docstring,
|
|
can_return_tuple,
|
|
is_torchdynamo_compiling,
|
|
torch_compilable_check,
|
|
)
|
|
from ..auto import AutoModel
|
|
from .configuration_glm46v import Glm46VConfig
|
|
|
|
|
|
@auto_docstring
|
|
class Glm46VPreTrainedModel(PreTrainedModel):
|
|
config: Glm46VConfig
|
|
base_model_prefix = "model"
|
|
input_modalities = ("image", "video", "text")
|
|
supports_gradient_checkpointing = True
|
|
_no_split_modules = None
|
|
_skip_keys_device_placement = "past_key_values"
|
|
_supports_flash_attn = True
|
|
_supports_sdpa = True
|
|
|
|
_can_compile_fullgraph = True
|
|
_supports_attention_backend = True
|
|
_can_record_outputs = None
|
|
|
|
|
|
@dataclass
|
|
@auto_docstring(
|
|
custom_intro="""
|
|
Base class for Llava outputs, with hidden states and attentions.
|
|
"""
|
|
)
|
|
class Glm46VModelOutputWithPast(ModelOutput):
|
|
r"""
|
|
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
|
It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
|
|
|
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
|
`past_key_values` input) to speed up sequential decoding.
|
|
rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
|
|
The rope index difference between sequence length and multimodal rope.
|
|
"""
|
|
|
|
last_hidden_state: torch.FloatTensor | None = None
|
|
past_key_values: Cache | None = None
|
|
hidden_states: tuple[torch.FloatTensor] | None = None
|
|
attentions: tuple[torch.FloatTensor] | None = None
|
|
rope_deltas: torch.LongTensor | None = None
|
|
|
|
|
|
@auto_docstring
|
|
class Glm46VModel(Glm46VPreTrainedModel):
|
|
base_model_prefix = "model"
|
|
_checkpoint_conversion_mapping = {}
|
|
# Reference: fix gemma3 grad acc #37208
|
|
accepts_loss_kwargs = False
|
|
config: Glm46VConfig
|
|
_no_split_modules = None
|
|
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.visual = AutoModel.from_config(config.vision_config)
|
|
self.language_model = AutoModel.from_config(config.text_config)
|
|
self.rope_deltas = None # cache rope_deltas here
|
|
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
|
|
def get_input_embeddings(self):
|
|
return self.language_model.get_input_embeddings()
|
|
|
|
def set_input_embeddings(self, value):
|
|
self.language_model.set_input_embeddings(value)
|
|
|
|
def get_rope_index(
|
|
self,
|
|
input_ids: torch.LongTensor | None = None,
|
|
image_grid_thw: torch.LongTensor | None = None,
|
|
video_grid_thw: torch.LongTensor | None = None,
|
|
attention_mask: torch.Tensor | None = None,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Calculate the 3D rope index based on image and video's temporal, height and width in LLM.
|
|
|
|
Explanation:
|
|
Each embedding sequence contains vision embedding and text embedding or just contains text embedding.
|
|
|
|
For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs.
|
|
Examples:
|
|
input_ids: [T T T T T], here T is for text.
|
|
temporal position_ids: [0, 1, 2, 3, 4]
|
|
height position_ids: [0, 1, 2, 3, 4]
|
|
width position_ids: [0, 1, 2, 3, 4]
|
|
|
|
For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part
|
|
and 1D rotary position embedding for text part.
|
|
Examples:
|
|
Temporal (Time): 3 patches, representing different segments of the video in time.
|
|
Height: 2 patches, dividing each frame vertically.
|
|
Width: 2 patches, dividing each frame horizontally.
|
|
We also have some important parameters:
|
|
fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each second.
|
|
tokens_per_second: This is a crucial parameter. It dictates how many "time-steps" or "temporal tokens" are conceptually packed into a one-second interval of the video. In this case, we have 25 tokens per second. So each second of the video will be represented with 25 separate time points. It essentially defines the temporal granularity.
|
|
temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames.
|
|
interval: The step size for the temporal position IDs, calculated as tokens_per_second * temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be have a difference of 50 in the temporal position IDs.
|
|
input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision.
|
|
vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100]
|
|
vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]
|
|
vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
|
|
text temporal position_ids: [101, 102, 103, 104, 105]
|
|
text height position_ids: [101, 102, 103, 104, 105]
|
|
text width position_ids: [101, 102, 103, 104, 105]
|
|
Here we calculate the text start position_ids as the max vision position_ids plus 1.
|
|
|
|
Args:
|
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
|
it.
|
|
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
|
|
The temporal, height and width of feature shape of each image in LLM.
|
|
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
|
|
The temporal, height and width of feature shape of each video in LLM.
|
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
|
|
|
- 1 for tokens that are **not masked**,
|
|
- 0 for tokens that are **masked**.
|
|
|
|
Returns:
|
|
position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`)
|
|
mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`)
|
|
"""
|
|
|
|
spatial_merge_size = self.config.vision_config.spatial_merge_size
|
|
image_token_id = self.config.image_token_id
|
|
video_start_token_id = self.config.video_start_token_id
|
|
video_end_token_id = self.config.video_end_token_id
|
|
|
|
mrope_position_deltas = []
|
|
if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
|
|
total_input_ids = input_ids
|
|
if attention_mask is None:
|
|
attention_mask = torch.ones_like(total_input_ids)
|
|
position_ids = torch.ones(
|
|
3,
|
|
input_ids.shape[0],
|
|
input_ids.shape[1],
|
|
dtype=input_ids.dtype,
|
|
device=input_ids.device,
|
|
)
|
|
image_index, video_index = 0, 0
|
|
video_group_index = 0
|
|
attention_mask = attention_mask.to(total_input_ids.device)
|
|
for i, input_ids in enumerate(total_input_ids):
|
|
input_ids = input_ids[attention_mask[i] == 1]
|
|
input_tokens = input_ids.tolist()
|
|
|
|
input_token_type = []
|
|
video_check_flg = False
|
|
for token in input_tokens:
|
|
if token == video_start_token_id:
|
|
video_check_flg = True
|
|
elif token == video_end_token_id:
|
|
video_check_flg = False
|
|
|
|
if token == image_token_id and not video_check_flg:
|
|
input_token_type.append("image")
|
|
elif token == image_token_id and video_check_flg:
|
|
input_token_type.append("video")
|
|
else:
|
|
input_token_type.append("text")
|
|
|
|
input_type_group = []
|
|
for key, group in itertools.groupby(enumerate(input_token_type), lambda x: x[1]):
|
|
group = list(group)
|
|
start_index = group[0][0]
|
|
end_index = group[-1][0] + 1
|
|
input_type_group.append((key, start_index, end_index))
|
|
|
|
llm_pos_ids_list = []
|
|
video_frame_num = 1
|
|
for modality_type, start_idx, end_idx in input_type_group:
|
|
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
|
|
|
if modality_type == "image":
|
|
t, h, w = (
|
|
image_grid_thw[image_index][0],
|
|
image_grid_thw[image_index][1],
|
|
image_grid_thw[image_index][2],
|
|
)
|
|
llm_grid_t, llm_grid_h, llm_grid_w = (
|
|
t.item(),
|
|
h.item() // spatial_merge_size,
|
|
w.item() // spatial_merge_size,
|
|
)
|
|
|
|
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
|
|
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
|
|
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
|
|
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + st_idx)
|
|
|
|
image_index += 1
|
|
video_frame_num = 1
|
|
|
|
elif modality_type == "video":
|
|
t, h, w = (
|
|
video_frame_num,
|
|
video_grid_thw[video_index][1],
|
|
video_grid_thw[video_index][2],
|
|
)
|
|
|
|
llm_grid_t, llm_grid_h, llm_grid_w = (
|
|
t,
|
|
h.item() // spatial_merge_size,
|
|
w.item() // spatial_merge_size,
|
|
)
|
|
|
|
for t_idx in range(llm_grid_t):
|
|
t_index = torch.tensor(t_idx).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
|
|
|
|
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(1, -1, llm_grid_w).flatten()
|
|
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(1, llm_grid_h, -1).flatten()
|
|
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + st_idx)
|
|
|
|
video_group_index += 1
|
|
|
|
if video_group_index >= video_grid_thw[video_index][0]:
|
|
video_index += 1
|
|
video_group_index = 0
|
|
|
|
video_frame_num += 1
|
|
|
|
else:
|
|
text_len = end_idx - start_idx
|
|
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
|
|
|
video_frame_num = 1
|
|
|
|
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
|
position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
|
|
mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
|
|
mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
|
|
return position_ids, mrope_position_deltas
|
|
else:
|
|
if attention_mask is not None:
|
|
position_ids = attention_mask.long().cumsum(-1) - 1
|
|
position_ids.masked_fill_(attention_mask == 0, 1)
|
|
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
|
|
max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
|
|
mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
|
|
else:
|
|
position_ids = (
|
|
torch.arange(input_ids.shape[1], device=input_ids.device)
|
|
.view(1, 1, -1)
|
|
.expand(3, input_ids.shape[0], -1)
|
|
)
|
|
mrope_position_deltas = torch.zeros(
|
|
[input_ids.shape[0], 1],
|
|
device=input_ids.device,
|
|
dtype=input_ids.dtype,
|
|
)
|
|
|
|
return position_ids, mrope_position_deltas
|
|
|
|
@can_return_tuple
|
|
@auto_docstring
|
|
def get_video_features(
|
|
self,
|
|
pixel_values_videos: torch.FloatTensor,
|
|
video_grid_thw: torch.LongTensor | None = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> tuple | BaseModelOutputWithPooling:
|
|
r"""
|
|
pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
|
The tensors corresponding to the input videos.
|
|
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
|
|
The temporal, height and width of feature shape of each video in LLM.
|
|
"""
|
|
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
|
|
# reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames
|
|
temp_frames_hw = []
|
|
for t, h, w in video_grid_thw:
|
|
repeated_row = torch.tensor([1, h.item(), w.item()]).unsqueeze(0).repeat(t, 1)
|
|
temp_frames_hw.append(repeated_row)
|
|
flattened_video_grid_thw = torch.cat(temp_frames_hw, dim=0)
|
|
vision_outputs = self.visual(
|
|
pixel_values_videos, grid_thw=flattened_video_grid_thw, return_dict=True, **kwargs
|
|
)
|
|
split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
|
|
video_embeds = torch.split(vision_outputs.pooler_output, split_sizes)
|
|
vision_outputs.pooler_output = video_embeds
|
|
|
|
return vision_outputs
|
|
|
|
@can_return_tuple
|
|
@auto_docstring
|
|
def get_image_features(
|
|
self,
|
|
pixel_values: torch.FloatTensor,
|
|
image_grid_thw: torch.LongTensor | None = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> tuple | BaseModelOutputWithPooling:
|
|
r"""
|
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
|
The tensors corresponding to the input images.
|
|
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
|
|
The temporal, height and width of feature shape of each image in LLM.
|
|
"""
|
|
pixel_values = pixel_values.type(self.visual.dtype)
|
|
vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, return_dict=True, **kwargs)
|
|
split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
|
|
image_embeds = torch.split(vision_outputs.pooler_output, split_sizes)
|
|
vision_outputs.pooler_output = image_embeds
|
|
|
|
return vision_outputs
|
|
|
|
def get_placeholder_mask(
|
|
self,
|
|
input_ids: torch.LongTensor,
|
|
inputs_embeds: torch.FloatTensor,
|
|
image_features: torch.FloatTensor | None = None,
|
|
video_features: torch.FloatTensor | None = None,
|
|
):
|
|
"""
|
|
Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
|
|
equal to the length of multimodal features. If the lengths are different, an error is raised.
|
|
"""
|
|
if input_ids is None:
|
|
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
|
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
|
)
|
|
special_image_mask = special_image_mask.all(-1)
|
|
special_video_mask = inputs_embeds == self.get_input_embeddings()(
|
|
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
|
|
)
|
|
special_video_mask = special_video_mask.all(-1)
|
|
else:
|
|
# GLM-4.1V and GLM-4.5V special_video_mask is special_image_mask
|
|
special_image_mask = input_ids == self.config.image_token_id
|
|
special_video_mask = input_ids == self.config.image_token_id
|
|
|
|
n_image_tokens = special_image_mask.sum()
|
|
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
|
if image_features is not None:
|
|
torch_compilable_check(
|
|
inputs_embeds[special_image_mask].numel() == image_features.numel(),
|
|
f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {image_features.shape[0]}",
|
|
)
|
|
|
|
n_video_tokens = special_video_mask.sum()
|
|
special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
|
if video_features is not None:
|
|
torch_compilable_check(
|
|
inputs_embeds[special_video_mask].numel() == video_features.numel(),
|
|
f"Video features and video tokens do not match, tokens: {n_video_tokens}, features: {video_features.shape[0]}",
|
|
)
|
|
return special_image_mask, special_video_mask
|
|
|
|
@auto_docstring
|
|
@can_return_tuple
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor | None = None,
|
|
attention_mask: torch.Tensor | None = None,
|
|
position_ids: torch.LongTensor | None = None,
|
|
past_key_values: Cache | None = None,
|
|
inputs_embeds: torch.FloatTensor | None = None,
|
|
pixel_values: torch.Tensor | None = None,
|
|
pixel_values_videos: torch.FloatTensor | None = None,
|
|
image_grid_thw: torch.LongTensor | None = None,
|
|
video_grid_thw: torch.LongTensor | None = None,
|
|
rope_deltas: torch.LongTensor | None = None,
|
|
cache_position: torch.LongTensor | None = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> tuple | Glm46VModelOutputWithPast:
|
|
r"""
|
|
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
|
|
The temporal, height and width of feature shape of each image in LLM.
|
|
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
|
|
The temporal, height and width of feature shape of each video in LLM.
|
|
rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
|
|
The rope index difference between sequence length and multimodal rope.
|
|
"""
|
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
|
|
if inputs_embeds is None:
|
|
inputs_embeds = self.get_input_embeddings()(input_ids)
|
|
|
|
if pixel_values is not None:
|
|
image_embeds = self.get_image_features(pixel_values, image_grid_thw, return_dict=True).pooler_output
|
|
image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
|
|
image_mask, _ = self.get_placeholder_mask(input_ids, inputs_embeds, image_features=image_embeds)
|
|
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
|
|
|
if pixel_values_videos is not None:
|
|
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw, return_dict=True).pooler_output
|
|
video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
|
|
_, video_mask = self.get_placeholder_mask(input_ids, inputs_embeds, video_features=video_embeds)
|
|
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
|
|
|
if position_ids is None:
|
|
attention_mask_tensor = (
|
|
attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"]
|
|
)
|
|
if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4:
|
|
attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2)
|
|
# Only apply conversion for floating point tensors (inverted masks)
|
|
if attention_mask_tensor.dtype.is_floating_point:
|
|
attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min
|
|
attention_mask_tensor = (1.0 - attention_mask_tensor).int()
|
|
|
|
# Calculate RoPE index once per generation in the pre-fill stage only.
|
|
# When compiling, we can't check tensor values thus we check only input length
|
|
# It is safe to assume that `length!=1` means we're in pre-fill because compiled
|
|
# models currently cannot do asssisted decoding
|
|
prefill_compiled_stage = is_torchdynamo_compiling() and (
|
|
(input_ids is not None and input_ids.shape[1] != 1)
|
|
or (inputs_embeds is not None and inputs_embeds.shape[1] != 1)
|
|
)
|
|
prefill_noncompiled_stage = not is_torchdynamo_compiling() and (
|
|
(cache_position is not None and cache_position[0] == 0)
|
|
or (past_key_values is None or past_key_values.get_seq_length() == 0)
|
|
)
|
|
if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None:
|
|
position_ids, rope_deltas = self.get_rope_index(
|
|
input_ids,
|
|
image_grid_thw,
|
|
video_grid_thw,
|
|
attention_mask=attention_mask_tensor,
|
|
)
|
|
self.rope_deltas = rope_deltas
|
|
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
|
else:
|
|
batch_size, seq_length, _ = inputs_embeds.shape
|
|
delta = (
|
|
(cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
|
|
if cache_position is not None
|
|
else 0
|
|
)
|
|
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
|
|
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
|
|
if cache_position is not None: # otherwise `deltas` is an int `0`
|
|
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
|
|
position_ids = position_ids.add(delta)
|
|
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
|
|
|
|
outputs = self.language_model(
|
|
input_ids=None,
|
|
position_ids=position_ids,
|
|
attention_mask=attention_mask,
|
|
past_key_values=past_key_values,
|
|
inputs_embeds=inputs_embeds,
|
|
cache_position=cache_position,
|
|
**kwargs,
|
|
)
|
|
|
|
return Glm46VModelOutputWithPast(
|
|
last_hidden_state=outputs.last_hidden_state,
|
|
past_key_values=outputs.past_key_values,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
rope_deltas=self.rope_deltas,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
@auto_docstring(
|
|
custom_intro="""
|
|
Base class for Glm46V causal language model (or autoregressive) outputs.
|
|
"""
|
|
)
|
|
class Glm46VCausalLMOutputWithPast(ModelOutput):
|
|
r"""
|
|
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
|
Language modeling loss (for next-token prediction).
|
|
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
|
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
|
It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
|
|
|
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
|
`past_key_values` input) to speed up sequential decoding.
|
|
rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
|
|
The rope index difference between sequence length and multimodal rope.
|
|
"""
|
|
|
|
loss: torch.FloatTensor | None = None
|
|
logits: torch.FloatTensor | None = None
|
|
past_key_values: Cache | None = None
|
|
hidden_states: tuple[torch.FloatTensor] | None = None
|
|
attentions: tuple[torch.FloatTensor] | None = None
|
|
rope_deltas: torch.LongTensor | None = None
|
|
|
|
|
|
class Glm46VForConditionalGeneration(Glm46VPreTrainedModel, GenerationMixin):
|
|
_checkpoint_conversion_mapping = {}
|
|
_tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
|
|
# Reference: fix gemma3 grad acc #37208
|
|
accepts_loss_kwargs = False
|
|
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.model = Glm46VModel(config)
|
|
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
|
|
|
|
self.post_init()
|
|
|
|
def get_input_embeddings(self):
|
|
return self.model.get_input_embeddings()
|
|
|
|
def set_input_embeddings(self, value):
|
|
self.model.set_input_embeddings(value)
|
|
|
|
@auto_docstring
|
|
def get_video_features(
|
|
self,
|
|
pixel_values_videos: torch.FloatTensor,
|
|
video_grid_thw: torch.LongTensor | None = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> tuple | BaseModelOutputWithPooling:
|
|
r"""
|
|
pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
|
The tensors corresponding to the input videos.
|
|
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
|
|
The temporal, height and width of feature shape of each video in LLM.
|
|
"""
|
|
return self.model.get_video_features(
|
|
pixel_values_videos=pixel_values_videos, video_grid_thw=video_grid_thw, **kwargs
|
|
)
|
|
|
|
@auto_docstring
|
|
def get_image_features(
|
|
self,
|
|
pixel_values: torch.FloatTensor,
|
|
image_grid_thw: torch.LongTensor | None = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> tuple | BaseModelOutputWithPooling:
|
|
r"""
|
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
|
The tensors corresponding to the input images.
|
|
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
|
|
The temporal, height and width of feature shape of each image in LLM.
|
|
"""
|
|
return self.model.get_image_features(pixel_values=pixel_values, image_grid_thw=image_grid_thw, **kwargs)
|
|
|
|
@can_return_tuple
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor | None = None,
|
|
attention_mask: torch.Tensor | None = None,
|
|
position_ids: torch.LongTensor | None = None,
|
|
past_key_values: Cache | None = None,
|
|
inputs_embeds: torch.FloatTensor | None = None,
|
|
labels: torch.LongTensor | None = None,
|
|
pixel_values: torch.Tensor | None = None,
|
|
pixel_values_videos: torch.FloatTensor | None = None,
|
|
image_grid_thw: torch.LongTensor | None = None,
|
|
video_grid_thw: torch.LongTensor | None = None,
|
|
cache_position: torch.LongTensor | None = None,
|
|
logits_to_keep: int | torch.Tensor = 0,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> tuple | Glm46VCausalLMOutputWithPast:
|
|
r"""
|
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
|
|
The temporal, height and width of feature shape of each image in LLM.
|
|
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
|
|
The temporal, height and width of feature shape of each video in LLM.
|
|
|
|
Example:
|
|
|
|
```python
|
|
>>> from PIL import Image
|
|
>>> import httpx
|
|
>>> from io import BytesIO
|
|
>>> from transformers import AutoProcessor, Glm46VForConditionalGeneration
|
|
|
|
>>> model = Glm46VForConditionalGeneration.from_pretrained("zai-org/GLM-4.1V-9B-Thinking")
|
|
>>> processor = AutoProcessor.from_pretrained("zai-org/GLM-4.1V-9B-Thinking")
|
|
|
|
>>> messages = [
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
|
|
{"type": "text", "text": "What is shown in this image?"},
|
|
],
|
|
},
|
|
]
|
|
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
|
>>> with httpx.stream("GET", url) as response:
|
|
... image = Image.open(BytesIO(response.read()))
|
|
|
|
>>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
|
>>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
|
|
|
|
>>> # Generate
|
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
|
|
```"""
|
|
outputs = self.model(
|
|
input_ids=input_ids,
|
|
pixel_values=pixel_values,
|
|
pixel_values_videos=pixel_values_videos,
|
|
image_grid_thw=image_grid_thw,
|
|
video_grid_thw=video_grid_thw,
|
|
position_ids=position_ids,
|
|
attention_mask=attention_mask,
|
|
past_key_values=past_key_values,
|
|
inputs_embeds=inputs_embeds,
|
|
cache_position=cache_position,
|
|
**kwargs,
|
|
)
|
|
|
|
hidden_states = outputs[0]
|
|
|
|
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
|
|
loss = None
|
|
if labels is not None:
|
|
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
|
|
|
|
return Glm46VCausalLMOutputWithPast(
|
|
loss=loss,
|
|
logits=logits,
|
|
past_key_values=outputs.past_key_values,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
rope_deltas=outputs.rope_deltas,
|
|
)
|
|
|
|
def prepare_inputs_for_generation(
|
|
self,
|
|
input_ids,
|
|
past_key_values=None,
|
|
attention_mask=None,
|
|
inputs_embeds=None,
|
|
cache_position=None,
|
|
position_ids=None,
|
|
use_cache=True,
|
|
pixel_values=None,
|
|
pixel_values_videos=None,
|
|
image_grid_thw=None,
|
|
video_grid_thw=None,
|
|
is_first_iteration=False,
|
|
**kwargs,
|
|
):
|
|
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
|
|
|
|
model_inputs = super().prepare_inputs_for_generation(
|
|
input_ids,
|
|
past_key_values=past_key_values,
|
|
attention_mask=attention_mask,
|
|
inputs_embeds=inputs_embeds,
|
|
cache_position=cache_position,
|
|
position_ids=position_ids,
|
|
pixel_values=pixel_values,
|
|
pixel_values_videos=pixel_values_videos,
|
|
image_grid_thw=image_grid_thw,
|
|
video_grid_thw=video_grid_thw,
|
|
use_cache=use_cache,
|
|
is_first_iteration=is_first_iteration,
|
|
**kwargs,
|
|
)
|
|
|
|
# GLM-V position_ids are prepared with rope_deltas in forward
|
|
model_inputs["position_ids"] = None
|
|
|
|
if not is_first_iteration and use_cache:
|
|
model_inputs["pixel_values"] = None
|
|
model_inputs["pixel_values_videos"] = None
|
|
|
|
return model_inputs
|
|
|
|
def _get_image_nums_and_video_nums(
|
|
self,
|
|
input_ids: torch.LongTensor | None,
|
|
inputs_embeds: torch.Tensor | None = None,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Get the number of images and videos for each sample to calculate the separation length of the sample tensor.
|
|
These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications.
|
|
|
|
Args:
|
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
Indices of input sequence tokens in the vocabulary.
|
|
|
|
Returns:
|
|
image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`)
|
|
video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`)
|
|
"""
|
|
|
|
if inputs_embeds is not None:
|
|
is_image = (
|
|
inputs_embeds
|
|
== self.get_input_embeddings()(
|
|
torch.tensor(self.config.image_start_token_id, dtype=torch.long, device=inputs_embeds.device)
|
|
)
|
|
)[..., 0]
|
|
is_video_start = (
|
|
inputs_embeds
|
|
== self.get_input_embeddings()(
|
|
torch.tensor(self.config.video_start_token_id, dtype=torch.long, device=inputs_embeds.device)
|
|
)
|
|
)[..., 0]
|
|
is_video_end = (
|
|
inputs_embeds
|
|
== self.get_input_embeddings()(
|
|
torch.tensor(self.config.video_end_token_id, dtype=torch.long, device=inputs_embeds.device)
|
|
)
|
|
)[..., 0]
|
|
else:
|
|
is_image = input_ids == self.config.image_start_token_id
|
|
is_video_start = input_ids == self.config.video_start_token_id
|
|
is_video_end = input_ids == self.config.video_end_token_id
|
|
|
|
# Cumulative sum to track if we're inside a video span
|
|
# We'll assume well-formed video tags (i.e. matching starts and ends)
|
|
video_level = torch.cumsum(is_video_start.int() - is_video_end.int(), dim=1)
|
|
inside_video = video_level > 0 # shape (batch_size, seq_length)
|
|
|
|
# Mask out image tokens that are inside video spans
|
|
standalone_images = is_image & (~inside_video)
|
|
|
|
# Count per batch
|
|
image_counts = standalone_images.sum(dim=1)
|
|
video_counts = is_video_start.sum(dim=1)
|
|
|
|
return image_counts, video_counts
|
|
|
|
def _expand_inputs_for_generation(
|
|
self,
|
|
expand_size: int = 1,
|
|
is_encoder_decoder: bool = False,
|
|
input_ids: torch.LongTensor | None = None,
|
|
**model_kwargs,
|
|
) -> tuple[torch.LongTensor, dict[str, Any]]:
|
|
# Overwritten -- Support for expanding tensors without a batch size dimension
|
|
# e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t
|
|
# pixel_values.shape[0] is sum(seqlen_images for samples)
|
|
# image_grid_thw.shape[0] is sum(num_images for samples)
|
|
|
|
if expand_size == 1:
|
|
return input_ids, model_kwargs
|
|
|
|
visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw", "second_per_grid_ts"]
|
|
|
|
def _expand_dict_for_generation_visual(dict_to_expand):
|
|
image_grid_thw = model_kwargs.get("image_grid_thw", None)
|
|
video_grid_thw = model_kwargs.get("video_grid_thw", None)
|
|
image_nums, video_nums = self._get_image_nums_and_video_nums(
|
|
input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None)
|
|
)
|
|
|
|
def _repeat_interleave_samples(x, lengths, repeat_times):
|
|
samples = torch.split(x, lengths)
|
|
repeat_args = [repeat_times] + [1] * (x.dim() - 1)
|
|
result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0)
|
|
return result
|
|
|
|
for key in dict_to_expand:
|
|
if key == "pixel_values":
|
|
# split images into samples
|
|
samples = torch.split(image_grid_thw, list(image_nums))
|
|
# compute the sequence length of images for each sample
|
|
lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
|
|
dict_to_expand[key] = _repeat_interleave_samples(
|
|
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
|
|
)
|
|
elif key == "image_grid_thw":
|
|
# get the num of images for each sample
|
|
lengths = list(image_nums)
|
|
dict_to_expand[key] = _repeat_interleave_samples(
|
|
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
|
|
)
|
|
elif key == "pixel_values_videos":
|
|
samples = torch.split(video_grid_thw, list(video_nums))
|
|
lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
|
|
dict_to_expand[key] = _repeat_interleave_samples(
|
|
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
|
|
)
|
|
elif key == "video_grid_thw":
|
|
lengths = list(video_nums)
|
|
dict_to_expand[key] = _repeat_interleave_samples(
|
|
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
|
|
)
|
|
elif key == "second_per_grid_ts":
|
|
dict_to_expand[key] = _repeat_interleave_samples(
|
|
dict_to_expand[key], lengths=list(video_nums), repeat_times=expand_size
|
|
)
|
|
return dict_to_expand
|
|
|
|
def _expand_dict_for_generation(dict_to_expand):
|
|
for key in dict_to_expand:
|
|
if (
|
|
key != "cache_position"
|
|
and dict_to_expand[key] is not None
|
|
and isinstance(dict_to_expand[key], torch.Tensor)
|
|
and key not in visual_keys
|
|
):
|
|
dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
|
|
return dict_to_expand
|
|
|
|
model_kwargs = _expand_dict_for_generation_visual(model_kwargs)
|
|
|
|
if input_ids is not None:
|
|
input_ids = input_ids.repeat_interleave(expand_size, dim=0)
|
|
|
|
model_kwargs = _expand_dict_for_generation(model_kwargs)
|
|
|
|
if is_encoder_decoder:
|
|
if model_kwargs.get("encoder_outputs") is None:
|
|
raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
|
|
model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
|
|
|
|
return input_ids, model_kwargs
|
|
|
|
|
|
__all__ = ["Glm46VModel", "Glm46VPreTrainedModel", "Glm46VForConditionalGeneration"]
|