@@ -128,15 +128,18 @@ end
128128
129129Adapt. adapt_structure (to, t:: MultiplexTraces{names} ) where {names} = MultiplexTraces {names} (Adapt. adapt_structure (to, t. trace))
130130
131- function Base. getindex (t:: MultiplexTraces{names} , k:: Symbol ) where {names}
132- a, b = names
133- if k == a
134- RelativeTrace {0,-1} (convert (AbstractTrace, t. trace))
135- elseif k == b
136- RelativeTrace {1,0} (convert (AbstractTrace, t. trace))
131+ Base. getindex (t:: MultiplexTraces{names} , k:: Symbol ) where {names} = _getindex (t, Val (k))
132+
133+ @generated function _getindex (t:: MultiplexTraces{names} , :: Val{k} ) where {names,k}
134+ ex = :()
135+ if QuoteNode (names[1 ]) == QuoteNode (k)
136+ ex = :(RelativeTrace {0,-1} (t. trace))
137+ elseif QuoteNode (names[2 ]) == QuoteNode (k)
138+ ex = :(RelativeTrace {1,0} (t. trace))
137139 else
138- throw (ArgumentError (" unknown trace name: $k " ))
140+ ex = :( throw (ArgumentError (" unknown trace name: $k " ) ))
139141 end
142+ return :($ ex)
140143end
141144
142145Base. getindex (t:: MultiplexTraces{names} , I:: Int ) where {names} = NamedTuple {names} ((t. trace[I], t. trace[I+ 1 ]))
@@ -163,83 +166,144 @@ end
163166
164167struct Traces{names,T,N,E} <: AbstractTraces{names,E}
165168 traces:: T
166- inds:: NamedTuple{names,NTuple{N,Int}}
167169end
168170
169171function Adapt. adapt_structure (to, t:: Traces{names,T,N,E} ) where {names,T,N,E}
170172 data = Adapt. adapt_structure (to, t. traces)
171173 # FIXME : `E` is not adapted here
172- Traces {names,typeof(data),length(names),E} (data, t . inds )
174+ Traces {names,typeof(data),length(names),E} (data)
173175end
174176
175177function Traces (; kw... )
176178 data = map (x -> convert (AbstractTrace, x), values (kw))
177179 names = keys (data)
178- inds = NamedTuple (k => i for (i, k) in enumerate (names))
179- Traces {names,typeof(data),length(names),typeof(values(data))} (data, inds)
180+ Traces {names,typeof(data),length(names),typeof(values(data))} (data)
180181end
181182
182183
183- function Base. getindex (ts:: Traces , s:: Symbol )
184- t = ts. traces[ts. inds[s]]
184+ Base. getindex (ts:: Traces , s:: Symbol ) = Base. getindex (ts:: Traces , Val (s))
185+
186+ function Base. getindex (ts:: Traces , :: Val{s} ) where {s}
187+ t = _gettrace (ts, Val (s))
185188 if t isa AbstractTrace
186189 t
190+ elseif t isa MultiplexTraces
191+ _getindex (t, Val (s))
187192 else
188- t[s]
193+ throw ( ArgumentError ( " unknown trace name: $s " ))
189194 end
190195end
191196
192- Base. getindex (t:: Traces{names} , i) where {names} = NamedTuple {names} (map (k -> t[k][i], names))
197+ @generated function _gettrace (ts:: Traces{names,Trs,N,E} , :: Val{k} ) where {names,Trs,N,E,k}
198+ index_ = build_trace_index (names, Trs)
199+ # Generate code, i.e. find the correct index for a given key
200+ ex = :()
201+
202+ for name in names
203+ if QuoteNode (name) == QuoteNode (k)
204+ index_element = index_[k]
205+ ex = :(ts. traces[$ index_element])
206+ break
207+ end
208+ end
209+
210+ return :($ ex)
211+ end
212+
213+ @generated function Base. getindex (t:: Traces{names} , i) where {names}
214+ ex = :(NamedTuple {$(names)} ($ (Expr (:tuple ))))
215+ for k in names
216+ push! (ex. args[2 ]. args, :(t[Val ($ (QuoteNode (k)))][i]))
217+ end
218+ return ex
219+ end
193220
194221function Base.:(+ )(t1:: AbstractTraces{k1,T1} , t2:: AbstractTraces{k2,T2} ) where {k1,k2,T1,T2}
195222 ks = (k1... , k2... )
196223 ts = (t1, t2)
197- inds = (; (k => 1 for k in k1). .. , (k => 2 for k in k2). .. )
198- Traces {ks,typeof(ts),length(ks),Tuple{T1.types...,T2.types...}} (ts, inds)
224+ Traces {ks,typeof(ts),length(ks),Tuple{T1.types...,T2.types...}} (ts)
199225end
200226
201227function Base.:(+ )(t1:: AbstractTraces{k1,T1} , t2:: Traces{k2,T,N,T2} ) where {k1,T1,k2,T,N,T2}
202228 ks = (k1... , k2... )
203229 ts = (t1, t2. traces... )
204- inds = merge (NamedTuple (k => 1 for k in k1), map (v -> v + 1 , t2. inds))
205- Traces {ks,typeof(ts),length(ks),Tuple{T1.types...,T2.types...}} (ts, inds)
230+ Traces {ks,typeof(ts),length(ks),Tuple{T1.types...,T2.types...}} (ts)
206231end
207232
208233
209234function Base.:(+ )(t1:: Traces{k1,T,N,T1} , t2:: AbstractTraces{k2,T2} ) where {k1,T,N,T1,k2,T2}
210235 ks = (k1... , k2... )
211236 ts = (t1. traces... , t2)
212- inds = merge (t1. inds, (; (k => length (ts) for k in k2). .. ))
213- Traces {ks,typeof(ts),length(ks),Tuple{T1.types...,T2.types...}} (ts, inds)
237+ Traces {ks,typeof(ts),length(ks),Tuple{T1.types...,T2.types...}} (ts)
214238end
215239
216240function Base.:(+ )(t1:: Traces{k1,T1,N1,E1} , t2:: Traces{k2,T2,N2,E2} ) where {k1,T1,N1,E1,k2,T2,N2,E2}
217241 ks = (k1... , k2... )
218242 ts = (t1. traces... , t2. traces... )
219- inds = merge (t1. inds, map (x -> x + length (t1. traces), t2. inds))
220- Traces {ks,typeof(ts),length(ks),Tuple{E1.types...,E2.types...}} (ts, inds)
243+ Traces {ks,typeof(ts),length(ks),Tuple{E1.types...,E2.types...}} (ts)
221244end
222245
223246Base. size (t:: Traces ) = (mapreduce (length, min, t. traces),)
224- capacity (t:: Traces ) = minimum (map (idx-> capacity (t. traces[idx]),t. inds))
225247
226- for f in (:push! , :pushfirst! )
227- @eval function Base. $f (ts:: Traces , xs:: NamedTuple )
228- for (k, v) in pairs (xs)
229- $ f (ts, Val (k), v)
248+ function capacity (t:: Traces{names,Trs,N,E} ) where {names,Trs,N,E}
249+ minimum (map (idx-> capacity (t[idx]), names))
250+ end
251+
252+ @generated function Base. push! (ts:: Traces , xs:: NamedTuple{N,T} ) where {N,T}
253+ ex = :()
254+ for n in N
255+ ex = :($ ex; push! (ts, Val ($ (QuoteNode (n))), xs.$ n))
256+ end
257+ return :($ ex)
258+ end
259+
260+ @generated function Base. pushfirst! (ts:: Traces , xs:: NamedTuple{N,T} ) where {N,T}
261+ ex = :()
262+ for n in N
263+ ex = :($ ex; pushfirst! (ts, Val ($ (QuoteNode (n))), xs.$ n))
264+ end
265+ return :($ ex)
266+ end
267+
268+ @generated function Base. pushfirst! (ts:: Traces{names,Trs,N,E} , :: Val{k} , v) where {names,Trs,N,E,k}
269+ index_ = build_trace_index (names, Trs)
270+ # Generate code, i.e. find the correct index for a given key
271+ ex = :()
272+
273+ for name in names
274+ if QuoteNode (name) == QuoteNode (k)
275+ index_element = index_[k]
276+ ex = :(pushfirst! (ts. traces[$ index_element], Val ($ (QuoteNode (k))), v))
277+ break
230278 end
231279 end
232280
233- @eval function Base. $f (ts:: Traces , :: Val{k} , v) where {k}
234- $ f (ts. traces[ts. inds[k]], Val (k), v)
281+ return :($ ex)
282+ end
283+
284+ @generated function Base. push! (ts:: Traces{names,Trs,N,E} , :: Val{k} , v) where {names,Trs,N,E,k}
285+ index_ = build_trace_index (names, Trs)
286+ # Generate code, i.e. find the correct index for a given key
287+ ex = :()
288+
289+ for name in names
290+ if QuoteNode (name) == QuoteNode (k)
291+ index_element = index_[k]
292+ ex = :(push! (ts. traces[$ index_element], Val ($ (QuoteNode (k))), v))
293+ break
294+ end
235295 end
236296
297+ return :($ ex)
298+ end
299+
300+ for f in (:push! , :pushfirst! )
237301 @eval function Base. $f (t:: AbstractTrace , :: Val{k} , v) where {k}
238302 $ f (t, v)
239303 end
240304
241305 @eval function Base. $f (t:: Trace , :: Val{k} , v) where {k}
242- $ f (t, v)
306+ $ f (t. parent , v)
243307 end
244308
245309 @eval function Base. $f (ts:: MultiplexTraces , :: Val{k} , v) where {k}
251315for f in (:append! , :prepend! )
252316 @eval function Base. $f (ts:: Traces , xs:: Traces )
253317 for k in keys (xs)
254- t = ts . traces[ts . inds[k]]
318+ t = _gettrace (ts, Val (k))
255319 $ f (t, xs[k])
256320 end
257321 end
@@ -264,3 +328,38 @@ for f in (:pop!, :popfirst!, :empty!)
264328 end
265329 end
266330end
331+
332+
333+ """
334+ build_trace_index(names::NTuple, traces_signature::DataType)
335+
336+ Take type signature from `Traces` and build a mapping from trace name to trace index
337+ """
338+ function build_trace_index (names:: NTuple , traces_signature:: DataType )
339+ # Build index
340+ index_ = Dict ()
341+
342+ if traces_signature <: NamedTuple
343+ # Handle simple Traces
344+ index_ = Dict (name => i for (name, i) ∈ zip (names, 1 : length (names)))
345+ elseif traces_signature <: Tuple
346+ # Handle MultiplexTracesup
347+ i = 1
348+ j = 1
349+ trace_list = traces_signature. parameters
350+ for tr in trace_list
351+ if tr <: MultiplexTraces
352+ index_[names[i]] = j
353+ i += 1
354+ index_[names[i]] = j
355+ else
356+ index_[names[i]] = j
357+ end
358+ i += 1
359+ j += 1
360+ end
361+ else
362+ error (" Traces store is neither a tuple nor a named tuple!" )
363+ end
364+ return index_
365+ end
0 commit comments