You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

775 lines
32 KiB

# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/ovis2/modular_ovis2.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_ovis2.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from collections.abc import Callable
from dataclasses import dataclass
import torch
from torch import nn
from ... import initialization as init
from ...activations import ACT2FN
from ...cache_utils import Cache
from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPooling
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, torch_compilable_check
from ...utils.generic import check_model_inputs
from ..auto import AutoModel
from .configuration_ovis2 import Ovis2Config, Ovis2VisionConfig
@dataclass
@auto_docstring
class BaseModelOutputWithVisualIndicatorFeatures(BaseModelOutputWithPooling):
r"""
visual_indicator_features (`torch.FloatTensor` of shape `(batch_size, visual_indicator_size)`):
Visual indicator features extracted from the model, which can be used for auxiliary tasks or further processing.
"""
visual_indicator_features: torch.FloatTensor | None = None
@dataclass
@auto_docstring(
custom_intro="""
Base class for Llava outputs, with hidden states and attentions.
"""
)
class Ovis2ModelOutputWithPast(BaseModelOutputWithPast):
r"""
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""
image_hidden_states: torch.FloatTensor | None = None
@dataclass
@auto_docstring(
custom_intro="""
Base class for Ovis2 causal language model (or autoregressive) outputs.
"""
)
class Ovis2CausalLMOutputWithPast(ModelOutput):
r"""
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size (batch_size * num_patches, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""
loss: torch.FloatTensor | None = None
logits: torch.FloatTensor | None = None
past_key_values: Cache | None = None
hidden_states: tuple[torch.FloatTensor] | None = None
attentions: tuple[torch.FloatTensor] | None = None
image_hidden_states: torch.FloatTensor | None = None
@use_kernel_forward_from_hub("RMSNorm")
class Ovis2RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
Ovis2RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
class Ovis2VisionMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
class Ovis2VisionEmbeddings(nn.Module):
def __init__(self, config: Ovis2VisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
padding="valid",
)
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
self.rms_norm = Ovis2RMSNorm(config.hidden_size, config.rms_norm_eps)
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
embeddings = patch_embeds.flatten(2).transpose(1, 2)
embeddings = self.rms_norm(embeddings)
embeddings = embeddings + self.position_embedding(self.position_ids)
return embeddings
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor | None,
scaling: float,
dropout: float = 0.0,
**kwargs,
):
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class Ovis2VisionAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.is_causal = False
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""Input shape: Batch x Time x Channel"""
batch_size, seq_length, embed_dim = hidden_states.shape
queries = self.q_proj(hidden_states)
keys = self.k_proj(hidden_states)
values = self.v_proj(hidden_states)
queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
self.config._attn_implementation, eager_attention_forward
)
attn_output, attn_weights = attention_interface(
self,
queries,
keys,
values,
attention_mask,
is_causal=self.is_causal,
scaling=self.scale,
dropout=0.0 if not self.training else self.dropout,
)
attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights
class Ovis2MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
class Ovis2VisionEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: Ovis2VisionConfig):
super().__init__()
self.attention = Ovis2VisionAttention(config)
self.ffn = Ovis2MLP(config)
self.rms_norm1 = Ovis2RMSNorm(config.hidden_size, config.rms_norm_eps)
self.rms_norm2 = Ovis2RMSNorm(config.hidden_size, config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> torch.Tensor:
norm_hidden_states = self.rms_norm1(hidden_states)
attn_output, _ = self.attention(hidden_states=norm_hidden_states, attention_mask=attention_mask, **kwargs)
hidden_states = hidden_states + attn_output
norm_hidden_states = self.rms_norm2(hidden_states)
mlp_output = self.ffn(norm_hidden_states)
hidden_states = hidden_states + mlp_output
return hidden_states
class Ovis2VisionEncoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
[`Ovis2VisionEncoderLayer`].
Args:
config: Ovis2VisionConfig
"""
def __init__(self, config: Ovis2VisionConfig):
super().__init__()
self.config = config
self.layers = nn.ModuleList([Ovis2VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
# Ignore copy
@can_return_tuple
@auto_docstring
def forward(
self,
inputs_embeds,
attention_mask: torch.Tensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutput:
hidden_states = inputs_embeds
for encoder_layer in self.layers:
hidden_states = encoder_layer(hidden_states, attention_mask, **kwargs)
return BaseModelOutput(last_hidden_state=hidden_states)
class Ovis2VisionTransformer(nn.Module):
def __init__(self, config: Ovis2VisionConfig):
super().__init__()
self.config = config
self.embeddings = Ovis2VisionEmbeddings(config)
self.encoder = Ovis2VisionEncoder(config)
self.rms_norm = Ovis2RMSNorm(config.hidden_size, config.rms_norm_eps)
self.gradient_checkpointing = False
@can_return_tuple
def forward(
self,
pixel_values,
attention_mask: torch.Tensor | None = None,
**kwargs,
):
hidden_states = self.embeddings(pixel_values)
encoder_outputs: BaseModelOutput = self.encoder(
inputs_embeds=hidden_states,
attention_mask=attention_mask,
**kwargs,
)
last_hidden_state = encoder_outputs.last_hidden_state
last_hidden_state = self.rms_norm(last_hidden_state)
return BaseModelOutput(last_hidden_state=last_hidden_state)
class Ovis2VisualEmbeddingTable(nn.Embedding):
def forward(self, visual_tokens: torch.Tensor) -> torch.Tensor:
if visual_tokens.dtype in [torch.int8, torch.int16, torch.int32, torch.int64, torch.long]:
return super().forward(visual_tokens)
return torch.matmul(visual_tokens, self.weight)
class Ovis2PreTrainedModel(PreTrainedModel):
config: Ovis2Config
base_model_prefix = "model"
input_modalities = ("image", "text")
supports_gradient_checkpointing = True
_no_split_modules = ["Ovis2VisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
_supports_flash_attn = True
_supports_flex_attn = True
_supports_sdpa = True
_can_compile_fullgraph = True
_supports_attention_backend = True
def _init_weights(self, module):
super()._init_weights(module)
if isinstance(module, Ovis2VisionEmbeddings):
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
def hard_softmax(logits: torch.Tensor, dim: int):
y_soft = logits.softmax(dim)
# Straight through.
index = y_soft.max(dim, keepdim=True)[1]
y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
ret = y_hard - y_soft.detach() + y_soft
return ret
class Ovis2VisionModel(Ovis2PreTrainedModel):
config: Ovis2VisionConfig
_can_record_outputs = {
"hidden_states": Ovis2VisionEncoderLayer,
"attentions": Ovis2VisionAttention,
}
def __init__(self, config: Ovis2VisionConfig):
super().__init__(config)
self.config = config
self.transformer = Ovis2VisionTransformer(config)
self.num_visual_indicator_tokens = config.num_visual_indicator_tokens
self.vocab_size = config.vocab_size
self.head_linear = nn.Linear(
config.hidden_size * config.hidden_stride * config.hidden_stride,
self.vocab_size - self.num_visual_indicator_tokens,
bias=False,
)
self.head_norm = nn.LayerNorm(self.vocab_size - self.num_visual_indicator_tokens)
self.post_init()
@check_model_inputs
def forward(
self, pixel_values: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs]
) -> tuple | BaseModelOutputWithVisualIndicatorFeatures:
outputs = self.transformer(pixel_values, **kwargs)
last_hidden_state = outputs[0]
if self.config.hidden_stride > 1:
num_images, seq_len, hidden_dim = last_hidden_state.shape
hidden_stride = self.config.hidden_stride
sqrt_l = int(math.sqrt(seq_len))
if sqrt_l * sqrt_l != seq_len:
raise ValueError("Token sequence length must be a perfect square")
pad_size = (hidden_stride - (sqrt_l % hidden_stride)) % hidden_stride
last_hidden_state = nn.functional.pad(last_hidden_state, (0, 0, 0, pad_size, 0, pad_size), "constant", 0)
sqrt_l += pad_size
last_hidden_state = last_hidden_state.reshape(
num_images, sqrt_l // hidden_stride, hidden_stride, sqrt_l // hidden_stride, hidden_stride, hidden_dim
)
last_hidden_state = last_hidden_state.permute(0, 1, 3, 2, 4, 5)
last_hidden_state = last_hidden_state.reshape(
num_images, -1, hidden_stride * hidden_stride * hidden_dim
) # (n, (sqrt_l//hs)^2, hs^2*d)
logits = self.head_linear(last_hidden_state)
logits = self.head_norm(logits)
if self.config.tokenize_function == "gumbel_argmax":
prob_token = nn.functional.gumbel_softmax(logits, dim=-1, hard=True)
elif self.config.tokenize_function == "st_argmax":
prob_token = hard_softmax(logits, dim=-1)
elif self.config.tokenize_function == "softmax":
prob_token = nn.functional.softmax(logits, dim=-1)
return BaseModelOutputWithVisualIndicatorFeatures(
last_hidden_state=last_hidden_state,
pooler_output=prob_token,
)
@auto_docstring(
custom_intro="""
The Ovis2 model which consists of a vision backbone and a language model, without a language modeling head.
"""
)
class Ovis2Model(Ovis2PreTrainedModel):
_checkpoint_conversion_mapping = {}
def __init__(self, config: Ovis2Config):
super().__init__(config)
self.vision_tower = Ovis2VisionModel(config.vision_config)
self.language_model = AutoModel.from_config(config.text_config)
self.visual_embeddings_table = Ovis2VisualEmbeddingTable(config.vision_config.vocab_size, config.hidden_size)
self.visual_vocab_size = config.vision_config.vocab_size
self.vocab_size = config.vocab_size
self.visual_indicator_token_ids = config.visual_indicator_token_ids
self.post_init()
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
@can_return_tuple
@auto_docstring(
custom_intro="Obtains image last hidden states from the vision tower and apply multimodal projection."
)
def get_image_features(
self,
pixel_values: torch.FloatTensor,
**kwargs: Unpack[TransformersKwargs],
) -> tuple | BaseModelOutputWithVisualIndicatorFeatures:
image_outputs = self.vision_tower(pixel_values, return_dict=True, **kwargs)
image_features = image_outputs.pooler_output
batch_size, img_seq_len, _ = image_features.shape
padding_tensor = torch.zeros(
(batch_size, img_seq_len, self.vision_tower.num_visual_indicator_tokens),
dtype=image_features.dtype,
device=image_features.device,
requires_grad=False,
layout=image_features.layout,
)
image_features = torch.cat([image_features, padding_tensor], dim=2)
image_features = self.visual_embeddings_table(image_features)
visual_indicator = torch.arange(
self.visual_vocab_size - self.vision_tower.num_visual_indicator_tokens,
self.visual_vocab_size,
dtype=torch.long,
).to(image_features.device)
image_outputs.pooler_output = image_features
image_outputs.visual_indicator_features = self.visual_embeddings_table(visual_indicator)
return image_outputs
def get_placeholder_mask(
self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
):
"""
Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
equal to the length of multimodal features. If the lengths are different, an error is raised.
"""
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.image_token_id
n_image_tokens = special_image_mask.sum()
n_image_features = image_features.shape[0] * image_features.shape[1]
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
torch_compilable_check(
inputs_embeds[special_image_mask].numel() == image_features.numel(),
f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {n_image_features}",
)
return special_image_mask
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: torch.LongTensor | None = None,
pixel_values: torch.FloatTensor | None = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: Cache | None = None,
inputs_embeds: torch.FloatTensor | None = None,
labels: torch.LongTensor | None = None,
use_cache: bool | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
return_dict: bool | None = None,
cache_position: torch.LongTensor | None = None,
logits_to_keep: int | torch.Tensor = 0,
**kwargs,
) -> tuple | Ovis2ModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None:
image_outputs = self.get_image_features(pixel_values=pixel_values, return_dict=True)
image_features = image_outputs.pooler_output
visual_indicator_features = image_outputs.visual_indicator_features
special_image_mask = self.get_placeholder_mask(
input_ids,
inputs_embeds=inputs_embeds,
image_features=image_features,
)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
for i, visual_indicator_id in enumerate(self.visual_indicator_token_ids):
if input_ids is None:
mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(visual_indicator_id, dtype=torch.long, device=inputs_embeds.device)
)
mask = mask.all(-1)
else:
mask = (input_ids == visual_indicator_id).to(inputs_embeds.device)
if mask.any():
inputs_embeds[mask] = (
visual_indicator_features[i]
.expand_as(inputs_embeds[mask])
.to(inputs_embeds.device, inputs_embeds.dtype)
)
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
**kwargs,
)
return Ovis2ModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
@auto_docstring
class Ovis2ForConditionalGeneration(Ovis2PreTrainedModel, GenerationMixin):
_checkpoint_conversion_mapping = {}
_tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
def __init__(self, config: Ovis2Config):
super().__init__(config)
self.model = Ovis2Model(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def get_output_embeddings(self) -> nn.Module:
return self.lm_head
@auto_docstring
def get_image_features(
self, pixel_values: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs]
) -> tuple | BaseModelOutputWithVisualIndicatorFeatures:
return self.model.get_image_features(pixel_values=pixel_values, **kwargs)
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: torch.LongTensor | None = None,
pixel_values: torch.FloatTensor | None = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: Cache | None = None,
inputs_embeds: torch.FloatTensor | None = None,
labels: torch.LongTensor | None = None,
use_cache: bool | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
return_dict: bool | None = None,
cache_position: torch.LongTensor | None = None,
logits_to_keep: int | torch.Tensor = 0,
**kwargs,
) -> tuple | Ovis2CausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Example:
```python
>>> from PIL import Image
>>> import httpx
>>> from io import BytesIO
>>> from transformers import AutoProcessor, Ovis2ForConditionalGeneration
>>> model = Ovis2ForConditionalGeneration.from_pretrained("thisisiron/Ovis2-2B-hf")
>>> processor = AutoProcessor.from_pretrained("thisisiron/Ovis2-2B-hf")
>>> prompt = "<|im_start|>user\n<image>\nDescribe the image.<|im_end|>\n<|im_start|>assistant\n"
>>> url = "http://images.cocodataset.org/val2014/COCO_val2014_000000537955.jpg"
>>> with httpx.stream("GET", url) as response:
... image = Image.open(BytesIO(response.read()))
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(**inputs, max_new_tokens=15)
>>> processor.batch_decode(generate_ids, skip_special_tokens=True)[0]
"user\n\nDescribe the image.\nassistant\nThe image features a brown dog standing on a wooden floor, looking up with"
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
loss = self.loss_function(
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
)
return Ovis2CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=outputs.image_hidden_states,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
inputs_embeds=None,
pixel_values=None,
attention_mask=None,
cache_position=None,
logits_to_keep=None,
is_first_iteration=False,
**kwargs,
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
is_first_iteration=is_first_iteration,
**kwargs,
)
if is_first_iteration or not kwargs.get("use_cache", True):
# Pixel values are used only in the first iteration if available
# In subsquent iterations, they are already merged with text and cached
# NOTE: first iteration doesn't have to be prefill, it can be the first
# iteration with a question and cached system prompt (continue generate from cache)
model_inputs["pixel_values"] = pixel_values
return model_inputs
__all__ = ["Ovis2PreTrainedModel", "Ovis2Model", "Ovis2ForConditionalGeneration"]