Skip to content
145 changes: 116 additions & 29 deletions onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
// Licensed under the MIT License

#include "core/providers/openvino/ov_stateful_patch_utils.h"
#include "core/providers/shared_library/provider_api.h"
#include "core/common/common.h"

namespace onnxruntime {
namespace openvino_ep {
Expand Down Expand Up @@ -132,50 +134,135 @@
manager.run_passes(ov_model);
}

// Converted to C++ from below reference URL:
// https://github.com/huggingface/optimum-intel/blob/main/optimum/exporters/openvino/stateful.py#L281
void PatchStatefulDecoder(std::shared_ptr<ov::Model> model) {
// Helper function to extract KV patterns from output names dynamically
//
// Example: Given output names ["present_key_cross_0", "present_key_cross_1", "present_value_cross_0", "present_value_cross_1", "logits"]
// key_value_output_names = ["present_key_cross_0", "present_key_cross_1", "present_value_cross_0", "present_value_cross_1"]
// unique_patterns = {"key_cross", "value_cross"}
std::pair<std::vector<std::string>, std::unordered_set<std::string>> ExtractKVPatternsFromOutputs(const std::shared_ptr<ov::Model>& model) {
std::vector<std::string> key_value_output_names;
std::unordered_set<std::string> unique_patterns;

const std::string prefix = "present_";
const size_t prefix_len = prefix.length();
for (const ov::Output<ov::Node>& output : model->outputs()) {
const auto& names = output.get_names();
for (const auto& name : names) {
if (name.find(prefix) == 0 && name.length() > prefix_len) {
size_t last_underscore_pos = name.rfind('_');
// Extract pattern between "present_" and the last underscore
if (last_underscore_pos != std::string::npos && last_underscore_pos > prefix_len) {
std::string pattern = name.substr(prefix_len, last_underscore_pos - prefix_len);
if (!pattern.empty()) {
unique_patterns.insert(pattern);
key_value_output_names.push_back(name);
}
}
break;
}
}
}

if (unique_patterns.size() > 2) {
ORT_THROW("More than two unique KV patterns found in output names.");
}
return std::make_pair(key_value_output_names, unique_patterns);
}

// Main function to extract KV tensors using dynamic pattern matching
//
// Example: Given input names ["input_ids", "attention_mask", "past_key_cross_0", "past_key_cross_1", "past_value_cross_0", "past_value_cross_1"]
// kv_patterns = {"key_cross", "value_cross"}
//
// key_value_input_names = ["past_key_cross_0", "past_key_cross_1", "past_value_cross_0", "past_value_cross_1"]
// not_kv_inputs = ["input_ids", "attention_mask"]
std::pair<std::vector<std::string>, std::vector<std::string>> ExtractInputKVTensors(
const std::shared_ptr<ov::Model>& model, const std::unordered_set<std::string>& kv_patterns) {

Check warning on line 180 in onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <unordered_set> for unordered_set<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc:180: Add #include <unordered_set> for unordered_set<> [build/include_what_you_use] [4]

std::vector<std::string> key_value_input_names;
std::vector<std::string> not_kv_inputs;

if (kv_patterns.empty()) {
// Fallback: use original substring matching
for (const ov::Output<ov::Node>& input : model->inputs()) {
const auto& names = input.get_names();
const std::string input_name = input.get_any_name();

bool is_kv_input = false;
for (const auto& name : names) {
if (name.find("key_values") != std::string::npos ||
name.find("keys") != std::string::npos ||
name.find("values") != std::string::npos) {
key_value_input_names.push_back(name);
is_kv_input = true;
break;
}
}

if (!is_kv_input) {
not_kv_inputs.push_back(input_name);
}
}

return std::make_pair(key_value_input_names, not_kv_inputs);
}

// Inline helper function to check if name is matched with provided pattern followed by "_%d"
auto matches_pattern = [](const std::string& name, const std::string& pattern) -> bool {
size_t pos = name.find(pattern);
if (pos == std::string::npos) {
return false;
}

size_t after_pattern = pos + pattern.length();
if (after_pattern >= name.length() || name[after_pattern] != '_') {
return false;
}

std::string suffix = name.substr(after_pattern + 1);
return !suffix.empty() && std::all_of(suffix.begin(), suffix.end(), ::isdigit);
};

for (const ov::Output<ov::Node>& input : model->inputs()) {
auto& names = input.get_names();

bool found = false;
for (auto& name : names) {
if (name.find("key_values") != std::string::npos) {
key_value_input_names.push_back(name);
found = true;
break;
} else if (name.find("keys") != std::string::npos) {
key_value_input_names.push_back(name);
found = true;
break;
} else if (name.find("values") != std::string::npos) {
key_value_input_names.push_back(name);
found = true;
break;

// Check if any input name contains either key or value pattern
for (const auto& name : names) {
for (const auto& pattern : kv_patterns) {
if (matches_pattern(name, pattern)) {
key_value_input_names.push_back(name);
found = true;
break;
}
}
if (found) break;
}

if (!found) {
not_kv_inputs.push_back(input.get_any_name());
}
}

std::vector<std::string> key_value_output_names;
for (const ov::Output<ov::Node>& output : model->outputs()) {
auto& names = output.get_names();
for (auto& name : names) {
if (name.find("present") != std::string::npos) {
key_value_output_names.push_back(name);
break;
}
}
}
return std::make_pair(key_value_input_names, not_kv_inputs);
}

// Updated PatchStatefulDecoder function
void PatchStatefulDecoder(std::shared_ptr<ov::Model> model) {
// Use the dynamic pattern-based extraction logic
auto [key_value_output_names, extracted_patterns] = ExtractKVPatternsFromOutputs(model);
auto [key_value_input_names, not_kv_inputs] = ExtractInputKVTensors(model, extracted_patterns);

if (key_value_input_names.empty() || key_value_output_names.empty()) {
std::cout << "no key_value_input_names or key_value_output_names found" << std::endl;
return;
ORT_THROW("No key_value_input_names or key_value_output_names found");
}

if (key_value_input_names.size() != key_value_output_names.size()) {
ORT_THROW("Found different sizes between key_value_input_names (",
key_value_input_names.size(),
") and key_value_output_names (",
key_value_output_names.size(),
"). They couldn't be paired.");
}

// By default, batch is the 0 - th but chatglm uses 1 - st dimension as batch
Expand Down
Loading