@@ -695,12 +695,18 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -
695695 axis = 0
696696 return torch .index_select (x , axis , indices , ** kwargs )
697697
698-
699-
700698# Note: torch.linalg.cross does not default to axis=-1 (it defaults to the
701699# first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743
700+
701+ # torch.cross also does not support broadcasting when it would add new
702+ # dimensions https://github.com/pytorch/pytorch/issues/39656
702703def cross (x1 : array , x2 : array , / , * , axis : int = - 1 ) -> array :
703704 x1 , x2 = _fix_promotion (x1 , x2 , only_scalar = False )
705+ if not (- builtin_min (x1 .ndim , x2 .ndim ) <= axis < builtin_max (x1 .ndim , x2 .ndim )):
706+ raise ValueError (f"axis { axis } out of bounds for cross product of arrays with shapes { x1 .shape } and { x2 .shape } " )
707+ if not (x1 .shape [axis ] == x2 .shape [axis ] == 3 ):
708+ raise ValueError (f"cross product axis must have size 3, got { x1 .shape [axis ]} and { x2 .shape [axis ]} " )
709+ x1 , x2 = torch .broadcast_tensors (x1 , x2 )
704710 return torch .linalg .cross (x1 , x2 , dim = axis )
705711
706712def vecdot_linalg (x1 : array , x2 : array , / , * , axis : int = - 1 , ** kwargs ) -> array :
0 commit comments