You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
77 lines
2.3 KiB
77 lines
2.3 KiB
from __future__ import annotations
|
|
|
|
import importlib
|
|
import logging
|
|
import os
|
|
from importlib.metadata import PackageNotFoundError, metadata
|
|
|
|
import torch
|
|
from transformers import is_torch_npu_available
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def get_device_name() -> str:
|
|
"""
|
|
Returns the name of the device where this module is running on.
|
|
|
|
This function only supports single device or basic distributed training setups.
|
|
In distributed mode for cuda device, it uses the rank to assign a specific CUDA device.
|
|
|
|
Returns:
|
|
str: Device name, like 'cuda:2', 'mps', 'npu', 'xpu', 'hpu', or 'cpu'
|
|
"""
|
|
if torch.cuda.is_available():
|
|
if "LOCAL_RANK" in os.environ:
|
|
local_rank = int(os.environ["LOCAL_RANK"])
|
|
elif torch.distributed.is_initialized() and torch.cuda.device_count() > torch.distributed.get_rank():
|
|
local_rank = torch.distributed.get_rank()
|
|
else:
|
|
local_rank = 0
|
|
return f"cuda:{local_rank}"
|
|
elif torch.backends.mps.is_available():
|
|
return "mps"
|
|
elif is_torch_npu_available():
|
|
return "npu"
|
|
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
|
return "xpu"
|
|
elif importlib.util.find_spec("habana_frameworks") is not None:
|
|
import habana_frameworks.torch.hpu as hthpu
|
|
|
|
if hthpu.is_available():
|
|
return "hpu"
|
|
return "cpu"
|
|
|
|
|
|
def check_package_availability(package_name: str, owner: str) -> bool:
|
|
"""
|
|
Checks if a package is available from the correct owner.
|
|
"""
|
|
try:
|
|
meta = metadata(package_name)
|
|
return meta["Name"] == package_name and owner in meta["Home-page"]
|
|
except PackageNotFoundError:
|
|
return False
|
|
|
|
|
|
def is_accelerate_available() -> bool:
|
|
"""
|
|
Returns True if the Huggingface accelerate library is available.
|
|
"""
|
|
return check_package_availability("accelerate", "huggingface")
|
|
|
|
|
|
def is_datasets_available() -> bool:
|
|
"""
|
|
Returns True if the Huggingface datasets library is available.
|
|
"""
|
|
return check_package_availability("datasets", "huggingface")
|
|
|
|
|
|
def is_training_available() -> bool:
|
|
"""
|
|
Returns True if we have the required dependencies for training Sentence
|
|
Transformers models, i.e. Huggingface datasets and Huggingface accelerate.
|
|
"""
|
|
return is_accelerate_available() and is_datasets_available()
|