2525# *****************************************************************************
2626
2727import dpctl .tensor as dpt
28+ from dpctl .tensor ._numpy_helper import AxisError
2829
2930import dpnp
3031
@@ -205,6 +206,7 @@ def __bool__(self):
205206 return self ._array_obj .__bool__ ()
206207
207208 # '__class__',
209+ # `__class_getitem__`,
208210
209211 def __complex__ (self ):
210212 return self ._array_obj .__complex__ ()
@@ -335,6 +337,8 @@ def __getitem__(self, key):
335337 res ._array_obj = item
336338 return res
337339
340+ # '__getstate__',
341+
338342 def __gt__ (self , other ):
339343 """Return ``self>value``."""
340344 return dpnp .greater (self , other )
@@ -361,7 +365,31 @@ def __ilshift__(self, other):
361365 dpnp .left_shift (self , other , out = self )
362366 return self
363367
364- # '__imatmul__',
368+ def __imatmul__ (self , other ):
369+ """Return ``self@=value``."""
370+
371+ """
372+ Unlike `matmul(a, b, out=a)` we ensure that the result is not broadcast
373+ if the result without `out` would have less dimensions than `a`.
374+ Since the signature of matmul is '(n?,k),(k,m?)->(n?,m?)' this is the
375+ case exactly when the second operand has both core dimensions.
376+ We have to enforce this check by passing the correct `axes=`.
377+ """
378+ if self .ndim == 1 :
379+ axes = [(- 1 ,), (- 2 , - 1 ), (- 1 ,)]
380+ else :
381+ axes = [(- 2 , - 1 ), (- 2 , - 1 ), (- 2 , - 1 )]
382+
383+ try :
384+ dpnp .matmul (self , other , out = self , axes = axes )
385+ except AxisError :
386+ # AxisError should indicate that the axes argument didn't work out
387+ # which should mean the second operand not being 2 dimensional.
388+ raise ValueError (
389+ "inplace matrix multiplication requires the first operand to "
390+ "have at least one and the second at least two dimensions."
391+ )
392+ return self
365393
366394 def __imod__ (self , other ):
367395 """Return ``self%=value``."""
@@ -469,9 +497,11 @@ def __pow__(self, other):
469497 return dpnp .power (self , other )
470498
471499 def __radd__ (self , other ):
500+ """Return ``value+self``."""
472501 return dpnp .add (other , self )
473502
474503 def __rand__ (self , other ):
504+ """Return ``value&self``."""
475505 return dpnp .bitwise_and (other , self )
476506
477507 # '__rdivmod__',
@@ -483,40 +513,51 @@ def __repr__(self):
483513 return dpt .usm_ndarray_repr (self ._array_obj , prefix = "array" )
484514
485515 def __rfloordiv__ (self , other ):
516+ """Return ``value//self``."""
486517 return dpnp .floor_divide (self , other )
487518
488519 def __rlshift__ (self , other ):
520+ """Return ``value<<self``."""
489521 return dpnp .left_shift (other , self )
490522
491523 def __rmatmul__ (self , other ):
524+ """Return ``value@self``."""
492525 return dpnp .matmul (other , self )
493526
494527 def __rmod__ (self , other ):
528+ """Return ``value%self``."""
495529 return dpnp .remainder (other , self )
496530
497531 def __rmul__ (self , other ):
532+ """Return ``value*self``."""
498533 return dpnp .multiply (other , self )
499534
500535 def __ror__ (self , other ):
536+ """Return ``value|self``."""
501537 return dpnp .bitwise_or (other , self )
502538
503539 def __rpow__ (self , other ):
540+ """Return ``value**self``."""
504541 return dpnp .power (other , self )
505542
506543 def __rrshift__ (self , other ):
544+ """Return ``value>>self``."""
507545 return dpnp .right_shift (other , self )
508546
509547 def __rshift__ (self , other ):
510548 """Return ``self>>value``."""
511549 return dpnp .right_shift (self , other )
512550
513551 def __rsub__ (self , other ):
552+ """Return ``value-self``."""
514553 return dpnp .subtract (other , self )
515554
516555 def __rtruediv__ (self , other ):
556+ """Return ``value/self``."""
517557 return dpnp .true_divide (other , self )
518558
519559 def __rxor__ (self , other ):
560+ """Return ``value^self``."""
520561 return dpnp .bitwise_xor (other , self )
521562
522563 # '__setattr__',
0 commit comments