@@ -198,13 +198,14 @@ def _define_dim_flags(x, axis):
198198 """
199199 Define useful flags for the calculations in dpnp_matmul and dpnp_vecdot.
200200 x_is_1D: `x` is 1D array or inherently 1D (all dimensions are equal to one
201- except for one of them ), for instance, if x.shape = (1, 1, 1, 2),
202- then x_is_1D = True
201+ except for dimension at `axis` ), for instance, if x.shape = (1, 1, 1, 2),
202+ and axis=-1, then x_is_1D = True.
203203 x_is_2D: `x` is 2D array or inherently 2D (all dimensions are equal to one
204204 except for the last two of them), for instance, if x.shape = (1, 1, 3, 2),
205- then x_is_2D = True
205+ then x_is_2D = True.
206206 x_base_is_1D: `x` is 1D considering only its last two dimensions, for instance,
207- if x.shape = (3, 4, 1, 2), then x_base_is_1D = True
207+ if x.shape = (3, 4, 1, 2), then x_base_is_1D = True.
208+
208209 """
209210
210211 x_shape = x .shape
@@ -326,14 +327,11 @@ def _get_result_shape_vecdot(x1, x2, x1_ndim, x2_ndim):
326327 if x1_shape [- 1 ] != x2_shape [- 1 ]:
327328 _shape_error (x1_shape [- 1 ], x2_shape [- 1 ], "vecdot" , err_msg = 0 )
328329
329- _ , x1_is_1D , _ = _define_dim_flags (x1 , axis = - 1 )
330- _ , x2_is_1D , _ = _define_dim_flags (x2 , axis = - 1 )
331-
332330 if x1_ndim == 1 and x2_ndim == 1 :
333331 result_shape = ()
334- elif x1_is_1D :
332+ elif x1_ndim == 1 :
335333 result_shape = x2_shape [:- 1 ]
336- elif x2_is_1D :
334+ elif x2_ndim == 1 :
337335 result_shape = x1_shape [:- 1 ]
338336 else : # at least 2D
339337 if x1_ndim != x2_ndim :
0 commit comments