Skip to content
This repository was archived by the owner on Nov 7, 2024. It is now read-only.

Commit aae89ae

Browse files
committed
Fix trace for PyTorch backend
1 parent 95861b2 commit aae89ae

File tree

1 file changed

+2
-31
lines changed

1 file changed

+2
-31
lines changed

tensornetwork/backends/pytorch/pytorch_backend.py

Lines changed: 2 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -410,45 +410,16 @@ def trace(self, tensor: Tensor, offset: int = 0, axis1: int = -2,
410410
axis1 and axis2 are used to determine the 2-D sub-array whose diagonal is
411411
summed.
412412
413-
In the PyTorch backend the trace is always over the main diagonal of the
414-
last two entries.
415-
416413
Args:
417414
tensor: A tensor.
418415
offset: Offset of the diagonal from the main diagonal.
419-
This argument is not supported by the PyTorch
420-
backend and an error will be raised if they are
421-
specified.
422416
axis1, axis2: Axis to be used as the first/second axis of the 2D
423417
sub-arrays from which the diagonals should be taken.
424-
Defaults to first/second axis.
425-
These arguments are not supported by the PyTorch
426-
backend and an error will be raised if they are
427-
specified.
418+
Defaults to second-last/last axis.
428419
Returns:
429420
array_of_diagonals: The batched summed diagonals.
430421
"""
431-
if offset != 0:
432-
errstr = (f"offset = {offset} must be 0 (the default)"
433-
f"with PyTorch backend.")
434-
raise NotImplementedError(errstr)
435-
if axis1 == axis2:
436-
raise ValueError(f"axis1 = {axis1} cannot equal axis2 = {axis2}")
437-
N = len(tensor.shape)
438-
if N > 25:
439-
raise ValueError(f"Currently only tensors with ndim <= 25 can be traced"
440-
f"in the PyTorch backend (yours was {N})")
441-
442-
if axis1 < 0:
443-
axis1 = N+axis1
444-
if axis2 < 0:
445-
axis2 = N+axis2
446-
447-
inds = list(map(chr, range(98, 98+N)))
448-
indsout = [i for n, i in enumerate(inds) if n not in (axis1, axis2)]
449-
inds[axis1] = 'a'
450-
inds[axis2] = 'a'
451-
return torchlib.einsum(''.join(inds) + '->' +''.join(indsout), tensor)
422+
return torchlib.sum(torchlib.diagonal(tensor, offset=offset, dim1=axis1, dim2=axis2), dim=-1)
452423

453424
def abs(self, tensor: Tensor) -> Tensor:
454425
"""

0 commit comments

Comments
 (0)