You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
1037 lines
37 KiB
1037 lines
37 KiB
"""Utility functions to use Python Array API compatible libraries.
|
|
|
|
For the context about the Array API see:
|
|
https://data-apis.org/array-api/latest/purpose_and_scope.html
|
|
|
|
The SciPy use case of the Array API is described on the following page:
|
|
https://data-apis.org/array-api/latest/use_cases.html#use-case-scipy
|
|
"""
|
|
import operator
|
|
import dataclasses
|
|
import functools
|
|
import textwrap
|
|
|
|
from collections.abc import Generator
|
|
from contextlib import contextmanager
|
|
from contextvars import ContextVar
|
|
from types import ModuleType
|
|
from typing import Any, Literal, TypeAlias
|
|
from collections.abc import Iterable
|
|
|
|
import numpy as np
|
|
import numpy.typing as npt
|
|
|
|
from scipy._lib.array_api_compat import (
|
|
is_array_api_obj,
|
|
is_lazy_array,
|
|
is_numpy_array,
|
|
is_cupy_array,
|
|
is_torch_array,
|
|
is_jax_array,
|
|
is_dask_array,
|
|
size as xp_size,
|
|
numpy as np_compat,
|
|
device as xp_device,
|
|
is_numpy_namespace as is_numpy,
|
|
is_cupy_namespace as is_cupy,
|
|
is_torch_namespace as is_torch,
|
|
is_jax_namespace as is_jax,
|
|
is_dask_namespace as is_dask,
|
|
is_array_api_strict_namespace as is_array_api_strict,
|
|
)
|
|
from scipy._lib.array_api_compat.common._helpers import _compat_module_name
|
|
from scipy._lib.array_api_extra.testing import lazy_xp_function
|
|
from scipy._lib._array_api_override import (
|
|
array_namespace, SCIPY_ARRAY_API, SCIPY_DEVICE
|
|
)
|
|
from scipy._lib._docscrape import FunctionDoc
|
|
from scipy._lib import array_api_extra as xpx
|
|
|
|
|
|
__all__ = [
|
|
'_asarray', 'array_namespace', 'assert_almost_equal', 'assert_array_almost_equal',
|
|
'default_xp', 'eager_warns', 'is_lazy_array', 'is_marray',
|
|
'is_array_api_strict', 'is_complex', 'is_cupy', 'is_jax', 'is_numpy', 'is_torch',
|
|
'np_compat', 'get_native_namespace_name',
|
|
'SCIPY_ARRAY_API', 'SCIPY_DEVICE', 'scipy_namespace_for',
|
|
'xp_assert_close', 'xp_assert_equal', 'xp_assert_less',
|
|
'xp_copy', 'xp_device', 'xp_ravel', 'xp_size',
|
|
'xp_unsupported_param_msg', 'xp_vector_norm', 'xp_capabilities',
|
|
'xp_result_type', 'xp_promote',
|
|
'make_xp_test_case', 'make_xp_pytest_marks', 'make_xp_pytest_param',
|
|
]
|
|
|
|
|
|
Array: TypeAlias = Any # To be changed to a Protocol later (see array-api#589)
|
|
ArrayLike: TypeAlias = Array | npt.ArrayLike
|
|
|
|
|
|
def _check_finite(array: Array, xp: ModuleType) -> None:
|
|
"""Check for NaNs or Infs."""
|
|
if not xp.all(xp.isfinite(array)):
|
|
msg = "array must not contain infs or NaNs"
|
|
raise ValueError(msg)
|
|
|
|
def _asarray(
|
|
array: ArrayLike,
|
|
dtype: Any = None,
|
|
order: Literal['K', 'A', 'C', 'F'] | None = None,
|
|
copy: bool | None = None,
|
|
*,
|
|
xp: ModuleType | None = None,
|
|
check_finite: bool = False,
|
|
subok: bool = False,
|
|
) -> Array:
|
|
"""SciPy-specific replacement for `np.asarray` with `order`, `check_finite`, and
|
|
`subok`.
|
|
|
|
Memory layout parameter `order` is not exposed in the Array API standard.
|
|
`order` is only enforced if the input array implementation
|
|
is NumPy based, otherwise `order` is just silently ignored.
|
|
|
|
`check_finite` is also not a keyword in the array API standard; included
|
|
here for convenience rather than that having to be a separate function
|
|
call inside SciPy functions.
|
|
|
|
`subok` is included to allow this function to preserve the behaviour of
|
|
`np.asanyarray` for NumPy based inputs.
|
|
"""
|
|
if xp is None:
|
|
xp = array_namespace(array)
|
|
if is_numpy(xp):
|
|
# Use NumPy API to support order
|
|
if copy is True:
|
|
array = np.array(array, order=order, dtype=dtype, subok=subok)
|
|
elif subok:
|
|
array = np.asanyarray(array, order=order, dtype=dtype)
|
|
else:
|
|
array = np.asarray(array, order=order, dtype=dtype)
|
|
else:
|
|
try:
|
|
array = xp.asarray(array, dtype=dtype, copy=copy)
|
|
except TypeError:
|
|
coerced_xp = array_namespace(xp.asarray(3))
|
|
array = coerced_xp.asarray(array, dtype=dtype, copy=copy)
|
|
|
|
if check_finite:
|
|
_check_finite(array, xp)
|
|
|
|
return array
|
|
|
|
|
|
def xp_copy(x: Array, *, xp: ModuleType | None = None) -> Array:
|
|
"""
|
|
Copies an array.
|
|
|
|
Parameters
|
|
----------
|
|
x : array
|
|
|
|
xp : array_namespace
|
|
|
|
Returns
|
|
-------
|
|
copy : array
|
|
Copied array
|
|
|
|
Notes
|
|
-----
|
|
This copy function does not offer all the semantics of `np.copy`, i.e. the
|
|
`subok` and `order` keywords are not used.
|
|
"""
|
|
# Note: for older NumPy versions, `np.asarray` did not support the `copy` kwarg,
|
|
# so this uses our other helper `_asarray`.
|
|
if xp is None:
|
|
xp = array_namespace(x)
|
|
|
|
return _asarray(x, copy=True, xp=xp)
|
|
|
|
|
|
def _xp_copy_to_numpy(x: Array) -> np.ndarray:
|
|
"""Copies a possibly on device array to a NumPy array.
|
|
|
|
This function is intended only for converting alternative backend
|
|
arrays to numpy arrays within test code, to make it easier for use
|
|
of the alternative backend to be isolated only to the function being
|
|
tested. `_xp_copy_to_numpy` should NEVER be used except in test code
|
|
for the specific purpose mentioned above. In production code, attempts
|
|
to copy device arrays to NumPy arrays should fail, or else functions
|
|
may appear to be working on the GPU when they actually aren't.
|
|
|
|
Parameters
|
|
----------
|
|
x : array
|
|
|
|
Returns
|
|
-------
|
|
ndarray
|
|
"""
|
|
xp = array_namespace(x)
|
|
if is_numpy(xp):
|
|
return x.copy()
|
|
if is_cupy(xp):
|
|
return x.get()
|
|
if is_torch(xp):
|
|
return x.cpu().numpy()
|
|
if is_array_api_strict(xp):
|
|
# array api strict supports multiple devices, so need to
|
|
# ensure x is on the cpu before copying to NumPy.
|
|
return np.asarray(
|
|
xp.asarray(x, device=xp.Device("CPU_DEVICE")), copy=True
|
|
)
|
|
# Fall back to np.asarray. This works for dask.array. It
|
|
# currently works for jax.numpy, but hopefully JAX will make
|
|
# the transfer guard workable enough for use in scipy tests, in
|
|
# which case, JAX will have to be handled explicitly.
|
|
# If new backends are added, they may require explicit handling as
|
|
# well.
|
|
return np.asarray(x, copy=True)
|
|
|
|
|
|
_default_xp_ctxvar: ContextVar[ModuleType] = ContextVar("_default_xp")
|
|
|
|
@contextmanager
|
|
def default_xp(xp: ModuleType) -> Generator[None, None, None]:
|
|
"""In all ``xp_assert_*`` and ``assert_*`` function calls executed within this
|
|
context manager, test by default that the array namespace is
|
|
the provided across all arrays, unless one explicitly passes the ``xp=``
|
|
parameter or ``check_namespace=False``.
|
|
|
|
Without this context manager, the default value for `xp` is the namespace
|
|
for the desired array (the second parameter of the tests).
|
|
"""
|
|
token = _default_xp_ctxvar.set(xp)
|
|
try:
|
|
yield
|
|
finally:
|
|
_default_xp_ctxvar.reset(token)
|
|
|
|
|
|
def eager_warns(warning_type, *, match=None, xp):
|
|
"""pytest.warns context manager if arrays of specified namespace are always eager.
|
|
|
|
Otherwise, context manager that *ignores* specified warning.
|
|
"""
|
|
import pytest
|
|
from scipy._lib._util import ignore_warns
|
|
if is_numpy(xp) or is_array_api_strict(xp) or is_cupy(xp):
|
|
return pytest.warns(warning_type, match=match)
|
|
return ignore_warns(warning_type, match='' if match is None else match)
|
|
|
|
|
|
def _strict_check(actual, desired, xp, *,
|
|
check_namespace=True, check_dtype=True, check_shape=True,
|
|
check_0d=True):
|
|
__tracebackhide__ = True # Hide traceback for py.test
|
|
|
|
if xp is None:
|
|
try:
|
|
xp = _default_xp_ctxvar.get()
|
|
except LookupError:
|
|
xp = array_namespace(desired)
|
|
|
|
if check_namespace:
|
|
_assert_matching_namespace(actual, desired, xp)
|
|
|
|
# only NumPy distinguishes between scalars and arrays; we do if check_0d=True.
|
|
# do this first so we can then cast to array (and thus use the array API) below.
|
|
if is_numpy(xp) and check_0d:
|
|
_msg = ("Array-ness does not match:\n Actual: "
|
|
f"{type(actual)}\n Desired: {type(desired)}")
|
|
assert ((xp.isscalar(actual) and xp.isscalar(desired))
|
|
or (not xp.isscalar(actual) and not xp.isscalar(desired))), _msg
|
|
|
|
actual = xp.asarray(actual)
|
|
desired = xp.asarray(desired)
|
|
|
|
if check_dtype:
|
|
_msg = f"dtypes do not match.\nActual: {actual.dtype}\nDesired: {desired.dtype}"
|
|
assert actual.dtype == desired.dtype, _msg
|
|
|
|
if check_shape:
|
|
if is_dask(xp):
|
|
actual.compute_chunk_sizes()
|
|
desired.compute_chunk_sizes()
|
|
_msg = f"Shapes do not match.\nActual: {actual.shape}\nDesired: {desired.shape}"
|
|
assert actual.shape == desired.shape, _msg
|
|
|
|
desired = xp.broadcast_to(desired, actual.shape)
|
|
return actual, desired, xp
|
|
|
|
|
|
def _assert_matching_namespace(actual, desired, xp):
|
|
__tracebackhide__ = True # Hide traceback for py.test
|
|
|
|
desired_arr_space = array_namespace(desired)
|
|
_msg = ("Namespace of desired array does not match expectations "
|
|
"set by the `default_xp` context manager or by the `xp`"
|
|
"pytest fixture.\n"
|
|
f"Desired array's space: {desired_arr_space.__name__}\n"
|
|
f"Expected namespace: {xp.__name__}")
|
|
assert desired_arr_space == xp, _msg
|
|
|
|
actual_arr_space = array_namespace(actual)
|
|
_msg = ("Namespace of actual and desired arrays do not match.\n"
|
|
f"Actual: {actual_arr_space.__name__}\n"
|
|
f"Desired: {xp.__name__}")
|
|
assert actual_arr_space == xp, _msg
|
|
|
|
|
|
def xp_assert_equal(actual, desired, *, check_namespace=True, check_dtype=True,
|
|
check_shape=True, check_0d=True, err_msg='', xp=None):
|
|
__tracebackhide__ = True # Hide traceback for py.test
|
|
|
|
actual, desired, xp = _strict_check(
|
|
actual, desired, xp, check_namespace=check_namespace,
|
|
check_dtype=check_dtype, check_shape=check_shape,
|
|
check_0d=check_0d
|
|
)
|
|
|
|
if is_cupy(xp):
|
|
return xp.testing.assert_array_equal(actual, desired, err_msg=err_msg)
|
|
elif is_torch(xp):
|
|
# PyTorch recommends using `rtol=0, atol=0` like this
|
|
# to test for exact equality
|
|
err_msg = None if err_msg == '' else err_msg
|
|
return xp.testing.assert_close(actual, desired, rtol=0, atol=0, equal_nan=True,
|
|
check_dtype=False, msg=err_msg)
|
|
# JAX uses `np.testing`
|
|
return np.testing.assert_array_equal(actual, desired, err_msg=err_msg)
|
|
|
|
|
|
def xp_assert_close(actual, desired, *, rtol=None, atol=0, check_namespace=True,
|
|
check_dtype=True, check_shape=True, check_0d=True,
|
|
err_msg='', xp=None):
|
|
__tracebackhide__ = True # Hide traceback for py.test
|
|
|
|
actual, desired, xp = _strict_check(
|
|
actual, desired, xp,
|
|
check_namespace=check_namespace, check_dtype=check_dtype,
|
|
check_shape=check_shape, check_0d=check_0d
|
|
)
|
|
|
|
floating = xp.isdtype(actual.dtype, ('real floating', 'complex floating'))
|
|
if rtol is None and floating:
|
|
# multiplier of 4 is used as for `np.float64` this puts the default `rtol`
|
|
# roughly half way between sqrt(eps) and the default for
|
|
# `numpy.testing.assert_allclose`, 1e-7
|
|
rtol = xp.finfo(actual.dtype).eps**0.5 * 4
|
|
elif rtol is None:
|
|
rtol = 1e-7
|
|
|
|
if is_cupy(xp):
|
|
return xp.testing.assert_allclose(actual, desired, rtol=rtol,
|
|
atol=atol, err_msg=err_msg)
|
|
elif is_torch(xp):
|
|
err_msg = None if err_msg == '' else err_msg
|
|
return xp.testing.assert_close(actual, desired, rtol=rtol, atol=atol,
|
|
equal_nan=True, check_dtype=False, msg=err_msg)
|
|
# JAX uses `np.testing`
|
|
return np.testing.assert_allclose(actual, desired, rtol=rtol,
|
|
atol=atol, err_msg=err_msg)
|
|
|
|
|
|
def xp_assert_close_nulp(actual, desired, *, nulp=1, check_namespace=True,
|
|
check_dtype=True, check_shape=True, check_0d=True,
|
|
err_msg='', xp=None):
|
|
__tracebackhide__ = True # Hide traceback for py.test
|
|
|
|
actual, desired, xp = _strict_check(
|
|
actual, desired, xp,
|
|
check_namespace=check_namespace, check_dtype=check_dtype,
|
|
check_shape=check_shape, check_0d=check_0d
|
|
)
|
|
|
|
actual, desired = map(_xp_copy_to_numpy, (actual, desired))
|
|
return np.testing.assert_array_almost_equal_nulp(actual, desired, nulp=nulp)
|
|
|
|
|
|
def xp_assert_less(actual, desired, *, check_namespace=True, check_dtype=True,
|
|
check_shape=True, check_0d=True, err_msg='', verbose=True, xp=None):
|
|
__tracebackhide__ = True # Hide traceback for py.test
|
|
|
|
actual, desired, xp = _strict_check(
|
|
actual, desired, xp, check_namespace=check_namespace,
|
|
check_dtype=check_dtype, check_shape=check_shape,
|
|
check_0d=check_0d
|
|
)
|
|
|
|
if is_cupy(xp):
|
|
return xp.testing.assert_array_less(actual, desired,
|
|
err_msg=err_msg, verbose=verbose)
|
|
elif is_torch(xp):
|
|
if actual.device.type != 'cpu':
|
|
actual = actual.cpu()
|
|
if desired.device.type != 'cpu':
|
|
desired = desired.cpu()
|
|
# JAX uses `np.testing`
|
|
return np.testing.assert_array_less(actual, desired,
|
|
err_msg=err_msg, verbose=verbose)
|
|
|
|
|
|
def assert_array_almost_equal(actual, desired, decimal=6, *args, **kwds):
|
|
"""Backwards compatible replacement. In new code, use xp_assert_close instead.
|
|
"""
|
|
rtol, atol = 0, 1.5*10**(-decimal)
|
|
return xp_assert_close(actual, desired,
|
|
atol=atol, rtol=rtol, check_dtype=False, check_shape=False,
|
|
*args, **kwds)
|
|
|
|
|
|
def assert_almost_equal(actual, desired, decimal=7, *args, **kwds):
|
|
"""Backwards compatible replacement. In new code, use xp_assert_close instead.
|
|
"""
|
|
rtol, atol = 0, 1.5*10**(-decimal)
|
|
return xp_assert_close(actual, desired,
|
|
atol=atol, rtol=rtol, check_dtype=False, check_shape=False,
|
|
*args, **kwds)
|
|
|
|
|
|
def xp_unsupported_param_msg(param: Any) -> str:
|
|
return f'Providing {param!r} is only supported for numpy arrays.'
|
|
|
|
|
|
def is_complex(x: Array, xp: ModuleType) -> bool:
|
|
return xp.isdtype(x.dtype, 'complex floating')
|
|
|
|
|
|
def get_native_namespace_name(xp: ModuleType) -> str:
|
|
"""Return name for native namespace (without array_api_compat prefix)."""
|
|
name = xp.__name__
|
|
return name.removeprefix(f"{_compat_module_name()}.")
|
|
|
|
|
|
def scipy_namespace_for(xp: ModuleType) -> ModuleType | None:
|
|
"""Return the `scipy`-like namespace of a non-NumPy backend
|
|
|
|
That is, return the namespace corresponding with backend `xp` that contains
|
|
`scipy` sub-namespaces like `linalg` and `special`. If no such namespace
|
|
exists, return ``None``. Useful for dispatching.
|
|
"""
|
|
|
|
if is_cupy(xp):
|
|
import cupyx # type: ignore[import-not-found,import-untyped]
|
|
return cupyx.scipy
|
|
|
|
if is_jax(xp):
|
|
import jax # type: ignore[import-not-found]
|
|
return jax.scipy
|
|
|
|
if is_torch(xp):
|
|
return xp
|
|
|
|
return None
|
|
|
|
|
|
# maybe use `scipy.linalg` if/when array API support is added
|
|
def xp_vector_norm(x: Array, /, *,
|
|
axis: int | tuple[int] | None = None,
|
|
keepdims: bool = False,
|
|
ord: int | float = 2,
|
|
xp: ModuleType | None = None) -> Array:
|
|
xp = array_namespace(x) if xp is None else xp
|
|
|
|
if SCIPY_ARRAY_API:
|
|
# check for optional `linalg` extension
|
|
if hasattr(xp, 'linalg'):
|
|
return xp.linalg.vector_norm(x, axis=axis, keepdims=keepdims, ord=ord)
|
|
else:
|
|
if ord != 2:
|
|
raise ValueError(
|
|
"only the Euclidean norm (`ord=2`) is currently supported in "
|
|
"`xp_vector_norm` for backends not implementing the `linalg` "
|
|
"extension."
|
|
)
|
|
# return (x @ x)**0.5
|
|
# or to get the right behavior with nd, complex arrays
|
|
return xp.sum(xp.conj(x) * x, axis=axis, keepdims=keepdims)**0.5
|
|
else:
|
|
# to maintain backwards compatibility
|
|
return np.linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims)
|
|
|
|
|
|
def xp_ravel(x: Array, /, *, xp: ModuleType | None = None) -> Array:
|
|
# Equivalent of np.ravel written in terms of array API
|
|
# Even though it's one line, it comes up so often that it's worth having
|
|
# this function for readability
|
|
xp = array_namespace(x) if xp is None else xp
|
|
return xp.reshape(x, (-1,))
|
|
|
|
|
|
def xp_swapaxes(a, axis1, axis2, xp=None):
|
|
# Equivalent of np.swapaxes written in terms of array API
|
|
xp = array_namespace(a) if xp is None else xp
|
|
axes = list(range(a.ndim))
|
|
axes[axis1], axes[axis2] = axes[axis2], axes[axis1]
|
|
a = xp.permute_dims(a, axes)
|
|
return a
|
|
|
|
|
|
# utility to find common dtype with option to force floating
|
|
def xp_result_type(*args, force_floating=False, xp):
|
|
"""
|
|
Returns the dtype that results from applying type promotion rules
|
|
(see Array API Standard Type Promotion Rules) to the arguments. Augments
|
|
standard `result_type` in a few ways:
|
|
|
|
- There is a `force_floating` argument that ensures that the result type
|
|
is floating point, even when all args are integer.
|
|
- When a TypeError is raised (e.g. due to an unsupported promotion)
|
|
and `force_floating=True`, we define a custom rule: use the result type
|
|
of the default float and any other floats passed. See
|
|
https://github.com/scipy/scipy/pull/22695/files#r1997905891
|
|
for rationale.
|
|
- This function accepts array-like iterables, which are immediately converted
|
|
to the namespace's arrays before result type calculation. Consequently, the
|
|
result dtype may be different when an argument is `1.` vs `[1.]`.
|
|
|
|
Typically, this function will be called shortly after `array_namespace`
|
|
on a subset of the arguments passed to `array_namespace`.
|
|
"""
|
|
# prevent double conversion of iterable to array
|
|
# avoid `np.iterable` for torch arrays due to pytorch/pytorch#143334
|
|
# don't use `array_api_compat.is_array_api_obj` as it returns True for NumPy scalars
|
|
args = [(_asarray(arg, subok=True, xp=xp) if is_torch_array(arg) or np.iterable(arg)
|
|
else arg) for arg in args]
|
|
args_not_none = [arg for arg in args if arg is not None]
|
|
if force_floating:
|
|
args_not_none.append(1.0)
|
|
|
|
if is_numpy(xp) and xp.__version__ < '2.0':
|
|
# Follow NEP 50 promotion rules anyway
|
|
args_not_none = [arg.dtype if getattr(arg, 'size', 0) == 1 else arg
|
|
for arg in args_not_none]
|
|
return xp.result_type(*args_not_none)
|
|
|
|
try: # follow library's preferred promotion rules
|
|
return xp.result_type(*args_not_none)
|
|
except TypeError: # mixed type promotion isn't defined
|
|
if not force_floating:
|
|
raise
|
|
# use `result_type` of default floating point type and any floats present
|
|
# This can be revisited, but right now, the only backends that get here
|
|
# are array-api-strict (which is not for production use) and PyTorch
|
|
# (due to data-apis/array-api-compat#279).
|
|
float_args = []
|
|
for arg in args_not_none:
|
|
arg_array = xp.asarray(arg) if np.isscalar(arg) else arg
|
|
dtype = getattr(arg_array, 'dtype', arg)
|
|
if xp.isdtype(dtype, ('real floating', 'complex floating')):
|
|
float_args.append(arg)
|
|
return xp.result_type(*float_args, xp_default_dtype(xp))
|
|
|
|
|
|
def xp_promote(*args, broadcast=False, force_floating=False, xp):
|
|
"""
|
|
Promotes elements of *args to result dtype, ignoring `None`s.
|
|
Includes options for forcing promotion to floating point and
|
|
broadcasting the arrays, again ignoring `None`s.
|
|
Type promotion rules follow `xp_result_type` instead of `xp.result_type`.
|
|
|
|
Typically, this function will be called shortly after `array_namespace`
|
|
on a subset of the arguments passed to `array_namespace`.
|
|
|
|
This function accepts array-like iterables, which are immediately converted
|
|
to the namespace's arrays before result type calculation. Consequently, the
|
|
result dtype may be different when an argument is `1.` vs `[1.]`.
|
|
|
|
See Also
|
|
--------
|
|
xp_result_type
|
|
"""
|
|
if not args:
|
|
return args
|
|
|
|
# prevent double conversion of iterable to array
|
|
# avoid `np.iterable` for torch arrays due to pytorch/pytorch#143334
|
|
# don't use `array_api_compat.is_array_api_obj` as it returns True for NumPy scalars
|
|
args = [(_asarray(arg, subok=True, xp=xp) if is_torch_array(arg) or np.iterable(arg)
|
|
else arg) for arg in args]
|
|
|
|
dtype = xp_result_type(*args, force_floating=force_floating, xp=xp)
|
|
|
|
args = [(_asarray(arg, dtype=dtype, subok=True, xp=xp) if arg is not None else arg)
|
|
for arg in args]
|
|
|
|
if not broadcast:
|
|
return args[0] if len(args)==1 else tuple(args)
|
|
|
|
args_not_none = [arg for arg in args if arg is not None]
|
|
|
|
# determine result shape
|
|
shapes = {arg.shape for arg in args_not_none}
|
|
try:
|
|
shape = (np.broadcast_shapes(*shapes) if len(shapes) != 1
|
|
else args_not_none[0].shape)
|
|
except ValueError as e:
|
|
message = "Array shapes are incompatible for broadcasting."
|
|
raise ValueError(message) from e
|
|
|
|
out = []
|
|
for arg in args:
|
|
if arg is None:
|
|
out.append(arg)
|
|
continue
|
|
|
|
# broadcast only if needed
|
|
# Even if two arguments need broadcasting, this is faster than
|
|
# `broadcast_arrays`, especially since we've already determined `shape`
|
|
if arg.shape != shape:
|
|
kwargs = {'subok': True} if is_numpy(xp) else {}
|
|
arg = xp.broadcast_to(arg, shape, **kwargs)
|
|
|
|
# This is much faster than xp.astype(arg, dtype, copy=False)
|
|
if arg.dtype != dtype:
|
|
arg = xp.astype(arg, dtype)
|
|
|
|
out.append(arg)
|
|
|
|
return out[0] if len(out)==1 else tuple(out)
|
|
|
|
|
|
def xp_float_to_complex(arr: Array, xp: ModuleType | None = None) -> Array:
|
|
xp = array_namespace(arr) if xp is None else xp
|
|
arr_dtype = arr.dtype
|
|
# The standard float dtypes are float32 and float64.
|
|
# Convert float32 to complex64,
|
|
# and float64 (and non-standard real dtypes) to complex128
|
|
if xp.isdtype(arr_dtype, xp.float32):
|
|
arr = xp.astype(arr, xp.complex64)
|
|
elif xp.isdtype(arr_dtype, 'real floating'):
|
|
arr = xp.astype(arr, xp.complex128)
|
|
|
|
return arr
|
|
|
|
|
|
def xp_default_dtype(xp):
|
|
"""Query the namespace-dependent default floating-point dtype.
|
|
"""
|
|
if is_torch(xp):
|
|
# historically, we allow pytorch to keep its default of float32
|
|
return xp.get_default_dtype()
|
|
else:
|
|
# we default to float64
|
|
return xp.float64
|
|
|
|
|
|
### MArray Helpers ###
|
|
def xp_result_device(*args):
|
|
"""Return the device of an array in `args`, for the purpose of
|
|
input-output device propagation.
|
|
If there are multiple devices, return an arbitrary one.
|
|
If there are no arrays, return None (this typically happens only on NumPy).
|
|
"""
|
|
for arg in args:
|
|
# Do not do a duck-type test for the .device attribute, as many backends today
|
|
# don't have it yet. See workarouunds in array_api_compat.device().
|
|
if is_array_api_obj(arg):
|
|
return xp_device(arg)
|
|
return None
|
|
|
|
|
|
# np.r_ replacement
|
|
def concat_1d(xp: ModuleType | None, *arrays: Iterable[ArrayLike]) -> Array:
|
|
"""A replacement for `np.r_` as `xp.concat` does not accept python scalars
|
|
or 0-D arrays.
|
|
"""
|
|
arys = [xpx.atleast_nd(xp.asarray(a), ndim=1, xp=xp) for a in arrays]
|
|
return xp.concat(arys)
|
|
|
|
|
|
def is_marray(xp):
|
|
"""Returns True if `xp` is an MArray namespace; False otherwise."""
|
|
return "marray" in xp.__name__
|
|
|
|
|
|
def _length_nonmasked(x, axis, keepdims=False, xp=None):
|
|
xp = array_namespace(x) if xp is None else xp
|
|
if is_marray(xp):
|
|
if np.iterable(axis):
|
|
message = '`axis` must be an integer or None for use with `MArray`.'
|
|
raise NotImplementedError(message)
|
|
return xp.astype(xp.count(x, axis=axis, keepdims=keepdims), x.dtype)
|
|
return (xp_size(x) if axis is None else
|
|
# compact way to deal with axis tuples or ints
|
|
int(np.prod(np.asarray(x.shape)[np.asarray(axis)])))
|
|
|
|
|
|
def _share_masks(*args, xp):
|
|
if is_marray(xp):
|
|
mask = functools.reduce(operator.or_, (arg.mask for arg in args))
|
|
args = [xp.asarray(arg.data, mask=mask) for arg in args]
|
|
return args[0] if len(args) == 1 else args
|
|
|
|
### End MArray Helpers ###
|
|
|
|
|
|
@dataclasses.dataclass(repr=False)
|
|
class _XPSphinxCapability:
|
|
cpu: bool | None # None if not applicable
|
|
gpu: bool | None
|
|
warnings: list[str] = dataclasses.field(default_factory=list)
|
|
|
|
def _render(self, value):
|
|
if value is None:
|
|
return "n/a"
|
|
if not value:
|
|
return "⛔"
|
|
if self.warnings:
|
|
res = "⚠️ " + '; '.join(self.warnings)
|
|
assert len(res) <= 20, "Warnings too long"
|
|
return res
|
|
return "✅"
|
|
|
|
def __str__(self):
|
|
cpu = self._render(self.cpu)
|
|
gpu = self._render(self.gpu)
|
|
return f"{cpu:20} {gpu:20}"
|
|
|
|
|
|
def _make_sphinx_capabilities(
|
|
# lists of tuples [(module name, reason), ...]
|
|
skip_backends=(), xfail_backends=(),
|
|
# @pytest.mark.skip/xfail_xp_backends kwargs
|
|
cpu_only=False, np_only=False, out_of_scope=False, exceptions=(),
|
|
# xpx.lazy_xp_backends kwargs
|
|
allow_dask_compute=False, jax_jit=True,
|
|
# list of tuples [(module name, reason), ...]
|
|
warnings = (),
|
|
# unused in documentation
|
|
reason=None,
|
|
):
|
|
if out_of_scope:
|
|
return {"out_of_scope": True}
|
|
|
|
exceptions = set(exceptions)
|
|
|
|
# Default capabilities
|
|
capabilities = {
|
|
"numpy": _XPSphinxCapability(cpu=True, gpu=None),
|
|
"array_api_strict": _XPSphinxCapability(cpu=True, gpu=None),
|
|
"cupy": _XPSphinxCapability(cpu=None, gpu=True),
|
|
"torch": _XPSphinxCapability(cpu=True, gpu=True),
|
|
"jax.numpy": _XPSphinxCapability(cpu=True, gpu=True,
|
|
warnings=[] if jax_jit else ["no JIT"]),
|
|
# Note: Dask+CuPy is currently untested and unsupported
|
|
"dask.array": _XPSphinxCapability(cpu=True, gpu=None,
|
|
warnings=["computes graph"] if allow_dask_compute else []),
|
|
}
|
|
|
|
# documentation doesn't display the reason
|
|
for module, _ in list(skip_backends) + list(xfail_backends):
|
|
backend = capabilities[module]
|
|
if backend.cpu is not None:
|
|
backend.cpu = False
|
|
if backend.gpu is not None:
|
|
backend.gpu = False
|
|
|
|
for module, backend in capabilities.items():
|
|
if np_only and module not in exceptions | {"numpy"}:
|
|
if backend.cpu is not None:
|
|
backend.cpu = False
|
|
if backend.gpu is not None:
|
|
backend.gpu = False
|
|
elif cpu_only and module not in exceptions and backend.gpu is not None:
|
|
backend.gpu = False
|
|
|
|
for module, warning in warnings:
|
|
backend = capabilities[module]
|
|
backend.warnings.append(warning)
|
|
|
|
return capabilities
|
|
|
|
|
|
def _make_capabilities_note(fun_name, capabilities, extra_note=None):
|
|
if "out_of_scope" in capabilities:
|
|
# It will be better to link to a section of the dev-arrayapi docs
|
|
# that explains what is and isn't in-scope, but such a section
|
|
# doesn't exist yet. Using :ref:`dev-arrayapi` as a placeholder.
|
|
note = f"""
|
|
**Array API Standard Support**
|
|
|
|
`{fun_name}` is not in-scope for support of Python Array API Standard compatible
|
|
backends other than NumPy.
|
|
|
|
See :ref:`dev-arrayapi` for more information.
|
|
"""
|
|
return textwrap.dedent(note)
|
|
|
|
# Note: deliberately not documenting array-api-strict
|
|
note = f"""
|
|
**Array API Standard Support**
|
|
|
|
`{fun_name}` has experimental support for Python Array API Standard compatible
|
|
backends in addition to NumPy. Please consider testing these features
|
|
by setting an environment variable ``SCIPY_ARRAY_API=1`` and providing
|
|
CuPy, PyTorch, JAX, or Dask arrays as array arguments. The following
|
|
combinations of backend and device (or other capability) are supported.
|
|
|
|
==================== ==================== ====================
|
|
Library CPU GPU
|
|
==================== ==================== ====================
|
|
NumPy {capabilities['numpy'] }
|
|
CuPy {capabilities['cupy'] }
|
|
PyTorch {capabilities['torch'] }
|
|
JAX {capabilities['jax.numpy'] }
|
|
Dask {capabilities['dask.array'] }
|
|
==================== ==================== ====================
|
|
|
|
""" + (extra_note or "") + " See :ref:`dev-arrayapi` for more information."
|
|
|
|
return textwrap.dedent(note)
|
|
|
|
|
|
def xp_capabilities(
|
|
*,
|
|
# Alternative capabilities table.
|
|
# Used only for testing this decorator.
|
|
capabilities_table=None,
|
|
# Generate pytest.mark.skip/xfail_xp_backends.
|
|
# See documentation in conftest.py.
|
|
# lists of tuples [(module name, reason), ...]
|
|
skip_backends=(), xfail_backends=(),
|
|
cpu_only=False, np_only=False, reason=None,
|
|
out_of_scope=False, exceptions=(),
|
|
# lists of tuples [(module name, reason), ...]
|
|
warnings=(),
|
|
# xpx.testing.lazy_xp_function kwargs.
|
|
# Refer to array-api-extra documentation.
|
|
allow_dask_compute=False, jax_jit=True,
|
|
# Extra note to inject into the docstring
|
|
extra_note=None,
|
|
):
|
|
"""Decorator for a function that states its support among various
|
|
Array API compatible backends.
|
|
|
|
This decorator has two effects:
|
|
1. It allows tagging tests with ``@make_xp_test_case`` or
|
|
``make_xp_pytest_param`` (see below) to automatically generate
|
|
SKIP/XFAIL markers and perform additional backend-specific
|
|
testing, such as extra validation for Dask and JAX;
|
|
2. It automatically adds a note to the function's docstring, containing
|
|
a table matching what has been tested.
|
|
|
|
See Also
|
|
--------
|
|
make_xp_test_case
|
|
make_xp_pytest_param
|
|
array_api_extra.testing.lazy_xp_function
|
|
"""
|
|
capabilities_table = (xp_capabilities_table if capabilities_table is None
|
|
else capabilities_table)
|
|
|
|
if out_of_scope:
|
|
np_only = True
|
|
|
|
capabilities = dict(
|
|
skip_backends=skip_backends,
|
|
xfail_backends=xfail_backends,
|
|
cpu_only=cpu_only,
|
|
np_only=np_only,
|
|
out_of_scope=out_of_scope,
|
|
reason=reason,
|
|
exceptions=exceptions,
|
|
allow_dask_compute=allow_dask_compute,
|
|
jax_jit=jax_jit,
|
|
warnings=warnings,
|
|
)
|
|
sphinx_capabilities = _make_sphinx_capabilities(**capabilities)
|
|
|
|
def decorator(f):
|
|
# Don't use a wrapper, as in some cases @xp_capabilities is
|
|
# applied to a ufunc
|
|
capabilities_table[f] = capabilities
|
|
note = _make_capabilities_note(f.__name__, sphinx_capabilities, extra_note)
|
|
doc = FunctionDoc(f)
|
|
doc['Notes'].append(note)
|
|
doc = str(doc).split("\n", 1)[1].lstrip(" \n") # remove signature
|
|
try:
|
|
f.__doc__ = doc
|
|
except AttributeError:
|
|
# Can't update __doc__ on ufuncs if SciPy
|
|
# was compiled against NumPy < 2.2.
|
|
pass
|
|
|
|
return f
|
|
return decorator
|
|
|
|
|
|
def make_xp_test_case(*funcs, capabilities_table=None):
|
|
capabilities_table = (xp_capabilities_table if capabilities_table is None
|
|
else capabilities_table)
|
|
"""Generate pytest decorator for a test function that tests functionality
|
|
of one or more Array API compatible functions.
|
|
|
|
Read the parameters of the ``@xp_capabilities`` decorator applied to the
|
|
listed functions and:
|
|
|
|
- Generate the ``@pytest.mark.skip_xp_backends`` and
|
|
``@pytest.mark.xfail_xp_backends`` decorators
|
|
for the decorated test function
|
|
- Tag the function with `xpx.testing.lazy_xp_function`
|
|
|
|
Example::
|
|
|
|
@make_xp_test_case(f1)
|
|
def test_f1(xp):
|
|
...
|
|
|
|
@make_xp_test_case(f2)
|
|
def test_f2(xp):
|
|
...
|
|
|
|
@make_xp_test_case(f1, f2)
|
|
def test_f1_and_f2(xp):
|
|
...
|
|
|
|
The above is equivalent to::
|
|
@pytest.mark.skip_xp_backends(...)
|
|
@pytest.mark.skip_xp_backends(...)
|
|
@pytest.mark.xfail_xp_backends(...)
|
|
@pytest.mark.xfail_xp_backends(...)
|
|
def test_f1(xp):
|
|
...
|
|
|
|
etc., where the arguments of ``skip_xp_backends`` and ``xfail_xp_backends`` are
|
|
determined by the ``@xp_capabilities`` decorator applied to the functions.
|
|
|
|
See Also
|
|
--------
|
|
xp_capabilities
|
|
make_xp_pytest_marks
|
|
make_xp_pytest_param
|
|
array_api_extra.testing.lazy_xp_function
|
|
"""
|
|
marks = make_xp_pytest_marks(*funcs, capabilities_table=capabilities_table)
|
|
return lambda func: functools.reduce(lambda f, g: g(f), marks, func)
|
|
|
|
|
|
def make_xp_pytest_param(func, *args, capabilities_table=None):
|
|
"""Variant of ``make_xp_test_case`` that returns a pytest.param for a function,
|
|
with all necessary skip_xp_backends and xfail_xp_backends marks applied::
|
|
|
|
@pytest.mark.parametrize(
|
|
"func", [make_xp_pytest_param(f1), make_xp_pytest_param(f2)]
|
|
)
|
|
def test(func, xp):
|
|
...
|
|
|
|
The above is equivalent to::
|
|
|
|
@pytest.mark.parametrize(
|
|
"func", [
|
|
pytest.param(f1, marks=[
|
|
pytest.mark.skip_xp_backends(...),
|
|
pytest.mark.xfail_xp_backends(...), ...]),
|
|
pytest.param(f2, marks=[
|
|
pytest.mark.skip_xp_backends(...),
|
|
pytest.mark.xfail_xp_backends(...), ...]),
|
|
)
|
|
def test(func, xp):
|
|
...
|
|
|
|
Parameters
|
|
----------
|
|
func : Callable
|
|
Function to be tested. It must be decorated with ``@xp_capabilities``.
|
|
*args : Any, optional
|
|
Extra pytest parameters for the use case, e.g.::
|
|
|
|
@pytest.mark.parametrize("func,verb", [
|
|
make_xp_pytest_param(f1, "hello"),
|
|
make_xp_pytest_param(f2, "world")])
|
|
def test(func, verb, xp):
|
|
# iterates on (func=f1, verb="hello")
|
|
# and (func=f2, verb="world")
|
|
|
|
See Also
|
|
--------
|
|
xp_capabilities
|
|
make_xp_test_case
|
|
make_xp_pytest_marks
|
|
array_api_extra.testing.lazy_xp_function
|
|
"""
|
|
import pytest
|
|
|
|
marks = make_xp_pytest_marks(func, capabilities_table=capabilities_table)
|
|
return pytest.param(func, *args, marks=marks, id=func.__name__)
|
|
|
|
|
|
def make_xp_pytest_marks(*funcs, capabilities_table=None):
|
|
"""Variant of ``make_xp_test_case`` that returns a list of pytest marks,
|
|
which can be used with the module-level `pytestmark = ...` variable::
|
|
|
|
pytestmark = make_xp_pytest_marks(f1, f2)
|
|
|
|
def test(xp):
|
|
...
|
|
|
|
In this example, the whole test module is dedicated to testing `f1` or `f2`,
|
|
and the two functions have the same capabilities, so it's unnecessary to
|
|
cherry-pick which test tests which function.
|
|
The above is equivalent to::
|
|
|
|
pytestmark = [
|
|
pytest.mark.skip_xp_backends(...),
|
|
pytest.mark.xfail_xp_backends(...), ...]),
|
|
]
|
|
|
|
def test(xp):
|
|
...
|
|
|
|
See Also
|
|
--------
|
|
xp_capabilities
|
|
make_xp_test_case
|
|
make_xp_pytest_param
|
|
array_api_extra.testing.lazy_xp_function
|
|
"""
|
|
capabilities_table = (xp_capabilities_table if capabilities_table is None
|
|
else capabilities_table)
|
|
import pytest
|
|
|
|
marks = []
|
|
for func in funcs:
|
|
capabilities = capabilities_table[func]
|
|
exceptions = capabilities['exceptions']
|
|
reason = capabilities['reason']
|
|
|
|
if capabilities['cpu_only']:
|
|
marks.append(pytest.mark.skip_xp_backends(
|
|
cpu_only=True, exceptions=exceptions, reason=reason))
|
|
if capabilities['np_only']:
|
|
marks.append(pytest.mark.skip_xp_backends(
|
|
np_only=True, exceptions=exceptions, reason=reason))
|
|
|
|
for mod_name, reason in capabilities['skip_backends']:
|
|
marks.append(pytest.mark.skip_xp_backends(mod_name, reason=reason))
|
|
for mod_name, reason in capabilities['xfail_backends']:
|
|
marks.append(pytest.mark.xfail_xp_backends(mod_name, reason=reason))
|
|
|
|
lazy_kwargs = {k: capabilities[k]
|
|
for k in ('allow_dask_compute', 'jax_jit')}
|
|
lazy_xp_function(func, **lazy_kwargs)
|
|
|
|
return marks
|
|
|
|
|
|
# Is it OK to have a dictionary that is mutated (once upon import) in many places?
|
|
xp_capabilities_table = {} # type: ignore[var-annotated]
|
|
|
|
|
|
def xp_device_type(a: Array) -> Literal["cpu", "cuda", None]:
|
|
if is_numpy_array(a):
|
|
return "cpu"
|
|
if is_cupy_array(a):
|
|
return "cuda"
|
|
if is_torch_array(a):
|
|
# TODO this can return other backends e.g. tpu but they're unsupported in scipy
|
|
return a.device.type
|
|
if is_jax_array(a):
|
|
# TODO this can return other backends e.g. tpu but they're unsupported in scipy
|
|
return "cuda" if (p := a.device.platform) == "gpu" else p
|
|
if is_dask_array(a):
|
|
return xp_device_type(a._meta)
|
|
# array-api-strict is a stand-in for unknown libraries; don't special-case it
|
|
return None
|