You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

449 lines
16 KiB

4 days ago
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib.metadata
import os
import re
from collections.abc import Callable
from types import ModuleType
from packaging import version as pkg_version
from ..utils import ENV_VARS_TRUE_VALUES, logging
from ..utils.import_utils import is_kernels_available
from .flash_attention import flash_attention_forward
logger = logging.get_logger(__name__)
try:
from kernels import (
Device,
LayerRepository,
Mode,
register_kernel_mapping,
replace_kernel_forward_from_hub,
)
from kernels import (
get_kernel as get_kernel_hub,
)
from kernels import (
use_kernel_forward_from_hub as _kernels_use_kernel_forward_from_hub,
)
# Try to import FuncRepository, fallback if not available
try:
from kernels import FuncRepository
except ImportError:
FuncRepository = None
# Try to import use_kernel_func_from_hub, fallback if not available
try:
from kernels import use_kernel_func_from_hub as _kernels_use_kernel_func_from_hub
_has_use_kernel_func_from_hub = True
except ImportError:
_has_use_kernel_func_from_hub = False
_TRANSFORMERS_USE_HUB_KERNELS = os.environ.get("USE_HUB_KERNELS", "YES").upper()
_kernels_available = True
_kernels_enabled = _TRANSFORMERS_USE_HUB_KERNELS in ENV_VARS_TRUE_VALUES
def use_kernel_forward_from_hub(layer_name: str):
if _kernels_enabled:
return _kernels_use_kernel_forward_from_hub(layer_name)
else:
logger.warning_once(
f"kernels hub usage is disabled through the environment USE_HUB_KERNELS={_TRANSFORMERS_USE_HUB_KERNELS}"
)
return lambda cls: cls
def use_kernel_func_from_hub(func_name: str):
if _kernels_enabled and _has_use_kernel_func_from_hub:
return _kernels_use_kernel_func_from_hub(func_name)
else:
if not _has_use_kernel_func_from_hub:
logger.warning_once(
"use_kernel_func_from_hub is not available in the installed kernels version. "
"Please upgrade kernels to use this feature."
)
else:
logger.warning_once(
f"kernels hub usage is disabled through the environment USE_HUB_KERNELS={_TRANSFORMERS_USE_HUB_KERNELS}"
)
return lambda func: func
_KERNEL_MAPPING: dict[str, dict[Device | str, LayerRepository | dict[Mode, LayerRepository]]] = {
"MultiScaleDeformableAttention": {
"cuda": LayerRepository(
repo_id="kernels-community/deformable-detr",
layer_name="MultiScaleDeformableAttention",
)
},
"Llama4TextMoe": {
"cuda": LayerRepository(
repo_id="kernels-community/moe",
layer_name="Llama4TextMoe",
)
},
"RMSNorm": {
"cuda": {
Mode.INFERENCE: LayerRepository(
repo_id="kernels-community/liger_kernels",
layer_name="LigerRMSNorm",
# revision="pure-layer-test",
),
},
"rocm": {
Mode.INFERENCE: LayerRepository(
repo_id="kernels-community/liger_kernels",
layer_name="LigerRMSNorm",
)
},
"xpu": {
Mode.INFERENCE: LayerRepository(
repo_id="kernels-community/rmsnorm",
layer_name="RMSNorm",
)
},
"mps": {
Mode.INFERENCE: LayerRepository(
repo_id="kernels-community/mlx_rmsnorm",
layer_name="RMSNorm",
)
},
"npu": {
Mode.INFERENCE: LayerRepository(
repo_id="kernels-community/liger_kernels",
layer_name="LigerRMSNorm",
)
},
},
"MLP": {
"cuda": LayerRepository(
repo_id="medmekk/triton-llama-mlp",
layer_name="TritonLlamaMLP",
)
},
"MegaBlocksMoeMLP": {
"cuda": {
Mode.TRAINING: LayerRepository(
repo_id="kernels-community/megablocks",
layer_name="MegaBlocksMoeMLP",
),
Mode.INFERENCE: LayerRepository(
repo_id="kernels-community/megablocks",
layer_name="MegaBlocksMoeMLP",
),
},
"rocm": {
Mode.INFERENCE: LayerRepository(
repo_id="ahadnagy/megablocks",
layer_name="MegaBlocksMoeMLP",
)
},
"xpu": {
Mode.INFERENCE: LayerRepository(
repo_id="kernels-community/megablocks",
layer_name="MegaBlocksMoeMLP",
)
},
},
"FastGELU": {
"cuda": {
Mode.INFERENCE | Mode.TORCH_COMPILE: LayerRepository(
repo_id="kernels-community/activation",
layer_name="FastGELU",
version=1,
)
}
},
"QuickGELU": {
"cuda": {
Mode.INFERENCE | Mode.TORCH_COMPILE: LayerRepository(
repo_id="kernels-community/activation",
layer_name="QuickGELU",
version=1,
)
}
},
"NewGELU": {
"cuda": {
Mode.INFERENCE | Mode.TORCH_COMPILE: LayerRepository(
repo_id="kernels-community/activation",
layer_name="NewGELU",
version=1,
)
}
},
"SiLU": {
"cuda": {
Mode.INFERENCE | Mode.TORCH_COMPILE: LayerRepository(
repo_id="kernels-community/activation", layer_name="Silu", version=1
)
}
},
"GeLU": {
"cuda": {
Mode.INFERENCE | Mode.TORCH_COMPILE: LayerRepository(
repo_id="kernels-community/activation", layer_name="Gelu", version=1
)
}
},
"GeluTanh": {
"cuda": {
Mode.INFERENCE | Mode.TORCH_COMPILE: LayerRepository(
repo_id="kernels-community/activation", layer_name="GeluTanh", version=1
)
}
},
}
# Add function kernel mappings if FuncRepository is available
if FuncRepository is not None:
_KERNEL_MAPPING["rotary_pos_emb"] = {
"xpu": {
Mode.INFERENCE: FuncRepository(
repo_id="kernels-community/rotary", func_name="apply_rotary_transformers"
)
},
"cuda": {
Mode.INFERENCE: FuncRepository(
repo_id="kernels-community/rotary", func_name="apply_rotary_transformers"
)
},
}
def has_key(d, key):
return key in d or any(isinstance(v, dict) and has_key(v, key) for v in d.values())
def register_kernel_mapping_transformers(mapping=None):
if mapping is None:
mapping = _KERNEL_MAPPING
if has_key(mapping, "xpu") and not is_kernels_available(MIN_VERSION="0.10.2"):
raise ImportError(
"kernels uses an incompatible version. Please install the latest version with `pip install -U kernels`."
)
register_kernel_mapping(mapping)
except ImportError:
_kernels_available = False
_kernels_enabled = False
# Stub to make decorators int transformers work when `kernels`
# is not installed.
def use_kernel_forward_from_hub(*args, **kwargs):
def decorator(cls):
return cls
return decorator
def use_kernel_func_from_hub(*args, **kwargs):
def decorator(func):
return func
return decorator
class LayerRepository:
def __init__(self, *args, **kwargs):
raise RuntimeError("LayerRepository requires `kernels` to be installed. Run `pip install kernels`.")
def replace_kernel_forward_from_hub(*args, **kwargs):
raise RuntimeError(
"replace_kernel_forward_from_hub requires `kernels` to be installed. Run `pip install kernels`."
)
def register_kernel_mapping(*args, **kwargs):
raise RuntimeError("register_kernel_mapping requires `kernels` to be installed. Run `pip install kernels`.")
def register_kernel_mapping_transformers(*args, **kwargs):
raise RuntimeError(
"register_kernel_mapping_transformers requires `kernels` to be installed. Run `pip install kernels`."
)
_HUB_KERNEL_MAPPING: dict[str, dict[str, str]] = {
"causal-conv1d": {"repo_id": "kernels-community/causal-conv1d"},
"mamba-ssm": {"repo_id": "kernels-community/mamba-ssm", "revision": "v0.0.4"},
"falcon_mamba-ssm": {"repo_id": "kernels-community/mamba-ssm", "revision": "v0.0.4"},
}
_KERNEL_MODULE_MAPPING: dict[str, ModuleType | None] = {}
def is_kernel(attn_implementation: str | None) -> bool:
"""Check whether `attn_implementation` matches a kernel pattern from the hub."""
return (
attn_implementation is not None
and re.search(r"^[^/:]+/[^/:]+(?:@[^/:]+)?(?::[^/:]+)?$", attn_implementation) is not None
)
def load_and_register_attn_kernel(
attn_implementation: str, attention_wrapper: Callable | None = None
) -> ModuleType | None:
"""
Load and register the kernel associated to `attn_implementation`.
Args:
attn_implementation: A string, usually a kernel repo like "kernels-community/flash-mla".
attn_wrapper: a callable for the wrapper around the attention implementation. In `transformers` we
have a wrapper around the `flash_attn_var_len` call, and the same goes for `sdpa` and `eager`.
They just prepare the arguments properly. This is mostly used for continious batching, where we
want the `paged` wrapper, which calls the paged cache.
"""
from ..masking_utils import ALL_MASK_ATTENTION_FUNCTIONS
from ..modeling_utils import ALL_ATTENTION_FUNCTIONS
actual_attn_name = attn_implementation.split("|")[1] if "|" in attn_implementation else attn_implementation
if not is_kernel(actual_attn_name):
return None
if not _kernels_available:
raise ImportError(
"`kernels` is either not installed or uses an incompatible version. "
"Please install the latest version with `pip install -U kernels`."
)
# Extract repo_id and kernel_name from the string
if ":" in actual_attn_name:
repo_id, kernel_name = actual_attn_name.split(":")
kernel_name = kernel_name.strip()
else:
repo_id = actual_attn_name
kernel_name = None
repo_id = repo_id.strip()
# extract the rev after the @ if it exists
repo_id, _, rev = repo_id.partition("@")
repo_id = repo_id.strip()
rev = rev.strip() if rev else None
# Load the kernel from hub
try:
kernel = get_kernel(repo_id, revision=rev)
except Exception as e:
raise ValueError(f"An error occurred while trying to load from '{repo_id}': {e}.")
# correctly wrap the kernel
if hasattr(kernel, "flash_attn_varlen_func"):
if attention_wrapper is None:
attention_wrapper = flash_attention_forward
kernel_function = attention_wrapper
elif kernel_name is not None:
kernel_function = getattr(kernel, kernel_name)
# Register the kernel as a valid attention
ALL_ATTENTION_FUNCTIONS.register(attn_implementation, kernel_function)
ALL_MASK_ATTENTION_FUNCTIONS.register(attn_implementation, ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"])
return kernel
def lazy_load_kernel(kernel_name: str, mapping: dict[str, ModuleType | None] = _KERNEL_MODULE_MAPPING):
if kernel_name in mapping and isinstance(mapping[kernel_name], ModuleType):
return mapping[kernel_name]
if kernel_name not in _HUB_KERNEL_MAPPING:
logger.warning_once(f"Kernel {kernel_name} not found in _HUB_KERNEL_MAPPING")
mapping[kernel_name] = None
return None
if _kernels_available:
try:
repo_id = _HUB_KERNEL_MAPPING[kernel_name]["repo_id"]
revision = _HUB_KERNEL_MAPPING[kernel_name].get("revision", None)
version = _HUB_KERNEL_MAPPING[kernel_name].get("version", None)
kernel = get_kernel(repo_id, revision=revision, version=version)
mapping[kernel_name] = kernel
except FileNotFoundError:
mapping[kernel_name] = None
except AssertionError:
# Happens when torch is built without an accelerator backend; fall back to slow path.
mapping[kernel_name] = None
else:
# Try to import is_{kernel_name}_available from ..utils
import importlib
new_kernel_name = kernel_name.replace("-", "_")
func_name = f"is_{new_kernel_name}_available"
try:
utils_mod = importlib.import_module("..utils.import_utils", __package__)
is_kernel_available = getattr(utils_mod, func_name, None)
except Exception:
is_kernel_available = None
if callable(is_kernel_available) and is_kernel_available():
# Try to import the module "{kernel_name}" from parent package level
try:
module = importlib.import_module(f"{new_kernel_name}")
mapping[kernel_name] = module
return module
except Exception:
mapping[kernel_name] = None
else:
mapping[kernel_name] = None
return mapping[kernel_name]
def get_kernel(kernel_name: str, revision: str | None = None, version: str | None = None) -> ModuleType:
from .. import __version__
user_agent = {"framework": "transformers", "version": __version__, "repo_id": kernel_name}
if _kernels_available:
kernels_version = importlib.metadata.version("kernels")
if pkg_version.parse(kernels_version) >= pkg_version.parse("0.10.4"):
return get_kernel_hub(kernel_name, revision=revision, version=version, user_agent=user_agent)
else:
return get_kernel_hub(kernel_name, revision=revision)
else:
raise ImportError("kernels is not installed, please install it with `pip install kernels`")
def use_kernelized_func(module_names: list[Callable] | Callable):
"""
This decorator attaches the target function as an attribute of the module.
The function must already be decorated with @use_kernel_func_from_hub
this decorator then wraps it as an nn.Module internally.
When kernelize is later applied to the full model, the function can be accessed as a regular module attribute and kernelized just like any other layer.
The kernelization is performed in place, modifying the module directly.
"""
if isinstance(module_names, Callable):
module_names = [module_names]
def decorator(cls):
orig_init = cls.__init__
def new_init(self, *args, **kwargs):
orig_init(self, *args, **kwargs)
for fn in module_names:
# we hardcode the name of the function to "rotary_fn" for now
setattr(self, "rotary_fn", fn)
cls.__init__ = new_init
return cls
return decorator
__all__ = [
"LayerRepository",
"use_kernel_forward_from_hub",
"use_kernel_func_from_hub",
"register_kernel_mapping",
"register_kernel_mapping_transformers",
"replace_kernel_forward_from_hub",
"lazy_load_kernel",
"get_kernel",
"use_kernelized_func",
] # type: ignore