@@ -13,7 +13,7 @@ contains
1313 ! Internal use only!
1414 pure function matmul_chain_order(p) result(s)
1515 integer, intent(in) :: p(:)
16- integer :: s(1:size(p) - 2, 2: size(p) - 1), m(1: size(p) - 1, 1: size(p) - 1)
16+ integer :: s(1:size(p) - 2, 2:size(p) - 1), m(1:size(p) - 1, 1:size(p) - 1)
1717 integer :: n, l, i, j, k, q
1818 n = size(p) - 1
1919 m(:,:) = 0
@@ -40,35 +40,102 @@ contains
4040
4141 pure function matmul_chain_mult_${s}$_3 (m1, m2, m3, start, s) result(r)
4242 ${t}$, intent(in) :: m1(:,:), m2(:,:), m3(:,:)
43- integer, intent(in) :: start, s(:,:)
43+ integer, intent(in) :: start, s(:,2 :)
4444 ${t}$, allocatable :: r(:,:)
45+ integer :: tmp
46+ tmp = s(start, start + 2)
47+
48+ if (tmp == start) then
49+ r = matmul(m1, matmul(m2, m3))
50+ else if (tmp == start + 1) then
51+ r = matmul(matmul(m1, m2), m3)
52+ else
53+ error stop "stdlib_matmul: error: unexpected s(i,j)"
54+ end if
4555
46- select case (s(start, start + 2))
47- case (1)
48- r = matmul(m1, matmul(m2, m3))
49- case (2)
50- r = matmul(matmul(m1, m2), m3)
51- case default
52- error stop "stdlib_matmul: unexpected error unexpected s(i,j)"
53- end select
5456 end function matmul_chain_mult_${s}$_3
5557
5658 pure function matmul_chain_mult_${s}$_4 (m1, m2, m3, m4, start, s) result(r)
5759 ${t}$, intent(in) :: m1(:,:), m2(:,:), m3(:,:), m4(:,:)
58- integer, intent(in) :: start, s(:,:)
60+ integer, intent(in) :: start, s(:,2:)
61+ ${t}$, allocatable :: r(:,:)
62+ integer :: tmp
63+ tmp = s(start, start + 3)
64+
65+ if (tmp == start) then
66+ r = matmul(m1, matmul_chain_mult_${s}$_3(m2, m3, m4, start + 1, s))
67+ else if (tmp == start + 1) then
68+ r = matmul(matmul(m1, m2), matmul(m3, m4))
69+ else if (tmp == start + 2) then
70+ r = matmul(matmul_chain_mult_${s}$_3(m1, m2, m3, start, s), m4)
71+ else
72+ error stop "stdlib_matmul: error: unexpected s(i,j)"
73+ end if
74+
75+ end function matmul_chain_mult_${s}$_4
76+
77+ pure module function stdlib_matmul_${s}$ (m1, m2, m3, m4, m5) result(r)
78+ ${t}$, intent(in) :: m1(:,:), m2(:,:)
79+ ${t}$, intent(in), optional :: m3(:,:), m4(:,:), m5(:,:)
5980 ${t}$, allocatable :: r(:,:)
81+ integer :: p(6), num_present
82+ integer, allocatable :: s(:,:)
6083
61- select case (s(start, start + 3))
84+ p(1) = size(m1, 1)
85+ p(2) = size(m2, 1)
86+ p(3) = size(m2, 2)
87+
88+ num_present = 2
89+ if (present(m3)) then
90+ p(3) = size(m3, 1)
91+ p(4) = size(m3, 2)
92+ num_present = num_present + 1
93+ end if
94+ if (present(m4)) then
95+ p(4) = size(m4, 1)
96+ p(5) = size(m4, 2)
97+ num_present = num_present + 1
98+ end if
99+ if (present(m5)) then
100+ p(5) = size(m5, 1)
101+ p(6) = size(m5, 2)
102+ num_present = num_present + 1
103+ end if
104+
105+ if (num_present == 2) then
106+ r = matmul(m1, m2)
107+ return
108+ end if
109+
110+ ! Now num_present >= 3
111+ allocate(s(1:num_present - 1, 2:num_present))
112+
113+ s = matmul_chain_order(p(1: num_present + 1))
114+
115+ if (num_present == 3) then
116+ r = matmul_chain_mult_${s}$_3(m1, m2, m3, 1, s)
117+ return
118+ else if (num_present == 4) then
119+ r = matmul_chain_mult_${s}$_4(m1, m2, m3, m4, 1, s)
120+ return
121+ end if
122+
123+ ! Now num_present is 5
124+
125+ select case (s(1, 5))
62126 case (1)
63- r = matmul(m1, matmul_chain_mult_${s}$_3 (m2, m3, m4, start + 1 , s))
127+ r = matmul(m1, matmul_chain_mult_${s}$_4 (m2, m3, m4, m5, 2 , s))
64128 case (2)
65- r = matmul(matmul(m1, m2), matmul (m3, m4))
129+ r = matmul(matmul(m1, m2), matmul_chain_mult_${s}$_3 (m3, m4, m5, 3, s ))
66130 case (3)
67- r = matmul(matmul_chain_mult_${s}$_3(m1, m2, m3, start, s), m4)
131+ r = matmul(matmul_chain_mult_${s}$_3(m1, m2, m3, 1, s), matmul(m4, m5))
132+ case (4)
133+ r = matmul(matmul_chain_mult_${s}$_4(m1, m2, m3, m4, 1, s), m5)
68134 case default
69- error stop "stdlib_matmul: unexpected error unexpected s(i,j)"
135+ error stop "stdlib_matmul: error: unexpected s(i,j)"
70136 end select
71- end function matmul_chain_mult_${s}$_4
137+
138+ end function stdlib_matmul_${s}$
72139
73140#:endfor
74141end submodule stdlib_intrinsics_matmul
0 commit comments