@@ -410,45 +410,16 @@ def trace(self, tensor: Tensor, offset: int = 0, axis1: int = -2,
410410 axis1 and axis2 are used to determine the 2-D sub-array whose diagonal is
411411 summed.
412412
413- In the PyTorch backend the trace is always over the main diagonal of the
414- last two entries.
415-
416413 Args:
417414 tensor: A tensor.
418415 offset: Offset of the diagonal from the main diagonal.
419- This argument is not supported by the PyTorch
420- backend and an error will be raised if they are
421- specified.
422416 axis1, axis2: Axis to be used as the first/second axis of the 2D
423417 sub-arrays from which the diagonals should be taken.
424- Defaults to first/second axis.
425- These arguments are not supported by the PyTorch
426- backend and an error will be raised if they are
427- specified.
418+ Defaults to second-last/last axis.
428419 Returns:
429420 array_of_diagonals: The batched summed diagonals.
430421 """
431- if offset != 0 :
432- errstr = (f"offset = { offset } must be 0 (the default)"
433- f"with PyTorch backend." )
434- raise NotImplementedError (errstr )
435- if axis1 == axis2 :
436- raise ValueError (f"axis1 = { axis1 } cannot equal axis2 = { axis2 } " )
437- N = len (tensor .shape )
438- if N > 25 :
439- raise ValueError (f"Currently only tensors with ndim <= 25 can be traced"
440- f"in the PyTorch backend (yours was { N } )" )
441-
442- if axis1 < 0 :
443- axis1 = N + axis1
444- if axis2 < 0 :
445- axis2 = N + axis2
446-
447- inds = list (map (chr , range (98 , 98 + N )))
448- indsout = [i for n , i in enumerate (inds ) if n not in (axis1 , axis2 )]
449- inds [axis1 ] = 'a'
450- inds [axis2 ] = 'a'
451- return torchlib .einsum ('' .join (inds ) + '->' + '' .join (indsout ), tensor )
422+ return torchlib .sum (torchlib .diagonal (tensor , offset = offset , dim1 = axis1 , dim2 = axis2 ), dim = - 1 )
452423
453424 def abs (self , tensor : Tensor ) -> Tensor :
454425 """
0 commit comments