You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
562 lines
18 KiB
562 lines
18 KiB
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Optional, TYPE_CHECKING, Union
|
|
|
|
import torch
|
|
|
|
from ._dim_entry import _match_levels, DimEntry
|
|
from ._tensor_info import TensorInfo
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from . import Dim
|
|
|
|
|
|
def _safe_index(lst: list, item: Any) -> Optional[int]:
|
|
"""
|
|
Helper function to find index of item in list.
|
|
|
|
For DimEntry objects, uses __eq__ comparison which properly handles
|
|
both positional and Dim entries.
|
|
|
|
Returns the index if found, None if not found.
|
|
"""
|
|
for i, list_item in enumerate(lst):
|
|
# Use == for DimEntry objects as they have proper __eq__ implementation
|
|
if isinstance(item, DimEntry) and isinstance(list_item, DimEntry):
|
|
if list_item == item:
|
|
return i
|
|
elif list_item is item:
|
|
return i
|
|
return None
|
|
|
|
|
|
@dataclass
|
|
class IndexingInfo:
|
|
can_call_original: bool = False
|
|
advanced_indexing: bool = False
|
|
self_tensor: Optional[torch.Tensor] = None
|
|
flat_inputs: list[Any] = field(default_factory=list)
|
|
result_levels: list[DimEntry] = field(default_factory=list)
|
|
has_device: bool = False
|
|
|
|
|
|
def has_dims(obj: Any) -> bool:
|
|
"""
|
|
Check if an object has first-class dimensions.
|
|
|
|
This function checks if the object is either a Dim or a functorch Tensor
|
|
that has first-class dimensions, using the proper check_exact methods.
|
|
"""
|
|
from . import Dim, Tensor
|
|
|
|
return Dim.check_exact(obj) or Tensor.check_exact(obj)
|
|
|
|
|
|
def _bind_dims_to_size(sz: int, sd: int, dims: list, nsz: list, nsd: list) -> None:
|
|
"""
|
|
Bind dimensions to size and calculate proper strides for dim packs.
|
|
"""
|
|
from . import DimensionBindError
|
|
|
|
rhs_prod = 1
|
|
for i, dim in enumerate(dims):
|
|
if not dim.is_bound:
|
|
# Check for multiple unbound dimensions
|
|
for j in range(i + 1, len(dims)):
|
|
if not dims[j].is_bound:
|
|
raise DimensionBindError(
|
|
f"cannot infer the sizes of two dimensions at once {dim!r} and {dims[j]!r}"
|
|
)
|
|
rhs_prod *= dims[j].size
|
|
|
|
# Calculate the size for this unbound dimension
|
|
if sz % rhs_prod != 0:
|
|
tup = tuple(dim.size if dim.is_bound else "?" for dim in dims)
|
|
raise DimensionBindError(
|
|
f"inferred dimension does not evenly fit into larger dimension: {sz} vs {tup}"
|
|
)
|
|
|
|
inferred_size = sz // rhs_prod
|
|
dim.size = inferred_size
|
|
rhs_prod = sz
|
|
break
|
|
else:
|
|
rhs_prod *= dim.size
|
|
|
|
# Final validation that dimensions match
|
|
if rhs_prod != sz:
|
|
tup = tuple(dims)
|
|
raise DimensionBindError(
|
|
f"Dimension sizes to do not match ({sz} != {rhs_prod}) when matching dimension pack {tup}"
|
|
)
|
|
|
|
# Calculate new sizes and strides for each dimension in the pack
|
|
# First calculate all strides by iterating in reverse
|
|
new_strides = [0] * len(dims)
|
|
current_stride = sd
|
|
for i in reversed(range(len(dims))):
|
|
new_strides[i] = current_stride
|
|
current_stride *= dims[i].size
|
|
|
|
# Then append sizes and strides in forward order
|
|
for i in range(len(dims)):
|
|
nsz.append(dims[i].size)
|
|
nsd.append(new_strides[i])
|
|
|
|
|
|
def slice_to_tuple(flat_inputs: list) -> tuple:
|
|
return tuple(flat_inputs)
|
|
|
|
|
|
def extractIndices(index: Any, indices: list) -> bool:
|
|
if isinstance(index, tuple): # mpy::tuple_view::check
|
|
indices.extend(index)
|
|
return True
|
|
elif isinstance(index, torch.Tensor): # THPVariable_Check
|
|
indices.append(index)
|
|
return False
|
|
elif not hasattr(index, "__iter__") or isinstance(
|
|
index, (str, bytes)
|
|
): # !mpy::is_sequence
|
|
indices.append(index)
|
|
return False
|
|
|
|
# Handle sequence case (list)
|
|
if isinstance(index, list):
|
|
if len(index) >= 32:
|
|
indices.extend(index)
|
|
return True
|
|
|
|
# Check each item in the sequence
|
|
for item in index:
|
|
if (
|
|
isinstance(item, (torch.Tensor, slice))
|
|
or hasattr(item, "__iter__")
|
|
or item is ...
|
|
or item is None
|
|
or has_dims(item)
|
|
):
|
|
indices.extend(index)
|
|
return True
|
|
|
|
# If we got here, treat as single index
|
|
indices.append(index)
|
|
return False
|
|
|
|
# Default case
|
|
indices.append(index)
|
|
return False
|
|
|
|
|
|
def getitem(cls: Any, func: Any, types: Any, args: Any, kwargs: Any) -> Any:
|
|
self = args[0]
|
|
index = args[1]
|
|
|
|
iinfo = getsetitem(self, index, has_dims(self))
|
|
if iinfo.can_call_original:
|
|
# Call original tensor __getitem__ directly, bypassing __torch_function__
|
|
return torch.Tensor.__getitem__(self, index)
|
|
|
|
return invoke_getitem(iinfo)
|
|
|
|
|
|
def setitem(self: Any, index: Any, rhs: Any) -> None:
|
|
"""Set values in tensor using first-class dimensions."""
|
|
from . import DimensionBindError, TensorInfo
|
|
|
|
iinfo = getsetitem(self, index, has_dims(self) or has_dims(rhs))
|
|
|
|
if iinfo.can_call_original:
|
|
# Call original tensor __setitem__ directly, bypassing __torch_function__
|
|
torch._C.TensorBase.__setitem__(self, index, rhs)
|
|
return
|
|
|
|
# Handle RHS tensor with dimensions
|
|
rhs_info = TensorInfo.create(rhs, False, False)
|
|
|
|
if rhs_info:
|
|
# Check that rhs dimensions are compatible with result dimensions
|
|
for l in rhs_info.levels:
|
|
if not l.is_positional():
|
|
# Find this dimension in result levels
|
|
found = False
|
|
for result_level in iinfo.result_levels:
|
|
if (
|
|
not result_level.is_positional()
|
|
and result_level.dim() is l.dim()
|
|
):
|
|
found = True
|
|
break
|
|
|
|
if not found:
|
|
# Create tuple representation of result levels for error message
|
|
result_dims: list[Union[int, Dim]] = []
|
|
for rl in iinfo.result_levels:
|
|
if rl.is_positional():
|
|
result_dims.append(rl.position())
|
|
else:
|
|
result_dims.append(rl.dim())
|
|
|
|
raise DimensionBindError(
|
|
f"rhs of setitem contains dimension {l.dim()!r} which is not in the dimension on the left "
|
|
f"({tuple(result_dims)!r})"
|
|
)
|
|
|
|
# Match RHS tensor to result levels
|
|
assert rhs_info.tensor is not None, "Cannot match levels on None tensor"
|
|
matched_rhs = _match_levels(
|
|
rhs_info.tensor, rhs_info.levels, iinfo.result_levels
|
|
)
|
|
else:
|
|
matched_rhs = rhs
|
|
|
|
# For advanced indexing with dimensions, we need special handling
|
|
if iinfo.advanced_indexing:
|
|
# Use advanced indexing - the flat_inputs already contain matched tensors
|
|
tup = slice_to_tuple(iinfo.flat_inputs)
|
|
if iinfo.self_tensor is None:
|
|
raise RuntimeError("Cannot setitem on None tensor")
|
|
torch._C.TensorBase.__setitem__(iinfo.self_tensor, tup, matched_rhs)
|
|
else:
|
|
# Simple copy operation
|
|
if iinfo.self_tensor is None:
|
|
raise RuntimeError("Cannot copy to None tensor")
|
|
iinfo.self_tensor.copy_(matched_rhs)
|
|
|
|
|
|
def invoke_getitem(iinfo: IndexingInfo) -> Any:
|
|
if iinfo.advanced_indexing:
|
|
self_tensor = iinfo.self_tensor
|
|
tup = slice_to_tuple(iinfo.flat_inputs)
|
|
if self_tensor is None:
|
|
raise RuntimeError("Cannot getitem on None tensor")
|
|
rtensor = self_tensor[tup]
|
|
else:
|
|
rtensor = iinfo.self_tensor # type: ignore[assignment]
|
|
if rtensor is None:
|
|
raise RuntimeError("Cannot getitem on None tensor")
|
|
# rtensor is now guaranteed to be not None
|
|
|
|
# Create a Tensor with the proper dimensions using the class method
|
|
from . import Tensor
|
|
|
|
return Tensor.from_positional(rtensor, iinfo.result_levels, iinfo.has_device)
|
|
|
|
|
|
def getsetitem(self: Any, index: Any, tensors_have_dims: bool) -> IndexingInfo:
|
|
from . import DimList # Import DimList for type checking
|
|
|
|
can_call_original_getitem = not tensors_have_dims
|
|
|
|
input_list = []
|
|
if has_dims(index):
|
|
input_list.append(index)
|
|
else:
|
|
is_sequence = extractIndices(index, input_list)
|
|
# nothing about first class dims here, fallback to getitem
|
|
if can_call_original_getitem and not is_sequence:
|
|
return IndexingInfo(can_call_original=True)
|
|
|
|
# Calculate how many dimensions have been indexed in order to compute the
|
|
# size of ... or expand a potentially unbound dimension list.
|
|
dims_indexed = 0
|
|
expanding_object = -1
|
|
unbound_dim_list = None
|
|
dimlists = [] # Track DimList positions for later processing
|
|
|
|
def check_expanding(i: int) -> None:
|
|
nonlocal expanding_object
|
|
if expanding_object != -1:
|
|
from . import DimensionBindError
|
|
|
|
raise DimensionBindError(
|
|
f"at most one ... or unbound dimension list can exist in indexing list but found 2 at offsets "
|
|
f"{expanding_object} and {i}"
|
|
)
|
|
expanding_object = i
|
|
|
|
def is_dimpack(s: Any) -> bool:
|
|
from . import Dim
|
|
|
|
return (
|
|
isinstance(s, (tuple, list))
|
|
and len(s) > 0
|
|
and all(Dim.check_exact(item) for item in s)
|
|
)
|
|
|
|
has_dimpacks_or_none = False
|
|
for i, s in enumerate(input_list):
|
|
if has_dims(s):
|
|
can_call_original_getitem = False
|
|
dims_indexed += 1
|
|
elif s is ...:
|
|
check_expanding(i)
|
|
elif isinstance(s, DimList):
|
|
can_call_original_getitem = False
|
|
if not s.is_bound:
|
|
check_expanding(i)
|
|
unbound_dim_list = s
|
|
else:
|
|
dims_indexed += len(s._dims)
|
|
dimlists.append(i)
|
|
elif s is None:
|
|
has_dimpacks_or_none = True
|
|
elif is_dimpack(s):
|
|
can_call_original_getitem = False
|
|
has_dimpacks_or_none = True
|
|
dims_indexed += 1
|
|
else:
|
|
dims_indexed += 1
|
|
|
|
# Early return if we can use original getitem
|
|
if can_call_original_getitem:
|
|
return IndexingInfo(can_call_original=True)
|
|
|
|
self_info = TensorInfo.create(self, False, True)
|
|
total_dims = len(self_info.levels) # Total dimensions (positional + named)
|
|
if dims_indexed > total_dims:
|
|
raise ValueError(
|
|
f"at least {dims_indexed} indices were supplied but the tensor only has {total_dims} dimensions"
|
|
)
|
|
|
|
# Expand any unbound dimension list, or expand ... into individual : slices.
|
|
expanding_dims = total_dims - dims_indexed
|
|
if expanding_object != -1:
|
|
if unbound_dim_list is not None:
|
|
# Bind unbound dimension list to the expanding dimensions
|
|
unbound_dim_list.bind_len(expanding_dims)
|
|
else:
|
|
# Expand ... into slice(None) objects
|
|
no_slices = [slice(None)] * expanding_dims
|
|
input_list = (
|
|
input_list[:expanding_object]
|
|
+ no_slices
|
|
+ input_list[expanding_object + 1 :]
|
|
)
|
|
|
|
# Flatten out any dimensions stored in dimlist elements directly into the inputs
|
|
# Process in reverse order to maintain indices
|
|
for i in range(len(dimlists) - 1, -1, -1):
|
|
idx = dimlists[i]
|
|
|
|
# We added more elements to input because of ...
|
|
# so we need to also adjust the index to get back to where the
|
|
# dimlist existed
|
|
if (
|
|
unbound_dim_list is None
|
|
and expanding_object != -1
|
|
and idx > expanding_object
|
|
):
|
|
idx += expanding_dims
|
|
|
|
dl = input_list[idx]
|
|
|
|
# PRIVATE here naughty
|
|
input_list = input_list[:idx] + dl._dims + input_list[idx + 1 :]
|
|
|
|
return getsetitem_flat(self_info, input_list, [], [], has_dimpacks_or_none)
|
|
|
|
|
|
def getsetitem_flat(
|
|
self_info: TensorInfo,
|
|
input_list: list,
|
|
keys: list[DimEntry],
|
|
values: list,
|
|
has_dimpacks_or_none: bool,
|
|
) -> IndexingInfo:
|
|
from . import Dim
|
|
|
|
# Track dimension usage
|
|
seen_dims: list[Any] = []
|
|
seen_dims_nuses: list[int] = []
|
|
|
|
def add_dim(dim: Any) -> None:
|
|
# Use safe indexing to avoid triggering __torch_function__ on Dim objects
|
|
idx = _safe_index(seen_dims, dim)
|
|
if idx is not None:
|
|
seen_dims_nuses[idx] += 1
|
|
else:
|
|
seen_dims.append(dim)
|
|
seen_dims_nuses.append(1)
|
|
|
|
flat_inputs = []
|
|
tensor_inputs: list[Any] = []
|
|
device_holding_tensor = None
|
|
|
|
def append_flat_handle(handle: Any) -> None:
|
|
flat_inputs.append(handle)
|
|
tensor_inputs.append(None)
|
|
|
|
def append_tensor_input(ti: TensorInfo) -> None:
|
|
flat_inputs.append(None)
|
|
tensor_inputs.append(ti)
|
|
nonlocal device_holding_tensor
|
|
if ti.has_device and device_holding_tensor is None:
|
|
device_holding_tensor = ti.tensor
|
|
|
|
nsz = []
|
|
nsd = []
|
|
if self_info.tensor is None:
|
|
raise RuntimeError("Cannot get size/stride on None tensor")
|
|
sz = self_info.tensor.size()
|
|
sd = self_info.tensor.stride()
|
|
|
|
def append_size(i: int) -> None:
|
|
if has_dimpacks_or_none:
|
|
nsz.append(sz[i])
|
|
nsd.append(sd[i])
|
|
|
|
input_it = input_list[:]
|
|
|
|
def parse_nones() -> None:
|
|
nonlocal input_it
|
|
while input_it and input_it[0] is None:
|
|
append_flat_handle(slice(None))
|
|
nsz.append(1)
|
|
nsd.append(0)
|
|
input_it = input_it[1:]
|
|
|
|
def append_item(i: int, arg: Any) -> None:
|
|
if Dim.check_exact(arg):
|
|
d = arg
|
|
if d._size == -1:
|
|
d.size = sz[i]
|
|
add_dim(d)
|
|
append_size(i)
|
|
append_flat_handle(arg)
|
|
return
|
|
|
|
info = TensorInfo.create(arg, False, False)
|
|
if info:
|
|
append_size(i)
|
|
append_tensor_input(info)
|
|
for level in info.levels:
|
|
if not level.is_positional():
|
|
add_dim(level.dim())
|
|
return
|
|
|
|
if has_dimpacks_or_none:
|
|
if isinstance(arg, (tuple, list)) and all(Dim.check_exact(d) for d in arg):
|
|
# dim pack
|
|
dim_pack = list(arg)
|
|
for d in dim_pack:
|
|
add_dim(d)
|
|
append_flat_handle(d)
|
|
_bind_dims_to_size(sz[i], sd[i], dim_pack, nsz, nsd)
|
|
return
|
|
|
|
append_size(i)
|
|
append_flat_handle(arg)
|
|
|
|
# Match indexing expressions with tensor dimensions
|
|
for i, level in enumerate(self_info.levels):
|
|
# Use safe indexing to avoid triggering __torch_function__ on DimEntry comparisons
|
|
idx = _safe_index(keys, level)
|
|
if idx is not None:
|
|
append_item(i, values[idx])
|
|
else:
|
|
if level.is_positional():
|
|
parse_nones()
|
|
if not input_it:
|
|
append_flat_handle(slice(None))
|
|
append_size(i)
|
|
else:
|
|
arg = input_it[0]
|
|
input_it = input_it[1:]
|
|
append_item(i, arg)
|
|
else:
|
|
add_dim(level.dim())
|
|
append_flat_handle(level.dim())
|
|
append_size(i)
|
|
|
|
parse_nones()
|
|
|
|
# Restride tensor if needed
|
|
if has_dimpacks_or_none and nsz:
|
|
if self_info.tensor is None:
|
|
raise RuntimeError("Cannot restride None tensor")
|
|
self_tensor = self_info.tensor.as_strided(
|
|
nsz, nsd, self_info.tensor.storage_offset()
|
|
)
|
|
else:
|
|
self_tensor = self_info.tensor
|
|
|
|
# Determine result shape and indexing requirements
|
|
result_levels: list[Any] = []
|
|
index_levels = []
|
|
tensor_insert_point = -1
|
|
requires_getindex = False
|
|
|
|
def mark_tensor_index() -> None:
|
|
nonlocal tensor_insert_point
|
|
if tensor_insert_point == -1:
|
|
tensor_insert_point = len(result_levels)
|
|
elif tensor_insert_point != len(result_levels):
|
|
tensor_insert_point = 0
|
|
|
|
for i, inp in enumerate(flat_inputs):
|
|
if tensor_inputs[i] is not None:
|
|
requires_getindex = True
|
|
mark_tensor_index()
|
|
for level in tensor_inputs[i].levels:
|
|
if level not in index_levels:
|
|
index_levels.append(level)
|
|
elif Dim.check_exact(inp):
|
|
d = inp
|
|
# Use safe indexing to avoid triggering __torch_function__
|
|
dim_idx = _safe_index(seen_dims, d)
|
|
assert dim_idx is not None, f"Dim {d} not found in seen_dims"
|
|
if seen_dims_nuses[dim_idx] == 1:
|
|
flat_inputs[i] = slice(None)
|
|
result_levels.append(DimEntry(d))
|
|
else:
|
|
requires_getindex = True
|
|
flat_inputs[i] = None
|
|
tensor_inputs[i] = TensorInfo(
|
|
d._get_range(), [DimEntry(d)], False, None
|
|
)
|
|
if DimEntry(d) not in index_levels:
|
|
index_levels.append(DimEntry(d))
|
|
mark_tensor_index()
|
|
else:
|
|
if inp != slice(None):
|
|
requires_getindex = True
|
|
if not isinstance(inp, int):
|
|
result_levels.append(DimEntry(-1))
|
|
|
|
# Insert indexing dimensions at first tensor use point
|
|
if tensor_insert_point != -1:
|
|
for level in reversed(index_levels):
|
|
result_levels.insert(tensor_insert_point, level)
|
|
|
|
# Match tensors to indexing shape
|
|
if requires_getindex:
|
|
for i in range(len(flat_inputs)):
|
|
if tensor_inputs[i] is not None:
|
|
t = tensor_inputs[i].tensor
|
|
assert t is not None, "TensorInfo should have valid tensor data"
|
|
if (
|
|
not tensor_inputs[i].has_device
|
|
and device_holding_tensor is not None
|
|
):
|
|
t = t.to(device_holding_tensor.device)
|
|
flat_inputs[i] = _match_levels(t, tensor_inputs[i].levels, index_levels)
|
|
|
|
# Number positional dimensions correctly
|
|
seen_positionals = 0
|
|
for i in reversed(range(len(result_levels))):
|
|
if result_levels[i].is_positional():
|
|
seen_positionals += 1
|
|
result_levels[i] = DimEntry(-seen_positionals)
|
|
|
|
return IndexingInfo(
|
|
can_call_original=False,
|
|
advanced_indexing=requires_getindex,
|
|
self_tensor=self_tensor,
|
|
flat_inputs=flat_inputs,
|
|
result_levels=result_levels,
|
|
has_device=self_info.has_device,
|
|
)
|