@@ -51,10 +51,9 @@ Specifically:
5151 import torch.multiprocessing as mp
5252 import torch.nn as nn
5353
54- from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
54+ from torch.distributed.fsdp import fully_shard
5555 from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
5656 from torch.distributed.checkpoint.stateful import Stateful
57- from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
5857
5958 CHECKPOINT_DIR = " checkpoint"
6059
@@ -74,7 +73,7 @@ Specifically:
7473
7574 def state_dict (self ):
7675 # this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
77- model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer)
76+ model_state_dict, optimizer_state_dict = get_state_dict(self . model, self . optimizer)
7877 return {
7978 " model" : model_state_dict,
8079 " optim" : optimizer_state_dict
@@ -105,7 +104,7 @@ Specifically:
105104 os.environ[" MASTER_PORT" ] = " 12355 "
106105
107106 # initialize the process group
108- dist.init_process_group(" nccl " , rank = rank, world_size = world_size)
107+ dist.init_process_group(" gloo " , rank = rank, world_size = world_size)
109108 torch.cuda.set_device(rank)
110109
111110
@@ -119,7 +118,7 @@ Specifically:
119118
120119 # create a model and move it to GPU with id rank
121120 model = ToyModel().to(rank)
122- model = FSDP (model)
121+ model = fully_shard (model)
123122
124123 loss_fn = nn.MSELoss()
125124 optimizer = torch.optim.Adam(model.parameters(), lr = 0.1 )
@@ -158,9 +157,9 @@ Specifically, this optimization attacks the main overhead of asynchronous checkp
158157checkpoint requests users can take advantage of direct memory access to speed up this copy.
159158
160159.. note ::
161- The main drawback of this optimization is the persistence of the buffer in between checkpointing steps. Without
162- the pinned memory optimization (as demonstrated above), any checkpointing buffers are released as soon as
163- checkpointing is finished. With the pinned memory implementation, this buffer is maintained between steps,
160+ The main drawback of this optimization is the persistence of the buffer in between checkpointing steps. Without
161+ the pinned memory optimization (as demonstrated above), any checkpointing buffers are released as soon as
162+ checkpointing is finished. With the pinned memory implementation, this buffer is maintained between steps,
164163 leading to the same
165164 peak memory pressure being sustained through the application life.
166165
@@ -175,11 +174,10 @@ checkpoint requests users can take advantage of direct memory access to speed up
175174 import torch.multiprocessing as mp
176175 import torch.nn as nn
177176
178- from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
177+ from torch.distributed.fsdp import fully_shard
179178 from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
180179 from torch.distributed.checkpoint.stateful import Stateful
181- from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
182- from torch.distributed.checkpoint import StorageWriter
180+ from torch.distributed.checkpoint import FileSystemWriter as StorageWriter
183181
184182 CHECKPOINT_DIR = " checkpoint"
185183
@@ -199,7 +197,7 @@ checkpoint requests users can take advantage of direct memory access to speed up
199197
200198 def state_dict (self ):
201199 # this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
202- model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer)
200+ model_state_dict, optimizer_state_dict = get_state_dict(self . model, self . optimizer)
203201 return {
204202 " model" : model_state_dict,
205203 " optim" : optimizer_state_dict
@@ -230,7 +228,7 @@ checkpoint requests users can take advantage of direct memory access to speed up
230228 os.environ[" MASTER_PORT" ] = " 12355 "
231229
232230 # initialize the process group
233- dist.init_process_group(" nccl " , rank = rank, world_size = world_size)
231+ dist.init_process_group(" gloo " , rank = rank, world_size = world_size)
234232 torch.cuda.set_device(rank)
235233
236234
@@ -244,7 +242,7 @@ checkpoint requests users can take advantage of direct memory access to speed up
244242
245243 # create a model and move it to GPU with id rank
246244 model = ToyModel().to(rank)
247- model = FSDP (model)
245+ model = fully_shard (model)
248246
249247 loss_fn = nn.MSELoss()
250248 optimizer = torch.optim.Adam(model.parameters(), lr = 0.1 )
@@ -254,7 +252,7 @@ checkpoint requests users can take advantage of direct memory access to speed up
254252 # into a persistent buffer with pinned memory enabled.
255253 # Note: It's important that the writer persists in between checkpointing requests, since it maintains the
256254 # pinned memory buffer.
257- writer = StorageWriter(cached_state_dict = True )
255+ writer = StorageWriter(cache_staged_state_dict = True , path = CHECKPOINT_DIR )
258256 checkpoint_future = None
259257 for step in range (10 ):
260258 optimizer.zero_grad()
0 commit comments