You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1496 lines
62 KiB

# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/evolla/modular_evolla.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_evolla.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# Copyright 2025 Westlake Representational Learning Lab (Fajie Yuan Lab) team and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from collections.abc import Callable
from dataclasses import dataclass
from typing import Optional
import torch
from torch import nn
from ... import initialization as init
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
from ...masking_utils import create_bidirectional_mask, create_causal_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutputWithCrossAttentions,
BaseModelOutputWithPast,
BaseModelOutputWithPoolingAndCrossAttentions,
CausalLMOutputWithPast,
ModelOutput,
)
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast
from .configuration_evolla import EvollaConfig, SaProtConfig
def create_position_ids_from_input_ids(input_ids, padding_idx):
"""
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
are ignored. This is modified from fairseq's `utils.make_positions`.
Args:
x: torch.Tensor x:
Returns: torch.Tensor
"""
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
mask = input_ids.ne(padding_idx).int()
incremental_indices = torch.cumsum(mask, dim=1).type_as(mask) * mask
return incremental_indices.long() + padding_idx
class EvollaSaProtEmbeddings(nn.Module):
"""
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
"""
def __init__(self, config):
super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
if config.emb_layer_norm_before:
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
else:
self.layer_norm = None
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.padding_idx = config.pad_token_id
if self.position_embedding_type == "absolute":
self.position_embeddings = nn.Embedding(
config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
)
self.token_dropout = config.token_dropout
self.mask_token_id = config.mask_token_id
# remove the position_ids in EsmEmbeddings
self.position_ids = None
def forward(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
inputs_embeds=None,
):
if position_ids is None:
if input_ids is not None:
# Create the position ids from the input token ids. Any padded tokens remain padded.
position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx)
else:
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
# Note that if we want to support EVOLLA_SA_PROT-1 (not 1b!) in future then we need to support an
# embedding_scale factor here.
embeddings = inputs_embeds
# Matt: EVOLLA_SA_PROT has the option to handle masking in MLM in a slightly unusual way. If the token_dropout
# flag is False then it is handled in the same was as BERT/RoBERTa. If it is set to True, however,
# masked tokens are treated as if they were selected for input dropout and zeroed out.
# This "mask-dropout" is compensated for when masked tokens are not present, by scaling embeddings by
# a factor of (fraction of unmasked tokens during training) / (fraction of unmasked tokens in sample).
# This is analogous to the way that dropout layers scale down outputs during evaluation when not
# actually dropping out values (or, equivalently, scale up their un-dropped outputs in training).
if self.token_dropout and input_ids is not None:
embeddings = embeddings.masked_fill((input_ids == self.mask_token_id).unsqueeze(-1), 0.0)
mask_ratio_train = 0.15 * 0.8 # Hardcoded as the ratio used in all EVOLLA_SA_PROT model training runs
src_lengths = attention_mask.sum(-1) if attention_mask is not None else input_ids.shape[1]
mask_ratio_observed = (input_ids == self.mask_token_id).sum(-1).float() / src_lengths
embeddings = (embeddings * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]).to(
embeddings.dtype
)
if self.position_embedding_type == "absolute":
position_embeddings = self.position_embeddings(position_ids)
embeddings = embeddings + position_embeddings
if self.layer_norm is not None:
embeddings = self.layer_norm(embeddings)
if attention_mask is not None:
embeddings = (embeddings * attention_mask.unsqueeze(-1)).to(embeddings.dtype)
# Matt: I think this line was copied incorrectly from BERT, disabling it for now.
# embeddings = self.dropout(embeddings)
return embeddings
def create_position_ids_from_inputs_embeds(self, inputs_embeds):
"""
We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
Args:
inputs_embeds: torch.Tensor
Returns: torch.Tensor
"""
input_shape = inputs_embeds.size()[:-1]
sequence_length = input_shape[1]
position_ids = torch.arange(
self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
)
return position_ids.unsqueeze(0).expand(input_shape)
def rotate_half_esm(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb_esm(x, cos, sin):
cos = cos[:, :, : x.shape[-2], :]
sin = sin[:, :, : x.shape[-2], :]
return (x * cos) + (rotate_half_esm(x) * sin)
class EvollaSaProtRotaryEmbedding(nn.Module):
"""
Rotary position embeddings based on those in
[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation
matrices which depend on their relative positions.
"""
inv_freq: torch.Tensor # fix linting for `register_buffer`
def __init__(self, dim: int):
super().__init__()
self.dim = dim
# Generate and save the inverse frequency buffer (non trainable)
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
self.register_buffer("inv_freq", inv_freq)
self._seq_len_cached = None
self._cos_cached = None
self._sin_cached = None
def _update_cos_sin_tables(self, x, seq_dimension=2):
seq_len = x.shape[seq_dimension]
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
self._seq_len_cached = seq_len
t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self._cos_cached = emb.cos()[None, None, :, :]
self._sin_cached = emb.sin()[None, None, :, :]
return self._cos_cached, self._sin_cached
def forward(self, q: torch.Tensor, k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2)
return (
apply_rotary_pos_emb_esm(q, self._cos_cached, self._sin_cached).to(dtype=q.dtype),
apply_rotary_pos_emb_esm(k, self._cos_cached, self._sin_cached).to(dtype=k.dtype),
)
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor | None,
scaling: float | None = None,
dropout: float = 0.0,
**kwargs: Unpack[TransformersKwargs],
):
if scaling is None:
scaling = query.size(-1) ** -0.5
# Take the dot product between "query" and "key" to get the raw attention scores.
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
if attention_mask is not None:
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class EvollaSaProtSelfAttention(nn.Module):
def __init__(self, config, position_embedding_type=None, layer_idx=None, is_cross_attention=False):
super().__init__()
self.config = config
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
f"heads ({config.num_attention_heads})"
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = config.attention_probs_dropout_prob
self.rotary_embeddings = None
self.position_embedding_type = position_embedding_type or getattr(
config, "position_embedding_type", "absolute"
)
if self.position_embedding_type == "rotary":
self.rotary_embeddings = EvollaSaProtRotaryEmbedding(dim=self.attention_head_size)
self.is_decoder = config.is_decoder
self.layer_idx = layer_idx
self.scaling = 1.0
self.is_causal = self.is_decoder and not is_cross_attention
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.FloatTensor | None = None,
encoder_hidden_states: torch.FloatTensor | None = None,
encoder_attention_mask: torch.FloatTensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor]:
batch_size, seq_length = hidden_states.shape[:-1]
hidden_shape = (batch_size, seq_length, -1, self.attention_head_size)
query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
is_cross_attention = encoder_hidden_states is not None
current_states = encoder_hidden_states if is_cross_attention else hidden_states
attention_mask = encoder_attention_mask if is_cross_attention else attention_mask
key_layer = self.key(current_states).view(hidden_shape).transpose(1, 2)
value_layer = self.value(current_states).view(hidden_shape).transpose(1, 2)
# Matt: Our BERT model (which this code was derived from) scales attention logits down by sqrt(head_dim).
# EVOLLA_SA_PROT scales the query down by the same factor instead. Modulo numerical stability these are equivalent,
# but not when rotary embeddings get involved. Therefore, we scale the query here to match the original
# EVOLLA_SA_PROT code and fix rotary embeddings.
query_layer = query_layer * self.attention_head_size**-0.5
if self.position_embedding_type == "rotary":
query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
self.config._attn_implementation, eager_attention_forward
)
attn_output, attn_weights = attention_interface(
self,
query_layer,
key_layer,
value_layer,
attention_mask,
dropout=0.0 if not self.training else self.dropout,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()
return attn_output, attn_weights
class EvollaSaProtSelfOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = hidden_states + input_tensor
return hidden_states
class EvollaSaProtAttention(nn.Module):
def __init__(self, config, layer_idx=None, is_cross_attention=False):
super().__init__()
self.self = EvollaSaProtSelfAttention(config, layer_idx=layer_idx, is_cross_attention=is_cross_attention)
self.output = EvollaSaProtSelfOutput(config)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(
self,
hidden_states,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
**kwargs: Unpack[TransformersKwargs],
):
hidden_states_ln = self.LayerNorm(hidden_states)
attn_output, _ = self.self(
hidden_states_ln,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
**kwargs,
)
attn_output = self.output(attn_output, hidden_states)
return attn_output
def gelu(x):
"""
This is the gelu implementation from the original EVOLLA_SA_PROT repo. Using F.gelu yields subtly wrong results.
"""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
class EvollaSaProtIntermediate(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = gelu(hidden_states)
return hidden_states
class EvollaSaProtOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = hidden_states + input_tensor
return hidden_states
class EvollaSaProtLayer(GradientCheckpointingLayer):
def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = EvollaSaProtAttention(config)
self.is_decoder = config.is_decoder
self.add_cross_attention = config.add_cross_attention
if self.add_cross_attention:
if not self.is_decoder:
raise RuntimeError(f"{self} should be used as a decoder model if cross attention is added")
self.crossattention = EvollaSaProtAttention(config, is_cross_attention=True)
self.intermediate = EvollaSaProtIntermediate(config)
self.output = EvollaSaProtOutput(config)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(
self,
hidden_states,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
**kwargs: Unpack[TransformersKwargs],
):
attention_output = self.attention(
hidden_states,
attention_mask=attention_mask,
**kwargs,
)
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise AttributeError(
f"If `encoder_hidden_states` are passed, {self} has to be instantiated"
" with cross-attention layers by setting `config.add_cross_attention=True`"
)
attention_output = self.crossattention(
attention_output,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
**kwargs,
)
layer_output = self.feed_forward_chunk(attention_output)
return layer_output
def feed_forward_chunk(self, attention_output):
attention_output_ln = self.LayerNorm(attention_output)
intermediate_output = self.intermediate(attention_output_ln)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
class EvollaSaProtEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.layer = nn.ModuleList([EvollaSaProtLayer(config) for _ in range(config.num_hidden_layers)])
self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.gradient_checkpointing = False
@can_return_tuple
def forward(
self,
hidden_states,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
**kwargs: Unpack[TransformersKwargs],
):
for i, layer_module in enumerate(self.layer):
hidden_states = layer_module(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
**kwargs,
)
if self.emb_layer_norm_after:
hidden_states = self.emb_layer_norm_after(hidden_states)
return BaseModelOutputWithCrossAttentions(last_hidden_state=hidden_states)
class EvollaSaProtPooler(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
@auto_docstring
class EvollaSaProtPreTrainedModel(PreTrainedModel):
config: SaProtConfig
_no_split_modules = ["EvollaSaProtLayer"]
_supports_flash_attn = True
_supports_sdpa = True
_supports_flex_attn = True
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": EvollaSaProtLayer,
"attentions": [OutputRecorder(EvollaSaProtSelfAttention, index=1, layer_name="attention")],
"cross_attentions": [
OutputRecorder(EvollaSaProtSelfAttention, index=1, layer_name="crossattention"),
],
}
def _init_weights(self, module):
super()._init_weights(module)
if isinstance(module, EvollaSaProtRotaryEmbedding):
inv_freq = 1.0 / (10000 ** (torch.arange(0, module.dim, 2, dtype=torch.int64).float() / module.dim))
init.copy_(module.inv_freq, inv_freq)
class EvollaSaProtProteinEncoder(EvollaSaProtPreTrainedModel):
def __init__(self, config: SaProtConfig):
super().__init__(config)
self.embeddings = EvollaSaProtEmbeddings(config)
self.encoder = EvollaSaProtEncoder(config)
self.post_init()
def get_input_embeddings(self):
return self.embeddings.word_embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
@check_model_inputs
def forward(
self,
input_ids: torch.Tensor | None,
attention_mask: torch.Tensor | None = None,
**kwargs,
) -> tuple[torch.Tensor] | BaseModelOutputWithPoolingAndCrossAttentions:
input_shape = input_ids.size()
batch_size, seq_length = input_shape
device = input_ids.device
if attention_mask is None:
attention_mask = torch.ones(((batch_size, seq_length)), device=device)
inputs_embeds = self.embeddings(input_ids=input_ids, attention_mask=attention_mask)
attention_mask = create_bidirectional_mask(
config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
)
encoder_outputs = self.encoder(inputs_embeds, attention_mask=attention_mask, **kwargs)
sequence_output = encoder_outputs[0]
return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
)
class EvollaSequenceCompressorAttention(nn.Module):
def __init__(self, dim, dim_head=64, heads=8):
super().__init__()
self.scale = dim_head**-0.5
self.heads = heads
inner_dim = dim_head * heads
self.norm_media = nn.LayerNorm(dim)
self.norm_latents = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
def forward(self, x, latents, mask):
"""
Args:
x (torch.Tensor): image features
shape (b, n1, D)
latent (torch.Tensor): latent features
shape (b, n2, D); n2: num of latent tokens
"""
x = self.norm_media(x)
latents = self.norm_latents(latents)
h = self.heads
q = self.to_q(latents)
kv_input = torch.cat((x, latents), dim=-2)
k, v = self.to_kv(kv_input).chunk(
2, dim=-1
) # each: batch_size, max_protein_length+num_latents, dim_head*num_heads
q = q.view(q.size(0), q.size(1), h, -1).permute(0, 2, 1, 3)
k = k.view(k.size(0), k.size(1), h, -1).permute(0, 2, 1, 3)
v = v.view(v.size(0), v.size(1), h, -1).permute(0, 2, 1, 3)
q = q * self.scale # batch_size, num_heads, num_latents, dim_head
# attention
sim = torch.matmul(q, k.transpose(-1, -2))
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
bs, nh, skd, okd = sim.shape
ones = torch.ones(nh, skd).to(mask.device) # Create a tensor of ones with shape (nh, skd)
mask_exp = mask[:, None, None, :]
ones_exp = ones[None, :, :, None]
mask = mask_exp * ones_exp
sim = sim.masked_fill((1 - mask).bool(), -1e4)
attn = sim.softmax(dim=-1)
out = torch.matmul(attn, v)
out = out.permute(0, 2, 1, 3)
# [batch, seq, head, features] -> [batch, seq, head*features]
out = out.reshape(out.size(0), out.size(1), -1)
return self.to_out(out)
class EvollaFeedForward(nn.Module):
def __init__(self, dim, mult=4):
super().__init__()
inner_dim = int(dim * mult)
self.norm = nn.LayerNorm(dim)
self.fc1 = nn.Linear(dim, inner_dim, bias=False)
self.activation = nn.GELU()
self.fc2 = nn.Linear(inner_dim, dim, bias=False)
def forward(self, x):
return self.fc2(self.activation(self.fc1(self.norm(x))))
class EvollaSequenceCompressorResampler(nn.Module):
def __init__(self, config: EvollaConfig):
super().__init__()
protein_repr_dim = config.protein_encoder_config.hidden_size
self.num_latents = config.resampler_num_latents
self.latents = nn.Parameter(torch.randn(self.num_latents, protein_repr_dim), requires_grad=True)
self.layers = nn.ModuleList([])
for _ in range(config.resampler_depth):
self.layers.append(
nn.ModuleList(
[
EvollaSequenceCompressorAttention(
dim=protein_repr_dim, dim_head=config.resampler_dim_head, heads=config.resampler_heads
),
EvollaFeedForward(dim=protein_repr_dim, mult=config.resampler_ff_mult),
]
)
)
self.norm = nn.LayerNorm(config.hidden_size)
self.protein_projector = nn.Linear(protein_repr_dim, config.hidden_size)
def forward(self, embeds, mask):
b = embeds.shape[0]
bs, _ = mask.shape # bs, max_protein_length
latent_mask = torch.ones(bs, self.num_latents).to(mask.device)
mask = torch.cat((mask, latent_mask), dim=1) # bs, max_protein_length + num_latents
# blocks
ones = torch.ones(b).to(self.latents.device)
latents = self.latents[None] * ones.view(-1, 1, 1) # [b,n,d]
latents = latents.to(embeds.dtype)
for attn, ff in self.layers:
latents = attn(embeds, latents, mask) + latents
latents = ff(latents) + latents
transformed_feature = self.protein_projector(latents)
return self.norm(transformed_feature)
@dataclass
@auto_docstring
class EvollaProteinEncoderModelOutput(ModelOutput):
sequence_compressor_output: torch.FloatTensor | None = None
last_hidden_state: torch.FloatTensor | None = None
hidden_states: tuple[torch.FloatTensor, ...] | None = None
attentions: tuple[torch.FloatTensor, ...] | None = None
class EvollaProteinEncoder(nn.Module):
def __init__(self, config: EvollaConfig):
super().__init__()
self.model = EvollaSaProtProteinEncoder(config=config.protein_encoder_config)
self.sequence_compressor_resampler = EvollaSequenceCompressorResampler(config=config)
@can_return_tuple
def forward(self, input_ids: torch.LongTensor, attention_mask: torch.FloatTensor, **kwargs):
protein_output = self.model(input_ids=input_ids, attention_mask=attention_mask)
protein_embeds = protein_output.last_hidden_state
sequence_repr = self.sequence_compressor_resampler(protein_embeds, attention_mask)
return EvollaProteinEncoderModelOutput(
sequence_compressor_output=sequence_repr,
last_hidden_state=protein_output.last_hidden_state,
)
class EvollaSequenceAlignerCrossAttention(nn.Module):
def __init__(
self,
config,
protein_encoder_dim: int | None = None,
structure_encoder_dim: int | None = None,
msa_encoder_dim: int | None = None,
):
super().__init__()
self.hidden_size = config.hidden_size
self.num_attention_heads = config.num_attention_heads
self.scale = self.num_attention_heads**-0.5
self.attention_head_size = int(self.hidden_size / self.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
attention_probs_dropout_prob = config.aligner_attention_probs_dropout_prob
enable_bias = config.aligner_enable_bias
ffn_mult = config.aligner_ffn_mult
self.query = nn.Linear(self.hidden_size, self.all_head_size)
if protein_encoder_dim is not None:
self.key_protein = nn.Linear(protein_encoder_dim, self.all_head_size)
self.value_protein = nn.Linear(protein_encoder_dim, self.all_head_size)
else:
self.key_protein = None
self.value_protein = None
if structure_encoder_dim is not None:
self.key_structure = nn.Linear(structure_encoder_dim, self.all_head_size)
self.value_structure = nn.Linear(structure_encoder_dim, self.all_head_size)
else:
self.key_structure = None
self.value_structure = None
if msa_encoder_dim is not None:
self.key_msa = nn.Linear(msa_encoder_dim, self.all_head_size)
self.value_msa = nn.Linear(msa_encoder_dim, self.all_head_size)
else:
self.key_msa = None
self.value_msa = None
self.attention_norm = EvollaRMSNorm(self.hidden_size)
self.dropout = nn.Dropout(attention_probs_dropout_prob)
self.out_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=enable_bias)
self.ff = EvollaFeedForward(self.hidden_size, ffn_mult)
self.gate_attention = nn.Parameter(torch.tensor([0.0]))
self.gate_ffw = nn.Parameter(torch.tensor([0.0]))
def cross_attention(
self,
query_states,
protein_key_value_states,
structure_key_value_states,
msa_key_value_states,
query_attn_mask,
protein_kv_attn_mask,
structure_kv_attn_mask,
msa_kv_attn_mask,
):
"""
query_states: text
key_value_states: protein
query_states: [bs, query_seq_len, dim]
key_value_states: [bs, kv_seq_len, dim]
query_attn_mask: [bs, query_seq_len]
kv_attn_mask: [bs, kv_seq_len]
"""
# Concatenate protein and structure
kv_attn_mask = [protein_kv_attn_mask, structure_kv_attn_mask, msa_kv_attn_mask]
kv_attn_mask = [_ for _ in kv_attn_mask if _ is not None]
if not kv_attn_mask:
raise ValueError("At least one modality should be provided for cross attention.")
kv_attn_mask = torch.cat(kv_attn_mask, dim=1)
query_layer = self.attention_norm(query_states)
# Warning: This place might cause issues, refers to
# https://discuss.pytorch.org/t/cuda-error-cublas-status-not-supported-when-calling-cublasltmatmul-from-torch-nn-functional-linear/170214/13
# Solution: add `DISABLE_ADDMM_CUDA_LT=1` as environment variable
# Apply linear transformation to input_query, input_key, and input_value
query_layer = self.query(query_layer) # [bs, querylength, dim]
if self.key_protein is not None and self.value_protein is not None:
protein_key_value_states = protein_key_value_states.to(query_states)
key_layer_protein = self.key_protein(protein_key_value_states) # [bs, keylength, dim]
value_layer_protein = self.value_protein(protein_key_value_states) # [bs, keylength, dim]
else:
key_layer_protein = None
value_layer_protein = None
if self.key_structure is not None and self.value_structure is not None:
structure_key_value_states = structure_key_value_states.to(query_states)
key_layer_structure = self.key_structure(structure_key_value_states) # [bs, keylength, dim]
value_layer_structure = self.value_structure(structure_key_value_states) # [bs, keylength, dim]
else:
key_layer_structure = None
value_layer_structure = None
if self.key_msa is not None and self.value_msa is not None:
msa_key_value_states = msa_key_value_states.to(query_states)
key_layer_msa = self.key_msa(msa_key_value_states) # [bs, keylength, dim]
value_layer_msa = self.value_msa(msa_key_value_states) # [bs, keylength, dim]
else:
key_layer_msa = None
value_layer_msa = None
key_layer = [key_layer_protein, key_layer_structure, key_layer_msa]
key_layer = [_ for _ in key_layer if _ is not None]
key_layer = torch.cat(key_layer, dim=1)
value_layer = [value_layer_protein, value_layer_structure, value_layer_msa]
value_layer = [_ for _ in value_layer if _ is not None]
value_layer = torch.cat(value_layer, dim=1)
new_query_layer_shape = query_layer.size()[:-1] + (
self.num_attention_heads,
self.attention_head_size,
)
query_layer = query_layer.view(*new_query_layer_shape).permute(0, 2, 1, 3)
new_key_layer_shape = key_layer.size()[:-1] + (
self.num_attention_heads,
self.attention_head_size,
)
key_layer = key_layer.view(*new_key_layer_shape).permute(0, 2, 1, 3)
new_value_layer_shape = value_layer.size()[:-1] + (
self.num_attention_heads,
self.attention_head_size,
)
value_layer = value_layer.view(*new_value_layer_shape).permute(0, 2, 1, 3)
query_layer = query_layer * self.scale
# attention_mask: [bs, 1, querylength, keylength]
if query_attn_mask is None:
query_attn_mask = torch.ones(query_states.size(0), query_states.size(1)).to(query_states.device)
attention_mask = query_attn_mask[:, None, :, None] * kv_attn_mask[:, None, None, :]
# Compute the scaled dot-product attention scores
attn_weights = torch.matmul(query_layer, key_layer.transpose(-1, -2)) # [bs, numheads, querylength, keylength]
attn_weights = attn_weights - attn_weights.amax(dim=-1, keepdim=True).detach() # To stabilize score
attention_scores = attn_weights.masked_fill(
(1 - attention_mask).bool(), torch.finfo(attn_weights.dtype).min
) # [bs, numheads, querylength, keylength]
attention_probs = nn.Softmax(dim=-1)(attention_scores)
# attention_probs_dropped = self.dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer) # [bs, numheads, querylength, dim/numheads]
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
context_layer = self.out_proj(context_layer)
return context_layer
def forward(
self,
query_states,
protein_kv_states,
structure_kv_states,
msa_kv_states,
query_attn_mask,
protein_kv_attn_mask=None,
structure_kv_attn_mask=None,
msa_kv_attn_mask=None,
protein_batch_mask=None,
structure_batch_mask=None,
msa_batch_mask=None,
past_key_values=None,
):
if protein_kv_states is not None:
bs, protein_kv_seq_len, dim = protein_kv_states.shape
if protein_kv_attn_mask is None:
protein_kv_attn_mask = (
torch.ones(bs, protein_kv_seq_len).to(protein_batch_mask.device)
* protein_batch_mask.expand(size=(protein_kv_seq_len, bs)).T
).to(protein_kv_states.device)
else:
protein_kv_attn_mask = None
if structure_kv_states is not None:
bs, structure_kv_seq_len, dim = structure_kv_states.shape
if structure_kv_attn_mask is None:
structure_kv_attn_mask = (
torch.ones(bs, structure_kv_seq_len).to(protein_batch_mask.device)
* structure_batch_mask.expand(size=(structure_kv_seq_len, bs)).T
).to(structure_kv_states.device)
else:
structure_kv_attn_mask = None
if msa_kv_states is not None:
bs, msa_kv_seq_len, dim = msa_kv_states.shape
if msa_kv_attn_mask is None:
msa_kv_attn_mask = (
torch.ones(bs, msa_kv_seq_len).to(protein_batch_mask.device)
* msa_batch_mask.expand(size=(msa_kv_seq_len, bs)).T
).to(msa_kv_states.device)
else:
msa_kv_attn_mask = None
hidden_states = query_states
# only when there's at least one valid modality, crossattention will be performed
if (
(protein_kv_states is not None and protein_kv_attn_mask.any())
or (structure_kv_states is not None and structure_kv_attn_mask.any())
or (msa_kv_states is not None and msa_kv_attn_mask.any())
):
residual = hidden_states
hidden_states = self.cross_attention(
query_states=hidden_states,
protein_key_value_states=protein_kv_states,
structure_key_value_states=structure_kv_states,
msa_key_value_states=msa_kv_states,
query_attn_mask=query_attn_mask,
protein_kv_attn_mask=protein_kv_attn_mask,
structure_kv_attn_mask=structure_kv_attn_mask,
msa_kv_attn_mask=msa_kv_attn_mask,
) # [bs, query_seq_len, dim]
# tanh gate
hidden_states = torch.tanh(self.gate_attention) * hidden_states
hidden_states = residual + hidden_states # input_query
residual = hidden_states
hidden_states = self.ff(hidden_states) * torch.tanh(self.gate_ffw)
hidden_states = residual + hidden_states
return hidden_states
@use_kernel_forward_from_hub("RMSNorm")
class EvollaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
EvollaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
class EvollaRotaryEmbedding(nn.Module):
inv_freq: torch.Tensor # fix linting for `register_buffer`
def __init__(self, config: EvollaConfig, device=None):
super().__init__()
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_type = self.config.rope_parameters["rope_type"]
rope_init_fn: Callable = self.compute_default_rope_parameters
if self.rope_type != "default":
rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
@staticmethod
def compute_default_rope_parameters(
config: EvollaConfig | None = None,
device: Optional["torch.device"] = None,
seq_len: int | None = None,
) -> tuple["torch.Tensor", float]:
"""
Computes the inverse frequencies according to the original RoPE implementation
Args:
config ([`~transformers.PreTrainedConfig`]):
The model configuration.
device (`torch.device`):
The device to use for initialization of the inverse frequencies.
seq_len (`int`, *optional*):
The current sequence length. Unused for this type of RoPE.
Returns:
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
"""
base = config.rope_parameters["rope_theta"]
dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
attention_factor = 1.0 # Unused in this type of RoPE
# Compute the inverse frequencies
inv_freq = 1.0 / (
base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
)
return inv_freq, attention_factor
@torch.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
def forward(self, x, position_ids):
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
class EvollaMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
@use_kernelized_func(apply_rotary_pos_emb)
class EvollaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: EvollaConfig, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.is_causal = True
self.q_proj = nn.Linear(
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
)
self.k_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.v_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
attention_mask: torch.Tensor | None = None,
past_key_values: Cache | None = None,
cache_position: torch.LongTensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor, torch.Tensor]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_values is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
self.config._attn_implementation, eager_attention_forward
)
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class EvollaDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: EvollaConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = EvollaAttention(config=config, layer_idx=layer_idx)
self.mlp = EvollaMLP(config)
self.input_layernorm = EvollaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = EvollaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
if (layer_idx + 1) % max(config.num_hidden_layers // config.aligner_num_add_layers, 1) == 0:
self.adapter = EvollaSequenceAlignerCrossAttention(
config,
protein_encoder_dim=config.hidden_size,
)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: Cache | None = None,
use_cache: bool | None = False,
cache_position: torch.LongTensor | None = None,
protein_kv_states: torch.Tensor | None = None,
structure_kv_states: torch.Tensor | None = None,
msa_kv_states: torch.Tensor | None = None,
protein_batch_mask: torch.Tensor | None = None,
structure_batch_mask: torch.Tensor | None = None,
msa_batch_mask: torch.Tensor | None = None,
query_attn_mask: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
if hasattr(self, "adapter"):
hidden_states = self.adapter(
query_states=hidden_states,
protein_kv_states=protein_kv_states,
structure_kv_states=structure_kv_states,
msa_kv_states=msa_kv_states,
query_attn_mask=query_attn_mask,
protein_batch_mask=protein_batch_mask,
structure_batch_mask=structure_batch_mask,
msa_batch_mask=msa_batch_mask,
)
return hidden_states
@auto_docstring
class EvollaPreTrainedModel(PreTrainedModel):
config: EvollaConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = [
"EvollaDecoderLayer",
"EvollaSequenceCompressorResampler",
"EvollaSequenceAlignerCrossAttention",
]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn = False # see dependency on `EvollaSequenceCompressorResampler`
_supports_sdpa = True
_supports_flex_attn = False # see dependency on `EvollaSequenceCompressorResampler`
_can_compile_fullgraph = True
_supports_attention_backend = False
_can_record_outputs = {
"hidden_states": EvollaDecoderLayer,
"attentions": EvollaAttention,
}
@torch.no_grad()
def _init_weights(self, module):
std = self.config.initializer_range
super()._init_weights(module)
if isinstance(module, EvollaSequenceAlignerCrossAttention):
init.zeros_(module.gate_attention)
init.zeros_(module.gate_ffw)
init.ones_(module.attention_norm.weight)
elif isinstance(module, EvollaSequenceCompressorResampler):
init.normal_(module.latents, mean=0.0, std=std)
class EvollaModel(EvollaPreTrainedModel):
def __init__(self, config: EvollaConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(self.vocab_size, config.hidden_size, self.padding_idx)
self.protein_encoder = EvollaProteinEncoder(config=config)
self.layers = nn.ModuleList(
[
EvollaDecoderLayer(
config=config,
layer_idx=layer_idx,
)
for layer_idx in range(config.num_hidden_layers)
]
)
self.norm = EvollaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = getattr(config, "gradient_checkpointing", False)
self.rotary_emb = EvollaRotaryEmbedding(config=config)
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
@auto_docstring
@check_model_inputs
def forward(
self,
input_ids: torch.LongTensor | None = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: Cache | None = None,
inputs_embeds: torch.FloatTensor | None = None,
use_cache: bool | None = None,
cache_position: torch.LongTensor | None = None,
protein_input_ids: torch.LongTensor | None = None,
protein_attention_mask: torch.Tensor | None = None,
structure_feats: torch.FloatTensor | None = None,
msa_feats: torch.FloatTensor | None = None,
structure_batch_mask: torch.Tensor | None = None,
msa_batch_mask: torch.Tensor | None = None,
**kwargs,
) -> tuple | BaseModelOutputWithPast:
r"""
protein_input_ids (torch.LongTensor):
The input IDs for the protein sequence in structure-aware tokens. Should be of shape `(batch_size, protein_seq_length)` and type `torch.LongTensor`.
protein_attention_mask (torch.Tensor):
The attention mask for the protein sequence. Should be of shape `(batch_size, protein_seq_length)` and type `torch.Tensor`.
structure_feats (torch.FloatTensor):
The input IDs for purely structure-based features. Should be of shape `(batch_size, structure_seq_length, structure_feat_dim)` and type `torch.FloatTensor`. Dummy input for now.
msa_feats (torch.FloatTensor):
The input IDs for purely MSA-based features. Should be of shape `(batch_size, msa_seq_length, msa_feat_dim)` and type `torch.FloatTensor`. Dummy input for now.
structure_batch_mask (torch.Tensor):
The batch mask to decide which protein sequences are purely structure-based. Should be of shape `(batch_size)` and type `torch.Tensor`. Should be paired with `structure_feats`. Dummpy input for now.
msa_batch_mask (torch.Tensor):
The batch mask to decide which protein sequences are purely MSA-based. Should be of shape `(batch_size)` and type `torch.Tensor`. Should be paired with `msa_feats`. Dummpy input for now.
"""
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if use_cache and past_key_values is None:
past_key_values = DynamicCache(config=self.config)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
protein_feats = None
protein_batch_mask = None
# If provided, actually compute them
if protein_input_ids is not None and protein_attention_mask is not None:
protein_outputs = self.protein_encoder(
input_ids=protein_input_ids,
attention_mask=protein_attention_mask,
)
protein_feats = protein_outputs.sequence_compressor_output
protein_batch_mask = torch.ones(
protein_input_ids.shape[0],
device=protein_input_ids.device,
dtype=torch.bool,
)
causal_mask = create_causal_mask(
config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
)
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
for decoder_layer in self.layers:
hidden_states = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
protein_kv_states=protein_feats,
structure_kv_states=structure_feats,
msa_kv_states=msa_feats,
protein_batch_mask=protein_batch_mask,
structure_batch_mask=structure_batch_mask,
msa_batch_mask=msa_batch_mask,
query_attn_mask=attention_mask,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = self.norm(hidden_states)
output = BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
)
return output
class EvollaForProteinText2Text(EvollaPreTrainedModel, GenerationMixin):
def __init__(self, config):
super().__init__(config)
self.model = EvollaModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, self.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
return self.model.set_input_embeddings(value)
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: torch.LongTensor | None = None, # text input ids
attention_mask: torch.Tensor | None = None, # text attention mask
inputs_embeds: torch.FloatTensor | None = None, # text input embeddings
labels: torch.LongTensor | None = None,
protein_input_ids: torch.LongTensor | None = None,
protein_attention_mask: torch.Tensor | None = None,
use_cache: bool | None = None,
logits_to_keep: int | torch.Tensor = 0,
**kwargs,
):
r"""
protein_input_ids (torch.LongTensor):
The input IDs for the protein sequence. Should be of shape `(batch_size, protein_seq_length)` and type `torch.LongTensor`.
protein_attention_mask (torch.Tensor):
The attention mask for the protein sequence. Should be of shape `(batch_size, protein_seq_length)` and type `torch.Tensor`.
Example:
```python
>>> from transformers import EvollaProcessor, EvollaForProteinText2Text
>>> model = EvollaForProteinText2Text.from_pretrained("westlake/Evolla-10B-hf")
>>> processor = EvollaProcessor.from_pretrained("westlake/Evolla-10B-hf")
>>> protein_information = {
"aa_seq": "your amino acid sequence",
"foldseek": "your foldseek sequence",
}
>>> question = "What is the function of this protein?"
>>> message = [
{"role": "system", "content": "You are an AI expert that can answer any questions about protein."},
{"role": "user", "content": question},
]
>>> inputs = processor(proteins=[protein_information], messages_list=[message], return_tensors="pt", padding="longest")
>>> outputs = model.generate(**inputs)
>>> print(processor.batch_decode(outputs, skip_special_tokens=True))
```"""
outputs: BaseModelOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
protein_input_ids=protein_input_ids,
protein_attention_mask=protein_attention_mask,
use_cache=use_cache,
**kwargs,
)
hidden_states = outputs.last_hidden_state
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.vocab_size, **kwargs)
lm_outputs = CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
return lm_outputs
__all__ = ["EvollaForProteinText2Text", "EvollaModel", "EvollaPreTrainedModel"]