Skip to content

Commit b3a074d

Browse files
chilo-msJaswanth51
authored andcommitted
[EP ABI] Add OpAttr_GetTensorAttributeAsOrtValue and replace the existing Node_GetTensorAttributeAsOrtValue (microsoft#25886)
### Description Replace `Node_GetTensorAttributeAsOrtValue` with `OpAttr_GetTensorAttributeAsOrtValue`. Change the API signature to make it one of the `OpAttr` interfaces instead of the `OrtNode` interface. The original API was added [here](microsoft#25566).
1 parent 47a5306 commit b3a074d

File tree

8 files changed

+41
-54
lines changed

8 files changed

+41
-54
lines changed

include/onnxruntime/core/providers/utils/ort_graph_to_proto.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ static Ort::Status GetOrtValueInfoTensorTypeShape(const OrtValueInfo& ort_value_
232232
/*out*/ std::vector<int64_t>& dims,
233233
/*out*/ std::vector<std::string>& symbolic_dims);
234234
static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, onnx::ValueInfoProto& value_info_proto);
235-
static Ort::Status OrtOpAttrToProto(const OrtNode& ort_node, const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto);
235+
static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto);
236236

237237
Ort::Status OrtGraphToProto(const OrtGraph& ort_graph,
238238
onnx::GraphProto& graph_proto,
@@ -379,7 +379,7 @@ Ort::Status OrtGraphToProto(const OrtGraph& ort_graph,
379379
}
380380

381381
onnx::AttributeProto* attr_proto = node_proto->add_attribute();
382-
ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtOpAttrToProto(*ort_node, *ort_attr, *attr_proto));
382+
ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtOpAttrToProto(*ort_attr, *attr_proto));
383383
}
384384
}
385385

@@ -652,7 +652,7 @@ static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info,
652652
return Ort::Status{nullptr};
653653
}
654654

655-
static Ort::Status OrtOpAttrToProto(const OrtNode& ort_node, const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto) {
655+
static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto) {
656656
const OrtApi& ort_api = Ort::GetApi();
657657

658658
const char* attr_name = nullptr;
@@ -766,7 +766,7 @@ static Ort::Status OrtOpAttrToProto(const OrtNode& ort_node, const OrtOpAttr& or
766766
// TensorProto as an attribute value doesn't require a name.
767767

768768
OrtValue* ort_value = nullptr;
769-
ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetTensorAttributeAsOrtValue(&ort_node, &ort_attr, &ort_value));
769+
ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.OpAttr_GetTensorAttributeAsOrtValue(&ort_attr, &ort_value));
770770

771771
Ort::Value tensor(ort_value);
772772

include/onnxruntime/core/session/onnxruntime_c_api.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6079,7 +6079,6 @@ struct OrtApi {
60796079

60806080
/** \brief Get the OrtNode's 'TENSOR' attribute as an OrtValue.
60816081
*
6082-
* \param[in] node The OrtNode instance.
60836082
* \param[in] attribute The OrtOpAttr instance.
60846083
* \param[out] attr_tensor If successful, contains the 'TENSOR' attribute as a newly created OrtValue.
60856084
Must be freed with OrtApi::ReleaseValue.
@@ -6088,7 +6087,7 @@ struct OrtApi {
60886087
*
60896088
* \since Version 1.23.
60906089
*/
6091-
ORT_API2_STATUS(Node_GetTensorAttributeAsOrtValue, _In_ const OrtNode* node, _In_ const OrtOpAttr* attribute,
6090+
ORT_API2_STATUS(OpAttr_GetTensorAttributeAsOrtValue, _In_ const OrtOpAttr* attribute,
60926091
_Outptr_result_maybenull_ OrtValue** attr_tensor);
60936092

60946093
/** \brief Get the attribute type as OrtOpAttrType from an OrtOpAttr.

onnxruntime/core/graph/abi_graph_types.h

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -252,16 +252,6 @@ struct OrtNode {
252252
/// <returns>A status indicating success or an error.</returns>
253253
virtual onnxruntime::Status GetAttributes(gsl::span<const OrtOpAttr*> attrs) const = 0;
254254

255-
/// <summary>
256-
/// Gets the node's 'TENSOR' attribute as an OrtValue.
257-
/// </summary>
258-
/// <param name="attr">Node's 'TENSOR' attribute.</param>
259-
/// <param name="value">Output parameter is set to a newly created OrtValue containing the 'TENSOR' attribute value,
260-
/// only if the attribute is of type 'TENSOR'</param>
261-
/// <returns>A status indicating success or an error.</returns>
262-
virtual onnxruntime::Status GetTensorAttributeAsOrtValue(const OrtOpAttr* attr,
263-
OrtValue*& value) const = 0;
264-
265255
/// <summary>
266256
/// Gets the number of node subgraphs.
267257
/// </summary>

onnxruntime/core/graph/ep_api_types.cc

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -249,32 +249,6 @@ Status EpNode::GetAttributes(gsl::span<const OrtOpAttr*> dst) const {
249249
return Status::OK();
250250
}
251251

252-
Status EpNode::GetTensorAttributeAsOrtValue(const OrtOpAttr* attribute, OrtValue*& result) const {
253-
const auto* attr_proto = reinterpret_cast<const ONNX_NAMESPACE::AttributeProto*>(attribute);
254-
255-
if (attr_proto->type() != onnx::AttributeProto::TENSOR) {
256-
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "This OrtOpAttr instance is not a 'TENSOR' attribute");
257-
}
258-
259-
const auto& graph_viewer = ep_graph_->GetGraphViewer();
260-
const auto& tensor_proto = attr_proto->t();
261-
262-
// Check that TensorProto is valid.
263-
ORT_ENFORCE(utils::HasDataType(tensor_proto), "Tensor proto doesn't have data type.");
264-
ORT_ENFORCE(ONNX_NAMESPACE::TensorProto::DataType_IsValid(tensor_proto.data_type()), "Tensor proto has invalid data type.");
265-
ORT_ENFORCE(!utils::HasExternalData(tensor_proto),
266-
"Tensor proto with external data for value attribute is not supported.");
267-
268-
// Initialize OrtValue for tensor attribute.
269-
auto tensor_attribute_value = std::make_unique<OrtValue>();
270-
AllocatorPtr tensor_attribute_allocator = CPUAllocator::DefaultInstance();
271-
ORT_RETURN_IF_ERROR(utils::TensorProtoToOrtValue(Env::Default(), graph_viewer.ModelPath(), tensor_proto,
272-
tensor_attribute_allocator, *tensor_attribute_value));
273-
274-
result = tensor_attribute_value.release();
275-
return Status::OK();
276-
}
277-
278252
Status EpNode::GetNumSubgraphs(size_t& num_subgraphs) const {
279253
num_subgraphs = subgraphs_.size();
280254
return Status::OK();

onnxruntime/core/graph/ep_api_types.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,9 +183,6 @@ struct EpNode : public OrtNode {
183183
// Gets the node's attributes.
184184
Status GetAttributes(gsl::span<const OrtOpAttr*> attrs) const override;
185185

186-
Status GetTensorAttributeAsOrtValue(const OrtOpAttr* attribute,
187-
OrtValue*& attr_tensor) const override;
188-
189186
// Gets the number of subgraphs contained by this node.
190187
Status GetNumSubgraphs(size_t& num_subgraphs) const override;
191188

onnxruntime/core/graph/model_editor_api_types.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -138,11 +138,6 @@ struct ModelEditorNode : public OrtNode {
138138
"OrtModelEditorApi does not support getting attribute OrtOpAttr for OrtNode");
139139
}
140140

141-
Status GetTensorAttributeAsOrtValue(const OrtOpAttr* /*attribute*/, OrtValue*& /*attr_tensor*/) const override {
142-
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
143-
"OrtModelEditorApi does not support getting 'TENSOR' attribute for OrtNode");
144-
}
145-
146141
Status GetNumSubgraphs(size_t& /*num_subgraphs*/) const override {
147142
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
148143
"OrtModelEditorApi does not support getting the subgraphs for OrtNode");

onnxruntime/core/session/onnxruntime_c_api.cc

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3036,7 +3036,7 @@ ORT_API_STATUS_IMPL(OrtApis::Node_GetAttributeByName, _In_ const OrtNode* node,
30363036
API_IMPL_END
30373037
}
30383038

3039-
ORT_API_STATUS_IMPL(OrtApis::Node_GetTensorAttributeAsOrtValue, _In_ const OrtNode* node, _In_ const OrtOpAttr* attribute, _Outptr_result_maybenull_ OrtValue** attr_tensor) {
3039+
ORT_API_STATUS_IMPL(OrtApis::OpAttr_GetTensorAttributeAsOrtValue, _In_ const OrtOpAttr* attribute, _Outptr_result_maybenull_ OrtValue** attr_tensor) {
30403040
API_IMPL_BEGIN
30413041
if (attr_tensor == nullptr) {
30423042
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "attr_tensor argument is null");
@@ -3045,7 +3045,39 @@ ORT_API_STATUS_IMPL(OrtApis::Node_GetTensorAttributeAsOrtValue, _In_ const OrtNo
30453045
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "attribute argument is null");
30463046
}
30473047

3048-
ORT_API_RETURN_IF_STATUS_NOT_OK(node->GetTensorAttributeAsOrtValue(attribute, *attr_tensor));
3048+
const auto* attr_proto = reinterpret_cast<const ONNX_NAMESPACE::AttributeProto*>(attribute);
3049+
3050+
if (attr_proto->type() != onnx::AttributeProto::TENSOR) {
3051+
return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "This OrtOpAttr instance is not a 'TENSOR' attribute");
3052+
}
3053+
3054+
const auto& tensor_proto = attr_proto->t();
3055+
3056+
// Check that TensorProto is valid.
3057+
if (!utils::HasDataType(tensor_proto)) {
3058+
return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Tensor proto doesn't have data type.");
3059+
}
3060+
3061+
if (!ONNX_NAMESPACE::TensorProto::DataType_IsValid(tensor_proto.data_type())) {
3062+
return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Tensor proto has invalid data type.");
3063+
}
3064+
3065+
if (utils::HasExternalData(tensor_proto)) {
3066+
return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT,
3067+
"Tensor proto with external data for value attribute is not supported.");
3068+
}
3069+
3070+
// Initialize OrtValue for tensor attribute.
3071+
auto tensor_attribute_value = std::make_unique<OrtValue>();
3072+
AllocatorPtr tensor_attribute_allocator = CPUAllocator::DefaultInstance();
3073+
// The tensor in the 'Tensor' attribute's TensorProto is stored inline, not in an external file.
3074+
// Therefore, the 'model_path' passed to TensorProtoToOrtValue() may be an empty path.
3075+
std::filesystem::path model_path;
3076+
ORT_API_RETURN_IF_STATUS_NOT_OK(utils::TensorProtoToOrtValue(Env::Default(), model_path, tensor_proto,
3077+
tensor_attribute_allocator, *tensor_attribute_value));
3078+
3079+
*attr_tensor = tensor_attribute_value.release();
3080+
30493081
return nullptr;
30503082
API_IMPL_END
30513083
}
@@ -4134,7 +4166,7 @@ static constexpr OrtApi ort_api_1_to_23 = {
41344166
&OrtApis::Node_GetNumAttributes,
41354167
&OrtApis::Node_GetAttributes,
41364168
&OrtApis::Node_GetAttributeByName,
4137-
&OrtApis::Node_GetTensorAttributeAsOrtValue,
4169+
&OrtApis::OpAttr_GetTensorAttributeAsOrtValue,
41384170
&OrtApis::OpAttr_GetType,
41394171
&OrtApis::OpAttr_GetName,
41404172
&OrtApis::Node_GetNumSubgraphs,

onnxruntime/core/session/ort_apis.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -687,7 +687,7 @@ ORT_API_STATUS_IMPL(Node_GetAttributes, _In_ const OrtNode* node,
687687
_Out_writes_(num_attributes) const OrtOpAttr** attributes, _In_ size_t num_attributes);
688688
ORT_API_STATUS_IMPL(Node_GetAttributeByName, _In_ const OrtNode* node, _In_ const char* attribute_name,
689689
_Outptr_result_maybenull_ const OrtOpAttr** attribute);
690-
ORT_API_STATUS_IMPL(Node_GetTensorAttributeAsOrtValue, _In_ const OrtNode* node, _In_ const OrtOpAttr* attribute,
690+
ORT_API_STATUS_IMPL(OpAttr_GetTensorAttributeAsOrtValue, _In_ const OrtOpAttr* attribute,
691691
_Outptr_result_maybenull_ OrtValue** attr_tensor);
692692
ORT_API_STATUS_IMPL(OpAttr_GetType, _In_ const OrtOpAttr* attribute, _Out_ OrtOpAttrType* type);
693693
ORT_API_STATUS_IMPL(OpAttr_GetName, _In_ const OrtOpAttr* attribute, _Outptr_ const char** name);

0 commit comments

Comments
 (0)