You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

147 lines
5.4 KiB

# Copyright 2024 The GLM & ZhipuAI team and HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
import torch.nn as nn
from ...utils import logging
from ..llama.modeling_llama import (
LlamaAttention,
LlamaForCausalLM,
LlamaForSequenceClassification,
LlamaForTokenClassification,
LlamaRotaryEmbedding,
)
from ..phi3.modeling_phi3 import Phi3MLP
from .configuration_glm import GlmConfig
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "THUDM/glm-4-9b"
class GlmMLP(Phi3MLP):
pass
class GlmRotaryEmbedding(LlamaRotaryEmbedding):
@staticmethod
def compute_default_rope_parameters(
config: GlmConfig | None = None,
device: Optional["torch.device"] = None,
seq_len: int | None = None,
) -> tuple["torch.Tensor", float]:
"""
Computes the inverse frequencies according to the original RoPE implementation
Args:
config ([`~transformers.PreTrainedConfig`]):
The model configuration.
device (`torch.device`):
The device to use for initialization of the inverse frequencies.
seq_len (`int`, *optional*):
The current sequence length. Unused for this type of RoPE.
Returns:
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
"""
base = config.rope_parameters["rope_theta"]
partial_rotary_factor = config.rope_parameters.get("partial_rotary_factor", 1.0)
head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
dim = int(head_dim * partial_rotary_factor)
attention_factor = 1.0 # Unused in this type of RoPE
# Compute the inverse frequencies
inv_freq = 1.0 / (
base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
)
return inv_freq, attention_factor
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., 0::2]
x2 = x[..., 1::2]
return torch.stack((-x2, x1), dim=-1).flatten(-2)
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
# Interleave them instead of usual shape
cos = cos[..., : cos.shape[-1] // 2].repeat_interleave(2, dim=-1)
sin = sin[..., : sin.shape[-1] // 2].repeat_interleave(2, dim=-1)
# Keep half or full tensor for later concatenation
rotary_dim = cos.shape[-1]
q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
# Apply rotary embeddings on the first half or full tensor
q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
# Concatenate back to full shape
q_embed = torch.cat([q_embed, q_pass], dim=-1)
k_embed = torch.cat([k_embed, k_pass], dim=-1)
return q_embed, k_embed
class GlmAttention(LlamaAttention):
def __init__(self, config: GlmConfig, layer_idx: int | None = None):
super().__init__(config, layer_idx)
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
class GlmForCausalLM(LlamaForCausalLM):
pass
class GlmForSequenceClassification(LlamaForSequenceClassification):
pass
class GlmForTokenClassification(LlamaForTokenClassification):
pass
__all__ = [
"GlmPreTrainedModel", # noqa: F822
"GlmModel", # noqa: F822
"GlmForCausalLM",
"GlmForSequenceClassification",
"GlmForTokenClassification",
]