Skip to content

Commit 7e93cd7

Browse files
mingyueliuhliumingyue
andauthored
[VitisAI] Align TensorProto_DataType with onnx1.16 (microsoft#21067)
### Description Vitis AI EP synchronously supports the TensorProto data types supported by ONNX 1.16. Add error message show when graph resolve fail for troubleshooting. ### Motivation and Context ONNX 1.15 & 1.16 add support some new TensorProto DataType , such as - FLOAT8E4M3FN - FLOAT8E4M3FNUZ - FLOAT8E5M2 - FLOAT8E5M2FNUZ - UINT4 - INT4 --------- Co-authored-by: liumingyue <mingyue@xilinx.com>
1 parent 6baaaf5 commit 7e93cd7

File tree

4 files changed

+35
-2
lines changed

4 files changed

+35
-2
lines changed

onnxruntime/core/providers/vitisai/imp/global_api.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,9 @@ vaip_core::OrtApiForVaip* create_org_api_hook() {
270270
graph.SetGraphResolveNeeded();
271271
}
272272
auto status = graph.Resolve();
273+
if (!status.IsOK()) {
274+
std::cerr << "graph resolve error:" << status.ErrorMessage() << std::endl;
275+
}
273276
return status.Code();
274277
};
275278
the_global_api.graph_get_consumer_nodes_unsafe = [](const Graph& graph, const std::string& node_arg_name) -> auto {

onnxruntime/core/providers/vitisai/include/vaip/my_ort.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,13 @@ enum TensorProto_DataType : int {
3838
TensorProto_DataType_UINT64 = 13,
3939
TensorProto_DataType_COMPLEX64 = 14,
4040
TensorProto_DataType_COMPLEX128 = 15,
41-
TensorProto_DataType_BFLOAT16 = 16
41+
TensorProto_DataType_BFLOAT16 = 16,
42+
TensorProto_DataType_FLOAT8E4M3FN = 17,
43+
TensorProto_DataType_FLOAT8E4M3FNUZ = 18,
44+
TensorProto_DataType_FLOAT8E5M2 = 19,
45+
TensorProto_DataType_FLOAT8E5M2FNUZ = 20,
46+
TensorProto_DataType_UINT4 = 21,
47+
TensorProto_DataType_INT4 = 22
4248
};
4349
enum AttributeProto_AttributeType : int {
4450
AttributeProto_AttributeType_UNDEFINED = 0,

onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ struct OrtApi;
1313
namespace vaip_core {
1414

1515
#define VAIP_ORT_API_MAJOR (3u)
16-
#define VAIP_ORT_API_MINOR (0u)
16+
#define VAIP_ORT_API_MINOR (1u)
1717
#define VAIP_ORT_API_PATCH (0u)
1818
struct OrtApiForVaip {
1919
uint32_t magic; // 'VAIP' or something else to make sure the following field

onnxruntime/core/session/provider_bridge_ort.cc

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -613,8 +613,12 @@ struct ProviderHostImpl : ProviderHost {
613613
elemType = ONNX_NAMESPACE::TensorProto_DataType_UINT8;
614614
} else if (data_type->s() == "int32") {
615615
elemType = ONNX_NAMESPACE::TensorProto_DataType_INT32;
616+
} else if (data_type->s() == "uint32") {
617+
elemType = ONNX_NAMESPACE::TensorProto_DataType_UINT32;
616618
} else if (data_type->s() == "int64") {
617619
elemType = ONNX_NAMESPACE::TensorProto_DataType_INT64;
620+
} else if (data_type->s() == "uint64") {
621+
elemType = ONNX_NAMESPACE::TensorProto_DataType_UINT64;
618622
} else if (data_type->s() == "int1") {
619623
elemType = ONNX_NAMESPACE::TensorProto_DataType_BOOL;
620624
} else if (data_type->s() == "bfloat16") {
@@ -625,6 +629,26 @@ struct ProviderHostImpl : ProviderHost {
625629
elemType = ONNX_NAMESPACE::TensorProto_DataType_UINT16;
626630
} else if (data_type->s() == "int16") {
627631
elemType = ONNX_NAMESPACE::TensorProto_DataType_INT16;
632+
} else if (data_type->s() == "double") {
633+
elemType = ONNX_NAMESPACE::TensorProto_DataType_DOUBLE;
634+
} else if (data_type->s() == "string") {
635+
elemType = ONNX_NAMESPACE::TensorProto_DataType_STRING;
636+
} else if (data_type->s() == "complex64") {
637+
elemType = ONNX_NAMESPACE::TensorProto_DataType_COMPLEX64;
638+
} else if (data_type->s() == "complex128") {
639+
elemType = ONNX_NAMESPACE::TensorProto_DataType_COMPLEX128;
640+
} else if (data_type->s() == "float8e4m3fn") {
641+
elemType = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN;
642+
} else if (data_type->s() == "float8e4m3fnuz") {
643+
elemType = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ;
644+
} else if (data_type->s() == "float8e5m2") {
645+
elemType = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2;
646+
} else if (data_type->s() == "float8e5m2funz") {
647+
elemType = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ;
648+
} else if (data_type->s() == "uint4") {
649+
elemType = ONNX_NAMESPACE::TensorProto_DataType_UINT4;
650+
} else if (data_type->s() == "int4") {
651+
elemType = ONNX_NAMESPACE::TensorProto_DataType_INT4;
628652
} else {
629653
return;
630654
}

0 commit comments

Comments
 (0)