@@ -325,11 +325,92 @@ function LinearAlgebra.ldiv!(B::AbstractGPUVecOrMat,
325325 B
326326end
327327
328+ # XXX : figure out how to do dynamically
329+ MAX_TILE_DIM = 16
328330
329331# # matrix multiplication
330332# legacy method
331333generic_matmatmul! (C:: AbstractArray , A:: AbstractArray , B:: AbstractArray , a:: Number , b:: Number ) =
332334 generic_matmatmul! (C, A, B, MulAddMul (a, b))
335+ function generic_matmatmul! (C:: AbstractGPUMatrix{R} , A:: AbstractGPUMatrix{T} , B:: AbstractGPUMatrix{S} , add:: MulAddMul ) where {T<: Number ,S<: Number ,R<: Number }
336+ N = size (A,1 )
337+ Q = size (A,2 )
338+ M = size (B,2 )
339+ if Q != size (B,1 )
340+ throw (DimensionMismatch (" matrix A has dimensions $(size (A)) , matrix B has dimensions $(size (B)) " ))
341+ end
342+ if size (C,1 ) != N || size (C,2 ) != M
343+ throw (DimensionMismatch (" result C has dimensions $(size (C)) , needs $((N,M)) " ))
344+ end
345+ if isempty (A) || isempty (B)
346+ return fill! (C, zero (R))
347+ end
348+
349+ @kernel unsafe_indices= true function coalesced_matmul_kernel! (
350+ output, @Const (input1), @Const (input2), N, Q, M,
351+ :: Val{BANK} = Val (1 ),
352+ ) where {BANK}
353+ grow, gcol = @index (Group, NTuple)
354+ tile_row, tile_col = @index (Local, NTuple)
355+
356+ TILE_DIM = @uniform @groupsize ()[1 ]
357+
358+ # +1 to avoid bank conflicts on shared memory
359+ tile1 = @localmem (R, (TILE_DIM + BANK, TILE_DIM))
360+ tile2 = @localmem (R, (TILE_DIM + BANK, TILE_DIM))
361+
362+ # private variable for tile output
363+ outval = @private R 1
364+ @inbounds outval[1 ] = - zero (R)
365+
366+ # number of tiles depends on inner dimension
367+ @uniform NUM_TILES = div (Q + TILE_DIM - 1 , TILE_DIM)
368+
369+ # loop over all tiles needed for this calculation
370+ for t in 0 : (NUM_TILES - 1 )
371+ I = (grow - 1 ) * TILE_DIM + tile_row
372+ J = (gcol - 1 ) * TILE_DIM + tile_col
373+
374+ # load inputs into tiles, with bounds checking for non-square matrices
375+ if I <= N && t * TILE_DIM + tile_col <= Q
376+ @inbounds tile1[tile_row, tile_col] = input1[I, t * TILE_DIM + tile_col]
377+ else
378+ @inbounds tile1[tile_row, tile_col] = zero (R)
379+ end
380+ if J <= M && t * TILE_DIM + tile_row <= Q
381+ @inbounds tile2[tile_row, tile_col] = input2[t * TILE_DIM + tile_row, J]
382+ else
383+ @inbounds tile2[tile_row, tile_col] = zero (R)
384+ end
385+
386+ # wait for all tiles to be loaded
387+ @synchronize
388+
389+ I = (grow - 1 ) * TILE_DIM + tile_row
390+ J = (gcol - 1 ) * TILE_DIM + tile_col
391+
392+ # calculate value of spot in output, use temporary value to allow for vectorization
393+ out = zero (R)
394+ @simd for k in 1 : TILE_DIM
395+ @inbounds out += tile1[tile_row, k] * tile2[k, tile_col]
396+ end
397+ outval[1 ] += out
398+
399+ @synchronize
400+ end
401+
402+ I = (grow - 1 ) * TILE_DIM + tile_row
403+ J = (gcol - 1 ) * TILE_DIM + tile_col
404+
405+ # save if inbounds
406+ if I <= N && J <= M
407+ @inbounds output[I, J] = add (outval[1 ], output[I, J])
408+ end
409+ end
410+
411+ coalesced_matmul_kernel! (get_backend (C), (MAX_TILE_DIM, MAX_TILE_DIM))(C, A, B, N, Q, M;ndrange= map (x -> ceil (Int,x/ MAX_TILE_DIM)* MAX_TILE_DIM, size (C)))
412+ C
413+ end
333414function generic_matmatmul! (C:: AbstractArray{R} , A:: AbstractArray{T} , B:: AbstractArray{S} , add:: MulAddMul ) where {T,S,R}
334415 if size (A,2 ) != size (B,1 )
335416 throw (DimensionMismatch (" matrix A has dimensions $(size (A)) , matrix B has dimensions $(size (B)) " ))
@@ -744,7 +825,7 @@ function LinearAlgebra.kron!(z::AbstractGPUVector{T1}, x::AbstractGPUVector{T2},
744825
745826 @kernel function kron_kernel! (z, @Const (x), @Const (y))
746827 i, j = @index (Global, NTuple)
747-
828+
748829 @inbounds z[(i - 1 ) * length (y) + j] = x[i] * y[j]
749830 end
750831
@@ -777,13 +858,13 @@ for (wrapa, transa, unwrapa) in trans_adj_wrappers, (wrapb, transb, unwrapb) in
777858
778859 ta = $ transa (T1)
779860 tb = $ transb (T2)
780-
861+
781862 @kernel function kron_kernel! (C, @Const (A), @Const (B))
782863 ai, aj = @index (Global, NTuple) # Indices in the result matrix
783-
864+
784865 # lb1, lb2 = size(B) # Dimensions of B
785866 lb1, lb2 = tb == ' N' ? size (B) : reverse (size (B))
786-
867+
787868 # Map global indices (ai, aj) to submatrices of the Kronecker product
788869 i_a = (ai - 1 ) ÷ lb1 + 1 # Corresponding row index in A
789870 i_b = (ai - 1 ) % lb1 + 1 # Corresponding row index in B
@@ -797,12 +878,12 @@ for (wrapa, transa, unwrapa) in trans_adj_wrappers, (wrapb, transb, unwrapb) in
797878 C[ai, aj] = a_ij * b_ij
798879 end
799880 end
800-
881+
801882 backend = KernelAbstractions. get_backend (C)
802883 kernel = kron_kernel! (backend)
803-
884+
804885 kernel (C, $ (unwrapa (:A )), $ (unwrapb (:B )), ndrange= (size (C, 1 ), size (C, 2 )))
805-
886+
806887 return C
807888 end
808889
0 commit comments