11# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
3+ # pyre-strict
4+
5+ import functools
26import io
7+ import os
8+
9+ import pickle
310from datetime import timedelta
411from typing import Any
512
916from torchcomms ._comms import TorchComm
1017
1118
19+ class _Serialization :
20+ """Serialization helper with serialize and deserialize methods."""
21+
22+ def __init__ (self ) -> None :
23+ self .use_pickle : bool = os .getenv ("TORCHCOMMS_SERIALIZATION" ) == "pickle"
24+
25+ def serialize (self , f : io .BytesIO , obj : object ) -> None :
26+ if self .use_pickle :
27+ pickle .Pickler (f ).dump (obj )
28+ else :
29+ torch .save (obj , f )
30+
31+ def deserialize (self , f : io .BytesIO , weights_only : bool ) -> object :
32+ if self .use_pickle :
33+ return pickle .Unpickler (f ).load ()
34+ else :
35+ return torch .load (f , weights_only = weights_only )
36+
37+
38+ @functools .lru_cache (maxsize = None )
39+ def get_serialization () -> _Serialization :
40+ """Returns a cached serialization object with serialize and deserialize methods."""
41+ return _Serialization ()
42+
43+
1244def _object_to_tensor (
1345 obj : object , device : torch .device
1446) -> tuple [torch .Tensor , torch .Tensor ]:
1547 with _WaitCounter ("pytorch.wait_counter.torchcomms._object_to_tensor" ).guard ():
1648 f = io .BytesIO ()
17- torch .save (obj , f )
49+ serialization = get_serialization ()
50+ serialization .serialize (f , obj )
1851 byte_storage = torch .ByteStorage ._from_buffer (f .getvalue ()) # type: ignore[attr-defined]
1952 # Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype.
2053 # Otherwise, it will cause 100X slowdown.
@@ -30,7 +63,8 @@ def _tensor_to_object(
3063 with _WaitCounter ("pytorch.wait_counter.torchcomms._tensor_to_object" ).guard ():
3164 tensor = tensor .cpu ()
3265 buf = tensor .numpy ().tobytes ()[:tensor_size ]
33- return torch .load (io .BytesIO (buf ), weights_only = weights_only )
66+ serialization = get_serialization ()
67+ return serialization .deserialize (io .BytesIO (buf ), weights_only = weights_only )
3468
3569
3670def all_gather_object (
@@ -445,7 +479,7 @@ def broadcast_object_list(
445479 root : int ,
446480 timeout : timedelta | None = None ,
447481 weights_only : bool = True ,
448- ):
482+ ) -> None :
449483 """
450484 Broadcasts picklable objects in ``object_list`` to the whole comm.
451485
0 commit comments