From d44a556737d5f20ce80c45da7b5bd30e96bc85ac Mon Sep 17 00:00:00 2001 From: Kotomi-Du Date: Tue, 4 Nov 2025 10:51:29 -0800 Subject: [PATCH 1/9] use output-to-input strategy to get the pairs of KV name --- .../openvino/ov_stateful_patch_utils.cc | 107 ++++++++++++++---- 1 file changed, 82 insertions(+), 25 deletions(-) diff --git a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc index ca4867b7d8ae4..43df0ddcb656f 100644 --- a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc +++ b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License #include "core/providers/openvino/ov_stateful_patch_utils.h" +#include "regex" namespace onnxruntime { namespace openvino_ep { @@ -134,27 +135,85 @@ void MakeStateful(std::shared_ptr& ov_model, // Converted to C++ from below reference URL: // https://github.com/huggingface/optimum-intel/blob/main/optimum/exporters/openvino/stateful.py#L281 -void PatchStatefulDecoder(std::shared_ptr model) { +// Helper function to extract KV patterns from output names dynamically +std::pair, std::vector> ExtractKVPatternsFromOutputs(const std::shared_ptr& model) { + std::set unique_patterns; + std::vector key_value_output_names; + + // Regex to match "present_" prefix and numeric suffix + std::regex present_pattern(R"(present_(.+)_(\d+))"); + + // Scan all outputs with "present" in the name + for (const ov::Output& output : model->outputs()) { + const auto& names = output.get_names(); + for (const auto& name : names) { + if (name.starts_with("present")) { + key_value_output_names.push_back(name); + std::smatch match; + if (std::regex_match(name, match, present_pattern)) { + // Extract the middle part (between "present_" and "_number") + std::string pattern = match[1].str(); + unique_patterns.insert(pattern); + } + break; + } + } + } + + std::vector extracted_patterns(unique_patterns.begin(), unique_patterns.end()); + + return std::make_pair(key_value_output_names, extracted_patterns); +} + +// Main function to extract KV tensors using dynamic pattern matching +std::pair, std::vector> ExtractInputKVTensors( + const std::shared_ptr& model, std::vector patterns) { + std::vector key_value_input_names; std::vector not_kv_inputs; + + if (patterns.empty()) { + // Fallback: use original substring matching + for (const ov::Output& input : model->inputs()) { + const auto& names = input.get_names(); + const std::string input_name = input.get_any_name(); + + bool is_kv_input = false; + for (const auto& name : names) { + if (name.find("key_values") != std::string::npos || + name.find("keys") != std::string::npos || + name.find("values") != std::string::npos) { + key_value_input_names.push_back(name); + is_kv_input = true; + break; + } + } + + if (!is_kv_input) { + not_kv_inputs.push_back(input_name); + } + } + + return std::make_pair(key_value_input_names, not_kv_inputs); + } + + std::set found_kv_inputs; + for (const ov::Output& input : model->inputs()) { - auto& names = input.get_names(); + const auto& names = input.get_names(); bool found = false; - for (auto& name : names) { - if (name.find("key_values") != std::string::npos) { - key_value_input_names.push_back(name); - found = true; - break; - } else if (name.find("keys") != std::string::npos) { - key_value_input_names.push_back(name); - found = true; - break; - } else if (name.find("values") != std::string::npos) { - key_value_input_names.push_back(name); - found = true; - break; + + // Check each input name against potential patterns + for (const auto& name : names) { + for (const auto& pattern : patterns) { + if (name.find(pattern) != std::string::npos){ + key_value_input_names.push_back(name); + found = true; + break; + } } + if (found) break; } if (!found) { @@ -162,16 +221,14 @@ void PatchStatefulDecoder(std::shared_ptr model) { } } - std::vector key_value_output_names; - for (const ov::Output& output : model->outputs()) { - auto& names = output.get_names(); - for (auto& name : names) { - if (name.find("present") != std::string::npos) { - key_value_output_names.push_back(name); - break; - } - } - } + return std::make_pair(key_value_input_names, not_kv_inputs); +} + +// Updated PatchStatefulDecoder function +void PatchStatefulDecoder(std::shared_ptr model) { + // Use the dynamic pattern-based extraction logic + auto [key_value_output_names, extracted_patterns] = ExtractKVPatternsFromOutputs(model); + auto [key_value_input_names, not_kv_inputs] = ExtractInputKVTensors(model, extracted_patterns); if (key_value_input_names.empty() || key_value_output_names.empty()) { std::cout << "no key_value_input_names or key_value_output_names found" << std::endl; From 1634684d1c785451049dd740489206f063d096e4 Mon Sep 17 00:00:00 2001 From: Kotomi-Du Date: Wed, 5 Nov 2025 16:30:39 -0800 Subject: [PATCH 2/9] minor change --- .../core/providers/openvino/ov_stateful_patch_utils.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc index 43df0ddcb656f..ea66a76810517 100644 --- a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc +++ b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc @@ -167,7 +167,7 @@ std::pair, std::vector> ExtractKVPatternsF // Main function to extract KV tensors using dynamic pattern matching std::pair, std::vector> ExtractInputKVTensors( - const std::shared_ptr& model, std::vector patterns) { + const std::shared_ptr& model, const std::vector& patterns) { std::vector key_value_input_names; std::vector not_kv_inputs; @@ -200,7 +200,7 @@ std::pair, std::vector> ExtractInputKVTens std::set found_kv_inputs; for (const ov::Output& input : model->inputs()) { - const auto& names = input.get_names(); + auto& names = input.get_names(); bool found = false; From f6b3c6362aaea50c8e2a1839370983b0ef22fbe6 Mon Sep 17 00:00:00 2001 From: Kotomi-Du Date: Thu, 6 Nov 2025 17:55:21 -0800 Subject: [PATCH 3/9] remove regex for extracting pattern --- .../openvino/ov_stateful_patch_utils.cc | 32 +++++++++++-------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc index ea66a76810517..89b3c0ba1ab7a 100644 --- a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc +++ b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc @@ -2,7 +2,6 @@ // Licensed under the MIT License #include "core/providers/openvino/ov_stateful_patch_utils.h" -#include "regex" namespace onnxruntime { namespace openvino_ep { @@ -140,26 +139,27 @@ std::pair, std::vector> ExtractKVPatternsF std::set unique_patterns; std::vector key_value_output_names; - // Regex to match "present_" prefix and numeric suffix - std::regex present_pattern(R"(present_(.+)_(\d+))"); - - // Scan all outputs with "present" in the name + const std::string prefix = "present_"; + const size_t prefix_len = prefix.length(); for (const ov::Output& output : model->outputs()) { const auto& names = output.get_names(); for (const auto& name : names) { - if (name.starts_with("present")) { + if (name.find(prefix) == 0 && name.length() > prefix_len) { key_value_output_names.push_back(name); - std::smatch match; - if (std::regex_match(name, match, present_pattern)) { - // Extract the middle part (between "present_" and "_number") - std::string pattern = match[1].str(); - unique_patterns.insert(pattern); + size_t last_underscore_pos = name.rfind('_'); + + // Extract pattern between "present_" and the last underscore + if (last_underscore_pos != std::string::npos && last_underscore_pos > prefix_len) { + std::string pattern = name.substr(prefix_len, last_underscore_pos - prefix_len); + + if (!pattern.empty()) { + unique_patterns.insert(pattern); + } } break; } } } - std::vector extracted_patterns(unique_patterns.begin(), unique_patterns.end()); return std::make_pair(key_value_output_names, extracted_patterns); @@ -204,7 +204,7 @@ std::pair, std::vector> ExtractInputKVTens bool found = false; - // Check each input name against potential patterns + // Check if any input name contains the extracted patterns for (const auto& name : names) { for (const auto& pattern : patterns) { if (name.find(pattern) != std::string::npos){ @@ -230,11 +230,17 @@ void PatchStatefulDecoder(std::shared_ptr model) { auto [key_value_output_names, extracted_patterns] = ExtractKVPatternsFromOutputs(model); auto [key_value_input_names, not_kv_inputs] = ExtractInputKVTensors(model, extracted_patterns); + std::cout << key_value_input_names.size() << ";" << key_value_output_names.size() << std::endl; if (key_value_input_names.empty() || key_value_output_names.empty()) { std::cout << "no key_value_input_names or key_value_output_names found" << std::endl; return; } + if (key_value_input_names.size() != key_value_output_names.size()) { + std::cout << "found different sizes btween key_value_input_names and key_value_output_names, they couldn't be paired" << std::endl; + return; + } + // By default, batch is the 0 - th but chatglm uses 1 - st dimension as batch // TODO(ryan): Deduce from a model via ordinal reshape(? ) and topology // batch_dim = 1 if config.model_type == "chatglm" and not hasattr(config, "rope_ratio") else 0 From 165a7cd6d94e8c0c17039fcc49ee9a787895a4b5 Mon Sep 17 00:00:00 2001 From: Kotomi-Du Date: Fri, 7 Nov 2025 12:04:50 -0800 Subject: [PATCH 4/9] Address review --- .../providers/openvino/ov_stateful_patch_utils.cc | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc index 89b3c0ba1ab7a..d887d1855b7a7 100644 --- a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc +++ b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License #include "core/providers/openvino/ov_stateful_patch_utils.h" +#include "core/common/common.h" namespace onnxruntime { namespace openvino_ep { @@ -230,15 +231,16 @@ void PatchStatefulDecoder(std::shared_ptr model) { auto [key_value_output_names, extracted_patterns] = ExtractKVPatternsFromOutputs(model); auto [key_value_input_names, not_kv_inputs] = ExtractInputKVTensors(model, extracted_patterns); - std::cout << key_value_input_names.size() << ";" << key_value_output_names.size() << std::endl; if (key_value_input_names.empty() || key_value_output_names.empty()) { - std::cout << "no key_value_input_names or key_value_output_names found" << std::endl; - return; + ORT_THROW("No key_value_input_names or key_value_output_names found"); } if (key_value_input_names.size() != key_value_output_names.size()) { - std::cout << "found different sizes btween key_value_input_names and key_value_output_names, they couldn't be paired" << std::endl; - return; + ORT_THROW("Found different sizes between key_value_input_names (", + key_value_input_names.size(), + ") and key_value_output_names (", + key_value_output_names.size(), + "). They couldn't be paired."); } // By default, batch is the 0 - th but chatglm uses 1 - st dimension as batch From 8565e296282cc60999a3b14003aafeb6f40b0ba3 Mon Sep 17 00:00:00 2001 From: Kotomi-Du Date: Mon, 10 Nov 2025 14:56:24 -0800 Subject: [PATCH 5/9] Design strict KV patterns: only two separately for key and value; patterns have to be followed by _%d --- .../openvino/ov_stateful_patch_utils.cc | 71 ++++++++++++++----- 1 file changed, 55 insertions(+), 16 deletions(-) diff --git a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc index d887d1855b7a7..bc0b15c27e946 100644 --- a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc +++ b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License #include "core/providers/openvino/ov_stateful_patch_utils.h" +#include "core/providers/shared_library/provider_api.h" #include "core/common/common.h" namespace onnxruntime { @@ -136,12 +137,15 @@ void MakeStateful(std::shared_ptr& ov_model, // Converted to C++ from below reference URL: // https://github.com/huggingface/optimum-intel/blob/main/optimum/exporters/openvino/stateful.py#L281 // Helper function to extract KV patterns from output names dynamically -std::pair, std::vector> ExtractKVPatternsFromOutputs(const std::shared_ptr& model) { - std::set unique_patterns; +std::pair, std::optional>> ExtractKVPatternsFromOutputs(const std::shared_ptr& model) { std::vector key_value_output_names; + std::string key_pattern; + std::string value_pattern; + std::optional> pattern_pair; const std::string prefix = "present_"; const size_t prefix_len = prefix.length(); + std::unordered_set unique_patterns; for (const ov::Output& output : model->outputs()) { const auto& names = output.get_names(); for (const auto& name : names) { @@ -161,19 +165,41 @@ std::pair, std::vector> ExtractKVPatternsF } } } - std::vector extracted_patterns(unique_patterns.begin(), unique_patterns.end()); - return std::make_pair(key_value_output_names, extracted_patterns); + if (unique_patterns.size() > 2) { + ORT_THROW("More than two unique KV patterns found in output names."); + } + + // Traverse unique patterns and assign them based on "key" or "value" substring + for (const auto& pattern : unique_patterns) { + std::string pattern_lower = pattern; + std::transform(pattern_lower.begin(), pattern_lower.end(), pattern_lower.begin(), ::tolower); + if (pattern_lower.find("key") != std::string::npos) { + key_pattern = pattern; + } else if (pattern_lower.find("value") != std::string::npos) { + value_pattern = pattern; + } + } + + if (key_pattern.empty() || value_pattern.empty()) { + ORT_THROW("Could not find both key and value patterns in output names."); + } + else + { + LOGS_DEFAULT(INFO) << "Extracted key pattern: " << key_pattern << ", value pattern: " << value_pattern; + } + pattern_pair = std::make_pair(key_pattern, value_pattern); + return std::make_pair(key_value_output_names, pattern_pair); } // Main function to extract KV tensors using dynamic pattern matching std::pair, std::vector> ExtractInputKVTensors( - const std::shared_ptr& model, const std::vector& patterns) { + const std::shared_ptr& model, const std::optional>& kv_pattern) { std::vector key_value_input_names; std::vector not_kv_inputs; - if (patterns.empty()) { + if (!kv_pattern.has_value()) { // Fallback: use original substring matching for (const ov::Output& input : model->inputs()) { const auto& names = input.get_names(); @@ -198,23 +224,36 @@ std::pair, std::vector> ExtractInputKVTens return std::make_pair(key_value_input_names, not_kv_inputs); } - std::set found_kv_inputs; + // Extract the key and value patterns from the pair + const auto& [key_pattern, value_pattern] = kv_pattern.value(); + + // Inline helper function to check if name is matched with provided pattern followed by "_%d" + auto matches_pattern = [](const std::string& name, const std::string& pattern) -> bool { + size_t pos = name.find(pattern); + if (pos == std::string::npos) { + return false; + } + + size_t after_pattern = pos + pattern.length(); + if (after_pattern >= name.length() || name[after_pattern] != '_') { + return false; + } + + std::string suffix = name.substr(after_pattern + 1); + return !suffix.empty() && std::all_of(suffix.begin(), suffix.end(), ::isdigit); + }; for (const ov::Output& input : model->inputs()) { auto& names = input.get_names(); - bool found = false; - // Check if any input name contains the extracted patterns + // Check if any input name contains either key or value pattern for (const auto& name : names) { - for (const auto& pattern : patterns) { - if (name.find(pattern) != std::string::npos){ - key_value_input_names.push_back(name); - found = true; - break; - } + if (matches_pattern(name, key_pattern) || matches_pattern(name, value_pattern)) { + key_value_input_names.push_back(name); + found = true; + break; } - if (found) break; } if (!found) { From a873e7b2c2cbaa255b07b7c6a86b648086a36791 Mon Sep 17 00:00:00 2001 From: Kotomi-Du Date: Tue, 11 Nov 2025 19:19:19 -0800 Subject: [PATCH 6/9] simplify code structure --- .../openvino/ov_stateful_patch_utils.cc | 47 +++++-------------- 1 file changed, 12 insertions(+), 35 deletions(-) diff --git a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc index bc0b15c27e946..d2f8d2e396ac4 100644 --- a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc +++ b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc @@ -137,15 +137,12 @@ void MakeStateful(std::shared_ptr& ov_model, // Converted to C++ from below reference URL: // https://github.com/huggingface/optimum-intel/blob/main/optimum/exporters/openvino/stateful.py#L281 // Helper function to extract KV patterns from output names dynamically -std::pair, std::optional>> ExtractKVPatternsFromOutputs(const std::shared_ptr& model) { +std::pair, std::unordered_set> ExtractKVPatternsFromOutputs(const std::shared_ptr& model) { std::vector key_value_output_names; - std::string key_pattern; - std::string value_pattern; - std::optional> pattern_pair; + std::unordered_set unique_patterns; const std::string prefix = "present_"; const size_t prefix_len = prefix.length(); - std::unordered_set unique_patterns; for (const ov::Output& output : model->outputs()) { const auto& names = output.get_names(); for (const auto& name : names) { @@ -169,37 +166,17 @@ std::pair, std::optional 2) { ORT_THROW("More than two unique KV patterns found in output names."); } - - // Traverse unique patterns and assign them based on "key" or "value" substring - for (const auto& pattern : unique_patterns) { - std::string pattern_lower = pattern; - std::transform(pattern_lower.begin(), pattern_lower.end(), pattern_lower.begin(), ::tolower); - if (pattern_lower.find("key") != std::string::npos) { - key_pattern = pattern; - } else if (pattern_lower.find("value") != std::string::npos) { - value_pattern = pattern; - } - } - - if (key_pattern.empty() || value_pattern.empty()) { - ORT_THROW("Could not find both key and value patterns in output names."); - } - else - { - LOGS_DEFAULT(INFO) << "Extracted key pattern: " << key_pattern << ", value pattern: " << value_pattern; - } - pattern_pair = std::make_pair(key_pattern, value_pattern); - return std::make_pair(key_value_output_names, pattern_pair); + return std::make_pair(key_value_output_names, unique_patterns); } // Main function to extract KV tensors using dynamic pattern matching std::pair, std::vector> ExtractInputKVTensors( - const std::shared_ptr& model, const std::optional>& kv_pattern) { + const std::shared_ptr& model, const std::unordered_set& kv_patterns) { std::vector key_value_input_names; std::vector not_kv_inputs; - if (!kv_pattern.has_value()) { + if (kv_patterns.empty()) { // Fallback: use original substring matching for (const ov::Output& input : model->inputs()) { const auto& names = input.get_names(); @@ -224,9 +201,6 @@ std::pair, std::vector> ExtractInputKVTens return std::make_pair(key_value_input_names, not_kv_inputs); } - // Extract the key and value patterns from the pair - const auto& [key_pattern, value_pattern] = kv_pattern.value(); - // Inline helper function to check if name is matched with provided pattern followed by "_%d" auto matches_pattern = [](const std::string& name, const std::string& pattern) -> bool { size_t pos = name.find(pattern); @@ -249,11 +223,14 @@ std::pair, std::vector> ExtractInputKVTens // Check if any input name contains either key or value pattern for (const auto& name : names) { - if (matches_pattern(name, key_pattern) || matches_pattern(name, value_pattern)) { - key_value_input_names.push_back(name); - found = true; - break; + for (const auto& pattern : kv_patterns) { + if (matches_pattern(name, pattern)) { + key_value_input_names.push_back(name); + found = true; + break; + } } + if (found) break; } if (!found) { From 5d5d1f206cc9b6467cc8c49315cf0c339dfc6854 Mon Sep 17 00:00:00 2001 From: Kotomi-Du Date: Wed, 12 Nov 2025 12:17:12 -0800 Subject: [PATCH 7/9] address review --- .../core/providers/openvino/ov_stateful_patch_utils.cc | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc index d2f8d2e396ac4..2c6af40c92ad0 100644 --- a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc +++ b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc @@ -147,15 +147,13 @@ std::pair, std::unordered_set> ExtractKVPa const auto& names = output.get_names(); for (const auto& name : names) { if (name.find(prefix) == 0 && name.length() > prefix_len) { - key_value_output_names.push_back(name); size_t last_underscore_pos = name.rfind('_'); - // Extract pattern between "present_" and the last underscore if (last_underscore_pos != std::string::npos && last_underscore_pos > prefix_len) { std::string pattern = name.substr(prefix_len, last_underscore_pos - prefix_len); - if (!pattern.empty()) { unique_patterns.insert(pattern); + key_value_output_names.push_back(name); } } break; From cec822e562c49a0fbfdb5206f0a46453e62f3ea5 Mon Sep 17 00:00:00 2001 From: Kotomi-Du Date: Wed, 12 Nov 2025 15:37:35 -0800 Subject: [PATCH 8/9] remove useless comment --- onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc index 2c6af40c92ad0..19764dc2ed53a 100644 --- a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc +++ b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc @@ -134,8 +134,6 @@ void MakeStateful(std::shared_ptr& ov_model, manager.run_passes(ov_model); } -// Converted to C++ from below reference URL: -// https://github.com/huggingface/optimum-intel/blob/main/optimum/exporters/openvino/stateful.py#L281 // Helper function to extract KV patterns from output names dynamically std::pair, std::unordered_set> ExtractKVPatternsFromOutputs(const std::shared_ptr& model) { std::vector key_value_output_names; From 38ae63929f66e5781f7c1208f0d34483c9d46ae4 Mon Sep 17 00:00:00 2001 From: Kotomi-Du Date: Thu, 13 Nov 2025 10:31:42 -0800 Subject: [PATCH 9/9] add brief example to explain the functionalities --- .../core/providers/openvino/ov_stateful_patch_utils.cc | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc index 19764dc2ed53a..90a6f94d50cb7 100644 --- a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc +++ b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc @@ -135,6 +135,10 @@ void MakeStateful(std::shared_ptr& ov_model, } // Helper function to extract KV patterns from output names dynamically +// +// Example: Given output names ["present_key_cross_0", "present_key_cross_1", "present_value_cross_0", "present_value_cross_1", "logits"] +// key_value_output_names = ["present_key_cross_0", "present_key_cross_1", "present_value_cross_0", "present_value_cross_1"] +// unique_patterns = {"key_cross", "value_cross"} std::pair, std::unordered_set> ExtractKVPatternsFromOutputs(const std::shared_ptr& model) { std::vector key_value_output_names; std::unordered_set unique_patterns; @@ -166,6 +170,12 @@ std::pair, std::unordered_set> ExtractKVPa } // Main function to extract KV tensors using dynamic pattern matching +// +// Example: Given input names ["input_ids", "attention_mask", "past_key_cross_0", "past_key_cross_1", "past_value_cross_0", "past_value_cross_1"] +// kv_patterns = {"key_cross", "value_cross"} +// +// key_value_input_names = ["past_key_cross_0", "past_key_cross_1", "past_value_cross_0", "past_value_cross_1"] +// not_kv_inputs = ["input_ids", "attention_mask"] std::pair, std::vector> ExtractInputKVTensors( const std::shared_ptr& model, const std::unordered_set& kv_patterns) {