You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

108 lines
3.3 KiB

"""
Matrix square root for general matrices and for upper triangular matrices.
This module exists to avoid cyclic imports.
"""
__all__ = []
import numpy as np
# Local imports
from .lapack import ztrsyl, dtrsyl
class SqrtmError(np.linalg.LinAlgError):
pass
from ._matfuncs_sqrtm_triu import within_block_loop # noqa: E402
def _sqrtm_triu(T, blocksize=64):
"""
Matrix square root of an upper triangular matrix.
This is a helper function for `sqrtm` and `logm`.
Parameters
----------
T : (N, N) array_like upper triangular
Matrix whose square root to evaluate
blocksize : int, optional
If the blocksize is not degenerate with respect to the
size of the input array, then use a blocked algorithm. (Default: 64)
Returns
-------
sqrtm : (N, N) ndarray
Value of the sqrt function at `T`
References
----------
.. [1] Edvin Deadman, Nicholas J. Higham, Rui Ralha (2013)
"Blocked Schur Algorithms for Computing the Matrix Square Root,
Lecture Notes in Computer Science, 7782. pp. 171-182.
"""
T_diag = np.diag(T)
keep_it_real = np.isrealobj(T) and np.min(T_diag, initial=0.) >= 0
# Cast to complex as necessary + ensure double precision
if not keep_it_real:
T = np.asarray(T, dtype=np.complex128, order="C")
T_diag = np.asarray(T_diag, dtype=np.complex128)
else:
T = np.asarray(T, dtype=np.float64, order="C")
T_diag = np.asarray(T_diag, dtype=np.float64)
R = np.diag(np.sqrt(T_diag))
# Compute the number of blocks to use; use at least one block.
n, n = T.shape
nblocks = max(n // blocksize, 1)
# Compute the smaller of the two sizes of blocks that
# we will actually use, and compute the number of large blocks.
bsmall, nlarge = divmod(n, nblocks)
blarge = bsmall + 1
nsmall = nblocks - nlarge
if nsmall * bsmall + nlarge * blarge != n:
raise Exception('internal inconsistency')
# Define the index range covered by each block.
start_stop_pairs = []
start = 0
for count, size in ((nsmall, bsmall), (nlarge, blarge)):
for i in range(count):
start_stop_pairs.append((start, start + size))
start += size
# Within-block interactions (Cythonized)
try:
within_block_loop(R, T, start_stop_pairs, nblocks)
except RuntimeError as e:
raise SqrtmError(*e.args) from e
# Between-block interactions (Cython would give no significant speedup)
for j in range(nblocks):
jstart, jstop = start_stop_pairs[j]
for i in range(j-1, -1, -1):
istart, istop = start_stop_pairs[i]
S = T[istart:istop, jstart:jstop]
if j - i > 1:
S = S - R[istart:istop, istop:jstart].dot(R[istop:jstart,
jstart:jstop])
# Invoke LAPACK.
# For more details, see the solve_sylvester implementation
# and the fortran dtrsyl and ztrsyl docs.
Rii = R[istart:istop, istart:istop]
Rjj = R[jstart:jstop, jstart:jstop]
if keep_it_real:
x, scale, info = dtrsyl(Rii, Rjj, S)
else:
x, scale, info = ztrsyl(Rii, Rjj, S)
R[istart:istop, jstart:jstop] = x * scale
# Return the matrix square root.
return R