Skip to content

Commit 297e285

Browse files
committed
[SE3Transformer/DGLPyT] Update container and fix benchmarking
1 parent 678b470 commit 297e285

File tree

7 files changed

+54
-39
lines changed

7 files changed

+54
-39
lines changed

DGLPyTorch/DrugDiscovery/SE3Transformer/Dockerfile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
# run docker daemon with --default-runtime=nvidia for GPU detection during build
2525
# multistage build for DGL with CUDA and FP16
2626

27-
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:22.08-py3
27+
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:23.01-py3
2828

2929
FROM ${FROM_IMAGE_NAME} AS dgl_builder
3030

@@ -33,7 +33,7 @@ RUN apt-get update \
3333
&& apt-get install -y git build-essential python3-dev make cmake \
3434
&& rm -rf /var/lib/apt/lists/*
3535
WORKDIR /dgl
36-
RUN git clone --branch 0.9.0 --recurse-submodules --depth 1 https://github.com/dmlc/dgl.git .
36+
RUN git clone --branch 1.0.0 --recurse-submodules --depth 1 https://github.com/dmlc/dgl.git .
3737
WORKDIR build
3838
RUN export NCCL_ROOT=/usr \
3939
&& cmake .. -GNinja -DCMAKE_BUILD_TYPE=Release \

DGLPyTorch/DrugDiscovery/SE3Transformer/README.md

Lines changed: 34 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -252,9 +252,9 @@ The following section lists the requirements that you need to meet in order to s
252252

253253
### Requirements
254254

255-
This repository contains a Dockerfile which extends the PyTorch 21.07 NGC container and encapsulates some dependencies. Aside from these dependencies, ensure you have the following components:
255+
This repository contains a Dockerfile which extends the PyTorch 23.01 NGC container and encapsulates some dependencies. Aside from these dependencies, ensure you have the following components:
256256
- [NVIDIA Docker](https://github.com/NVIDIA/nvidia-docker)
257-
- PyTorch 21.07+ NGC container
257+
- PyTorch 23.01+ NGC container
258258
- Supported GPUs:
259259
- [NVIDIA Volta architecture](https://www.nvidia.com/en-us/data-center/volta-gpu-architecture/)
260260
- [NVIDIA Turing architecture](https://www.nvidia.com/en-us/design-visualization/technologies/turing-architecture/)
@@ -290,7 +290,7 @@ To train your model using mixed or TF32 precision with Tensor Cores or FP32, per
290290
291291
4. Start training.
292292
```
293-
bash scripts/train.sh
293+
bash scripts/train.sh # or scripts/train_multi_gpu.sh
294294
```
295295
296296
5. Start inference/predictions.
@@ -474,7 +474,7 @@ The following sections provide details on how we achieved our performance and ac
474474
475475
##### Training accuracy: NVIDIA DGX A100 (8x A100 80GB)
476476
477-
Our results were obtained by running the `scripts/train.sh` training script in the PyTorch 21.07 NGC container on NVIDIA DGX A100 (8x A100 80GB) GPUs.
477+
Our results were obtained by running the `scripts/train.sh` and `scripts/train_multi_gpu.sh` training scripts in the PyTorch 23.01 NGC container on NVIDIA DGX A100 (8x A100 80GB) GPUs.
478478
479479
| GPUs | Batch size / GPU | Absolute error - TF32 | Absolute error - mixed precision | Time to train - TF32 | Time to train - mixed precision | Time to train speedup (mixed precision to TF32) |
480480
|:----:|:----------------:|:---------------------:|:--------------------------------:|:--------------------:|:-------------------------------:|:-----------------------------------------------:|
@@ -484,7 +484,7 @@ Our results were obtained by running the `scripts/train.sh` training script in t
484484
485485
##### Training accuracy: NVIDIA DGX-1 (8x V100 16GB)
486486
487-
Our results were obtained by running the `scripts/train.sh` training script in the PyTorch 21.07 NGC container on NVIDIA DGX-1 with (8x V100 16GB) GPUs.
487+
Our results were obtained by running the `scripts/train.sh` and `scripts/train_multi_gpu.sh` training scripts in the PyTorch 23.01 NGC container on NVIDIA DGX-1 with (8x V100 16GB) GPUs.
488488
489489
| GPUs | Batch size / GPU | Absolute error - FP32 | Absolute error - mixed precision | Time to train - FP32 | Time to train - mixed precision | Time to train speedup (mixed precision to FP32) |
490490
|:----:|:----------------:|:---------------------:|:--------------------------------:|:--------------------:|:-------------------------------:|:-----------------------------------------------:|
@@ -497,29 +497,29 @@ Our results were obtained by running the `scripts/train.sh` training script in t
497497
498498
##### Training performance: NVIDIA DGX A100 (8x A100 80GB)
499499
500-
Our results were obtained by running the `scripts/benchmark_train.sh` and `scripts/benchmark_train_multi_gpu.sh` benchmarking scripts in the PyTorch 21.07 NGC container on NVIDIA DGX A100 with 8x A100 80GB GPUs. Performance numbers (in molecules per millisecond) were averaged over five entire training epochs after a warmup epoch.
500+
Our results were obtained by running the `scripts/benchmark_train.sh` and `scripts/benchmark_train_multi_gpu.sh` benchmarking scripts in the PyTorch 23.01 NGC container on NVIDIA DGX A100 with 8x A100 80GB GPUs. Performance numbers (in molecules per millisecond) were averaged over five entire training epochs after a warmup epoch.
501501
502502
| GPUs | Batch size / GPU | Throughput - TF32 [mol/ms] | Throughput - mixed precision [mol/ms] | Throughput speedup (mixed precision - TF32) | Weak scaling - TF32 | Weak scaling - mixed precision |
503503
|:----------------:|:-------------------:|:--------------------------:|:-------------------------------------:|:-------------------------------------------:|:-------------------:|:------------------------------:|
504-
| 1 | 240 | 2.61 | 3.35 | 1.28x | | |
505-
| 1 | 120 | 1.94 | 2.07 | 1.07x | | |
506-
| 8 | 240 | 18.80 | 23.90 | 1.27x | 7.20 | 7.13 |
507-
| 8 | 120 | 14.10 | 14.52 | 1.03x | 7.27 | 7.01 |
504+
| 1 | 240 | 2.59 | 3.23 | 1.25x | | |
505+
| 1 | 120 | 1.89 | 1.89 | 1.00x | | |
506+
| 8 | 240 | 18.38 | 21.42 | 1.17x | 7.09 | 6.63 |
507+
| 8 | 120 | 13.23 | 13.23 | 1.00x | 7.00 | 7.00 |
508508
509509
510510
To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
511511
512512
513513
##### Training performance: NVIDIA DGX-1 (8x V100 16GB)
514514
515-
Our results were obtained by running the `scripts/benchmark_train.sh` and `scripts/benchmark_train_multi_gpu.sh` benchmarking scripts in the PyTorch 21.07 NGC container on NVIDIA DGX-1 with 8x V100 16GB GPUs. Performance numbers (in molecules per millisecond) were averaged over five entire training epochs after a warmup epoch.
515+
Our results were obtained by running the `scripts/benchmark_train.sh` and `scripts/benchmark_train_multi_gpu.sh` benchmarking scripts in the PyTorch 23.01 NGC container on NVIDIA DGX-1 with 8x V100 16GB GPUs. Performance numbers (in molecules per millisecond) were averaged over five entire training epochs after a warmup epoch.
516516
517517
| GPUs | Batch size / GPU | Throughput - FP32 [mol/ms] | Throughput - mixed precision [mol/ms] | Throughput speedup (FP32 - mixed precision) | Weak scaling - FP32 | Weak scaling - mixed precision |
518518
|:----------------:|:--------------------:|:--------------------------:|:--------------------------------------:|:-------------------------------------------:|:-------------------:|:------------------------------:|
519-
| 1 | 240 | 1.33 | 2.12 | 1.59x | | |
520-
| 1 | 120 | 1.11 | 1.45 | 1.31x | | |
521-
| 8 | 240 | 9.32 | 13.40 | 1.44x | 7.01 | 6.32 |
522-
| 8 | 120 | 6.90 | 8.39 | 1.22x | 6.21 | 5.79 |
519+
| 1 | 240 | 1.23 | 1.91 | 1.55x | | |
520+
| 1 | 120 | 1.01 | 1.23 | 1.22x | | |
521+
| 8 | 240 | 8.44 | 11.28 | 1.34x | 6.8 | 5.90 |
522+
| 8 | 120 | 6.06 | 7.36 | 1.21x | 6.00 | 5.98 |
523523
524524
525525
To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
@@ -530,47 +530,47 @@ To achieve these same results, follow the steps in the [Quick Start Guide](#quic
530530
531531
##### Inference performance: NVIDIA DGX A100 (1x A100 80GB)
532532
533-
Our results were obtained by running the `scripts/benchmark_inference.sh` inferencing benchmarking script in the PyTorch 21.07 NGC container on NVIDIA DGX A100 with 1x A100 80GB GPU.
533+
Our results were obtained by running the `scripts/benchmark_inference.sh` inferencing benchmarking script in the PyTorch 23.01 NGC container on NVIDIA DGX A100 with 1x A100 80GB GPU.
534534
535535
AMP
536536
537537
| Batch size | Throughput Avg [mol/ms] | Latency Avg [ms] | Latency 90% [ms] | Latency 95% [ms] | Latency 99% [ms] |
538538
|:----------:|:-----------------------:|:----------------:|:----------------:|:----------------:|:----------------:|
539-
| 1600 | 13.54 | 121.44 | 118.07 | 119.00 | 366.64 |
540-
| 800 | 12.63 | 64.11 | 63.78 | 64.37 | 68.19 |
541-
| 400 | 10.65 | 37.97 | 39.02 | 39.67 | 42.87 |
539+
| 1600 | 9.71 | 175.2 | 190.2 | 191.8 | 432.4 |
540+
| 800 | 7.90 | 114.5 | 134.3 | 135.8 | 140.2 |
541+
| 400 | 7.18 | 75.49 | 108.6 | 109.6 | 113.2 |
542542
543543
TF32
544544
545545
| Batch size | Throughput Avg [mol/ms] | Latency Avg [ms] | Latency 90% [ms] | Latency 95% [ms] | Latency 99% [ms] |
546546
|:----------:|:-----------------------:|:----------------:|:----------------:|:----------------:|:----------------:|
547-
| 1600 | 8.97 | 180.85 | 178.31 | 178.92 | 375.33 |
548-
| 800 | 8.86 | 90.76 | 90.77 | 91.11 | 92.96 |
549-
| 400 | 8.49 | 47.42 | 47.65 | 48.15 | 50.74 |
547+
| 1600 | 8.19 | 198.2 | 206.8 | 208.5 | 377.0 |
548+
| 800 | 7.56 | 107.5 | 119.6 | 120.5 | 125.7 |
549+
| 400 | 6.97 | 59.8 | 75.1 | 75.7 | 81.3 |
550550
551551
To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
552552
553553
554554
555555
##### Inference performance: NVIDIA DGX-1 (1x V100 16GB)
556556
557-
Our results were obtained by running the `scripts/benchmark_inference.sh` inferencing benchmarking script in the PyTorch 21.07 NGC container on NVIDIA DGX-1 with 1x V100 16GB GPU.
557+
Our results were obtained by running the `scripts/benchmark_inference.sh` inferencing benchmarking script in the PyTorch 23.01 NGC container on NVIDIA DGX-1 with 1x V100 16GB GPU.
558558
559559
AMP
560560
561561
| Batch size | Throughput Avg [mol/ms] | Latency Avg [ms] | Latency 90% [ms] | Latency 95% [ms] | Latency 99% [ms] |
562562
|:----------:|:-----------------------:|:----------------:|:----------------:|:----------------:|:----------------:|
563-
| 1600 | 6.59 | 248.02 | 242.11 | 242.62 | 674.60 |
564-
| 800 | 6.38 | 126.49 | 125.96 | 126.31 | 127.72 |
565-
| 400 | 5.90 | 68.24 | 68.53 | 69.02 | 70.87 |
563+
| 1600 | 5.39 | 306.6 | 321.2 | 324.9 | 819.1 |
564+
| 800 | 4.67 | 179.8 | 201.5 | 203.8 | 213.3 |
565+
| 400 | 4.25 | 108.2 | 142.0 | 143.0 | 149.0 |
566566
567567
FP32
568568
569569
| Batch size | Throughput Avg [mol/ms] | Latency Avg [ms] | Latency 90% [ms] | Latency 95% [ms] | Latency 99% [ms] |
570570
|:----------:|:-----------------------:|:----------------:|:----------------:|:----------------:|:----------------:|
571-
| 1600 | 3.33 | 482.20 | 483.50 | 485.28 | 754.84 |
572-
| 800 | 3.35 | 239.09 | 242.21 | 243.13 | 244.91 |
573-
| 400 | 3.27 | 122.68 | 123.60 | 124.18 | 125.85 |
571+
| 1600 | 3.14 | 510.9 | 518.83 | 521.1 | 808.0 |
572+
| 800 | 3.10 | 258.7 | 269.4 | 271.1 | 278.9 |
573+
| 400 | 2.93 | 137.3 | 147.5 | 148.8 | 151.7 |
574574
575575
576576
To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
@@ -580,6 +580,10 @@ To achieve these same results, follow the steps in the [Quick Start Guide](#quic
580580
581581
### Changelog
582582
583+
February 2023:
584+
- Upgraded base container
585+
- Fixed benchmarking code
586+
583587
August 2022:
584588
- Slight performance improvements
585589
- Upgraded base container
@@ -604,3 +608,4 @@ August 2021
604608
### Known issues
605609
606610
If you encounter `OSError: [Errno 12] Cannot allocate memory` during the Dataloader iterator creation (more precisely during the `fork()`, this is most likely due to the use of the `--precompute_bases` flag. If you cannot add more RAM or Swap to your machine, it is recommended to turn off bases precomputation by removing the `--precompute_bases` flag or using `--precompute_bases false`.
611+

DGLPyTorch/DrugDiscovery/SE3Transformer/scripts/benchmark_train.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ AMP=${2:-true}
88
CUDA_VISIBLE_DEVICES=0 python -m se3_transformer.runtime.training \
99
--amp "$AMP" \
1010
--batch_size "$BATCH_SIZE" \
11-
--epochs 6 \
11+
--epochs 16 \
1212
--use_layer_norm \
1313
--norm \
1414
--save_ckpt_path model_qm9.pth \

DGLPyTorch/DrugDiscovery/SE3Transformer/scripts/benchmark_train_multi_gpu.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ python -m torch.distributed.run --nnodes=1 --nproc_per_node=gpu --max_restarts 0
99
se3_transformer.runtime.training \
1010
--amp "$AMP" \
1111
--batch_size "$BATCH_SIZE" \
12-
--epochs 6 \
12+
--epochs 16 \
1313
--use_layer_norm \
1414
--norm \
1515
--save_ckpt_path model_qm9.pth \

DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/layers/convolution.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def __init__(
113113
nn.Linear(mid_dim, num_freq * channels_in * channels_out, bias=False)
114114
]
115115

116-
self.net = nn.Sequential(*[m for m in modules if m is not None])
116+
self.net = torch.jit.script(nn.Sequential(*[m for m in modules if m is not None]))
117117

118118
def forward(self, features: Tensor) -> Tensor:
119119
return self.net(features)

DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/layers/norm.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,15 @@
3232
from se3_transformer.model.fiber import Fiber
3333

3434

35+
@torch.jit.script
36+
def clamped_norm(x, clamp: float):
37+
return x.norm(p=2, dim=-1, keepdim=True).clamp(min=clamp)
38+
39+
@torch.jit.script
40+
def rescale(x, norm, new_norm):
41+
return x / norm * new_norm
42+
43+
3544
class NormSE3(nn.Module):
3645
"""
3746
Norm-based SE(3)-equivariant nonlinearity.
@@ -63,7 +72,7 @@ def forward(self, features: Dict[str, Tensor], *args, **kwargs) -> Dict[str, Ten
6372
output = {}
6473
if hasattr(self, 'group_norm'):
6574
# Compute per-degree norms of features
66-
norms = [features[str(d)].norm(dim=-1, keepdim=True).clamp(min=self.NORM_CLAMP)
75+
norms = [clamped_norm(features[str(d)], self.NORM_CLAMP)
6776
for d in self.fiber.degrees]
6877
fused_norms = torch.cat(norms, dim=-2)
6978

@@ -73,11 +82,11 @@ def forward(self, features: Dict[str, Tensor], *args, **kwargs) -> Dict[str, Ten
7382

7483
# Scale features to the new norms
7584
for norm, new_norm, d in zip(norms, new_norms, self.fiber.degrees):
76-
output[str(d)] = features[str(d)] / norm * new_norm
85+
output[str(d)] = rescale(features[str(d)], norm, new_norm)
7786
else:
7887
for degree, feat in features.items():
79-
norm = feat.norm(dim=-1, keepdim=True).clamp(min=self.NORM_CLAMP)
88+
norm = clamped_norm(feat, self.NORM_CLAMP)
8089
new_norm = self.nonlinearity(self.layer_norms[degree](norm.squeeze(-1)).unsqueeze(-1))
81-
output[degree] = new_norm * feat / norm
90+
output[degree] = rescale(new_norm, feat, norm)
8291

8392
return output

DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/runtime/callbacks.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ def __init__(self, logger, batch_size: int, warmup_epochs: int = 1, mode: str =
133133

134134
def on_batch_start(self):
135135
if self.epoch >= self.warmup_epochs:
136+
torch.cuda.synchronize()
136137
self.timestamps.append(time.time() * 1000.0)
137138

138139
def _log_perf(self):
@@ -153,7 +154,7 @@ def on_fit_end(self):
153154
def process_performance_stats(self):
154155
timestamps = np.asarray(self.timestamps)
155156
deltas = np.diff(timestamps)
156-
throughput = (self.batch_size / deltas).mean()
157+
throughput = self.batch_size / deltas.mean()
157158
stats = {
158159
f"throughput_{self.mode}": throughput,
159160
f"latency_{self.mode}_mean": deltas.mean(),

0 commit comments

Comments
 (0)