You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

124 lines
3.8 KiB

import bisect
from collections import defaultdict
from sympy.combinatorics import Permutation
from sympy.core.containers import Tuple
from sympy.core.numbers import Integer
def _get_mapping_from_subranks(subranks):
mapping = {}
counter = 0
for i, rank in enumerate(subranks):
for j in range(rank):
mapping[counter] = (i, j)
counter += 1
return mapping
def _get_contraction_links(args, subranks, *contraction_indices):
mapping = _get_mapping_from_subranks(subranks)
contraction_tuples = [[mapping[j] for j in i] for i in contraction_indices]
dlinks = defaultdict(dict)
for links in contraction_tuples:
if len(links) == 2:
(arg1, pos1), (arg2, pos2) = links
dlinks[arg1][pos1] = (arg2, pos2)
dlinks[arg2][pos2] = (arg1, pos1)
continue
return args, dict(dlinks)
def _sort_contraction_indices(pairing_indices):
pairing_indices = [Tuple(*sorted(i)) for i in pairing_indices]
pairing_indices.sort(key=lambda x: min(x))
return pairing_indices
def _get_diagonal_indices(flattened_indices):
axes_contraction = defaultdict(list)
for i, ind in enumerate(flattened_indices):
if isinstance(ind, (int, Integer)):
# If the indices is a number, there can be no diagonal operation:
continue
axes_contraction[ind].append(i)
axes_contraction = {k: v for k, v in axes_contraction.items() if len(v) > 1}
# Put the diagonalized indices at the end:
ret_indices = [i for i in flattened_indices if i not in axes_contraction]
diag_indices = list(axes_contraction)
diag_indices.sort(key=lambda x: flattened_indices.index(x))
diagonal_indices = [tuple(axes_contraction[i]) for i in diag_indices]
ret_indices += diag_indices
ret_indices = tuple(ret_indices)
return diagonal_indices, ret_indices
def _get_argindex(subindices, ind):
for i, sind in enumerate(subindices):
if ind == sind:
return i
if isinstance(sind, (set, frozenset)) and ind in sind:
return i
raise IndexError("%s not found in %s" % (ind, subindices))
def _apply_recursively_over_nested_lists(func, arr):
if isinstance(arr, (tuple, list, Tuple)):
return tuple(_apply_recursively_over_nested_lists(func, i) for i in arr)
elif isinstance(arr, Tuple):
return Tuple.fromiter(_apply_recursively_over_nested_lists(func, i) for i in arr)
else:
return func(arr)
def _build_push_indices_up_func_transformation(flattened_contraction_indices):
shifts = {0: 0}
i = 0
cumulative = 0
while i < len(flattened_contraction_indices):
j = 1
while i+j < len(flattened_contraction_indices):
if flattened_contraction_indices[i] + j != flattened_contraction_indices[i+j]:
break
j += 1
cumulative += j
shifts[flattened_contraction_indices[i]] = cumulative
i += j
shift_keys = sorted(shifts.keys())
def func(idx):
return shifts[shift_keys[bisect.bisect_right(shift_keys, idx)-1]]
def transform(j):
if j in flattened_contraction_indices:
return None
else:
return j - func(j)
return transform
def _build_push_indices_down_func_transformation(flattened_contraction_indices):
N = flattened_contraction_indices[-1]+2
shifts = [i for i in range(N) if i not in flattened_contraction_indices]
def transform(j):
if j < len(shifts):
return shifts[j]
else:
return j + shifts[-1] - len(shifts) + 1
return transform
def _apply_permutation_to_list(perm: Permutation, target_list: list):
"""
Permute a list according to the given permutation.
"""
new_list = [None for i in range(perm.size)]
for i, e in enumerate(target_list):
new_list[perm(i)] = e
return new_list