@@ -137,18 +137,28 @@ void MakeStateful(std::shared_ptr<ov::Model>& ov_model,
137137void PatchStatefulDecoder (std::shared_ptr<ov::Model> model) {
138138 std::vector<std::string> key_value_input_names;
139139 std::vector<std::string> not_kv_inputs;
140- const auto & params = model->get_parameters ();
141-
142- for (size_t i = 0 ; i < params.size (); i++) {
143- auto param_name = params[i]->output (0 ).get_any_name ();
144- if (param_name.find (" key_values" ) != std::string::npos) {
145- key_value_input_names.push_back (param_name);
146- } else if (param_name.find (" keys" ) != std::string::npos) {
147- key_value_input_names.push_back (param_name);
148- } else if (param_name.find (" values" ) != std::string::npos) {
149- key_value_input_names.push_back (param_name);
150- } else {
151- not_kv_inputs.push_back (param_name);
140+
141+ for (const ov::Output<ov::Node>& input : model->inputs ()) {
142+ auto & names = input.get_names ();
143+
144+ bool found = false ;
145+ for (auto & name : names) {
146+ if (name.find (" key_values" ) != std::string::npos) {
147+ key_value_input_names.push_back (name);
148+ found = true ;
149+ break ;
150+ } else if (name.find (" keys" ) != std::string::npos) {
151+ key_value_input_names.push_back (name);
152+ found = true ;
153+ break ;
154+ } else if (name.find (" values" ) != std::string::npos) {
155+ key_value_input_names.push_back (name);
156+ found = true ;
157+ break ;
158+ }
159+ }
160+ if (!found) {
161+ not_kv_inputs.push_back (input.get_any_name ());
152162 }
153163 }
154164
0 commit comments