@@ -122,7 +122,11 @@ When sampling a normalized trace, it will first normalize the samples to zero me
122122variance. Traces that do not have a normalizer are sample as usual.
123123
124124Note that when used in combination with [`Episodes`](@ref), `NormalizedTraces` must wrap
125- the `Episodes` struct, not the inner `AbstractTraces` contained in an `Episode`, otherwise
125+ the `Episodes` struct, not the inner `AbstractTraces` contained in an `Episode`, otherwise
126+ the running estimate will reset after each episode.
127+
128+ When used with a MultiplexTraces, the normalizer used with for one symbol (e.g. :state) will
129+ be the same used for the other one (e.g. :next_state).
126130
127131Preconfigured normalizers are provided for scalar (see [`scalar_normalizer`](@ref)) and
128132arrays (see [`array_normalizer`](@ref)).
@@ -131,6 +135,7 @@ arrays (see [`array_normalizer`](@ref)).
131135```
132136t = CircularArraySARTTraces(capacity = 10, state = Float64 => (5,))
133137nt = NormalizedTraces(t, reward = scalar_normalizer(), state = array_normalizer((5,)))
138+ # :next_state will also be normalized.
134139traj = Trajectory(
135140 container = nt,
136141 sampler = BatchSampler(10)
@@ -147,6 +152,18 @@ function NormalizedTraces(traces::AbstractTraces{names, TT}; trace_normalizer_pa
147152 @assert key in keys (traces) " Traces do not have key $key , valid keys are $(keys (traces)) ."
148153 end
149154 nt = (; trace_normalizer_pairs... )
155+ for trace in traces. traces
156+ # check if all traces of MultiplexTraces are in pairs
157+ if trace isa MultiplexTraces
158+ if length (intersect (keys (trace), keys (trace_normalizer_pairs))) in [0 , length (keys (trace))] # check if none or all keys are in normalizers
159+ continue
160+ else # if not then one is missing
161+ present_key = only (intersect (keys (trace), keys (trace_normalizer_pairs)))
162+ absent_key = only (setdiff (keys (trace), keys (trace_normalizer_pairs)))
163+ nt = merge (nt, (;(absent_key => nt[present_key],). .. )) # assign the same normalizer
164+ end
165+ end
166+ end
150167 NormalizedTraces {names, TT, typeof(traces), keys(nt), typeof(values(nt))} (traces, nt)
151168end
152169
0 commit comments