You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

111 lines
2.2 KiB

5 days ago
""" Test functions for linalg module using the matrix class."""
import pytest
import numpy as np
from numpy.linalg.tests.test_linalg import (
CondCases,
DetCases,
EigCases,
EigvalsCases,
InvCases,
LinalgCase,
LinalgTestCase,
LstsqCases,
PinvCases,
SolveCases,
SVDCases,
TestQR as _TestQR,
_TestNorm2D,
_TestNormDoubleBase,
_TestNormInt64Base,
_TestNormSingleBase,
apply_tag,
)
CASES = []
# square test cases
CASES += apply_tag('square', [
LinalgCase("0x0_matrix",
np.empty((0, 0), dtype=np.double).view(np.matrix),
np.empty((0, 1), dtype=np.double).view(np.matrix),
tags={'size-0'}),
LinalgCase("matrix_b_only",
np.array([[1., 2.], [3., 4.]]),
np.matrix([2., 1.]).T),
LinalgCase("matrix_a_and_b",
np.matrix([[1., 2.], [3., 4.]]),
np.matrix([2., 1.]).T),
])
# hermitian test-cases
CASES += apply_tag('hermitian', [
LinalgCase("hmatrix_a_and_b",
np.matrix([[1., 2.], [2., 1.]]),
None),
])
# No need to make generalized or strided cases for matrices.
class MatrixTestCase(LinalgTestCase):
TEST_CASES = CASES
class TestSolveMatrix(SolveCases, MatrixTestCase):
pass
class TestInvMatrix(InvCases, MatrixTestCase):
pass
class TestEigvalsMatrix(EigvalsCases, MatrixTestCase):
pass
class TestEigMatrix(EigCases, MatrixTestCase):
pass
class TestSVDMatrix(SVDCases, MatrixTestCase):
pass
class TestCondMatrix(CondCases, MatrixTestCase):
pass
class TestPinvMatrix(PinvCases, MatrixTestCase):
pass
class TestDetMatrix(DetCases, MatrixTestCase):
pass
@pytest.mark.thread_unsafe(
reason="residuals not calculated properly for square tests (gh-29851)"
)
class TestLstsqMatrix(LstsqCases, MatrixTestCase):
pass
class _TestNorm2DMatrix(_TestNorm2D):
array = np.matrix
class TestNormDoubleMatrix(_TestNorm2DMatrix, _TestNormDoubleBase):
pass
class TestNormSingleMatrix(_TestNorm2DMatrix, _TestNormSingleBase):
pass
class TestNormInt64Matrix(_TestNorm2DMatrix, _TestNormInt64Base):
pass
class TestQRMatrix(_TestQR):
array = np.matrix