"""Utility functions to use Python Array API compatible libraries. For the context about the Array API see: https://data-apis.org/array-api/latest/purpose_and_scope.html The SciPy use case of the Array API is described on the following page: https://data-apis.org/array-api/latest/use_cases.html#use-case-scipy """ import operator import dataclasses import functools import textwrap from collections.abc import Generator from contextlib import contextmanager from contextvars import ContextVar from types import ModuleType from typing import Any, Literal, TypeAlias from collections.abc import Iterable import numpy as np import numpy.typing as npt from scipy._lib.array_api_compat import ( is_array_api_obj, is_lazy_array, is_numpy_array, is_cupy_array, is_torch_array, is_jax_array, is_dask_array, size as xp_size, numpy as np_compat, device as xp_device, is_numpy_namespace as is_numpy, is_cupy_namespace as is_cupy, is_torch_namespace as is_torch, is_jax_namespace as is_jax, is_dask_namespace as is_dask, is_array_api_strict_namespace as is_array_api_strict, ) from scipy._lib.array_api_compat.common._helpers import _compat_module_name from scipy._lib.array_api_extra.testing import lazy_xp_function from scipy._lib._array_api_override import ( array_namespace, SCIPY_ARRAY_API, SCIPY_DEVICE ) from scipy._lib._docscrape import FunctionDoc from scipy._lib import array_api_extra as xpx __all__ = [ '_asarray', 'array_namespace', 'assert_almost_equal', 'assert_array_almost_equal', 'default_xp', 'eager_warns', 'is_lazy_array', 'is_marray', 'is_array_api_strict', 'is_complex', 'is_cupy', 'is_jax', 'is_numpy', 'is_torch', 'np_compat', 'get_native_namespace_name', 'SCIPY_ARRAY_API', 'SCIPY_DEVICE', 'scipy_namespace_for', 'xp_assert_close', 'xp_assert_equal', 'xp_assert_less', 'xp_copy', 'xp_device', 'xp_ravel', 'xp_size', 'xp_unsupported_param_msg', 'xp_vector_norm', 'xp_capabilities', 'xp_result_type', 'xp_promote', 'make_xp_test_case', 'make_xp_pytest_marks', 'make_xp_pytest_param', ] Array: TypeAlias = Any # To be changed to a Protocol later (see array-api#589) ArrayLike: TypeAlias = Array | npt.ArrayLike def _check_finite(array: Array, xp: ModuleType) -> None: """Check for NaNs or Infs.""" if not xp.all(xp.isfinite(array)): msg = "array must not contain infs or NaNs" raise ValueError(msg) def _asarray( array: ArrayLike, dtype: Any = None, order: Literal['K', 'A', 'C', 'F'] | None = None, copy: bool | None = None, *, xp: ModuleType | None = None, check_finite: bool = False, subok: bool = False, ) -> Array: """SciPy-specific replacement for `np.asarray` with `order`, `check_finite`, and `subok`. Memory layout parameter `order` is not exposed in the Array API standard. `order` is only enforced if the input array implementation is NumPy based, otherwise `order` is just silently ignored. `check_finite` is also not a keyword in the array API standard; included here for convenience rather than that having to be a separate function call inside SciPy functions. `subok` is included to allow this function to preserve the behaviour of `np.asanyarray` for NumPy based inputs. """ if xp is None: xp = array_namespace(array) if is_numpy(xp): # Use NumPy API to support order if copy is True: array = np.array(array, order=order, dtype=dtype, subok=subok) elif subok: array = np.asanyarray(array, order=order, dtype=dtype) else: array = np.asarray(array, order=order, dtype=dtype) else: try: array = xp.asarray(array, dtype=dtype, copy=copy) except TypeError: coerced_xp = array_namespace(xp.asarray(3)) array = coerced_xp.asarray(array, dtype=dtype, copy=copy) if check_finite: _check_finite(array, xp) return array def xp_copy(x: Array, *, xp: ModuleType | None = None) -> Array: """ Copies an array. Parameters ---------- x : array xp : array_namespace Returns ------- copy : array Copied array Notes ----- This copy function does not offer all the semantics of `np.copy`, i.e. the `subok` and `order` keywords are not used. """ # Note: for older NumPy versions, `np.asarray` did not support the `copy` kwarg, # so this uses our other helper `_asarray`. if xp is None: xp = array_namespace(x) return _asarray(x, copy=True, xp=xp) def _xp_copy_to_numpy(x: Array) -> np.ndarray: """Copies a possibly on device array to a NumPy array. This function is intended only for converting alternative backend arrays to numpy arrays within test code, to make it easier for use of the alternative backend to be isolated only to the function being tested. `_xp_copy_to_numpy` should NEVER be used except in test code for the specific purpose mentioned above. In production code, attempts to copy device arrays to NumPy arrays should fail, or else functions may appear to be working on the GPU when they actually aren't. Parameters ---------- x : array Returns ------- ndarray """ xp = array_namespace(x) if is_numpy(xp): return x.copy() if is_cupy(xp): return x.get() if is_torch(xp): return x.cpu().numpy() if is_array_api_strict(xp): # array api strict supports multiple devices, so need to # ensure x is on the cpu before copying to NumPy. return np.asarray( xp.asarray(x, device=xp.Device("CPU_DEVICE")), copy=True ) # Fall back to np.asarray. This works for dask.array. It # currently works for jax.numpy, but hopefully JAX will make # the transfer guard workable enough for use in scipy tests, in # which case, JAX will have to be handled explicitly. # If new backends are added, they may require explicit handling as # well. return np.asarray(x, copy=True) _default_xp_ctxvar: ContextVar[ModuleType] = ContextVar("_default_xp") @contextmanager def default_xp(xp: ModuleType) -> Generator[None, None, None]: """In all ``xp_assert_*`` and ``assert_*`` function calls executed within this context manager, test by default that the array namespace is the provided across all arrays, unless one explicitly passes the ``xp=`` parameter or ``check_namespace=False``. Without this context manager, the default value for `xp` is the namespace for the desired array (the second parameter of the tests). """ token = _default_xp_ctxvar.set(xp) try: yield finally: _default_xp_ctxvar.reset(token) def eager_warns(warning_type, *, match=None, xp): """pytest.warns context manager if arrays of specified namespace are always eager. Otherwise, context manager that *ignores* specified warning. """ import pytest from scipy._lib._util import ignore_warns if is_numpy(xp) or is_array_api_strict(xp) or is_cupy(xp): return pytest.warns(warning_type, match=match) return ignore_warns(warning_type, match='' if match is None else match) def _strict_check(actual, desired, xp, *, check_namespace=True, check_dtype=True, check_shape=True, check_0d=True): __tracebackhide__ = True # Hide traceback for py.test if xp is None: try: xp = _default_xp_ctxvar.get() except LookupError: xp = array_namespace(desired) if check_namespace: _assert_matching_namespace(actual, desired, xp) # only NumPy distinguishes between scalars and arrays; we do if check_0d=True. # do this first so we can then cast to array (and thus use the array API) below. if is_numpy(xp) and check_0d: _msg = ("Array-ness does not match:\n Actual: " f"{type(actual)}\n Desired: {type(desired)}") assert ((xp.isscalar(actual) and xp.isscalar(desired)) or (not xp.isscalar(actual) and not xp.isscalar(desired))), _msg actual = xp.asarray(actual) desired = xp.asarray(desired) if check_dtype: _msg = f"dtypes do not match.\nActual: {actual.dtype}\nDesired: {desired.dtype}" assert actual.dtype == desired.dtype, _msg if check_shape: if is_dask(xp): actual.compute_chunk_sizes() desired.compute_chunk_sizes() _msg = f"Shapes do not match.\nActual: {actual.shape}\nDesired: {desired.shape}" assert actual.shape == desired.shape, _msg desired = xp.broadcast_to(desired, actual.shape) return actual, desired, xp def _assert_matching_namespace(actual, desired, xp): __tracebackhide__ = True # Hide traceback for py.test desired_arr_space = array_namespace(desired) _msg = ("Namespace of desired array does not match expectations " "set by the `default_xp` context manager or by the `xp`" "pytest fixture.\n" f"Desired array's space: {desired_arr_space.__name__}\n" f"Expected namespace: {xp.__name__}") assert desired_arr_space == xp, _msg actual_arr_space = array_namespace(actual) _msg = ("Namespace of actual and desired arrays do not match.\n" f"Actual: {actual_arr_space.__name__}\n" f"Desired: {xp.__name__}") assert actual_arr_space == xp, _msg def xp_assert_equal(actual, desired, *, check_namespace=True, check_dtype=True, check_shape=True, check_0d=True, err_msg='', xp=None): __tracebackhide__ = True # Hide traceback for py.test actual, desired, xp = _strict_check( actual, desired, xp, check_namespace=check_namespace, check_dtype=check_dtype, check_shape=check_shape, check_0d=check_0d ) if is_cupy(xp): return xp.testing.assert_array_equal(actual, desired, err_msg=err_msg) elif is_torch(xp): # PyTorch recommends using `rtol=0, atol=0` like this # to test for exact equality err_msg = None if err_msg == '' else err_msg return xp.testing.assert_close(actual, desired, rtol=0, atol=0, equal_nan=True, check_dtype=False, msg=err_msg) # JAX uses `np.testing` return np.testing.assert_array_equal(actual, desired, err_msg=err_msg) def xp_assert_close(actual, desired, *, rtol=None, atol=0, check_namespace=True, check_dtype=True, check_shape=True, check_0d=True, err_msg='', xp=None): __tracebackhide__ = True # Hide traceback for py.test actual, desired, xp = _strict_check( actual, desired, xp, check_namespace=check_namespace, check_dtype=check_dtype, check_shape=check_shape, check_0d=check_0d ) floating = xp.isdtype(actual.dtype, ('real floating', 'complex floating')) if rtol is None and floating: # multiplier of 4 is used as for `np.float64` this puts the default `rtol` # roughly half way between sqrt(eps) and the default for # `numpy.testing.assert_allclose`, 1e-7 rtol = xp.finfo(actual.dtype).eps**0.5 * 4 elif rtol is None: rtol = 1e-7 if is_cupy(xp): return xp.testing.assert_allclose(actual, desired, rtol=rtol, atol=atol, err_msg=err_msg) elif is_torch(xp): err_msg = None if err_msg == '' else err_msg return xp.testing.assert_close(actual, desired, rtol=rtol, atol=atol, equal_nan=True, check_dtype=False, msg=err_msg) # JAX uses `np.testing` return np.testing.assert_allclose(actual, desired, rtol=rtol, atol=atol, err_msg=err_msg) def xp_assert_close_nulp(actual, desired, *, nulp=1, check_namespace=True, check_dtype=True, check_shape=True, check_0d=True, err_msg='', xp=None): __tracebackhide__ = True # Hide traceback for py.test actual, desired, xp = _strict_check( actual, desired, xp, check_namespace=check_namespace, check_dtype=check_dtype, check_shape=check_shape, check_0d=check_0d ) actual, desired = map(_xp_copy_to_numpy, (actual, desired)) return np.testing.assert_array_almost_equal_nulp(actual, desired, nulp=nulp) def xp_assert_less(actual, desired, *, check_namespace=True, check_dtype=True, check_shape=True, check_0d=True, err_msg='', verbose=True, xp=None): __tracebackhide__ = True # Hide traceback for py.test actual, desired, xp = _strict_check( actual, desired, xp, check_namespace=check_namespace, check_dtype=check_dtype, check_shape=check_shape, check_0d=check_0d ) if is_cupy(xp): return xp.testing.assert_array_less(actual, desired, err_msg=err_msg, verbose=verbose) elif is_torch(xp): if actual.device.type != 'cpu': actual = actual.cpu() if desired.device.type != 'cpu': desired = desired.cpu() # JAX uses `np.testing` return np.testing.assert_array_less(actual, desired, err_msg=err_msg, verbose=verbose) def assert_array_almost_equal(actual, desired, decimal=6, *args, **kwds): """Backwards compatible replacement. In new code, use xp_assert_close instead. """ rtol, atol = 0, 1.5*10**(-decimal) return xp_assert_close(actual, desired, atol=atol, rtol=rtol, check_dtype=False, check_shape=False, *args, **kwds) def assert_almost_equal(actual, desired, decimal=7, *args, **kwds): """Backwards compatible replacement. In new code, use xp_assert_close instead. """ rtol, atol = 0, 1.5*10**(-decimal) return xp_assert_close(actual, desired, atol=atol, rtol=rtol, check_dtype=False, check_shape=False, *args, **kwds) def xp_unsupported_param_msg(param: Any) -> str: return f'Providing {param!r} is only supported for numpy arrays.' def is_complex(x: Array, xp: ModuleType) -> bool: return xp.isdtype(x.dtype, 'complex floating') def get_native_namespace_name(xp: ModuleType) -> str: """Return name for native namespace (without array_api_compat prefix).""" name = xp.__name__ return name.removeprefix(f"{_compat_module_name()}.") def scipy_namespace_for(xp: ModuleType) -> ModuleType | None: """Return the `scipy`-like namespace of a non-NumPy backend That is, return the namespace corresponding with backend `xp` that contains `scipy` sub-namespaces like `linalg` and `special`. If no such namespace exists, return ``None``. Useful for dispatching. """ if is_cupy(xp): import cupyx # type: ignore[import-not-found,import-untyped] return cupyx.scipy if is_jax(xp): import jax # type: ignore[import-not-found] return jax.scipy if is_torch(xp): return xp return None # maybe use `scipy.linalg` if/when array API support is added def xp_vector_norm(x: Array, /, *, axis: int | tuple[int] | None = None, keepdims: bool = False, ord: int | float = 2, xp: ModuleType | None = None) -> Array: xp = array_namespace(x) if xp is None else xp if SCIPY_ARRAY_API: # check for optional `linalg` extension if hasattr(xp, 'linalg'): return xp.linalg.vector_norm(x, axis=axis, keepdims=keepdims, ord=ord) else: if ord != 2: raise ValueError( "only the Euclidean norm (`ord=2`) is currently supported in " "`xp_vector_norm` for backends not implementing the `linalg` " "extension." ) # return (x @ x)**0.5 # or to get the right behavior with nd, complex arrays return xp.sum(xp.conj(x) * x, axis=axis, keepdims=keepdims)**0.5 else: # to maintain backwards compatibility return np.linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims) def xp_ravel(x: Array, /, *, xp: ModuleType | None = None) -> Array: # Equivalent of np.ravel written in terms of array API # Even though it's one line, it comes up so often that it's worth having # this function for readability xp = array_namespace(x) if xp is None else xp return xp.reshape(x, (-1,)) def xp_swapaxes(a, axis1, axis2, xp=None): # Equivalent of np.swapaxes written in terms of array API xp = array_namespace(a) if xp is None else xp axes = list(range(a.ndim)) axes[axis1], axes[axis2] = axes[axis2], axes[axis1] a = xp.permute_dims(a, axes) return a # utility to find common dtype with option to force floating def xp_result_type(*args, force_floating=False, xp): """ Returns the dtype that results from applying type promotion rules (see Array API Standard Type Promotion Rules) to the arguments. Augments standard `result_type` in a few ways: - There is a `force_floating` argument that ensures that the result type is floating point, even when all args are integer. - When a TypeError is raised (e.g. due to an unsupported promotion) and `force_floating=True`, we define a custom rule: use the result type of the default float and any other floats passed. See https://github.com/scipy/scipy/pull/22695/files#r1997905891 for rationale. - This function accepts array-like iterables, which are immediately converted to the namespace's arrays before result type calculation. Consequently, the result dtype may be different when an argument is `1.` vs `[1.]`. Typically, this function will be called shortly after `array_namespace` on a subset of the arguments passed to `array_namespace`. """ # prevent double conversion of iterable to array # avoid `np.iterable` for torch arrays due to pytorch/pytorch#143334 # don't use `array_api_compat.is_array_api_obj` as it returns True for NumPy scalars args = [(_asarray(arg, subok=True, xp=xp) if is_torch_array(arg) or np.iterable(arg) else arg) for arg in args] args_not_none = [arg for arg in args if arg is not None] if force_floating: args_not_none.append(1.0) if is_numpy(xp) and xp.__version__ < '2.0': # Follow NEP 50 promotion rules anyway args_not_none = [arg.dtype if getattr(arg, 'size', 0) == 1 else arg for arg in args_not_none] return xp.result_type(*args_not_none) try: # follow library's preferred promotion rules return xp.result_type(*args_not_none) except TypeError: # mixed type promotion isn't defined if not force_floating: raise # use `result_type` of default floating point type and any floats present # This can be revisited, but right now, the only backends that get here # are array-api-strict (which is not for production use) and PyTorch # (due to data-apis/array-api-compat#279). float_args = [] for arg in args_not_none: arg_array = xp.asarray(arg) if np.isscalar(arg) else arg dtype = getattr(arg_array, 'dtype', arg) if xp.isdtype(dtype, ('real floating', 'complex floating')): float_args.append(arg) return xp.result_type(*float_args, xp_default_dtype(xp)) def xp_promote(*args, broadcast=False, force_floating=False, xp): """ Promotes elements of *args to result dtype, ignoring `None`s. Includes options for forcing promotion to floating point and broadcasting the arrays, again ignoring `None`s. Type promotion rules follow `xp_result_type` instead of `xp.result_type`. Typically, this function will be called shortly after `array_namespace` on a subset of the arguments passed to `array_namespace`. This function accepts array-like iterables, which are immediately converted to the namespace's arrays before result type calculation. Consequently, the result dtype may be different when an argument is `1.` vs `[1.]`. See Also -------- xp_result_type """ if not args: return args # prevent double conversion of iterable to array # avoid `np.iterable` for torch arrays due to pytorch/pytorch#143334 # don't use `array_api_compat.is_array_api_obj` as it returns True for NumPy scalars args = [(_asarray(arg, subok=True, xp=xp) if is_torch_array(arg) or np.iterable(arg) else arg) for arg in args] dtype = xp_result_type(*args, force_floating=force_floating, xp=xp) args = [(_asarray(arg, dtype=dtype, subok=True, xp=xp) if arg is not None else arg) for arg in args] if not broadcast: return args[0] if len(args)==1 else tuple(args) args_not_none = [arg for arg in args if arg is not None] # determine result shape shapes = {arg.shape for arg in args_not_none} try: shape = (np.broadcast_shapes(*shapes) if len(shapes) != 1 else args_not_none[0].shape) except ValueError as e: message = "Array shapes are incompatible for broadcasting." raise ValueError(message) from e out = [] for arg in args: if arg is None: out.append(arg) continue # broadcast only if needed # Even if two arguments need broadcasting, this is faster than # `broadcast_arrays`, especially since we've already determined `shape` if arg.shape != shape: kwargs = {'subok': True} if is_numpy(xp) else {} arg = xp.broadcast_to(arg, shape, **kwargs) # This is much faster than xp.astype(arg, dtype, copy=False) if arg.dtype != dtype: arg = xp.astype(arg, dtype) out.append(arg) return out[0] if len(out)==1 else tuple(out) def xp_float_to_complex(arr: Array, xp: ModuleType | None = None) -> Array: xp = array_namespace(arr) if xp is None else xp arr_dtype = arr.dtype # The standard float dtypes are float32 and float64. # Convert float32 to complex64, # and float64 (and non-standard real dtypes) to complex128 if xp.isdtype(arr_dtype, xp.float32): arr = xp.astype(arr, xp.complex64) elif xp.isdtype(arr_dtype, 'real floating'): arr = xp.astype(arr, xp.complex128) return arr def xp_default_dtype(xp): """Query the namespace-dependent default floating-point dtype. """ if is_torch(xp): # historically, we allow pytorch to keep its default of float32 return xp.get_default_dtype() else: # we default to float64 return xp.float64 ### MArray Helpers ### def xp_result_device(*args): """Return the device of an array in `args`, for the purpose of input-output device propagation. If there are multiple devices, return an arbitrary one. If there are no arrays, return None (this typically happens only on NumPy). """ for arg in args: # Do not do a duck-type test for the .device attribute, as many backends today # don't have it yet. See workarouunds in array_api_compat.device(). if is_array_api_obj(arg): return xp_device(arg) return None # np.r_ replacement def concat_1d(xp: ModuleType | None, *arrays: Iterable[ArrayLike]) -> Array: """A replacement for `np.r_` as `xp.concat` does not accept python scalars or 0-D arrays. """ arys = [xpx.atleast_nd(xp.asarray(a), ndim=1, xp=xp) for a in arrays] return xp.concat(arys) def is_marray(xp): """Returns True if `xp` is an MArray namespace; False otherwise.""" return "marray" in xp.__name__ def _length_nonmasked(x, axis, keepdims=False, xp=None): xp = array_namespace(x) if xp is None else xp if is_marray(xp): if np.iterable(axis): message = '`axis` must be an integer or None for use with `MArray`.' raise NotImplementedError(message) return xp.astype(xp.count(x, axis=axis, keepdims=keepdims), x.dtype) return (xp_size(x) if axis is None else # compact way to deal with axis tuples or ints int(np.prod(np.asarray(x.shape)[np.asarray(axis)]))) def _share_masks(*args, xp): if is_marray(xp): mask = functools.reduce(operator.or_, (arg.mask for arg in args)) args = [xp.asarray(arg.data, mask=mask) for arg in args] return args[0] if len(args) == 1 else args ### End MArray Helpers ### @dataclasses.dataclass(repr=False) class _XPSphinxCapability: cpu: bool | None # None if not applicable gpu: bool | None warnings: list[str] = dataclasses.field(default_factory=list) def _render(self, value): if value is None: return "n/a" if not value: return "⛔" if self.warnings: res = "⚠️ " + '; '.join(self.warnings) assert len(res) <= 20, "Warnings too long" return res return "✅" def __str__(self): cpu = self._render(self.cpu) gpu = self._render(self.gpu) return f"{cpu:20} {gpu:20}" def _make_sphinx_capabilities( # lists of tuples [(module name, reason), ...] skip_backends=(), xfail_backends=(), # @pytest.mark.skip/xfail_xp_backends kwargs cpu_only=False, np_only=False, out_of_scope=False, exceptions=(), # xpx.lazy_xp_backends kwargs allow_dask_compute=False, jax_jit=True, # list of tuples [(module name, reason), ...] warnings = (), # unused in documentation reason=None, ): if out_of_scope: return {"out_of_scope": True} exceptions = set(exceptions) # Default capabilities capabilities = { "numpy": _XPSphinxCapability(cpu=True, gpu=None), "array_api_strict": _XPSphinxCapability(cpu=True, gpu=None), "cupy": _XPSphinxCapability(cpu=None, gpu=True), "torch": _XPSphinxCapability(cpu=True, gpu=True), "jax.numpy": _XPSphinxCapability(cpu=True, gpu=True, warnings=[] if jax_jit else ["no JIT"]), # Note: Dask+CuPy is currently untested and unsupported "dask.array": _XPSphinxCapability(cpu=True, gpu=None, warnings=["computes graph"] if allow_dask_compute else []), } # documentation doesn't display the reason for module, _ in list(skip_backends) + list(xfail_backends): backend = capabilities[module] if backend.cpu is not None: backend.cpu = False if backend.gpu is not None: backend.gpu = False for module, backend in capabilities.items(): if np_only and module not in exceptions | {"numpy"}: if backend.cpu is not None: backend.cpu = False if backend.gpu is not None: backend.gpu = False elif cpu_only and module not in exceptions and backend.gpu is not None: backend.gpu = False for module, warning in warnings: backend = capabilities[module] backend.warnings.append(warning) return capabilities def _make_capabilities_note(fun_name, capabilities, extra_note=None): if "out_of_scope" in capabilities: # It will be better to link to a section of the dev-arrayapi docs # that explains what is and isn't in-scope, but such a section # doesn't exist yet. Using :ref:`dev-arrayapi` as a placeholder. note = f""" **Array API Standard Support** `{fun_name}` is not in-scope for support of Python Array API Standard compatible backends other than NumPy. See :ref:`dev-arrayapi` for more information. """ return textwrap.dedent(note) # Note: deliberately not documenting array-api-strict note = f""" **Array API Standard Support** `{fun_name}` has experimental support for Python Array API Standard compatible backends in addition to NumPy. Please consider testing these features by setting an environment variable ``SCIPY_ARRAY_API=1`` and providing CuPy, PyTorch, JAX, or Dask arrays as array arguments. The following combinations of backend and device (or other capability) are supported. ==================== ==================== ==================== Library CPU GPU ==================== ==================== ==================== NumPy {capabilities['numpy'] } CuPy {capabilities['cupy'] } PyTorch {capabilities['torch'] } JAX {capabilities['jax.numpy'] } Dask {capabilities['dask.array'] } ==================== ==================== ==================== """ + (extra_note or "") + " See :ref:`dev-arrayapi` for more information." return textwrap.dedent(note) def xp_capabilities( *, # Alternative capabilities table. # Used only for testing this decorator. capabilities_table=None, # Generate pytest.mark.skip/xfail_xp_backends. # See documentation in conftest.py. # lists of tuples [(module name, reason), ...] skip_backends=(), xfail_backends=(), cpu_only=False, np_only=False, reason=None, out_of_scope=False, exceptions=(), # lists of tuples [(module name, reason), ...] warnings=(), # xpx.testing.lazy_xp_function kwargs. # Refer to array-api-extra documentation. allow_dask_compute=False, jax_jit=True, # Extra note to inject into the docstring extra_note=None, ): """Decorator for a function that states its support among various Array API compatible backends. This decorator has two effects: 1. It allows tagging tests with ``@make_xp_test_case`` or ``make_xp_pytest_param`` (see below) to automatically generate SKIP/XFAIL markers and perform additional backend-specific testing, such as extra validation for Dask and JAX; 2. It automatically adds a note to the function's docstring, containing a table matching what has been tested. See Also -------- make_xp_test_case make_xp_pytest_param array_api_extra.testing.lazy_xp_function """ capabilities_table = (xp_capabilities_table if capabilities_table is None else capabilities_table) if out_of_scope: np_only = True capabilities = dict( skip_backends=skip_backends, xfail_backends=xfail_backends, cpu_only=cpu_only, np_only=np_only, out_of_scope=out_of_scope, reason=reason, exceptions=exceptions, allow_dask_compute=allow_dask_compute, jax_jit=jax_jit, warnings=warnings, ) sphinx_capabilities = _make_sphinx_capabilities(**capabilities) def decorator(f): # Don't use a wrapper, as in some cases @xp_capabilities is # applied to a ufunc capabilities_table[f] = capabilities note = _make_capabilities_note(f.__name__, sphinx_capabilities, extra_note) doc = FunctionDoc(f) doc['Notes'].append(note) doc = str(doc).split("\n", 1)[1].lstrip(" \n") # remove signature try: f.__doc__ = doc except AttributeError: # Can't update __doc__ on ufuncs if SciPy # was compiled against NumPy < 2.2. pass return f return decorator def make_xp_test_case(*funcs, capabilities_table=None): capabilities_table = (xp_capabilities_table if capabilities_table is None else capabilities_table) """Generate pytest decorator for a test function that tests functionality of one or more Array API compatible functions. Read the parameters of the ``@xp_capabilities`` decorator applied to the listed functions and: - Generate the ``@pytest.mark.skip_xp_backends`` and ``@pytest.mark.xfail_xp_backends`` decorators for the decorated test function - Tag the function with `xpx.testing.lazy_xp_function` Example:: @make_xp_test_case(f1) def test_f1(xp): ... @make_xp_test_case(f2) def test_f2(xp): ... @make_xp_test_case(f1, f2) def test_f1_and_f2(xp): ... The above is equivalent to:: @pytest.mark.skip_xp_backends(...) @pytest.mark.skip_xp_backends(...) @pytest.mark.xfail_xp_backends(...) @pytest.mark.xfail_xp_backends(...) def test_f1(xp): ... etc., where the arguments of ``skip_xp_backends`` and ``xfail_xp_backends`` are determined by the ``@xp_capabilities`` decorator applied to the functions. See Also -------- xp_capabilities make_xp_pytest_marks make_xp_pytest_param array_api_extra.testing.lazy_xp_function """ marks = make_xp_pytest_marks(*funcs, capabilities_table=capabilities_table) return lambda func: functools.reduce(lambda f, g: g(f), marks, func) def make_xp_pytest_param(func, *args, capabilities_table=None): """Variant of ``make_xp_test_case`` that returns a pytest.param for a function, with all necessary skip_xp_backends and xfail_xp_backends marks applied:: @pytest.mark.parametrize( "func", [make_xp_pytest_param(f1), make_xp_pytest_param(f2)] ) def test(func, xp): ... The above is equivalent to:: @pytest.mark.parametrize( "func", [ pytest.param(f1, marks=[ pytest.mark.skip_xp_backends(...), pytest.mark.xfail_xp_backends(...), ...]), pytest.param(f2, marks=[ pytest.mark.skip_xp_backends(...), pytest.mark.xfail_xp_backends(...), ...]), ) def test(func, xp): ... Parameters ---------- func : Callable Function to be tested. It must be decorated with ``@xp_capabilities``. *args : Any, optional Extra pytest parameters for the use case, e.g.:: @pytest.mark.parametrize("func,verb", [ make_xp_pytest_param(f1, "hello"), make_xp_pytest_param(f2, "world")]) def test(func, verb, xp): # iterates on (func=f1, verb="hello") # and (func=f2, verb="world") See Also -------- xp_capabilities make_xp_test_case make_xp_pytest_marks array_api_extra.testing.lazy_xp_function """ import pytest marks = make_xp_pytest_marks(func, capabilities_table=capabilities_table) return pytest.param(func, *args, marks=marks, id=func.__name__) def make_xp_pytest_marks(*funcs, capabilities_table=None): """Variant of ``make_xp_test_case`` that returns a list of pytest marks, which can be used with the module-level `pytestmark = ...` variable:: pytestmark = make_xp_pytest_marks(f1, f2) def test(xp): ... In this example, the whole test module is dedicated to testing `f1` or `f2`, and the two functions have the same capabilities, so it's unnecessary to cherry-pick which test tests which function. The above is equivalent to:: pytestmark = [ pytest.mark.skip_xp_backends(...), pytest.mark.xfail_xp_backends(...), ...]), ] def test(xp): ... See Also -------- xp_capabilities make_xp_test_case make_xp_pytest_param array_api_extra.testing.lazy_xp_function """ capabilities_table = (xp_capabilities_table if capabilities_table is None else capabilities_table) import pytest marks = [] for func in funcs: capabilities = capabilities_table[func] exceptions = capabilities['exceptions'] reason = capabilities['reason'] if capabilities['cpu_only']: marks.append(pytest.mark.skip_xp_backends( cpu_only=True, exceptions=exceptions, reason=reason)) if capabilities['np_only']: marks.append(pytest.mark.skip_xp_backends( np_only=True, exceptions=exceptions, reason=reason)) for mod_name, reason in capabilities['skip_backends']: marks.append(pytest.mark.skip_xp_backends(mod_name, reason=reason)) for mod_name, reason in capabilities['xfail_backends']: marks.append(pytest.mark.xfail_xp_backends(mod_name, reason=reason)) lazy_kwargs = {k: capabilities[k] for k in ('allow_dask_compute', 'jax_jit')} lazy_xp_function(func, **lazy_kwargs) return marks # Is it OK to have a dictionary that is mutated (once upon import) in many places? xp_capabilities_table = {} # type: ignore[var-annotated] def xp_device_type(a: Array) -> Literal["cpu", "cuda", None]: if is_numpy_array(a): return "cpu" if is_cupy_array(a): return "cuda" if is_torch_array(a): # TODO this can return other backends e.g. tpu but they're unsupported in scipy return a.device.type if is_jax_array(a): # TODO this can return other backends e.g. tpu but they're unsupported in scipy return "cuda" if (p := a.device.platform) == "gpu" else p if is_dask_array(a): return xp_device_type(a._meta) # array-api-strict is a stand-in for unknown libraries; don't special-case it return None