You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
27 lines
753 B
27 lines
753 B
# mypy: allow-untyped-defs
|
|
|
|
|
|
def maybe_view(tensor, size, check_same_size=True):
|
|
if check_same_size and tensor.size() == size:
|
|
return tensor
|
|
return tensor.contiguous().view(size)
|
|
|
|
|
|
def maybe_unexpand(tensor, old_size, check_same_size=True):
|
|
if check_same_size and tensor.size() == old_size:
|
|
return tensor
|
|
num_unsqueezed = tensor.dim() - len(old_size)
|
|
expanded_dims = [
|
|
dim
|
|
for dim, (expanded, original) in enumerate(
|
|
zip(tensor.size()[num_unsqueezed:], old_size)
|
|
)
|
|
if expanded != original
|
|
]
|
|
|
|
for _ in range(num_unsqueezed):
|
|
tensor = tensor.sum(0, keepdim=False)
|
|
for dim in expanded_dims:
|
|
tensor = tensor.sum(dim, keepdim=True)
|
|
return tensor
|