Skip to content

Commit 55cbfa5

Browse files
Add week09 materials (#22)
Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
1 parent 729a6f3 commit 55cbfa5

File tree

8 files changed

+1262
-1
lines changed

8 files changed

+1262
-1
lines changed

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ __This branch corresponds to the ongoing 2024 course. If you want to see full ma
2727
- [__Week 8:__](./week08_inference_software) __LLM inference optimizations and software__
2828
- Lecture: Inference speed metrics. KV caching, batch inference, continuous batching. FlashAttention with its modifications and PagedAttention. Overview of popular LLM serving frameworks.
2929
- Seminar: Basics of the Triton language. Layer fusion in PyTorch and Triton. Implementation of KV caching, FlashAttention in practice.
30-
- __Week 9:__ __Efficient model inference__
30+
- [__Week 9:__](./week09_compression) __Efficient model inference__
31+
- Lecture: Hardware utilization metrics for deep learning. Knowledge distillation, quantization, LLM.int8(), SmoothQuant, GPTQ. Efficient model architectures. Speculative decoding.
32+
- Seminar: Measuring Memory Bandwidth Utilization in practice. Data-free quantization, GPTq, and SmoothQuant in PyTorch.
3133
- __Week 10:__ __Guest lecture__
3234

3335
## Grading

week09_compression/README.md

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Week 9: Efficient model inference
2+
3+
* Lecture: [slides](./lecture.pdf)
4+
* Seminar: [notebook](./practice.ipynb)
5+
* Homework: see [homework/README.md](homework/README.md)
6+
7+
### Setup for the seminar notebook
8+
You can use [conda](https://docs.anaconda.com/free/miniconda/), [mamba](https://mamba.readthedocs.io/en/latest/user_guide/mamba.html) or [micromamba](https://mamba.readthedocs.io/en/latest/user_guide/micromamba.html) to create the environment.
9+
10+
```
11+
conda create -n inference \
12+
python=3.10 \
13+
pytorch=2.2.1 \
14+
torchvision=0.17.1 \
15+
torchaudio=2.2.1 \
16+
pytorch-cuda=11.8 \
17+
matplotlib=3.8.0 \
18+
seaborn=0.12.2 \
19+
numpy=1.26.4 \
20+
ipywidgets=8.1.2 \
21+
jupyterlab_widgets=3.0.10 \
22+
jupyterlab=4.0.11 \
23+
tqdm=4.65.0 \
24+
-c pytorch -c nvidia -y
25+
26+
conda activate inference
27+
28+
# To run part with auto-gptq
29+
pip install auto-gptq==0.7.1 accelerate==0.28.0
30+
pip install --upgrade git+https://github.com/huggingface/transformers.git
31+
32+
# To run part with Smoothquant
33+
cd ~
34+
git clone git@github.com:mit-han-lab/smoothquant.git
35+
cd smoothquant
36+
python setup.py install
37+
cd path/to/efficient-dl-systems/week09_compression
38+
39+
# Finally, running notebook
40+
jupyter lab --no-browser
41+
```
42+
43+
## Further reading
44+
45+
### Knowledge distillation
46+
* https://arxiv.org/abs/2106.05237
47+
* https://arxiv.org/abs/1910.01108
48+
* https://arxiv.org/abs/1909.10351
49+
50+
### Pruning
51+
* https://arxiv.org/abs/2302.04089
52+
* https://arxiv.org/abs/2301.00774
53+
54+
### Quantization
55+
* https://arxiv.org/abs/2206.09557
56+
* https://arxiv.org/abs/2208.07339
57+
* https://huggingface.co/blog/hf-bitsandbytes-integration
58+
* https://pytorch.org/tutorials/advanced/static_quantization_tutorial.html
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Week 9 home assignment
2+
3+
## Submission format
4+
Implement models, training procedures and benchmarks in `.py` files, run all code in a Jupyter notebook and convert it to the PDF format.
5+
Include your implementations and the report file into a `.zip` archive and submit it.
6+
7+
8+
## Task 1: knowledge distillation for image classification (6 points)
9+
10+
0. Finetune ResNet101 on CIFAR10: change only the classification linear layer [*] and don't freeze other weights (**0 points**)
11+
12+
Then take untrained ResNet101 model, remove the `layer3` (except one conv block that creates correct number of channels for the 4-th layer) block out of it and implement 3 training setups:
13+
1. Train the model on input data only (**1 point**)
14+
2. Train the model on data and add soft cross-entropy between the student (truncated ResNet101) and the teacher (finetuned full ResNet101) (**2 points**)
15+
3. Train the model as in the previous subtask, but also add the MSE loss between corresponding `layer1`, `layer2` and `layer4` features of the student and the teacher (**3 points**)
16+
4. Report test accuracy for each of the models
17+
18+
[\*] Vanilla ResNet is not very well suited for CIFAR: it downsamples the image by x32, while images in CIFAR are 32x32 pixels. So you can:
19+
- upsample images (easiest to implement, but you will perform more computations)
20+
- slightly change the first layers (e.g. make `model.conv1` a 3x3 convolution with stride 1 and remove `model.maxpool`)
21+
22+
Feel free to use dataset and model implementation from PyTorch.
23+
For losses in 2nd and 3rd subtasks use the simple average of all inputs.
24+
For the 3rd subtask, you will need to return not only the model's outputs but also intermediate feature maps.
25+
26+
### Training setup
27+
- Use the standard Adam optimizer without scheduler.
28+
- Use any suitable batch size from 128 to 512.
29+
- Training stopping criterion: accuracy (measured from 0 to 1) stabilizes in the second digit after decimal during at least 2 epochs on test set.
30+
That means that you must satisfy condition `torch.abs(acc - acc_prev) < 0.01` for at least two epochs in a row.
31+
32+
## Task 2: use `deepsparse` to prune & quantize your model (4 points)
33+
34+
0. Please read the whole task description before starting it.
35+
1. Install `deepsparse==1.7.0` and `sparseml==1.7.0`. Note: they might not work smoothly with last PyTorch versions. If so, you can downgrade to `torch==1.12.1`.
36+
2. Take your best trained model from subtasks 1.1-1.3 and run pruning + quantization-aware-training, adapting the following [example](./example_train_sparse_and_quantize.py). You will need to change/implement what is marked by #TODO and report the test accuracy of both models. (**3 points**)
37+
3. Take `onnx` baseline (best trained model from subtask 1.1 - 1.3) and pruned-quantized version and benchmark both models on the CPU using `deepsparse.benchmark` at batch sizes 1 and 32. (**1 point**)
38+
39+
For task 2.3, you may find [this page](https://web.archive.org/web/20240319095504/https://docs.neuralmagic.com/user-guides/deepsparse-engine/benchmarking/) helpful.
40+
41+
You should not use training stopping criterion in this part, since the sparsification recipe relies on having certain amount of epochs.
42+
43+
### Tips:
44+
- Debug your code with resnet18 to iterate faster
45+
- Don't forget `model.eval()` before onnx export
46+
- Don't forget `convert_qat=True` in `sparseml.pytorch.utils.export_onnx` after you trained the model with quantization
47+
- To visualize ONNX models, you can use [netron](https://netron.app/)
48+
- Explicitly set the amount of cores in `deepsparse.benchmark`
49+
- If you are desperate and don't have time to train bigger models, submit this part with resnet18
50+
51+
Good luck and have 59 funs!
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
from pathlib import Path
2+
from tqdm.auto import tqdm
3+
4+
import torch
5+
import torch.nn as nn
6+
from torch.utils.data import DataLoader
7+
8+
from torchvision.models import resnet18, ResNet18_Weights
9+
from sparseml.pytorch.datasets import ImagenetteDataset, ImagenetteSize
10+
from sparseml.pytorch.optim import ScheduledModifierManager
11+
from sparseml.pytorch.utils import export_onnx
12+
13+
def save_onnx(model, export_path, convert_qat):
14+
# It is important to call torch_model.eval() or torch_model.train(False) before exporting the model, to turn the model to inference mode.
15+
# This is required since operators like dropout or batchnorm behave differently in inference and training mode.
16+
model.eval()
17+
sample_batch = torch.randn((1, 3, 224, 224))
18+
export_onnx(model, sample_batch, export_path, convert_qat=convert_qat)
19+
20+
21+
def main():
22+
# TODO: add argparse/hydra/... to manage hyperparameters like batch_size, path to pretrained model, etc
23+
24+
# Sparsification recipe -- yaml file with instructions on how to sparsify the model
25+
recipe_path = "recipe.yaml"
26+
assert Path(recipe_path).exists(), "Didn't find sparsification recipe!"
27+
28+
checkpoints_path = Path("checkpoints")
29+
checkpoints_path.mkdir(exist_ok=True)
30+
31+
# Model creation
32+
# TODO: change to your best model from subtasks 1.1 - 1.3
33+
NUM_CLASSES = 10 # number of Imagenette classes
34+
model = resnet18(weights=ResNet18_Weights.DEFAULT)
35+
model.fc = nn.Linear(model.fc.in_features, NUM_CLASSES)
36+
37+
save_onnx(model, checkpoints_path / "baseline_resnet.onnx", convert_qat=False)
38+
39+
# Dataset creation
40+
# TODO: change to CIFAR10, add test dataset
41+
batch_size = 64
42+
train_dataset = ImagenetteDataset(train=True, dataset_size=ImagenetteSize.s320, image_size=224)
43+
train_loader = DataLoader(train_dataset, batch_size, shuffle=True, pin_memory=True, num_workers=8)
44+
45+
# Device setup
46+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
47+
model.to(device)
48+
49+
# Loss setup
50+
criterion = nn.CrossEntropyLoss()
51+
# Note that learning rate is being modified in `recipe.yaml`
52+
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
53+
54+
# SparseML Integration
55+
manager = ScheduledModifierManager.from_yaml(recipe_path)
56+
optimizer = manager.modify(model, optimizer, steps_per_epoch=len(train_loader))
57+
58+
# Training Loop
59+
model.train()
60+
61+
# TODO: implement `train_one_epoch` function to structure the code better
62+
pbar = tqdm(range(manager.max_epochs), desc="epoch")
63+
for epoch in pbar:
64+
running_loss = 0.0
65+
running_corrects = 0.0
66+
for inputs, labels in train_loader:
67+
inputs = inputs.to(device)
68+
labels = labels.to(device)
69+
optimizer.zero_grad()
70+
71+
with torch.set_grad_enabled(True):
72+
outputs = model(inputs)
73+
loss = criterion(outputs, labels)
74+
_, preds = torch.max(outputs, 1)
75+
loss.backward()
76+
optimizer.step()
77+
78+
running_loss += loss * inputs.size(0)
79+
running_corrects += torch.sum(preds == labels.data)
80+
81+
epoch_loss = running_loss.item() / len(train_loader.dataset)
82+
epoch_acc = running_corrects.double().item() / len(train_loader.dataset)
83+
pbar.set_description(f"Training loss: {epoch_loss:.3f} Accuracy: {epoch_acc:.3f}")
84+
85+
# TODO: implement `evaluate` function to measure accuracy on the test set
86+
87+
manager.finalize(model)
88+
89+
# Saving model
90+
save_onnx(model, checkpoints_path / "pruned_quantized_resnet.onnx", convert_qat=True)
91+
92+
if __name__ == "__main__":
93+
main()
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
modifiers:
2+
- !GlobalMagnitudePruningModifier
3+
init_sparsity: 0.05
4+
final_sparsity: 0.8
5+
start_epoch: 0.0
6+
end_epoch: 30.0
7+
update_frequency: 1.0
8+
params: __ALL_PRUNABLE__
9+
10+
- !SetLearningRateModifier
11+
start_epoch: 0.0
12+
learning_rate: 0.05
13+
14+
- !LearningRateFunctionModifier
15+
start_epoch: 30.0
16+
end_epoch: 50.0
17+
lr_func: cosine
18+
init_lr: 0.05
19+
final_lr: 0.001
20+
21+
- !QuantizationModifier
22+
start_epoch: 50.0
23+
freeze_bn_stats_epoch: 53.0
24+
25+
- !SetLearningRateModifier
26+
start_epoch: 50.0
27+
learning_rate: 10e-6
28+
29+
- !EpochRangeModifier
30+
start_epoch: 0.0
31+
end_epoch: 55.0

week09_compression/lecture.pdf

2.38 MB
Binary file not shown.
41.9 KB
Loading

0 commit comments

Comments
 (0)