Skip to content

Commit c6ba7ed

Browse files
authored
Enable pointer-generator T5 models in BeamSearch (microsoft#23134)
### Description Introduces a new optional input (encoder_ibnput_ids) in the decoder graph of the T5 implementation for BeamSearch. This allows usage of pointer generator networks in decoder graph. ### Motivation and Context - Fixes microsoft#23123
1 parent ebdbbb7 commit c6ba7ed

File tree

5 files changed

+448
-26
lines changed

5 files changed

+448
-26
lines changed

onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc

Lines changed: 45 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@ namespace transformers {
2020
2121
Inputs:
2222
input_ids: int32 (B, 1)
23+
encoder_input_ids: int32 (B, encode_sequence_length) (optional)
2324
encoder_attention_mask: int32 (B, encode_sequence_length)
24-
encoder_hidden_states: (B, encode_sequence_length, encoder_hidden_size)
25+
encoder_hidden_states: (B, encode_sequence_length, encoder_hidden_size) (optional)
2526
2627
past_key_self_0: (B, num_heads, past_decode_sequence_length, head_size)
2728
past_value_self_0: (B, num_heads, past_decode_sequence_length, head_size)
@@ -49,11 +50,9 @@ namespace transformers {
4950

5051
Status T5DecoderSubgraph::Validate(const std::vector<const NodeArg*>& subgraph_inputs,
5152
const std::vector<const NodeArg*>& subgraph_outputs) {
52-
bool has_hidden_state = subgraph_inputs[2]->Name() == "encoder_hidden_states" ? true : false;
53-
SetPastInputIndex(has_hidden_state);
54-
55-
ORT_RETURN_IF(first_past_input_index_ != 2 && first_past_input_index_ != 3,
56-
"kFirstPastInputIndex currently only supports 2 or 3");
53+
bool has_encoder_input_ids = subgraph_inputs[1]->Name() == "encoder_input_ids";
54+
bool has_hidden_state = subgraph_inputs[2 + has_encoder_input_ids]->Name() == "encoder_hidden_states";
55+
SetPastInputIndex(has_hidden_state, has_encoder_input_ids);
5756

5857
if (!past_present_share_buffer_) {
5958
ORT_RETURN_IF(has_decoder_masked_attention_, "decoder_masked_attention shall use with past_present_share_buffer");
@@ -75,13 +74,17 @@ Status T5DecoderSubgraph::Validate(const std::vector<const NodeArg*>& subgraph_i
7574

7675
ORT_RETURN_IF(subgraph_inputs[0]->Name() != "input_ids",
7776
"decoder subgraph input 0 shall be named as input_ids, got: ", subgraph_inputs[0]->Name());
78-
ORT_RETURN_IF(subgraph_inputs[1]->Name() != "encoder_attention_mask",
79-
"decoder subgraph input 1 shall be named as encoder_attention_mask, got: ",
80-
subgraph_inputs[1]->Name());
81-
if (first_past_input_index_ == 3) {
82-
ORT_RETURN_IF(subgraph_inputs[2]->Name() != "encoder_hidden_states",
83-
"decoder subgraph input 2 shall be named as encoder_hidden_states, got: ",
84-
subgraph_inputs[2]->Name());
77+
const int enc_attn_mask_index = 1 + has_encoder_input_ids_;
78+
const int enc_hidden_state_index = enc_attn_mask_index + 1;
79+
ORT_RETURN_IF(subgraph_inputs[enc_attn_mask_index]->Name() != "encoder_attention_mask",
80+
"decoder subgraph input ", std::to_string(enc_attn_mask_index),
81+
" shall be named as encoder_attention_mask, got: ",
82+
subgraph_inputs[enc_attn_mask_index]->Name());
83+
if (has_hidden_state_) {
84+
ORT_RETURN_IF(subgraph_inputs[enc_hidden_state_index]->Name() != "encoder_hidden_states",
85+
"decoder subgraph input ", std::to_string(enc_hidden_state_index),
86+
" shall be named as encoder_hidden_states, got: ",
87+
subgraph_inputs[enc_hidden_state_index]->Name());
8588
}
8689

8790
// check subgraph outputs
@@ -108,12 +111,19 @@ Status T5DecoderSubgraph::Validate(const std::vector<const NodeArg*>& subgraph_i
108111

109112
ORT_RETURN_IF(subgraph_inputs[0]->TypeAsProto()->tensor_type().elem_type() != int32_type,
110113
"decoder subgraph input 0 (input_ids) shall have int32 type");
111-
ORT_RETURN_IF(subgraph_inputs[1]->TypeAsProto()->tensor_type().elem_type() != int32_type,
112-
"decoder subgraph input 1 (encoder_attention_mask) shall have int32 type");
113-
114-
auto float_type = subgraph_inputs[2]->TypeAsProto()->tensor_type().elem_type();
115-
ORT_RETURN_IF(float_type != float32_type && float_type != float16_type,
116-
"decoder subgraph input 2 (encoder_hidden_states) shall have float or float16 type");
114+
if (has_encoder_input_ids_) {
115+
ORT_RETURN_IF(subgraph_inputs[1]->TypeAsProto()->tensor_type().elem_type() != int32_type,
116+
"decoder subgraph input 1 (encoder_input_ids) shall have int32 type");
117+
}
118+
ORT_RETURN_IF(subgraph_inputs[enc_attn_mask_index]->TypeAsProto()->tensor_type().elem_type() != int32_type,
119+
"decoder subgraph input ", std::to_string(enc_attn_mask_index),
120+
" (encoder_attention_mask) shall have int32 type");
121+
122+
auto float_type = subgraph_inputs[enc_hidden_state_index]->TypeAsProto()->tensor_type().elem_type();
123+
if (has_hidden_state_) {
124+
ORT_RETURN_IF(float_type != float32_type && float_type != float16_type,
125+
"decoder subgraph input ", std::to_string(enc_hidden_state_index), " (encoder_hidden_states) shall have float or float16 type");
126+
}
117127

118128
for (int i = first_past_input_index_; i < first_past_input_index_ + 4 * num_layers; i++) {
119129
ORT_RETURN_IF(subgraph_inputs[i]->TypeAsProto()->tensor_type().elem_type() != float_type,
@@ -219,6 +229,19 @@ Status T5DecoderSubgraph::CreateInitialFeeds(
219229
decoder_feeds.reserve(static_cast<size_t>(num_subgraph_inputs) + static_cast<size_t>(num_implicit_inputs));
220230
decoder_feeds.push_back(input_ids);
221231

232+
if (has_encoder_input_ids_) {
233+
// The encoder_input_ids is copied from the first input of encoder.
234+
OrtValue expanded_encoder_input_ids;
235+
ORT_RETURN_IF_ERROR(expand_buffer_int32_func(stream,
236+
encoder_feeds[0],
237+
num_beam,
238+
allocator,
239+
expanded_encoder_input_ids,
240+
false,
241+
0 /*max_sequence_length*/));
242+
decoder_feeds.push_back(expanded_encoder_input_ids);
243+
}
244+
222245
// The encoder_attention_mask is copied from the second input of encoder.
223246
OrtValue expanded_decoder_attention_masks;
224247
ORT_RETURN_IF_ERROR(expand_buffer_int32_func(stream,
@@ -238,7 +261,9 @@ Status T5DecoderSubgraph::CreateInitialFeeds(
238261
// When first_past_input_index_ == 3, the encoder_hidden_states and past states are copied from the second output
239262
// of encoder.
240263
// When first_past_input_index_ == 2, the past states are copied from the second output of encoder.
241-
for (size_t j = static_cast<size_t>(4) - first_past_input_index_; j < encoder_fetches.size(); j++) {
264+
// TODO - probably more robust to introduce a encoder_out/decoder_in mapping instead of relying on positions.
265+
// What happens if encoder_hidden_states is present in the encoder_fetches but not in the decoder_feeds?
266+
for (size_t j = static_cast<size_t>(2) - has_hidden_state_; j < encoder_fetches.size(); j++) {
242267
if (j == 1) {
243268
ORT_RETURN_IF(has_hidden_state_ == false, "Invalid hidden_states expension: has_hidden_state_ == false");
244269
OrtValue expanded_hidden_states;

onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,10 @@ class T5DecoderSubgraph : public Subgraph {
5454
Status Validate(const std::vector<const NodeArg*>& subgraph_inputs,
5555
const std::vector<const NodeArg*>& subgraph_outputs) override;
5656

57-
void SetPastInputIndex(bool has_hidden_state) {
57+
void SetPastInputIndex(bool has_hidden_state, bool has_encoder_input_ids) {
5858
has_hidden_state_ = has_hidden_state;
59-
if (!has_hidden_state_) {
60-
first_past_input_index_ = 2;
61-
} else {
62-
first_past_input_index_ = 3;
63-
}
59+
has_encoder_input_ids_ = has_encoder_input_ids;
60+
first_past_input_index_ = 2 + has_hidden_state_ + has_encoder_input_ids_;
6461
}
6562

6663
int GetFirstPastInputIndex() const {
@@ -79,6 +76,7 @@ class T5DecoderSubgraph : public Subgraph {
7976
int first_past_input_index_;
8077
int first_present_output_index_;
8178
bool has_hidden_state_;
79+
bool has_encoder_input_ids_;
8280
bool use_sequence_as_input_ids_;
8381
};
8482

onnxruntime/test/contrib_ops/beam_search_test.cc

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,8 @@ TEST(BeamSearchTest, DummyT5) {
394394
#if defined(USE_CUDA) && defined(USE_DML)
395395
SKIP_CUDA_TEST_WITH_DML;
396396
#endif
397+
// dummy_t5.onnx model generated using following command:
398+
// python onnxruntime/test/testdata/dummy_t5_generator.py --output-path dummy_t5.onnx
397399
ModelTester tester(CurrentTestName(), ORT_TSTR("testdata/dummy_t5.onnx"));
398400
tester.ConfigEp(DefaultCpuExecutionProvider());
399401
tester.AddInput("encoder_input_ids", {1, 5}, {14, 6, 13, 9, 7});
@@ -408,6 +410,8 @@ TEST(BeamSearchTest, DummyT5WithOuterScopeInitializers) {
408410
#if defined(USE_CUDA) && defined(USE_DML)
409411
SKIP_CUDA_TEST_WITH_DML;
410412
#endif
413+
// dummy_t5_with_outer_scope_initializers.onnx model generated using following command:
414+
// python onnxruntime/test/testdata/dummy_t5_generator.py --output-path dummy_t5_with_outer_scope_initializers.onnx --move-initializers
411415
ModelTester tester(CurrentTestName(), ORT_TSTR("testdata/dummy_t5_with_outer_scope_initializers.onnx"));
412416
tester.ConfigEp(DefaultCpuExecutionProvider());
413417
tester.AddInput("encoder_input_ids", {1, 5}, {14, 6, 13, 9, 7});
@@ -422,6 +426,8 @@ TEST(BeamSearchTest, DummyT5WithSequenceInputIds) {
422426
#if defined(USE_CUDA) && defined(USE_DML)
423427
SKIP_CUDA_TEST_WITH_DML;
424428
#endif
429+
// dummy_t5_with_sequence_input_ids.onnx model generated using following command:
430+
// python onnxruntime/test/testdata/dummy_t5_generator.py --output-path dummy_t5_with_sequence_input_ids.onnx --sequence-as-input
425431
ModelTester tester(CurrentTestName(), ORT_TSTR("testdata/dummy_t5_with_sequence_input_ids.onnx"));
426432
tester.ConfigEp(DefaultCpuExecutionProvider());
427433
tester.AddInput("encoder_input_ids", {1, 5}, {16, 17, 1, 0, 8});
@@ -432,5 +438,21 @@ TEST(BeamSearchTest, DummyT5WithSequenceInputIds) {
432438
tester.RunWithConfig();
433439
}
434440

441+
TEST(BeamSearchTest, DummyT5PointerGenerator) {
442+
#if defined(USE_CUDA) && defined(USE_DML)
443+
SKIP_CUDA_TEST_WITH_DML;
444+
#endif
445+
// dummy_t5_pointer_generator.onnx model generated using following command:
446+
// python onnxruntime/test/testdata/dummy_t5_generator.py --output-path dummy_t5_pointer_generator.onnx --decoder-needs-input-ids
447+
ModelTester tester(CurrentTestName(), ORT_TSTR("testdata/dummy_t5_pointer_generator.onnx"));
448+
tester.ConfigEp(DefaultCpuExecutionProvider());
449+
tester.AddInput("encoder_input_ids", {1, 5}, {14, 6, 13, 9, 7});
450+
tester.AddOutput("sequences", {1, 3, 10}, {2, 3, 6, 7, 3, 6, 7, 18, 3, 6, 2, 3, 6, 7, 18, 3, 6, 7, 18, 3, 2, 3, 6, 7, 3, 6, 7, 3, 6, 7});
451+
#ifdef USE_CUDA
452+
tester.ConfigEp(DefaultCudaExecutionProvider());
453+
#endif
454+
tester.RunWithConfig();
455+
}
456+
435457
} // namespace test
436458
} // namespace onnxruntime

0 commit comments

Comments
 (0)