@@ -213,6 +213,8 @@ def compute_feature_stats_for_dataset(opts, detector_url, detector_kwargs, rel_l
213213 # Main loop.
214214 item_subset = [(i * opts .num_gpus + opts .rank ) % num_items for i in range ((num_items - 1 ) // opts .num_gpus + 1 )]
215215 for images , _labels in torch .utils .data .DataLoader (dataset = dataset , sampler = item_subset , batch_size = batch_size , ** data_loader_kwargs ):
216+ if images .shape [1 ] == 1 :
217+ images = images .repeat ([1 , 3 , 1 , 1 ])
216218 features = detector (images .to (opts .device ), ** detector_kwargs )
217219 stats .append_torch (features , num_gpus = opts .num_gpus , rank = opts .rank )
218220 progress .update (stats .num_items )
@@ -262,7 +264,10 @@ def run_generator(z, c):
262264 c = [dataset .get_label (np .random .randint (len (dataset ))) for _i in range (batch_gen )]
263265 c = torch .from_numpy (np .stack (c )).pin_memory ().to (opts .device )
264266 images .append (run_generator (z , c ))
265- features = detector (torch .cat (images ), ** detector_kwargs )
267+ images = torch .cat (images )
268+ if images .shape [1 ] == 1 :
269+ images = images .repeat ([1 , 3 , 1 , 1 ])
270+ features = detector (images , ** detector_kwargs )
266271 stats .append_torch (features , num_gpus = opts .num_gpus , rank = opts .rank )
267272 progress .update (stats .num_items )
268273 return stats
0 commit comments