@@ -325,37 +325,85 @@ function LinearAlgebra.ldiv!(B::AbstractGPUVecOrMat,
325325 B
326326end
327327
328+ # XXX : figure out how to do dynamically
329+ const TILE_DIM = 16
328330
329331# # matrix multiplication
330332# legacy method
331333generic_matmatmul! (C:: AbstractArray , A:: AbstractArray , B:: AbstractArray , a:: Number , b:: Number ) =
332334 generic_matmatmul! (C, A, B, MulAddMul (a, b))
333335function generic_matmatmul! (C:: AbstractArray{R} , A:: AbstractArray{T} , B:: AbstractArray{S} , add:: MulAddMul ) where {T,S,R}
334- if size (A,2 ) != size (B,1 )
336+ N = size (A,1 )
337+ Q = size (A,2 )
338+ M = size (B,2 )
339+ if Q != size (B,1 )
335340 throw (DimensionMismatch (" matrix A has dimensions $(size (A)) , matrix B has dimensions $(size (B)) " ))
336341 end
337- if size (C,1 ) != size (A, 1 ) || size (C,2 ) != size (B, 2 )
338- throw (DimensionMismatch (" result C has dimensions $(size (C)) , needs $((size (A, 1 ), size (B, 2 ) )) " ))
342+ if size (C,1 ) != N || size (C,2 ) != M
343+ throw (DimensionMismatch (" result C has dimensions $(size (C)) , needs $((N,M )) " ))
339344 end
340345 if isempty (A) || isempty (B)
341346 return fill! (C, zero (R))
342347 end
343348
344- @kernel function matmatmul_kernel! (C, A, B)
345- assume .(size (C) .> 0 )
346- idx = @index (Global, Linear)
347- i, j = @inbounds Tuple (CartesianIndices (C)[idx])... , 1
349+ @kernel unsafe_indices= true function coalesced_matmul_kernel! (
350+ output, @Const (input1), @Const (input2), N, Q, M,
351+ :: Val{BANK} = Val (1 ),
352+ ) where {BANK}
353+ grow, gcol = @index (Group, NTuple)
354+ tile_row, tile_col = @index (Local, NTuple)
348355
349- @inbounds if i <= size (A,1 ) && j <= size (B,2 )
350- z2 = zero (A[i, 1 ]* B[1 , j] + A[i, 1 ]* B[1 , j])
351- Cij = convert (promote_type (R, typeof (z2)), z2)
352- for k in 1 : size (A,2 )
353- Cij += A[i, k]* B[k, j]
356+ # TILE_DIM = @uniform @groupsize()[1]
357+
358+ # +1 to avoid bank conflicts on shared memory
359+ tile1 = @localmem (R, (TILE_DIM + BANK, TILE_DIM))
360+ tile2 = @localmem (R, (TILE_DIM + BANK, TILE_DIM))
361+
362+ # private variable for tile output
363+ outval = @private R 1
364+ @inbounds outval[1 ] = - zero (R)
365+
366+ # @uniform N = size(output, 1)
367+ # number of tiles depends on inner dimension
368+ @uniform NUM_TILES = div (Q + TILE_DIM - 1 , TILE_DIM)
369+
370+ I = (grow - 1 ) * TILE_DIM + tile_row
371+ J = (gcol - 1 ) * TILE_DIM + tile_col
372+
373+ # loop over all tiles needed for this calculation
374+ for t in 0 : (NUM_TILES - 1 )
375+ # load inputs into tiles, with bounds checking for non-square matrices
376+ if I <= N && t * TILE_DIM + tile_col <= Q
377+ @inbounds tile1[tile_row, tile_col] = input1[I, t * TILE_DIM + tile_col]
378+ else
379+ @inbounds tile1[tile_row, tile_col] = zero (R)
380+ end
381+ if J <= M && t * TILE_DIM + tile_row <= Q
382+ @inbounds tile2[tile_row, tile_col] = input2[t * TILE_DIM + tile_row, J]
383+ else
384+ @inbounds tile2[tile_row, tile_col] = zero (R)
354385 end
355- C[i,j] = add (Cij, C[i,j])
386+
387+ # wait for all tiles to be loaded
388+ @synchronize
389+
390+ # calculate value of spot in output, use temporary value to allow for vectorization
391+ out = zero (R)
392+ @simd for k in 1 : TILE_DIM
393+ @inbounds out += tile1[tile_row, k] * tile2[k, tile_col]
394+ end
395+ outval[1 ] += out
396+
397+ @synchronize
398+ end
399+
400+ # save if inbounds
401+ if I <= N && J <= M
402+ @inbounds output[I, J] = add (outval[1 ], output[I, J])
356403 end
357404 end
358- matmatmul_kernel! (get_backend (C))(C, A, B; ndrange = size (C))
405+
406+ coalesced_matmul_kernel! (get_backend (C), (TILE_DIM, TILE_DIM))(C, A, B, N, Q, M;ndrange= map (x -> ceil (Int,x/ TILE_DIM)* TILE_DIM, size (C)))
359407 C
360408end
361409
@@ -744,7 +792,7 @@ function LinearAlgebra.kron!(z::AbstractGPUVector{T1}, x::AbstractGPUVector{T2},
744792
745793 @kernel function kron_kernel! (z, @Const (x), @Const (y))
746794 i, j = @index (Global, NTuple)
747-
795+
748796 @inbounds z[(i - 1 ) * length (y) + j] = x[i] * y[j]
749797 end
750798
@@ -777,13 +825,13 @@ for (wrapa, transa, unwrapa) in trans_adj_wrappers, (wrapb, transb, unwrapb) in
777825
778826 ta = $ transa (T1)
779827 tb = $ transb (T2)
780-
828+
781829 @kernel function kron_kernel! (C, @Const (A), @Const (B))
782830 ai, aj = @index (Global, NTuple) # Indices in the result matrix
783-
831+
784832 # lb1, lb2 = size(B) # Dimensions of B
785833 lb1, lb2 = tb == ' N' ? size (B) : reverse (size (B))
786-
834+
787835 # Map global indices (ai, aj) to submatrices of the Kronecker product
788836 i_a = (ai - 1 ) ÷ lb1 + 1 # Corresponding row index in A
789837 i_b = (ai - 1 ) % lb1 + 1 # Corresponding row index in B
@@ -797,12 +845,12 @@ for (wrapa, transa, unwrapa) in trans_adj_wrappers, (wrapb, transb, unwrapb) in
797845 C[ai, aj] = a_ij * b_ij
798846 end
799847 end
800-
848+
801849 backend = KernelAbstractions. get_backend (C)
802850 kernel = kron_kernel! (backend)
803-
851+
804852 kernel (C, $ (unwrapa (:A )), $ (unwrapb (:B )), ndrange= (size (C, 1 ), size (C, 2 )))
805-
853+
806854 return C
807855 end
808856
0 commit comments