Skip to content

Commit 4a8a289

Browse files
authored
[EP ABI] Support for TENSOR type attribute (microsoft#25566)
### Description Add a new `Node_GetTensorAttributeAsOrtValue` API to support attribute that is a `TENSOR` type. This API returns a const OrtValue that represents the TensorProto in the `TENSOR `attribute.
1 parent c29737d commit 4a8a289

File tree

9 files changed

+276
-3
lines changed

9 files changed

+276
-3
lines changed

include/onnxruntime/core/providers/utils/ort_graph_to_proto.h

Lines changed: 100 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ static Ort::Status GetOrtValueInfoTensorTypeShape(const OrtValueInfo& ort_value_
232232
/*out*/ std::vector<int64_t>& dims,
233233
/*out*/ std::vector<std::string>& symbolic_dims);
234234
static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, onnx::ValueInfoProto& value_info_proto);
235-
static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto);
235+
static Ort::Status OrtOpAttrToProto(const OrtNode& ort_node, const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto);
236236

237237
Ort::Status OrtGraphToProto(const OrtGraph& ort_graph,
238238
onnx::GraphProto& graph_proto,
@@ -379,7 +379,7 @@ Ort::Status OrtGraphToProto(const OrtGraph& ort_graph,
379379
}
380380

381381
onnx::AttributeProto* attr_proto = node_proto->add_attribute();
382-
ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtOpAttrToProto(*ort_attr, *attr_proto));
382+
ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtOpAttrToProto(*ort_node, *ort_attr, *attr_proto));
383383
}
384384
}
385385

@@ -652,7 +652,7 @@ static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info,
652652
return Ort::Status{nullptr};
653653
}
654654

655-
static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto) {
655+
static Ort::Status OrtOpAttrToProto(const OrtNode& ort_node, const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto) {
656656
const OrtApi& ort_api = Ort::GetApi();
657657

658658
const char* attr_name = nullptr;
@@ -758,6 +758,103 @@ static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributePr
758758

759759
break;
760760
}
761+
case OrtOpAttrType::ORT_OP_ATTR_TENSOR: {
762+
attr_proto.set_type(onnx::AttributeProto_AttributeType_TENSOR);
763+
764+
onnx::TensorProto tensor_proto;
765+
766+
// TensorProto as an attribute value doesn't require a name.
767+
768+
OrtValue* ort_value = nullptr;
769+
ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetTensorAttributeAsOrtValue(&ort_node, &ort_attr, &ort_value));
770+
771+
Ort::Value tensor(ort_value);
772+
773+
// Get tensor type and shape info
774+
Ort::TensorTypeAndShapeInfo type_shape_info = tensor.GetTensorTypeAndShapeInfo();
775+
776+
// Get tensor type
777+
ONNXTensorElementDataType element_type = type_shape_info.GetElementType();
778+
779+
size_t element_size = 0;
780+
switch (element_type) {
781+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: {
782+
tensor_proto.set_data_type(onnx::TensorProto_DataType_FLOAT);
783+
element_size = sizeof(float);
784+
break;
785+
}
786+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: {
787+
tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT8);
788+
element_size = sizeof(uint8_t);
789+
break;
790+
}
791+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: {
792+
tensor_proto.set_data_type(onnx::TensorProto_DataType_INT8);
793+
element_size = sizeof(int8_t);
794+
break;
795+
}
796+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: {
797+
tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT16);
798+
element_size = sizeof(uint16_t);
799+
break;
800+
}
801+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: {
802+
tensor_proto.set_data_type(onnx::TensorProto_DataType_INT16);
803+
element_size = sizeof(int16_t);
804+
break;
805+
}
806+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: {
807+
tensor_proto.set_data_type(onnx::TensorProto_DataType_INT32);
808+
element_size = sizeof(int32_t);
809+
break;
810+
}
811+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: {
812+
tensor_proto.set_data_type(onnx::TensorProto_DataType_INT64);
813+
element_size = sizeof(int64_t);
814+
break;
815+
}
816+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: {
817+
tensor_proto.set_data_type(onnx::TensorProto_DataType_BOOL);
818+
element_size = sizeof(bool);
819+
break;
820+
}
821+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: {
822+
tensor_proto.set_data_type(onnx::TensorProto_DataType_DOUBLE);
823+
element_size = sizeof(double);
824+
break;
825+
}
826+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: {
827+
tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT32);
828+
element_size = sizeof(uint32_t);
829+
break;
830+
}
831+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: {
832+
tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT64);
833+
element_size = sizeof(uint64_t);
834+
break;
835+
}
836+
default: {
837+
std::string err_msg = "Unexpected ONNXTensorElementDataType with value " + std::to_string(static_cast<int>(element_type));
838+
return Ort::Status(err_msg.c_str(), ORT_FAIL);
839+
}
840+
}
841+
842+
auto shape = type_shape_info.GetShape();
843+
844+
for (auto& dim : shape) {
845+
tensor_proto.add_dims(dim);
846+
}
847+
848+
size_t element_count = type_shape_info.GetElementCount();
849+
size_t data_bytes = element_count * element_size;
850+
const void* data = tensor.GetTensorData<void>();
851+
852+
// Copy the Ortvalue to TensorProto as raw data
853+
tensor_proto.set_raw_data(data, data_bytes);
854+
855+
*(attr_proto.mutable_t()) = std::move(tensor_proto);
856+
break;
857+
}
761858
default: {
762859
std::string err_msg = "Unexpected OrtOpAttrType with value " + std::to_string(static_cast<int>(attr_type));
763860
return Ort::Status(err_msg.c_str(), ORT_FAIL);

include/onnxruntime/core/session/onnxruntime_c_api.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,7 @@ typedef enum OrtOpAttrType {
276276
ORT_OP_ATTR_STRING,
277277
ORT_OP_ATTR_STRINGS,
278278
ORT_OP_ATTR_GRAPH,
279+
ORT_OP_ATTR_TENSOR,
279280
} OrtOpAttrType;
280281

281282
//! @}
@@ -6065,6 +6066,20 @@ struct OrtApi {
60656066
ORT_API2_STATUS(Node_GetAttributeByName, _In_ const OrtNode* node, _In_ const char* attribute_name,
60666067
_Outptr_result_maybenull_ const OrtOpAttr** attribute);
60676068

6069+
/** \brief Get the OrtNode's 'TENSOR' attribute as an OrtValue.
6070+
*
6071+
* \param[in] node The OrtNode instance.
6072+
* \param[in] attribute The OrtOpAttr instance.
6073+
* \param[out] attr_tensor If successful, contains the 'TENSOR' attribute as a newly created OrtValue.
6074+
Must be freed with OrtApi::ReleaseValue.
6075+
*
6076+
* \snippet{doc} snippets.dox OrtStatus Return Value
6077+
*
6078+
* \since Version 1.23.
6079+
*/
6080+
ORT_API2_STATUS(Node_GetTensorAttributeAsOrtValue, _In_ const OrtNode* node, _In_ const OrtOpAttr* attribute,
6081+
_Outptr_result_maybenull_ OrtValue** attr_tensor);
6082+
60686083
/** \brief Get the attribute type as OrtOpAttrType from an OrtOpAttr.
60696084
*
60706085
* \param[in] attribute The OrtOpAttr instance.

onnxruntime/core/graph/abi_graph_types.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,16 @@ struct OrtNode {
251251
/// <returns>A status indicating success or an error.</returns>
252252
virtual onnxruntime::Status GetAttributes(gsl::span<const OrtOpAttr*> attrs) const = 0;
253253

254+
/// <summary>
255+
/// Gets the node's 'TENSOR' attribute as an OrtValue.
256+
/// </summary>
257+
/// <param name="attr">Node's 'TENSOR' attribute.</param>
258+
/// <param name="value">Output parameter is set to a newly created OrtValue containing the 'TENSOR' attribute value,
259+
/// only if the attribute is of type 'TENSOR'</param>
260+
/// <returns>A status indicating success or an error.</returns>
261+
virtual onnxruntime::Status GetTensorAttributeAsOrtValue(const OrtOpAttr* attr,
262+
OrtValue*& value) const = 0;
263+
254264
/// <summary>
255265
/// Gets the number of node subgraphs.
256266
/// </summary>

onnxruntime/core/graph/ep_api_types.cc

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,32 @@ Status EpNode::GetAttributes(gsl::span<const OrtOpAttr*> dst) const {
248248
return Status::OK();
249249
}
250250

251+
Status EpNode::GetTensorAttributeAsOrtValue(const OrtOpAttr* attribute, OrtValue*& result) const {
252+
const auto* attr_proto = reinterpret_cast<const ONNX_NAMESPACE::AttributeProto*>(attribute);
253+
254+
if (attr_proto->type() != onnx::AttributeProto::TENSOR) {
255+
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "This OrtOpAttr instance is not a 'TENSOR' attribute");
256+
}
257+
258+
const auto& graph_viewer = ep_graph_->GetGraphViewer();
259+
const auto& tensor_proto = attr_proto->t();
260+
261+
// Check that TensorProto is valid.
262+
ORT_ENFORCE(utils::HasDataType(tensor_proto), "Tensor proto doesn't have data type.");
263+
ORT_ENFORCE(ONNX_NAMESPACE::TensorProto::DataType_IsValid(tensor_proto.data_type()), "Tensor proto has invalid data type.");
264+
ORT_ENFORCE(!utils::HasExternalData(tensor_proto),
265+
"Tensor proto with external data for value attribute is not supported.");
266+
267+
// Initialize OrtValue for tensor attribute.
268+
auto tensor_attribute_value = std::make_unique<OrtValue>();
269+
AllocatorPtr tensor_attribute_allocator = CPUAllocator::DefaultInstance();
270+
ORT_RETURN_IF_ERROR(utils::TensorProtoToOrtValue(Env::Default(), graph_viewer.ModelPath(), tensor_proto,
271+
tensor_attribute_allocator, *tensor_attribute_value));
272+
273+
result = tensor_attribute_value.release();
274+
return Status::OK();
275+
}
276+
251277
Status EpNode::GetNumSubgraphs(size_t& num_subgraphs) const {
252278
num_subgraphs = subgraphs_.size();
253279
return Status::OK();

onnxruntime/core/graph/ep_api_types.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,9 @@ struct EpNode : public OrtNode {
183183
// Gets the node's attributes.
184184
Status GetAttributes(gsl::span<const OrtOpAttr*> attrs) const override;
185185

186+
Status GetTensorAttributeAsOrtValue(const OrtOpAttr* attribute,
187+
OrtValue*& attr_tensor) const override;
188+
186189
// Gets the number of subgraphs contained by this node.
187190
Status GetNumSubgraphs(size_t& num_subgraphs) const override;
188191

onnxruntime/core/graph/model_editor_api_types.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,11 @@ struct ModelEditorNode : public OrtNode {
137137
"OrtModelEditorApi does not support getting attribute OrtOpAttr for OrtNode");
138138
}
139139

140+
Status GetTensorAttributeAsOrtValue(const OrtOpAttr* /*attribute*/, OrtValue*& /*attr_tensor*/) const override {
141+
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
142+
"OrtModelEditorApi does not support getting 'TENSOR' attribute for OrtNode");
143+
}
144+
140145
Status GetNumSubgraphs(size_t& /*num_subgraphs*/) const override {
141146
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
142147
"OrtModelEditorApi does not support getting the subgraphs for OrtNode");

onnxruntime/core/session/onnxruntime_c_api.cc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3018,6 +3018,20 @@ ORT_API_STATUS_IMPL(OrtApis::Node_GetAttributeByName, _In_ const OrtNode* node,
30183018
API_IMPL_END
30193019
}
30203020

3021+
ORT_API_STATUS_IMPL(OrtApis::Node_GetTensorAttributeAsOrtValue, _In_ const OrtNode* node, _In_ const OrtOpAttr* attribute, _Outptr_result_maybenull_ OrtValue** attr_tensor) {
3022+
API_IMPL_BEGIN
3023+
if (attr_tensor == nullptr) {
3024+
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "attr_tensor argument is null");
3025+
}
3026+
if (attribute == nullptr) {
3027+
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "attribute argument is null");
3028+
}
3029+
3030+
ORT_API_RETURN_IF_STATUS_NOT_OK(node->GetTensorAttributeAsOrtValue(attribute, *attr_tensor));
3031+
return nullptr;
3032+
API_IMPL_END
3033+
}
3034+
30213035
ORT_API_STATUS_IMPL(OrtApis::OpAttr_GetType, _In_ const OrtOpAttr* attribute, _Out_ OrtOpAttrType* type) {
30223036
API_IMPL_BEGIN
30233037
const auto attr = attribute->attr_proto;
@@ -3055,6 +3069,10 @@ ORT_API_STATUS_IMPL(OrtApis::OpAttr_GetType, _In_ const OrtOpAttr* attribute, _O
30553069
*type = OrtOpAttrType::ORT_OP_ATTR_GRAPH;
30563070
break;
30573071
}
3072+
case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_TENSOR: {
3073+
*type = OrtOpAttrType::ORT_OP_ATTR_TENSOR;
3074+
break;
3075+
}
30583076
default:
30593077
return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Unexpected attribute type.");
30603078
}
@@ -4037,6 +4055,7 @@ static constexpr OrtApi ort_api_1_to_23 = {
40374055
&OrtApis::Node_GetNumAttributes,
40384056
&OrtApis::Node_GetAttributes,
40394057
&OrtApis::Node_GetAttributeByName,
4058+
&OrtApis::Node_GetTensorAttributeAsOrtValue,
40404059
&OrtApis::OpAttr_GetType,
40414060
&OrtApis::OpAttr_GetName,
40424061
&OrtApis::Node_GetNumSubgraphs,

onnxruntime/core/session/ort_apis.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -679,6 +679,8 @@ ORT_API_STATUS_IMPL(Node_GetAttributes, _In_ const OrtNode* node,
679679
_Out_writes_(num_attributes) const OrtOpAttr** attributes, _In_ size_t num_attributes);
680680
ORT_API_STATUS_IMPL(Node_GetAttributeByName, _In_ const OrtNode* node, _In_ const char* attribute_name,
681681
_Outptr_result_maybenull_ const OrtOpAttr** attribute);
682+
ORT_API_STATUS_IMPL(Node_GetTensorAttributeAsOrtValue, _In_ const OrtNode* node, _In_ const OrtOpAttr* attribute,
683+
_Outptr_result_maybenull_ OrtValue** attr_tensor);
682684
ORT_API_STATUS_IMPL(OpAttr_GetType, _In_ const OrtOpAttr* attribute, _Out_ OrtOpAttrType* type);
683685
ORT_API_STATUS_IMPL(OpAttr_GetName, _In_ const OrtOpAttr* attribute, _Outptr_ const char** name);
684686
ORT_API_STATUS_IMPL(Node_GetNumSubgraphs, _In_ const OrtNode* node, _Out_ size_t* num_subgraphs);

onnxruntime/test/ep_graph/test_ep_graph.cc

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,39 @@ static void RunMNISTModel(const ORTCHAR_T* model_path, std::vector<float>& outpu
306306
output_data.assign(output_values, output_values + num_output_elems);
307307
}
308308

309+
static void RunConstantOfShapeModel(const ORTCHAR_T* model_path, std::vector<float>& output_data) {
310+
auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
311+
Ort::SessionOptions sess_options;
312+
Ort::Session session(*ort_env, model_path, sess_options);
313+
314+
std::vector<int64_t> input_shape = {3};
315+
std::vector<int64_t> input_data = {2, 3, 4};
316+
std::vector<Ort::Value> ort_inputs;
317+
std::vector<const char*> ort_input_names;
318+
319+
// Add 'x'
320+
ort_inputs.emplace_back(Ort::Value::CreateTensor<int64_t>(
321+
memory_info, input_data.data(), input_data.size(), input_shape.data(), input_shape.size()));
322+
ort_input_names.push_back("x");
323+
324+
// Run session and get outputs
325+
std::array<const char*, 1> output_names{"y"};
326+
std::vector<Ort::Value> ort_outputs = session.Run(Ort::RunOptions{nullptr}, ort_input_names.data(), ort_inputs.data(),
327+
ort_inputs.size(), output_names.data(), output_names.size());
328+
329+
// Check output type and number of elements.
330+
Ort::Value& ort_output = ort_outputs[0];
331+
auto output_type_shape = ort_output.GetTensorTypeAndShapeInfo();
332+
size_t num_output_elems = output_type_shape.GetElementCount();
333+
334+
ASSERT_EQ(output_type_shape.GetElementType(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT);
335+
ASSERT_EQ(num_output_elems, 24);
336+
337+
// Return output data.
338+
const float* output_values = ort_output.GetTensorData<float>();
339+
output_data.assign(output_values, output_values + num_output_elems);
340+
}
341+
309342
// Test serializing an OrtGraph (MNIST) to GraphProto. Saves initializers to external file.
310343
// Checks that the outputs of the serialized and original models are identical.
311344
TEST(EpGraphTest, SerializeToProto_Mnist) {
@@ -436,6 +469,65 @@ TEST(EpGraphTest, SerializeToProto_ExternalInitializersInMemory) {
436469
}
437470
}
438471

472+
// Test serializing an OrtGraph (MNIST) to GraphProto. Saves initializers to external file.
473+
// Checks that the outputs of the serialized and original models are identical.
474+
TEST(EpGraphTest, SerializeToProto_ConstantOfShape) {
475+
const ORTCHAR_T* original_model_path = ORT_TSTR("testdata/ort_minimal_test_models/tensor_attribute.onnx");
476+
const ORTCHAR_T* serialized_model_path = ORT_TSTR("constant_of_shape.onnx");
477+
std::filesystem::remove(serialized_model_path);
478+
479+
{
480+
auto test_graph = TestGraph::Load(original_model_path);
481+
ASSERT_NE(test_graph, nullptr) << "Failed to load test model";
482+
483+
// Serialize OrtGraph to GraphProto. Save initializers to external file.
484+
std::string ext_ini_file_path = "constant_of_shape_serialized.bin";
485+
std::filesystem::remove(ext_ini_file_path);
486+
std::ofstream ext_ini_ofs(ext_ini_file_path, std::ios::binary);
487+
auto handle_initializer_data = [&ext_ini_ofs, &ext_ini_file_path](const OrtValueInfo* value_info,
488+
const void* data, size_t bytes,
489+
bool& is_external, std::string& location,
490+
int64_t& offset) -> Ort::Status {
491+
// OrtValueInfo* could be used to query initializer's name, type, shape,
492+
// node consumers, etc.
493+
static_cast<void>(value_info);
494+
495+
if (bytes <= 127) {
496+
is_external = false; // Keep small initializers stored inside the TensorProto.
497+
return Ort::Status{nullptr};
498+
}
499+
500+
offset = ext_ini_ofs.tellp();
501+
location = ext_ini_file_path;
502+
ext_ini_ofs.write(static_cast<const char*>(data), bytes);
503+
ext_ini_ofs.flush();
504+
is_external = true; // True if is external initializer.
505+
506+
return Ort::Status{nullptr};
507+
};
508+
509+
ONNX_NAMESPACE::ModelProto model_proto;
510+
ASSERT_CXX_ORTSTATUS_OK(OrtEpUtils::OrtGraphToProto(test_graph->GetOrtGraph(), model_proto,
511+
handle_initializer_data));
512+
513+
std::ofstream ofs(serialized_model_path, std::ios::binary);
514+
model_proto.SerializeToOstream(&ofs);
515+
ofs.flush();
516+
517+
ASSERT_TRUE(std::filesystem::exists(serialized_model_path));
518+
ASSERT_TRUE(std::filesystem::exists(ext_ini_file_path));
519+
}
520+
521+
// Compare output of the original and serialized models. Should be identical.
522+
std::vector<float> output_original;
523+
std::vector<float> output_serialized;
524+
525+
RunConstantOfShapeModel(original_model_path, output_original);
526+
RunConstantOfShapeModel(serialized_model_path, output_serialized);
527+
528+
EXPECT_EQ(output_serialized, output_original);
529+
}
530+
439531
static void Run3LayerModel(const ORTCHAR_T* model_path, bool input_cond, std::vector<float>& output_data) {
440532
auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
441533
Ort::SessionOptions sess_options;
@@ -978,6 +1070,10 @@ static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_
9781070
ASSERT_EQ(api_node_attr_type, OrtOpAttrType::ORT_OP_ATTR_GRAPH);
9791071
break;
9801072
}
1073+
case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_TENSOR: {
1074+
ASSERT_EQ(api_node_attr_type, OrtOpAttrType::ORT_OP_ATTR_TENSOR);
1075+
break;
1076+
}
9811077
default:
9821078
// The unsupported type should be skipped by 'continue' above. It's unexpected so we force test to fail.
9831079
ASSERT_ORTSTATUS_OK(ort_api.CreateStatus(ORT_FAIL, "The attribute type is not in AttributeProto_AttributeType and this case shouldn't be hit."));

0 commit comments

Comments
 (0)