@@ -20,8 +20,9 @@ namespace transformers {
2020
2121 Inputs:
2222 input_ids: int32 (B, 1)
23+ encoder_input_ids: int32 (B, encode_sequence_length) (optional)
2324 encoder_attention_mask: int32 (B, encode_sequence_length)
24- encoder_hidden_states: (B, encode_sequence_length, encoder_hidden_size)
25+ encoder_hidden_states: (B, encode_sequence_length, encoder_hidden_size) (optional)
2526
2627 past_key_self_0: (B, num_heads, past_decode_sequence_length, head_size)
2728 past_value_self_0: (B, num_heads, past_decode_sequence_length, head_size)
@@ -49,11 +50,9 @@ namespace transformers {
4950
5051Status T5DecoderSubgraph::Validate (const std::vector<const NodeArg*>& subgraph_inputs,
5152 const std::vector<const NodeArg*>& subgraph_outputs) {
52- bool has_hidden_state = subgraph_inputs[2 ]->Name () == " encoder_hidden_states" ? true : false ;
53- SetPastInputIndex (has_hidden_state);
54-
55- ORT_RETURN_IF (first_past_input_index_ != 2 && first_past_input_index_ != 3 ,
56- " kFirstPastInputIndex currently only supports 2 or 3" );
53+ bool has_encoder_input_ids = subgraph_inputs[1 ]->Name () == " encoder_input_ids" ;
54+ bool has_hidden_state = subgraph_inputs[2 + has_encoder_input_ids]->Name () == " encoder_hidden_states" ;
55+ SetPastInputIndex (has_hidden_state, has_encoder_input_ids);
5756
5857 if (!past_present_share_buffer_) {
5958 ORT_RETURN_IF (has_decoder_masked_attention_, " decoder_masked_attention shall use with past_present_share_buffer" );
@@ -75,13 +74,17 @@ Status T5DecoderSubgraph::Validate(const std::vector<const NodeArg*>& subgraph_i
7574
7675 ORT_RETURN_IF (subgraph_inputs[0 ]->Name () != " input_ids" ,
7776 " decoder subgraph input 0 shall be named as input_ids, got: " , subgraph_inputs[0 ]->Name ());
78- ORT_RETURN_IF (subgraph_inputs[1 ]->Name () != " encoder_attention_mask" ,
79- " decoder subgraph input 1 shall be named as encoder_attention_mask, got: " ,
80- subgraph_inputs[1 ]->Name ());
81- if (first_past_input_index_ == 3 ) {
82- ORT_RETURN_IF (subgraph_inputs[2 ]->Name () != " encoder_hidden_states" ,
83- " decoder subgraph input 2 shall be named as encoder_hidden_states, got: " ,
84- subgraph_inputs[2 ]->Name ());
77+ const int enc_attn_mask_index = 1 + has_encoder_input_ids_;
78+ const int enc_hidden_state_index = enc_attn_mask_index + 1 ;
79+ ORT_RETURN_IF (subgraph_inputs[enc_attn_mask_index]->Name () != " encoder_attention_mask" ,
80+ " decoder subgraph input " , std::to_string (enc_attn_mask_index),
81+ " shall be named as encoder_attention_mask, got: " ,
82+ subgraph_inputs[enc_attn_mask_index]->Name ());
83+ if (has_hidden_state_) {
84+ ORT_RETURN_IF (subgraph_inputs[enc_hidden_state_index]->Name () != " encoder_hidden_states" ,
85+ " decoder subgraph input " , std::to_string (enc_hidden_state_index),
86+ " shall be named as encoder_hidden_states, got: " ,
87+ subgraph_inputs[enc_hidden_state_index]->Name ());
8588 }
8689
8790 // check subgraph outputs
@@ -108,12 +111,19 @@ Status T5DecoderSubgraph::Validate(const std::vector<const NodeArg*>& subgraph_i
108111
109112 ORT_RETURN_IF (subgraph_inputs[0 ]->TypeAsProto ()->tensor_type ().elem_type () != int32_type,
110113 " decoder subgraph input 0 (input_ids) shall have int32 type" );
111- ORT_RETURN_IF (subgraph_inputs[1 ]->TypeAsProto ()->tensor_type ().elem_type () != int32_type,
112- " decoder subgraph input 1 (encoder_attention_mask) shall have int32 type" );
113-
114- auto float_type = subgraph_inputs[2 ]->TypeAsProto ()->tensor_type ().elem_type ();
115- ORT_RETURN_IF (float_type != float32_type && float_type != float16_type,
116- " decoder subgraph input 2 (encoder_hidden_states) shall have float or float16 type" );
114+ if (has_encoder_input_ids_) {
115+ ORT_RETURN_IF (subgraph_inputs[1 ]->TypeAsProto ()->tensor_type ().elem_type () != int32_type,
116+ " decoder subgraph input 1 (encoder_input_ids) shall have int32 type" );
117+ }
118+ ORT_RETURN_IF (subgraph_inputs[enc_attn_mask_index]->TypeAsProto ()->tensor_type ().elem_type () != int32_type,
119+ " decoder subgraph input " , std::to_string (enc_attn_mask_index),
120+ " (encoder_attention_mask) shall have int32 type" );
121+
122+ auto float_type = subgraph_inputs[enc_hidden_state_index]->TypeAsProto ()->tensor_type ().elem_type ();
123+ if (has_hidden_state_) {
124+ ORT_RETURN_IF (float_type != float32_type && float_type != float16_type,
125+ " decoder subgraph input " , std::to_string (enc_hidden_state_index), " (encoder_hidden_states) shall have float or float16 type" );
126+ }
117127
118128 for (int i = first_past_input_index_; i < first_past_input_index_ + 4 * num_layers; i++) {
119129 ORT_RETURN_IF (subgraph_inputs[i]->TypeAsProto ()->tensor_type ().elem_type () != float_type,
@@ -219,6 +229,19 @@ Status T5DecoderSubgraph::CreateInitialFeeds(
219229 decoder_feeds.reserve (static_cast <size_t >(num_subgraph_inputs) + static_cast <size_t >(num_implicit_inputs));
220230 decoder_feeds.push_back (input_ids);
221231
232+ if (has_encoder_input_ids_) {
233+ // The encoder_input_ids is copied from the first input of encoder.
234+ OrtValue expanded_encoder_input_ids;
235+ ORT_RETURN_IF_ERROR (expand_buffer_int32_func (stream,
236+ encoder_feeds[0 ],
237+ num_beam,
238+ allocator,
239+ expanded_encoder_input_ids,
240+ false ,
241+ 0 /* max_sequence_length*/ ));
242+ decoder_feeds.push_back (expanded_encoder_input_ids);
243+ }
244+
222245 // The encoder_attention_mask is copied from the second input of encoder.
223246 OrtValue expanded_decoder_attention_masks;
224247 ORT_RETURN_IF_ERROR (expand_buffer_int32_func (stream,
@@ -238,7 +261,9 @@ Status T5DecoderSubgraph::CreateInitialFeeds(
238261 // When first_past_input_index_ == 3, the encoder_hidden_states and past states are copied from the second output
239262 // of encoder.
240263 // When first_past_input_index_ == 2, the past states are copied from the second output of encoder.
241- for (size_t j = static_cast <size_t >(4 ) - first_past_input_index_; j < encoder_fetches.size (); j++) {
264+ // TODO - probably more robust to introduce a encoder_out/decoder_in mapping instead of relying on positions.
265+ // What happens if encoder_hidden_states is present in the encoder_fetches but not in the decoder_feeds?
266+ for (size_t j = static_cast <size_t >(2 ) - has_hidden_state_; j < encoder_fetches.size (); j++) {
242267 if (j == 1 ) {
243268 ORT_RETURN_IF (has_hidden_state_ == false , " Invalid hidden_states expension: has_hidden_state_ == false" );
244269 OrtValue expanded_hidden_states;
0 commit comments