@@ -2997,6 +2997,49 @@ def test_prb_update_max_priority(self, max_priority_within_buffer):
29972997 assert rb ._sampler ._max_priority [0 ] == 21
29982998 assert rb ._sampler ._max_priority [1 ] == 0
29992999
3000+ def test_prb_serialization (self , tmpdir ):
3001+ rb = ReplayBuffer (
3002+ storage = LazyMemmapStorage (max_size = 10 ),
3003+ sampler = PrioritizedSampler (max_capacity = 10 , alpha = 0.8 , beta = 0.6 ),
3004+ )
3005+
3006+ td = TensorDict (
3007+ {
3008+ "observations" : torch .zeros (1 , 3 ),
3009+ "actions" : torch .zeros (1 , 1 ),
3010+ "rewards" : torch .zeros (1 , 1 ),
3011+ "next_observations" : torch .zeros (1 , 3 ),
3012+ "terminations" : torch .zeros (1 , 1 , dtype = torch .bool ),
3013+ },
3014+ batch_size = [1 ],
3015+ )
3016+ rb .extend (td )
3017+
3018+ rb .save (tmpdir )
3019+
3020+ rb2 = ReplayBuffer (
3021+ storage = LazyMemmapStorage (max_size = 10 ),
3022+ sampler = PrioritizedSampler (max_capacity = 10 , alpha = 0.5 , beta = 0.5 ),
3023+ )
3024+
3025+ td = TensorDict (
3026+ {
3027+ "observations" : torch .ones (1 , 3 ),
3028+ "actions" : torch .ones (1 , 1 ),
3029+ "rewards" : torch .ones (1 , 1 ),
3030+ "next_observations" : torch .ones (1 , 3 ),
3031+ "terminations" : torch .ones (1 , 1 , dtype = torch .bool ),
3032+ },
3033+ batch_size = [1 ],
3034+ )
3035+ rb2 .extend (td )
3036+ rb2 .load (tmpdir )
3037+ assert len (rb ) == 1
3038+ assert rb .sampler ._alpha == rb2 .sampler ._alpha
3039+ assert rb .sampler ._beta == rb2 .sampler ._beta
3040+ assert rb .sampler ._max_priority [0 ] == rb2 .sampler ._max_priority [0 ]
3041+ assert rb .sampler ._max_priority [1 ] == rb2 .sampler ._max_priority [1 ]
3042+
30003043 def test_prb_ndim (self ):
30013044 """This test lists all the possible ways of updating the priority of a PRB with RB, TRB and TPRB.
30023045
0 commit comments