You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

281 lines
10 KiB

5 days ago
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..utils import PushToHubMixin, is_torch_available
if is_torch_available():
import torch
def infer_device(model):
"""
Infers the device type from the model parameters.
Args:
model: The model instance.
Returns:
The device type.
"""
EXAMPLE_MAPPING = """
{
"RMSNorm": {
"cuda":
"kernels-community/layer_norm:LlamaRMSNorm",
...
},
...
}
"""
try:
param = next(model.parameters())
except StopIteration:
raise ValueError(
f"Cannot determine model device, please provide a device to the mapping. Example: {EXAMPLE_MAPPING}"
)
dev_type = param.device.type
if dev_type == "cuda":
# Refine based on actual platform
if torch.version.hip is not None:
return "rocm"
return dev_type
def add_to_mapping(layer_name, device, repo_name, mode, compatible_mapping):
from kernels import LayerRepository
if device not in ["cuda", "rocm", "xpu", "npu"]:
raise ValueError(f"Only cuda, rocm, xpu and npu devices supported, got: {device}")
repo_layer_name = repo_name.split(":")[1]
repo_id = repo_name.split(":")[0]
compatible_mapping[layer_name] = {
device: {
mode: LayerRepository(
repo_id=repo_id,
layer_name=repo_layer_name,
)
}
}
def add_to_mapping_local(layer_name, device, repo_name, mode, compatible_mapping):
from pathlib import Path
from kernels import LocalLayerRepository
if device not in ["cuda", "rocm", "xpu", "npu"]:
raise ValueError(f"Only cuda, rocm, xpu and npu devices supported, got: {device}")
repo_layer_name = repo_name.split(":")[1]
repo_path = repo_name.split(":")[0]
repo_package_name = repo_path.split("/")[-1]
compatible_mapping[layer_name] = {
device: {
mode: LocalLayerRepository(
repo_path=Path(repo_path),
package_name=repo_package_name,
layer_name=repo_layer_name,
)
}
}
class KernelConfig(PushToHubMixin):
"""
Kernel configuration class. This class is used to configure the kernel mapping for a model.
"""
def __init__(self, kernel_mapping={}, use_local_kernel=False):
self.kernel_mapping = kernel_mapping
self.registered_layer_names = {}
self.use_local_kernel = use_local_kernel
def update_kernel(self, repo_id, registered_name, layer_name, device, mode, revision=None):
from kernels import LayerRepository
self.kernel_mapping[registered_name] = {
device: {
mode: LayerRepository(
repo_id=repo_id,
layer_name=layer_name,
revision=revision,
)
}
}
def store_registered_layer_names(self, model):
for name, module in model.named_modules():
if hasattr(module, "kernel_layer_name"):
self.registered_layer_names[name] = module.kernel_layer_name
def sanitize_kernel_mapping(self, model):
"""
Validates the kernel_mapping to ensure that:
1. Each layer_name in the mapping is registered in the model (i.e., the model contains a module with a matching kernel_layer_name).
2. Each kernel value is either a string of the form 'org/repo:layer_name' or a dict mapping device types ("cuda", "rocm", "xpu", "npu") to such strings.
3. Each device key in a dict is one of "cuda", "rocm", "xpu", or "npu".
4. Each repo_name is a valid repository and layer name in the format 'org/repo:layer_name' (i.e., a string containing both a slash and a colon).
5. If a local path is detected, it should be in the format '/abs/path:layer_name'. The absolute path must include the `package_name`, like "/home/user/layer_norm".
Args:
model: The model instance whose modules are checked for registered kernel_layer_name attributes.
Raises:
ValueError: If a layer_name is not registered in the model, if a device is not supported,
or if a repo_name is not a valid 'org/repo:layer_name' string.
"""
MAPPING_FORMAT = """
For single device form remote
{
"RMSNorm":
"kernels-community/layer_norm:LlamaRMSNorm",
...
},
For multiple devices form remote
{
"RMSNorm": {
"cuda":
"kernels-community/layer_norm:LlamaRMSNorm",
"rocm":
"kernels-community/layer_norm:LlamaRMSNorm",
...
},
...
}
For single device form local
{
"RMSNorm":
"/abs/path:LlamaRMSNorm",
...
},
For multiple devices form local
{
"RMSNorm": {
"cuda":
"/abs/path:LlamaRMSNorm",
"rocm":
"/abs/path:LlamaRMSNorm",
...
},
...
}
"""
self.store_registered_layer_names(model)
# Validate that the kernel mapping is a dict
if not isinstance(self.kernel_mapping, dict):
raise ValueError(
f"Kernel mapping must be a dict of the following format: {MAPPING_FORMAT}, got: {type(self.kernel_mapping)}"
)
for layer_name, kernel in self.kernel_mapping.items():
if layer_name not in self.registered_layer_names.values():
raise ValueError(
f"Layer {layer_name} is not registered in the model, please register it first using register_kernel_forward_from_hub"
)
if isinstance(kernel, str):
if "/" not in kernel or ":" not in kernel:
raise ValueError(
f"Kernel mapping for '{layer_name}' must be a valid repo name with a layer name (e.g., 'org/repo:layer_name' or '/abs/path:layer_name'), got: {kernel}"
)
elif isinstance(kernel, dict):
for device, repo_name in kernel.items():
if device not in ["cuda", "rocm", "xpu", "npu"]:
raise ValueError(f"Only cuda, rocm, xpu and npu devices supported, got: {device}")
if not isinstance(repo_name, str) or "/" not in repo_name or ":" not in repo_name:
raise ValueError(
f"Kernel mapping for '{layer_name}' must be a valid repo name with a layer name (e.g., 'org/repo:layer_name' or '/abs/path:layer_name'), got: {repo_name}"
)
else:
raise ValueError(f"Kernel mapping must follow the format: {MAPPING_FORMAT}, got: {kernel}")
def create_compatible_mapping(self, model, compile=False):
"""
Transforms a simple kernel_mapping of the form:
{
"RMSNorm":
"kernels-community/layer_norm:LlamaRMSNorm",
...
},
or for local path:
{
"RMSNorm":
"/home/user/liger_kernels:LigerRMSNorm",
...
},
into a nested mapping:
{
"RMSNorm": {
"cuda": {
Mode.INFERENCE: LayerRepository(
repo_id="kernels-community/layer_norm",
layer_name="LlamaRMSNorm",
)
}
}
}
or for local path:
{
"RMSNorm": {
"cuda": {
Mode.INFERENCE: LocalLayerRepository(
repo_path=Path("/home/user/liger_kernels"),
package_name="liger_kernels",
layer_name="LigerRMSNorm",
)
}
}
}
that's compatible with the kernels library.
The device is inferred from the model's parameters if not provided.
The Mode is inferred from the model's training state.
"""
from kernels import Mode
compatible_mapping = {}
current_device = infer_device(model)
for layer_name, kernel in self.kernel_mapping.items():
# Infer Mode: use Mode.TRAINING if model is training, else use Mode.INFERENCE
mode = Mode.TRAINING if model.training else Mode.INFERENCE
if compile:
mode = mode | Mode.TORCH_COMPILE
if isinstance(kernel, str):
repo_name = kernel
if not self.use_local_kernel:
add_to_mapping(layer_name, current_device, repo_name, mode, compatible_mapping)
else:
add_to_mapping_local(layer_name, current_device, repo_name, mode, compatible_mapping)
elif isinstance(kernel, dict):
for device, repo_name in kernel.items():
if device != current_device:
continue
if not self.use_local_kernel:
add_to_mapping(layer_name, device, repo_name, mode, compatible_mapping)
else:
add_to_mapping_local(layer_name, device, repo_name, mode, compatible_mapping)
self.kernel_mapping = compatible_mapping