@@ -4,137 +4,28 @@ using Base.PermutedDimsArrays: genperm
44# i.e. `ContractAdd`?
55function output_axes (
66 :: typeof (contract),
7- biperm_dest:: BlockedPermutation {2} ,
7+ biperm_dest:: AbstractBlockPermutation {2} ,
88 a1:: AbstractArray ,
9- biperm1:: BlockedPermutation {2} ,
9+ biperm1:: AbstractBlockPermutation {2} ,
1010 a2:: AbstractArray ,
11- biperm2:: BlockedPermutation {2} ,
11+ biperm2:: AbstractBlockPermutation {2} ,
1212 α:: Number = one (Bool),
1313)
14- axes_codomain, axes_contracted = blockpermute (axes (a1), biperm1)
15- axes_contracted2, axes_domain = blockpermute (axes (a2), biperm2)
14+ axes_codomain, axes_contracted = blocks (axes (a1)[ biperm1] )
15+ axes_contracted2, axes_domain = blocks (axes (a2)[ biperm2] )
1616 @assert axes_contracted == axes_contracted2
1717 return genperm ((axes_codomain... , axes_domain... ), invperm (Tuple (biperm_dest)))
1818end
1919
20- # Inner-product contraction.
21- # TODO : Use `ArrayLayouts`-like `MulAdd` object,
22- # i.e. `ContractAdd`?
23- function output_axes (
24- :: typeof (contract),
25- perm_dest:: BlockedPermutation{0} ,
26- a1:: AbstractArray ,
27- perm1:: BlockedPermutation{1} ,
28- a2:: AbstractArray ,
29- perm2:: BlockedPermutation{1} ,
30- α:: Number = one (Bool),
31- )
32- axes_contracted = blockpermute (axes (a1), perm1)
33- axes_contracted′ = blockpermute (axes (a2), perm2)
34- @assert axes_contracted == axes_contracted′
35- return ()
36- end
37-
38- # Vec-mat.
39- function output_axes (
40- :: typeof (contract),
41- perm_dest:: BlockedPermutation{1} ,
42- a1:: AbstractArray ,
43- perm1:: BlockedPermutation{1} ,
44- a2:: AbstractArray ,
45- biperm2:: BlockedPermutation{2} ,
46- α:: Number = one (Bool),
47- )
48- (axes_contracted,) = blockpermute (axes (a1), perm1)
49- axes_contracted′, axes_dest = blockpermute (axes (a2), biperm2)
50- @assert axes_contracted == axes_contracted′
51- return genperm ((axes_dest... ,), invperm (Tuple (perm_dest)))
52- end
53-
54- # Mat-vec.
55- function output_axes (
56- :: typeof (contract),
57- perm_dest:: BlockedPermutation{1} ,
58- a1:: AbstractArray ,
59- perm1:: BlockedPermutation{2} ,
60- a2:: AbstractArray ,
61- biperm2:: BlockedPermutation{1} ,
62- α:: Number = one (Bool),
63- )
64- axes_dest, axes_contracted = blockpermute (axes (a1), perm1)
65- (axes_contracted′,) = blockpermute (axes (a2), biperm2)
66- @assert axes_contracted == axes_contracted′
67- return genperm ((axes_dest... ,), invperm (Tuple (perm_dest)))
68- end
69-
70- # Outer product.
71- function output_axes (
72- :: typeof (contract),
73- biperm_dest:: BlockedPermutation{2} ,
74- a1:: AbstractArray ,
75- perm1:: BlockedPermutation{1} ,
76- a2:: AbstractArray ,
77- perm2:: BlockedPermutation{1} ,
78- α:: Number = one (Bool),
79- )
80- @assert istrivialperm (Tuple (perm1))
81- @assert istrivialperm (Tuple (perm2))
82- axes_dest = (axes (a1)... , axes (a2)... )
83- return genperm (axes_dest, invperm (Tuple (biperm_dest)))
84- end
85-
86- # Array-scalar contraction.
87- function output_axes (
88- :: typeof (contract),
89- perm_dest:: BlockedPermutation{1} ,
90- a1:: AbstractArray ,
91- perm1:: BlockedPermutation{1} ,
92- a2:: AbstractArray ,
93- perm2:: BlockedPermutation{0} ,
94- α:: Number = one (Bool),
95- )
96- @assert istrivialperm (Tuple (perm1))
97- axes_dest = axes (a1)
98- return genperm (axes_dest, invperm (Tuple (perm_dest)))
99- end
100-
101- # Scalar-array contraction.
102- function output_axes (
103- :: typeof (contract),
104- perm_dest:: BlockedPermutation{1} ,
105- a1:: AbstractArray ,
106- perm1:: BlockedPermutation{0} ,
107- a2:: AbstractArray ,
108- perm2:: BlockedPermutation{1} ,
109- α:: Number = one (Bool),
110- )
111- @assert istrivialperm (Tuple (perm2))
112- axes_dest = axes (a2)
113- return genperm (axes_dest, invperm (Tuple (perm_dest)))
114- end
115-
116- # Scalar-scalar contraction.
117- function output_axes (
118- :: typeof (contract),
119- perm_dest:: BlockedPermutation{0} ,
120- a1:: AbstractArray ,
121- perm1:: BlockedPermutation{0} ,
122- a2:: AbstractArray ,
123- perm2:: BlockedPermutation{0} ,
124- α:: Number = one (Bool),
125- )
126- return ()
127- end
128-
12920# TODO : Use `ArrayLayouts`-like `MulAdd` object,
13021# i.e. `ContractAdd`?
13122function allocate_output (
13223 :: typeof (contract),
133- biperm_dest:: BlockedPermutation ,
24+ biperm_dest:: AbstractBlockPermutation ,
13425 a1:: AbstractArray ,
135- biperm1:: BlockedPermutation ,
26+ biperm1:: AbstractBlockPermutation ,
13627 a2:: AbstractArray ,
137- biperm2:: BlockedPermutation ,
28+ biperm2:: AbstractBlockPermutation ,
13829 α:: Number = one (Bool),
13930)
14031 axes_dest = output_axes (contract, biperm_dest, a1, biperm1, a2, biperm2, α)
0 commit comments