@@ -36,10 +36,6 @@ BasicBackend::BasicBackend(std::unique_ptr<ONNX_NAMESPACE::ModelProto>& model_pr
3636 if (ValidateSubgraph (const_outputs_map_))
3737 return ;
3838
39- // Pre-requisite is provider_option "context" must be set
40- auto auto_unified_compile = ((hw_target.find (" AUTO" ) == std::string::npos) ||
41- (session_context_.OpenVINO_Version .at (0 ) >= 2024 &&
42- session_context_.OpenVINO_Version .at (1 ) > 2 ));
4339 ov::AnyMap device_config;
4440 SetOVDeviceConfiguration (device_config);
4541 if (subgraph_context_.is_ep_ctx_graph ) {
@@ -81,42 +77,46 @@ BasicBackend::BasicBackend(std::unique_ptr<ONNX_NAMESPACE::ModelProto>& model_pr
8177 ORT_THROW (msg);
8278 } // Delete stream after it is no longer needed
8379 } else {
80+ std::shared_ptr<const onnxruntime::openvino_ep::OVNetwork> ov_model;
8481 std::string model = model_proto->SerializeAsString ();
8582 if (!subgraph_context.has_dynamic_input_shape ) {
8683 model_proto.reset ();
8784 }
85+ bool eligible_for_cpu_fallback = session_context_.device_type .find (" NPU" ) != std::string::npos &&
86+ !session_context_.so_disable_cpu_ep_fallback &&
87+ !subgraph_context_.is_ep_ctx_graph ;
88+ #if defined(OPENVINO_DISABLE_NPU_FALLBACK)
89+ eligible_for_cpu_fallback = false ;
90+ #endif
91+ auto auto_unified_compile = (hw_target.find (" AUTO" ) == std::string::npos);
92+
93+ // Unified compile is efficient with cahce_dir cached model loading that bypass Read Model
94+ // Does not support model with exteral weights, dynamic input shape, Epctx onnx cached model,
95+ // reshape, enable_causallm, and for NPU CPU fallback
96+
97+ auto is_unified_compile = (!session_context_.has_external_weights &&
98+ !subgraph_context_.has_dynamic_input_shape &&
99+ !session_context_.so_context_enable &&
100+ session_context_.reshape .empty () &&
101+ !enable_causallm &&
102+ !eligible_for_cpu_fallback &&
103+ auto_unified_compile);
88104 try {
89- // SetOVDeviceConfiguration(device_config);
90- if (!session_context_.has_external_weights &&
91- !subgraph_context_.has_dynamic_input_shape &&
92- !session_context_.so_context_enable &&
93- session_context_.reshape .empty () &&
94- !enable_causallm &&
95- auto_unified_compile) {
96- // Unified OV compile_model is efficient when ov model caching is enabled
97- // Unified OV compile_model API is supported with AUTO from version 2024.3 and above
98- // Inputs with static dimensions
99- // Not enabled for models with external weights and when ep context is set.
100-
105+ if (is_unified_compile) {
101106 exe_network_ = OVCore::Get ()->CompileModel (model,
102107 hw_target,
103108 device_config,
104109 subgraph_context_.subgraph_name );
105110 } else { // For all other types use ov::ov_core read_model() to generate OV IR
106111 // followed by ov::ov_core compile_model()
107- auto ov_model = CreateOVModel (std::move (model), session_context_, const_outputs_map_);
112+ ov_model = CreateOVModel (std::move (model), session_context_, const_outputs_map_);
108113 exe_network_ = OVCore::Get ()->CompileModel (
109114 ov_model, hw_target, device_config, enable_causallm, subgraph_context_.subgraph_name );
110115 }
111116 LOGS_DEFAULT (INFO) << log_tag << " Loaded model to the plugin" ;
112117 } catch (const OnnxRuntimeException& ex) {
113118 std::string exception_str = ex.what ();
114- bool eligible_for_cpu_fallback = session_context_.device_type .find (" NPU" ) != std::string::npos &&
115- !session_context_.so_disable_cpu_ep_fallback &&
116- !subgraph_context_.is_ep_ctx_graph ;
117- #if defined(OPENVINO_DISABLE_NPU_FALLBACK)
118- eligible_for_cpu_fallback = false ;
119- #endif
119+
120120 if (eligible_for_cpu_fallback) {
121121 LOGS_DEFAULT (WARNING) << " Model compilation failed at OV NPU."
122122 << " Falling back to OV CPU for execution" ;
@@ -125,8 +125,6 @@ BasicBackend::BasicBackend(std::unique_ptr<ONNX_NAMESPACE::ModelProto>& model_pr
125125 device_config.clear ();
126126 SetOVDeviceConfiguration (device_config);
127127 try {
128- // Recreate the model with CPU device type
129- auto ov_model = CreateOVModel (std::move (model), session_context_, const_outputs_map_);
130128 exe_network_ = OVCore::Get ()->CompileModel (
131129 ov_model, hw_target, device_config, enable_causallm, subgraph_context_.subgraph_name );
132130 } catch (std::string const & msg) {
0 commit comments