@@ -332,7 +332,7 @@ const TILE_DIM = 16
332332# legacy method
333333generic_matmatmul! (C:: AbstractArray , A:: AbstractArray , B:: AbstractArray , a:: Number , b:: Number ) =
334334 generic_matmatmul! (C, A, B, MulAddMul (a, b))
335- function generic_matmatmul! (C:: AbstractArray {R} , A:: AbstractArray {T} , B:: AbstractArray {S} , add:: MulAddMul ) where {T,S,R}
335+ function generic_matmatmul! (C:: AbstractGPUMatrix {R} , A:: AbstractGPUMatrix {T} , B:: AbstractGPUMatrix {S} , add:: MulAddMul ) where {T,S,R}
336336 N = size (A,1 )
337337 Q = size (A,2 )
338338 M = size (B,2 )
@@ -347,7 +347,7 @@ function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::Abstrac
347347 end
348348
349349 @kernel unsafe_indices= true function coalesced_matmul_kernel! (
350- output, @Const ( input1), @Const ( input2) , N, Q, M,
350+ output, input1, input2, N, Q, M,
351351 :: Val{BANK} = Val (1 ),
352352 ) where {BANK}
353353 grow, gcol = @index (Group, NTuple)
@@ -363,7 +363,6 @@ function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::Abstrac
363363 outval = @private R 1
364364 @inbounds outval[1 ] = - zero (R)
365365
366- # @uniform N = size(output, 1)
367366 # number of tiles depends on inner dimension
368367 @uniform NUM_TILES = div (Q + TILE_DIM - 1 , TILE_DIM)
369368
@@ -406,6 +405,34 @@ function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::Abstrac
406405 coalesced_matmul_kernel! (get_backend (C), (TILE_DIM, TILE_DIM))(C, A, B, N, Q, M;ndrange= map (x -> ceil (Int,x/ TILE_DIM)* TILE_DIM, size (C)))
407406 C
408407end
408+ function generic_matmatmul! (C:: AbstractArray{R} , A:: AbstractArray{T} , B:: AbstractArray{S} , add:: MulAddMul ) where {T,S,R}
409+ if size (A,2 ) != size (B,1 )
410+ throw (DimensionMismatch (" matrix A has dimensions $(size (A)) , matrix B has dimensions $(size (B)) " ))
411+ end
412+ if size (C,1 ) != size (A,1 ) || size (C,2 ) != size (B,2 )
413+ throw (DimensionMismatch (" result C has dimensions $(size (C)) , needs $((size (A,1 ),size (B,2 ))) " ))
414+ end
415+ if isempty (A) || isempty (B)
416+ return fill! (C, zero (R))
417+ end
418+
419+ @kernel function matmatmul_kernel! (C, A, B)
420+ assume .(size (C) .> 0 )
421+ idx = @index (Global, Linear)
422+ i, j = @inbounds Tuple (CartesianIndices (C)[idx])... , 1
423+
424+ @inbounds if i <= size (A,1 ) && j <= size (B,2 )
425+ z2 = zero (A[i, 1 ]* B[1 , j] + A[i, 1 ]* B[1 , j])
426+ Cij = convert (promote_type (R, typeof (z2)), z2)
427+ for k in 1 : size (A,2 )
428+ Cij += A[i, k]* B[k, j]
429+ end
430+ C[i,j] = add (Cij, C[i,j])
431+ end
432+ end
433+ matmatmul_kernel! (get_backend (C))(C, A, B; ndrange = size (C))
434+ C
435+ end
409436
410437@static if VERSION < v " 1.12.0-"
411438function LinearAlgebra. generic_matvecmul! (C:: AbstractGPUVector , tA:: AbstractChar , A:: AbstractGPUMatrix , B:: AbstractGPUVector , _add:: MulAddMul = MulAddMul ())
0 commit comments