@@ -55,27 +55,73 @@ submodule (stdlib_linalg) stdlib_linalg_solve
5555 type(linalg_state_type), intent(out) :: err
5656 !> Result array/matrix x[n] or x[n,nrhs]
5757 ${rt}$, allocatable, target :: x${nd}$
58+
59+ ! Initialize solution shape from the rhs array
60+ allocate(x,mold=b)
61+
62+ call stdlib_linalg_${ri}$_solve_lu_${ndsuf}$(a,b,x,overwrite_a=overwrite_a,err=err)
63+
64+ end function stdlib_linalg_${ri}$_solve_${ndsuf}$
65+
66+ !> Compute the solution to a real system of linear equations A * X = B (pure interface)
67+ pure module function stdlib_linalg_${ri}$_pure_solve_${ndsuf}$(a,b) result(x)
68+ !> Input matrix a[n,n]
69+ ${rt}$, intent(in) :: a(:,:)
70+ !> Right hand side vector or array, b[n] or b[n,nrhs]
71+ ${rt}$, intent(in) :: b${nd}$
72+ !> Result array/matrix x[n] or x[n,nrhs]
73+ ${rt}$, allocatable, target :: x${nd}$
74+
75+ ! Local variables
76+ ${rt}$, allocatable :: amat(:,:)
77+
78+ ! Copy `a` so it can be intent(in)
79+ allocate(amat,source=a)
80+
81+ ! Initialize solution shape from the rhs array
82+ allocate(x,mold=b)
83+
84+ call stdlib_linalg_${ri}$_solve_lu_${ndsuf}$(amat,b,x,overwrite_a=.true.)
85+
86+ end function stdlib_linalg_${ri}$_pure_solve_${ndsuf}$
87+
88+ !> Compute the solution to a real system of linear equations A * X = B (pure interface)
89+ pure module subroutine stdlib_linalg_${ri}$_solve_lu_${ndsuf}$(a,b,x,pivot,overwrite_a,err)
90+ !> Input matrix a[n,n]
91+ ${rt}$, intent(inout), target :: a(:,:)
92+ !> Right hand side vector or array, b[n] or b[n,nrhs]
93+ ${rt}$, intent(in) :: b${nd}$
94+ !> Result array/matrix x[n] or x[n,nrhs]
95+ ${rt}$, intent(inout), contiguous, target :: x${nd}$
96+ !> [optional] Storage array for the diagonal pivot indices
97+ integer(ilp), optional, intent(inout), target :: pivot(:)
98+ !> [optional] Can A data be overwritten and destroyed?
99+ logical(lk), optional, intent(in) :: overwrite_a
100+ !> [optional] state return flag. On error if not requested, the code will stop
101+ type(linalg_state_type), optional, intent(out) :: err
58102
59103 ! Local variables
60104 type(linalg_state_type) :: err0
61- integer(ilp) :: lda,n,ldb,nrhs,info
62- integer(ilp), allocatable :: ipiv(:)
105+ integer(ilp) :: lda,n,ldb,ldx,nrhsx, nrhs,info,npiv
106+ integer(ilp), pointer :: ipiv(:)
63107 logical(lk) :: copy_a
64108 ${rt}$, pointer :: xmat(:,:),amat(:,:)
65109
66110 ! Problem sizes
67- lda = size(a,1,kind=ilp)
68- n = size(a,2,kind=ilp)
69- ldb = size(b,1,kind=ilp)
70- nrhs = size(b ,kind=ilp)/ldb
71-
72- if (any([lda,n,ldb]<1) .or. any([lda,ldb]/=n)) then
73- err0 = linalg_state_type(this,LINALG_VALUE_ERROR,'invalid sizes: a=',[lda,n], &
74- ', b=',[ldb,nrhs])
75- allocate(x${nde}$)
76- call linalg_error_handling(err0,err)
77- return
78- end if
111+ lda = size(a,1,kind=ilp)
112+ n = size(a,2,kind=ilp)
113+ ldb = size(b,1,kind=ilp)
114+ nrhs = size(b ,kind=ilp)/ldb
115+ ldx = size(x,1,kind=ilp)
116+ nrhsx = size(x ,kind=ilp)/ldx
117+
118+ ! Has a pre-allocated pivots storage array been provided?
119+ if (present(pivot)) then
120+ ipiv => pivot
121+ else
122+ allocate(ipiv(n))
123+ endif
124+ npiv = size(ipiv,kind=ilp)
79125
80126 ! Can A be overwritten? By default, do not overwrite
81127 if (present(overwrite_a)) then
@@ -84,8 +130,13 @@ submodule (stdlib_linalg) stdlib_linalg_solve
84130 copy_a = .true._lk
85131 endif
86132
87- ! Pivot indices
88- allocate(ipiv(n))
133+ if (any([lda,n,ldb]<1) .or. any([lda,ldb,ldx]/=n) .or. nrhsx/=nrhs .or. npiv/=n) then
134+ err0 = linalg_state_type(this,LINALG_VALUE_ERROR,'invalid sizes: a=',[lda,n], &
135+ 'b=',[ldb,nrhs],' x=',[ldx,nrhsx], &
136+ 'pivot=',n)
137+ call linalg_error_handling(err0,err)
138+ return
139+ end if
89140
90141 ! Initialize a matrix temporary
91142 if (copy_a) then
@@ -95,7 +146,7 @@ submodule (stdlib_linalg) stdlib_linalg_solve
95146 endif
96147
97148 ! Initialize solution with the rhs
98- allocate(x,source=b)
149+ x = b
99150 xmat(1:n,1:nrhs) => x
100151
101152 ! Solve system
@@ -105,64 +156,13 @@ submodule (stdlib_linalg) stdlib_linalg_solve
105156 call handle_gesv_info(info,lda,n,nrhs,err0)
106157
107158 if (copy_a) deallocate(amat)
159+ if (.not.present(pivot)) deallocate(ipiv)
108160
109161 ! Process output and return
110162 call linalg_error_handling(err0,err)
111163
112- end function stdlib_linalg_${ri}$_solve_${ndsuf}$
113-
114- !> Compute the solution to a real system of linear equations A * X = B (pure interface)
115- pure module function stdlib_linalg_${ri}$_pure_solve_${ndsuf}$(a,b) result(x)
116- !> Input matrix a[n,n]
117- ${rt}$, intent(in), target :: a(:,:)
118- !> Right hand side vector or array, b[n] or b[n,nrhs]
119- ${rt}$, intent(in) :: b${nd}$
120- !> Result array/matrix x[n] or x[n,nrhs]
121- ${rt}$, allocatable, target :: x${nd}$
122-
123- ! Local variables
124- type(linalg_state_type) :: err0
125- integer(ilp) :: lda,n,ldb,nrhs,info
126- integer(ilp), allocatable :: ipiv(:)
127- ${rt}$, pointer :: xmat(:,:)
128- ${rt}$, allocatable :: amat(:,:)
129-
130- ! Problem sizes
131- lda = size(a,1,kind=ilp)
132- n = size(a,2,kind=ilp)
133- ldb = size(b,1,kind=ilp)
134- nrhs = size(b ,kind=ilp)/ldb
135-
136- if (any([lda,n,ldb]<1) .or. any([lda,ldb]/=n)) then
137- err0 = linalg_state_type(this,LINALG_VALUE_ERROR,'invalid sizes: a=',[lda,n], &
138- ', b=',[ldb,nrhs])
139- allocate(x${nde}$)
140- call linalg_error_handling(err0)
141- return
142- end if
143-
144- ! Pivot indices
145- allocate(ipiv(n))
146-
147- ! Initialize a matrix temporary
148- allocate(amat,source=a)
149-
150- ! Initialize solution with the rhs
151- allocate(x,source=b)
152- xmat(1:n,1:nrhs) => x
153-
154- ! Solve system
155- call gesv(n,nrhs,amat,lda,ipiv,xmat,ldb,info)
156-
157- ! Process output
158- call handle_gesv_info(info,lda,n,nrhs,err0)
159-
160- deallocate(amat)
161-
162- ! Process output and return
163- call linalg_error_handling(err0)
164-
165- end function stdlib_linalg_${ri}$_pure_solve_${ndsuf}$
164+ end subroutine stdlib_linalg_${ri}$_solve_lu_${ndsuf}$
165+
166166 #:endif
167167 #:endfor
168168 #:endfor
0 commit comments