@@ -166,6 +166,45 @@ static sycl::event gemm_batch_impl(sycl::queue &exec_q,
166166 return gemm_batch_event;
167167}
168168
169+ void standardize_strides_to_nonzero (std::vector<py::ssize_t > &strides,
170+ const py::ssize_t *shape)
171+ {
172+ // When shape of an array along any particular dimension is 1, the stride
173+ // along that dimension is undefined. This function standardize the strides
174+ // by calculating the non-zero value of the strides.
175+ std::size_t ndim = strides.size ();
176+ bool has_zero_stride = std::accumulate (strides.begin (), strides.end (), 1 ,
177+ std::multiplies<py::ssize_t >{}) == 0 ;
178+
179+ if (has_zero_stride) {
180+ for (std::size_t i = 0 ; i < ndim - 1 ; ++i) {
181+ strides[i] = strides[i] == 0
182+ ? std::accumulate (shape + i + 1 , shape + ndim, 1 ,
183+ std::multiplies<py::ssize_t >{})
184+ : strides[i];
185+ }
186+ strides[ndim - 1 ] = strides[ndim - 1 ] == 0 ? 1 : strides[ndim - 1 ];
187+ }
188+ }
189+
190+ void standardize_strides_to_zero (std::vector<py::ssize_t > &strides,
191+ const py::ssize_t *shape)
192+ {
193+ // When shape of an array along any particular dimension is 1, the stride
194+ // along that dimension is undefined. This function standardize the strides
195+ // by defining such a stride as zero. This is because for these cases,
196+ // instead of copying the array into the additional dimension for batch
197+ // multiplication, we choose to use zero as the stride between different
198+ // matrices. Therefore, the same array is used repeatedly.
199+ std::size_t ndim = strides.size ();
200+
201+ for (size_t i = 0 ; i < ndim; ++i) {
202+ if (shape[i] <= 1 ) {
203+ strides[i] = 0 ;
204+ }
205+ }
206+ }
207+
169208std::tuple<sycl::event, sycl::event, bool >
170209 gemm_batch (sycl::queue &exec_q,
171210 dpctl::tensor::usm_ndarray matrixA,
@@ -240,10 +279,15 @@ std::tuple<sycl::event, sycl::event, bool>
240279 std::vector<py::ssize_t > a_stride = matrixA.get_strides_vector ();
241280 std::vector<py::ssize_t > b_stride = matrixB.get_strides_vector ();
242281 std::vector<py::ssize_t > c_stride = resultC.get_strides_vector ();
282+ standardize_strides_to_zero (a_stride, a_shape);
283+ standardize_strides_to_zero (b_stride, b_shape);
284+ standardize_strides_to_zero (c_stride, c_shape);
243285 const std::int64_t stridea = a_stride[0 ];
244286 const std::int64_t strideb = b_stride[0 ];
245287 const std::int64_t stridec = c_stride[0 ];
246288
289+ standardize_strides_to_nonzero (a_stride, a_shape);
290+ standardize_strides_to_nonzero (b_stride, b_shape);
247291 bool A_base_is_f_contig = a_stride[1 ] == 1 && a_stride[2 ] == a_shape[1 ];
248292 bool B_base_is_f_contig = b_stride[1 ] == 1 && b_stride[2 ] == b_shape[1 ];
249293
0 commit comments