You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

62 lines
2.4 KiB

from __future__ import annotations
import torch
import torch.distributed as dist
from transformers.utils import logging
# NOTE: transformers wraps the regular logging module for e.g. warning_once
logger = logging.get_logger(__name__)
def all_gather(tensor: torch.Tensor, with_grad: bool = False) -> torch.Tensor:
"""
Gathers a tensor from each distributed rank into a list. Always retains gradients for the local rank's tensor,
and optionally retains gradients for the gathered tensors if `with_grad` is True.
Args:
tensor (torch.Tensor): The tensor to gather from each rank.
with_grad (bool, optional): If True, the local rank's tensor retains its gradients. Defaults to False.
Returns:
torch.Tensor: A tensor containing the gathered tensors from all ranks, concatenated along the first dimension.
If torch.distributed is not available or not initialized, returns the original tensor.
"""
if dist.is_available() and dist.is_initialized():
if with_grad:
gathered_tensors = torch.distributed.nn.all_gather(tensor)
else:
world_size = dist.get_world_size()
gathered_tensors = [torch.zeros_like(tensor) for _ in range(world_size)]
# Perform all_gather.
dist.all_gather(gathered_tensors, tensor)
# Replace local rank's tensor with the original (retaining gradients).
local_rank = dist.get_rank()
gathered_tensors[local_rank] = tensor
return torch.cat(gathered_tensors, dim=0)
# Warn once about uninitialized or single-GPU usage.
warning = (
"Trying to gather while torch.distributed is not available or has not been initialized, "
"returning the original (local) tensor. This is expected if you are "
"only using one GPU; consider not using gathering to remove this warning."
)
logger.warning_once(warning)
return tensor
def all_gather_with_grad(tensor: torch.Tensor) -> torch.Tensor:
"""
Gathers a tensor from each distributed rank into a list, retaining gradients for the local rank's tensor.
Args:
tensor (torch.Tensor): The tensor to gather from each rank.
Returns:
torch.Tensor: A tensor containing the gathered tensors from all ranks, concatenated along the first dimension.
If torch.distributed is not available or not initialized, returns the original tensor.
"""
return all_gather(tensor, with_grad=True)