11from __future__ import annotations
2- from pydantic import BaseModel
2+ from pydantic import BaseModel , PositiveInt
33from typing import (
44 Tuple ,
55 Dict ,
@@ -56,6 +56,9 @@ def __init__(
5656 dataset : Dataset [T ],
5757 sampler : torch .utils .data .Sampler = None
5858 ):
59+ if len (dataset ) == 0 :
60+ raise ValueError ('Cannot create datastream from empty dataset' )
61+
5962 super ().__init__ (
6063 dataset = dataset ,
6164 sampler = (
@@ -67,6 +70,9 @@ def __init__(
6770
6871 def __len__ (self ):
6972 return len (self .sampler )
73+
74+ def __iter__ (self ):
75+ return map (self .dataset .__getitem__ , iter (self .sampler ))
7076
7177 @staticmethod
7278 def merge (datastreams_and_ns : Tuple [Union [
@@ -151,6 +157,10 @@ def data_loader(
151157 '''
152158 Get ``torch.utils.data.DataLoader`` for use in pytorch pipeline.
153159
160+ The argument ``n_batches_per_epoch`` overrides the underlying length
161+ of the dataset. If the epoch ends before the full dataset has been
162+ processed then it will continue from the same point the next epoch.
163+
154164 >>> data_loader = (
155165 ... Datastream(Dataset.from_subscriptable([5, 5, 5]))
156166 ... .data_loader(batch_size=5, n_batches_per_epoch=10)
@@ -218,13 +228,15 @@ def sample_proportion(
218228
219229 def take (
220230 self : Datastream [T ],
221- n_samples : int ,
231+ n_samples : PositiveInt ,
222232 ) -> Datastream [T ]:
223233 '''
224234 Like :func:`Datastream.sample_proportion` but specify the number of
225235 samples instead of a proportion.
226236 '''
227- return self .sample_proportion (min (1 , n_samples / len (self )))
237+ if n_samples < 1 :
238+ raise ValueError ('n_samples must be greater than or equal to 1' )
239+ return self .sample_proportion (n_samples / len (self ))
228240
229241 def state_dict (self ) -> Dict :
230242 '''Get state of datastream. Useful for checkpointing sample weights.'''
@@ -278,6 +290,28 @@ def cache(
278290 )
279291
280292
293+ def test_infinite ():
294+
295+ datastream = Datastream (Dataset .from_subscriptable (list ('abc' )))
296+ it = iter (datastream .data_loader (batch_size = 8 , n_batches_per_epoch = 10 ))
297+ for _ in range (10 ):
298+ batch = next (it )
299+
300+
301+ def test_iter ():
302+
303+ datastream = Datastream (Dataset .from_subscriptable (list ('abc' )))
304+ assert len (list (datastream )) == 3
305+
306+
307+ def test_empty ():
308+
309+ import pytest
310+
311+ with pytest .raises (ValueError ):
312+ Datastream (Dataset .from_subscriptable (list ()))
313+
314+
281315def test_datastream_merge ():
282316
283317 datastream = Datastream .merge ([
@@ -289,10 +323,16 @@ def test_datastream_merge():
289323 for _ in range (2 ):
290324 index = next (it )
291325
292- it = iter (datastream .data_loader (batch_size = 8 ))
326+ it = iter (datastream .data_loader (batch_size = 8 , n_batches_per_epoch = 10 ))
293327 for _ in range (10 ):
294328 batch = next (it )
295329
330+ assert (
331+ len (list (
332+ datastream .data_loader (batch_size = 1 )
333+ )) == len (datastream )
334+ )
335+
296336
297337def test_datastream_zip ():
298338
@@ -314,6 +354,12 @@ def test_datastream_zip():
314354 assert batch [1 ][0 ] == 3 and batch [1 ][1 ] == 4 and batch [1 ][2 ] == 5
315355 assert batch [2 ][0 ] == 6 and batch [2 ][1 ] == 7 and batch [2 ][2 ] == 6
316356
357+ assert (
358+ len (list (
359+ zipped_datastream .data_loader (batch_size = 1 )
360+ )) == len (zipped_datastream )
361+ )
362+
317363
318364def test_datastream_merge_zip_merge ():
319365 '''
@@ -442,3 +488,20 @@ def test_multi_sample():
442488 zero_indices = set ([index for _ , index in output [:2 ]])
443489 for number , index in output2 :
444490 assert index not in zero_indices
491+
492+
493+ def test_take ():
494+
495+ import pytest
496+
497+ datastream = Datastream (Dataset .from_subscriptable (list ('abc' ))).take (2 )
498+ assert len (list (datastream .data_loader (batch_size = 1 ))) == 2
499+
500+ with pytest .raises (ValueError ):
501+ Datastream (Dataset .from_subscriptable (list ('abc' ))).take (0 )
502+
503+ datastream = Datastream .merge ([
504+ Datastream (Dataset .from_subscriptable (list ('abc' ))),
505+ Datastream (Dataset .from_subscriptable (list ('d' ))),
506+ ])
507+ assert len (list (datastream .take (2 ).data_loader (batch_size = 1 ))) == 2
0 commit comments