Skip to content

Commit 5c4f245

Browse files
committed
Add week04 materials
1 parent fa050e1 commit 5c4f245

File tree

12 files changed

+1226
-1
lines changed

12 files changed

+1226
-1
lines changed

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@ __This branch corresponds to the ongoing 2025 course. If you want to see full ma
1313
- [__Week 3:__ ](./week03_fast_pipelines) __Training optimizations, FP16/BF16/FP8 formats, profiling deep learning code__
1414
- Lecture: Measuring performance of GPU-accelerated software. Mixed-precision training. Data storage and loading optimizations. Tools for profiling deep learning workloads.
1515
- Seminar: Automatic Mixed Precision in PyTorch. Dynamic padding for sequence data and JPEG decoding benchmarks. Basics of profiling with py-spy, PyTorch Profiler, Memory Snapshot and Nsight Systems.
16-
- __Week 4:__ __Data-parallel training and All-Reduce__
16+
- [__Week 4:__](./week04_data_parallel) __Data-parallel training and All-Reduce__
17+
- Lecture: Introduction to distributed training. Data-parallel training of neural networks. All-Reduce and its efficient implementations.
18+
- Seminar: Introduction to PyTorch Distributed. Data-parallel training primitives.
1719
- __Week 5:__ __Sharded data-parallel training, distributed training optimizations__
1820
- __Week 6:__ __Training large models__
1921
- __Week 7:__ __Python web application deployment__

week04_data_parallel/README.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Week 4: Data-parallel training and All-Reduce
2+
3+
* Lecture: [link](./lecture.pdf)
4+
* Seminar: [link](./practice.ipynb)
5+
* Homework: see the [homework](./homework) folder
6+
7+
## Further reading
8+
* [Numba parallel](https://numba.pydata.org/numba-doc/dev/user/parallel.html) - a way to develop threaded parallel code in python without GIL
9+
* [joblib](https://joblib.readthedocs.io/) - a library of multiprocessing primitives similar to mp.Pool, but with some extra conveniences
10+
* BytePS paper - https://www.usenix.org/system/files/osdi20-jiang.pdf
11+
* Alternative lecture: Parameter servers from CMU 10-605 - [here](https://www.youtube.com/watch?v=N241lmq5mqk)
12+
* Alternative seminar: python multiprocessing - [playlist](https://www.youtube.com/watch?v=RR4SoktDQAw&list=PL5tcWHG-UPH3SX16DI6EP1FlEibgxkg_6)
13+
* [Python multiprocessing docs](https://docs.python.org/3/library/multiprocessing.html) (pay attention to `fork` vs `spawn`!)
14+
* [PyTorch Distributed tutorial](https://pytorch.org/tutorials/intermediate/dist_tuto.html)
15+
* [Collective communication protocols in NCCL](https://images.nvidia.com/events/sc15/pdfs/NCCL-Woolley.pdf)
16+
* There's a ton of links on the slides, please check the PDF.
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Week 4 home assignment
2+
3+
The assignment for this week consists of four parts: the first three are obligatory, and the fourth is a bonus one.
4+
Include all the files with implemented functions/classes and the report for Tasks 2 and 4 in your submission.
5+
6+
## Task 1 (1 point)
7+
8+
Implement the function for deterministic sequential printing of N numbers for N processes,
9+
using [sequential_print.py](./sequential_print.py) as a template.
10+
You should be able to test it with `torchrun --nproc_per_node N sequential_print.py`
11+
Pay attention to the output format!
12+
13+
## Task 2 (7 points)
14+
15+
The pipeline you saw in the seminar shows only the basic building blocks of distributed training. Now, let's train
16+
something actually interesting!
17+
18+
### SyncBatchNorm implementation
19+
For this task, let's take the [CIFAR-100](https://pytorch.org/vision/0.8/datasets.html#torchvision.datasets.CIFAR100)
20+
dataset and train a model with **synchronized** Batch Normalization: this version of the layer aggregates
21+
the statistics **across all workers** during each forward pass.
22+
23+
Importantly, you have to call a communication primitive **only once** during each forward or backward pass;
24+
if you use it more than once, you will only earn up to 4 points for this task.
25+
Additionally, you are **not allowed** to use internal PyTorch functions that compute the backward pass
26+
of batch normalization: please implement it manually.
27+
28+
### Reducing gradient synchronization
29+
Also, implement a version of distributed training which is aware of **gradient accumulation**:
30+
for every batch that doesn't run `optimizer.step`, you do not need to run All-Reduce for gradients at all.
31+
32+
### Benchmarking the training pipeline
33+
Compare the performance (in terms of speed, memory footprint, and final quality) of your distributed training
34+
pipeline with the one that uses primitives from PyTorch (i.e., [torch.nn.parallel.DistributedDataParallel](https://pytorch.org/docs/stable/nn.html#torch.nn.parallel.DistributedDataParallel) **and** [torch.nn.SyncBatchNorm](https://pytorch.org/docs/stable/generated/torch.nn.SyncBatchNorm.html)).
35+
You need to compare the implementations by training with **at least two** processes, and your pipeline needs to have
36+
at least 2 gradient accumulation steps.
37+
38+
### Tests for SyncBatchNorm
39+
In addition, **test the SyncBN layer itself** by comparing the results with standard **BatchNorm1d** and changing
40+
the number of workers (1 and 4), the size of activations (128, 256, 512, 1024), and the batch size (32, 64).
41+
42+
Compare the results of forward/backward passes in the following setup:
43+
* FP32 inputs come from the standard Gaussian distribution;
44+
* The loss function takes the outputs of batch normalization and computes the sum over all dimensions
45+
for first B/2 samples (B is the total batch size).
46+
47+
A working implementation of SyncBN should have reasonably low `atol` (at least 1e-3) and `rtol` equal to 0.
48+
49+
This test needs to be implemented via `pytest` in [test_syncbn.py](./test_syncbn.py): in particular, all the above
50+
parameters (including the number of workers) need to be the inputs of that test.
51+
Therefore, you will need to **start worker processes** within the test as well: `test_batchnorm` contains helpful
52+
comments to get you started.
53+
The test can be implemented to work only on the CPU for simplicity.
54+
55+
### Performance benchmarks
56+
Finally, measure the GPU time (2+ workers) and the memory footprint of standard **SyncBatchNorm**
57+
and your implementation in the above setup: in total, you should have 8 speed/memory benchmarks for each implementation.
58+
59+
### Submission format
60+
Provide the results of your experiments in a `.ipynb`/`.pdf` report and attach it to your code
61+
when submitting the homework.
62+
Your report should include a brief experimental setup (if changed), results of all experiments **with the commands/code
63+
to reproduce them**, and the infrastructure description (version of PyTorch, number of processes, type of GPUs, etc.).
64+
65+
Use [syncbn.py](./syncbn.py) and [ddp_cifar100.py](./ddp_cifar100.py) as a template.
66+
67+
## Task 3 (2 points)
68+
69+
Until now, we only aggregated the gradients across different workers during training. But what if we want to run
70+
distributed validation on a large dataset as well? In this assignment, you have to implement distributed metric
71+
aggregation: shard the dataset across different workers (with [scatter](https://pytorch.org/docs/stable/distributed.html#torch.distributed.scatter)), compute accuracy for each subset on
72+
its respective worker and then average the metric values on the master process.
73+
74+
Also, make one more quality-of-life improvement of the pipeline by logging the loss (and accuracy!)
75+
only from the rank-0 process to avoid flooding the standard output of your training command.
76+
Submit the training code that includes all enhancements from Tasks 2 and 3.
77+
78+
## Task 4 (bonus, 3 points)
79+
80+
Using [allreduce.py](./allreduce.py) as a template, implement the Ring All-Reduce algorithm
81+
using only point-to-point communication primitives from `torch.distributed`.
82+
Compare it with the provided implementation of Butterfly All-Reduce
83+
and with `torch.distributed.all_reduce` in terms of CPU speed, memory usage and the accuracy of averaging.
84+
Specifically, compare custom implementations of All-Reduce with 1–32 workers and compare your implementation of
85+
Ring All-Reduce with `torch.distributed.all_reduce` on 1–16 processes and vectors of 1,000–100,000 elements.
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import os
2+
import random
3+
4+
import torch
5+
import torch.distributed as dist
6+
from torch.multiprocessing import Process
7+
8+
9+
def init_process(rank, size, fn, master_port, backend="gloo"):
10+
"""Initialize the distributed environment."""
11+
os.environ["MASTER_ADDR"] = "127.0.0.1"
12+
os.environ["MASTER_PORT"] = str(master_port)
13+
dist.init_process_group(backend, rank=rank, world_size=size)
14+
fn(rank, size)
15+
16+
17+
def butterfly_allreduce(send, rank, size):
18+
"""
19+
Performs Butterfly All-Reduce over the process group. Modifies the input tensor in place.
20+
Args:
21+
send: torch.Tensor to be averaged with other processes.
22+
rank: Current process rank (in a range from 0 to size)
23+
size: Number of workers
24+
"""
25+
26+
buffer_for_chunk = torch.empty((size,), dtype=torch.float)
27+
28+
send_futures = []
29+
30+
for i, elem in enumerate(send):
31+
if i != rank:
32+
send_futures.append(dist.isend(elem, i))
33+
34+
recv_futures = []
35+
36+
for i, elem in enumerate(buffer_for_chunk):
37+
if i != rank:
38+
recv_futures.append(dist.irecv(elem, i))
39+
else:
40+
elem.copy_(send[i])
41+
42+
for future in recv_futures:
43+
future.wait()
44+
45+
# compute the average
46+
torch.mean(buffer_for_chunk, dim=0, out=send[rank])
47+
48+
for i in range(size):
49+
if i != rank:
50+
send_futures.append(dist.isend(send[rank], i))
51+
52+
recv_futures = []
53+
54+
for i, elem in enumerate(send):
55+
if i != rank:
56+
recv_futures.append(dist.irecv(elem, i))
57+
58+
for future in recv_futures:
59+
future.wait()
60+
for future in send_futures:
61+
future.wait()
62+
63+
64+
def ring_allreduce(send, rank, size):
65+
"""
66+
Performs Ring All-Reduce over the process group. Modifies the input tensor in place.
67+
Args:
68+
send: torch.Tensor to be averaged with other processes.
69+
rank: Current process rank (in a range from 0 to size)
70+
size: Number of workers
71+
"""
72+
pass
73+
74+
75+
def run_butterfly_allreduce(rank, size):
76+
"""Simple point-to-point communication."""
77+
torch.manual_seed(rank)
78+
tensor = torch.randn((size,), dtype=torch.float)
79+
print("Rank ", rank, " has data ", tensor)
80+
butterfly_allreduce(tensor, rank, size)
81+
print("Rank ", rank, " has data ", tensor)
82+
83+
84+
if __name__ == "__main__":
85+
size = 5
86+
processes = []
87+
port = random.randint(25000, 30000)
88+
for rank in range(size):
89+
p = Process(target=init_process, args=(rank, size, run_butterfly_allreduce, port))
90+
p.start()
91+
processes.append(p)
92+
93+
for p in processes:
94+
p.join()
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
import os
2+
3+
import torch
4+
import torch.distributed as dist
5+
import torch.nn as nn
6+
import torch.nn.functional as F
7+
import torchvision.transforms as transforms
8+
from torch.utils.data import DataLoader
9+
from torch.utils.data.distributed import DistributedSampler
10+
from torchvision.datasets import CIFAR100
11+
12+
torch.set_num_threads(1)
13+
14+
15+
def init_process(local_rank, fn, backend="nccl"):
16+
"""Initialize the distributed environment."""
17+
dist.init_process_group(backend, rank=local_rank)
18+
size = dist.get_world_size()
19+
fn(local_rank, size)
20+
21+
22+
class Net(nn.Module):
23+
"""
24+
A very simple model with minimal changes from the tutorial, used for the sake of simplicity.
25+
Feel free to replace it with EffNetV2-XL once you get comfortable injecting SyncBN into models programmatically.
26+
"""
27+
28+
def __init__(self):
29+
super().__init__()
30+
self.conv1 = nn.Conv2d(3, 32, 3, 1)
31+
self.conv2 = nn.Conv2d(32, 32, 3, 1)
32+
self.dropout1 = nn.Dropout(0.25)
33+
self.dropout2 = nn.Dropout(0.5)
34+
self.fc1 = nn.Linear(6272, 128)
35+
self.fc2 = nn.Linear(128, 100)
36+
self.bn1 = nn.BatchNorm1d(128, affine=False) # to be replaced with SyncBatchNorm
37+
38+
def forward(self, x):
39+
x = self.conv1(x)
40+
x = F.relu(x)
41+
42+
x = self.conv2(x)
43+
x = F.relu(x)
44+
45+
x = F.max_pool2d(x, 2)
46+
x = self.dropout1(x)
47+
48+
x = torch.flatten(x, 1)
49+
x = self.fc1(x)
50+
x = self.bn1(x)
51+
x = F.relu(x)
52+
x = self.dropout2(x)
53+
output = self.fc2(x)
54+
return output
55+
56+
57+
def average_gradients(model):
58+
size = float(dist.get_world_size())
59+
for param in model.parameters():
60+
dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
61+
param.grad.data /= size
62+
63+
64+
def run_training(rank, size):
65+
torch.manual_seed(0)
66+
67+
dataset = CIFAR100(
68+
"./cifar",
69+
transform=transforms.Compose(
70+
[
71+
transforms.ToTensor(),
72+
transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
73+
]
74+
),
75+
download=True,
76+
)
77+
# where's the validation dataset?
78+
loader = DataLoader(dataset, sampler=DistributedSampler(dataset, size, rank), batch_size=64)
79+
80+
model = Net()
81+
device = torch.device("cpu") # replace with "cuda" afterwards
82+
model.to(device)
83+
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
84+
85+
num_batches = len(loader)
86+
87+
for _ in range(10):
88+
epoch_loss = torch.zeros((1,), device=device)
89+
90+
for data, target in loader:
91+
data = data.to(device)
92+
target = target.to(device)
93+
94+
optimizer.zero_grad()
95+
output = model(data)
96+
loss = torch.nn.functional.cross_entropy(output, target)
97+
epoch_loss += loss.detach()
98+
loss.backward()
99+
average_gradients(model)
100+
optimizer.step()
101+
102+
acc = (output.argmax(dim=1) == target).float().mean()
103+
104+
print(f"Rank {dist.get_rank()}, loss: {epoch_loss / num_batches}, acc: {acc}")
105+
epoch_loss = 0
106+
# where's the validation loop?
107+
108+
109+
if __name__ == "__main__":
110+
local_rank = int(os.environ["LOCAL_RANK"])
111+
init_process(local_rank, fn=run_training, backend="gloo") # replace with "nccl" when testing on several GPUs
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
pytest==8.3.4
2+
torch==2.4.0
3+
torchvision==0.19.0
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import os
2+
3+
import torch.distributed as dist
4+
5+
6+
def run_sequential(rank, size, num_iter=10):
7+
"""
8+
Prints the process rank sequentially according to its number over `num_iter` iterations,
9+
separating the output for each iteration by `---`
10+
Example (3 processes, num_iter=2):
11+
```
12+
Process 0
13+
Process 1
14+
Process 2
15+
---
16+
Process 0
17+
Process 1
18+
Process 2
19+
```
20+
"""
21+
22+
pass
23+
24+
25+
if __name__ == "__main__":
26+
local_rank = int(os.environ["LOCAL_RANK"])
27+
dist.init_process_group(rank=local_rank, backend="gloo")
28+
29+
run_sequential(local_rank, dist.get_world_size())

0 commit comments

Comments
 (0)