@@ -359,11 +359,11 @@ function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::Abstrac
359359
360360 @inbounds if i <= size (A,1 ) && j <= size (B,2 )
361361 z2 = zero (A[i, 1 ]* B[1 , j] + A[i, 1 ]* B[1 , j])
362- Ctmp = convert (promote_type (R, typeof (z2)), z2)
362+ Cij = convert (promote_type (R, typeof (z2)), z2)
363363 for k in 1 : size (A,2 )
364- Ctmp += A[i, k]* B[k, j]
364+ Cij += A[i, k]* B[k, j]
365365 end
366- C[i,j] = add (Ctmp , C[i,j])
366+ C[i,j] = add (Cij , C[i,j])
367367 end
368368
369369 return
@@ -388,7 +388,184 @@ end
388388function LinearAlgebra. generic_matmatmul! (C:: AbstractGPUVecOrMat , tA, tB, A:: AbstractGPUVecOrMat , B:: AbstractGPUVecOrMat , a:: Number , b:: Number )
389389 LinearAlgebra. @stable_muladdmul generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), MulAddMul (a, b))
390390end
391- end
391+ end
392+
393+ function generic_trimatmul! (C:: AbstractGPUVecOrMat{R} , uploc, isunitc, tfun:: Function , A:: AbstractGPUMatrix{T} , B:: AbstractGPUVecOrMat{S} ) where {T,S,R}
394+ if size (A,2 ) != size (B,1 )
395+ throw (DimensionMismatch (lazy " matrix A has dimensions $(size(A)), matrix B has dimensions $(size(B))" ))
396+ end
397+ if size (C,1 ) != size (A,1 ) || size (C,2 ) != size (B,2 )
398+ throw (DimensionMismatch (lazy " result C has dimensions $(size(C)), needs $((size(A,1),size(B,2)))" ))
399+ end
400+ if isempty (A) || isempty (B)
401+ return fill! (C, zero (R))
402+ end
403+
404+ upper = tfun === identity ? uploc == ' U' : uploc != ' U'
405+ unit = isunitc == ' U'
406+
407+ function trimatmul (ctx, C, A, B)
408+ idx = @linearidx C
409+ assume .(size (C) .> 0 )
410+ i, j = @inbounds Tuple (CartesianIndices (C)[idx])... , 1
411+ l, m, n = size (A, 1 ), size (B, 1 ), size (B, 2 )
412+
413+ @inbounds if i <= l && j <= n
414+ z2 = zero (A[i,1 ] * B[1 ,j] + A[i,1 ] * B[1 ,j])
415+ Cij = convert (promote_type (R, typeof (z2)), z2)
416+ Cij += (unit ? one (Cij) : A[i,i]) * B[i,j]
417+ for k in (upper ? (i + 1 ) : 1 ): (upper ? m : (i - 1 ))
418+ Cij += A[i,k] * B[k,j]
419+ end
420+ C[i,j] += Cij
421+ end
422+
423+ return
424+ end
425+
426+ function trimatmul_t (ctx, C, A, B)
427+ idx = @linearidx C
428+ assume .(size (C) .> 0 )
429+ i, j = @inbounds Tuple (CartesianIndices (C)[idx])... , 1
430+ l, m, n = size (A, 1 ), size (B, 1 ), size (B, 2 )
431+
432+ @inbounds if i <= l && j <= n
433+ z2 = zero (A[i,1 ] * B[1 ,j] + A[i,1 ] * B[1 ,j])
434+ Cij = convert (promote_type (R, typeof (z2)), z2)
435+ Cij += (unit ? one (Cij) : transpose (A[i,i])) * B[i,j]
436+ for k in (upper ? (i + 1 ) : 1 ): (upper ? m : (i - 1 ))
437+ Cij += transpose (A[k,i]) * B[k,j]
438+ end
439+ C[i,j] += Cij
440+ end
441+
442+ return
443+ end
444+
445+ function trimatmul_a (ctx, C, A, B)
446+ idx = @linearidx C
447+ assume .(size (C) .> 0 )
448+ i, j = @inbounds Tuple (CartesianIndices (C)[idx])... , 1
449+ l, m, n = size (A, 1 ), size (B, 1 ), size (B, 2 )
450+
451+ @inbounds if i <= l && j <= n
452+ z2 = zero (A[i,1 ] * B[1 ,j] + A[i,1 ] * B[1 ,j])
453+ Cij = convert (promote_type (R, typeof (z2)), z2)
454+ Cij += (unit ? one (Cij) : adjoint (A[i,i])) * B[i,j]
455+ for k in (upper ? (i + 1 ) : 1 ): (upper ? m : (i - 1 ))
456+ Cij += adjoint (A[k,i]) * B[k,j]
457+ end
458+ C[i,j] += Cij
459+ end
460+
461+ return
462+ end
463+
464+ if tfun === identity
465+ gpu_call (trimatmul, C, A, B; name= " trimatmul" )
466+ elseif tfun == transpose
467+ gpu_call (trimatmul_t, C, A, B; name= " trimatmul_t" )
468+ elseif tfun === adjoint
469+ gpu_call (trimatmul_a, C, A, B; name= " trimatmul_a" )
470+ else
471+ error (" Not supported" )
472+ end
473+
474+ C
475+ end
476+
477+ function generic_mattrimul! (C:: AbstractGPUVecOrMat{R} , uploc, isunitc, tfun:: Function , A:: AbstractGPUMatrix{T} , B:: AbstractGPUVecOrMat{S} ) where {T,S,R}
478+ if size (A,2 ) != size (B,1 )
479+ throw (DimensionMismatch (lazy " matrix A has dimensions $(size(A)), matrix B has dimensions $(size(B))" ))
480+ end
481+ if size (C,1 ) != size (A,1 ) || size (C,2 ) != size (B,2 )
482+ throw (DimensionMismatch (lazy " result C has dimensions $(size(C)), needs $((size(A,1),size(B,2)))" ))
483+ end
484+ if isempty (A) || isempty (B)
485+ return fill! (C, zero (R))
486+ end
487+
488+ upper = tfun === identity ? uploc == ' U' : uploc != ' U'
489+ unit = isunitc == ' U'
490+
491+ function mattrimul (ctx, C, A, B)
492+ idx = @linearidx C
493+ assume .(size (C) .> 0 )
494+ i, j = @inbounds Tuple (CartesianIndices (C)[idx])... , 1
495+ l, m, n = size (A, 1 ), size (B, 1 ), size (B, 2 )
496+
497+ @inbounds if i <= l && j <= n
498+ z2 = zero (A[i,1 ] * B[1 ,j] + A[i,1 ] * B[1 ,j])
499+ Cij = convert (promote_type (R, typeof (z2)), z2)
500+ Cij += A[i,j] * (unit ? one (Cij) : B[j,j])
501+ for k in (upper ? 1 : (j + 1 )): (upper ? (j - 1 ) : m)
502+ Cij += A[i,k] * B[k,j]
503+ end
504+ C[i,j] += Cij
505+ end
506+
507+ return
508+ end
509+
510+ function mattrimul_t (ctx, C, A, B)
511+ idx = @linearidx C
512+ assume .(size (C) .> 0 )
513+ i, j = @inbounds Tuple (CartesianIndices (C)[idx])... , 1
514+ l, m, n = size (A, 1 ), size (B, 1 ), size (B, 2 )
515+
516+ @inbounds if i <= l && j <= n
517+ z2 = zero (A[i,1 ] * B[1 ,j] + A[i,1 ] * B[1 ,j])
518+ Cij = convert (promote_type (R, typeof (z2)), z2)
519+ Cij += A[i,j] * (unit ? one (Cij) : transpose (B[j,j]))
520+ for k in (upper ? 1 : (j + 1 ) ): (upper ? (j - 1 ) : m)
521+ Cij += A[i,k] * transpose (B[j,k])
522+ end
523+ C[i,j] += Cij
524+ end
525+
526+ return
527+ end
528+
529+ function mattrimul_a (ctx, C, A, B)
530+ idx = @linearidx C
531+ assume .(size (C) .> 0 )
532+ i, j = @inbounds Tuple (CartesianIndices (C)[idx])... , 1
533+ l, m, n = size (A, 1 ), size (B, 1 ), size (B, 2 )
534+
535+ @inbounds if i <= l && j <= n
536+ z2 = zero (A[i,1 ] * B[1 ,j] + A[i,1 ] * B[1 ,j])
537+ Cij = convert (promote_type (R, typeof (z2)), z2)
538+ Cij += A[i,j] * (unit ? one (Cij) : adjoint (B[j,j]))
539+ for k in (upper ? 1 : (j + 1 )): (upper ? (j - 1 ) : m)
540+ Cij += A[i,k] * adjoint (B[j,k])
541+ end
542+ C[i,j] += Cij
543+ end
544+
545+ return
546+ end
547+
548+ if tfun === identity
549+ gpu_call (mattrimul, C, A, B; name= " mattrimul" )
550+ elseif tfun == transpose
551+ gpu_call (mattrimul_t, C, A, B; name= " mattrimul_t" )
552+ elseif tfun === adjoint
553+ gpu_call (mattrimul_a, C, A, B; name= " mattrimul_a" )
554+ else
555+ error (" Not supported" )
556+ end
557+
558+ C
559+ end
560+
561+ if VERSION >= v " 1.10-"
562+ function LinearAlgebra. generic_trimatmul! (C:: AbstractGPUVecOrMat , uploc, isunitc, tfun:: Function , A:: AbstractGPUMatrix , B:: AbstractGPUVecOrMat )
563+ generic_trimatmul! (C, uploc, isunitc, tfun, A, B)
564+ end
565+ function LinearAlgebra. generic_mattrimul! (C:: AbstractGPUMatrix , uploc, isunitc, tfun:: Function , A:: AbstractGPUMatrix , B:: AbstractGPUMatrix )
566+ generic_mattrimul! (C, uploc, isunitc, tfun, A, B)
567+ end
568+ end
392569
393570if VERSION < v " 1.10.0-DEV.1365"
394571# catch other functions that are called by LinearAlgebra's mul!
0 commit comments