You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

104 lines
2.6 KiB

import functools
import torch
@functools.cache
def has_jax_package() -> bool:
"""Check if JAX is installed."""
try:
import jax # noqa: F401 # type: ignore[import-not-found]
return True
except ImportError:
return False
@functools.cache
def has_pallas_package() -> bool:
"""Check if Pallas (JAX experimental) is available."""
if not has_jax_package():
return False
try:
from jax.experimental import ( # noqa: F401 # type: ignore[import-not-found]
pallas as pl,
)
return True
except ImportError:
return False
@functools.cache
def get_jax_version(fallback: tuple[int, int, int] = (0, 0, 0)) -> tuple[int, int, int]:
"""Get JAX version as (major, minor, patch) tuple."""
try:
import jax # type: ignore[import-not-found]
version_parts = jax.__version__.split(".")
major, minor, patch = (int(v) for v in version_parts[:3])
return (major, minor, patch)
except (ImportError, ValueError, AttributeError):
return fallback
@functools.cache
def has_jax_cuda_backend() -> bool:
"""Check if JAX has CUDA backend support."""
if not has_jax_package():
return False
try:
import jax # type: ignore[import-not-found]
# Check if CUDA backend is available
devices = jax.devices("gpu")
return len(devices) > 0
except Exception:
return False
@functools.cache
def has_jax_tpu_backend() -> bool:
"""Check if JAX has TPU backend support."""
if not has_jax_package():
return False
try:
import jax # type: ignore[import-not-found]
# Check if TPU backend is available
devices = jax.devices("tpu")
return len(devices) > 0
except Exception:
return False
@functools.cache
def has_cpu_pallas() -> bool:
"""Checks for a full Pallas-on-CPU environment."""
return has_pallas_package()
@functools.cache
def has_cuda_pallas() -> bool:
"""Checks for a full Pallas-on-CUDA environment."""
return has_pallas_package() and torch.cuda.is_available() and has_jax_cuda_backend()
@functools.cache
def has_tpu_pallas() -> bool:
"""Checks for a full Pallas-on-TPU environment."""
return has_pallas_package() and has_jax_tpu_backend()
@functools.cache
def has_pallas() -> bool:
"""
Check if Pallas backend is fully available for use.
Requirements:
- JAX package installed
- Pallas (jax.experimental.pallas) available
- A compatible backend (CUDA or TPU) is available in both PyTorch and JAX.
"""
return has_cpu_pallas() or has_cuda_pallas() or has_tpu_pallas()