You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

625 lines
25 KiB

4 days ago
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformers Xcodec model."""
import math
from dataclasses import dataclass
from functools import lru_cache
import torch
import torch.nn as nn
import torch.nn.functional as F
from ... import initialization as init
from ...audio_utils import conv1d_output_length
from ...modeling_utils import PreTrainedAudioTokenizerBase
from ...utils import ModelOutput, auto_docstring
from ..auto import AutoModel
from .configuration_xcodec import XcodecConfig
@dataclass
class XcodecOutput(ModelOutput):
"""
Args:
audio_codes (`torch.LongTensor` of shape `(batch_size, num_quantizers, codes_length)`, *optional*):
Discrete code indices computed using `model.encode`.
audio_values (`torch.FloatTensor` of shape `(batch_size, channels, num_samples)`, *optional*)
Decoded audio values obtained using the decoder part of Xcodec.
"""
audio_codes: torch.LongTensor | None = None
audio_values: torch.FloatTensor | None = None
@dataclass
class XcodecEncoderOutput(ModelOutput):
"""
Args:
audio_codes (`torch.LongTensor` of shape `(batch_size, num_quantizers, codes_length)`, *optional*):
Discrete code indices computed using `model.encode`.
"""
audio_codes: torch.LongTensor | None = None
@dataclass
class XcodecDecoderOutput(ModelOutput):
"""
Args:
audio_values (`torch.FloatTensor` of shape `(batch_size, channels, num_samples)`, *optional*):
Decoded audio values obtained using the decoder part of Xcodec.
"""
audio_values: torch.FloatTensor | None = None
class ResidualUnit(nn.Module):
"""Residual block for SemanticEncoder and SemanticDecoder used in Xcodec."""
def __init__(self, config: XcodecConfig, in_channels: int, out_channels: int, dilation: int):
super().__init__()
self.activation = nn.ELU()
padding = ((config.unit_kernel_size - 1) // 2) * dilation
self.conv1 = nn.Conv1d(
in_channels,
out_channels,
config.unit_kernel_size,
stride=1,
padding=padding,
dilation=dilation,
groups=1,
bias=False,
)
self.conv2 = nn.Conv1d(in_channels=out_channels, out_channels=out_channels, kernel_size=1, bias=False)
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
output_tensor = self.activation(hidden_state)
output_tensor = self.conv1(output_tensor)
output_tensor = self.activation(output_tensor)
output_tensor = self.conv2(output_tensor)
return hidden_state + output_tensor
class SemanticEncoderBlock(nn.Module):
def __init__(self, config: XcodecConfig, in_channels: int, out_channels: int, stride: int):
super().__init__()
self.res_units = nn.ModuleList(
[ResidualUnit(config, in_channels, in_channels, dilation) for dilation in config.block_dilations]
)
# special case: stride=1, do not use kernel=2
kernel = 3 if stride == 1 else (2 * stride)
padding = (kernel - 1) // 2
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=kernel, stride=stride, padding=padding, bias=True)
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
for unit in self.res_units:
hidden_state = unit(hidden_state)
hidden_state = self.conv(hidden_state)
return hidden_state
class SemanticEncoder(nn.Module):
def __init__(self, config):
super().__init__()
if len(config.strides) != len(config.channel_ratios):
raise ValueError("Number of strides must match the number of channel_ratios.")
self.conv = nn.Conv1d(
config.semantic_hidden_size,
config.semantic_hidden_size,
config.kernel_size,
1,
config.kernel_size // 2,
bias=False,
)
in_channels = config.semantic_hidden_size
conv_blocks = []
for i, stride in enumerate(config.strides):
out_channels = int(config.semantic_hidden_size * config.channel_ratios[i])
conv_blocks += [SemanticEncoderBlock(config, in_channels, out_channels, stride)]
in_channels = out_channels
self.conv_blocks = nn.ModuleList(conv_blocks)
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
hidden_state = self.conv(hidden_state)
for block in self.conv_blocks:
hidden_state = block(hidden_state)
return hidden_state
class SemanticDecoderBlock(nn.Module):
def __init__(self, config: XcodecConfig, in_channels: int, out_channels: int, stride: int):
super().__init__()
if stride == 1:
self.conv = nn.Conv1d(
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
bias=True,
)
else:
kernel_size = 2 * stride
padding = (stride + 1) // 2
output_padding = 1 if stride % 2 == 1 else 0
self.conv = nn.ConvTranspose1d(
in_channels, out_channels, kernel_size, stride, padding, output_padding, bias=False
)
self.res_units = nn.ModuleList(
[ResidualUnit(config, out_channels, out_channels, dilation) for dilation in config.block_dilations]
)
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
hidden_state = self.conv(hidden_state)
for unit in self.res_units:
hidden_state = unit(hidden_state)
return hidden_state
class SemanticDecoder(nn.Module):
def __init__(self, config):
super().__init__()
self.conv1 = nn.Conv1d(
in_channels=config.semantic_hidden_size,
out_channels=int(config.semantic_hidden_size * config.channel_ratios[0]),
kernel_size=config.kernel_size,
stride=1,
padding=config.kernel_size // 2,
bias=False,
)
conv_blocks = []
for i, stride in enumerate(config.strides):
in_channels = int(config.semantic_hidden_size * config.channel_ratios[i])
if i < (len(config.channel_ratios) - 1):
out_channels = int(config.semantic_hidden_size * config.channel_ratios[i + 1])
else:
out_channels = config.semantic_hidden_size
conv_blocks += [SemanticDecoderBlock(config, in_channels, out_channels, stride)]
self.conv_blocks = nn.ModuleList(conv_blocks)
self.conv2 = nn.Conv1d(
config.semantic_hidden_size,
config.semantic_hidden_size,
config.kernel_size,
stride=1,
padding=config.kernel_size // 2,
bias=False,
)
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
hidden_state = self.conv1(hidden_state)
for block in self.conv_blocks:
hidden_state = block(hidden_state)
hidden_state = self.conv2(hidden_state)
return hidden_state
class XcodecEuclideanCodebook(nn.Module):
"""Codebook with Euclidean distance."""
def __init__(self, config):
super().__init__()
embed = torch.zeros(config.codebook_size, config.codebook_dim)
self.codebook_size = config.codebook_size
self.register_buffer("inited", torch.Tensor([True]))
self.register_buffer("cluster_size", torch.zeros(config.codebook_size))
self.register_buffer("embed", embed)
self.register_buffer("embed_avg", embed.clone())
# Copied from transformers.models.encodec.modeling_encodec.EncodecEuclideanCodebook.quantize
def quantize(self, hidden_states):
embed = self.embed.t()
scaled_states = hidden_states.pow(2).sum(1, keepdim=True)
dist = -(scaled_states - 2 * hidden_states @ embed + embed.pow(2).sum(0, keepdim=True))
embed_ind = dist.max(dim=-1).indices
return embed_ind
def encode(self, hidden_states):
shape = hidden_states.shape
hidden_states = hidden_states.reshape((-1, shape[-1]))
embed_ind = self.quantize(hidden_states)
embed_ind = embed_ind.view(*shape[:-1])
return embed_ind
def decode(self, embed_ind):
quantized = F.embedding(embed_ind, self.embed)
return quantized
class XcodecVectorQuantization(nn.Module):
"""
Vector quantization implementation. Currently supports only euclidean distance.
"""
def __init__(self, config: XcodecConfig):
super().__init__()
self.codebook = XcodecEuclideanCodebook(config)
# Copied from transformers.models.encodec.modeling_encodec.EncodecVectorQuantization.encode
def encode(self, hidden_states):
hidden_states = hidden_states.permute(0, 2, 1)
embed_in = self.codebook.encode(hidden_states)
return embed_in
# Copied from transformers.models.encodec.modeling_encodec.EncodecVectorQuantization.decode
def decode(self, embed_ind):
quantize = self.codebook.decode(embed_ind)
quantize = quantize.permute(0, 2, 1)
return quantize
class XcodecResidualVectorQuantization(nn.Module):
"""
Residual vector quantization implementation. Follows Algorithm 1 in https://huggingface.co/papers/2107.03312
"""
def __init__(self, config: XcodecConfig):
super().__init__()
self.quantizers = nn.ModuleList([XcodecVectorQuantization(config) for _ in range(config.num_quantizers)])
self.frame_rate = config.frame_rate
self.codebook_size = config.codebook_size
self.num_quantizers = config.num_quantizers
def get_bandwidth_per_quantizer(self):
"""Return bandwidth per quantizer."""
return math.log2(self.codebook_size) * self.frame_rate / 1000
def get_num_quantizers_for_bandwidth(self, bandwidth=None) -> int:
"""Return num_quantizers based on specified target bandwidth."""
bw_per_q = self.get_bandwidth_per_quantizer()
num_quantizers = self.num_quantizers
if bandwidth is not None and bandwidth > 0.0:
num_quantizers = int(max(1, math.floor(bandwidth / bw_per_q)))
return num_quantizers
def encode(self, embeddings: torch.Tensor, bandwidth=None) -> torch.Tensor:
"""
Encode the input tensor into discrete indices using RVQ, with the number of quantizers selected based on the given bandwidth.
Each quantizer /codebook residually quantizes the input and returns the nearest indices in terms of Euclidian distance.
"""
num_quantizers = self.get_num_quantizers_for_bandwidth(bandwidth)
residual = embeddings
all_indices = []
for quantizer in self.quantizers[:num_quantizers]:
indices = quantizer.encode(residual)
quantized = quantizer.decode(indices)
residual = residual - quantized
all_indices.append(indices)
out_indices = torch.stack(all_indices)
return out_indices
def decode(self, codes: torch.Tensor) -> torch.Tensor:
"""Decode the given codes to their quantized representation."""
quantized_out = torch.tensor(0.0, device=codes.device)
for i, indices in enumerate(codes):
quantizer = self.quantizers[i]
quantized = quantizer.decode(indices)
quantized_out = quantized_out + quantized
return quantized_out
@auto_docstring
class XcodecPreTrainedModel(PreTrainedAudioTokenizerBase):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = XcodecConfig
base_model_prefix = "xcodec"
main_input_name = "input_values"
input_modalities = "audio"
@torch.no_grad()
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
init.zeros_(module.bias)
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
init.zeros_(module.bias)
init.ones_(module.weight)
elif isinstance(module, nn.Conv1d):
init.kaiming_normal_(module.weight)
if module.bias is not None:
k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
init.uniform_(module.bias, a=-k, b=k)
elif module.__class__.__name__ == "Snake1d":
init.ones_(module.alpha)
elif isinstance(module, nn.ConvTranspose1d):
module.reset_parameters()
elif isinstance(module, nn.Embedding):
init.normal_(module.weight, mean=0.0, std=0.02)
elif isinstance(module, XcodecModel):
# The conv1d are not handled correctly, as `self.acoustic_encoder/decoder` are initialized from a PreTrainedModel,
# but then only the submodules are used (which are not PreTrainedModels...) -> here we reinit them as in DacModel
for submodule in module.acoustic_encoder.modules():
if isinstance(submodule, nn.Conv1d):
init.trunc_normal_(submodule.weight, std=0.02)
init.constant_(submodule.bias, 0)
for submodule in module.acoustic_decoder.modules():
if isinstance(submodule, nn.Conv1d):
init.trunc_normal_(submodule.weight, std=0.02)
init.constant_(submodule.bias, 0)
elif isinstance(module, XcodecEuclideanCodebook):
init.copy_(module.inited, torch.Tensor([True]))
init.zeros_(module.cluster_size)
init.zeros_(module.embed)
init.zeros_(module.embed_avg)
def apply_weight_norm(self):
"""Apply weight norm in the acoustic encoder and decoder because the original checkpoint has weight norm applied."""
weight_norm = torch.nn.utils.weight_norm
if hasattr(torch.nn.utils.parametrizations, "weight_norm"):
weight_norm = torch.nn.utils.parametrizations.weight_norm
weight_norm(self.acoustic_encoder.conv1)
weight_norm(self.acoustic_encoder.conv2)
for block in self.acoustic_encoder.block:
weight_norm(block.conv1)
for res_unit in (block.res_unit1, block.res_unit2, block.res_unit3):
weight_norm(res_unit.conv1)
weight_norm(res_unit.conv2)
weight_norm(self.acoustic_decoder.conv1, name="weight")
weight_norm(self.acoustic_decoder.conv2, name="weight")
for block in self.acoustic_decoder.block:
weight_norm(block.conv_t1, name="weight")
for res_unit in (block.res_unit1, block.res_unit2, block.res_unit3):
weight_norm(res_unit.conv1, name="weight")
weight_norm(res_unit.conv2, name="weight")
def remove_weight_norm(self):
"""Remove the weight norm from the acoustic encoder and decoder."""
for module in (self.acoustic_encoder, self.acoustic_decoder):
for m in module.modules():
try:
torch.nn.utils.remove_weight_norm(m, name="weight")
except (ValueError, AttributeError):
pass
if hasattr(m, "parametrizations") and "weight" in m.parametrizations:
torch.nn.utils.parametrize.remove_parametrizations(m, "weight", leave_parametrized=True)
@lru_cache
def _get_conv1d_layers(self, module):
"""
Recursively iterate to fetch all Conv1d layers.
"""
def get_conv1d_layers_recursive(module: nn.Module):
params_list = []
if isinstance(module, nn.Conv1d):
params_list.append(module)
# Recursively check all child modules
for child in module.children():
params_list.extend(get_conv1d_layers_recursive(child))
return params_list
return tuple(get_conv1d_layers_recursive(module))
def _get_conv1d_output_lengths(self, input_length, module=None):
"""
For a given module, compute the output length that would be obtained after all Conv1d layers.
"""
if module is None:
module = self
conv1d_layers = self._get_conv1d_layers(module)
for layer in conv1d_layers:
input_length = conv1d_output_length(layer, input_length)
return input_length
@auto_docstring(custom_intro="""The Xcodec neural audio codec model.""")
class XcodecModel(XcodecPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.config = config
self.pad = config.hop_length // 2
acoustic_model = AutoModel.from_config(config.acoustic_model_config)
self.acoustic_encoder = acoustic_model.encoder
self.acoustic_decoder = acoustic_model.decoder
self._adjust_dac_decoder(self.acoustic_decoder)
self.encoder_semantic = SemanticEncoder(config)
self.decoder_semantic = SemanticDecoder(config)
self.semantic_model = AutoModel.from_config(config.semantic_model_config).eval()
self.fc = nn.Linear(config.hidden_size, config.hidden_size)
self.fc1 = nn.Linear(config.hidden_size, config.semantic_model_config.hidden_size)
self.fc2 = nn.Linear(config.hidden_size, config.acoustic_model_config.hidden_size)
self.quantizer = XcodecResidualVectorQuantization(config)
# Initialize weights and apply final processing
self.post_init()
@staticmethod
def _adjust_dac_decoder(decoder: nn.Module):
r"""
DAC implemented in Xcodec is slightly different from the HF version.
DAC in Xcodec adjusts the output padding in every ConvTranspose1d in the decoder and removes
the final `nn.Tanh` activation function.
"""
for module in decoder.modules():
if isinstance(module, nn.ConvTranspose1d):
stride = module.stride[0] if isinstance(module.stride, tuple) else module.stride
module.output_padding = (stride % 2,)
if hasattr(decoder, "tanh") and isinstance(decoder.tanh, nn.Tanh):
decoder.tanh = nn.Identity()
def _extract_semantic_features(self, input_values: torch.FloatTensor) -> torch.FloatTensor:
input_values = input_values[:, 0, :]
input_values = F.pad(input_values, (self.pad, self.pad))
with torch.no_grad():
outputs = self.semantic_model(input_values, output_hidden_states=True)
hidden_states = outputs.hidden_states
stacked = torch.stack(hidden_states, dim=1)
return stacked.mean(dim=1)
@auto_docstring
def encode(
self,
input_values: torch.Tensor,
bandwidth: float | None = None,
return_dict: bool | None = None,
) -> torch.Tensor | XcodecEncoderOutput:
r"""
input_values (`torch.FloatTensor` of shape `(batch_size, channels, num_samples)`):
Float values of the input audio waveform.
bandwidth (`float`, *optional*):
The target bandwidth in (kbps) supports only values in `config.target_bandwidths`.
Defaults to the highest available bandwidth `4.0` kbps.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`].
Returns:
`torch.LongTensor` of shape `(batch_size, num_quantizers, codes_length)` containing the discrete encoded audio codes.
"""
return_dict = return_dict if return_dict is not None else self.config.return_dict
channels = input_values.shape[1]
if channels != 1:
raise ValueError(f"Audio must be mono, but got {channels}")
if bandwidth is None:
bandwidth = self.config.target_bandwidths[-1]
elif bandwidth not in self.config.target_bandwidths:
raise ValueError(
f"This model doesn't support the bandwidth {bandwidth}. Select one of {self.config.target_bandwidths}."
)
e_semantic_input = self._extract_semantic_features(input_values).detach()
e_semantic = self.encoder_semantic(e_semantic_input.transpose(1, 2))
# original codebase infer to get the output length, but we can directly infer it
# from the model and know whether we should pad
if self._get_conv1d_output_lengths(input_values.shape[2], self.acoustic_encoder) != e_semantic.shape[2]:
e_acoustic = self.acoustic_encoder(F.pad(input_values, (self.pad, self.pad)))
else:
e_acoustic = self.acoustic_encoder(input_values)
embeddings = torch.cat([e_acoustic, e_semantic], dim=1)
embeddings = self.fc(embeddings.transpose(1, 2)).transpose(1, 2)
audio_codes = self.quantizer.encode(embeddings, bandwidth)
audio_codes = audio_codes.transpose(0, 1)
if not return_dict:
return audio_codes
return XcodecEncoderOutput(audio_codes)
@auto_docstring
def decode(
self,
audio_codes: torch.Tensor,
return_dict: bool | None = None,
) -> torch.Tensor | XcodecDecoderOutput:
r"""
audio_codes (`torch.LongTensor` of shape `(batch_size, num_quantizers, codes_length)`):
Discrete code indices computed using `model.encode`.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`]
Returns:
Decoded audio values of shape `(batch_size, channels, num_samples)` obtained using the decoder part of
Xcodec.
"""
return_dict = return_dict if return_dict is not None else self.config.return_dict
audio_codes = audio_codes.transpose(0, 1)
quantized = self.quantizer.decode(audio_codes)
quantized_acoustic = self.fc2(quantized.transpose(1, 2)).transpose(1, 2)
audio_values = self.acoustic_decoder(quantized_acoustic)
if not return_dict:
return audio_values
return XcodecDecoderOutput(audio_values)
@auto_docstring
def forward(
self,
input_values: torch.Tensor,
audio_codes: torch.Tensor | None = None,
bandwidth: float | None = None,
return_dict: bool | None = None,
) -> tuple[torch.Tensor, torch.Tensor] | XcodecOutput:
r"""
input_values (`torch.FloatTensor` of shape `(batch_size, channels, num_samples)`):
The raw float values of the input audio waveform.
audio_codes (`torch.LongTensor` of shape `(batch_size, num_quantizers, codes_length)`:
Discrete code indices computed using `model.encode`.
bandwidth (`float`, *optional*):
Target bandwidth in kbps. Must be one of `config.target_bandwidths`. Defaults to the highest available bandwidth.
bandwidth (`float`, *optional*):
Target bandwidth in kbps. Must be one of `config.target_bandwidths`. Defaults to the highest available bandwidth.
return_dict (`bool`, *optional*):
Whether to return a [`XcodecOutput`] instead of a plain tuple.
Returns:
`XcodecOutput` or tuple `(audio_codes, audio_values)`:
- `audio_codes` of shape `(batch_size, num_quantizers, codes_length)`: the quantized discrete codes.
- `audio_values` of shape `(batch_size, channels, num_samples)`: the reconstructed audio waveform given the codes.
Example:
```python
>>> from datasets import load_dataset
>>> from transformers import AutoFeatureExtractor, XcodecModel
>>> model_id = "hf-audio/xcodec-hubert-librispeech"
>>> model = XcodecModel.from_pretrained(model_id)
>>> feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> dataset = dataset.cast_column("audio", Audio(sampling_rate=feature_extractor.sampling_rate))
>>> audio_sample = dataset[0]['audio']['array']
>>> inputs = feature_extractor(raw_audio=audio_sample, return_tensors="pt")
>>> outputs = model(**inputs)
>>> audio_codes = outputs.audio_codes
>>> audio_values = outputs.audio_values
```
"""
return_dict = return_dict if return_dict is not None else self.config.return_dict
length = input_values.shape[-1]
if audio_codes is None:
audio_codes = self.encode(input_values, bandwidth, return_dict=False)
audio_values = self.decode(audio_codes, return_dict=return_dict)[0][..., :length]
if not return_dict:
return (audio_codes, audio_values)
return XcodecOutput(audio_codes=audio_codes, audio_values=audio_values)
__all__ = ["XcodecModel", "XcodecPreTrainedModel"]