@@ -40,17 +40,16 @@ def test_generate_movielens_dataset(self):
4040 stats = loader .generate_movielens_dataset (self .dataset_dir , gen_dir ,
4141 'train.tfrecord' , 'test.tfrecord' ,
4242 'movie_vocab.json' , 'meta.json' )
43- self .assertDictContainsSubset (
44- {
45- 'train_file' : os .path .join (gen_dir , 'train.tfrecord' ),
46- 'test_file' : os .path .join (gen_dir , 'test.tfrecord' ),
47- 'vocab_file' : os .path .join (gen_dir , 'movie_vocab.json' ),
48- 'train_size' : _testutil .TRAIN_SIZE ,
49- 'test_size' : _testutil .TEST_SIZE ,
50- 'vocab_size' : _testutil .VOCAB_SIZE ,
51- 'vocab_max_id' : _testutil .MAX_ITEM_ID ,
52- }, stats )
53-
43+ expected = {
44+ 'train_file' : os .path .join (gen_dir , 'train.tfrecord' ),
45+ 'test_file' : os .path .join (gen_dir , 'test.tfrecord' ),
46+ 'vocab_file' : os .path .join (gen_dir , 'movie_vocab.json' ),
47+ 'train_size' : _testutil .TRAIN_SIZE ,
48+ 'test_size' : _testutil .TEST_SIZE ,
49+ 'vocab_size' : _testutil .VOCAB_SIZE ,
50+ 'vocab_max_id' : _testutil .MAX_ITEM_ID ,
51+ }
52+ self .assertEqual (stats , {** stats , ** expected })
5453 self .assertTrue (os .path .exists (gen_dir ))
5554 self .assertGreater (len (os .listdir (gen_dir )), 0 )
5655
0 commit comments