4343 "data_batch_5"
4444]
4545_CIFAR10_TEST_FILES = ["test_batch" ]
46- _CIFAR10_IMAGE_SIZE = 32
46+ _CIFAR10_IMAGE_SIZE = _CIFAR100_IMAGE_SIZE = 32
4747
48+ _CIFAR100_URL = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
49+ _CIFAR100_PREFIX = "cifar-100-python/"
50+ _CIFAR100_TRAIN_FILES = ["train" ]
51+ _CIFAR100_TEST_FILES = ["test" ]
4852
49- def _get_cifar10 (directory ):
53+
54+ def _get_cifar (directory , url ):
5055 """Download and extract CIFAR to directory unless it is there."""
51- filename = os .path .basename (_CIFAR10_URL )
52- path = generator_utils .maybe_download (directory , filename , _CIFAR10_URL )
56+ filename = os .path .basename (url )
57+ path = generator_utils .maybe_download (directory , filename , url )
5358 tarfile .open (path , "r:gz" ).extractall (directory )
5459
5560
56- def cifar10_generator ( tmp_dir , training , how_many , start_from = 0 ):
57- """Image generator for CIFAR-10.
61+ def cifar_generator ( cifar_version , tmp_dir , training , how_many , start_from = 0 ):
62+ """Image generator for CIFAR-10 and 100 .
5863
5964 Args:
65+ cifar_version: string; one of "cifar10" or "cifar100"
6066 tmp_dir: path to temporary storage directory.
6167 training: a Boolean; if true, we use the train set, otherwise the test set.
6268 how_many: how many images and labels to generate.
@@ -65,21 +71,33 @@ def cifar10_generator(tmp_dir, training, how_many, start_from=0):
6571 Returns:
6672 An instance of image_generator that produces CIFAR-10 images and labels.
6773 """
68- _get_cifar10 (tmp_dir )
69- data_files = _CIFAR10_TRAIN_FILES if training else _CIFAR10_TEST_FILES
74+ if cifar_version == "cifar10" :
75+ url = _CIFAR10_URL
76+ train_files = _CIFAR10_TRAIN_FILES
77+ test_files = _CIFAR10_TEST_FILES
78+ prefix = _CIFAR10_PREFIX
79+ image_size = _CIFAR10_IMAGE_SIZE
80+ elif cifar_version == "cifar100" :
81+ url = _CIFAR100_URL
82+ train_files = _CIFAR100_TRAIN_FILES
83+ test_files = _CIFAR100_TEST_FILES
84+ prefix = _CIFAR100_PREFIX
85+ image_size = _CIFAR100_IMAGE_SIZE
86+
87+ _get_cifar (tmp_dir , url )
88+ data_files = train_files if training else test_files
7089 all_images , all_labels = [], []
7190 for filename in data_files :
72- path = os .path .join (tmp_dir , _CIFAR10_PREFIX , filename )
91+ path = os .path .join (tmp_dir , prefix , filename )
7392 with tf .gfile .Open (path , "r" ) as f :
7493 data = cPickle .load (f )
7594 images = data ["data" ]
7695 num_images = images .shape [0 ]
77- images = images .reshape ((num_images , 3 , _CIFAR10_IMAGE_SIZE ,
78- _CIFAR10_IMAGE_SIZE ))
96+ images = images .reshape ((num_images , 3 , image_size , image_size ))
7997 all_images .extend ([
8098 np .squeeze (images [j ]).transpose ((1 , 2 , 0 )) for j in xrange (num_images )
8199 ])
82- labels = data ["labels" ]
100+ labels = data ["labels" if cifar_version == "cifar10" else "fine_labels" ]
83101 all_labels .extend ([labels [j ] for j in xrange (num_images )])
84102 return image_utils .image_generator (
85103 all_images [start_from :start_from + how_many ],
@@ -112,19 +130,19 @@ def preprocess_example(self, example, mode, unused_hparams):
112130
113131 def generator (self , data_dir , tmp_dir , is_training ):
114132 if is_training :
115- return cifar10_generator ( tmp_dir , True , 48000 )
133+ return cifar_generator ( "cifar10" , tmp_dir , True , 48000 )
116134 else :
117- return cifar10_generator ( tmp_dir , True , 2000 , 48000 )
135+ return cifar_generator ( "cifar10" , tmp_dir , True , 2000 , 48000 )
118136
119137
120138@registry .register_problem
121139class ImageCifar10 (ImageCifar10Tune ):
122140
123141 def generator (self , data_dir , tmp_dir , is_training ):
124142 if is_training :
125- return cifar10_generator ( tmp_dir , True , 50000 )
143+ return cifar_generator ( "cifar10" , tmp_dir , True , 50000 )
126144 else :
127- return cifar10_generator ( tmp_dir , False , 10000 )
145+ return cifar_generator ( "cifar10" , tmp_dir , False , 10000 )
128146
129147
130148@registry .register_problem
@@ -188,3 +206,210 @@ def hparams(self, defaults, unused_model_hparams):
188206 p .batch_size_multiplier = 256
189207 p .input_space_id = 1
190208 p .target_space_id = 1
209+
210+
211+ @registry .register_problem
212+ class ImageCifar100Tune (mnist .ImageMnistTune ):
213+ """Cifar-100 Tune."""
214+
215+ @property
216+ def num_classes (self ):
217+ return 100
218+
219+ @property
220+ def num_channels (self ):
221+ return 3
222+
223+ @property
224+ def class_labels (self ):
225+ return [
226+ "beaver" ,
227+ "dolphin" ,
228+ "otter" ,
229+ "seal" ,
230+ "whale" ,
231+ "aquarium fish" ,
232+ "flatfish" ,
233+ "ray" ,
234+ "shark" ,
235+ "trout" ,
236+ "orchids" ,
237+ "poppies" ,
238+ "roses" ,
239+ "sunflowers" ,
240+ "tulips" ,
241+ "bottles" ,
242+ "bowls" ,
243+ "cans" ,
244+ "cups" ,
245+ "plates" ,
246+ "apples" ,
247+ "mushrooms" ,
248+ "oranges" ,
249+ "pears" ,
250+ "sweet peppers" ,
251+ "clock" ,
252+ "computer keyboard" ,
253+ "lamp" ,
254+ "telephone" ,
255+ "television" ,
256+ "bed" ,
257+ "chair" ,
258+ "couch" ,
259+ "table" ,
260+ "wardrobe" ,
261+ "bee" ,
262+ "beetle" ,
263+ "butterfly" ,
264+ "caterpillar" ,
265+ "cockroach" ,
266+ "bear" ,
267+ "leopard" ,
268+ "lion" ,
269+ "tiger" ,
270+ "wolf" ,
271+ "bridge" ,
272+ "castle" ,
273+ "house" ,
274+ "road" ,
275+ "skyscraper" ,
276+ "cloud" ,
277+ "forest" ,
278+ "mountain" ,
279+ "plain" ,
280+ "sea" ,
281+ "camel" ,
282+ "cattle" ,
283+ "chimpanzee" ,
284+ "elephant" ,
285+ "kangaroo" ,
286+ "fox" ,
287+ "porcupine" ,
288+ "possum" ,
289+ "raccoon" ,
290+ "skunk" ,
291+ "crab" ,
292+ "lobster" ,
293+ "snail" ,
294+ "spider" ,
295+ "worm" ,
296+ "baby" ,
297+ "boy" ,
298+ "girl" ,
299+ "man" ,
300+ "woman" ,
301+ "crocodile" ,
302+ "dinosaur" ,
303+ "lizard" ,
304+ "snake" ,
305+ "turtle" ,
306+ "hamster" ,
307+ "mouse" ,
308+ "rabbit" ,
309+ "shrew" ,
310+ "squirrel" ,
311+ "maple" ,
312+ "oak" ,
313+ "palm" ,
314+ "pine" ,
315+ "willow" ,
316+ "bicycle" ,
317+ "bus" ,
318+ "motorcycle" ,
319+ "pickup truck" ,
320+ "train" ,
321+ "lawn-mower" ,
322+ "rocket" ,
323+ "streetcar" ,
324+ "tank" ,
325+ "tractor" ,
326+ ]
327+
328+ def preprocess_example (self , example , mode , unused_hparams ):
329+ image = example ["inputs" ]
330+ image .set_shape ([_CIFAR100_IMAGE_SIZE , _CIFAR100_IMAGE_SIZE , 3 ])
331+ if mode == tf .estimator .ModeKeys .TRAIN :
332+ image = image_utils .cifar_image_augmentation (image )
333+ image = tf .image .per_image_standardization (image )
334+ example ["inputs" ] = image
335+ return example
336+
337+ def generator (self , data_dir , tmp_dir , is_training ):
338+ if is_training :
339+ return cifar_generator ("cifar100" , tmp_dir , True , 48000 )
340+ else :
341+ return cifar_generator ("cifar100" , tmp_dir , True , 2000 , 48000 )
342+
343+
344+ @registry .register_problem
345+ class ImageCifar100 (ImageCifar100Tune ):
346+
347+ def generator (self , data_dir , tmp_dir , is_training ):
348+ if is_training :
349+ return cifar_generator ("cifar100" , tmp_dir , True , 50000 )
350+ else :
351+ return cifar_generator ("cifar100" , tmp_dir , False , 10000 )
352+
353+
354+ @registry .register_problem
355+ class ImageCifar100Plain (ImageCifar100 ):
356+
357+ def preprocess_example (self , example , mode , unused_hparams ):
358+ image = example ["inputs" ]
359+ image .set_shape ([_CIFAR100_IMAGE_SIZE , _CIFAR100_IMAGE_SIZE , 3 ])
360+ image = tf .image .per_image_standardization (image )
361+ example ["inputs" ] = image
362+ return example
363+
364+
365+ @registry .register_problem
366+ class ImageCifar100PlainGen (ImageCifar100Plain ):
367+ """CIFAR-100 32x32 for image generation without standardization preprep."""
368+
369+ def dataset_filename (self ):
370+ return "image_cifar100_plain" # Reuse CIFAR-100 plain data.
371+
372+ def preprocess_example (self , example , mode , unused_hparams ):
373+ example ["inputs" ].set_shape ([_CIFAR100_IMAGE_SIZE , _CIFAR100_IMAGE_SIZE , 3 ])
374+ example ["inputs" ] = tf .to_int64 (example ["inputs" ])
375+ return example
376+
377+
378+ @registry .register_problem
379+ class ImageCifar100Plain8 (ImageCifar100 ):
380+ """CIFAR-100 rescaled to 8x8 for output: Conditional image generation."""
381+
382+ def dataset_filename (self ):
383+ return "image_cifar100_plain" # Reuse CIFAR-100 plain data.
384+
385+ def preprocess_example (self , example , mode , unused_hparams ):
386+ image = example ["inputs" ]
387+ image = image_utils .resize_by_area (image , 8 )
388+ image = tf .image .per_image_standardization (image )
389+ example ["inputs" ] = image
390+ return example
391+
392+
393+ @registry .register_problem
394+ class Img2imgCifar100 (ImageCifar100 ):
395+ """CIFAR-100 rescaled to 8x8 for input and 32x32 for output."""
396+
397+ def dataset_filename (self ):
398+ return "image_cifar100_plain" # Reuse CIFAR-100 plain data.
399+
400+ def preprocess_example (self , example , unused_mode , unused_hparams ):
401+
402+ inputs = example ["inputs" ]
403+ # For Img2Img resize input and output images as desired.
404+ example ["inputs" ] = image_utils .resize_by_area (inputs , 8 )
405+ example ["targets" ] = image_utils .resize_by_area (inputs , 32 )
406+ return example
407+
408+ def hparams (self , defaults , unused_model_hparams ):
409+ p = defaults
410+ p .input_modality = {"inputs" : ("image:identity" , 256 )}
411+ p .target_modality = ("image:identity" , 256 )
412+ p .batch_size_multiplier = 256
413+ p .max_expected_batch_size_per_shard = 4
414+ p .input_space_id = 1
415+ p .target_space_id = 1
0 commit comments