You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

122 lines
3.5 KiB

import contextlib
import shutil
import tempfile
from pathlib import Path
import numpy
import pytest
from thinc.api import ArgsKwargs, Linear, Padded, Ragged
from thinc.util import has_cupy, is_cupy_array, is_numpy_array
@contextlib.contextmanager
def make_tempdir():
d = Path(tempfile.mkdtemp())
yield d
shutil.rmtree(str(d))
def get_model(W_b_input, cls=Linear):
W, b, input_ = W_b_input
nr_out, nr_in = W.shape
model = cls(nr_out, nr_in)
model.set_param("W", W)
model.set_param("b", b)
model.initialize()
return model
def get_shape(W_b_input):
W, b, input_ = W_b_input
return input_.shape[0], W.shape[0], W.shape[1]
def get_data_checker(inputs):
if isinstance(inputs, Ragged):
return assert_raggeds_match
elif isinstance(inputs, Padded):
return assert_paddeds_match
elif isinstance(inputs, list):
return assert_lists_match
elif isinstance(inputs, tuple) and len(inputs) == 4:
return assert_padded_data_match
elif isinstance(inputs, tuple) and len(inputs) == 2:
return assert_ragged_data_match
else:
return assert_arrays_match
def assert_arrays_match(X, Y):
assert X.dtype == Y.dtype
# Transformations are allowed to change last dimension, but not batch size.
assert X.shape[0] == Y.shape[0]
return True
def assert_lists_match(X, Y):
assert isinstance(X, list)
assert isinstance(Y, list)
assert len(X) == len(Y)
for x, y in zip(X, Y):
assert_arrays_match(x, y)
return True
def assert_raggeds_match(X, Y):
assert isinstance(X, Ragged)
assert isinstance(Y, Ragged)
assert_arrays_match(X.lengths, Y.lengths)
assert_arrays_match(X.data, Y.data)
return True
def assert_paddeds_match(X, Y):
assert isinstance(X, Padded)
assert isinstance(Y, Padded)
assert_arrays_match(X.size_at_t, Y.size_at_t)
assert assert_arrays_match(X.lengths, Y.lengths)
assert assert_arrays_match(X.indices, Y.indices)
assert X.data.dtype == Y.data.dtype
assert X.data.shape[1] == Y.data.shape[1]
assert X.data.shape[0] == Y.data.shape[0]
return True
def assert_padded_data_match(X, Y):
return assert_paddeds_match(Padded(*X), Padded(*Y))
def assert_ragged_data_match(X, Y):
return assert_raggeds_match(Ragged(*X), Ragged(*Y))
def check_input_converters(Y, backprop, data, n_args, kwargs_keys, type_):
assert isinstance(Y, ArgsKwargs)
assert len(Y.args) == n_args
assert list(Y.kwargs.keys()) == kwargs_keys
assert all(isinstance(arg, type_) for arg in Y.args)
assert all(isinstance(arg, type_) for arg in Y.kwargs.values())
dX = backprop(Y)
def is_supported_backend_array(arr):
return is_cupy_array(arr) or is_numpy_array(arr)
input_type = type(data) if not isinstance(data, list) else tuple
assert isinstance(dX, input_type) or is_supported_backend_array(dX)
if isinstance(data, dict):
assert list(dX.keys()) == kwargs_keys
assert all(is_supported_backend_array(arr) for arr in dX.values())
elif isinstance(data, (list, tuple)):
assert isinstance(dX, tuple)
assert all(is_supported_backend_array(arr) for arr in dX)
elif isinstance(data, ArgsKwargs):
assert len(dX.args) == n_args
assert list(dX.kwargs.keys()) == kwargs_keys
assert all(is_supported_backend_array(arg) for arg in dX.args)
assert all(is_supported_backend_array(arg) for arg in dX.kwargs.values())
elif not isinstance(data, numpy.ndarray):
pytest.fail(f"Bad data type: {dX}")