You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
128 lines
3.5 KiB
128 lines
3.5 KiB
from __future__ import annotations
|
|
|
|
from typing import TYPE_CHECKING, Union
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import Sequence
|
|
|
|
from . import Dim
|
|
|
|
import torch # noqa: TC002
|
|
|
|
|
|
# NB: The old code represented dimension was from as negative number, so we
|
|
# follow this convention even though it shouldn't be necessary now
|
|
class DimEntry:
|
|
# The dimension this is from the rhs, or a FCD
|
|
data: Union[Dim, int]
|
|
|
|
def __init__(self, data: Union[Dim, int, None] = None) -> None:
|
|
from . import Dim
|
|
|
|
if type(data) is int:
|
|
assert data < 0
|
|
elif data is None:
|
|
data = 0
|
|
else:
|
|
assert isinstance(data, Dim)
|
|
self.data = data
|
|
|
|
def __eq__(self, other: object) -> bool:
|
|
if not isinstance(other, DimEntry):
|
|
return False
|
|
# Use 'is' for Dim objects to avoid triggering __torch_function__
|
|
# Use '==' only for positional (int) comparisons
|
|
if self.is_positional() and other.is_positional():
|
|
# Both are positional (ints)
|
|
return self.data == other.data
|
|
elif not self.is_positional() and not other.is_positional():
|
|
# Both are Dim objects - use 'is' to avoid __eq__
|
|
return self.data is other.data
|
|
else:
|
|
# One is positional, one is Dim - they can't be equal
|
|
return False
|
|
|
|
def is_positional(self) -> bool:
|
|
return type(self.data) is int and self.data < 0
|
|
|
|
def is_none(self) -> bool:
|
|
# Use isinstance to check for Dim objects, avoid triggering __torch_function__
|
|
from . import Dim
|
|
|
|
if isinstance(self.data, Dim):
|
|
# This is a Dim object, it can't be "none" (which is represented by 0)
|
|
return False
|
|
else:
|
|
# This is an int or other type
|
|
return self.data == 0
|
|
|
|
def position(self) -> int:
|
|
assert isinstance(self.data, int)
|
|
return self.data
|
|
|
|
def dim(self) -> Dim:
|
|
assert not isinstance(self.data, int)
|
|
return self.data
|
|
|
|
def __repr__(self) -> str:
|
|
return repr(self.data)
|
|
|
|
|
|
def ndim_of_levels(levels: Sequence[DimEntry]) -> int:
|
|
r = 0
|
|
for l in levels:
|
|
if l.is_positional():
|
|
r += 1
|
|
return r
|
|
|
|
|
|
def _match_levels(
|
|
tensor: torch.Tensor,
|
|
from_levels: list[DimEntry],
|
|
to_levels: list[DimEntry],
|
|
drop_levels: bool = False,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Reshape a tensor to match target levels using as_strided.
|
|
|
|
Args:
|
|
tensor: Input tensor to reshape
|
|
from_levels: Current levels of the tensor
|
|
to_levels: Target levels to match
|
|
drop_levels: If True, missing dimensions are assumed to have stride 0
|
|
|
|
Returns:
|
|
Reshaped tensor
|
|
"""
|
|
if from_levels == to_levels:
|
|
return tensor
|
|
|
|
sizes = tensor.size()
|
|
strides = tensor.stride()
|
|
|
|
if not drop_levels:
|
|
assert len(from_levels) <= len(to_levels), (
|
|
"Cannot expand dimensions without drop_levels"
|
|
)
|
|
|
|
new_sizes = []
|
|
new_strides = []
|
|
|
|
for level in to_levels:
|
|
# Find index of this level in from_levels
|
|
try:
|
|
idx = from_levels.index(level)
|
|
except ValueError:
|
|
# Level not found in from_levels
|
|
if level.is_positional():
|
|
new_sizes.append(1)
|
|
else:
|
|
new_sizes.append(level.dim().size)
|
|
new_strides.append(0)
|
|
else:
|
|
new_sizes.append(sizes[idx])
|
|
new_strides.append(strides[idx])
|
|
|
|
return tensor.as_strided(new_sizes, new_strides, tensor.storage_offset())
|