@@ -227,31 +227,50 @@ def version_7(cls, ctx, node, **kwargs):
227227 ctx .remove_input (node , node .input [1 ], 1 )
228228
229229
230+ def _const_like_version_1 (ctx , node , value ):
231+ shapes = node .output_shapes
232+ dtypes = node .output_dtypes
233+ ctx .remove_node (node .name )
234+ casted_input = ctx .make_node ("Cast" , node .input , attr = {'to' : onnx_pb .TensorProto .INT64 })
235+ const_value = ctx .make_const (utils .make_name ("value" ), np .array (value ).astype (np .int64 ))
236+ mul_node = ctx .make_node ('Mul' , inputs = [casted_input .output [0 ], const_value .output [0 ]])
237+ ctx .make_node ("Cast" , inputs = [mul_node .output [0 ]],
238+ attr = {'to' : dtypes [0 ]},
239+ name = node .name , outputs = node .output ,
240+ shapes = shapes , dtypes = dtypes )
241+
242+
243+ def _const_like_version_9 (ctx , node , value ):
244+ dtypes = node .output_dtypes
245+ ctx .remove_node (node .name )
246+ shape = ctx .make_node ("Shape" , node .input ).output [0 ]
247+ value_tensor = helper .make_tensor ("value" , dtypes [0 ], [1 ], vals = [value ])
248+ ctx .make_node ("ConstantOfShape" , inputs = [shape ],
249+ attr = {'value' : value_tensor },
250+ name = node .name , outputs = node .output ,
251+ dtypes = dtypes )
252+
253+
230254@tf_op ("ZerosLike" )
231255class ZerosLike :
232256 @classmethod
233257 def version_1 (cls , ctx , node , ** kwargs ):
234- shapes = node .output_shapes
235- dtypes = node .output_dtypes
236- ctx .remove_node (node .name )
237- casted_input = ctx .make_node ("Cast" , node .input , attr = {'to' : onnx_pb .TensorProto .INT64 })
238- const_zero = ctx .make_const (utils .make_name ("zero" ), np .array (0 ).astype (np .int64 ))
239- mul_node = ctx .make_node ('Mul' , inputs = [casted_input .output [0 ], const_zero .output [0 ]])
240- ctx .make_node ("Cast" , inputs = [mul_node .output [0 ]],
241- attr = {'to' : dtypes [0 ]},
242- name = node .name , outputs = node .output ,
243- shapes = shapes , dtypes = dtypes )
258+ _const_like_version_1 (ctx , node , 0 )
244259
245260 @classmethod
246261 def version_9 (cls , ctx , node , ** kwargs ):
247- dtypes = node .output_dtypes
248- ctx .remove_node (node .name )
249- shape = ctx .make_node ("Shape" , node .input ).output [0 ]
250- zero_tensor = helper .make_tensor ("value" , dtypes [0 ], [1 ], vals = [0 ])
251- ctx .make_node ("ConstantOfShape" , inputs = [shape ],
252- attr = {'value' : zero_tensor },
253- name = node .name , outputs = node .output ,
254- dtypes = dtypes )
262+ _const_like_version_9 (ctx , node , 0 )
263+
264+
265+ @tf_op ("OnesLike" )
266+ class OnesLike :
267+ @classmethod
268+ def version_1 (cls , ctx , node , ** kwargs ):
269+ _const_like_version_1 (ctx , node , 1 )
270+
271+ @classmethod
272+ def version_9 (cls , ctx , node , ** kwargs ):
273+ _const_like_version_9 (ctx , node , 1 )
255274
256275
257276@tf_op (["IteratorV2" , "FIFOQueueV2" ])
0 commit comments