@@ -37,99 +37,38 @@ contains
3737 end function matmul_chain_order
3838
3939#:for k, t, s in I_KINDS_TYPES + R_KINDS_TYPES + C_KINDS_TYPES
40+
41+ pure function matmul_chain_mult_${s}$_3 (m1, m2, m3, start, s) result(r)
42+ ${t}$, intent(in) :: m1(:,:), m2(:,:), m3(:,:)
43+ integer, intent(in) :: start, s(:,:)
44+ ${t}$, allocatable :: r(:,:)
4045
41- pure module function stdlib_matmul_${s}$_3 (a, b, c) result(d)
42- ${t}$, intent(in) :: a(:,:), b(:,:), c(:,:)
43- ${t}$, allocatable :: d(:,:)
44- integer :: sa(2), sb(2), sc(2), cost1, cost2
45- sa = shape(a)
46- sb = shape(b)
47- sc = shape(c)
48-
49- if ((sa(2) /= sb(1)) .or. (sb(2) /= sc(1))) then
50- error stop "stdlib_matmul: Incompatible array shapes"
51- end if
52-
53- ! computes the cost (number of scalar multiplications required)
54- ! cost(A, B) = shape(A)(1) * shape(A)(2) * shape(B)(2)
55- cost1 = sa(1) * sa(2) * sb(2) + sa(1) * sb(2) * sc(2) ! ((AB)C)
56- cost2 = sb(1) * sb(2) * sc(2) + sa(1) * sa(2) * sc(2) ! (A(BC))
57-
58- if (cost1 < cost2) then
59- d = matmul(matmul(a, b), c)
60- else
61- d = matmul(a, matmul(b, c))
62- end if
63- end function stdlib_matmul_${s}$_3
64-
65- pure module function stdlib_matmul_${s}$_4 (a, b, c, d) result(e)
66- ${t}$, intent(in) :: a(:,:), b(:,:), c(:,:), d(:,:)
67- ${t}$, allocatable :: e(:,:)
68- integer :: p(5), i
69- integer :: s(3,2:4)
70-
71- p(1) = size(a, 1)
72- p(2) = size(b, 1)
73- p(3) = size(c, 1)
74- p(4) = size(d, 1)
75- p(5) = size(d, 2)
76-
77- s = matmul_chain_order(p)
78-
79- select case (s(1,4))
46+ select case (s(start, start + 2))
8047 case (1)
81- select case (s(2, 4))
82- case (2)
83- e = matmul(a, matmul(b, matmul(c, d)))
84- case (3)
85- e = matmul(a, matmul(matmul(b, c), d))
86- case default
87- error stop "stdlib_matmul: unexpected error unexpected s(i,j)"
88- end select
48+ r = matmul(m1, matmul(m2, m3))
8949 case (2)
90- e = matmul(matmul(a, b), matmul(c, d))
91- case (3)
92- select case (s(1, 3))
93- case (1)
94- e = matmul(matmul(a, matmul(b, c)), d)
95- case (2)
96- e = matmul(matmul(matmul(a, b), c), d)
97- case default
98- error stop "stdlib_matmul: unexpected error unexpected s(i,j)"
99- end select
50+ r = matmul(matmul(m1, m2), m3)
10051 case default
10152 error stop "stdlib_matmul: unexpected error unexpected s(i,j)"
10253 end select
103- end function stdlib_matmul_${s}$_4
104-
105- pure module function stdlib_matmul_${s}$_5 (a, b, c, d, e) result(f)
106- ${t}$, intent(in) :: a(:,:), b(:,:), c(:,:), d(:,:), e(:,:)
107- ${t}$, allocatable :: f(:,:)
108- integer :: p(6), i
109- integer :: s(4,2:5)
110-
111- p(1) = size(a, 1)
112- p(2) = size(b, 1)
113- p(3) = size(c, 1)
114- p(4) = size(d, 1)
115- p(5) = size(e, 1)
116- p(6) = size(e, 2)
54+ end function matmul_chain_mult_${s}$_3
11755
118- s = matmul_chain_order(p)
56+ pure function matmul_chain_mult_${s}$_4 (m1, m2, m3, m4, start, s) result(r)
57+ ${t}$, intent(in) :: m1(:,:), m2(:,:), m3(:,:), m4(:,:)
58+ integer, intent(in) :: start, s(:,:)
59+ ${t}$, allocatable :: r(:,:)
11960
120- select case (s(1,5 ))
61+ select case (s(start, start + 3 ))
12162 case (1)
122- f = matmul(a, stdlib_matmul(b, c, d, e ))
63+ r = matmul(m1, matmul_chain_mult_${s}$_3(m2, m3, m4, start + 1, s ))
12364 case (2)
124- f = matmul(matmul(a, b ), stdlib_matmul(c, d, e ))
65+ r = matmul(matmul(m1, m2 ), matmul(m3, m4 ))
12566 case (3)
126- f = matmul(stdlib_matmul(a, b ,c), matmul(d, e))
127- case (4)
128- f = matmul(stdlib_matmul(a, b, c, d), e)
67+ r = matmul(matmul_chain_mult_${s}$_3(m1, m2, m3, start, s), m4)
12968 case default
13069 error stop "stdlib_matmul: unexpected error unexpected s(i,j)"
13170 end select
132- end function stdlib_matmul_ ${s}$_5
71+ end function matmul_chain_mult_ ${s}$_4
13372
13473#:endfor
13574end submodule stdlib_intrinsics_matmul
0 commit comments