You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

215 lines
6.8 KiB

from __future__ import annotations
from typing import Any, TYPE_CHECKING, Union
if TYPE_CHECKING:
from collections.abc import Sequence
import torch # noqa: TC002
from ._dim_entry import _match_levels, DimEntry, ndim_of_levels
def _wrap_dim(arg: Any, orig_ndim: int, allow_none: bool = True) -> DimEntry:
"""
Convert various dimension representations to DimEntry.
Args:
arg: The argument to convert (Dim, int, or other)
orig_ndim: Original number of dimensions
allow_none: Whether to allow None values
Returns:
DimEntry representation of the dimension
"""
from . import Dim
if arg is None and allow_none:
return DimEntry() # None entry
elif isinstance(arg, Dim):
return DimEntry(arg)
elif isinstance(arg, int):
if arg < 0:
pos = arg
else:
pos = arg - orig_ndim
return DimEntry(pos)
else:
return DimEntry()
def order(
tensor_or_dim: Union[torch.Tensor, Any], *dims: Union[Any, Sequence[Any]]
) -> torch.Tensor:
"""
Reorder the dimensions of a tensor or create a tensor from a dimension.
It allows reordering tensor dimensions using first-class dimensions and
positional indices.
Args:
tensor_or_dim: Input tensor with first-class dimensions, or a Dim object
*dims: Dimensions or sequences of dimensions specifying the new order
Returns:
Tensor with reordered dimensions
Examples:
>>> import torch
>>> from functorch.dim import dims
>>> batch, channel, height, width = dims(4)
>>> x = torch.randn(2, 3, 4, 5)[batch, channel, height, width]
>>> # Reorder to [height, width, batch, channel]
>>> y = order(x, height, width, batch, channel)
"""
from . import Dim, DimList, Tensor
# Handle first argument - tensor or dimension
if isinstance(tensor_or_dim, Tensor):
# First-class tensor
orig_levels = tensor_or_dim._levels[:]
data = tensor_or_dim._tensor
has_device = tensor_or_dim._has_device
elif isinstance(tensor_or_dim, Dim):
# Single dimension - create range tensor
orig_levels = [DimEntry(tensor_or_dim)]
data = tensor_or_dim._get_range()
has_device = False
else:
raise ValueError("First argument must be a Tensor or Dim object")
flat_positional_dims = []
to_flatten = [] # List of (start_index, length) pairs for flattening
levels = orig_levels[:]
orig_ndim = ndim_of_levels(levels)
def append_dim(d: DimEntry) -> None:
"""Add a dimension to the reordering, removing it from available levels."""
try:
idx = levels.index(d)
except ValueError:
idx = None
if idx is None:
if d.is_positional():
raise ValueError(
f"tensor has {orig_ndim} positional dimensions, but {d.position() + orig_ndim} specified, "
f"or it was specified twice"
)
else:
raise ValueError(
f"tensor does not contain dim {d.dim()} or it was specified twice"
)
levels[idx] = DimEntry()
flat_positional_dims.append(d)
n_new_positional = 0
# Process each dimension argument
for arg in dims:
entry = _wrap_dim(arg, orig_ndim, False)
if not entry.is_none():
append_dim(entry)
n_new_positional += 1
elif isinstance(arg, DimList):
# Handle DimList
for dim in arg._dims:
append_dim(DimEntry(dim))
n_new_positional += 1
else:
# Handle sequences of dimensions for flattening
n_new_positional += 1
if not hasattr(arg, "__iter__"):
raise ValueError("expected a Dim, List[Dim], or Sequence[Dim]")
# Convert to list to get length
seq = list(arg)
to_flatten.append((len(flat_positional_dims), len(seq)))
for item in seq:
entry = _wrap_dim(item, orig_ndim, False)
if entry.is_none():
raise ValueError("expected a Dim or int")
append_dim(entry)
# Build new level ordering
insert_point = -1
new_levels: list[DimEntry] = []
# Add remaining (non-reordered) levels, finding insertion point for new dimensions
for level in levels:
if level.is_none():
continue
if level.is_positional():
if insert_point == -1:
insert_point = len(new_levels)
new_levels.extend(flat_positional_dims)
new_levels.append(level)
# If no positional dimensions found, append new dims at the end
if insert_point == -1:
insert_point = len(new_levels)
new_levels.extend(flat_positional_dims)
# Match tensor to new level structure
assert data is not None, "Cannot reorder None tensor"
ndata = _match_levels(data, orig_levels, new_levels)
# Handle dimension flattening if requested
if to_flatten:
# Now build the reshape target
view_shape = []
sizes = ndata.size()
# Add dimensions before the reordered ones
for i in range(insert_point):
view_shape.append(sizes[i])
# Process flattening groups
i = 0
for start_idx, length in to_flatten:
# Add individual dims before this flattening group
while i < start_idx:
view_shape.append(sizes[insert_point + i])
i += 1
# Flatten the group
new_size = 1
for j in range(length):
new_size *= sizes[insert_point + i + j]
view_shape.append(new_size)
i += length
# Add remaining individual dims
while i < len(flat_positional_dims):
view_shape.append(sizes[insert_point + i])
i += 1
# Add dimensions after the reordered ones
for i in range(insert_point + len(flat_positional_dims), len(levels)):
view_shape.append(sizes[i])
# Update levels by removing flattened dimensions
n_to_remove = len(flat_positional_dims) - n_new_positional
if n_to_remove > 0:
# Remove flattened levels
new_levels = (
new_levels[:insert_point] + new_levels[insert_point + n_to_remove :]
)
ndata = ndata.reshape(view_shape)
# Renumber positional dimensions (negative indexing from the right)
seen = 0
for i in range(len(new_levels) - 1, -1, -1):
if new_levels[i].is_positional() or (
i >= insert_point and i < insert_point + n_new_positional
):
seen -= 1
new_levels[i] = DimEntry(seen)
result = Tensor.from_positional(ndata, new_levels, has_device)
return result # type: ignore[return-value]