You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
1377 lines
56 KiB
1377 lines
56 KiB
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
from __future__ import annotations
|
|
|
|
import math
|
|
import operator
|
|
import os
|
|
import re
|
|
from functools import reduce
|
|
|
|
from ..distributed import DistributedConfig
|
|
from ..utils import is_torch_greater_or_equal, logging
|
|
from ..utils.generic import GeneralInterface
|
|
from ..utils.import_utils import is_torch_available
|
|
|
|
|
|
if is_torch_available():
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch import nn
|
|
|
|
# Cache this result has it's a C FFI call which can be pretty time-consuming
|
|
_torch_distributed_available = torch.distributed.is_available()
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
def initialize_tensor_parallelism(
|
|
tp_plan: str | dict[str, str] | None, tp_size: int | None = None, device_mesh=None, device_map=None
|
|
):
|
|
r"""
|
|
Sets up the device mesh and initialized the backend for tensor parallelism.
|
|
This function is called when the model is loaded and the TP plan is set to 'auto'.
|
|
"""
|
|
if tp_size is not None and tp_plan is None:
|
|
raise ValueError("tp_plan has to be set when tp_size is passed.")
|
|
if tp_plan is not None and device_map is not None:
|
|
raise ValueError("`tp_plan` and `device_map` are mutually exclusive. Choose either one for parallelization.")
|
|
if device_mesh is None:
|
|
if not is_torch_greater_or_equal("2.5"):
|
|
raise OSError("Tensor parallel is only supported for `torch>=2.5`.")
|
|
|
|
# Detect the accelerator on the machine. If no accelerator is available, it returns CPU.
|
|
device_type = torch._C._get_accelerator().type
|
|
if device_type == "mps":
|
|
device_type = "cpu" # fallback
|
|
current_device = getattr(torch, device_type)
|
|
if not torch.distributed.is_initialized():
|
|
try:
|
|
rank = int(os.environ["RANK"])
|
|
local_rank = int(os.environ["LOCAL_RANK"])
|
|
world_size = int(os.environ["WORLD_SIZE"])
|
|
|
|
backend_map = {"cuda": "nccl", "cpu": "gloo", "xpu": "xccl", "hpu": "hccl"}
|
|
backend = backend_map.get(device_type)
|
|
|
|
torch.distributed.init_process_group(backend=backend, rank=rank, world_size=world_size)
|
|
current_device = getattr(torch, device_type)
|
|
if device_type != "cpu":
|
|
current_device.set_device(local_rank)
|
|
|
|
except Exception as e:
|
|
raise OSError(
|
|
"We tried to initialize torch.distributed for you, but it failed. Make "
|
|
"sure you init torch distributed in your script to use `tp_plan`."
|
|
) from e
|
|
|
|
if device_type != "cpu":
|
|
current_device.set_device(int(os.environ["LOCAL_RANK"]))
|
|
index = current_device.current_device()
|
|
tp_device = torch.device(device_type, index)
|
|
device_map = tp_device
|
|
# Silence output for non-primary ranks
|
|
if index > 0:
|
|
import sys
|
|
|
|
sys.stdout = open(os.devnull, "w")
|
|
sys.stderr = open(os.devnull, "w")
|
|
|
|
else:
|
|
tp_device = torch.device(device_type)
|
|
device_map = device_type or {}
|
|
|
|
tp_size = tp_size if tp_size is not None else torch.distributed.get_world_size()
|
|
device_mesh = torch.distributed.init_device_mesh(tp_device.type, (tp_size,))
|
|
else:
|
|
if device_mesh.ndim > 1:
|
|
if "tp" not in device_mesh.mesh_dim_names:
|
|
raise ValueError(
|
|
"When using `tp_plan` and n-d `device_mesh`, it must contain a 'tp' dimension. "
|
|
"Please provide a valid `device_mesh`."
|
|
)
|
|
device_mesh = device_mesh["tp"]
|
|
tp_size = device_mesh.size()
|
|
device_map = torch.device(f"{device_mesh.device_type}:{int(os.environ['LOCAL_RANK'])}")
|
|
|
|
return device_map, device_mesh, tp_size
|
|
|
|
|
|
def replace_layer_number_by_wildcard(name: str) -> str:
|
|
"""
|
|
Replace the numbers in the `name` by wildcards, only if they are in-between dots (`.`) or if they are between
|
|
a dot (`.`) and the end of the string.
|
|
This matches how modules are named/numbered when using a nn.ModuleList or nn.Sequential, but will NOT match
|
|
numbers in a parameter name itself, e.g. if the param is named `"w1"` or `"w2"`.
|
|
"""
|
|
return re.sub(r"\.\d+(\.|$)", lambda m: ".*" + m.group(1), name)
|
|
|
|
|
|
def _get_parameter_tp_plan(parameter_name: str, tp_plan: dict[str, str], is_weight=True) -> str | None:
|
|
"""
|
|
Get the TP style for a parameter from the TP plan.
|
|
|
|
The TP plan is a dictionary that maps parameter names to TP styles.
|
|
The parameter name can be a generic name with wildcards (e.g. "*.weight") or a specific name (e.g. "layer_1.weight").
|
|
|
|
The `is_weight` is important because for weights, we want to support `.weights` and `.bias` cases seamlessly! but
|
|
not parent classes for `post_init` calls
|
|
"""
|
|
generic_param_name = replace_layer_number_by_wildcard(parameter_name)
|
|
if generic_param_name in tp_plan:
|
|
return tp_plan[generic_param_name]
|
|
elif is_weight and "." in generic_param_name and (module_name := generic_param_name.rsplit(".", 1)[0]) in tp_plan:
|
|
return tp_plan[module_name]
|
|
return None
|
|
|
|
|
|
# =============================================================================
|
|
# Tensor Sharding Utilities
|
|
# =============================================================================
|
|
|
|
|
|
if is_torch_available():
|
|
str_to_dtype = {
|
|
"BOOL": torch.bool,
|
|
"U8": torch.uint8,
|
|
"I8": torch.int8,
|
|
"I16": torch.int16,
|
|
"F16": torch.float16,
|
|
"BF16": torch.bfloat16,
|
|
"I32": torch.int32,
|
|
"F32": torch.float32,
|
|
"F64": torch.float64,
|
|
"I64": torch.int64,
|
|
"F8_E4M3": torch.float8_e4m3fn,
|
|
}
|
|
|
|
|
|
def _blocks_to_block_sizes(total_size: int, blocks: int | list[int]) -> list[int]:
|
|
"""
|
|
Convert block count or proportions to block sizes.
|
|
|
|
This function accepts
|
|
|
|
- The number of blocks (int), in which case the block size is
|
|
total_size//blocks; or
|
|
- A list of block sizes (list[int]).
|
|
|
|
In the second case, if sum(blocks) < total_size, the ratios between
|
|
the block sizes will be preserved. For instance, if blocks is
|
|
[2, 1, 1] and total_size is 1024, the returned block sizes are
|
|
[512, 256, 256].
|
|
"""
|
|
if isinstance(blocks, list):
|
|
total_blocks = sum(blocks)
|
|
assert total_size % total_blocks == 0, f"Cannot split {total_size} in proportional blocks: {blocks}"
|
|
part_size = total_size // total_blocks
|
|
return [part_size * block for block in blocks]
|
|
else:
|
|
assert total_size % blocks == 0, f"Prepacked is not divisible by {blocks}"
|
|
single_size = total_size // blocks
|
|
return [single_size] * blocks
|
|
|
|
|
|
def get_packed_weights(param, empty_param, device_mesh, rank, dim):
|
|
"""
|
|
When weights are packed (gate_up_proj), we need to make sure each shard gets its correct share.
|
|
So if you have: gate_proj ( 16, 5120, 8190)
|
|
and up_proj ( 16, 5120, 8190)
|
|
packed as gate_up_proj ( 16, 5120, 2 * 8190)
|
|
And you shard along the last dimension, you need to interleave the gate and up values:
|
|
|
|
Now, if we shard along the last dimension across TP_size (Tensor Parallelism size), we must interleave the values from gate and up projections correctly.
|
|
|
|
Let's take TP_size = 4 for an example:
|
|
|
|
Packed tensor `gate_up_proj`
|
|
---------------------------------------------------------------
|
|
[ G0 G1 G2 G3 | G4 G5 G6 G7 | ... | U0 U1 U2 U3 | U4 U5 U6 U7 | ... ]
|
|
↑─────────────↑ ↑─────────────↑ ↑─────────────↑ ↑─────────────↑
|
|
Gate Slice 0 Gate Slice 1 Up Slice 0 Up Slice 1
|
|
|
|
Explanation:
|
|
- The first half of the tensor (left of the center) holds the gate_proj values.
|
|
- The second half (right of the center) holds the up_proj values.
|
|
- For TP=4, we divide each half into 4 slices. In this example, we show two slices for brevity.
|
|
- Each shard receives one slice from the gate part and the corresponding slice from the up part.
|
|
|
|
For instance:
|
|
• Shard 0 gets: [ Gate Slice 0, Up Slice 0 ] = [ G0, G1, G2, G3, U0, U1, U2, U3 ]
|
|
• Shard 1 gets: [ Gate Slice 1, Up Slice 1 ] = [ G4, G5, G6, G7, U4, U5, U6, U7 ]
|
|
• … and so on.
|
|
|
|
This ensures that each shard receives an equal portion of both gate and up projections, maintaining consistency across tensor parallelism.
|
|
"""
|
|
slice_ = param
|
|
total_size = empty_param.shape[dim]
|
|
world_size = device_mesh.size()
|
|
block_sizes = _blocks_to_block_sizes(total_size=total_size, blocks=2)
|
|
|
|
tensors_slices = []
|
|
block_offset = 0
|
|
for block_size in block_sizes:
|
|
shard_block_size = block_size // world_size
|
|
start = rank * shard_block_size
|
|
stop = (rank + 1) * shard_block_size
|
|
tensors_slices += range(block_offset + start, block_offset + stop)
|
|
block_offset += block_size
|
|
|
|
slice_dtype = slice_.get_dtype()
|
|
# Handle F8_E4M3 dtype by converting to float16 before slicing
|
|
# Without upcasting, the slicing causes : RuntimeError: "index_cpu" not implemented for 'Float8_e4m3fn'
|
|
casted = False
|
|
if slice_dtype == "F8_E4M3" or slice_dtype == "F8_E5M2":
|
|
slice_ = slice_[...].to(torch.float16)
|
|
casted = True
|
|
|
|
if dim == 0:
|
|
tensor = slice_[tensors_slices, ...]
|
|
elif dim == 1 or dim == -2:
|
|
tensor = slice_[:, tensors_slices, ...]
|
|
elif dim == 2 or dim == -1:
|
|
tensor = slice_[..., tensors_slices]
|
|
else:
|
|
raise ValueError(f"Unsupported dim {dim}, only dim 0, 1 or 2 are supported")
|
|
|
|
if casted:
|
|
return tensor
|
|
else:
|
|
return tensor.to(str_to_dtype[slice_dtype])
|
|
|
|
|
|
def repack_weights(
|
|
packed_parameter: torch.Tensor,
|
|
sharded_dim: int, # The dimension index in the global tensor that was sharded
|
|
world_size: int,
|
|
num_blocks: int = 2,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Reorders a tensor that was reconstructed from sharded packed weights into its canonical packed format.
|
|
|
|
For example, if a weight was packed (e.g., gate_proj and up_proj) and then sharded,
|
|
DTensor.full_tensor() might produce an interleaved layout like [G0, U0, G1, U1, ...]
|
|
along the sharded dimension. This function reorders it to [G0, G1, ..., U0, U1, ...].
|
|
This is an inverse operation to get_packed_weights.
|
|
|
|
Args:
|
|
reconstructed_tensor: The tensor reconstructed from DTensor (e.g., via .full_tensor().contiguous()).
|
|
sharded_dim: The dimension index in the reconstructed_tensor that was originally sharded.
|
|
world_size: The tensor parallel world size.
|
|
num_packed_projs: The number of projections that were packed together (e.g., 2 for gate_up_proj).
|
|
|
|
Returns:
|
|
The reordered tensor in canonical packed format.
|
|
"""
|
|
|
|
if num_blocks != 2:
|
|
raise ValueError(
|
|
"Num blocks different from 2 is not supported yet. This is most likely a bug in your implementation as we only pack gate and up projections together."
|
|
)
|
|
|
|
actual_sharded_dim = sharded_dim if sharded_dim >= 0 else sharded_dim + packed_parameter.ndim
|
|
total_size_on_sharded_dim = packed_parameter.shape[actual_sharded_dim]
|
|
original_block_size_on_dim = total_size_on_sharded_dim // num_blocks
|
|
shard_chunk_size = original_block_size_on_dim // world_size
|
|
|
|
prefix_shape = packed_parameter.shape[:actual_sharded_dim]
|
|
suffix_shape = packed_parameter.shape[actual_sharded_dim + 1 :]
|
|
|
|
tensor_view = packed_parameter.view(
|
|
*prefix_shape,
|
|
world_size,
|
|
num_blocks,
|
|
shard_chunk_size,
|
|
*suffix_shape,
|
|
)
|
|
|
|
# Permute to bring num_packed_projs first, then world_size, then shard_chunk_size
|
|
# This groups all chunks of G together, then all chunks of U together.
|
|
# Target order of these middle dimensions: (num_packed_projs, world_size, shard_chunk_size)
|
|
# Current order of view's middle dimensions: (world_size, num_packed_projs, shard_chunk_size)
|
|
# Absolute indices of the dimensions to be permuted (world_size, num_packed_projs)
|
|
axis_ws_abs = len(prefix_shape)
|
|
axis_npp_abs = len(prefix_shape) + 1
|
|
|
|
permute_order = list(range(tensor_view.ndim))
|
|
permute_order[axis_ws_abs], permute_order[axis_npp_abs] = permute_order[axis_npp_abs], permute_order[axis_ws_abs]
|
|
|
|
tensor_permuted = tensor_view.permute(*permute_order)
|
|
|
|
# Reshape back to the original tensor's ndim, with the sharded dimension now correctly ordered as [G_all, U_all].
|
|
# The final shape should be the same as reconstructed_tensor.
|
|
final_ordered_tensor = tensor_permuted.reshape_as(packed_parameter)
|
|
|
|
return final_ordered_tensor
|
|
|
|
|
|
def get_tensor_shard(param, empty_param, device_mesh, rank, dim, tensor_idx: int | None = None):
|
|
"""
|
|
Generalized tensor sharding across a multi-dimensional device mesh.
|
|
Extract only the fraction of the parameter owned by the given `rank` when the parameter would have gone sharding at provided `dim`.
|
|
Extraction follows the pytorch `Shard` placement so that sharding and materializing back to full tensor follows `Shard` semantics.
|
|
`Shard` follows torch.chunk style sharding of the tensor. We demonstrate some cases below on how sharding happens including some edge cases
|
|
such as some ranks having an empty tensor as shard. Below implementation is robut to all these cases.
|
|
|
|
Case (1)
|
|
empty_param (16, 5120, 8190)
|
|
dim 0
|
|
device_mesh.size() 4
|
|
rank 0 gets (4, 5120, 8190) (0 ... 4, 5120, 8190)
|
|
rank 1 gets (4, 5120, 8190) (4 ... 8, 5120, 8190)
|
|
rank 2 gets (4, 5120, 8190) (8 ... 12, 5120, 8190)
|
|
rank 3 gets (4, 5120, 8190) (12 ... 16, 5120, 8190)
|
|
|
|
Case (2)
|
|
empty_param (16, 5120, 8190)
|
|
dim 0
|
|
device_mesh.size() 14
|
|
rank 0 gets (2, 5120, 8190) (0 ... 2, 5120, 8190)
|
|
rank 1 gets (2, 5120, 8190) (2 ... 4, 5120, 8190)
|
|
rank 2 gets (2, 5120, 8190) (4 ... 6, 5120, 8190)
|
|
rank 3 gets (2, 5120, 8190) (6 ... 8, 5120, 8190)
|
|
rank 4 gets (2, 5120, 8190) (8 ... 10, 5120, 8190)
|
|
rank 5 gets (2, 5120, 8190) (10 ... 12, 5120, 8190)
|
|
rank 6 gets (2, 5120, 8190) (12 ... 14, 5120, 8190)
|
|
rank 7 gets (2, 5120, 8190) (14 ... 16, 5120, 8190)
|
|
rank 8 gets (0, 5120, 8190)
|
|
rank 9 gets (0, 5120, 8190)
|
|
rank 10 gets (0, 5120, 8190)
|
|
rank 11 gets (0, 5120, 8190)
|
|
rank 12 gets (0, 5120, 8190)
|
|
rank 13 gets (0, 5120, 8190)
|
|
|
|
Case (3)
|
|
empty_param (16, 5120, 8190)
|
|
dim 0
|
|
device_mesh.size() 3
|
|
rank 0 gets (6, 5120, 8190) (0 ... 6, 5120, 8190)
|
|
rank 1 gets (6, 5120, 8190) (6 ... 12, 5120, 8190)
|
|
rank 2 gets (4, 5120, 8190) (12 ... 16, 5120, 8190)
|
|
|
|
In case (2), empty shards are returned with appropriate dimension to allow for operations to work smoothly.
|
|
Args:
|
|
param (torch.Tensor): The tensor to shard.
|
|
empty_param (torch.Tensor): A tensor used for shape reference.
|
|
device_mesh (torch.Tensor): Shape [d_0, ..., d_n] representing the mesh.
|
|
rank (int): Global rank of the current process/device.
|
|
dim (int): Dimension along which to shard the tensor.
|
|
"""
|
|
param_dim = empty_param.ndim
|
|
mesh_shape = device_mesh.shape
|
|
world_size = reduce(operator.mul, mesh_shape)
|
|
# Get param shape: works for both torch.Tensor and safetensors TensorInfo
|
|
param_shape = list(param.shape) if isinstance(param, torch.Tensor) else param.get_shape()
|
|
if dim < 0:
|
|
dim = param_dim + dim
|
|
if empty_param.dim() == 3 and dim == 1 and len(param_shape) == 2:
|
|
dim = 0
|
|
elif empty_param.dim() == 3 and dim == 2 and len(param_shape) == 2:
|
|
dim = 1
|
|
|
|
shard_size = math.ceil(param_shape[dim] / world_size)
|
|
start = rank * shard_size
|
|
end = min(start + shard_size, param_shape[dim])
|
|
|
|
if dim >= param_dim:
|
|
raise ValueError(f"dim {dim} is out of bounds for tensor of dimension {param_dim}")
|
|
|
|
if rank >= world_size:
|
|
raise ValueError(f"Rank {rank} is out of bounds for mesh size {world_size}")
|
|
|
|
# we have the full tensor not 1 part of it.
|
|
# in that case, we just assume that the weight was properly saved
|
|
# and thus because we TP if the layer is colwise it should not use this. Layer should be packed_colwise
|
|
# to inform that it needs to read form a packed tensor. It will also take care of the module list thingy.
|
|
# here we take care of potential chunking / layer split / layer chunking.
|
|
# The only "hard" case is? if we collect q,k,v -> merge it into qkv. In that case
|
|
# actually we still shard dim=0 does not change
|
|
# so only case is if the dim of the empty param is 3 and the shard dim is 0 -> we put the
|
|
# tensor on a certain device (with the input tensor_index)
|
|
if tensor_idx is not None and empty_param.dim() == 3 and dim == 0 and len(param_shape) == 2:
|
|
# special case we don't "shard" just send this entire tensor to the correct rank.
|
|
if start <= tensor_idx < end:
|
|
# this tensor does need to be materialized on this device:
|
|
return param[:]
|
|
else:
|
|
return torch.empty([], dtype=torch.int64, device=rank)
|
|
|
|
slice_indices = [slice(None)] * len(param_shape)
|
|
|
|
if start < param_shape[dim]:
|
|
slice_indices[dim] = slice(start, end)
|
|
param = param[tuple(slice_indices)]
|
|
if isinstance(param, list): # TODO handle the modulelist case!
|
|
param = [p[:] for p in param]
|
|
return param
|
|
|
|
param_shape[dim] = 0
|
|
return torch.empty(tuple(param_shape), dtype=torch.int64) # empty allocates memory....
|
|
|
|
|
|
def _split_along_last_dim(x, world_size):
|
|
"""Split tensor along last dimension into world_size chunks."""
|
|
return torch.chunk(x, world_size, dim=-1)
|
|
|
|
|
|
# =============================================================================
|
|
# Distributed Communication Primitives
|
|
# =============================================================================
|
|
#
|
|
# Naming convention:
|
|
# - Functions describe their FORWARD behavior
|
|
# - Backward behavior is the "conjugate" operation for gradient flow
|
|
#
|
|
# Available operations:
|
|
# ┌────────────────────┬─────────────────────┬─────────────────────┐
|
|
# │ Function │ Forward │ Backward │
|
|
# ├────────────────────┼─────────────────────┼─────────────────────┤
|
|
# │ all_reduce │ all-reduce (sum) │ identity │
|
|
# │ all_reduce_backward│ identity │ all-reduce (sum) │
|
|
# │ all_gather │ all-gather │ split (local chunk) │
|
|
# │ split │ split (local chunk) │ all-gather │
|
|
# │ reduce_scatter │ reduce-scatter │ all-gather │
|
|
# └────────────────────┴─────────────────────┴─────────────────────┘
|
|
# ===================
|
|
|
|
|
|
class _AllReduceBackward(torch.autograd.Function):
|
|
"""Identity forward, all-reduce backward. Used before colwise layers (f in Megatron)."""
|
|
|
|
@staticmethod
|
|
def forward(ctx, x, device_mesh):
|
|
ctx.device_mesh = device_mesh
|
|
return x
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
device_mesh = ctx.device_mesh
|
|
if device_mesh.size() == 1:
|
|
return grad_output, None
|
|
dist.all_reduce(grad_output, op=dist.ReduceOp.SUM, group=device_mesh.get_group())
|
|
return grad_output, None
|
|
|
|
|
|
class _AllReduceForward(torch.autograd.Function):
|
|
"""All-reduce forward, identity backward. Used after rowwise layers (g in Megatron)."""
|
|
|
|
@staticmethod
|
|
def forward(ctx, x, device_mesh):
|
|
if device_mesh.size() == 1:
|
|
return x
|
|
dist.all_reduce(x, op=dist.ReduceOp.SUM, group=device_mesh.get_group())
|
|
return x
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
return grad_output, None
|
|
|
|
|
|
class _AllGather(torch.autograd.Function):
|
|
"""All-gather forward, split backward. Gathers sharded outputs."""
|
|
|
|
@staticmethod
|
|
def forward(ctx, x, device_mesh):
|
|
ctx.device_mesh = device_mesh
|
|
world_size = device_mesh.size()
|
|
|
|
if world_size == 1:
|
|
return x
|
|
|
|
last_dim = x.dim() - 1
|
|
rank = device_mesh.get_local_rank()
|
|
group = device_mesh.get_group()
|
|
|
|
x = x.contiguous()
|
|
tensor_list = [torch.empty_like(x) for _ in range(world_size)]
|
|
tensor_list[rank] = x
|
|
dist.all_gather(tensor_list, x, group=group)
|
|
return torch.cat(tensor_list, dim=last_dim).contiguous()
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
device_mesh = ctx.device_mesh
|
|
world_size = device_mesh.size()
|
|
|
|
if world_size == 1:
|
|
return grad_output, None
|
|
|
|
rank = device_mesh.get_local_rank()
|
|
chunks = _split_along_last_dim(grad_output, world_size)
|
|
return chunks[rank].contiguous(), None
|
|
|
|
|
|
class _Split(torch.autograd.Function):
|
|
"""Split forward, all-gather backward. Scatters replicated input."""
|
|
|
|
@staticmethod
|
|
def forward(ctx, x, device_mesh):
|
|
ctx.device_mesh = device_mesh
|
|
world_size = device_mesh.size()
|
|
|
|
if world_size == 1:
|
|
return x
|
|
|
|
rank = device_mesh.get_local_rank()
|
|
chunks = _split_along_last_dim(x, world_size)
|
|
return chunks[rank].contiguous()
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
device_mesh = ctx.device_mesh
|
|
world_size = device_mesh.size()
|
|
|
|
if world_size == 1:
|
|
return grad_output, None
|
|
|
|
last_dim = grad_output.dim() - 1
|
|
rank = device_mesh.get_local_rank()
|
|
group = device_mesh.get_group()
|
|
|
|
grad_output = grad_output.contiguous()
|
|
tensor_list = [torch.empty_like(grad_output) for _ in range(world_size)]
|
|
tensor_list[rank] = grad_output
|
|
dist.all_gather(tensor_list, grad_output, group=group)
|
|
return torch.cat(tensor_list, dim=last_dim).contiguous(), None
|
|
|
|
|
|
class _ReduceScatter(torch.autograd.Function):
|
|
"""Reduce-scatter forward, all-gather backward. For sequence parallel."""
|
|
|
|
@staticmethod
|
|
def forward(ctx, x, device_mesh):
|
|
ctx.device_mesh = device_mesh
|
|
world_size = device_mesh.size()
|
|
|
|
if world_size == 1:
|
|
return x
|
|
|
|
last_dim = x.dim() - 1
|
|
group = device_mesh.get_group()
|
|
|
|
input_chunks = list(x.chunk(world_size, dim=last_dim))
|
|
output_shape = list(x.shape)
|
|
output_shape[last_dim] //= world_size
|
|
output = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
|
|
|
dist.reduce_scatter(output, input_chunks, op=dist.ReduceOp.SUM, group=group)
|
|
return output
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
device_mesh = ctx.device_mesh
|
|
world_size = device_mesh.size()
|
|
|
|
if world_size == 1:
|
|
return grad_output, None
|
|
|
|
last_dim = grad_output.dim() - 1
|
|
rank = device_mesh.get_local_rank()
|
|
group = device_mesh.get_group()
|
|
|
|
grad_output = grad_output.contiguous()
|
|
tensor_list = [torch.empty_like(grad_output) for _ in range(world_size)]
|
|
tensor_list[rank] = grad_output
|
|
dist.all_gather(tensor_list, grad_output, group=group)
|
|
return torch.cat(tensor_list, dim=last_dim).contiguous(), None
|
|
|
|
|
|
# =============================================================================
|
|
# Convenience wrappers
|
|
# =============================================================================
|
|
|
|
|
|
def all_reduce_backward(x, device_mesh):
|
|
"""Identity forward, all-reduce backward. Use before colwise layers."""
|
|
return _AllReduceBackward.apply(x, device_mesh)
|
|
|
|
|
|
def all_reduce_forward(x, device_mesh):
|
|
"""All-reduce forward, identity backward. Use after rowwise layers."""
|
|
return _AllReduceForward.apply(x, device_mesh)
|
|
|
|
|
|
def all_gather(x, device_mesh):
|
|
"""All-gather forward, split backward."""
|
|
return _AllGather.apply(x, device_mesh)
|
|
|
|
|
|
def split(x, device_mesh):
|
|
"""Split forward, all-gather backward."""
|
|
return _Split.apply(x, device_mesh)
|
|
|
|
|
|
def reduce_scatter(x, device_mesh):
|
|
"""Reduce-scatter forward, all-gather backward."""
|
|
return _ReduceScatter.apply(x, device_mesh)
|
|
|
|
|
|
def distribute_module(
|
|
module: nn.Module,
|
|
device_mesh=None,
|
|
input_fn=None,
|
|
output_fn=None,
|
|
) -> nn.Module:
|
|
"""
|
|
Copy pasted from torch's function but we remove the communications (partitioning)
|
|
as well as buffer registering that is similarly not efficient.
|
|
"""
|
|
if input_fn is not None:
|
|
module.register_forward_pre_hook(lambda mod, inputs: input_fn(mod, inputs, device_mesh))
|
|
if output_fn is not None:
|
|
module.register_forward_hook(lambda mod, inputs, outputs: output_fn(mod, outputs, device_mesh))
|
|
return module
|
|
|
|
|
|
class TensorParallelLayer:
|
|
"""General tensor parallel layer for transformers"""
|
|
|
|
device_mesh = None
|
|
rank = None
|
|
empty_param = None
|
|
|
|
def __init__(self, device_mesh=None, rank=None, empty_param=None):
|
|
self.rank = rank
|
|
self.device_mesh = device_mesh
|
|
self.empty_param = empty_param
|
|
|
|
@staticmethod
|
|
def _prepare_input_fn(mod, inputs, device_mesh): ...
|
|
|
|
@staticmethod
|
|
def _prepare_output_fn(mod, outputs, device_mesh): ...
|
|
|
|
def shard_tensor(
|
|
self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
|
|
) -> torch.Tensor:
|
|
raise NotImplementedError
|
|
|
|
def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module:
|
|
distribute_module(
|
|
module,
|
|
device_mesh,
|
|
self._prepare_input_fn,
|
|
self._prepare_output_fn,
|
|
)
|
|
|
|
def get_expected_sharded_shape(self, full_shape: tuple[int, ...] | torch.Size) -> tuple[int, ...]:
|
|
"""
|
|
Compute the expected shape after TP sharding for a given full shape.
|
|
|
|
Args:
|
|
full_shape: The full (unsharded) parameter shape
|
|
|
|
Returns:
|
|
The expected sharded shape for this rank
|
|
"""
|
|
# Default: no sharding, return full shape
|
|
return tuple(full_shape)
|
|
|
|
|
|
class ColwiseParallel(TensorParallelLayer):
|
|
"""
|
|
Column-wise parallel: weight is sharded on dim -2 (output features).
|
|
Forward: input replicated -> output sharded on last dim.
|
|
If gather_output=True, output is all-gathered to produce full tensor.
|
|
"""
|
|
|
|
def __init__(self, gather_output: bool = False, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self.gather_output = gather_output
|
|
|
|
def _prepare_input_fn(self, mod, inputs, device_mesh):
|
|
input_tensor = inputs[0] if inputs else inputs
|
|
return all_reduce_backward(input_tensor, device_mesh)
|
|
|
|
def _prepare_output_fn(self, mod, outputs, device_mesh):
|
|
if self.gather_output:
|
|
return all_gather(outputs, device_mesh)
|
|
return outputs
|
|
|
|
def shard_tensor(
|
|
self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
|
|
) -> torch.Tensor:
|
|
# If only 1 dim, shard this one (usually it's a `bias`)
|
|
dim = param.dim() if isinstance(param, torch.Tensor) else len(param.get_shape())
|
|
if dim == 1:
|
|
parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -1)
|
|
else:
|
|
parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -2)
|
|
return parameter.to(device=device, dtype=dtype)
|
|
|
|
def get_expected_sharded_shape(self, full_shape: tuple[int, ...] | torch.Size) -> tuple[int, ...]:
|
|
world_size = self.device_mesh.size()
|
|
shape = list(full_shape)
|
|
# Colwise shards dim -2, but 1D tensors (bias) shard on dim -1
|
|
dim = -1 if len(shape) == 1 else -2
|
|
dim = len(shape) + dim if dim < 0 else dim
|
|
shard_size = math.ceil(shape[dim] / world_size)
|
|
start = self.rank * shard_size
|
|
end = min(start + shard_size, shape[dim])
|
|
shape[dim] = end - start
|
|
return tuple(shape)
|
|
|
|
|
|
class RowwiseParallel(TensorParallelLayer):
|
|
"""
|
|
Row-wise parallel: weight is sharded on dim -1 (input features).
|
|
Forward: input (optionally split) -> output partial -> all-reduce to replicate.
|
|
|
|
Args:
|
|
split_input: If True, splits replicated input before matmul. Use when input
|
|
comes from a non-parallelizable operation (chunk/slice).
|
|
Default False (expects pre-sharded input from colwise layer).
|
|
"""
|
|
|
|
def __init__(self, split_input: bool = False, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self.split_input = split_input
|
|
|
|
def _prepare_input_fn(self, mod, inputs, device_mesh):
|
|
if hasattr(mod, "bias") and mod.bias is not None:
|
|
mod._bias = mod.bias
|
|
mod.bias = None
|
|
|
|
input_tensor = inputs[0] if inputs else inputs
|
|
|
|
if self.split_input:
|
|
# Input is replicated, split it to match sharded weight
|
|
return split(input_tensor, device_mesh)
|
|
return input_tensor
|
|
|
|
def _prepare_output_fn(self, mod, outputs, device_mesh):
|
|
outputs = all_reduce_forward(outputs, device_mesh)
|
|
if hasattr(mod, "_bias") and mod._bias is not None:
|
|
outputs = outputs + mod._bias
|
|
return outputs
|
|
|
|
def shard_tensor(
|
|
self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
|
|
) -> torch.Tensor:
|
|
# If only 1 dim, it should not be sharded (usually it's a `bias`)
|
|
dim = param.dim() if isinstance(param, torch.Tensor) else len(param.get_shape())
|
|
if dim == 1:
|
|
parameter = param[...]
|
|
else:
|
|
parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -1)
|
|
return parameter.to(device=device, dtype=dtype)
|
|
|
|
def get_expected_sharded_shape(self, full_shape: tuple[int, ...] | torch.Size) -> tuple[int, ...]:
|
|
# 1D tensors (bias) are NOT sharded in rowwise
|
|
if len(full_shape) == 1:
|
|
return tuple(full_shape)
|
|
world_size = self.device_mesh.size()
|
|
shape = list(full_shape)
|
|
dim = -1
|
|
dim = len(shape) + dim if dim < 0 else dim
|
|
shard_size = math.ceil(shape[dim] / world_size)
|
|
start = self.rank * shard_size
|
|
end = min(start + shard_size, shape[dim])
|
|
shape[dim] = end - start
|
|
return tuple(shape)
|
|
|
|
|
|
class PackedColwiseParallel(ColwiseParallel):
|
|
"""Packed column-wise parallel for fused weights like gate_up_proj."""
|
|
|
|
def shard_tensor(
|
|
self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
|
|
) -> torch.Tensor:
|
|
# If only 1 dim, shard this one (usually it's a `bias`)
|
|
dim = param.dim() if isinstance(param, torch.Tensor) else len(param.get_shape())
|
|
if dim == 1:
|
|
parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -1)
|
|
else:
|
|
expected_shape = self.get_expected_sharded_shape(self.empty_param.shape)
|
|
if dim < len(expected_shape):
|
|
# Input is unpacked (e.g., gate_proj that will be concatenated to gate_up_proj)
|
|
# Use regular tensor shard - concatenation will happen after
|
|
parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -2)
|
|
else:
|
|
# Input is already packed, use packed sharding
|
|
parameter = get_packed_weights(param, self.empty_param, self.device_mesh, self.rank, -2)
|
|
return parameter.to(device=device, dtype=dtype)
|
|
|
|
|
|
class PackedRowwiseParallel(RowwiseParallel):
|
|
"""Packed row-wise parallel for fused weights like gate_up_proj."""
|
|
|
|
def shard_tensor(
|
|
self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
|
|
) -> torch.Tensor:
|
|
# If only 1 dim, it should not be sharded (usually it's a `bias`)
|
|
dim = param.dim() if isinstance(param, torch.Tensor) else len(param.get_shape())
|
|
if dim == 1:
|
|
parameter = param[...]
|
|
else:
|
|
# Check if input tensor is unpacked (shape mismatch with expected packed size)
|
|
# This happens when using MergeModulelist + Concatenate for fused weights like gate_up_proj
|
|
param_shape = param.shape if isinstance(param, torch.Tensor) else param.get_shape()
|
|
expected_packed_dim = self.empty_param.shape[-1] if self.empty_param.dim() >= 1 else 0
|
|
actual_dim = param_shape[-1] if len(param_shape) >= 1 else 0
|
|
|
|
if actual_dim < expected_packed_dim:
|
|
# Input is unpacked, use regular tensor shard
|
|
parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -1)
|
|
else:
|
|
# Input is already packed, use packed sharding
|
|
parameter = get_packed_weights(param, self.empty_param, self.device_mesh, self.rank, -1)
|
|
return parameter.to(device=device, dtype=dtype)
|
|
|
|
|
|
class EmbeddingParallel(TensorParallelLayer):
|
|
"""EmbeddingParallel: shards embedding table, handles masked lookups for vocab parallelism."""
|
|
|
|
def __init__(self, *, embedding_dim_sharding: int = 0, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self.embedding_dim_sharding = embedding_dim_sharding
|
|
|
|
def _prepare_input_fn(self, mod, inputs, device_mesh):
|
|
input_tensor = inputs[0] if inputs else inputs
|
|
|
|
# For vocab-parallel (dim 0), we need to handle masking and offsetting
|
|
if self.embedding_dim_sharding == 0:
|
|
rank = device_mesh.get_local_rank()
|
|
|
|
# Get vocab range for this rank
|
|
# Use weight.shape[0] to get the actual local (sharded) size, not num_embeddings
|
|
# which may not be updated after sharding
|
|
per_partition_size = mod.weight.shape[0]
|
|
vocab_start_index = rank * per_partition_size
|
|
vocab_end_index = vocab_start_index + per_partition_size
|
|
|
|
# Build mask for out-of-vocabulary tokens
|
|
input_mask = (input_tensor < vocab_start_index) | (input_tensor >= vocab_end_index)
|
|
mod._input_mask = input_mask
|
|
|
|
# Offset input to local indices and mask invalid ones
|
|
masked_input = input_tensor.clone() - vocab_start_index
|
|
masked_input[input_mask] = 0 # Set to valid local index
|
|
|
|
return masked_input
|
|
|
|
return input_tensor
|
|
|
|
def _prepare_output_fn(self, mod, outputs, device_mesh):
|
|
# For vocab-parallel (dim 0), zero out embeddings for out-of-range tokens before all-reduce
|
|
if self.embedding_dim_sharding == 0 and hasattr(mod, "_input_mask"):
|
|
input_mask = mod._input_mask
|
|
# Use multiplication instead of in-place assignment to preserve gradients
|
|
mask_expanded = input_mask.unsqueeze(-1).expand_as(outputs)
|
|
outputs = outputs * (~mask_expanded).float()
|
|
del mod._input_mask
|
|
|
|
return all_reduce_forward(outputs, device_mesh)
|
|
|
|
def shard_tensor(
|
|
self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
|
|
) -> torch.Tensor:
|
|
# If only 1 dim, shard this one (usually it's a `bias`)
|
|
dim = param.dim() if isinstance(param, torch.Tensor) else len(param.get_shape())
|
|
if dim == 1:
|
|
parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -1)
|
|
else:
|
|
parameter = get_tensor_shard(
|
|
param,
|
|
self.empty_param,
|
|
self.device_mesh,
|
|
self.rank,
|
|
self.embedding_dim_sharding,
|
|
)
|
|
return parameter.to(device=device, dtype=dtype)
|
|
|
|
def get_expected_sharded_shape(self, full_shape: tuple[int, ...] | torch.Size) -> tuple[int, ...]:
|
|
world_size = self.device_mesh.size()
|
|
shape = list(full_shape)
|
|
# EmbeddingParallel shards on self.embedding_dim_sharding (default 0)
|
|
# 1D tensors (bias) shard on dim -1
|
|
dim = -1 if len(shape) == 1 else self.embedding_dim_sharding
|
|
dim = len(shape) + dim if dim < 0 else dim
|
|
shard_size = math.ceil(shape[dim] / world_size)
|
|
start = self.rank * shard_size
|
|
end = min(start + shard_size, shape[dim])
|
|
shape[dim] = end - start
|
|
return tuple(shape)
|
|
|
|
|
|
class SequenceParallel(TensorParallelLayer):
|
|
"""
|
|
Sequence Parallel: input/output sharded on sequence dimension.
|
|
Weights are replicated.
|
|
"""
|
|
|
|
def __init__(self, sequence_dim: int = 1, use_local_output: bool = False, use_dtensor=False, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self.sequence_dim = sequence_dim
|
|
|
|
def _prepare_input_fn(self, mod, inputs, device_mesh):
|
|
input_tensor = inputs[0] if inputs else inputs
|
|
# For sequence parallel, input is sharded on sequence dim
|
|
# All-gather for the layer, then reduce-scatter after
|
|
return all_gather(input_tensor, device_mesh)
|
|
|
|
def _prepare_output_fn(self, mod, outputs, device_mesh):
|
|
return reduce_scatter(outputs, device_mesh)
|
|
|
|
def shard_tensor(
|
|
self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
|
|
) -> torch.Tensor:
|
|
return param[...].to(device=device, dtype=dtype)
|
|
|
|
|
|
class GroupedGemmParallel(TensorParallelLayer):
|
|
"""
|
|
Applies Expert Parallelism to MoE experts by loading the correct experts on each device.
|
|
"""
|
|
|
|
def __init__(self, **kwargs):
|
|
super().__init__(**kwargs)
|
|
|
|
def shard_tensor(
|
|
self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
|
|
) -> torch.Tensor:
|
|
global_num_experts = self.empty_param.shape[0]
|
|
if global_num_experts % self.device_mesh.size() != 0:
|
|
raise ValueError(
|
|
f"Global number of experts must be divisible by number of devices: {global_num_experts} % {self.device_mesh.size()} != 0"
|
|
)
|
|
local_num_experts = global_num_experts // self.device_mesh.size()
|
|
shard_size = local_num_experts
|
|
if isinstance(device, torch.device):
|
|
device = device.index if device.index is not None else 0
|
|
start = device * shard_size
|
|
end = (device + 1) * shard_size
|
|
# special case we don't "shard" just send this entire tensor to the correct rank.
|
|
shape = param.get_shape() if not isinstance(param, torch.Tensor) else param.shape
|
|
if tensor_idx is not None and start <= tensor_idx < end:
|
|
# this tensor does need to be materialized on this device:
|
|
return param[:].to(device=device)
|
|
elif tensor_idx is None: # a bias or a weight, but already merged
|
|
return param[start:end].to(device=device, dtype=dtype)
|
|
elif len(shape) >= 1 and tensor_idx is not None:
|
|
return None
|
|
else: # bias case
|
|
return param[:].to(device=device, dtype=dtype)
|
|
|
|
def get_expected_sharded_shape(self, full_shape: tuple[int, ...] | torch.Size) -> tuple[int, ...]:
|
|
# GroupedGemm shards on dim 0 (experts dimension)
|
|
world_size = self.device_mesh.size()
|
|
shape = list(full_shape)
|
|
local_num_experts = shape[0] // world_size
|
|
shape[0] = local_num_experts
|
|
return tuple(shape)
|
|
|
|
|
|
class RouterParallel(TensorParallelLayer):
|
|
"""
|
|
Allows to reshape the router scores to support running expert parallel.
|
|
"""
|
|
|
|
def __init__(self, **kwargs):
|
|
super().__init__(**kwargs)
|
|
|
|
@staticmethod
|
|
def _prepare_input_fn(mod, inputs, device_mesh):
|
|
return inputs[0] if inputs else inputs
|
|
|
|
@staticmethod
|
|
def _prepare_output_fn(mod, outputs, device_mesh):
|
|
"""
|
|
Imagine if you had 4 tokens, top_k = 4, and 128experts.
|
|
With EP = 8. The num_local_expert should be 128/8 = 16
|
|
Imagine router_indices being:
|
|
[ 52, 42, 119, 67],
|
|
[102, 89, 61, 40],
|
|
[ 82, 103, 4, 34],
|
|
[ 93, 23, 109, 11],
|
|
|
|
then you can map which rank should be getting which values
|
|
|
|
[3, 2, 7, 4],
|
|
[6, 5, 3, 2],
|
|
[5, 6, 0, 2],
|
|
[5, 1, 6, 0],
|
|
|
|
Thus for say rank 0, you fill with 16 (num_local_expert) the index tensor
|
|
|
|
[ 16, 16, 16, 16],
|
|
[ 16, 16, 16, 16],
|
|
[ 16, 16, 4, 16],
|
|
[ 16, 16, 16, 11],
|
|
|
|
This works well. For another rank you need to make sure you round to num_local_expert
|
|
because the next operation will one hot encode the router index vector.
|
|
|
|
This allows us to know directly which local expert is hit.
|
|
Similarly the scores are indexed with something created form
|
|
router_indices.
|
|
|
|
The kinda naive training loop that we use for device_map "auto" uses a similar logic.
|
|
Here we are just making each rank believe that he is alone, and he computes his part of the hiddenstates.
|
|
Mask invalid indices with num_local_expert for one-hot encoding, so the computes will skip the masking index.
|
|
"""
|
|
ep_rank, ep_size = device_mesh.get_local_rank(), device_mesh.size()
|
|
if mod.num_experts % ep_size != 0:
|
|
raise ValueError(
|
|
f"The number of experts must be divisible by number of ep_size: {mod.num_experts} % {ep_size} != 0"
|
|
)
|
|
num_local_experts = mod.num_experts // ep_size
|
|
router_logits, router_scores, router_indices = outputs
|
|
router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_scores)
|
|
router_scores = router_scores[:, ep_rank * num_local_experts : (ep_rank + 1) * num_local_experts]
|
|
router_indices = router_indices.masked_fill((router_indices // num_local_experts) != ep_rank, -1)
|
|
# As -1 % 1 is 0, we can only use mask fill when num_local_experts is 1
|
|
if num_local_experts > 1:
|
|
router_indices = torch.fmod(router_indices, num_local_experts)
|
|
else:
|
|
router_indices = router_indices.masked_fill(router_indices > 0, 0).masked_fill(router_indices < 0, -1)
|
|
router_indices = router_indices.masked_fill(router_indices == -1, num_local_experts)
|
|
return router_logits, router_scores, router_indices
|
|
|
|
def shard_tensor(
|
|
self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
|
|
) -> torch.Tensor:
|
|
return param[...].to(device=device, dtype=dtype)
|
|
|
|
|
|
class MoeTensorParalellExperts(TensorParallelLayer):
|
|
"""
|
|
Note: For tensor parallel, the MoEExpertsParallel TP layer handles gradient sync:
|
|
- all_reduce_backward on hidden_states (for colwise gate_up_proj gradient)
|
|
- all_reduce_backward on top_k_weights (for router gradient)
|
|
- all_reduce_forward on output (for partial expert outputs)
|
|
"""
|
|
|
|
def __init__(self, **kwargs):
|
|
super().__init__(**kwargs)
|
|
|
|
@staticmethod
|
|
def _prepare_input_fn(mod, inputs, device_mesh):
|
|
# inputs = (hidden_states, top_k_index, top_k_weights)
|
|
hidden_states = inputs[0]
|
|
top_k_index = inputs[1]
|
|
top_k_weights = inputs[2]
|
|
|
|
# all_reduce_backward on hidden_states for correct colwise (gate_up_proj) gradient
|
|
hidden_states = all_reduce_backward(hidden_states, device_mesh)
|
|
|
|
# all_reduce_backward on routing weights for correct router gradient
|
|
# This is needed because ∂L/∂routing_weights = ∂L/∂output * partial_expert_output
|
|
# and partial_expert_output is different on each GPU before all-reduce
|
|
top_k_weights = all_reduce_backward(top_k_weights, device_mesh)
|
|
|
|
return (hidden_states, top_k_index, top_k_weights)
|
|
|
|
@staticmethod
|
|
def _prepare_output_fn(mod, outputs, device_mesh):
|
|
# all_reduce_forward to sum partial expert outputs across GPUs
|
|
return all_reduce_forward(outputs, device_mesh)
|
|
|
|
def shard_tensor(
|
|
self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
|
|
) -> torch.Tensor:
|
|
# This class doesn't shard tensors - sharding is handled by packed_colwise/rowwise
|
|
# on the individual weight tensors (gate_up_proj/down_proj)
|
|
return param[...].to(device=device, dtype=dtype)
|
|
|
|
|
|
class ParallelInterface(GeneralInterface):
|
|
# Class instance object, so that a call to `register` can be reflected into all other files correctly, even if
|
|
# a new instance is created (in order to locally override a given entry)
|
|
_global_mapping = (
|
|
{
|
|
"embedding_rowwise": EmbeddingParallel(embedding_dim_sharding=0),
|
|
"colwise_gather_output": ColwiseParallel(gather_output=True),
|
|
"colwise": ColwiseParallel(),
|
|
"rowwise": RowwiseParallel(),
|
|
"rowwise_split_input": RowwiseParallel(split_input=True),
|
|
"packed_colwise": PackedColwiseParallel(),
|
|
"packed_rowwise": PackedRowwiseParallel(),
|
|
"sequence_parallel": SequenceParallel(),
|
|
"grouped_gemm": GroupedGemmParallel(),
|
|
"ep_router": RouterParallel(),
|
|
"moe_tp_experts": MoeTensorParalellExperts(),
|
|
}
|
|
if is_torch_available() and _torch_distributed_available
|
|
else {}
|
|
)
|
|
|
|
# Map plan names to sharding dimensions for weights
|
|
# For weights: colwise shards dim -2, rowwise shards dim -1
|
|
# For embedding: rowwise shards dim 0 (vocab), colwise shards dim -2 (hidden)
|
|
plan_to_weight_dim: dict[str, int | None] = {
|
|
"colwise": -2,
|
|
"colwise_gather_output": -2,
|
|
"packed_colwise": -2,
|
|
"rowwise": -1,
|
|
"rowwise_split_input": -1,
|
|
"packed_rowwise": -1,
|
|
"embedding_rowwise": 0,
|
|
"sequence_parallel": None,
|
|
}
|
|
|
|
# Bias sharding: colwise shards bias, rowwise doesn't (bias is replicated and all-reduced)
|
|
plan_to_bias_dim: dict[str, int | None] = {
|
|
"colwise": -1,
|
|
"colwise_gather_output": -1,
|
|
"packed_colwise": -1,
|
|
"rowwise": None,
|
|
"rowwise_split_input": None,
|
|
"packed_rowwise": None,
|
|
"embedding_rowwise": None,
|
|
"sequence_parallel": None,
|
|
}
|
|
|
|
|
|
ALL_PARALLEL_STYLES: ParallelInterface = ParallelInterface()
|
|
|
|
|
|
# =============================================================================
|
|
# High-Level API Functions
|
|
# =============================================================================
|
|
|
|
|
|
def gather_full_tensor(local_tensor: torch.Tensor, shard_dim: int, device_mesh) -> torch.Tensor:
|
|
"""
|
|
All-gather a sharded tensor along the specified dimension to reconstruct the full tensor.
|
|
|
|
Args:
|
|
local_tensor: The local shard of the tensor on this rank
|
|
shard_dim: The dimension along which the tensor was sharded
|
|
device_mesh: The device mesh for distributed communication
|
|
|
|
Returns:
|
|
The full reconstructed tensor (same on all ranks)
|
|
"""
|
|
world_size = device_mesh.size()
|
|
|
|
# Normalize negative dimension
|
|
if shard_dim < 0:
|
|
shard_dim = local_tensor.ndim + shard_dim
|
|
|
|
# Gather all shards
|
|
gathered_tensors = [torch.empty_like(local_tensor) for _ in range(world_size)]
|
|
dist.all_gather(gathered_tensors, local_tensor.contiguous())
|
|
|
|
# Concatenate along the shard dimension
|
|
return torch.cat(gathered_tensors, dim=shard_dim)
|
|
|
|
|
|
def gather_state_dict_for_save(
|
|
state_dict: dict[str, torch.Tensor],
|
|
tp_plan: dict[str, str],
|
|
device_mesh,
|
|
tp_size: int,
|
|
) -> dict[str, torch.Tensor]:
|
|
"""
|
|
Gather sharded tensors to reconstruct full tensors for saving.
|
|
|
|
This function all-gathers each sharded tensor along its shard dimension
|
|
to reconstruct the full unsharded tensor for checkpoint saving.
|
|
|
|
Args:
|
|
state_dict: The model state dict with local sharded tensors
|
|
tp_plan: The tensor parallel plan mapping layer patterns to shard styles
|
|
device_mesh: The device mesh for distributed communication
|
|
tp_size: The tensor parallel world size
|
|
|
|
Returns:
|
|
State dict with full (gathered) tensors
|
|
"""
|
|
# Use the global mappings from ParallelInterface (can be extended by users)
|
|
plan_to_weight_dim = ALL_PARALLEL_STYLES.plan_to_weight_dim
|
|
plan_to_bias_dim = ALL_PARALLEL_STYLES.plan_to_bias_dim
|
|
|
|
result = {}
|
|
for key, tensor in state_dict.items():
|
|
# Find the matching TP plan for this parameter
|
|
param_name = key.rsplit(".", 1)[0] if "." in key else key
|
|
param_type = key.rsplit(".", 1)[1] if "." in key else None
|
|
generic_param_name = re.sub(r"\d+", "*", param_name)
|
|
# Also check the full key for nn.Parameter (e.g., MoE experts without .weight suffix)
|
|
generic_full_key = re.sub(r"\d+", "*", key)
|
|
|
|
# Check if this parameter has a TP plan
|
|
current_plan = None
|
|
if generic_full_key in tp_plan:
|
|
# Full key match (e.g., "model.layers.*.mlp.experts.gate_up_proj" for MoE experts)
|
|
current_plan = tp_plan[generic_full_key]
|
|
elif generic_param_name in tp_plan:
|
|
current_plan = tp_plan[generic_param_name]
|
|
elif "." in generic_param_name:
|
|
parent_param_name = generic_param_name.rsplit(".", 1)[0]
|
|
if parent_param_name in tp_plan:
|
|
current_plan = tp_plan[parent_param_name]
|
|
|
|
if current_plan is None or current_plan not in plan_to_weight_dim:
|
|
# Not sharded, keep as-is
|
|
result[key] = tensor
|
|
continue
|
|
|
|
# Determine sharding dimension based on param type
|
|
if param_type == "bias":
|
|
shard_dim = plan_to_bias_dim.get(current_plan)
|
|
else:
|
|
shard_dim = plan_to_weight_dim.get(current_plan)
|
|
|
|
if shard_dim is None:
|
|
# Replicated, keep as-is
|
|
result[key] = tensor
|
|
continue
|
|
|
|
# Gather full tensor and handle packed weights repacking
|
|
full_tensor = gather_full_tensor(tensor, shard_dim, device_mesh)
|
|
if current_plan in ("packed_colwise", "packed_rowwise"):
|
|
full_tensor = repack_weights(full_tensor, shard_dim, tp_size, 2)
|
|
result[key] = full_tensor.contiguous()
|
|
|
|
return result
|
|
|
|
|
|
def add_tensor_parallel_hooks_to_module(
|
|
model, module, tp_plan, layer_name, current_module_plan, device_mesh, parameter_name=None
|
|
):
|
|
r"""
|
|
This function is called in `PretrainedModel.post_init()`. It is responsible of adding hooks
|
|
to the modules of the `model`, based on the `PretrainedModel._tp_plan`.
|
|
|
|
This is the place where we add the `pre_forward` and `post_forwards` hooks. These are defined
|
|
for each `TensorParallelLayer` as `_prepare_input_fn` and `_prepare_output_fn`.
|
|
|
|
"""
|
|
if current_module_plan is not None:
|
|
tp_layer = ALL_PARALLEL_STYLES[current_module_plan]
|
|
try:
|
|
tp_layer.prepare_module_tp(module, device_mesh)
|
|
except NotImplementedError as e:
|
|
print(
|
|
f"Trying to prepare {layer_name}, but it's not supported. Corresponding module: {module} Fix it's TP plan: {e}"
|
|
)
|
|
|
|
module._hf_tp_plan = current_module_plan
|
|
module.__repr__ = lambda: f"{module.__repr__()}\nTP Plan: {current_module_plan}"
|
|
|
|
|
|
def shard_and_distribute_module(
|
|
model, param, empty_param, parameter_name, param_casting_dtype, is_contiguous, rank, device_mesh
|
|
):
|
|
r"""
|
|
This function is called in `from_pretrained` when loading a model's checkpoints.
|
|
It receives the pointer to the parameter (or the parameter itself) and takes care of "sharding".
|
|
All process run this function, so they just load the partition of the tensor that they require.
|
|
|
|
Main uses cases:
|
|
- column / rowise parallelism, you just shard all the weights of the layer (weight and bias)
|
|
- packed layers: you slice the weights, then shard like above
|
|
- custom operation:
|
|
- you want to add an all-gather at the end of a local layer.
|
|
- you want to have a layer that is isolated from the rest of the world (because torch.DTensor does not work well with `.view` for instance)
|
|
|
|
"""
|
|
param_name, param_type = parameter_name.rsplit(".", 1) if "." in parameter_name else parameter_name
|
|
tp_plan = model.tp_plan or {}
|
|
module_to_tp = model.get_submodule(param_name)
|
|
rank = int(rank)
|
|
current_shard_plan = _get_parameter_tp_plan(parameter_name, tp_plan)
|
|
|
|
if dist.get_rank() == 0:
|
|
if current_shard_plan is None:
|
|
logger.info(f"Tensor sharding plan for {param_name} not found, using default 'replicate' plan.")
|
|
else:
|
|
logger.info(f"Tensor sharding plan for {param_name}: {current_shard_plan}")
|
|
|
|
if current_shard_plan is not None:
|
|
try:
|
|
tp_layer = ALL_PARALLEL_STYLES[current_shard_plan]
|
|
tp_layer.empty_param = empty_param
|
|
tp_layer.device_mesh = device_mesh
|
|
tp_layer.rank = rank
|
|
param = tp_layer.shard_tensor(param, tensor_idx=None, dtype=param_casting_dtype, device=rank)
|
|
if is_contiguous:
|
|
param = param.contiguous()
|
|
except NotImplementedError as e:
|
|
print(
|
|
f"Trying to prepare {parameter_name}, but it's not supported. Corresponding module: {module_to_tp} Fix it's TP plan, current layer: {tp_layer} : {e}"
|
|
)
|
|
else:
|
|
param = param[:].to(param_casting_dtype)
|
|
|
|
# SUPER IMPORTANT we have to use setattr
|
|
# otherwise loading is crazy slow
|
|
if not isinstance(param, torch.nn.Parameter):
|
|
param = torch.nn.Parameter(param, requires_grad=empty_param.is_floating_point())
|
|
setattr(module_to_tp, param_type, param)
|
|
return param
|
|
|
|
|
|
def verify_tp_plan(expected_keys: list[str], tp_plan: dict[str, str] | None):
|
|
"""
|
|
Verify the TP plan of the model, log a warning if the layers that were not sharded and the rules that were not applied.
|
|
"""
|
|
|
|
if tp_plan is None:
|
|
return
|
|
|
|
generic_keys = {replace_layer_number_by_wildcard(key) for key in expected_keys}
|
|
unsharded_layers = set(generic_keys)
|
|
unused_rules = tp_plan.copy()
|
|
|
|
for key in generic_keys:
|
|
param_name = key.rsplit(".", 1)[0] if "." in key else key
|
|
generic_param_name = re.sub(r"\d+", "*", param_name)
|
|
|
|
if generic_param_name in tp_plan:
|
|
unused_rules.pop(generic_param_name, None)
|
|
unsharded_layers.discard(key)
|
|
elif "." in generic_param_name and (parent_param_name := generic_param_name.rsplit(".", 1)[0]) in tp_plan:
|
|
unused_rules.pop(parent_param_name, None)
|
|
unsharded_layers.discard(key)
|
|
|
|
if len(unused_rules) > 0:
|
|
logger.warning(f"The following TP rules were not applied on any of the layers: {unused_rules}")
|
|
if len(unsharded_layers) > 0:
|
|
logger.warning(f"The following layers were not sharded: {', '.join(unsharded_layers)}")
|
|
|
|
|
|
def distribute_model(model, tp_plan, distributed_config, device_mesh, tp_size):
|
|
"""Distribute a model according to the TP plan."""
|
|
model._tp_size = tp_size
|
|
model._device_mesh = device_mesh
|
|
if distributed_config is not None:
|
|
if isinstance(distributed_config, dict):
|
|
distributed_config = DistributedConfig.from_dict(distributed_config)
|
|
model.config.distributed_config = distributed_config
|
|
# Set the new requested tp_plan on the model
|
|
if isinstance(tp_plan, dict):
|
|
model.tp_plan = tp_plan
|
|
model_plan = model.tp_plan
|
|
if model_plan is not None and _torch_distributed_available:
|
|
for v in model_plan.values():
|
|
if v not in ALL_PARALLEL_STYLES:
|
|
raise ValueError(f"Unsupported tensor parallel style {v}. Supported styles are {ALL_PARALLEL_STYLES}")
|
|
for name, module in model.named_modules():
|
|
if not getattr(module, "_is_hooked", False):
|
|
plan = _get_parameter_tp_plan(parameter_name=name, tp_plan=model_plan, is_weight=False)
|
|
add_tensor_parallel_hooks_to_module(
|
|
model=model,
|
|
module=module,
|
|
tp_plan=model_plan,
|
|
layer_name="",
|
|
current_module_plan=plan,
|
|
device_mesh=device_mesh,
|
|
)
|
|
module._is_hooked = True
|
|
return model
|