@@ -418,7 +418,7 @@ contains
418418 character(1), intent(in), optional :: op
419419 ${t1}$ :: alpha_
420420 character(1) :: op_
421- integer(ilp) :: i, nz, rowidx, num_chunks, rm
421+ integer(ilp) :: i, j, k, nz, rowidx, num_chunks, rm
422422
423423 op_ = sparse_op_none; if(present(op)) op_ = op
424424 alpha_ = one_${s1}$
@@ -447,7 +447,12 @@ contains
447447 do i = 1, num_chunks
448448 nz = ia(i+1) - ia(i)
449449 rowidx = (i - 1)*${chunk}$ + 1
450- call chunk_kernel_${chunk}$(nz,data(:,ia(i)),ja(:,ia(i)),vec_x,vec_y(rowidx:))
450+ associate(col => ja(1:${chunk}$,ia(i):ia(i)+nz-1), mat => data(1:${chunk}$,ia(i):ia(i)+nz-1), &
451+ & x => vec_x, y => vec_y(rowidx:rowidx+${chunk}$-1) )
452+ do j = 1, nz
453+ where(col(:,j) > 0) y = y + alpha_ * mat(:,j) * x(col(:,j))
454+ end do
455+ end associate
451456 end do
452457 #:endfor
453458 end select
@@ -457,7 +462,12 @@ contains
457462 i = num_chunks + 1
458463 nz = ia(i+1) - ia(i)
459464 rowidx = (i - 1)*cs + 1
460- call chunk_kernel_remainder(nz,cs,rm,data(:,ia(i)),ja(:,ia(i)),vec_x,vec_y(rowidx:))
465+ associate(col => ja(1:${chunk}$,ia(i):ia(i)+nz-1), mat => data(1:${chunk}$,ia(i):ia(i)+nz-1), &
466+ & x => vec_x, y => vec_y(rowidx:rowidx+rm-1) )
467+ do j = 1, nz
468+ where(col(1:rm,j) > 0) y = y + alpha_ * mat(1:rm,j) * x(col(1:rm,j))
469+ end do
470+ end associate
461471 end if
462472
463473 else if( storage == sparse_full .and. op_==sparse_op_transpose ) then
@@ -468,7 +478,14 @@ contains
468478 do i = 1, num_chunks
469479 nz = ia(i+1) - ia(i)
470480 rowidx = (i - 1)*${chunk}$ + 1
471- call chunk_kernel_trans_${chunk}$(nz,data(:,ia(i)),ja(:,ia(i)),vec_x(rowidx:),vec_y)
481+ associate(col => ja(1:${chunk}$,ia(i):ia(i)+nz-1), mat => data(1:${chunk}$,ia(i):ia(i)+nz-1), &
482+ & x => vec_x(rowidx:rowidx+${chunk}$-1), y => vec_y )
483+ do j = 1, nz
484+ do k = 1, ${chunk}$
485+ if(col(k,j) > 0) y(col(k,j)) = y(col(k,j)) + alpha_ * mat(k,j) * x(k)
486+ end do
487+ end do
488+ end associate
472489 end do
473490 #:endfor
474491 end select
@@ -478,63 +495,21 @@ contains
478495 i = num_chunks + 1
479496 nz = ia(i+1) - ia(i)
480497 rowidx = (i - 1)*cs + 1
481- call chunk_kernel_remainder_trans(nz,cs,rm,data(:,ia(i)),ja(:,ia(i)),vec_x(rowidx:),vec_y)
498+ associate(col => ja(1:${chunk}$,ia(i):ia(i)+nz-1), mat => data(1:${chunk}$,ia(i):ia(i)+nz-1), &
499+ & x => vec_x(rowidx:rowidx+rm-1), y => vec_y )
500+ do j = 1, nz
501+ do k = 1, rm
502+ if(col(k,j) > 0) y(col(k,j)) = y(col(k,j)) + alpha_ * mat(k,j) * x(k)
503+ end do
504+ end do
505+ end associate
482506 end if
483507 else
484508 print *, "error: sellc format for spmv operation not yet supported."
485509 return
486510 end if
487511 end associate
488512
489- contains
490- #:for chunk in CHUNKS
491- pure subroutine chunk_kernel_${chunk}$(n,a,col,x,y)
492- integer, value :: n
493- ${t1}$, intent(in) :: a(${chunk}$,n), x(*)
494- integer(ilp), intent(in) :: col(${chunk}$,n)
495- ${t1}$, intent(inout) :: y(${chunk}$)
496- integer :: j
497- do j = 1, n
498- where(col(:,j) > 0) y = y + alpha_ * a(:,j) * x(col(:,j))
499- end do
500- end subroutine
501- pure subroutine chunk_kernel_trans_${chunk}$(n,a,col,x,y)
502- integer, value :: n
503- ${t1}$, intent(in) :: a(${chunk}$,n), x(${chunk}$)
504- integer(ilp), intent(in) :: col(${chunk}$,n)
505- ${t1}$, intent(inout) :: y(*)
506- integer :: j, k
507- do j = 1, n
508- do k = 1, ${chunk}$
509- if(col(k,j) > 0) y(col(k,j)) = y(col(k,j)) + alpha_ * a(k,j) * x(k)
510- end do
511- end do
512- end subroutine
513- #:endfor
514-
515- pure subroutine chunk_kernel_remainder(n,cs,rm,a,col,x,y)
516- integer, value :: n, cs, rm
517- ${t1}$, intent(in) :: a(cs,n), x(*)
518- integer(ilp), intent(in) :: col(cs,n)
519- ${t1}$, intent(inout) :: y(rm)
520- integer :: j
521- do j = 1, n
522- where(col(1:rm,j) > 0) y = y + alpha_ * a(1:rm,j) * x(col(1:rm,j))
523- end do
524- end subroutine
525- pure subroutine chunk_kernel_remainder_trans(n,cs,rm,a,col,x,y)
526- integer, value :: n, cs, rm
527- ${t1}$, intent(in) :: a(cs,n), x(rm)
528- integer(ilp), intent(in) :: col(cs,n)
529- ${t1}$, intent(inout) :: y(*)
530- integer :: j, k
531- do j = 1, n
532- do k = 1, rm
533- if(col(k,j) > 0) y(col(k,j)) = y(col(k,j)) + alpha_ * a(k,j) * x(k)
534- end do
535- end do
536- end subroutine
537-
538513 end subroutine
539514
540515 #:endfor
0 commit comments