You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

562 lines
18 KiB

from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Optional, TYPE_CHECKING, Union
import torch
from ._dim_entry import _match_levels, DimEntry
from ._tensor_info import TensorInfo
if TYPE_CHECKING:
from . import Dim
def _safe_index(lst: list, item: Any) -> Optional[int]:
"""
Helper function to find index of item in list.
For DimEntry objects, uses __eq__ comparison which properly handles
both positional and Dim entries.
Returns the index if found, None if not found.
"""
for i, list_item in enumerate(lst):
# Use == for DimEntry objects as they have proper __eq__ implementation
if isinstance(item, DimEntry) and isinstance(list_item, DimEntry):
if list_item == item:
return i
elif list_item is item:
return i
return None
@dataclass
class IndexingInfo:
can_call_original: bool = False
advanced_indexing: bool = False
self_tensor: Optional[torch.Tensor] = None
flat_inputs: list[Any] = field(default_factory=list)
result_levels: list[DimEntry] = field(default_factory=list)
has_device: bool = False
def has_dims(obj: Any) -> bool:
"""
Check if an object has first-class dimensions.
This function checks if the object is either a Dim or a functorch Tensor
that has first-class dimensions, using the proper check_exact methods.
"""
from . import Dim, Tensor
return Dim.check_exact(obj) or Tensor.check_exact(obj)
def _bind_dims_to_size(sz: int, sd: int, dims: list, nsz: list, nsd: list) -> None:
"""
Bind dimensions to size and calculate proper strides for dim packs.
"""
from . import DimensionBindError
rhs_prod = 1
for i, dim in enumerate(dims):
if not dim.is_bound:
# Check for multiple unbound dimensions
for j in range(i + 1, len(dims)):
if not dims[j].is_bound:
raise DimensionBindError(
f"cannot infer the sizes of two dimensions at once {dim!r} and {dims[j]!r}"
)
rhs_prod *= dims[j].size
# Calculate the size for this unbound dimension
if sz % rhs_prod != 0:
tup = tuple(dim.size if dim.is_bound else "?" for dim in dims)
raise DimensionBindError(
f"inferred dimension does not evenly fit into larger dimension: {sz} vs {tup}"
)
inferred_size = sz // rhs_prod
dim.size = inferred_size
rhs_prod = sz
break
else:
rhs_prod *= dim.size
# Final validation that dimensions match
if rhs_prod != sz:
tup = tuple(dims)
raise DimensionBindError(
f"Dimension sizes to do not match ({sz} != {rhs_prod}) when matching dimension pack {tup}"
)
# Calculate new sizes and strides for each dimension in the pack
# First calculate all strides by iterating in reverse
new_strides = [0] * len(dims)
current_stride = sd
for i in reversed(range(len(dims))):
new_strides[i] = current_stride
current_stride *= dims[i].size
# Then append sizes and strides in forward order
for i in range(len(dims)):
nsz.append(dims[i].size)
nsd.append(new_strides[i])
def slice_to_tuple(flat_inputs: list) -> tuple:
return tuple(flat_inputs)
def extractIndices(index: Any, indices: list) -> bool:
if isinstance(index, tuple): # mpy::tuple_view::check
indices.extend(index)
return True
elif isinstance(index, torch.Tensor): # THPVariable_Check
indices.append(index)
return False
elif not hasattr(index, "__iter__") or isinstance(
index, (str, bytes)
): # !mpy::is_sequence
indices.append(index)
return False
# Handle sequence case (list)
if isinstance(index, list):
if len(index) >= 32:
indices.extend(index)
return True
# Check each item in the sequence
for item in index:
if (
isinstance(item, (torch.Tensor, slice))
or hasattr(item, "__iter__")
or item is ...
or item is None
or has_dims(item)
):
indices.extend(index)
return True
# If we got here, treat as single index
indices.append(index)
return False
# Default case
indices.append(index)
return False
def getitem(cls: Any, func: Any, types: Any, args: Any, kwargs: Any) -> Any:
self = args[0]
index = args[1]
iinfo = getsetitem(self, index, has_dims(self))
if iinfo.can_call_original:
# Call original tensor __getitem__ directly, bypassing __torch_function__
return torch.Tensor.__getitem__(self, index)
return invoke_getitem(iinfo)
def setitem(self: Any, index: Any, rhs: Any) -> None:
"""Set values in tensor using first-class dimensions."""
from . import DimensionBindError, TensorInfo
iinfo = getsetitem(self, index, has_dims(self) or has_dims(rhs))
if iinfo.can_call_original:
# Call original tensor __setitem__ directly, bypassing __torch_function__
torch._C.TensorBase.__setitem__(self, index, rhs)
return
# Handle RHS tensor with dimensions
rhs_info = TensorInfo.create(rhs, False, False)
if rhs_info:
# Check that rhs dimensions are compatible with result dimensions
for l in rhs_info.levels:
if not l.is_positional():
# Find this dimension in result levels
found = False
for result_level in iinfo.result_levels:
if (
not result_level.is_positional()
and result_level.dim() is l.dim()
):
found = True
break
if not found:
# Create tuple representation of result levels for error message
result_dims: list[Union[int, Dim]] = []
for rl in iinfo.result_levels:
if rl.is_positional():
result_dims.append(rl.position())
else:
result_dims.append(rl.dim())
raise DimensionBindError(
f"rhs of setitem contains dimension {l.dim()!r} which is not in the dimension on the left "
f"({tuple(result_dims)!r})"
)
# Match RHS tensor to result levels
assert rhs_info.tensor is not None, "Cannot match levels on None tensor"
matched_rhs = _match_levels(
rhs_info.tensor, rhs_info.levels, iinfo.result_levels
)
else:
matched_rhs = rhs
# For advanced indexing with dimensions, we need special handling
if iinfo.advanced_indexing:
# Use advanced indexing - the flat_inputs already contain matched tensors
tup = slice_to_tuple(iinfo.flat_inputs)
if iinfo.self_tensor is None:
raise RuntimeError("Cannot setitem on None tensor")
torch._C.TensorBase.__setitem__(iinfo.self_tensor, tup, matched_rhs)
else:
# Simple copy operation
if iinfo.self_tensor is None:
raise RuntimeError("Cannot copy to None tensor")
iinfo.self_tensor.copy_(matched_rhs)
def invoke_getitem(iinfo: IndexingInfo) -> Any:
if iinfo.advanced_indexing:
self_tensor = iinfo.self_tensor
tup = slice_to_tuple(iinfo.flat_inputs)
if self_tensor is None:
raise RuntimeError("Cannot getitem on None tensor")
rtensor = self_tensor[tup]
else:
rtensor = iinfo.self_tensor # type: ignore[assignment]
if rtensor is None:
raise RuntimeError("Cannot getitem on None tensor")
# rtensor is now guaranteed to be not None
# Create a Tensor with the proper dimensions using the class method
from . import Tensor
return Tensor.from_positional(rtensor, iinfo.result_levels, iinfo.has_device)
def getsetitem(self: Any, index: Any, tensors_have_dims: bool) -> IndexingInfo:
from . import DimList # Import DimList for type checking
can_call_original_getitem = not tensors_have_dims
input_list = []
if has_dims(index):
input_list.append(index)
else:
is_sequence = extractIndices(index, input_list)
# nothing about first class dims here, fallback to getitem
if can_call_original_getitem and not is_sequence:
return IndexingInfo(can_call_original=True)
# Calculate how many dimensions have been indexed in order to compute the
# size of ... or expand a potentially unbound dimension list.
dims_indexed = 0
expanding_object = -1
unbound_dim_list = None
dimlists = [] # Track DimList positions for later processing
def check_expanding(i: int) -> None:
nonlocal expanding_object
if expanding_object != -1:
from . import DimensionBindError
raise DimensionBindError(
f"at most one ... or unbound dimension list can exist in indexing list but found 2 at offsets "
f"{expanding_object} and {i}"
)
expanding_object = i
def is_dimpack(s: Any) -> bool:
from . import Dim
return (
isinstance(s, (tuple, list))
and len(s) > 0
and all(Dim.check_exact(item) for item in s)
)
has_dimpacks_or_none = False
for i, s in enumerate(input_list):
if has_dims(s):
can_call_original_getitem = False
dims_indexed += 1
elif s is ...:
check_expanding(i)
elif isinstance(s, DimList):
can_call_original_getitem = False
if not s.is_bound:
check_expanding(i)
unbound_dim_list = s
else:
dims_indexed += len(s._dims)
dimlists.append(i)
elif s is None:
has_dimpacks_or_none = True
elif is_dimpack(s):
can_call_original_getitem = False
has_dimpacks_or_none = True
dims_indexed += 1
else:
dims_indexed += 1
# Early return if we can use original getitem
if can_call_original_getitem:
return IndexingInfo(can_call_original=True)
self_info = TensorInfo.create(self, False, True)
total_dims = len(self_info.levels) # Total dimensions (positional + named)
if dims_indexed > total_dims:
raise ValueError(
f"at least {dims_indexed} indices were supplied but the tensor only has {total_dims} dimensions"
)
# Expand any unbound dimension list, or expand ... into individual : slices.
expanding_dims = total_dims - dims_indexed
if expanding_object != -1:
if unbound_dim_list is not None:
# Bind unbound dimension list to the expanding dimensions
unbound_dim_list.bind_len(expanding_dims)
else:
# Expand ... into slice(None) objects
no_slices = [slice(None)] * expanding_dims
input_list = (
input_list[:expanding_object]
+ no_slices
+ input_list[expanding_object + 1 :]
)
# Flatten out any dimensions stored in dimlist elements directly into the inputs
# Process in reverse order to maintain indices
for i in range(len(dimlists) - 1, -1, -1):
idx = dimlists[i]
# We added more elements to input because of ...
# so we need to also adjust the index to get back to where the
# dimlist existed
if (
unbound_dim_list is None
and expanding_object != -1
and idx > expanding_object
):
idx += expanding_dims
dl = input_list[idx]
# PRIVATE here naughty
input_list = input_list[:idx] + dl._dims + input_list[idx + 1 :]
return getsetitem_flat(self_info, input_list, [], [], has_dimpacks_or_none)
def getsetitem_flat(
self_info: TensorInfo,
input_list: list,
keys: list[DimEntry],
values: list,
has_dimpacks_or_none: bool,
) -> IndexingInfo:
from . import Dim
# Track dimension usage
seen_dims: list[Any] = []
seen_dims_nuses: list[int] = []
def add_dim(dim: Any) -> None:
# Use safe indexing to avoid triggering __torch_function__ on Dim objects
idx = _safe_index(seen_dims, dim)
if idx is not None:
seen_dims_nuses[idx] += 1
else:
seen_dims.append(dim)
seen_dims_nuses.append(1)
flat_inputs = []
tensor_inputs: list[Any] = []
device_holding_tensor = None
def append_flat_handle(handle: Any) -> None:
flat_inputs.append(handle)
tensor_inputs.append(None)
def append_tensor_input(ti: TensorInfo) -> None:
flat_inputs.append(None)
tensor_inputs.append(ti)
nonlocal device_holding_tensor
if ti.has_device and device_holding_tensor is None:
device_holding_tensor = ti.tensor
nsz = []
nsd = []
if self_info.tensor is None:
raise RuntimeError("Cannot get size/stride on None tensor")
sz = self_info.tensor.size()
sd = self_info.tensor.stride()
def append_size(i: int) -> None:
if has_dimpacks_or_none:
nsz.append(sz[i])
nsd.append(sd[i])
input_it = input_list[:]
def parse_nones() -> None:
nonlocal input_it
while input_it and input_it[0] is None:
append_flat_handle(slice(None))
nsz.append(1)
nsd.append(0)
input_it = input_it[1:]
def append_item(i: int, arg: Any) -> None:
if Dim.check_exact(arg):
d = arg
if d._size == -1:
d.size = sz[i]
add_dim(d)
append_size(i)
append_flat_handle(arg)
return
info = TensorInfo.create(arg, False, False)
if info:
append_size(i)
append_tensor_input(info)
for level in info.levels:
if not level.is_positional():
add_dim(level.dim())
return
if has_dimpacks_or_none:
if isinstance(arg, (tuple, list)) and all(Dim.check_exact(d) for d in arg):
# dim pack
dim_pack = list(arg)
for d in dim_pack:
add_dim(d)
append_flat_handle(d)
_bind_dims_to_size(sz[i], sd[i], dim_pack, nsz, nsd)
return
append_size(i)
append_flat_handle(arg)
# Match indexing expressions with tensor dimensions
for i, level in enumerate(self_info.levels):
# Use safe indexing to avoid triggering __torch_function__ on DimEntry comparisons
idx = _safe_index(keys, level)
if idx is not None:
append_item(i, values[idx])
else:
if level.is_positional():
parse_nones()
if not input_it:
append_flat_handle(slice(None))
append_size(i)
else:
arg = input_it[0]
input_it = input_it[1:]
append_item(i, arg)
else:
add_dim(level.dim())
append_flat_handle(level.dim())
append_size(i)
parse_nones()
# Restride tensor if needed
if has_dimpacks_or_none and nsz:
if self_info.tensor is None:
raise RuntimeError("Cannot restride None tensor")
self_tensor = self_info.tensor.as_strided(
nsz, nsd, self_info.tensor.storage_offset()
)
else:
self_tensor = self_info.tensor
# Determine result shape and indexing requirements
result_levels: list[Any] = []
index_levels = []
tensor_insert_point = -1
requires_getindex = False
def mark_tensor_index() -> None:
nonlocal tensor_insert_point
if tensor_insert_point == -1:
tensor_insert_point = len(result_levels)
elif tensor_insert_point != len(result_levels):
tensor_insert_point = 0
for i, inp in enumerate(flat_inputs):
if tensor_inputs[i] is not None:
requires_getindex = True
mark_tensor_index()
for level in tensor_inputs[i].levels:
if level not in index_levels:
index_levels.append(level)
elif Dim.check_exact(inp):
d = inp
# Use safe indexing to avoid triggering __torch_function__
dim_idx = _safe_index(seen_dims, d)
assert dim_idx is not None, f"Dim {d} not found in seen_dims"
if seen_dims_nuses[dim_idx] == 1:
flat_inputs[i] = slice(None)
result_levels.append(DimEntry(d))
else:
requires_getindex = True
flat_inputs[i] = None
tensor_inputs[i] = TensorInfo(
d._get_range(), [DimEntry(d)], False, None
)
if DimEntry(d) not in index_levels:
index_levels.append(DimEntry(d))
mark_tensor_index()
else:
if inp != slice(None):
requires_getindex = True
if not isinstance(inp, int):
result_levels.append(DimEntry(-1))
# Insert indexing dimensions at first tensor use point
if tensor_insert_point != -1:
for level in reversed(index_levels):
result_levels.insert(tensor_insert_point, level)
# Match tensors to indexing shape
if requires_getindex:
for i in range(len(flat_inputs)):
if tensor_inputs[i] is not None:
t = tensor_inputs[i].tensor
assert t is not None, "TensorInfo should have valid tensor data"
if (
not tensor_inputs[i].has_device
and device_holding_tensor is not None
):
t = t.to(device_holding_tensor.device)
flat_inputs[i] = _match_levels(t, tensor_inputs[i].levels, index_levels)
# Number positional dimensions correctly
seen_positionals = 0
for i in reversed(range(len(result_levels))):
if result_levels[i].is_positional():
seen_positionals += 1
result_levels[i] = DimEntry(-seen_positionals)
return IndexingInfo(
can_call_original=False,
advanced_indexing=requires_getindex,
self_tensor=self_tensor,
flat_inputs=flat_inputs,
result_levels=result_levels,
has_device=self_info.has_device,
)