Skip to content

Commit ebdbbb7

Browse files
authored
[VitisAI] Int4 support (microsoft#22850)
### Description <!-- Describe your changes. --> 1. Add support for throwing error when hardware is not supported for VitisAI. 2. Add support for unloading VitisAI EP. 3. Add API for Win25. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> This is requirement for Win25
1 parent 6806174 commit ebdbbb7

File tree

10 files changed

+80
-8
lines changed

10 files changed

+80
-8
lines changed

onnxruntime/core/providers/shared_library/provider_interfaces.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,7 @@ struct ProviderHost {
589589
virtual const ConfigOptions& RunOptions__GetConfigOptions(const RunOptions* p) = 0;
590590
// OrtSessionOptions
591591
virtual const std::unordered_map<std::string, std::string>& SessionOptions__GetConfigOptionsMap(const OrtSessionOptions* p) = 0;
592+
virtual bool SessionOptions__GetEnableProfiling(const OrtSessionOptions* p) = 0;
592593
// ComputeCapability
593594
virtual std::unique_ptr<ComputeCapability> ComputeCapability__construct(std::unique_ptr<IndexedSubGraph> t_sub_graph) = 0;
594595
virtual void ComputeCapability__operator_delete(ComputeCapability* p) = 0;

onnxruntime/core/providers/shared_library/provider_wrappedtypes.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1476,5 +1476,8 @@ struct OrtSessionOptions final {
14761476
const std::unordered_map<std::string, std::string>& GetConfigOptions() const {
14771477
return onnxruntime::g_host->SessionOptions__GetConfigOptionsMap(this);
14781478
}
1479+
bool GetEnableProfiling() const {
1480+
return onnxruntime::g_host->SessionOptions__GetEnableProfiling(this);
1481+
}
14791482
PROVIDER_DISALLOW_ALL(OrtSessionOptions)
14801483
};

onnxruntime/core/providers/vitisai/imp/global_api.cc

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ struct OrtVitisAIEpAPI {
4747
void (*initialize_onnxruntime_vitisai_ep)(vaip_core::OrtApiForVaip* api, std::vector<OrtCustomOpDomain*>& ret_domain);
4848
std::vector<std::unique_ptr<vaip_core::ExecutionProvider>>* (*compile_onnx_model_with_options)(
4949
const std::string& model_path, const onnxruntime::Graph& graph, const onnxruntime::ProviderOptions& options);
50+
std::vector<std::unique_ptr<vaip_core::ExecutionProvider>>* (*compile_onnx_model_vitisai_ep_with_error_handling)(
51+
const std::string& model_path, const onnxruntime::Graph& graph, const onnxruntime::ProviderOptions& options, void* status, vaip_core::error_report_func func);
5052
uint32_t (*vaip_get_version)();
5153
void (*create_ep_context_nodes)(
5254
const std::vector<std::unique_ptr<vaip_core::ExecutionProvider>>& eps,
@@ -77,10 +79,11 @@ struct OrtVitisAIEpAPI {
7779
ORT_THROW_IF_ERROR(env.LoadDynamicLibrary(full_path, true, &handle_));
7880
#endif
7981
ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(handle_, "initialize_onnxruntime_vitisai_ep", (void**)&initialize_onnxruntime_vitisai_ep));
80-
auto status = env.GetSymbolFromLibrary(handle_, "compile_onnx_model_vitisai_ep_with_options", (void**)&compile_onnx_model_with_options);
81-
if (!status.IsOK()) {
82-
::onnxruntime::LogRuntimeError(0, status, __FILE__, static_cast<const char*>(__FUNCTION__), __LINE__);
83-
ORT_THROW(status);
82+
auto status1 = env.GetSymbolFromLibrary(handle_, "compile_onnx_model_vitisai_ep_with_error_handling", (void**)&compile_onnx_model_vitisai_ep_with_error_handling);
83+
auto status2 = env.GetSymbolFromLibrary(handle_, "compile_onnx_model_vitisai_ep_with_options", (void**)&compile_onnx_model_with_options);
84+
if ((!status1.IsOK()) && (!status2.IsOK())) {
85+
::onnxruntime::LogRuntimeError(0, status2, __FILE__, static_cast<const char*>(__FUNCTION__), __LINE__);
86+
ORT_THROW(status2);
8487
}
8588
std::ignore = env.GetSymbolFromLibrary(handle_, "vaip_get_version",
8689
(void**)&vaip_get_version);
@@ -89,6 +92,14 @@ struct OrtVitisAIEpAPI {
8992
ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(handle_, "vitisai_ep_on_run_start", (void**)&vitisai_ep_on_run_start));
9093
ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(handle_, "vitisai_ep_set_ep_dynamic_options", (void**)&vitisai_ep_set_ep_dynamic_options));
9194
}
95+
void Clear() {
96+
if (handle_) {
97+
auto& env = Provider_GetHost()->Env__Default();
98+
auto status = env.UnloadDynamicLibrary(handle_);
99+
vai_assert(status.IsOK(), status.ErrorMessage());
100+
handle_ = nullptr;
101+
}
102+
}
92103

93104
private:
94105
void* handle_{};
@@ -109,10 +120,25 @@ void profiler_collect(
109120
}
110121
}
111122

123+
void change_status_with_error(void* status_ptr, int error_code, const char* error_msg) {
124+
auto status = reinterpret_cast<Status*>(status_ptr);
125+
*status = Status(onnxruntime::common::ONNXRUNTIME, error_code, error_msg);
126+
}
127+
112128
vaip_core::DllSafe<std::vector<std::unique_ptr<vaip_core::ExecutionProvider>>> compile_onnx_model(
113-
const onnxruntime::GraphViewer& graph_viewer, const logging::Logger& logger, const ProviderOptions& options) {
129+
const onnxruntime::GraphViewer& graph_viewer, const onnxruntime::logging::Logger& logger, const onnxruntime::ProviderOptions& options) {
114130
auto model_path = graph_viewer.ModelPath().string();
115-
return vaip_core::DllSafe(s_library_vitisaiep.compile_onnx_model_with_options(model_path, graph_viewer.GetGraph(), options));
131+
if (s_library_vitisaiep.compile_onnx_model_vitisai_ep_with_error_handling) {
132+
Status status = Status::OK();
133+
auto status_ptr = reinterpret_cast<void*>(&status);
134+
auto ret = vaip_core::DllSafe(s_library_vitisaiep.compile_onnx_model_vitisai_ep_with_error_handling(model_path, graph_viewer.GetGraph(), options, status_ptr, change_status_with_error));
135+
if (!status.IsOK()) {
136+
ORT_THROW(status);
137+
}
138+
return ret;
139+
} else {
140+
return vaip_core::DllSafe(s_library_vitisaiep.compile_onnx_model_with_options(model_path, graph_viewer.GetGraph(), options));
141+
}
116142
}
117143

118144
std::optional<std::vector<Node*>> create_ep_context_nodes(
@@ -396,10 +422,12 @@ vaip_core::OrtApiForVaip* create_org_api_hook() {
396422
the_global_api.tensor_proto_get_shape_unsafe = vaip::tensor_proto_get_shape;
397423
the_global_api.tensor_proto_data_type = [](const ONNX_NAMESPACE::TensorProto& t) -> int { return t.data_type(); };
398424
the_global_api.tensor_proto_delete = [](ONNX_NAMESPACE::TensorProto* tp) { delete tp; };
425+
the_global_api.tensor_proto_new_i4 = vaip::tensor_proto_new_i4;
399426
the_global_api.tensor_proto_new_i8 = vaip::tensor_proto_new_i8;
400427
the_global_api.tensor_proto_new_i16 = vaip::tensor_proto_new_i16;
401428
the_global_api.tensor_proto_new_i32 = vaip::tensor_proto_new_i32;
402429
the_global_api.tensor_proto_new_i64 = vaip::tensor_proto_new_i64;
430+
the_global_api.tensor_proto_new_u4 = vaip::tensor_proto_new_u4;
403431
the_global_api.tensor_proto_new_u8 = vaip::tensor_proto_new_u8;
404432
the_global_api.tensor_proto_new_u16 = vaip::tensor_proto_new_u16;
405433
the_global_api.tensor_proto_new_u32 = vaip::tensor_proto_new_u32;
@@ -468,9 +496,21 @@ vaip_core::OrtApiForVaip* create_org_api_hook() {
468496
return vaip_core::DllSafe<std::string>(std::move(local_str));
469497
};
470498

499+
the_global_api.is_profiling_enabled = [](void* session_options) {
500+
auto options = reinterpret_cast<OrtSessionOptions*>(session_options);
501+
return options->GetEnableProfiling();
502+
};
503+
the_global_api.graph_remove_initialized_tensor = [](Graph& graph, const std::string& tensor_name) {
504+
graph.RemoveInitializedTensor(tensor_name);
505+
};
471506
if (!s_library_vitisaiep.vaip_get_version) {
472507
return reinterpret_cast<vaip_core::OrtApiForVaip*>(&(the_global_api.host_));
473508
} else {
474509
return &the_global_api;
475510
}
476511
}
512+
513+
void deinitialize_vitisai_ep() {
514+
s_library_vitisaiep.Clear();
515+
s_kernel_registry_vitisaiep.reset();
516+
}

onnxruntime/core/providers/vitisai/imp/tensor_proto.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,12 @@ static ONNX_NAMESPACE::TensorProto* tensor_proto_new(const std::string& name, co
8787
return tensor_proto.release();
8888
}
8989

90+
ONNX_NAMESPACE::TensorProto* tensor_proto_new_i4(const std::string& name, const std::vector<int64_t>& shape,
91+
const std::vector<int8_t>& data) {
92+
return tensor_proto_new(name, shape, ONNX_NAMESPACE::TensorProto_DataType_INT4,
93+
reinterpret_cast<const char*>(&data[0]), data.size() * sizeof(data[0]));
94+
}
95+
9096
ONNX_NAMESPACE::TensorProto* tensor_proto_new_i8(const std::string& name, const std::vector<int64_t>& shape,
9197
const std::vector<int8_t>& data) {
9298
return tensor_proto_new(name, shape, ONNX_NAMESPACE::TensorProto_DataType_INT8,
@@ -108,6 +114,13 @@ ONNX_NAMESPACE::TensorProto* tensor_proto_new_i64(const std::string& name, const
108114
return tensor_proto_new(name, shape, ONNX_NAMESPACE::TensorProto_DataType_INT64,
109115
reinterpret_cast<const char*>(&data[0]), data.size() * sizeof(data[0]));
110116
}
117+
118+
ONNX_NAMESPACE::TensorProto* tensor_proto_new_u4(const std::string& name, const std::vector<int64_t>& shape,
119+
const std::vector<uint8_t>& data) {
120+
return tensor_proto_new(name, shape, ONNX_NAMESPACE::TensorProto_DataType_UINT4,
121+
reinterpret_cast<const char*>(&data[0]), data.size() * sizeof(data[0]));
122+
}
123+
111124
ONNX_NAMESPACE::TensorProto* tensor_proto_new_u8(const std::string& name, const std::vector<int64_t>& shape,
112125
const std::vector<uint8_t>& data) {
113126
return tensor_proto_new(name, shape, ONNX_NAMESPACE::TensorProto_DataType_UINT8,

onnxruntime/core/providers/vitisai/imp/tensor_proto.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@ namespace vaip {
99
gsl::span<const char> tensor_proto_as_raw(const onnxruntime::Graph& graph, const ONNX_NAMESPACE::TensorProto& tensor);
1010
vaip_core::DllSafe<std::vector<int64_t>> tensor_proto_get_shape(const ONNX_NAMESPACE::TensorProto& tensor);
1111
const std::string& tensor_proto_get_name(const ONNX_NAMESPACE::TensorProto& tensor);
12+
ONNX_NAMESPACE::TensorProto* tensor_proto_new_i4(const std::string& name, const std::vector<int64_t>& shape,
13+
const std::vector<int8_t>& data);
14+
ONNX_NAMESPACE::TensorProto* tensor_proto_new_u4(const std::string& name, const std::vector<int64_t>& shape,
15+
const std::vector<uint8_t>& data);
1216
ONNX_NAMESPACE::TensorProto* tensor_proto_new_i8(const std::string& name, const std::vector<int64_t>& shape,
1317
const std::vector<int8_t>& data);
1418
ONNX_NAMESPACE::TensorProto* tensor_proto_new_u8(const std::string& name, const std::vector<int64_t>& shape,

onnxruntime/core/providers/vitisai/include/vaip/global_api.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "vaip/custom_op.h"
1212
#include <optional>
1313
void initialize_vitisai_ep();
14+
void deinitialize_vitisai_ep();
1415
vaip_core::DllSafe<std::vector<std::unique_ptr<vaip_core::ExecutionProvider>>> compile_onnx_model(const onnxruntime::GraphViewer& graph_viewer, const onnxruntime::logging::Logger& logger, const onnxruntime::ProviderOptions& options);
1516
std::shared_ptr<onnxruntime::KernelRegistry> get_kernel_registry_vitisaiep();
1617
const std::vector<OrtCustomOpDomain*>& get_domains_vitisaiep();

onnxruntime/core/providers/vitisai/include/vaip/my_ort.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,4 +122,5 @@ using InitializedTensorSet =
122122
std::unordered_map<std::string, const TensorProto*>;
123123

124124
using ModelMetaData = std::unordered_map<std::string, std::string>;
125+
using error_report_func = void (*)(void*, int, const char*);
125126
} // namespace vaip_core

onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ struct OrtApi;
1313

1414
namespace vaip_core {
1515

16-
#define VAIP_ORT_API_MAJOR (12u)
16+
#define VAIP_ORT_API_MAJOR (13u)
1717
#define VAIP_ORT_API_MINOR (0u)
1818
#define VAIP_ORT_API_PATCH (0u)
1919
struct OrtApiForVaip {
@@ -235,6 +235,14 @@ struct OrtApiForVaip {
235235
DllSafe<std::string> (*model_proto_serialize_as_string)(ModelProto& model_proto); // [96]
236236
void (*model_proto_delete)(ModelProto* p); // [97]
237237
DllSafe<std::string> (*attr_proto_release_string)(AttributeProto* attr); // [98]
238+
bool (*is_profiling_enabled)(void* session_options); // [99] // [98]
239+
TensorProto* (*tensor_proto_new_i4)(const std::string& name,
240+
const std::vector<int64_t>& shape,
241+
const std::vector<int8_t>& data); // [100]
242+
TensorProto* (*tensor_proto_new_u4)(const std::string& name,
243+
const std::vector<int64_t>& shape,
244+
const std::vector<uint8_t>& data); // [101]
245+
void (*graph_remove_initialized_tensor)(Graph& graph, const std::string& tensor_name); // [102]
238246
};
239247

240248
#ifndef USE_VITISAI

onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ struct VitisAI_Provider : Provider {
5050
// Called right after loading the shared library, if this throws any errors Shutdown() will be called and the library unloaded
5151
void Initialize() override { initialize_vitisai_ep(); }
5252
// Called right before unloading the shared library
53-
void Shutdown() override {}
53+
void Shutdown() override { deinitialize_vitisai_ep(); }
5454
} g_provider;
5555

5656
} // namespace onnxruntime

onnxruntime/core/session/provider_bridge_ort.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -720,6 +720,7 @@ struct ProviderHostImpl : ProviderHost {
720720

721721
// OrtSessionOptions (wrapped)
722722
const std::unordered_map<std::string, std::string>& SessionOptions__GetConfigOptionsMap(const OrtSessionOptions* p) override { return p->value.config_options.configurations; }
723+
bool SessionOptions__GetEnableProfiling(const OrtSessionOptions* p) override { return p->value.enable_profiling; };
723724
// ComputeCapability (wrapped)
724725
std::unique_ptr<ComputeCapability> ComputeCapability__construct(std::unique_ptr<IndexedSubGraph> t_sub_graph) override { return std::make_unique<ComputeCapability>(std::move(t_sub_graph)); }
725726
void ComputeCapability__operator_delete(ComputeCapability* p) override { delete p; }

0 commit comments

Comments
 (0)