@@ -122,59 +122,60 @@ def get_random_transformation(self, data, training=True, seed=None):
122122 return h_start , w_start
123123
124124 def transform_images (self , images , transformation , training = True ):
125- images = self .backend .cast (images , self .compute_dtype )
126- crop_box_hstart , crop_box_wstart = transformation
127- crop_height = self .height
128- crop_width = self .width
125+ if training :
126+ images = self .backend .cast (images , self .compute_dtype )
127+ crop_box_hstart , crop_box_wstart = transformation
128+ crop_height = self .height
129+ crop_width = self .width
129130
130- if self .data_format == "channels_last" :
131- if len (images .shape ) == 4 :
132- images = images [
133- :,
134- crop_box_hstart : crop_box_hstart + crop_height ,
135- crop_box_wstart : crop_box_wstart + crop_width ,
136- :,
137- ]
138- else :
139- images = images [
140- crop_box_hstart : crop_box_hstart + crop_height ,
141- crop_box_wstart : crop_box_wstart + crop_width ,
142- :,
143- ]
144- else :
145- if len (images .shape ) == 4 :
146- images = images [
147- :,
148- :,
149- crop_box_hstart : crop_box_hstart + crop_height ,
150- crop_box_wstart : crop_box_wstart + crop_width ,
151- ]
131+ if self .data_format == "channels_last" :
132+ if len (images .shape ) == 4 :
133+ images = images [
134+ :,
135+ crop_box_hstart : crop_box_hstart + crop_height ,
136+ crop_box_wstart : crop_box_wstart + crop_width ,
137+ :,
138+ ]
139+ else :
140+ images = images [
141+ crop_box_hstart : crop_box_hstart + crop_height ,
142+ crop_box_wstart : crop_box_wstart + crop_width ,
143+ :,
144+ ]
152145 else :
153- images = images [
154- :,
155- crop_box_hstart : crop_box_hstart + crop_height ,
156- crop_box_wstart : crop_box_wstart + crop_width ,
157- ]
146+ if len (images .shape ) == 4 :
147+ images = images [
148+ :,
149+ :,
150+ crop_box_hstart : crop_box_hstart + crop_height ,
151+ crop_box_wstart : crop_box_wstart + crop_width ,
152+ ]
153+ else :
154+ images = images [
155+ :,
156+ crop_box_hstart : crop_box_hstart + crop_height ,
157+ crop_box_wstart : crop_box_wstart + crop_width ,
158+ ]
158159
159- shape = self .backend .shape (images )
160- new_height = shape [self .height_axis ]
161- new_width = shape [self .width_axis ]
162- if (
163- not isinstance (new_height , int )
164- or not isinstance (new_width , int )
165- or new_height != self .height
166- or new_width != self .width
167- ):
168- # Resize images if size mismatch or
169- # if size mismatch cannot be determined
170- # (in the case of a TF dynamic shape).
171- images = self .backend .image .resize (
172- images ,
173- size = (self .height , self .width ),
174- data_format = self .data_format ,
175- )
176- # Resize may have upcasted the outputs
177- images = self .backend .cast (images , self .compute_dtype )
160+ shape = self .backend .shape (images )
161+ new_height = shape [self .height_axis ]
162+ new_width = shape [self .width_axis ]
163+ if (
164+ not isinstance (new_height , int )
165+ or not isinstance (new_width , int )
166+ or new_height != self .height
167+ or new_width != self .width
168+ ):
169+ # Resize images if size mismatch or
170+ # if size mismatch cannot be determined
171+ # (in the case of a TF dynamic shape).
172+ images = self .backend .image .resize (
173+ images ,
174+ size = (self .height , self .width ),
175+ data_format = self .data_format ,
176+ )
177+ # Resize may have upcasted the outputs
178+ images = self .backend .cast (images , self .compute_dtype )
178179 return images
179180
180181 def transform_labels (self , labels , transformation , training = True ):
@@ -197,56 +198,59 @@ def transform_bounding_boxes(
197198 "labels": (num_boxes, num_classes),
198199 }
199200 """
200- h_start , w_start = transformation
201- if not self .backend .is_tensor (bounding_boxes ["boxes" ]):
202- bounding_boxes = densify_bounding_boxes (
203- bounding_boxes , backend = self .backend
204- )
205- boxes = bounding_boxes ["boxes" ]
206- # Convert to a standard xyxy as operations are done xyxy by default.
207- boxes = convert_format (
208- boxes = boxes ,
209- source = self .bounding_box_format ,
210- target = "xyxy" ,
211- height = self .height ,
212- width = self .width ,
213- )
214- h_start = self .backend .cast (h_start , boxes .dtype )
215- w_start = self .backend .cast (w_start , boxes .dtype )
216- if len (self .backend .shape (boxes )) == 3 :
217- boxes = self .backend .numpy .stack (
218- [
219- self .backend .numpy .maximum (boxes [:, :, 0 ] - h_start , 0 ),
220- self .backend .numpy .maximum (boxes [:, :, 1 ] - w_start , 0 ),
221- self .backend .numpy .maximum (boxes [:, :, 2 ] - h_start , 0 ),
222- self .backend .numpy .maximum (boxes [:, :, 3 ] - w_start , 0 ),
223- ],
224- axis = - 1 ,
225- )
226- else :
227- boxes = self .backend .numpy .stack (
228- [
229- self .backend .numpy .maximum (boxes [:, 0 ] - h_start , 0 ),
230- self .backend .numpy .maximum (boxes [:, 1 ] - w_start , 0 ),
231- self .backend .numpy .maximum (boxes [:, 2 ] - h_start , 0 ),
232- self .backend .numpy .maximum (boxes [:, 3 ] - w_start , 0 ),
233- ],
234- axis = - 1 ,
201+
202+ if training :
203+ h_start , w_start = transformation
204+ if not self .backend .is_tensor (bounding_boxes ["boxes" ]):
205+ bounding_boxes = densify_bounding_boxes (
206+ bounding_boxes , backend = self .backend
207+ )
208+ boxes = bounding_boxes ["boxes" ]
209+ # Convert to a standard xyxy as operations are done xyxy by default.
210+ boxes = convert_format (
211+ boxes = boxes ,
212+ source = self .bounding_box_format ,
213+ target = "xyxy" ,
214+ height = self .height ,
215+ width = self .width ,
235216 )
217+ h_start = self .backend .cast (h_start , boxes .dtype )
218+ w_start = self .backend .cast (w_start , boxes .dtype )
219+ if len (self .backend .shape (boxes )) == 3 :
220+ boxes = self .backend .numpy .stack (
221+ [
222+ self .backend .numpy .maximum (boxes [:, :, 0 ] - h_start , 0 ),
223+ self .backend .numpy .maximum (boxes [:, :, 1 ] - w_start , 0 ),
224+ self .backend .numpy .maximum (boxes [:, :, 2 ] - h_start , 0 ),
225+ self .backend .numpy .maximum (boxes [:, :, 3 ] - w_start , 0 ),
226+ ],
227+ axis = - 1 ,
228+ )
229+ else :
230+ boxes = self .backend .numpy .stack (
231+ [
232+ self .backend .numpy .maximum (boxes [:, 0 ] - h_start , 0 ),
233+ self .backend .numpy .maximum (boxes [:, 1 ] - w_start , 0 ),
234+ self .backend .numpy .maximum (boxes [:, 2 ] - h_start , 0 ),
235+ self .backend .numpy .maximum (boxes [:, 3 ] - w_start , 0 ),
236+ ],
237+ axis = - 1 ,
238+ )
236239
237- # Convert to user defined bounding box format
238- boxes = convert_format (
239- boxes = boxes ,
240- source = "xyxy" ,
241- target = self .bounding_box_format ,
242- height = self .height ,
243- width = self .width ,
244- )
240+ # Convert to user defined bounding box format
241+ boxes = convert_format (
242+ boxes = boxes ,
243+ source = "xyxy" ,
244+ target = self .bounding_box_format ,
245+ height = self .height ,
246+ width = self .width ,
247+ )
245248
246- return {
247- "boxes" : boxes ,
248- "labels" : bounding_boxes ["labels" ],
249- }
249+ return {
250+ "boxes" : boxes ,
251+ "labels" : bounding_boxes ["labels" ],
252+ }
253+ return bounding_boxes
250254
251255 def transform_segmentation_masks (
252256 self , segmentation_masks , transformation , training = True
0 commit comments