Skip to content

Commit e0a1f99

Browse files
siyengarmeta-codesync[bot]
authored andcommitted
pickle changes
Summary: changes the serialization to pickle to match c10d changing the format caused some issues with objects not being consistent, so revert back to the old format Reviewed By: d4l3k Differential Revision: D86941971 fbshipit-source-id: 97b9b1f2415bfc239fa8ffffbd964b36a09e2199
1 parent 7d77302 commit e0a1f99

File tree

2 files changed

+46
-3
lines changed

2 files changed

+46
-3
lines changed

comms/torchcomms/objcol.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
3+
# pyre-strict
4+
5+
import functools
26
import io
7+
import os
8+
9+
import pickle
310
from datetime import timedelta
411
from typing import Any
512

@@ -9,12 +16,38 @@
916
from torchcomms._comms import TorchComm
1017

1118

19+
class _Serialization:
20+
"""Serialization helper with serialize and deserialize methods."""
21+
22+
def __init__(self) -> None:
23+
self.use_pickle: bool = os.getenv("TORCHCOMMS_SERIALIZATION") == "pickle"
24+
25+
def serialize(self, f: io.BytesIO, obj: object) -> None:
26+
if self.use_pickle:
27+
pickle.Pickler(f).dump(obj)
28+
else:
29+
torch.save(obj, f)
30+
31+
def deserialize(self, f: io.BytesIO, weights_only: bool) -> object:
32+
if self.use_pickle:
33+
return pickle.Unpickler(f).load()
34+
else:
35+
return torch.load(f, weights_only=weights_only)
36+
37+
38+
@functools.lru_cache(maxsize=None)
39+
def get_serialization() -> _Serialization:
40+
"""Returns a cached serialization object with serialize and deserialize methods."""
41+
return _Serialization()
42+
43+
1244
def _object_to_tensor(
1345
obj: object, device: torch.device
1446
) -> tuple[torch.Tensor, torch.Tensor]:
1547
with _WaitCounter("pytorch.wait_counter.torchcomms._object_to_tensor").guard():
1648
f = io.BytesIO()
17-
torch.save(obj, f)
49+
serialization = get_serialization()
50+
serialization.serialize(f, obj)
1851
byte_storage = torch.ByteStorage._from_buffer(f.getvalue()) # type: ignore[attr-defined]
1952
# Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype.
2053
# Otherwise, it will cause 100X slowdown.
@@ -30,7 +63,8 @@ def _tensor_to_object(
3063
with _WaitCounter("pytorch.wait_counter.torchcomms._tensor_to_object").guard():
3164
tensor = tensor.cpu()
3265
buf = tensor.numpy().tobytes()[:tensor_size]
33-
return torch.load(io.BytesIO(buf), weights_only=weights_only)
66+
serialization = get_serialization()
67+
return serialization.deserialize(io.BytesIO(buf), weights_only=weights_only)
3468

3569

3670
def all_gather_object(
@@ -445,7 +479,7 @@ def broadcast_object_list(
445479
root: int,
446480
timeout: timedelta | None = None,
447481
weights_only: bool = True,
448-
):
482+
) -> None:
449483
"""
450484
Broadcasts picklable objects in ``object_list`` to the whole comm.
451485

comms/torchcomms/tests/integration/py/ObjColTest.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,5 +182,14 @@ def test_scatter_object_list(self) -> None:
182182
self.assertEqual(scatter_output[0], expected_object)
183183

184184

185+
class ObjColTestWithPickle(ObjColTest):
186+
def setUp(self):
187+
"""Set up test environment before each test."""
188+
import os
189+
190+
os.environ["TORCHCOMMS_SERIALIZATION"] = "pickle"
191+
super().setUp()
192+
193+
185194
if __name__ == "__main__":
186195
unittest.main()

0 commit comments

Comments
 (0)