11from __future__ import annotations
2- from pydantic import BaseModel , PositiveInt
3- from typing import (
4- Tuple ,
5- Dict ,
6- List ,
7- Callable ,
8- Optional ,
9- TypeVar ,
10- Generic ,
11- Union ,
12- )
2+
3+ from typing import Callable , Dict , Generic , List , Optional , Tuple , TypeVar , Union
4+
135import numpy as np
146import torch
15- from pathlib import Path
7+ from pydantic import BaseModel , PositiveInt
168
179from datastream import Dataset
1810from datastream .samplers import (
19- StandardSampler ,
2011 MergeSampler ,
21- ZipSampler ,
2212 MultiSampler ,
2313 RepeatSampler ,
14+ StandardSampler ,
15+ ZipSampler ,
2416)
2517
26-
2718T = TypeVar ("T" )
2819R = TypeVar ("R" )
2920
@@ -46,7 +37,7 @@ class Datastream(BaseModel, Generic[T]):
4637 16
4738 """
4839
49- dataset : Dataset [ T ]
40+ dataset : Dataset
5041 sampler : Optional [torch .utils .data .Sampler ]
5142
5243 class Config :
@@ -286,29 +277,25 @@ def cache(
286277
287278
288279def test_infinite ():
289-
290280 datastream = Datastream (Dataset .from_subscriptable (list ("abc" )))
291281 it = iter (datastream .data_loader (batch_size = 8 , n_batches_per_epoch = 10 ))
292282 for _ in range (10 ):
293283 batch = next (it )
294284
295285
296286def test_iter ():
297-
298287 datastream = Datastream (Dataset .from_subscriptable (list ("abc" )))
299288 assert len (list (datastream )) == 3
300289
301290
302291def test_empty ():
303-
304292 import pytest
305293
306294 with pytest .raises (ValueError ):
307295 Datastream (Dataset .from_subscriptable (list ()))
308296
309297
310298def test_datastream_merge ():
311-
312299 datastream = Datastream .merge (
313300 [
314301 Datastream (Dataset .from_subscriptable (list ("abc" ))),
@@ -328,7 +315,6 @@ def test_datastream_merge():
328315
329316
330317def test_datastream_zip ():
331-
332318 datasets = [
333319 Dataset .from_subscriptable ([1 , 2 ]),
334320 Dataset .from_subscriptable ([3 , 4 , 5 ]),
@@ -384,7 +370,6 @@ def ZippedMergedDatastream():
384370
385371
386372def test_datastream_simple_weights ():
387-
388373 dataset = Dataset .from_subscriptable ([1 , 2 , 3 , 4 ])
389374 datastream = (
390375 Datastream (dataset )
@@ -412,7 +397,6 @@ def test_datastream_simple_weights():
412397
413398
414399def test_merge_datastream_weights ():
415-
416400 datasets = [
417401 Dataset .from_subscriptable ([1 , 2 ]),
418402 Dataset .from_subscriptable ([3 , 4 , 5 ]),
@@ -441,7 +425,6 @@ def test_merge_datastream_weights():
441425
442426
443427def test_multi_sample ():
444-
445428 data = [1 , 2 , 4 ]
446429 n_multi_sample = 2
447430
@@ -475,7 +458,6 @@ def test_multi_sample():
475458
476459
477460def test_take ():
478-
479461 import pytest
480462
481463 datastream = Datastream (Dataset .from_subscriptable (list ("abc" ))).take (2 )
@@ -494,7 +476,6 @@ def test_take():
494476
495477
496478def test_sequential_sampler ():
497-
498479 from datastream .samplers import SequentialSampler
499480
500481 dataset = Dataset .from_subscriptable (list ("abc" ))
0 commit comments