@@ -66,7 +66,137 @@ def example_reading_spec(self, label_key=None):
6666 return data_fields , data_items_to_decoders
6767
6868
69- # French street names dataset.
69+ @registry .register_problem ("image_celeba_tune" )
70+ class ImageCeleba (ImageProblem ):
71+ """CelebA dataset, aligned and cropped images."""
72+ IMG_DATA = ("img_align_celeba.zip" ,
73+ "https://drive.google.com/uc?export=download&"
74+ "id=0B7EVK8r0v71pZjFTYXZWM3FlRnM" )
75+ LANDMARKS_DATA = ("celeba_landmarks_align" ,
76+ "https://drive.google.com/uc?export=download&"
77+ "id=0B7EVK8r0v71pd0FJY3Blby1HUTQ" )
78+ ATTR_DATA = ("celeba_attr" , "https://drive.google.com/uc?export=download&"
79+ "id=0B7EVK8r0v71pblRyaVFSWGxPY0U" )
80+
81+ LANDMARK_HEADINGS = ("lefteye_x lefteye_y righteye_x righteye_y "
82+ "nose_x nose_y leftmouth_x leftmouth_y rightmouth_x "
83+ "rightmouth_y" ).split ()
84+ ATTR_HEADINGS = (
85+ "5_o_Clock_Shadow Arched_Eyebrows Attractive Bags_Under_Eyes Bald Bangs "
86+ "Big_Lips Big_Nose Black_Hair Blond_Hair Blurry Brown_Hair "
87+ "Bushy_Eyebrows Chubby Double_Chin Eyeglasses Goatee Gray_Hair "
88+ "Heavy_Makeup High_Cheekbones Male Mouth_Slightly_Open Mustache "
89+ "Narrow_Eyes No_Beard Oval_Face Pale_Skin Pointy_Nose Receding_Hairline "
90+ "Rosy_Cheeks Sideburns Smiling Straight_Hair Wavy_Hair Wearing_Earrings "
91+ "Wearing_Hat Wearing_Lipstick Wearing_Necklace Wearing_Necktie Young"
92+ ).split ()
93+
94+ def preprocess_examples (self , examples , unused_mode , unused_hparams ):
95+
96+ def resize (img , size ):
97+ return tf .to_int64 (
98+ tf .image .resize_images (img , [size , size ], tf .image .ResizeMethod .AREA ))
99+
100+ inputs = examples ["inputs" ]
101+ # Remove boundaries in CelebA images. Remove 40 pixels each side
102+ # vertically and 20 pixels each side horizontally.
103+ inputs = tf .image .crop_to_bounding_box (inputs , 40 , 20 , 218 - 80 , 178 - 40 )
104+ examples ["inputs" ] = resize (inputs , 8 )
105+ examples ["targets" ] = resize (inputs , 32 )
106+ return examples
107+
108+ def hparams (self , defaults , model_hparams ):
109+ p = defaults
110+ p .input_modality = {"inputs" : ("image:identity_no_pad" , None )}
111+ p .target_modality = ("image:identity_no_pad" , None )
112+ p .batch_size_multiplier = 256
113+ p .max_expected_batch_size_per_shard = 4
114+ p .input_space_id = 1
115+ p .target_space_id = 1
116+
117+ def generator (self , tmp_dir , how_many , start_from = 0 ):
118+ """Image generator for CELEBA dataset.
119+
120+ Args:
121+ tmp_dir: path to temporary storage directory.
122+ how_many: how many images and labels to generate.
123+ start_from: from which image to start.
124+
125+ Yields:
126+ A dictionary representing the images with the following fields:
127+ * image/encoded: the string encoding the image as JPEG,
128+ * image/format: the string "jpeg" representing image format,
129+ """
130+ out_paths = []
131+ for fname , url in [self .IMG_DATA , self .LANDMARKS_DATA , self .ATTR_DATA ]:
132+ path = generator_utils .maybe_download_from_drive (tmp_dir , fname , url )
133+ out_paths .append (path )
134+
135+ img_path , landmarks_path , attr_path = out_paths # pylint: disable=unbalanced-tuple-unpacking
136+ unzipped_folder = img_path [:- 4 ]
137+ if not tf .gfile .Exists (unzipped_folder ):
138+ zipfile .ZipFile (img_path , "r" ).extractall (tmp_dir )
139+
140+ with tf .gfile .Open (landmarks_path ) as f :
141+ landmarks_raw = f .read ()
142+
143+ with tf .gfile .Open (attr_path ) as f :
144+ attr_raw = f .read ()
145+
146+ def process_landmarks (raw_data ):
147+ landmarks = {}
148+ lines = raw_data .split ("\n " )
149+ headings = lines [1 ].strip ().split ()
150+ for line in lines [2 :- 1 ]:
151+ values = line .strip ().split ()
152+ img_name = values [0 ]
153+ landmark_values = [int (v ) for v in values [1 :]]
154+ landmarks [img_name ] = landmark_values
155+ return landmarks , headings
156+
157+ def process_attrs (raw_data ):
158+ attrs = {}
159+ lines = raw_data .split ("\n " )
160+ headings = lines [1 ].strip ().split ()
161+ for line in lines [2 :- 1 ]:
162+ values = line .strip ().split ()
163+ img_name = values [0 ]
164+ attr_values = [int (v ) for v in values [1 :]]
165+ attrs [img_name ] = attr_values
166+ return attrs , headings
167+
168+ img_landmarks , _ = process_landmarks (landmarks_raw )
169+ img_attrs , _ = process_attrs (attr_raw )
170+
171+ image_files = tf .gfile .Glob (unzipped_folder + "/*.jpg" )
172+ for filename in image_files [start_from :start_from + how_many ]:
173+ img_name = os .path .basename (filename )
174+ landmarks = img_landmarks [img_name ]
175+ attrs = img_attrs [img_name ]
176+
177+ with tf .gfile .Open (filename , "r" ) as f :
178+ encoded_image_data = f .read ()
179+ yield {
180+ "image/encoded" : [encoded_image_data ],
181+ "image/format" : ["jpeg" ],
182+ "attributes" : attrs ,
183+ "landmarks" : landmarks ,
184+ }
185+
186+ @property
187+ def train_shards (self ):
188+ return 100
189+
190+ @property
191+ def dev_shards (self ):
192+ return 10
193+
194+ def generate_data (self , data_dir , tmp_dir , task_id = - 1 ):
195+ generator_utils .generate_dataset_and_shuffle (
196+ self .generator (tmp_dir , 162770 ), # train
197+ self .training_filepaths (data_dir , self .train_shards , shuffled = False ),
198+ self .generator (tmp_dir , 19867 , 162770 ), # dev
199+ self .dev_filepaths (data_dir , self .dev_shards , shuffled = False ))
70200
71201
72202@registry .register_problem
@@ -199,7 +329,7 @@ def generate_data(self, data_dir, tmp_dir, task_id=-1):
199329 "instructions at https://github.com/tensorflow/models/blob/master"
200330 "/inception/README.md#getting-started" )
201331
202- def preprocess_examples (self , examples , mode ):
332+ def preprocess_examples (self , examples , mode , _ ):
203333 return imagenet_preprocess_examples (examples , mode )
204334
205335
@@ -638,7 +768,7 @@ def train_shards(self):
638768 def dev_shards (self ):
639769 return 10
640770
641- def preprocess_examples (self , examples , mode ):
771+ def preprocess_examples (self , examples , mode , _ ):
642772 return imagenet_preprocess_examples (examples , mode )
643773
644774 def generator (self , data_dir , tmp_dir , is_training ):
@@ -700,41 +830,3 @@ class ImageMsCocoTokens32k(ImageMsCocoTokens8k):
700830 @property
701831 def targeted_vocab_size (self ):
702832 return 2 ** 15 # 32768
703-
704-
705- # URL and filename for CELEBA data.
706- _CELEBA_NAME = "img_align_celeba"
707- _CELEBA_URL = "https://drive.google.com/uc?export=download&id=0B7EVK8r0v71pZjFTYXZWM3FlRnM"
708-
709-
710- def _get_celeba (directory ):
711- """Download and extract CELEBA to directory unless it is there."""
712- # path = os.path.join(directory, _CELEBA_NAME)
713- path = generator_utils .maybe_download_from_drive (directory , _CELEBA_NAME ,
714- _CELEBA_URL )
715- if not tf .gfile .Exists (path ):
716- zipfile .ZipFile (path + ".zip" , "r" ).extractall (directory )
717-
718-
719- def celeba_generator (tmp_dir , how_many , start_from = 0 ):
720- """Image generator for CELEBA dataset.
721-
722- Args:
723- tmp_dir: path to temporary storage directory.
724- how_many: how many images and labels to generate.
725- start_from: from which image to start.
726-
727- Yields:
728- A dictionary representing the images with the following fields:
729- * image/encoded: the string encoding the image as JPEG,
730- * image/format: the string "jpeg" representing image format,
731- """
732- _get_celeba (tmp_dir )
733- image_files = tf .gfile .Glob (os .path .join (tmp_dir , _CELEBA_NAME ) + "/*.jpg" )
734- for filename in image_files [start_from :start_from + how_many ]:
735- with tf .gfile .Open (filename , "r" ) as f :
736- encoded_image_data = f .read ()
737- yield {
738- "image/encoded" : [encoded_image_data ],
739- "image/format" : ["jpeg" ],
740- }
0 commit comments