@@ -666,12 +666,16 @@ void Node::ToProto(NodeProto& proto, bool update_subgraphs) const {
666666
667667 // Set attributes.
668668 proto.clear_attribute ();
669- for (const auto & attribute : attributes_) {
669+ for (const auto & [name, attribute] : attributes_) {
670670 const gsl::not_null<AttributeProto*> attr{proto.add_attribute ()};
671- *attr = attribute.second ; // copy
672- if (update_subgraphs && attr->has_g ()) {
671+ *attr = attribute; // copy
672+ if (update_subgraphs && utils::HasGraph (*attr)) {
673+ auto find_hit = attr_to_subgraph_map_.find (name);
674+ // Force ToGraphProto() const to be called so
675+ // that any in-memory TensorProto initializers go back to being inlined
676+ const Graph& subgraph = *find_hit->second ;
673677 attr->clear_g ();
674- *attr->mutable_g () = attr_to_subgraph_map_. find (attribute. first )-> second -> ToGraphProto ();
678+ *attr->mutable_g () = subgraph. ToGraphProto ();
675679 }
676680 }
677681
@@ -3381,7 +3385,12 @@ Status Graph::Resolve(const ResolveOptions& options) {
33813385
33823386 return Status::OK (); };
33833387
3384- ORT_RETURN_IF_ERROR (ForThisAndAllSubgraphs (all_subgraphs, finalize_func));
3388+ return ForThisAndAllSubgraphs (all_subgraphs, finalize_func);
3389+ }
3390+
3391+ Status Graph::ConvertInitializersIntoOrtValues () {
3392+ std::vector<Graph*> all_subgraphs;
3393+ FindAllSubgraphs (all_subgraphs);
33853394
33863395 auto put_weights_maybe_in_memory_func = [&](Graph& graph) -> Status {
33873396 // if we have any initializers that are not in memory, put them there.
@@ -4308,11 +4317,47 @@ Status InlineOrCopyInitializer(const Graph& src_graph, const ONNX_NAMESPACE::Ten
43084317 }
43094318 return Status::OK ();
43104319}
4311-
43124320} // namespace
43134321
4314- Status Graph::ProcessSubgraphsInMemoryData (ONNX_NAMESPACE::GraphProto& output_graph_proto,
4315- bool process_main) const {
4322+ Status Graph::RegenerateInitializersAndReplaceInMemory (gsl::span<const InitializedTensorSet::const_iterator> iterators,
4323+ ONNX_NAMESPACE::GraphProto& output_graph_proto) const {
4324+ auto & mutable_initializers = *output_graph_proto.mutable_initializer ();
4325+
4326+ #if !defined(DISABLE_SPARSE_TENSORS)
4327+ output_graph_proto.clear_sparse_initializer ();
4328+
4329+ const auto & model_path = ModelPath ();
4330+ const bool has_sparse_initializers = !sparse_tensor_names_.empty ();
4331+ const auto sparse_end = sparse_tensor_names_.end ();
4332+
4333+ for (const auto & iter : iterators) {
4334+ const auto & [name, tensor_proto] = *iter;
4335+ const auto & initializer = *tensor_proto;
4336+ if (!has_sparse_initializers || sparse_end == sparse_tensor_names_.find (name)) {
4337+ ORT_RETURN_IF_ERROR (InlineOrCopyInitializer (*this , initializer,
4338+ *mutable_initializers.Add ()));
4339+ } else {
4340+ auto & sparse_initializer = *output_graph_proto.add_sparse_initializer ();
4341+ if (utils::HasExternalDataInMemory (initializer)) {
4342+ ONNX_NAMESPACE::TensorProto tensor_proto_inlined;
4343+ ORT_RETURN_IF_ERROR (InlineOrCopyInitializer (*this , initializer,
4344+ tensor_proto_inlined));
4345+ ORT_RETURN_IF_ERROR (utils::DenseTensorToSparseTensorProto (tensor_proto_inlined, model_path, sparse_initializer));
4346+ } else {
4347+ ORT_RETURN_IF_ERROR (utils::DenseTensorToSparseTensorProto (initializer, model_path, sparse_initializer));
4348+ }
4349+ }
4350+ }
4351+ #else
4352+ for (const auto & iter : iterators) {
4353+ const auto & [name, tensor_proto] = *iter;
4354+ ORT_RETURN_IF_ERROR (InlineOrCopyInitializer (*this , *tensor_proto, *mutable_initializers.Add ()));
4355+ }
4356+ #endif
4357+ return Status::OK ();
4358+ }
4359+
4360+ Status Graph::ProcessSubgraphsInMemoryData (ONNX_NAMESPACE::GraphProto& output_graph_proto) const {
43164361 for (const auto & node : Nodes ()) {
43174362 if (node.ContainsSubgraph ()) {
43184363 // Let's find this node in the output_graph_proto
@@ -4343,103 +4388,48 @@ Status Graph::ProcessSubgraphsInMemoryData(ONNX_NAMESPACE::GraphProto& output_gr
43434388 " Subgraph " , name, " is referred to in GetAttributeNameToSubgraphMap, but not found in node " ,
43444389 node.Name (), " while attempting to recurse into it." );
43454390 auto & result_subgraph = *sub_hit->mutable_g ();
4346- ORT_RETURN_IF_ERROR (subgraph->ProcessSubgraphsInMemoryData (result_subgraph, process_main ));
4391+ ORT_RETURN_IF_ERROR (subgraph->ProcessSubgraphsInMemoryData (result_subgraph));
43474392 }
43484393 }
43494394 }
43504395
4351- // When graph_proto is copied from graph_proto, initializers already present in the main graph
4352- if (parent_graph_ != nullptr || process_main) {
4353- #if !defined(DISABLE_SPARSE_TENSORS)
4354- auto * mutable_initializers = output_graph_proto.mutable_initializer ();
4355- const auto & model_path = ModelPath ();
4356- const bool has_sparse_initializers = !sparse_tensor_names_.empty ();
4357- const auto sparse_end = sparse_tensor_names_.end ();
4358-
4359- // We want to make sure that sparse initializers do not appear
4360- // as dense duplicates within the initializers list.
4361- std::optional<InlinedHashSet<std::string>> initializer_to_remove;
4362- if (has_sparse_initializers) {
4363- // We need to remove the dense initializers that are sparse tensors
4364- initializer_to_remove.emplace ();
4365- }
4366-
4367- for (auto first = mutable_initializers->begin (), end = mutable_initializers->end (); first != end; ++first) {
4368- auto & initializer = *first;
4369- if (utils::HasExternalDataInMemory (initializer)) {
4370- // If the initializer has external data in memory, we need to inline it.
4371- ORT_RETURN_IF_ERROR (InlineOrCopyInitializer (*this , initializer, initializer));
4372- }
4373- if (has_sparse_initializers && sparse_end != sparse_tensor_names_.find (initializer.name ())) {
4374- auto & sparse_initializer = *output_graph_proto.add_sparse_initializer ();
4375- ORT_RETURN_IF_ERROR (utils::DenseTensorToSparseTensorProto (initializer, model_path, sparse_initializer));
4376- initializer_to_remove->insert (initializer.name ());
4377- }
4378- }
4379-
4380- // erase/remove dense initializers that are sparse tensors so no duplicates are present
4381- if (initializer_to_remove && !initializer_to_remove->empty ()) {
4382- mutable_initializers->erase (std::remove_if (
4383- mutable_initializers->begin (), mutable_initializers->end (),
4384- [&initializer_to_remove](const ONNX_NAMESPACE::TensorProto& initializer) {
4385- return initializer_to_remove->count (initializer.name ()) > 0 ;
4386- }),
4387- mutable_initializers->end ());
4388- }
4389- #else
4390- for (auto & initializer : *output_graph_proto.mutable_initializer ()) {
4391- if (utils::HasExternalDataInMemory (initializer)) {
4392- // If the initializer has external data in memory, we need to inline it.
4393- ORT_RETURN_IF_ERROR (InlineOrCopyInitializer (*this , initializer, initializer));
4394- }
4396+ // Filter in iterators for weights that are present in the name_to_initial_tensor_ map
4397+ // and preserve the order. This is needed for tests.
4398+ InlinedVector<InitializedTensorSet::const_iterator> initializers_to_process;
4399+ initializers_to_process.reserve (name_to_initial_tensor_.size ());
4400+ for (const auto & tensor_proto : output_graph_proto.initializer ()) {
4401+ auto hit = name_to_initial_tensor_.find (tensor_proto.name ());
4402+ if (hit != name_to_initial_tensor_.end ()) {
4403+ initializers_to_process.push_back (hit);
43954404 }
4396- #endif
43974405 }
4398- return Status::OK ();
4406+
4407+ output_graph_proto.clear_initializer ();
4408+ return RegenerateInitializersAndReplaceInMemory (initializers_to_process, output_graph_proto);
43994409}
44004410
44014411ONNX_NAMESPACE::GraphProto Graph::ToGraphProto () const {
44024412 GraphProto result;
44034413 if (!GraphProtoSyncNeeded ()) {
44044414 result = *graph_proto_;
4405- ORT_THROW_IF_ERROR (ProcessSubgraphsInMemoryData (result, /* process_main */ true ));
4415+ ORT_THROW_IF_ERROR (ProcessSubgraphsInMemoryData (result));
44064416 } else {
4417+ // Recursion is handled via Node::ToProto() const -> Graph::ToGraphProto() const (this method)
4418+ // so below we handle this graph only.
44074419 ToGraphProtoInternal (result);
44084420
4409- ORT_THROW_IF_ERROR (ProcessSubgraphsInMemoryData (result, /* process_main*/ false ));
4410-
4411- // Add initializers to parent graph by copy converting them from graph_proto_
4412- // ToGraphProtoInternal() does not copy initializers for the main graph
4413- auto * mutable_initializers = result.mutable_initializer ();
4414-
4415- #if !defined(DISABLE_SPARSE_TENSORS)
4416- const auto & model_path = ModelPath ();
4417- const bool has_sparse_initializers = !sparse_tensor_names_.empty ();
4418- const auto sparse_end = sparse_tensor_names_.end ();
4419-
4420- for (const auto & initializer : graph_proto_->initializer ()) {
4421- if (!has_sparse_initializers || sparse_end == sparse_tensor_names_.find (initializer.name ())) {
4422- ORT_THROW_IF_ERROR (InlineOrCopyInitializer (*this , initializer,
4423- *mutable_initializers->Add ()));
4424- } else {
4425- auto & sparse_initializer = *result.add_sparse_initializer ();
4426- if (utils::HasExternalDataInMemory (initializer)) {
4427- ONNX_NAMESPACE::TensorProto tensor_proto;
4428- ORT_THROW_IF_ERROR (InlineOrCopyInitializer (*this , initializer,
4429- tensor_proto));
4430- ORT_THROW_IF_ERROR (utils::DenseTensorToSparseTensorProto (tensor_proto, model_path, sparse_initializer));
4431- } else {
4432- ORT_THROW_IF_ERROR (utils::DenseTensorToSparseTensorProto (initializer, model_path, sparse_initializer));
4433- }
4421+ InlinedVector<InitializedTensorSet::const_iterator> initializers_to_process;
4422+ initializers_to_process.reserve (name_to_initial_tensor_.size ());
4423+ for (const auto & tensor_proto : graph_proto_->initializer ()) {
4424+ auto hit = name_to_initial_tensor_.find (tensor_proto.name ());
4425+ if (hit != name_to_initial_tensor_.end ()) {
4426+ initializers_to_process.push_back (hit);
44344427 }
44354428 }
4436- #else
4437- for (const auto & initializer : graph_proto_->initializer ()) {
4438- ORT_THROW_IF_ERROR (InlineOrCopyInitializer (*this , initializer, *mutable_initializers->Add ()));
4439- }
4440- #endif
4441- }
44424429
4430+ ORT_THROW_IF_ERROR (RegenerateInitializersAndReplaceInMemory (initializers_to_process,
4431+ result));
4432+ }
44434433 return result;
44444434}
44454435
@@ -5235,23 +5225,7 @@ Status Graph::AddConstantProtoAsInitializer(const ONNX_NAMESPACE::NodeProto& nod
52355225 tensor_proto.set_name (std::string (new_name.value ()));
52365226 }
52375227
5238- // In the constant node, we won't have symbolic dims.
5239- const auto tensor_shape = utils::GetTensorShapeFromTensorProto (tensor_proto);
5240- auto ml_data = DataTypeImpl::TensorTypeFromONNXEnum (tensor_proto.data_type ())->GetElementType ();
5241- const size_t size_in_bytes = Tensor::CalculateTensorStorageSize (ml_data, tensor_shape);
5242-
5243- if (size_in_bytes > utils::kSmallTensorExternalDataThreshold ) {
5244- OrtValue ort_value;
5245- ORT_RETURN_IF_ERROR (utils::TensorProtoToOrtValue (Env::Default (), ModelPath (), tensor_proto,
5246- CPUAllocator::DefaultInstance (), ort_value));
5247-
5248- constexpr const bool use_tensor_buffer_true = true ;
5249- auto tensor_proto_to_add = utils::TensorToTensorProto (ort_value.Get <Tensor>(), tensor_proto.name (),
5250- use_tensor_buffer_true);
5251- ORT_RETURN_IF_ERROR (AddInitializedOrtValue (tensor_proto_to_add, ort_value));
5252- } else {
5253- AddInitializedTensor (tensor_proto);
5254- }
5228+ AddInitializedTensor (tensor_proto);
52555229
52565230 if (GetNodeArg (tensor_proto.name ()) == nullptr ) {
52575231 TypeProto t{utils::TypeProtoFromTensorProto (tensor_proto)};
0 commit comments