diff --git a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc index ca4867b7d8ae4..90a6f94d50cb7 100644 --- a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc +++ b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc @@ -2,6 +2,8 @@ // Licensed under the MIT License #include "core/providers/openvino/ov_stateful_patch_utils.h" +#include "core/providers/shared_library/provider_api.h" +#include "core/common/common.h" namespace onnxruntime { namespace openvino_ep { @@ -132,29 +134,109 @@ void MakeStateful(std::shared_ptr& ov_model, manager.run_passes(ov_model); } -// Converted to C++ from below reference URL: -// https://github.com/huggingface/optimum-intel/blob/main/optimum/exporters/openvino/stateful.py#L281 -void PatchStatefulDecoder(std::shared_ptr model) { +// Helper function to extract KV patterns from output names dynamically +// +// Example: Given output names ["present_key_cross_0", "present_key_cross_1", "present_value_cross_0", "present_value_cross_1", "logits"] +// key_value_output_names = ["present_key_cross_0", "present_key_cross_1", "present_value_cross_0", "present_value_cross_1"] +// unique_patterns = {"key_cross", "value_cross"} +std::pair, std::unordered_set> ExtractKVPatternsFromOutputs(const std::shared_ptr& model) { + std::vector key_value_output_names; + std::unordered_set unique_patterns; + + const std::string prefix = "present_"; + const size_t prefix_len = prefix.length(); + for (const ov::Output& output : model->outputs()) { + const auto& names = output.get_names(); + for (const auto& name : names) { + if (name.find(prefix) == 0 && name.length() > prefix_len) { + size_t last_underscore_pos = name.rfind('_'); + // Extract pattern between "present_" and the last underscore + if (last_underscore_pos != std::string::npos && last_underscore_pos > prefix_len) { + std::string pattern = name.substr(prefix_len, last_underscore_pos - prefix_len); + if (!pattern.empty()) { + unique_patterns.insert(pattern); + key_value_output_names.push_back(name); + } + } + break; + } + } + } + + if (unique_patterns.size() > 2) { + ORT_THROW("More than two unique KV patterns found in output names."); + } + return std::make_pair(key_value_output_names, unique_patterns); +} + +// Main function to extract KV tensors using dynamic pattern matching +// +// Example: Given input names ["input_ids", "attention_mask", "past_key_cross_0", "past_key_cross_1", "past_value_cross_0", "past_value_cross_1"] +// kv_patterns = {"key_cross", "value_cross"} +// +// key_value_input_names = ["past_key_cross_0", "past_key_cross_1", "past_value_cross_0", "past_value_cross_1"] +// not_kv_inputs = ["input_ids", "attention_mask"] +std::pair, std::vector> ExtractInputKVTensors( + const std::shared_ptr& model, const std::unordered_set& kv_patterns) { + std::vector key_value_input_names; std::vector not_kv_inputs; + + if (kv_patterns.empty()) { + // Fallback: use original substring matching + for (const ov::Output& input : model->inputs()) { + const auto& names = input.get_names(); + const std::string input_name = input.get_any_name(); + + bool is_kv_input = false; + for (const auto& name : names) { + if (name.find("key_values") != std::string::npos || + name.find("keys") != std::string::npos || + name.find("values") != std::string::npos) { + key_value_input_names.push_back(name); + is_kv_input = true; + break; + } + } + + if (!is_kv_input) { + not_kv_inputs.push_back(input_name); + } + } + + return std::make_pair(key_value_input_names, not_kv_inputs); + } + + // Inline helper function to check if name is matched with provided pattern followed by "_%d" + auto matches_pattern = [](const std::string& name, const std::string& pattern) -> bool { + size_t pos = name.find(pattern); + if (pos == std::string::npos) { + return false; + } + + size_t after_pattern = pos + pattern.length(); + if (after_pattern >= name.length() || name[after_pattern] != '_') { + return false; + } + + std::string suffix = name.substr(after_pattern + 1); + return !suffix.empty() && std::all_of(suffix.begin(), suffix.end(), ::isdigit); + }; + for (const ov::Output& input : model->inputs()) { auto& names = input.get_names(); - bool found = false; - for (auto& name : names) { - if (name.find("key_values") != std::string::npos) { - key_value_input_names.push_back(name); - found = true; - break; - } else if (name.find("keys") != std::string::npos) { - key_value_input_names.push_back(name); - found = true; - break; - } else if (name.find("values") != std::string::npos) { - key_value_input_names.push_back(name); - found = true; - break; + + // Check if any input name contains either key or value pattern + for (const auto& name : names) { + for (const auto& pattern : kv_patterns) { + if (matches_pattern(name, pattern)) { + key_value_input_names.push_back(name); + found = true; + break; + } } + if (found) break; } if (!found) { @@ -162,20 +244,25 @@ void PatchStatefulDecoder(std::shared_ptr model) { } } - std::vector key_value_output_names; - for (const ov::Output& output : model->outputs()) { - auto& names = output.get_names(); - for (auto& name : names) { - if (name.find("present") != std::string::npos) { - key_value_output_names.push_back(name); - break; - } - } - } + return std::make_pair(key_value_input_names, not_kv_inputs); +} + +// Updated PatchStatefulDecoder function +void PatchStatefulDecoder(std::shared_ptr model) { + // Use the dynamic pattern-based extraction logic + auto [key_value_output_names, extracted_patterns] = ExtractKVPatternsFromOutputs(model); + auto [key_value_input_names, not_kv_inputs] = ExtractInputKVTensors(model, extracted_patterns); if (key_value_input_names.empty() || key_value_output_names.empty()) { - std::cout << "no key_value_input_names or key_value_output_names found" << std::endl; - return; + ORT_THROW("No key_value_input_names or key_value_output_names found"); + } + + if (key_value_input_names.size() != key_value_output_names.size()) { + ORT_THROW("Found different sizes between key_value_input_names (", + key_value_input_names.size(), + ") and key_value_output_names (", + key_value_output_names.size(), + "). They couldn't be paired."); } // By default, batch is the 0 - th but chatglm uses 1 - st dimension as batch