@@ -49,20 +49,22 @@ def _export_from_module(self, module, input_type, save_directory):
4949 {input_type : 'serving_default' })
5050 tf .saved_model .save (module , save_directory , signatures = signatures )
5151
52- def _get_dummy_input (self , input_type , input_image_size ):
52+ def _get_dummy_input (self , input_type , input_image_size , num_channels ):
5353 """Get dummy input for the given input type."""
5454
5555 height = input_image_size [0 ]
5656 width = input_image_size [1 ]
5757 if input_type == 'image_tensor' :
58- return tf .zeros ((1 , height , width , 3 ), dtype = np .uint8 )
58+ return tf .zeros ((1 , height , width , num_channels ), dtype = np .uint8 )
5959 elif input_type == 'image_bytes' :
60- image = Image .fromarray (np .zeros ((height , width , 3 ), dtype = np .uint8 ))
60+ image = Image .fromarray (
61+ np .zeros ((height , width , num_channels ), dtype = np .uint8 )
62+ )
6163 byte_io = io .BytesIO ()
6264 image .save (byte_io , 'PNG' )
6365 return [byte_io .getvalue ()]
6466 elif input_type == 'tf_example' :
65- image_tensor = tf .zeros ((height , width , 3 ), dtype = tf .uint8 )
67+ image_tensor = tf .zeros ((height , width , num_channels ), dtype = tf .uint8 )
6668 encoded_jpeg = tf .image .encode_jpeg (tf .constant (image_tensor )).numpy ()
6769 example = tf .train .Example (
6870 features = tf .train .Features (
@@ -73,7 +75,7 @@ def _get_dummy_input(self, input_type, input_image_size):
7375 })).SerializeToString ()
7476 return [example ]
7577 elif input_type == 'tflite' :
76- return tf .zeros ((1 , height , width , 3 ), dtype = np .float32 )
78+ return tf .zeros ((1 , height , width , num_channels ), dtype = np .float32 )
7779
7880 @parameterized .parameters (
7981 ('image_tensor' , False , [112 , 112 ], False ),
@@ -105,7 +107,7 @@ def test_export(self, input_type, rescale_output, input_image_size,
105107 imported = tf .saved_model .load (tmp_dir )
106108 segmentation_fn = imported .signatures ['serving_default' ]
107109
108- images = self ._get_dummy_input (input_type , input_image_size )
110+ images = self ._get_dummy_input (input_type , input_image_size , num_channels = 3 )
109111 if input_type != 'tflite' :
110112 processed_images , _ = tf .nest .map_structure (
111113 tf .stop_gradient ,
@@ -128,6 +130,68 @@ def test_export(self, input_type, rescale_output, input_image_size,
128130 out = segmentation_fn (tf .constant (images ))
129131 self .assertAllClose (out ['logits' ].numpy (), expected_output .numpy ())
130132
133+ @parameterized .parameters (
134+ ('image_tensor' ,),
135+ ('tflite' ,),
136+ )
137+ def test_export_with_extra_input_channels (self , input_type ):
138+ tmp_dir = self .get_temp_dir ()
139+ num_channels = 6
140+ params = exp_factory .get_exp_config ('mnv2_deeplabv3_pascal' )
141+ params .task .init_checkpoint = None
142+ params .task .model .input_size = [112 , 112 , num_channels ]
143+ params .task .export_config .rescale_output = False
144+ params .task .train_data .preserve_aspect_ratio = False
145+ params .task .train_data .image_feature .mean = [0.5 ] * num_channels
146+ params .task .train_data .image_feature .stddev = [0.5 ] * num_channels
147+ params .task .train_data .image_feature .num_channels = num_channels
148+ module = semantic_segmentation .SegmentationModule (
149+ params ,
150+ batch_size = 1 ,
151+ input_image_size = [112 , 112 ],
152+ input_type = input_type ,
153+ num_channels = num_channels ,
154+ )
155+
156+ self ._export_from_module (module , input_type , tmp_dir )
157+
158+ self .assertTrue (os .path .exists (os .path .join (tmp_dir , 'saved_model.pb' )))
159+ self .assertTrue (
160+ os .path .exists (os .path .join (tmp_dir , 'variables' , 'variables.index' ))
161+ )
162+ self .assertTrue (
163+ os .path .exists (
164+ os .path .join (tmp_dir , 'variables' , 'variables.data-00000-of-00001' )
165+ )
166+ )
167+
168+ imported = tf .saved_model .load (tmp_dir )
169+ segmentation_fn = imported .signatures ['serving_default' ]
170+
171+ images = self ._get_dummy_input (input_type , [112 , 112 ], num_channels )
172+
173+ if input_type != 'tflite' :
174+ processed_images , _ = tf .nest .map_structure (
175+ tf .stop_gradient ,
176+ tf .map_fn (
177+ module ._build_inputs ,
178+ elems = tf .zeros ((1 , 112 , 112 , num_channels ), dtype = tf .uint8 ),
179+ fn_output_signature = (
180+ tf .TensorSpec (
181+ shape = [112 , 112 , num_channels ], dtype = tf .float32
182+ ),
183+ tf .TensorSpec (shape = [4 , 2 ], dtype = tf .float32 ),
184+ ),
185+ ),
186+ )
187+ else :
188+ processed_images = images
189+
190+ logits = module .model (processed_images , training = False )['logits' ]
191+ expected_output = tf .image .resize (logits , [112 , 112 ], method = 'bilinear' )
192+ out = segmentation_fn (tf .constant (images ))
193+ self .assertAllClose (out ['logits' ].numpy (), expected_output .numpy ())
194+
131195 def test_export_invalid_batch_size (self ):
132196 batch_size = 3
133197 tmp_dir = self .get_temp_dir ()
0 commit comments