@@ -64,10 +64,10 @@ def predict(self, identifiers, input_data):
6464 outputs , raw_outputs = outputs
6565 else :
6666 raw_outputs = outputs
67- encoder_output = np . array ( outputs [self .encoder_out ]). squeeze ()
67+ encoder_output = outputs [self .encoder_out ]
6868 self .h0 = outputs [self .h0_out ]
6969 self .c0 = outputs [self .c0_out ]
70- return encoder_output , raw_outputs
70+ return encoder_output . squeeze () , raw_outputs
7171
7272 def fit_to_input (self , input_data ):
7373 return {self .input : input_data , self .h0_input : self .h0 , self .c0_input : self .c0 }
@@ -127,7 +127,7 @@ def predict(self, identifiers, input_data, hidden=None):
127127 raw_outputs = outputs
128128 self .h0 = outputs [self .h0_out ]
129129 self .c0 = outputs [self .c0_out ]
130- return np . array ( outputs [self .decoder_out ]) .squeeze (), (self .h0 , self .c0 ), raw_outputs
130+ return outputs [self .decoder_out ].squeeze (), (self .h0 , self .c0 ), raw_outputs
131131
132132 def fit_to_input (self , token_id , hidden ):
133133 if hidden is None :
@@ -189,7 +189,7 @@ def predict(self, identifiers, input_data):
189189 else :
190190 raw_outputs = outputs
191191 joint_out = outputs [self .output ]
192- return log_softmax (np . array ( joint_out ). squeeze () ), raw_outputs
192+ return log_softmax (joint_out ), raw_outputs
193193
194194 def fit_to_input (self , encoder_out , predictor_out ):
195195 return {self .input1 : encoder_out , self .input2 : predictor_out }
@@ -339,7 +339,7 @@ def __init__(self, network_info, launcher, suffix=None, delayed_model_loading=Fa
339339class OVJoint (Joint , CommonOpenVINOModel ):
340340 def __init__ (self , network_info , launcher , suffix = None , delayed_model_loading = False ):
341341 self .default_inputs = ['0' , '1' ]
342- self .default_outputs = ['8/sink_port ' ]
342+ self .default_outputs = ['8/sink_port_0 ' ]
343343 super ().__init__ (network_info , launcher , suffix , delayed_model_loading )
344344
345345
@@ -353,25 +353,22 @@ def infer(self, input_data):
353353 results = self .inference_session .run (self .output_names , input_data )
354354 return dict (zip (self .output_names , results ))
355355
356- def select_inputs_outputs (self , network_info ):
357- pass
358-
359356
360- class ONNXEncoder (CommonONNXModel , Encoder ):
357+ class ONNXEncoder (Encoder , CommonONNXModel ):
361358 def __init__ (self , network_info , launcher , suffix = None , delayed_model_loading = False ):
362359 self .default_inputs = ['input_0' , 'input_1' , 'input_2' ]
363360 self .default_outputs = ['output_0' , 'output_1' , 'output_2' ]
364361 super ().__init__ (network_info , launcher , suffix , delayed_model_loading )
365362
366363
367- class ONNXDecoder (CommonONNXModel , Decoder ):
364+ class ONNXDecoder (Decoder , CommonONNXModel ):
368365 def __init__ (self , network_info , launcher , suffix = None , delayed_model_loading = False ):
369366 self .default_inputs = ['input_0' , 'input_1' , 'input_2' ]
370367 self .default_outputs = ['output_0' , 'output_1' , 'output_2' ]
371368 super ().__init__ (network_info , launcher , suffix , delayed_model_loading )
372369
373370
374- class ONNXJoint (CommonONNXModel , Joint ):
371+ class ONNXJoint (Joint , CommonONNXModel ):
375372 def __init__ (self , network_info , launcher , suffix = None , delayed_model_loading = False ):
376373 self .default_inputs = ['0' , '1' ]
377374 self .default_outputs = ['8' ]
@@ -454,7 +451,7 @@ def predict(self, identifiers, input_data, encoder_callback=None):
454451 if len (B ) >= self .beam_width and yb .log_prob >= y_hat .log_prob :
455452 break
456453 B = heapq .nlargest (self .beam_width , B )
457- return self .adapter .process ([B [0 ].sequence ], identifiers , [{}]), {}
454+ return [{}], self .adapter .process ([B [0 ].sequence ], identifiers , [{}])
458455
459456 @staticmethod
460457 def prepare_records (features ):
0 commit comments