You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1377 lines
56 KiB

4 days ago
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import math
import operator
import os
import re
from functools import reduce
from ..distributed import DistributedConfig
from ..utils import is_torch_greater_or_equal, logging
from ..utils.generic import GeneralInterface
from ..utils.import_utils import is_torch_available
if is_torch_available():
import torch
import torch.distributed as dist
from torch import nn
# Cache this result has it's a C FFI call which can be pretty time-consuming
_torch_distributed_available = torch.distributed.is_available()
logger = logging.get_logger(__name__)
def initialize_tensor_parallelism(
tp_plan: str | dict[str, str] | None, tp_size: int | None = None, device_mesh=None, device_map=None
):
r"""
Sets up the device mesh and initialized the backend for tensor parallelism.
This function is called when the model is loaded and the TP plan is set to 'auto'.
"""
if tp_size is not None and tp_plan is None:
raise ValueError("tp_plan has to be set when tp_size is passed.")
if tp_plan is not None and device_map is not None:
raise ValueError("`tp_plan` and `device_map` are mutually exclusive. Choose either one for parallelization.")
if device_mesh is None:
if not is_torch_greater_or_equal("2.5"):
raise OSError("Tensor parallel is only supported for `torch>=2.5`.")
# Detect the accelerator on the machine. If no accelerator is available, it returns CPU.
device_type = torch._C._get_accelerator().type
if device_type == "mps":
device_type = "cpu" # fallback
current_device = getattr(torch, device_type)
if not torch.distributed.is_initialized():
try:
rank = int(os.environ["RANK"])
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
backend_map = {"cuda": "nccl", "cpu": "gloo", "xpu": "xccl", "hpu": "hccl"}
backend = backend_map.get(device_type)
torch.distributed.init_process_group(backend=backend, rank=rank, world_size=world_size)
current_device = getattr(torch, device_type)
if device_type != "cpu":
current_device.set_device(local_rank)
except Exception as e:
raise OSError(
"We tried to initialize torch.distributed for you, but it failed. Make "
"sure you init torch distributed in your script to use `tp_plan`."
) from e
if device_type != "cpu":
current_device.set_device(int(os.environ["LOCAL_RANK"]))
index = current_device.current_device()
tp_device = torch.device(device_type, index)
device_map = tp_device
# Silence output for non-primary ranks
if index > 0:
import sys
sys.stdout = open(os.devnull, "w")
sys.stderr = open(os.devnull, "w")
else:
tp_device = torch.device(device_type)
device_map = device_type or {}
tp_size = tp_size if tp_size is not None else torch.distributed.get_world_size()
device_mesh = torch.distributed.init_device_mesh(tp_device.type, (tp_size,))
else:
if device_mesh.ndim > 1:
if "tp" not in device_mesh.mesh_dim_names:
raise ValueError(
"When using `tp_plan` and n-d `device_mesh`, it must contain a 'tp' dimension. "
"Please provide a valid `device_mesh`."
)
device_mesh = device_mesh["tp"]
tp_size = device_mesh.size()
device_map = torch.device(f"{device_mesh.device_type}:{int(os.environ['LOCAL_RANK'])}")
return device_map, device_mesh, tp_size
def replace_layer_number_by_wildcard(name: str) -> str:
"""
Replace the numbers in the `name` by wildcards, only if they are in-between dots (`.`) or if they are between
a dot (`.`) and the end of the string.
This matches how modules are named/numbered when using a nn.ModuleList or nn.Sequential, but will NOT match
numbers in a parameter name itself, e.g. if the param is named `"w1"` or `"w2"`.
"""
return re.sub(r"\.\d+(\.|$)", lambda m: ".*" + m.group(1), name)
def _get_parameter_tp_plan(parameter_name: str, tp_plan: dict[str, str], is_weight=True) -> str | None:
"""
Get the TP style for a parameter from the TP plan.
The TP plan is a dictionary that maps parameter names to TP styles.
The parameter name can be a generic name with wildcards (e.g. "*.weight") or a specific name (e.g. "layer_1.weight").
The `is_weight` is important because for weights, we want to support `.weights` and `.bias` cases seamlessly! but
not parent classes for `post_init` calls
"""
generic_param_name = replace_layer_number_by_wildcard(parameter_name)
if generic_param_name in tp_plan:
return tp_plan[generic_param_name]
elif is_weight and "." in generic_param_name and (module_name := generic_param_name.rsplit(".", 1)[0]) in tp_plan:
return tp_plan[module_name]
return None
# =============================================================================
# Tensor Sharding Utilities
# =============================================================================
if is_torch_available():
str_to_dtype = {
"BOOL": torch.bool,
"U8": torch.uint8,
"I8": torch.int8,
"I16": torch.int16,
"F16": torch.float16,
"BF16": torch.bfloat16,
"I32": torch.int32,
"F32": torch.float32,
"F64": torch.float64,
"I64": torch.int64,
"F8_E4M3": torch.float8_e4m3fn,
}
def _blocks_to_block_sizes(total_size: int, blocks: int | list[int]) -> list[int]:
"""
Convert block count or proportions to block sizes.
This function accepts
- The number of blocks (int), in which case the block size is
total_size//blocks; or
- A list of block sizes (list[int]).
In the second case, if sum(blocks) < total_size, the ratios between
the block sizes will be preserved. For instance, if blocks is
[2, 1, 1] and total_size is 1024, the returned block sizes are
[512, 256, 256].
"""
if isinstance(blocks, list):
total_blocks = sum(blocks)
assert total_size % total_blocks == 0, f"Cannot split {total_size} in proportional blocks: {blocks}"
part_size = total_size // total_blocks
return [part_size * block for block in blocks]
else:
assert total_size % blocks == 0, f"Prepacked is not divisible by {blocks}"
single_size = total_size // blocks
return [single_size] * blocks
def get_packed_weights(param, empty_param, device_mesh, rank, dim):
"""
When weights are packed (gate_up_proj), we need to make sure each shard gets its correct share.
So if you have: gate_proj ( 16, 5120, 8190)
and up_proj ( 16, 5120, 8190)
packed as gate_up_proj ( 16, 5120, 2 * 8190)
And you shard along the last dimension, you need to interleave the gate and up values:
Now, if we shard along the last dimension across TP_size (Tensor Parallelism size), we must interleave the values from gate and up projections correctly.
Let's take TP_size = 4 for an example:
Packed tensor `gate_up_proj`
---------------------------------------------------------------
[ G0 G1 G2 G3 | G4 G5 G6 G7 | ... | U0 U1 U2 U3 | U4 U5 U6 U7 | ... ]
Gate Slice 0 Gate Slice 1 Up Slice 0 Up Slice 1
Explanation:
- The first half of the tensor (left of the center) holds the gate_proj values.
- The second half (right of the center) holds the up_proj values.
- For TP=4, we divide each half into 4 slices. In this example, we show two slices for brevity.
- Each shard receives one slice from the gate part and the corresponding slice from the up part.
For instance:
Shard 0 gets: [ Gate Slice 0, Up Slice 0 ] = [ G0, G1, G2, G3, U0, U1, U2, U3 ]
Shard 1 gets: [ Gate Slice 1, Up Slice 1 ] = [ G4, G5, G6, G7, U4, U5, U6, U7 ]
and so on.
This ensures that each shard receives an equal portion of both gate and up projections, maintaining consistency across tensor parallelism.
"""
slice_ = param
total_size = empty_param.shape[dim]
world_size = device_mesh.size()
block_sizes = _blocks_to_block_sizes(total_size=total_size, blocks=2)
tensors_slices = []
block_offset = 0
for block_size in block_sizes:
shard_block_size = block_size // world_size
start = rank * shard_block_size
stop = (rank + 1) * shard_block_size
tensors_slices += range(block_offset + start, block_offset + stop)
block_offset += block_size
slice_dtype = slice_.get_dtype()
# Handle F8_E4M3 dtype by converting to float16 before slicing
# Without upcasting, the slicing causes : RuntimeError: "index_cpu" not implemented for 'Float8_e4m3fn'
casted = False
if slice_dtype == "F8_E4M3" or slice_dtype == "F8_E5M2":
slice_ = slice_[...].to(torch.float16)
casted = True
if dim == 0:
tensor = slice_[tensors_slices, ...]
elif dim == 1 or dim == -2:
tensor = slice_[:, tensors_slices, ...]
elif dim == 2 or dim == -1:
tensor = slice_[..., tensors_slices]
else:
raise ValueError(f"Unsupported dim {dim}, only dim 0, 1 or 2 are supported")
if casted:
return tensor
else:
return tensor.to(str_to_dtype[slice_dtype])
def repack_weights(
packed_parameter: torch.Tensor,
sharded_dim: int, # The dimension index in the global tensor that was sharded
world_size: int,
num_blocks: int = 2,
) -> torch.Tensor:
"""
Reorders a tensor that was reconstructed from sharded packed weights into its canonical packed format.
For example, if a weight was packed (e.g., gate_proj and up_proj) and then sharded,
DTensor.full_tensor() might produce an interleaved layout like [G0, U0, G1, U1, ...]
along the sharded dimension. This function reorders it to [G0, G1, ..., U0, U1, ...].
This is an inverse operation to get_packed_weights.
Args:
reconstructed_tensor: The tensor reconstructed from DTensor (e.g., via .full_tensor().contiguous()).
sharded_dim: The dimension index in the reconstructed_tensor that was originally sharded.
world_size: The tensor parallel world size.
num_packed_projs: The number of projections that were packed together (e.g., 2 for gate_up_proj).
Returns:
The reordered tensor in canonical packed format.
"""
if num_blocks != 2:
raise ValueError(
"Num blocks different from 2 is not supported yet. This is most likely a bug in your implementation as we only pack gate and up projections together."
)
actual_sharded_dim = sharded_dim if sharded_dim >= 0 else sharded_dim + packed_parameter.ndim
total_size_on_sharded_dim = packed_parameter.shape[actual_sharded_dim]
original_block_size_on_dim = total_size_on_sharded_dim // num_blocks
shard_chunk_size = original_block_size_on_dim // world_size
prefix_shape = packed_parameter.shape[:actual_sharded_dim]
suffix_shape = packed_parameter.shape[actual_sharded_dim + 1 :]
tensor_view = packed_parameter.view(
*prefix_shape,
world_size,
num_blocks,
shard_chunk_size,
*suffix_shape,
)
# Permute to bring num_packed_projs first, then world_size, then shard_chunk_size
# This groups all chunks of G together, then all chunks of U together.
# Target order of these middle dimensions: (num_packed_projs, world_size, shard_chunk_size)
# Current order of view's middle dimensions: (world_size, num_packed_projs, shard_chunk_size)
# Absolute indices of the dimensions to be permuted (world_size, num_packed_projs)
axis_ws_abs = len(prefix_shape)
axis_npp_abs = len(prefix_shape) + 1
permute_order = list(range(tensor_view.ndim))
permute_order[axis_ws_abs], permute_order[axis_npp_abs] = permute_order[axis_npp_abs], permute_order[axis_ws_abs]
tensor_permuted = tensor_view.permute(*permute_order)
# Reshape back to the original tensor's ndim, with the sharded dimension now correctly ordered as [G_all, U_all].
# The final shape should be the same as reconstructed_tensor.
final_ordered_tensor = tensor_permuted.reshape_as(packed_parameter)
return final_ordered_tensor
def get_tensor_shard(param, empty_param, device_mesh, rank, dim, tensor_idx: int | None = None):
"""
Generalized tensor sharding across a multi-dimensional device mesh.
Extract only the fraction of the parameter owned by the given `rank` when the parameter would have gone sharding at provided `dim`.
Extraction follows the pytorch `Shard` placement so that sharding and materializing back to full tensor follows `Shard` semantics.
`Shard` follows torch.chunk style sharding of the tensor. We demonstrate some cases below on how sharding happens including some edge cases
such as some ranks having an empty tensor as shard. Below implementation is robut to all these cases.
Case (1)
empty_param (16, 5120, 8190)
dim 0
device_mesh.size() 4
rank 0 gets (4, 5120, 8190) (0 ... 4, 5120, 8190)
rank 1 gets (4, 5120, 8190) (4 ... 8, 5120, 8190)
rank 2 gets (4, 5120, 8190) (8 ... 12, 5120, 8190)
rank 3 gets (4, 5120, 8190) (12 ... 16, 5120, 8190)
Case (2)
empty_param (16, 5120, 8190)
dim 0
device_mesh.size() 14
rank 0 gets (2, 5120, 8190) (0 ... 2, 5120, 8190)
rank 1 gets (2, 5120, 8190) (2 ... 4, 5120, 8190)
rank 2 gets (2, 5120, 8190) (4 ... 6, 5120, 8190)
rank 3 gets (2, 5120, 8190) (6 ... 8, 5120, 8190)
rank 4 gets (2, 5120, 8190) (8 ... 10, 5120, 8190)
rank 5 gets (2, 5120, 8190) (10 ... 12, 5120, 8190)
rank 6 gets (2, 5120, 8190) (12 ... 14, 5120, 8190)
rank 7 gets (2, 5120, 8190) (14 ... 16, 5120, 8190)
rank 8 gets (0, 5120, 8190)
rank 9 gets (0, 5120, 8190)
rank 10 gets (0, 5120, 8190)
rank 11 gets (0, 5120, 8190)
rank 12 gets (0, 5120, 8190)
rank 13 gets (0, 5120, 8190)
Case (3)
empty_param (16, 5120, 8190)
dim 0
device_mesh.size() 3
rank 0 gets (6, 5120, 8190) (0 ... 6, 5120, 8190)
rank 1 gets (6, 5120, 8190) (6 ... 12, 5120, 8190)
rank 2 gets (4, 5120, 8190) (12 ... 16, 5120, 8190)
In case (2), empty shards are returned with appropriate dimension to allow for operations to work smoothly.
Args:
param (torch.Tensor): The tensor to shard.
empty_param (torch.Tensor): A tensor used for shape reference.
device_mesh (torch.Tensor): Shape [d_0, ..., d_n] representing the mesh.
rank (int): Global rank of the current process/device.
dim (int): Dimension along which to shard the tensor.
"""
param_dim = empty_param.ndim
mesh_shape = device_mesh.shape
world_size = reduce(operator.mul, mesh_shape)
# Get param shape: works for both torch.Tensor and safetensors TensorInfo
param_shape = list(param.shape) if isinstance(param, torch.Tensor) else param.get_shape()
if dim < 0:
dim = param_dim + dim
if empty_param.dim() == 3 and dim == 1 and len(param_shape) == 2:
dim = 0
elif empty_param.dim() == 3 and dim == 2 and len(param_shape) == 2:
dim = 1
shard_size = math.ceil(param_shape[dim] / world_size)
start = rank * shard_size
end = min(start + shard_size, param_shape[dim])
if dim >= param_dim:
raise ValueError(f"dim {dim} is out of bounds for tensor of dimension {param_dim}")
if rank >= world_size:
raise ValueError(f"Rank {rank} is out of bounds for mesh size {world_size}")
# we have the full tensor not 1 part of it.
# in that case, we just assume that the weight was properly saved
# and thus because we TP if the layer is colwise it should not use this. Layer should be packed_colwise
# to inform that it needs to read form a packed tensor. It will also take care of the module list thingy.
# here we take care of potential chunking / layer split / layer chunking.
# The only "hard" case is? if we collect q,k,v -> merge it into qkv. In that case
# actually we still shard dim=0 does not change
# so only case is if the dim of the empty param is 3 and the shard dim is 0 -> we put the
# tensor on a certain device (with the input tensor_index)
if tensor_idx is not None and empty_param.dim() == 3 and dim == 0 and len(param_shape) == 2:
# special case we don't "shard" just send this entire tensor to the correct rank.
if start <= tensor_idx < end:
# this tensor does need to be materialized on this device:
return param[:]
else:
return torch.empty([], dtype=torch.int64, device=rank)
slice_indices = [slice(None)] * len(param_shape)
if start < param_shape[dim]:
slice_indices[dim] = slice(start, end)
param = param[tuple(slice_indices)]
if isinstance(param, list): # TODO handle the modulelist case!
param = [p[:] for p in param]
return param
param_shape[dim] = 0
return torch.empty(tuple(param_shape), dtype=torch.int64) # empty allocates memory....
def _split_along_last_dim(x, world_size):
"""Split tensor along last dimension into world_size chunks."""
return torch.chunk(x, world_size, dim=-1)
# =============================================================================
# Distributed Communication Primitives
# =============================================================================
#
# Naming convention:
# - Functions describe their FORWARD behavior
# - Backward behavior is the "conjugate" operation for gradient flow
#
# Available operations:
# ┌────────────────────┬─────────────────────┬─────────────────────┐
# │ Function │ Forward │ Backward │
# ├────────────────────┼─────────────────────┼─────────────────────┤
# │ all_reduce │ all-reduce (sum) │ identity │
# │ all_reduce_backward│ identity │ all-reduce (sum) │
# │ all_gather │ all-gather │ split (local chunk) │
# │ split │ split (local chunk) │ all-gather │
# │ reduce_scatter │ reduce-scatter │ all-gather │
# └────────────────────┴─────────────────────┴─────────────────────┘
# ===================
class _AllReduceBackward(torch.autograd.Function):
"""Identity forward, all-reduce backward. Used before colwise layers (f in Megatron)."""
@staticmethod
def forward(ctx, x, device_mesh):
ctx.device_mesh = device_mesh
return x
@staticmethod
def backward(ctx, grad_output):
device_mesh = ctx.device_mesh
if device_mesh.size() == 1:
return grad_output, None
dist.all_reduce(grad_output, op=dist.ReduceOp.SUM, group=device_mesh.get_group())
return grad_output, None
class _AllReduceForward(torch.autograd.Function):
"""All-reduce forward, identity backward. Used after rowwise layers (g in Megatron)."""
@staticmethod
def forward(ctx, x, device_mesh):
if device_mesh.size() == 1:
return x
dist.all_reduce(x, op=dist.ReduceOp.SUM, group=device_mesh.get_group())
return x
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
class _AllGather(torch.autograd.Function):
"""All-gather forward, split backward. Gathers sharded outputs."""
@staticmethod
def forward(ctx, x, device_mesh):
ctx.device_mesh = device_mesh
world_size = device_mesh.size()
if world_size == 1:
return x
last_dim = x.dim() - 1
rank = device_mesh.get_local_rank()
group = device_mesh.get_group()
x = x.contiguous()
tensor_list = [torch.empty_like(x) for _ in range(world_size)]
tensor_list[rank] = x
dist.all_gather(tensor_list, x, group=group)
return torch.cat(tensor_list, dim=last_dim).contiguous()
@staticmethod
def backward(ctx, grad_output):
device_mesh = ctx.device_mesh
world_size = device_mesh.size()
if world_size == 1:
return grad_output, None
rank = device_mesh.get_local_rank()
chunks = _split_along_last_dim(grad_output, world_size)
return chunks[rank].contiguous(), None
class _Split(torch.autograd.Function):
"""Split forward, all-gather backward. Scatters replicated input."""
@staticmethod
def forward(ctx, x, device_mesh):
ctx.device_mesh = device_mesh
world_size = device_mesh.size()
if world_size == 1:
return x
rank = device_mesh.get_local_rank()
chunks = _split_along_last_dim(x, world_size)
return chunks[rank].contiguous()
@staticmethod
def backward(ctx, grad_output):
device_mesh = ctx.device_mesh
world_size = device_mesh.size()
if world_size == 1:
return grad_output, None
last_dim = grad_output.dim() - 1
rank = device_mesh.get_local_rank()
group = device_mesh.get_group()
grad_output = grad_output.contiguous()
tensor_list = [torch.empty_like(grad_output) for _ in range(world_size)]
tensor_list[rank] = grad_output
dist.all_gather(tensor_list, grad_output, group=group)
return torch.cat(tensor_list, dim=last_dim).contiguous(), None
class _ReduceScatter(torch.autograd.Function):
"""Reduce-scatter forward, all-gather backward. For sequence parallel."""
@staticmethod
def forward(ctx, x, device_mesh):
ctx.device_mesh = device_mesh
world_size = device_mesh.size()
if world_size == 1:
return x
last_dim = x.dim() - 1
group = device_mesh.get_group()
input_chunks = list(x.chunk(world_size, dim=last_dim))
output_shape = list(x.shape)
output_shape[last_dim] //= world_size
output = torch.empty(output_shape, dtype=x.dtype, device=x.device)
dist.reduce_scatter(output, input_chunks, op=dist.ReduceOp.SUM, group=group)
return output
@staticmethod
def backward(ctx, grad_output):
device_mesh = ctx.device_mesh
world_size = device_mesh.size()
if world_size == 1:
return grad_output, None
last_dim = grad_output.dim() - 1
rank = device_mesh.get_local_rank()
group = device_mesh.get_group()
grad_output = grad_output.contiguous()
tensor_list = [torch.empty_like(grad_output) for _ in range(world_size)]
tensor_list[rank] = grad_output
dist.all_gather(tensor_list, grad_output, group=group)
return torch.cat(tensor_list, dim=last_dim).contiguous(), None
# =============================================================================
# Convenience wrappers
# =============================================================================
def all_reduce_backward(x, device_mesh):
"""Identity forward, all-reduce backward. Use before colwise layers."""
return _AllReduceBackward.apply(x, device_mesh)
def all_reduce_forward(x, device_mesh):
"""All-reduce forward, identity backward. Use after rowwise layers."""
return _AllReduceForward.apply(x, device_mesh)
def all_gather(x, device_mesh):
"""All-gather forward, split backward."""
return _AllGather.apply(x, device_mesh)
def split(x, device_mesh):
"""Split forward, all-gather backward."""
return _Split.apply(x, device_mesh)
def reduce_scatter(x, device_mesh):
"""Reduce-scatter forward, all-gather backward."""
return _ReduceScatter.apply(x, device_mesh)
def distribute_module(
module: nn.Module,
device_mesh=None,
input_fn=None,
output_fn=None,
) -> nn.Module:
"""
Copy pasted from torch's function but we remove the communications (partitioning)
as well as buffer registering that is similarly not efficient.
"""
if input_fn is not None:
module.register_forward_pre_hook(lambda mod, inputs: input_fn(mod, inputs, device_mesh))
if output_fn is not None:
module.register_forward_hook(lambda mod, inputs, outputs: output_fn(mod, outputs, device_mesh))
return module
class TensorParallelLayer:
"""General tensor parallel layer for transformers"""
device_mesh = None
rank = None
empty_param = None
def __init__(self, device_mesh=None, rank=None, empty_param=None):
self.rank = rank
self.device_mesh = device_mesh
self.empty_param = empty_param
@staticmethod
def _prepare_input_fn(mod, inputs, device_mesh): ...
@staticmethod
def _prepare_output_fn(mod, outputs, device_mesh): ...
def shard_tensor(
self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
) -> torch.Tensor:
raise NotImplementedError
def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module:
distribute_module(
module,
device_mesh,
self._prepare_input_fn,
self._prepare_output_fn,
)
def get_expected_sharded_shape(self, full_shape: tuple[int, ...] | torch.Size) -> tuple[int, ...]:
"""
Compute the expected shape after TP sharding for a given full shape.
Args:
full_shape: The full (unsharded) parameter shape
Returns:
The expected sharded shape for this rank
"""
# Default: no sharding, return full shape
return tuple(full_shape)
class ColwiseParallel(TensorParallelLayer):
"""
Column-wise parallel: weight is sharded on dim -2 (output features).
Forward: input replicated -> output sharded on last dim.
If gather_output=True, output is all-gathered to produce full tensor.
"""
def __init__(self, gather_output: bool = False, **kwargs):
super().__init__(**kwargs)
self.gather_output = gather_output
def _prepare_input_fn(self, mod, inputs, device_mesh):
input_tensor = inputs[0] if inputs else inputs
return all_reduce_backward(input_tensor, device_mesh)
def _prepare_output_fn(self, mod, outputs, device_mesh):
if self.gather_output:
return all_gather(outputs, device_mesh)
return outputs
def shard_tensor(
self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
) -> torch.Tensor:
# If only 1 dim, shard this one (usually it's a `bias`)
dim = param.dim() if isinstance(param, torch.Tensor) else len(param.get_shape())
if dim == 1:
parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -1)
else:
parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -2)
return parameter.to(device=device, dtype=dtype)
def get_expected_sharded_shape(self, full_shape: tuple[int, ...] | torch.Size) -> tuple[int, ...]:
world_size = self.device_mesh.size()
shape = list(full_shape)
# Colwise shards dim -2, but 1D tensors (bias) shard on dim -1
dim = -1 if len(shape) == 1 else -2
dim = len(shape) + dim if dim < 0 else dim
shard_size = math.ceil(shape[dim] / world_size)
start = self.rank * shard_size
end = min(start + shard_size, shape[dim])
shape[dim] = end - start
return tuple(shape)
class RowwiseParallel(TensorParallelLayer):
"""
Row-wise parallel: weight is sharded on dim -1 (input features).
Forward: input (optionally split) -> output partial -> all-reduce to replicate.
Args:
split_input: If True, splits replicated input before matmul. Use when input
comes from a non-parallelizable operation (chunk/slice).
Default False (expects pre-sharded input from colwise layer).
"""
def __init__(self, split_input: bool = False, **kwargs):
super().__init__(**kwargs)
self.split_input = split_input
def _prepare_input_fn(self, mod, inputs, device_mesh):
if hasattr(mod, "bias") and mod.bias is not None:
mod._bias = mod.bias
mod.bias = None
input_tensor = inputs[0] if inputs else inputs
if self.split_input:
# Input is replicated, split it to match sharded weight
return split(input_tensor, device_mesh)
return input_tensor
def _prepare_output_fn(self, mod, outputs, device_mesh):
outputs = all_reduce_forward(outputs, device_mesh)
if hasattr(mod, "_bias") and mod._bias is not None:
outputs = outputs + mod._bias
return outputs
def shard_tensor(
self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
) -> torch.Tensor:
# If only 1 dim, it should not be sharded (usually it's a `bias`)
dim = param.dim() if isinstance(param, torch.Tensor) else len(param.get_shape())
if dim == 1:
parameter = param[...]
else:
parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -1)
return parameter.to(device=device, dtype=dtype)
def get_expected_sharded_shape(self, full_shape: tuple[int, ...] | torch.Size) -> tuple[int, ...]:
# 1D tensors (bias) are NOT sharded in rowwise
if len(full_shape) == 1:
return tuple(full_shape)
world_size = self.device_mesh.size()
shape = list(full_shape)
dim = -1
dim = len(shape) + dim if dim < 0 else dim
shard_size = math.ceil(shape[dim] / world_size)
start = self.rank * shard_size
end = min(start + shard_size, shape[dim])
shape[dim] = end - start
return tuple(shape)
class PackedColwiseParallel(ColwiseParallel):
"""Packed column-wise parallel for fused weights like gate_up_proj."""
def shard_tensor(
self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
) -> torch.Tensor:
# If only 1 dim, shard this one (usually it's a `bias`)
dim = param.dim() if isinstance(param, torch.Tensor) else len(param.get_shape())
if dim == 1:
parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -1)
else:
expected_shape = self.get_expected_sharded_shape(self.empty_param.shape)
if dim < len(expected_shape):
# Input is unpacked (e.g., gate_proj that will be concatenated to gate_up_proj)
# Use regular tensor shard - concatenation will happen after
parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -2)
else:
# Input is already packed, use packed sharding
parameter = get_packed_weights(param, self.empty_param, self.device_mesh, self.rank, -2)
return parameter.to(device=device, dtype=dtype)
class PackedRowwiseParallel(RowwiseParallel):
"""Packed row-wise parallel for fused weights like gate_up_proj."""
def shard_tensor(
self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
) -> torch.Tensor:
# If only 1 dim, it should not be sharded (usually it's a `bias`)
dim = param.dim() if isinstance(param, torch.Tensor) else len(param.get_shape())
if dim == 1:
parameter = param[...]
else:
# Check if input tensor is unpacked (shape mismatch with expected packed size)
# This happens when using MergeModulelist + Concatenate for fused weights like gate_up_proj
param_shape = param.shape if isinstance(param, torch.Tensor) else param.get_shape()
expected_packed_dim = self.empty_param.shape[-1] if self.empty_param.dim() >= 1 else 0
actual_dim = param_shape[-1] if len(param_shape) >= 1 else 0
if actual_dim < expected_packed_dim:
# Input is unpacked, use regular tensor shard
parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -1)
else:
# Input is already packed, use packed sharding
parameter = get_packed_weights(param, self.empty_param, self.device_mesh, self.rank, -1)
return parameter.to(device=device, dtype=dtype)
class EmbeddingParallel(TensorParallelLayer):
"""EmbeddingParallel: shards embedding table, handles masked lookups for vocab parallelism."""
def __init__(self, *, embedding_dim_sharding: int = 0, **kwargs):
super().__init__(**kwargs)
self.embedding_dim_sharding = embedding_dim_sharding
def _prepare_input_fn(self, mod, inputs, device_mesh):
input_tensor = inputs[0] if inputs else inputs
# For vocab-parallel (dim 0), we need to handle masking and offsetting
if self.embedding_dim_sharding == 0:
rank = device_mesh.get_local_rank()
# Get vocab range for this rank
# Use weight.shape[0] to get the actual local (sharded) size, not num_embeddings
# which may not be updated after sharding
per_partition_size = mod.weight.shape[0]
vocab_start_index = rank * per_partition_size
vocab_end_index = vocab_start_index + per_partition_size
# Build mask for out-of-vocabulary tokens
input_mask = (input_tensor < vocab_start_index) | (input_tensor >= vocab_end_index)
mod._input_mask = input_mask
# Offset input to local indices and mask invalid ones
masked_input = input_tensor.clone() - vocab_start_index
masked_input[input_mask] = 0 # Set to valid local index
return masked_input
return input_tensor
def _prepare_output_fn(self, mod, outputs, device_mesh):
# For vocab-parallel (dim 0), zero out embeddings for out-of-range tokens before all-reduce
if self.embedding_dim_sharding == 0 and hasattr(mod, "_input_mask"):
input_mask = mod._input_mask
# Use multiplication instead of in-place assignment to preserve gradients
mask_expanded = input_mask.unsqueeze(-1).expand_as(outputs)
outputs = outputs * (~mask_expanded).float()
del mod._input_mask
return all_reduce_forward(outputs, device_mesh)
def shard_tensor(
self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
) -> torch.Tensor:
# If only 1 dim, shard this one (usually it's a `bias`)
dim = param.dim() if isinstance(param, torch.Tensor) else len(param.get_shape())
if dim == 1:
parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -1)
else:
parameter = get_tensor_shard(
param,
self.empty_param,
self.device_mesh,
self.rank,
self.embedding_dim_sharding,
)
return parameter.to(device=device, dtype=dtype)
def get_expected_sharded_shape(self, full_shape: tuple[int, ...] | torch.Size) -> tuple[int, ...]:
world_size = self.device_mesh.size()
shape = list(full_shape)
# EmbeddingParallel shards on self.embedding_dim_sharding (default 0)
# 1D tensors (bias) shard on dim -1
dim = -1 if len(shape) == 1 else self.embedding_dim_sharding
dim = len(shape) + dim if dim < 0 else dim
shard_size = math.ceil(shape[dim] / world_size)
start = self.rank * shard_size
end = min(start + shard_size, shape[dim])
shape[dim] = end - start
return tuple(shape)
class SequenceParallel(TensorParallelLayer):
"""
Sequence Parallel: input/output sharded on sequence dimension.
Weights are replicated.
"""
def __init__(self, sequence_dim: int = 1, use_local_output: bool = False, use_dtensor=False, **kwargs):
super().__init__(**kwargs)
self.sequence_dim = sequence_dim
def _prepare_input_fn(self, mod, inputs, device_mesh):
input_tensor = inputs[0] if inputs else inputs
# For sequence parallel, input is sharded on sequence dim
# All-gather for the layer, then reduce-scatter after
return all_gather(input_tensor, device_mesh)
def _prepare_output_fn(self, mod, outputs, device_mesh):
return reduce_scatter(outputs, device_mesh)
def shard_tensor(
self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
) -> torch.Tensor:
return param[...].to(device=device, dtype=dtype)
class GroupedGemmParallel(TensorParallelLayer):
"""
Applies Expert Parallelism to MoE experts by loading the correct experts on each device.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
def shard_tensor(
self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
) -> torch.Tensor:
global_num_experts = self.empty_param.shape[0]
if global_num_experts % self.device_mesh.size() != 0:
raise ValueError(
f"Global number of experts must be divisible by number of devices: {global_num_experts} % {self.device_mesh.size()} != 0"
)
local_num_experts = global_num_experts // self.device_mesh.size()
shard_size = local_num_experts
if isinstance(device, torch.device):
device = device.index if device.index is not None else 0
start = device * shard_size
end = (device + 1) * shard_size
# special case we don't "shard" just send this entire tensor to the correct rank.
shape = param.get_shape() if not isinstance(param, torch.Tensor) else param.shape
if tensor_idx is not None and start <= tensor_idx < end:
# this tensor does need to be materialized on this device:
return param[:].to(device=device)
elif tensor_idx is None: # a bias or a weight, but already merged
return param[start:end].to(device=device, dtype=dtype)
elif len(shape) >= 1 and tensor_idx is not None:
return None
else: # bias case
return param[:].to(device=device, dtype=dtype)
def get_expected_sharded_shape(self, full_shape: tuple[int, ...] | torch.Size) -> tuple[int, ...]:
# GroupedGemm shards on dim 0 (experts dimension)
world_size = self.device_mesh.size()
shape = list(full_shape)
local_num_experts = shape[0] // world_size
shape[0] = local_num_experts
return tuple(shape)
class RouterParallel(TensorParallelLayer):
"""
Allows to reshape the router scores to support running expert parallel.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
@staticmethod
def _prepare_input_fn(mod, inputs, device_mesh):
return inputs[0] if inputs else inputs
@staticmethod
def _prepare_output_fn(mod, outputs, device_mesh):
"""
Imagine if you had 4 tokens, top_k = 4, and 128experts.
With EP = 8. The num_local_expert should be 128/8 = 16
Imagine router_indices being:
[ 52, 42, 119, 67],
[102, 89, 61, 40],
[ 82, 103, 4, 34],
[ 93, 23, 109, 11],
then you can map which rank should be getting which values
[3, 2, 7, 4],
[6, 5, 3, 2],
[5, 6, 0, 2],
[5, 1, 6, 0],
Thus for say rank 0, you fill with 16 (num_local_expert) the index tensor
[ 16, 16, 16, 16],
[ 16, 16, 16, 16],
[ 16, 16, 4, 16],
[ 16, 16, 16, 11],
This works well. For another rank you need to make sure you round to num_local_expert
because the next operation will one hot encode the router index vector.
This allows us to know directly which local expert is hit.
Similarly the scores are indexed with something created form
router_indices.
The kinda naive training loop that we use for device_map "auto" uses a similar logic.
Here we are just making each rank believe that he is alone, and he computes his part of the hiddenstates.
Mask invalid indices with num_local_expert for one-hot encoding, so the computes will skip the masking index.
"""
ep_rank, ep_size = device_mesh.get_local_rank(), device_mesh.size()
if mod.num_experts % ep_size != 0:
raise ValueError(
f"The number of experts must be divisible by number of ep_size: {mod.num_experts} % {ep_size} != 0"
)
num_local_experts = mod.num_experts // ep_size
router_logits, router_scores, router_indices = outputs
router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_scores)
router_scores = router_scores[:, ep_rank * num_local_experts : (ep_rank + 1) * num_local_experts]
router_indices = router_indices.masked_fill((router_indices // num_local_experts) != ep_rank, -1)
# As -1 % 1 is 0, we can only use mask fill when num_local_experts is 1
if num_local_experts > 1:
router_indices = torch.fmod(router_indices, num_local_experts)
else:
router_indices = router_indices.masked_fill(router_indices > 0, 0).masked_fill(router_indices < 0, -1)
router_indices = router_indices.masked_fill(router_indices == -1, num_local_experts)
return router_logits, router_scores, router_indices
def shard_tensor(
self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
) -> torch.Tensor:
return param[...].to(device=device, dtype=dtype)
class MoeTensorParalellExperts(TensorParallelLayer):
"""
Note: For tensor parallel, the MoEExpertsParallel TP layer handles gradient sync:
- all_reduce_backward on hidden_states (for colwise gate_up_proj gradient)
- all_reduce_backward on top_k_weights (for router gradient)
- all_reduce_forward on output (for partial expert outputs)
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
@staticmethod
def _prepare_input_fn(mod, inputs, device_mesh):
# inputs = (hidden_states, top_k_index, top_k_weights)
hidden_states = inputs[0]
top_k_index = inputs[1]
top_k_weights = inputs[2]
# all_reduce_backward on hidden_states for correct colwise (gate_up_proj) gradient
hidden_states = all_reduce_backward(hidden_states, device_mesh)
# all_reduce_backward on routing weights for correct router gradient
# This is needed because ∂L/∂routing_weights = ∂L/∂output * partial_expert_output
# and partial_expert_output is different on each GPU before all-reduce
top_k_weights = all_reduce_backward(top_k_weights, device_mesh)
return (hidden_states, top_k_index, top_k_weights)
@staticmethod
def _prepare_output_fn(mod, outputs, device_mesh):
# all_reduce_forward to sum partial expert outputs across GPUs
return all_reduce_forward(outputs, device_mesh)
def shard_tensor(
self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
) -> torch.Tensor:
# This class doesn't shard tensors - sharding is handled by packed_colwise/rowwise
# on the individual weight tensors (gate_up_proj/down_proj)
return param[...].to(device=device, dtype=dtype)
class ParallelInterface(GeneralInterface):
# Class instance object, so that a call to `register` can be reflected into all other files correctly, even if
# a new instance is created (in order to locally override a given entry)
_global_mapping = (
{
"embedding_rowwise": EmbeddingParallel(embedding_dim_sharding=0),
"colwise_gather_output": ColwiseParallel(gather_output=True),
"colwise": ColwiseParallel(),
"rowwise": RowwiseParallel(),
"rowwise_split_input": RowwiseParallel(split_input=True),
"packed_colwise": PackedColwiseParallel(),
"packed_rowwise": PackedRowwiseParallel(),
"sequence_parallel": SequenceParallel(),
"grouped_gemm": GroupedGemmParallel(),
"ep_router": RouterParallel(),
"moe_tp_experts": MoeTensorParalellExperts(),
}
if is_torch_available() and _torch_distributed_available
else {}
)
# Map plan names to sharding dimensions for weights
# For weights: colwise shards dim -2, rowwise shards dim -1
# For embedding: rowwise shards dim 0 (vocab), colwise shards dim -2 (hidden)
plan_to_weight_dim: dict[str, int | None] = {
"colwise": -2,
"colwise_gather_output": -2,
"packed_colwise": -2,
"rowwise": -1,
"rowwise_split_input": -1,
"packed_rowwise": -1,
"embedding_rowwise": 0,
"sequence_parallel": None,
}
# Bias sharding: colwise shards bias, rowwise doesn't (bias is replicated and all-reduced)
plan_to_bias_dim: dict[str, int | None] = {
"colwise": -1,
"colwise_gather_output": -1,
"packed_colwise": -1,
"rowwise": None,
"rowwise_split_input": None,
"packed_rowwise": None,
"embedding_rowwise": None,
"sequence_parallel": None,
}
ALL_PARALLEL_STYLES: ParallelInterface = ParallelInterface()
# =============================================================================
# High-Level API Functions
# =============================================================================
def gather_full_tensor(local_tensor: torch.Tensor, shard_dim: int, device_mesh) -> torch.Tensor:
"""
All-gather a sharded tensor along the specified dimension to reconstruct the full tensor.
Args:
local_tensor: The local shard of the tensor on this rank
shard_dim: The dimension along which the tensor was sharded
device_mesh: The device mesh for distributed communication
Returns:
The full reconstructed tensor (same on all ranks)
"""
world_size = device_mesh.size()
# Normalize negative dimension
if shard_dim < 0:
shard_dim = local_tensor.ndim + shard_dim
# Gather all shards
gathered_tensors = [torch.empty_like(local_tensor) for _ in range(world_size)]
dist.all_gather(gathered_tensors, local_tensor.contiguous())
# Concatenate along the shard dimension
return torch.cat(gathered_tensors, dim=shard_dim)
def gather_state_dict_for_save(
state_dict: dict[str, torch.Tensor],
tp_plan: dict[str, str],
device_mesh,
tp_size: int,
) -> dict[str, torch.Tensor]:
"""
Gather sharded tensors to reconstruct full tensors for saving.
This function all-gathers each sharded tensor along its shard dimension
to reconstruct the full unsharded tensor for checkpoint saving.
Args:
state_dict: The model state dict with local sharded tensors
tp_plan: The tensor parallel plan mapping layer patterns to shard styles
device_mesh: The device mesh for distributed communication
tp_size: The tensor parallel world size
Returns:
State dict with full (gathered) tensors
"""
# Use the global mappings from ParallelInterface (can be extended by users)
plan_to_weight_dim = ALL_PARALLEL_STYLES.plan_to_weight_dim
plan_to_bias_dim = ALL_PARALLEL_STYLES.plan_to_bias_dim
result = {}
for key, tensor in state_dict.items():
# Find the matching TP plan for this parameter
param_name = key.rsplit(".", 1)[0] if "." in key else key
param_type = key.rsplit(".", 1)[1] if "." in key else None
generic_param_name = re.sub(r"\d+", "*", param_name)
# Also check the full key for nn.Parameter (e.g., MoE experts without .weight suffix)
generic_full_key = re.sub(r"\d+", "*", key)
# Check if this parameter has a TP plan
current_plan = None
if generic_full_key in tp_plan:
# Full key match (e.g., "model.layers.*.mlp.experts.gate_up_proj" for MoE experts)
current_plan = tp_plan[generic_full_key]
elif generic_param_name in tp_plan:
current_plan = tp_plan[generic_param_name]
elif "." in generic_param_name:
parent_param_name = generic_param_name.rsplit(".", 1)[0]
if parent_param_name in tp_plan:
current_plan = tp_plan[parent_param_name]
if current_plan is None or current_plan not in plan_to_weight_dim:
# Not sharded, keep as-is
result[key] = tensor
continue
# Determine sharding dimension based on param type
if param_type == "bias":
shard_dim = plan_to_bias_dim.get(current_plan)
else:
shard_dim = plan_to_weight_dim.get(current_plan)
if shard_dim is None:
# Replicated, keep as-is
result[key] = tensor
continue
# Gather full tensor and handle packed weights repacking
full_tensor = gather_full_tensor(tensor, shard_dim, device_mesh)
if current_plan in ("packed_colwise", "packed_rowwise"):
full_tensor = repack_weights(full_tensor, shard_dim, tp_size, 2)
result[key] = full_tensor.contiguous()
return result
def add_tensor_parallel_hooks_to_module(
model, module, tp_plan, layer_name, current_module_plan, device_mesh, parameter_name=None
):
r"""
This function is called in `PretrainedModel.post_init()`. It is responsible of adding hooks
to the modules of the `model`, based on the `PretrainedModel._tp_plan`.
This is the place where we add the `pre_forward` and `post_forwards` hooks. These are defined
for each `TensorParallelLayer` as `_prepare_input_fn` and `_prepare_output_fn`.
"""
if current_module_plan is not None:
tp_layer = ALL_PARALLEL_STYLES[current_module_plan]
try:
tp_layer.prepare_module_tp(module, device_mesh)
except NotImplementedError as e:
print(
f"Trying to prepare {layer_name}, but it's not supported. Corresponding module: {module} Fix it's TP plan: {e}"
)
module._hf_tp_plan = current_module_plan
module.__repr__ = lambda: f"{module.__repr__()}\nTP Plan: {current_module_plan}"
def shard_and_distribute_module(
model, param, empty_param, parameter_name, param_casting_dtype, is_contiguous, rank, device_mesh
):
r"""
This function is called in `from_pretrained` when loading a model's checkpoints.
It receives the pointer to the parameter (or the parameter itself) and takes care of "sharding".
All process run this function, so they just load the partition of the tensor that they require.
Main uses cases:
- column / rowise parallelism, you just shard all the weights of the layer (weight and bias)
- packed layers: you slice the weights, then shard like above
- custom operation:
- you want to add an all-gather at the end of a local layer.
- you want to have a layer that is isolated from the rest of the world (because torch.DTensor does not work well with `.view` for instance)
"""
param_name, param_type = parameter_name.rsplit(".", 1) if "." in parameter_name else parameter_name
tp_plan = model.tp_plan or {}
module_to_tp = model.get_submodule(param_name)
rank = int(rank)
current_shard_plan = _get_parameter_tp_plan(parameter_name, tp_plan)
if dist.get_rank() == 0:
if current_shard_plan is None:
logger.info(f"Tensor sharding plan for {param_name} not found, using default 'replicate' plan.")
else:
logger.info(f"Tensor sharding plan for {param_name}: {current_shard_plan}")
if current_shard_plan is not None:
try:
tp_layer = ALL_PARALLEL_STYLES[current_shard_plan]
tp_layer.empty_param = empty_param
tp_layer.device_mesh = device_mesh
tp_layer.rank = rank
param = tp_layer.shard_tensor(param, tensor_idx=None, dtype=param_casting_dtype, device=rank)
if is_contiguous:
param = param.contiguous()
except NotImplementedError as e:
print(
f"Trying to prepare {parameter_name}, but it's not supported. Corresponding module: {module_to_tp} Fix it's TP plan, current layer: {tp_layer} : {e}"
)
else:
param = param[:].to(param_casting_dtype)
# SUPER IMPORTANT we have to use setattr
# otherwise loading is crazy slow
if not isinstance(param, torch.nn.Parameter):
param = torch.nn.Parameter(param, requires_grad=empty_param.is_floating_point())
setattr(module_to_tp, param_type, param)
return param
def verify_tp_plan(expected_keys: list[str], tp_plan: dict[str, str] | None):
"""
Verify the TP plan of the model, log a warning if the layers that were not sharded and the rules that were not applied.
"""
if tp_plan is None:
return
generic_keys = {replace_layer_number_by_wildcard(key) for key in expected_keys}
unsharded_layers = set(generic_keys)
unused_rules = tp_plan.copy()
for key in generic_keys:
param_name = key.rsplit(".", 1)[0] if "." in key else key
generic_param_name = re.sub(r"\d+", "*", param_name)
if generic_param_name in tp_plan:
unused_rules.pop(generic_param_name, None)
unsharded_layers.discard(key)
elif "." in generic_param_name and (parent_param_name := generic_param_name.rsplit(".", 1)[0]) in tp_plan:
unused_rules.pop(parent_param_name, None)
unsharded_layers.discard(key)
if len(unused_rules) > 0:
logger.warning(f"The following TP rules were not applied on any of the layers: {unused_rules}")
if len(unsharded_layers) > 0:
logger.warning(f"The following layers were not sharded: {', '.join(unsharded_layers)}")
def distribute_model(model, tp_plan, distributed_config, device_mesh, tp_size):
"""Distribute a model according to the TP plan."""
model._tp_size = tp_size
model._device_mesh = device_mesh
if distributed_config is not None:
if isinstance(distributed_config, dict):
distributed_config = DistributedConfig.from_dict(distributed_config)
model.config.distributed_config = distributed_config
# Set the new requested tp_plan on the model
if isinstance(tp_plan, dict):
model.tp_plan = tp_plan
model_plan = model.tp_plan
if model_plan is not None and _torch_distributed_available:
for v in model_plan.values():
if v not in ALL_PARALLEL_STYLES:
raise ValueError(f"Unsupported tensor parallel style {v}. Supported styles are {ALL_PARALLEL_STYLES}")
for name, module in model.named_modules():
if not getattr(module, "_is_hooked", False):
plan = _get_parameter_tp_plan(parameter_name=name, tp_plan=model_plan, is_weight=False)
add_tensor_parallel_hooks_to_module(
model=model,
module=module,
tp_plan=model_plan,
layer_name="",
current_module_plan=plan,
device_mesh=device_mesh,
)
module._is_hooked = True
return model