You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
62 lines
2.4 KiB
62 lines
2.4 KiB
from __future__ import annotations
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from transformers.utils import logging
|
|
|
|
# NOTE: transformers wraps the regular logging module for e.g. warning_once
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
def all_gather(tensor: torch.Tensor, with_grad: bool = False) -> torch.Tensor:
|
|
"""
|
|
Gathers a tensor from each distributed rank into a list. Always retains gradients for the local rank's tensor,
|
|
and optionally retains gradients for the gathered tensors if `with_grad` is True.
|
|
|
|
Args:
|
|
tensor (torch.Tensor): The tensor to gather from each rank.
|
|
with_grad (bool, optional): If True, the local rank's tensor retains its gradients. Defaults to False.
|
|
|
|
Returns:
|
|
torch.Tensor: A tensor containing the gathered tensors from all ranks, concatenated along the first dimension.
|
|
If torch.distributed is not available or not initialized, returns the original tensor.
|
|
"""
|
|
|
|
if dist.is_available() and dist.is_initialized():
|
|
if with_grad:
|
|
gathered_tensors = torch.distributed.nn.all_gather(tensor)
|
|
else:
|
|
world_size = dist.get_world_size()
|
|
gathered_tensors = [torch.zeros_like(tensor) for _ in range(world_size)]
|
|
|
|
# Perform all_gather.
|
|
dist.all_gather(gathered_tensors, tensor)
|
|
|
|
# Replace local rank's tensor with the original (retaining gradients).
|
|
local_rank = dist.get_rank()
|
|
gathered_tensors[local_rank] = tensor
|
|
return torch.cat(gathered_tensors, dim=0)
|
|
|
|
# Warn once about uninitialized or single-GPU usage.
|
|
warning = (
|
|
"Trying to gather while torch.distributed is not available or has not been initialized, "
|
|
"returning the original (local) tensor. This is expected if you are "
|
|
"only using one GPU; consider not using gathering to remove this warning."
|
|
)
|
|
logger.warning_once(warning)
|
|
return tensor
|
|
|
|
|
|
def all_gather_with_grad(tensor: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Gathers a tensor from each distributed rank into a list, retaining gradients for the local rank's tensor.
|
|
|
|
Args:
|
|
tensor (torch.Tensor): The tensor to gather from each rank.
|
|
|
|
Returns:
|
|
torch.Tensor: A tensor containing the gathered tensors from all ranks, concatenated along the first dimension.
|
|
If torch.distributed is not available or not initialized, returns the original tensor.
|
|
"""
|
|
return all_gather(tensor, with_grad=True)
|