@@ -3,6 +3,7 @@ export Trace, Traces, MultiplexTraces, Episode, Episodes
33import MacroTools: @forward
44
55import CircularArrayBuffers
6+ import Adapt
67
78# ####
89
@@ -13,11 +14,23 @@ Base.convert(::Type{AbstractTrace}, x::AbstractTrace) = x
1314Base. summary (io:: IO , t:: AbstractTrace ) = print (io, " $(length (t)) -element $(nameof (typeof (t))) " )
1415
1516# ####
17+
18+ """
19+ Trace(A::AbstractArray)
20+
21+ Similar to
22+ [`Slices`](https://github.com/JuliaLang/julia/blob/master/base/slicearray.jl)
23+ which will be introduced in `Julia@v1.9`. The main difference is that, the
24+ `axes` info in the `Slices` is static, while it may be dynamic with `Trace`.
25+
26+ We only support slices along the last dimension since it's the most common usage
27+ in RL.
28+ """
1629struct Trace{T,E} <: AbstractTrace{E}
1730 parent:: T
1831end
1932
20- Base. summary (io:: IO , t:: Trace{T} ) where {T} = print (io, " $(length (t)) -element $(nameof (typeof (t))) {$T }" )
33+ Base. summary (io:: IO , t:: Trace{T} ) where {T} = print (io, " $(length (t)) -element$( length (t) > 0 ? ' s ' : " " ) $(nameof (typeof (t))) {$T }" )
2134
2235function Trace (x:: T ) where {T<: AbstractArray }
2336 E = eltype (x)
@@ -27,6 +40,8 @@ function Trace(x::T) where {T<:AbstractArray}
2740 Trace {T,SubArray{E,N,P,I,true}} (x)
2841end
2942
43+ Adapt. adapt_structure (to, t:: Trace ) = Trace (Adapt. adapt_structure (to, t. parent))
44+
3045Base. convert (:: Type{AbstractTrace} , x:: AbstractArray ) = Trace (x)
3146
3247Base. size (x:: Trace ) = (size (x. parent, ndims (x. parent)),)
@@ -59,6 +74,21 @@ Base.haskey(t::AbstractTraces{names}, k::Symbol) where {names} = k in names
5974
6075# ####
6176
77+ """
78+ Dedicated for `MultiplexTraces` to avoid scalar indexing when `view(view(t::MultiplexTrace, 1:end-1), I)`.
79+ """
80+ struct RelativeTrace{left,right,T,E} <: AbstractTrace{E}
81+ trace:: Trace{T,E}
82+ end
83+ RelativeTrace {left,right} (t:: Trace{T,E} ) where {left,right,T,E} = RelativeTrace {left,right,T,E} (t)
84+
85+ Base. size (x:: RelativeTrace{0,-1} ) = (max (0 , length (x. trace) - 1 ),)
86+ Base. size (x:: RelativeTrace{1,0} ) = (max (0 , length (x. trace) - 1 ),)
87+ Base. getindex (s:: RelativeTrace{0,-1} , I) = getindex (s. trace, I)
88+ Base. getindex (s:: RelativeTrace{1,0} , I) = getindex (s. trace, I .+ 1 )
89+ Base. setindex! (s:: RelativeTrace{0,-1} , v, I) = setindex! (s. trace, v, I)
90+ Base. setindex! (s:: RelativeTrace{1,0} , v, I) = setindex! (s. trace, v, I .+ 1 )
91+
6292"""
6393 MultiplexTraces{names}(trace)
6494
@@ -89,12 +119,14 @@ function MultiplexTraces{names}(t) where {names}
89119 MultiplexTraces {names,typeof(trace),eltype(trace)} (trace)
90120end
91121
122+ Adapt. adapt_structure (to, t:: MultiplexTraces{names} ) where {names} = MultiplexTraces {names} (Adapt. adapt_structure (to, t. trace))
123+
92124function Base. getindex (t:: MultiplexTraces{names} , k:: Symbol ) where {names}
93125 a, b = names
94126 if k == a
95- convert (AbstractTrace, t. trace[ 1 : end - 1 ] )
127+ RelativeTrace {0,-1} ( convert (AbstractTrace, t. trace) )
96128 elseif k == b
97- convert (AbstractTrace, t. trace[ 2 : end ] )
129+ RelativeTrace {1,0} ( convert (AbstractTrace, t. trace) )
98130 else
99131 throw (ArgumentError (" unknown trace name: $k " ))
100132 end
133165
134166Episode (t:: AbstractTraces{names,T} ) where {names,T} = Episode {typeof(t),names,T} (t, Ref (false ))
135167
168+ Adapt. adapt_structure (to, t:: Episode{T,names,E} ) where {T,names,E} = Episode {T,names,E} (Adapt. adapt_structure (to, t. traces), t. is_terminated)
169+
136170@forward Episode. traces Base. getindex, Base. setindex!, Base. size
137171
138172Base. getindex (e:: Episode ) = getindex (e. is_terminated)
@@ -175,6 +209,11 @@ struct Episodes{names,E,T} <: AbstractTraces{names,E}
175209 inds:: Vector{Tuple{Int,Int}}
176210end
177211
212+ Adapt. adapt_structure (to, t:: Episodes ) =
213+ Episodes () do
214+ Adapt. adapt_structure (to, t. init ())
215+ end
216+
178217function Episodes (init)
179218 x = init ()
180219 T = typeof (x)
@@ -249,6 +288,11 @@ struct Traces{names,T,N,E} <: AbstractTraces{names,E}
249288 inds:: NamedTuple{names,NTuple{N,Int}}
250289end
251290
291+ function Adapt. adapt_structure (to, t:: Traces{names,T,N,E} ) where {names,T,N,E}
292+ data = Adapt. adapt_structure (to, t. traces)
293+ # FIXME : `E` is not adapted here
294+ Traces {names,typeof(data),length(names),E} (data, t. inds)
295+ end
252296
253297function Traces (; kw... )
254298 data = map (x -> convert (AbstractTrace, x), values (kw))
0 commit comments