# Copyright 2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import importlib.metadata import os import re from collections.abc import Callable from types import ModuleType from packaging import version as pkg_version from ..utils import ENV_VARS_TRUE_VALUES, logging from ..utils.import_utils import is_kernels_available from .flash_attention import flash_attention_forward logger = logging.get_logger(__name__) try: from kernels import ( Device, LayerRepository, Mode, register_kernel_mapping, replace_kernel_forward_from_hub, ) from kernels import ( get_kernel as get_kernel_hub, ) from kernels import ( use_kernel_forward_from_hub as _kernels_use_kernel_forward_from_hub, ) # Try to import FuncRepository, fallback if not available try: from kernels import FuncRepository except ImportError: FuncRepository = None # Try to import use_kernel_func_from_hub, fallback if not available try: from kernels import use_kernel_func_from_hub as _kernels_use_kernel_func_from_hub _has_use_kernel_func_from_hub = True except ImportError: _has_use_kernel_func_from_hub = False _TRANSFORMERS_USE_HUB_KERNELS = os.environ.get("USE_HUB_KERNELS", "YES").upper() _kernels_available = True _kernels_enabled = _TRANSFORMERS_USE_HUB_KERNELS in ENV_VARS_TRUE_VALUES def use_kernel_forward_from_hub(layer_name: str): if _kernels_enabled: return _kernels_use_kernel_forward_from_hub(layer_name) else: logger.warning_once( f"kernels hub usage is disabled through the environment USE_HUB_KERNELS={_TRANSFORMERS_USE_HUB_KERNELS}" ) return lambda cls: cls def use_kernel_func_from_hub(func_name: str): if _kernels_enabled and _has_use_kernel_func_from_hub: return _kernels_use_kernel_func_from_hub(func_name) else: if not _has_use_kernel_func_from_hub: logger.warning_once( "use_kernel_func_from_hub is not available in the installed kernels version. " "Please upgrade kernels to use this feature." ) else: logger.warning_once( f"kernels hub usage is disabled through the environment USE_HUB_KERNELS={_TRANSFORMERS_USE_HUB_KERNELS}" ) return lambda func: func _KERNEL_MAPPING: dict[str, dict[Device | str, LayerRepository | dict[Mode, LayerRepository]]] = { "MultiScaleDeformableAttention": { "cuda": LayerRepository( repo_id="kernels-community/deformable-detr", layer_name="MultiScaleDeformableAttention", ) }, "Llama4TextMoe": { "cuda": LayerRepository( repo_id="kernels-community/moe", layer_name="Llama4TextMoe", ) }, "RMSNorm": { "cuda": { Mode.INFERENCE: LayerRepository( repo_id="kernels-community/liger_kernels", layer_name="LigerRMSNorm", # revision="pure-layer-test", ), }, "rocm": { Mode.INFERENCE: LayerRepository( repo_id="kernels-community/liger_kernels", layer_name="LigerRMSNorm", ) }, "xpu": { Mode.INFERENCE: LayerRepository( repo_id="kernels-community/rmsnorm", layer_name="RMSNorm", ) }, "mps": { Mode.INFERENCE: LayerRepository( repo_id="kernels-community/mlx_rmsnorm", layer_name="RMSNorm", ) }, "npu": { Mode.INFERENCE: LayerRepository( repo_id="kernels-community/liger_kernels", layer_name="LigerRMSNorm", ) }, }, "MLP": { "cuda": LayerRepository( repo_id="medmekk/triton-llama-mlp", layer_name="TritonLlamaMLP", ) }, "MegaBlocksMoeMLP": { "cuda": { Mode.TRAINING: LayerRepository( repo_id="kernels-community/megablocks", layer_name="MegaBlocksMoeMLP", ), Mode.INFERENCE: LayerRepository( repo_id="kernels-community/megablocks", layer_name="MegaBlocksMoeMLP", ), }, "rocm": { Mode.INFERENCE: LayerRepository( repo_id="ahadnagy/megablocks", layer_name="MegaBlocksMoeMLP", ) }, "xpu": { Mode.INFERENCE: LayerRepository( repo_id="kernels-community/megablocks", layer_name="MegaBlocksMoeMLP", ) }, }, "FastGELU": { "cuda": { Mode.INFERENCE | Mode.TORCH_COMPILE: LayerRepository( repo_id="kernels-community/activation", layer_name="FastGELU", version=1, ) } }, "QuickGELU": { "cuda": { Mode.INFERENCE | Mode.TORCH_COMPILE: LayerRepository( repo_id="kernels-community/activation", layer_name="QuickGELU", version=1, ) } }, "NewGELU": { "cuda": { Mode.INFERENCE | Mode.TORCH_COMPILE: LayerRepository( repo_id="kernels-community/activation", layer_name="NewGELU", version=1, ) } }, "SiLU": { "cuda": { Mode.INFERENCE | Mode.TORCH_COMPILE: LayerRepository( repo_id="kernels-community/activation", layer_name="Silu", version=1 ) } }, "GeLU": { "cuda": { Mode.INFERENCE | Mode.TORCH_COMPILE: LayerRepository( repo_id="kernels-community/activation", layer_name="Gelu", version=1 ) } }, "GeluTanh": { "cuda": { Mode.INFERENCE | Mode.TORCH_COMPILE: LayerRepository( repo_id="kernels-community/activation", layer_name="GeluTanh", version=1 ) } }, } # Add function kernel mappings if FuncRepository is available if FuncRepository is not None: _KERNEL_MAPPING["rotary_pos_emb"] = { "xpu": { Mode.INFERENCE: FuncRepository( repo_id="kernels-community/rotary", func_name="apply_rotary_transformers" ) }, "cuda": { Mode.INFERENCE: FuncRepository( repo_id="kernels-community/rotary", func_name="apply_rotary_transformers" ) }, } def has_key(d, key): return key in d or any(isinstance(v, dict) and has_key(v, key) for v in d.values()) def register_kernel_mapping_transformers(mapping=None): if mapping is None: mapping = _KERNEL_MAPPING if has_key(mapping, "xpu") and not is_kernels_available(MIN_VERSION="0.10.2"): raise ImportError( "kernels uses an incompatible version. Please install the latest version with `pip install -U kernels`." ) register_kernel_mapping(mapping) except ImportError: _kernels_available = False _kernels_enabled = False # Stub to make decorators int transformers work when `kernels` # is not installed. def use_kernel_forward_from_hub(*args, **kwargs): def decorator(cls): return cls return decorator def use_kernel_func_from_hub(*args, **kwargs): def decorator(func): return func return decorator class LayerRepository: def __init__(self, *args, **kwargs): raise RuntimeError("LayerRepository requires `kernels` to be installed. Run `pip install kernels`.") def replace_kernel_forward_from_hub(*args, **kwargs): raise RuntimeError( "replace_kernel_forward_from_hub requires `kernels` to be installed. Run `pip install kernels`." ) def register_kernel_mapping(*args, **kwargs): raise RuntimeError("register_kernel_mapping requires `kernels` to be installed. Run `pip install kernels`.") def register_kernel_mapping_transformers(*args, **kwargs): raise RuntimeError( "register_kernel_mapping_transformers requires `kernels` to be installed. Run `pip install kernels`." ) _HUB_KERNEL_MAPPING: dict[str, dict[str, str]] = { "causal-conv1d": {"repo_id": "kernels-community/causal-conv1d"}, "mamba-ssm": {"repo_id": "kernels-community/mamba-ssm", "revision": "v0.0.4"}, "falcon_mamba-ssm": {"repo_id": "kernels-community/mamba-ssm", "revision": "v0.0.4"}, } _KERNEL_MODULE_MAPPING: dict[str, ModuleType | None] = {} def is_kernel(attn_implementation: str | None) -> bool: """Check whether `attn_implementation` matches a kernel pattern from the hub.""" return ( attn_implementation is not None and re.search(r"^[^/:]+/[^/:]+(?:@[^/:]+)?(?::[^/:]+)?$", attn_implementation) is not None ) def load_and_register_attn_kernel( attn_implementation: str, attention_wrapper: Callable | None = None ) -> ModuleType | None: """ Load and register the kernel associated to `attn_implementation`. Args: attn_implementation: A string, usually a kernel repo like "kernels-community/flash-mla". attn_wrapper: a callable for the wrapper around the attention implementation. In `transformers` we have a wrapper around the `flash_attn_var_len` call, and the same goes for `sdpa` and `eager`. They just prepare the arguments properly. This is mostly used for continious batching, where we want the `paged` wrapper, which calls the paged cache. """ from ..masking_utils import ALL_MASK_ATTENTION_FUNCTIONS from ..modeling_utils import ALL_ATTENTION_FUNCTIONS actual_attn_name = attn_implementation.split("|")[1] if "|" in attn_implementation else attn_implementation if not is_kernel(actual_attn_name): return None if not _kernels_available: raise ImportError( "`kernels` is either not installed or uses an incompatible version. " "Please install the latest version with `pip install -U kernels`." ) # Extract repo_id and kernel_name from the string if ":" in actual_attn_name: repo_id, kernel_name = actual_attn_name.split(":") kernel_name = kernel_name.strip() else: repo_id = actual_attn_name kernel_name = None repo_id = repo_id.strip() # extract the rev after the @ if it exists repo_id, _, rev = repo_id.partition("@") repo_id = repo_id.strip() rev = rev.strip() if rev else None # Load the kernel from hub try: kernel = get_kernel(repo_id, revision=rev) except Exception as e: raise ValueError(f"An error occurred while trying to load from '{repo_id}': {e}.") # correctly wrap the kernel if hasattr(kernel, "flash_attn_varlen_func"): if attention_wrapper is None: attention_wrapper = flash_attention_forward kernel_function = attention_wrapper elif kernel_name is not None: kernel_function = getattr(kernel, kernel_name) # Register the kernel as a valid attention ALL_ATTENTION_FUNCTIONS.register(attn_implementation, kernel_function) ALL_MASK_ATTENTION_FUNCTIONS.register(attn_implementation, ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"]) return kernel def lazy_load_kernel(kernel_name: str, mapping: dict[str, ModuleType | None] = _KERNEL_MODULE_MAPPING): if kernel_name in mapping and isinstance(mapping[kernel_name], ModuleType): return mapping[kernel_name] if kernel_name not in _HUB_KERNEL_MAPPING: logger.warning_once(f"Kernel {kernel_name} not found in _HUB_KERNEL_MAPPING") mapping[kernel_name] = None return None if _kernels_available: try: repo_id = _HUB_KERNEL_MAPPING[kernel_name]["repo_id"] revision = _HUB_KERNEL_MAPPING[kernel_name].get("revision", None) version = _HUB_KERNEL_MAPPING[kernel_name].get("version", None) kernel = get_kernel(repo_id, revision=revision, version=version) mapping[kernel_name] = kernel except FileNotFoundError: mapping[kernel_name] = None except AssertionError: # Happens when torch is built without an accelerator backend; fall back to slow path. mapping[kernel_name] = None else: # Try to import is_{kernel_name}_available from ..utils import importlib new_kernel_name = kernel_name.replace("-", "_") func_name = f"is_{new_kernel_name}_available" try: utils_mod = importlib.import_module("..utils.import_utils", __package__) is_kernel_available = getattr(utils_mod, func_name, None) except Exception: is_kernel_available = None if callable(is_kernel_available) and is_kernel_available(): # Try to import the module "{kernel_name}" from parent package level try: module = importlib.import_module(f"{new_kernel_name}") mapping[kernel_name] = module return module except Exception: mapping[kernel_name] = None else: mapping[kernel_name] = None return mapping[kernel_name] def get_kernel(kernel_name: str, revision: str | None = None, version: str | None = None) -> ModuleType: from .. import __version__ user_agent = {"framework": "transformers", "version": __version__, "repo_id": kernel_name} if _kernels_available: kernels_version = importlib.metadata.version("kernels") if pkg_version.parse(kernels_version) >= pkg_version.parse("0.10.4"): return get_kernel_hub(kernel_name, revision=revision, version=version, user_agent=user_agent) else: return get_kernel_hub(kernel_name, revision=revision) else: raise ImportError("kernels is not installed, please install it with `pip install kernels`") def use_kernelized_func(module_names: list[Callable] | Callable): """ This decorator attaches the target function as an attribute of the module. The function must already be decorated with @use_kernel_func_from_hub this decorator then wraps it as an nn.Module internally. When kernelize is later applied to the full model, the function can be accessed as a regular module attribute and kernelized just like any other layer. The kernelization is performed in place, modifying the module directly. """ if isinstance(module_names, Callable): module_names = [module_names] def decorator(cls): orig_init = cls.__init__ def new_init(self, *args, **kwargs): orig_init(self, *args, **kwargs) for fn in module_names: # we hardcode the name of the function to "rotary_fn" for now setattr(self, "rotary_fn", fn) cls.__init__ = new_init return cls return decorator __all__ = [ "LayerRepository", "use_kernel_forward_from_hub", "use_kernel_func_from_hub", "register_kernel_mapping", "register_kernel_mapping_transformers", "replace_kernel_forward_from_hub", "lazy_load_kernel", "get_kernel", "use_kernelized_func", ] # type: ignore