@@ -527,7 +527,7 @@ def test_sequential_sampler():
527527 assert next (it ) == ['a' , 'b' , 'c' , 'a' , 'b' , 'c' ]
528528
529529
530- def test_merge_concat ():
530+ def test_concat_merge ():
531531 dataset = Dataset .concat ([
532532 Dataset .from_subscriptable ([1 , 2 ]),
533533 Dataset .from_subscriptable ([1 , 3 , 5 ]),
@@ -536,8 +536,51 @@ def test_merge_concat():
536536 datastream = Datastream .merge ([
537537 Datastream (dataset ),
538538 Datastream (dataset .subset (
539- lambda df : df [ " index" ] <= 3
539+ lambda df : [ index < 3 for index in range ( len ( df ))]
540540 )),
541541 ])
542542
543- list (datastream )
543+ assert len (dataset .subset (
544+ lambda df : [index < 3 for index in range (len (df ))]
545+ )) == 3
546+
547+ assert len (list (datastream )) == 6
548+
549+
550+ def test_combine_concat_merge ():
551+ dataset = Dataset .concat ([
552+ Dataset .zip ([
553+ Dataset .from_subscriptable ([1 ]),
554+ Dataset .from_subscriptable ([2 ]),
555+ ]),
556+ Dataset .combine ([
557+ Dataset .from_subscriptable ([3 , 3 ]),
558+ Dataset .from_subscriptable ([4 , 4 , 4 ]),
559+ ]),
560+ ])
561+
562+ datastream = Datastream .merge ([
563+ Datastream (dataset ),
564+ Datastream (Dataset .zip ([
565+ Dataset .from_subscriptable ([5 ]),
566+ Dataset .from_subscriptable ([6 ]),
567+ ])),
568+ ])
569+
570+ assert len (list (datastream )) == 2
571+
572+
573+ def test_last_batch ():
574+ from datastream .samplers import SequentialSampler
575+
576+ datastream = Datastream (
577+ Dataset .from_subscriptable (list ('abc' ))
578+ )
579+ assert list (map (len , datastream .data_loader (batch_size = 4 ))) == [3 ]
580+ assert list (map (len , datastream .data_loader (batch_size = 4 , n_batches_per_epoch = 2 ))) == [4 , 4 ]
581+
582+ datastream = Datastream (
583+ Dataset .from_subscriptable (list ('abc' )),
584+ SequentialSampler (3 ),
585+ )
586+ assert list (map (len , datastream .data_loader (batch_size = 2 ))) == [2 , 1 ]
0 commit comments