@@ -232,7 +232,7 @@ static Ort::Status GetOrtValueInfoTensorTypeShape(const OrtValueInfo& ort_value_
232232 /* out*/ std::vector<int64_t >& dims,
233233 /* out*/ std::vector<std::string>& symbolic_dims);
234234static Ort::Status OrtValueInfoToProto (const OrtValueInfo& ort_value_info, onnx::ValueInfoProto& value_info_proto);
235- static Ort::Status OrtOpAttrToProto (const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto);
235+ static Ort::Status OrtOpAttrToProto (const OrtNode& ort_node, const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto);
236236
237237Ort::Status OrtGraphToProto (const OrtGraph& ort_graph,
238238 onnx::GraphProto& graph_proto,
@@ -379,7 +379,7 @@ Ort::Status OrtGraphToProto(const OrtGraph& ort_graph,
379379 }
380380
381381 onnx::AttributeProto* attr_proto = node_proto->add_attribute ();
382- ORT_EP_UTILS_CXX_RETURN_IF_ERROR (OrtOpAttrToProto (*ort_attr, *attr_proto));
382+ ORT_EP_UTILS_CXX_RETURN_IF_ERROR (OrtOpAttrToProto (*ort_node, * ort_attr, *attr_proto));
383383 }
384384 }
385385
@@ -652,7 +652,7 @@ static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info,
652652 return Ort::Status{nullptr };
653653}
654654
655- static Ort::Status OrtOpAttrToProto (const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto) {
655+ static Ort::Status OrtOpAttrToProto (const OrtNode& ort_node, const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto) {
656656 const OrtApi& ort_api = Ort::GetApi ();
657657
658658 const char * attr_name = nullptr ;
@@ -758,6 +758,103 @@ static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributePr
758758
759759 break ;
760760 }
761+ case OrtOpAttrType::ORT_OP_ATTR_TENSOR: {
762+ attr_proto.set_type (onnx::AttributeProto_AttributeType_TENSOR);
763+
764+ onnx::TensorProto tensor_proto;
765+
766+ // TensorProto as an attribute value doesn't require a name.
767+
768+ OrtValue* ort_value = nullptr ;
769+ ORT_EP_UTILS_C_RETURN_IF_ERROR (ort_api.Node_GetTensorAttributeAsOrtValue (&ort_node, &ort_attr, &ort_value));
770+
771+ Ort::Value tensor (ort_value);
772+
773+ // Get tensor type and shape info
774+ Ort::TensorTypeAndShapeInfo type_shape_info = tensor.GetTensorTypeAndShapeInfo ();
775+
776+ // Get tensor type
777+ ONNXTensorElementDataType element_type = type_shape_info.GetElementType ();
778+
779+ size_t element_size = 0 ;
780+ switch (element_type) {
781+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: {
782+ tensor_proto.set_data_type (onnx::TensorProto_DataType_FLOAT);
783+ element_size = sizeof (float );
784+ break ;
785+ }
786+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: {
787+ tensor_proto.set_data_type (onnx::TensorProto_DataType_UINT8);
788+ element_size = sizeof (uint8_t );
789+ break ;
790+ }
791+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: {
792+ tensor_proto.set_data_type (onnx::TensorProto_DataType_INT8);
793+ element_size = sizeof (int8_t );
794+ break ;
795+ }
796+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: {
797+ tensor_proto.set_data_type (onnx::TensorProto_DataType_UINT16);
798+ element_size = sizeof (uint16_t );
799+ break ;
800+ }
801+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: {
802+ tensor_proto.set_data_type (onnx::TensorProto_DataType_INT16);
803+ element_size = sizeof (int16_t );
804+ break ;
805+ }
806+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: {
807+ tensor_proto.set_data_type (onnx::TensorProto_DataType_INT32);
808+ element_size = sizeof (int32_t );
809+ break ;
810+ }
811+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: {
812+ tensor_proto.set_data_type (onnx::TensorProto_DataType_INT64);
813+ element_size = sizeof (int64_t );
814+ break ;
815+ }
816+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: {
817+ tensor_proto.set_data_type (onnx::TensorProto_DataType_BOOL);
818+ element_size = sizeof (bool );
819+ break ;
820+ }
821+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: {
822+ tensor_proto.set_data_type (onnx::TensorProto_DataType_DOUBLE);
823+ element_size = sizeof (double );
824+ break ;
825+ }
826+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: {
827+ tensor_proto.set_data_type (onnx::TensorProto_DataType_UINT32);
828+ element_size = sizeof (uint32_t );
829+ break ;
830+ }
831+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: {
832+ tensor_proto.set_data_type (onnx::TensorProto_DataType_UINT64);
833+ element_size = sizeof (uint64_t );
834+ break ;
835+ }
836+ default : {
837+ std::string err_msg = " Unexpected ONNXTensorElementDataType with value " + std::to_string (static_cast <int >(element_type));
838+ return Ort::Status (err_msg.c_str (), ORT_FAIL);
839+ }
840+ }
841+
842+ auto shape = type_shape_info.GetShape ();
843+
844+ for (auto & dim : shape) {
845+ tensor_proto.add_dims (dim);
846+ }
847+
848+ size_t element_count = type_shape_info.GetElementCount ();
849+ size_t data_bytes = element_count * element_size;
850+ const void * data = tensor.GetTensorData <void >();
851+
852+ // Copy the Ortvalue to TensorProto as raw data
853+ tensor_proto.set_raw_data (data, data_bytes);
854+
855+ *(attr_proto.mutable_t ()) = std::move (tensor_proto);
856+ break ;
857+ }
761858 default : {
762859 std::string err_msg = " Unexpected OrtOpAttrType with value " + std::to_string (static_cast <int >(attr_type));
763860 return Ort::Status (err_msg.c_str (), ORT_FAIL);
0 commit comments