Skip to content

Commit 3c1c4c3

Browse files
committed
revert original code which is functional
1 parent d7ee534 commit 3c1c4c3

File tree

1 file changed

+22
-12
lines changed

1 file changed

+22
-12
lines changed

onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -137,18 +137,28 @@ void MakeStateful(std::shared_ptr<ov::Model>& ov_model,
137137
void PatchStatefulDecoder(std::shared_ptr<ov::Model> model) {
138138
std::vector<std::string> key_value_input_names;
139139
std::vector<std::string> not_kv_inputs;
140-
const auto& params = model->get_parameters();
141-
142-
for (size_t i = 0; i < params.size(); i++) {
143-
auto param_name = params[i]->output(0).get_any_name();
144-
if (param_name.find("key_values") != std::string::npos) {
145-
key_value_input_names.push_back(param_name);
146-
} else if (param_name.find("keys") != std::string::npos) {
147-
key_value_input_names.push_back(param_name);
148-
} else if (param_name.find("values") != std::string::npos) {
149-
key_value_input_names.push_back(param_name);
150-
} else{
151-
not_kv_inputs.push_back(param_name);
140+
141+
for (const ov::Output<ov::Node>& input : model->inputs()) {
142+
auto& names = input.get_names();
143+
144+
bool found = false;
145+
for (auto& name : names) {
146+
if (name.find("key_values") != std::string::npos) {
147+
key_value_input_names.push_back(name);
148+
found = true;
149+
break;
150+
} else if (name.find("keys") != std::string::npos) {
151+
key_value_input_names.push_back(name);
152+
found = true;
153+
break;
154+
} else if (name.find("values") != std::string::npos) {
155+
key_value_input_names.push_back(name);
156+
found = true;
157+
break;
158+
}
159+
}
160+
if (!found) {
161+
not_kv_inputs.push_back(input.get_any_name());
152162
}
153163
}
154164

0 commit comments

Comments
 (0)