@@ -26,10 +26,16 @@ arguments(a::CartesianProduct, n::Int) = arguments(a)[n]
2626arg1 (a:: CartesianProduct ) = a. a
2727arg2 (a:: CartesianProduct ) = a. b
2828
29+ Base. copy (a:: CartesianProduct ) = copy (arg1 (a)) × copy (arg2 (a))
30+
2931function Base. show (io:: IO , a:: CartesianProduct )
3032 print (io, a. a, " × " , a. b)
3133 return nothing
3234end
35+ function Base. show (io:: IO , :: MIME"text/plain" , a:: CartesianProduct )
36+ show (io, a)
37+ return nothing
38+ end
3339
3440× (a:: AbstractVector , b:: AbstractVector ) = CartesianProduct (a, b)
3541Base. length (a:: CartesianProduct ) = length (a. a) * length (a. b)
@@ -42,8 +48,38 @@ function Base.getindex(a::CartesianProduct, i::CartesianPair)
4248 return arg1 (a)[arg1 (i)] × arg2 (a)[arg2 (i)]
4349end
4450function Base. getindex (a:: CartesianProduct , i:: Int )
45- I = Tuple (CartesianIndices ((length (arg1 (a)), length (arg2 (a))))[i])
46- return a[I[1 ] × I[2 ]]
51+ I = Tuple (CartesianIndices ((length (arg2 (a)), length (arg1 (a))))[i])
52+ return a[I[2 ] × I[1 ]]
53+ end
54+
55+ struct CartesianProductVector{T,P<: CartesianProduct ,V<: AbstractVector{T} } < :
56+ AbstractVector{T}
57+ product:: P
58+ values:: V
59+ end
60+ cartesianproduct (r:: CartesianProductVector ) = getfield (r, :product )
61+ unproduct (r:: CartesianProductVector ) = getfield (r, :values )
62+ Base. length (a:: CartesianProductVector ) = length (unproduct (a))
63+ Base. size (a:: CartesianProductVector ) = (length (a),)
64+ function Base. axes (r:: CartesianProductVector )
65+ return (CartesianProductUnitRange (cartesianproduct (r), only (axes (unproduct (r)))),)
66+ end
67+ function Base. copy (a:: CartesianProductVector )
68+ return CartesianProductVector (copy (cartesianproduct (a)), copy (unproduct (a)))
69+ end
70+ function Base. getindex (r:: CartesianProductVector , i:: Integer )
71+ return unproduct (r)[i]
72+ end
73+
74+ function Base. show (io:: IO , a:: CartesianProductVector )
75+ show (io, unproduct (a))
76+ return nothing
77+ end
78+ function Base. show (io:: IO , mime:: MIME"text/plain" , a:: CartesianProductVector )
79+ show (io, mime, cartesianproduct (a))
80+ println (io)
81+ show (io, mime, unproduct (a))
82+ return nothing
4783end
4884
4985struct CartesianProductUnitRange{T,P<: CartesianProduct ,R<: AbstractUnitRange{T} } < :
@@ -60,13 +96,24 @@ unproduct(r::CartesianProductUnitRange) = getfield(r, :range)
6096arg1 (a:: CartesianProductUnitRange ) = arg1 (cartesianproduct (a))
6197arg2 (a:: CartesianProductUnitRange ) = arg2 (cartesianproduct (a))
6298
99+ function Base. show (io:: IO , a:: CartesianProductUnitRange )
100+ show (io, unproduct (a))
101+ return nothing
102+ end
103+ function Base. show (io:: IO , mime:: MIME"text/plain" , a:: CartesianProductUnitRange )
104+ show (io, mime, cartesianproduct (a))
105+ println (io)
106+ show (io, mime, unproduct (a))
107+ return nothing
108+ end
109+
63110function CartesianProductUnitRange (p:: CartesianProduct )
64111 return CartesianProductUnitRange (p, Base. OneTo (length (p)))
65112end
66113function CartesianProductUnitRange (a, b)
67114 return CartesianProductUnitRange (a × b)
68115end
69- to_product_indices (a:: AbstractUnitRange ) = a
116+ to_product_indices (a:: AbstractVector ) = a
70117to_product_indices (i:: Integer ) = Base. OneTo (i)
71118cartesianrange (a, b) = cartesianrange (to_product_indices (a) × to_product_indices (b))
72119function cartesianrange (p:: CartesianPair )
@@ -94,10 +141,16 @@ function Base.checkindex(::Type{Bool}, inds::CartesianProductUnitRange, i::Carte
94141 return checkindex (Bool, arg1 (inds), arg1 (i)) && checkindex (Bool, arg2 (inds), arg2 (i))
95142end
96143
144+ function Base. getindex (a:: CartesianProductUnitRange , I:: CartesianProduct )
145+ prod = cartesianproduct (a)
146+ prod_I = arg1 (prod)[arg1 (I)] × arg2 (prod)[arg2 (I)]
147+ return CartesianProductVector (prod_I, map (Base. Fix1 (getindex, a), I))
148+ end
149+
97150# Reverse map from CartesianPair to linear index in the range.
98151function Base. getindex (inds:: CartesianProductUnitRange , i:: CartesianPair )
99- i′ = (findfirst (== (arg1 (i)), arg1 (inds)), findfirst (== (arg2 (i)), arg2 (inds)))
100- return inds[LinearIndices ((length (arg1 (inds)), length (arg2 (inds))))[i′... ]]
152+ i′ = (findfirst (== (arg2 (i)), arg2 (inds)), findfirst (== (arg1 (i)), arg1 (inds)))
153+ return inds[LinearIndices ((length (arg2 (inds)), length (arg1 (inds))))[i′... ]]
101154end
102155
103156using Base. Broadcast: DefaultArrayStyle
0 commit comments