@@ -306,6 +306,39 @@ static void RunMNISTModel(const ORTCHAR_T* model_path, std::vector<float>& outpu
306306 output_data.assign (output_values, output_values + num_output_elems);
307307}
308308
309+ static void RunConstantOfShapeModel (const ORTCHAR_T* model_path, std::vector<float >& output_data) {
310+ auto memory_info = Ort::MemoryInfo::CreateCpu (OrtDeviceAllocator, OrtMemTypeCPU);
311+ Ort::SessionOptions sess_options;
312+ Ort::Session session (*ort_env, model_path, sess_options);
313+
314+ std::vector<int64_t > input_shape = {3 };
315+ std::vector<int64_t > input_data = {2 , 3 , 4 };
316+ std::vector<Ort::Value> ort_inputs;
317+ std::vector<const char *> ort_input_names;
318+
319+ // Add 'x'
320+ ort_inputs.emplace_back (Ort::Value::CreateTensor<int64_t >(
321+ memory_info, input_data.data (), input_data.size (), input_shape.data (), input_shape.size ()));
322+ ort_input_names.push_back (" x" );
323+
324+ // Run session and get outputs
325+ std::array<const char *, 1 > output_names{" y" };
326+ std::vector<Ort::Value> ort_outputs = session.Run (Ort::RunOptions{nullptr }, ort_input_names.data (), ort_inputs.data (),
327+ ort_inputs.size (), output_names.data (), output_names.size ());
328+
329+ // Check output type and number of elements.
330+ Ort::Value& ort_output = ort_outputs[0 ];
331+ auto output_type_shape = ort_output.GetTensorTypeAndShapeInfo ();
332+ size_t num_output_elems = output_type_shape.GetElementCount ();
333+
334+ ASSERT_EQ (output_type_shape.GetElementType (), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT);
335+ ASSERT_EQ (num_output_elems, 24 );
336+
337+ // Return output data.
338+ const float * output_values = ort_output.GetTensorData <float >();
339+ output_data.assign (output_values, output_values + num_output_elems);
340+ }
341+
309342// Test serializing an OrtGraph (MNIST) to GraphProto. Saves initializers to external file.
310343// Checks that the outputs of the serialized and original models are identical.
311344TEST (EpGraphTest, SerializeToProto_Mnist) {
@@ -436,6 +469,65 @@ TEST(EpGraphTest, SerializeToProto_ExternalInitializersInMemory) {
436469 }
437470}
438471
472+ // Test serializing an OrtGraph (MNIST) to GraphProto. Saves initializers to external file.
473+ // Checks that the outputs of the serialized and original models are identical.
474+ TEST (EpGraphTest, SerializeToProto_ConstantOfShape) {
475+ const ORTCHAR_T* original_model_path = ORT_TSTR (" testdata/ort_minimal_test_models/tensor_attribute.onnx" );
476+ const ORTCHAR_T* serialized_model_path = ORT_TSTR (" constant_of_shape.onnx" );
477+ std::filesystem::remove (serialized_model_path);
478+
479+ {
480+ auto test_graph = TestGraph::Load (original_model_path);
481+ ASSERT_NE (test_graph, nullptr ) << " Failed to load test model" ;
482+
483+ // Serialize OrtGraph to GraphProto. Save initializers to external file.
484+ std::string ext_ini_file_path = " constant_of_shape_serialized.bin" ;
485+ std::filesystem::remove (ext_ini_file_path);
486+ std::ofstream ext_ini_ofs (ext_ini_file_path, std::ios::binary);
487+ auto handle_initializer_data = [&ext_ini_ofs, &ext_ini_file_path](const OrtValueInfo* value_info,
488+ const void * data, size_t bytes,
489+ bool & is_external, std::string& location,
490+ int64_t & offset) -> Ort::Status {
491+ // OrtValueInfo* could be used to query initializer's name, type, shape,
492+ // node consumers, etc.
493+ static_cast <void >(value_info);
494+
495+ if (bytes <= 127 ) {
496+ is_external = false ; // Keep small initializers stored inside the TensorProto.
497+ return Ort::Status{nullptr };
498+ }
499+
500+ offset = ext_ini_ofs.tellp ();
501+ location = ext_ini_file_path;
502+ ext_ini_ofs.write (static_cast <const char *>(data), bytes);
503+ ext_ini_ofs.flush ();
504+ is_external = true ; // True if is external initializer.
505+
506+ return Ort::Status{nullptr };
507+ };
508+
509+ ONNX_NAMESPACE::ModelProto model_proto;
510+ ASSERT_CXX_ORTSTATUS_OK (OrtEpUtils::OrtGraphToProto (test_graph->GetOrtGraph (), model_proto,
511+ handle_initializer_data));
512+
513+ std::ofstream ofs (serialized_model_path, std::ios::binary);
514+ model_proto.SerializeToOstream (&ofs);
515+ ofs.flush ();
516+
517+ ASSERT_TRUE (std::filesystem::exists (serialized_model_path));
518+ ASSERT_TRUE (std::filesystem::exists (ext_ini_file_path));
519+ }
520+
521+ // Compare output of the original and serialized models. Should be identical.
522+ std::vector<float > output_original;
523+ std::vector<float > output_serialized;
524+
525+ RunConstantOfShapeModel (original_model_path, output_original);
526+ RunConstantOfShapeModel (serialized_model_path, output_serialized);
527+
528+ EXPECT_EQ (output_serialized, output_original);
529+ }
530+
439531static void Run3LayerModel (const ORTCHAR_T* model_path, bool input_cond, std::vector<float >& output_data) {
440532 auto memory_info = Ort::MemoryInfo::CreateCpu (OrtDeviceAllocator, OrtMemTypeCPU);
441533 Ort::SessionOptions sess_options;
@@ -978,6 +1070,10 @@ static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_
9781070 ASSERT_EQ (api_node_attr_type, OrtOpAttrType::ORT_OP_ATTR_GRAPH);
9791071 break ;
9801072 }
1073+ case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_TENSOR: {
1074+ ASSERT_EQ (api_node_attr_type, OrtOpAttrType::ORT_OP_ATTR_TENSOR);
1075+ break ;
1076+ }
9811077 default :
9821078 // The unsupported type should be skipped by 'continue' above. It's unexpected so we force test to fail.
9831079 ASSERT_ORTSTATUS_OK (ort_api.CreateStatus (ORT_FAIL, " The attribute type is not in AttributeProto_AttributeType and this case shouldn't be hit." ));
0 commit comments