Skip to content

Commit 777d174

Browse files
IzzyPuttermannv-kkudrynski
authored andcommitted
[TFT/PyTorch] Move to nvFuser
1 parent e040735 commit 777d174

File tree

16 files changed

+348
-204
lines changed

16 files changed

+348
-204
lines changed

PyTorch/Forecasting/TFT/Dockerfile

100644100755
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:21.12-py3
15+
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:22.11-py3
16+
1617
FROM ${FROM_IMAGE_NAME}
1718

1819
# Set workdir and python path

PyTorch/Forecasting/TFT/Dockerfile-triton

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:21.12-py3
15+
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:22.11-py3
1616
FROM ${FROM_IMAGE_NAME}
1717

1818
# Ensure apt-get won't prompt for selecting options

PyTorch/Forecasting/TFT/README.md

Lines changed: 43 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,6 @@ For information about:
123123
Training of Deep Neural
124124
Networks](https://devblogs.nvidia.com/mixed-precision-training-deep-neural-networks/)
125125
blog.
126-
* APEX tools for mixed precision training, refer to the [NVIDIA Apex: Tools for Easy Mixed-Precision Training in
127-
PyTorch](https://devblogs.nvidia.com/apex-pytorch-easy-mixed-precision-training/)
128-
.
129126

130127

131128
#### Enabling mixed precision
@@ -169,7 +166,7 @@ The following section lists the requirements that you need to meet in order to s
169166

170167
This repository contains Dockerfile, which extends the PyTorch NGC container and encapsulates some dependencies. Aside from these dependencies, ensure you have the following components:
171168
- [NVIDIA Docker](https://github.com/NVIDIA/nvidia-docker)
172-
- [PyTorch 21.12 NGC container](https://ngc.nvidia.com/catalog/containers/nvidia:pytorch)
169+
- [PyTorch 22.11 NGC container](https://ngc.nvidia.com/catalog/containers/nvidia:pytorch)
173170
- Supported GPUs:
174171
- [NVIDIA Volta architecture](https://www.nvidia.com/en-us/data-center/volta-gpu-architecture/)
175172
- [NVIDIA Turing architecture](https://www.nvidia.com/en-us/design-visualization/technologies/turing-architecture/)
@@ -371,7 +368,7 @@ The [NVIDIA Triton Inference Server](https://github.com/triton-inference-server/
371368

372369
### Benchmarking
373370

374-
The following section shows how to run benchmarks measuring the model performance in training and inference modes.
371+
The following section shows how to run benchmarks measuring the model performance in training and inference modes. Note that the first 3 steps of each epoch are not used in the throughput or latency calculation. This is due to the fact that the nvFuser performs the optimizations on the 3rd step of the first epoch causing a multi-second pause.
375372

376373
#### Training performance benchmark
377374

@@ -390,24 +387,24 @@ We conducted an extensive hyperparameter search along with stability tests. The
390387

391388
##### Training accuracy: NVIDIA DGX A100 (8x A100 80GB)
392389

393-
Our results were obtained by running the `train.sh` training script in the [PyTorch 21.06 NGC container](https://ngc.nvidia.com/catalog/containers/nvidia:pytorch) on NVIDIA A100 (8x A100 80GB) GPUs.
390+
Our results were obtained by running the `train.sh` training script in the [PyTorch 22.11 NGC container](https://ngc.nvidia.com/catalog/containers/nvidia:pytorch) on NVIDIA A100 (8x A100 80GB) GPUs.
394391

395392
| Dataset | GPUs | Batch size / GPU | Accuracy - TF32 | Accuracy - mixed precision | Time to train - TF32 | Time to train - mixed precision | Time to train speedup (TF32 to mixed precision)
396393
|-------------|---|------|-----------------------|-----------------------|-------|-------|-------
397-
| Electricity | 8 | 1024 | 0.027 / 0.057 / 0.029 | 0.028 / 0.057 / 0.029 | 216s | 176s | 1.227x
398-
| Traffic | 8 | 1024 | 0.043 / 0.108 / 0.079 | 0.042 / 0.107 / 0.078 | 151s | 126s | 1.198x
394+
| Electricity | 8 | 1024 | 0.026 / 0.056 / 0.029 | 0.028 / 0.058 / 0.029 | 200s | 176s | 1.136x
395+
| Traffic | 8 | 1024 | 0.044 / 0.108 / 0.078 | 0.044 / 0.109 / 0.079 | 140s | 129s | 1.085x
399396

400397

401398

402399

403400
##### Training accuracy: NVIDIA DGX-1 (8x V100 16GB)
404401

405-
Our results were obtained by running the `train.sh` training script in the [PyTorch 21.06 NGC container](https://ngc.nvidia.com/catalog/containers/nvidia:pytorch) on NVIDIA DGX-1 with (8x V100 16GB) GPUs.
402+
Our results were obtained by running the `train.sh` training script in the [PyTorch 22.11 NGC container](https://ngc.nvidia.com/catalog/containers/nvidia:pytorch) on NVIDIA DGX-1 with (8x V100 16GB) GPUs.
406403

407404
| Dataset | GPUs | Batch size / GPU | Accuracy - FP32 | Accuracy - mixed precision | Time to train - FP32 | Time to train - mixed precision | Time to train speedup (FP32 to mixed precision)
408405
|-------------|---|------|-----------------------|-----------------------|-------|-------|-----------
409-
| Electricity | 8 | 1024 | 0.028 / 0.057 / 0.029 | 0.027 / 0.057 / 0.029 | 381s | 261s | 1.460x
410-
| Traffic | 8 | 1024 | 0.042 / 0.106 / 0.076 | 0.040 / 0.103 / 0.074 | 256s | 176s | 1.455x
406+
| Electricity | 8 | 1024 | 0.028 / 0.057 / 0.028 | 0.027 / 0.059 / 0.030 | 371s | 269s | 1.379x
407+
| Traffic | 8 | 1024 | 0.042 / 0.110 / 0.080 | 0.043 / 0.109 / 0.080 | 251s | 191s | 1.314x
411408

412409

413410

@@ -417,22 +414,22 @@ In order to get a greater picture of the model’s accuracy, we performed a hype
417414

418415
| Dataset | #GPU | Hidden size | #Heads | Local BS | LR | Gradient clipping | Dropout | Mean q-risk | Std q-risk | Min q-risk | Max q-risk
419416
|-------------|------|-------------|--------|----------|------|-------------------|---------|-------------|------------| -----------|------
420-
| Electricity | 8 | 128 | 4 | 1024 | 1e-3 | 0.0 | 0.1 | 0.1131 | 0.0025 | 0.1080 | 0.1200
421-
| Traffic | 8 | 128 | 4 | 1024 | 1e-3 | 0.0 | 0.3 | 0.2180 | 0.0049 | 0.2069 | 0.2336
417+
| Electricity | 8 | 128 | 4 | 1024 | 1e-3 | 0.0 | 0.1 | 0.1129 | 0.0025 | 0.1074 | 0.1244
418+
| Traffic | 8 | 128 | 4 | 1024 | 1e-3 | 0.0 | 0.3 | 0.2262 | 0.0027 | 0.2207 | 0.2331
422419

423420

424421
#### Training performance results
425422

426423
##### Training performance: NVIDIA DGX A100 (8x A100 80GB)
427424

428-
Our results were obtained by running the `train.sh` training script in the [PyTorch 21.06 NGC container](https://ngc.nvidia.com/catalog/containers/nvidia:pytorch) on NVIDIA A100 (8x A100 80GB) GPUs. Performance numbers (in items/images per second) were averaged over an entire training epoch.
425+
Our results were obtained by running the `train.sh` training script in the [PyTorch 22.11 NGC container](https://ngc.nvidia.com/catalog/containers/nvidia:pytorch) on NVIDIA A100 (8x A100 80GB) GPUs. Performance numbers (in items/images per second) were averaged over an entire training epoch.
429426

430427
| Dataset | GPUs | Batch size / GPU | Throughput - TF32 | Throughput - mixed precision | Throughput speedup (TF32 - mixed precision) | Weak scaling - TF32 | Weak scaling - mixed precision
431428
|-------------|---|------|--------|--------|-------|-------|-----
432-
| Electricity | 1 | 1024 | 10173 | 13703 | 1.35x | 1 | 1
433-
| Electricity | 8 | 1024 | 80596 | 107761 | 1.34x | 7.92x | 7.86x
434-
| Traffic | 1 | 1024 | 10197 | 13779 | 1.35x | 1 | 1
435-
| Traffic | 8 | 1024 | 80692 | 107979 | 1.34x | 7.91x | 7.84x
429+
| Electricity | 1 | 1024 | 12435 | 17608 | 1.42x | 1 | 1
430+
| Electricity | 8 | 1024 | 94389 | 130769 | 1.39x | 7.59x | 7.42x
431+
| Traffic | 1 | 1024 | 12509 | 17591 | 1.40x | 1 | 1
432+
| Traffic | 8 | 1024 | 94476 | 130992 | 1.39x | 7.55x | 7.45x
436433

437434

438435
To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
@@ -442,14 +439,14 @@ The performance metrics used were items per second.
442439

443440
##### Training performance: NVIDIA DGX-1 (8x V100 16GB)
444441

445-
Our results were obtained by running the `train.sh` training script in the [PyTorch 21.06 NGC container](https://ngc.nvidia.com/catalog/containers/nvidia:pytorch) on NVIDIA DGX-1 with (8x V100 16GB) GPUs. Performance numbers (in items/images per second) were averaged over an entire training epoch.
442+
Our results were obtained by running the `train.sh` training script in the [PyTorch 22.11 NGC container](https://ngc.nvidia.com/catalog/containers/nvidia:pytorch) on NVIDIA DGX-1 with (8x V100 16GB) GPUs. Performance numbers (in items/images per second) were averaged over an entire training epoch.
446443

447444
| Dataset | GPUs | Batch size / GPU | Throughput - FP32 | Throughput - mixed precision | Throughput speedup (FP32 - mixed precision) | Weak scaling - FP32 | Weak scaling - mixed precision
448445
|-------------|---|------|-------|-------|-------|------|----
449-
| Electricity | 1 | 1024 | 5580 | 9148 | 1.64x | 1 | 1
450-
| Electricity | 8 | 1024 | 43351 | 69855 | 1.61x | 7.77x | 7.64x
451-
| Traffic | 1 | 1024 | 5593 | 9194 | 1.64x | 1 | 1
452-
| Traffic | 8 | 1024 | 43426 | 69983 | 1.61x | 7.76x | 7.61x
446+
| Electricity | 1 | 1024 | 5932 | 10163 | 1.71x | 1 | 1
447+
| Electricity | 8 | 1024 | 45566 | 75660 | 1.66x | 7.68x | 7.44x
448+
| Traffic | 1 | 1024 | 5971 | 10166 | 1.70x | 1 | 1
449+
| Traffic | 8 | 1024 | 45925 | 75640 | 1.64x | 7.69x | 7.44x
453450

454451

455452

@@ -463,39 +460,44 @@ The performance metrics used were items per second.
463460

464461
##### Inference Performance: NVIDIA DGX A100
465462

466-
Our results were obtained by running the `inference.py` script in the [PyTorch 21.12 NGC container](https://ngc.nvidia.com/catalog/containers/nvidia:pytorch) on NVIDIA DGX A100. Throughput is measured in items per second and latency is measured in milliseconds.
463+
Our results were obtained by running the `inference.py` script in the [PyTorch 22.11 NGC container](https://ngc.nvidia.com/catalog/containers/nvidia:pytorch) on NVIDIA DGX A100. Throughput is measured in items per second and latency is measured in milliseconds.
467464
To benchmark the inference performance on a specific batch size and dataset, run the `inference.py` script.
468465
| Dataset | GPUs | Batch size / GPU | Throughput - mixed precision (item/s) | Average Latency (ms) | Latency p90 (ms) | Latency p95 (ms) | Latency p99 (ms)
469466
|-------------|--------|-----|---------------------------------|-----------------|-------------|-------------|------------
470-
| Electricity | 1 | 1 | 144.37 | 6.93 | 7.00 | 7.04 | 7.25
471-
| Electricity | 1 | 2 | 277.53 | 7.21 | 7.25 | 7.27 | 7.48
472-
| Electricity | 1 | 4 | 564.37 | 7.09 | 7.13 | 7.15 | 7.64
473-
| Electricity | 1 | 8 | 1399.25 | 5.72 | 5.71 | 5.77 | 7.51
474-
| Traffic | 1 | 1 | 145.26 | 6.88 | 6.91 | 6.95 | 7.60
475-
| Traffic | 1 | 2 | 277.97 | 7.19 | 7.28 | 7.30 | 7.46
476-
| Traffic | 1 | 4 | 563.05 | 7.10 | 7.14 | 7.16 | 7.42
477-
| Traffic | 1 | 8 | 1411.62 | 5.67 | 5.69 | 5.79 | 6.21
467+
| Electricity | 1 | 1 | 272.43 | 3.67 | 3.70 | 3.87 | 4.18
468+
| Electricity | 1 | 2 | 518.13 | 3.86 | 3.88 | 3.93 | 4.19
469+
| Electricity | 1 | 4 | 1039.31 | 3.85 | 3.89 | 3.97 | 4.15
470+
| Electricity | 1 | 8 | 2039.54 | 3.92 | 3.93 | 3.95 | 4.32
471+
| Traffic | 1 | 1 | 269.59 | 3.71 | 3.74 | 3.79 | 4.30
472+
| Traffic | 1 | 2 | 518.73 | 3.86 | 3.78 | 3.91 | 4.66
473+
| Traffic | 1 | 4 | 1021.49 | 3.92 | 3.94 | 3.95 | 4.25
474+
| Traffic | 1 | 8 | 2005.54 | 3.99 | 4.01 | 4.03 | 4.39
478475

479476

480477
##### Inference Performance: NVIDIA DGX-1 V100
481478

482-
Our results were obtained by running the `inference.py` script in the [PyTorch 21.12 NGC container](https://ngc.nvidia.com/catalog/containers/nvidia:pytorch) on NVIDIA DGX-1 V100. Throughput is measured in items per second and latency is measured in milliseconds.
479+
Our results were obtained by running the `inference.py` script in the [PyTorch 22.11 NGC container](https://ngc.nvidia.com/catalog/containers/nvidia:pytorch) on NVIDIA DGX-1 V100. Throughput is measured in items per second and latency is measured in milliseconds.
483480
To benchmark the inference performance on a specific batch size and dataset, run the `inference.py` script.
484481
| Dataset | GPUs | Batch size / GPU | Throughput - mixed precision (item/s) | Average Latency (ms) | Latency p90 (ms) | Latency p95 (ms) | Latency p99 (ms)
485482
|-------------|--------|-----|---------------------------------|-----------------|-------------|-------------|------------
486-
| Electricity | 1 | 1 | 95.65 | 10.45 | 11.30 | 11.95 | 12.13
487-
| Electricity | 1 | 2 | 193.15 | 10.35 | 10.80 | 11.46 | 12.16
488-
| Electricity | 1 | 4 | 381.09 | 10.49 | 10.75 | 12.29 | 12.41
489-
| Electricity | 1 | 8 | 805.49 | 9.93 | 10.41 | 10.48 | 10.91
490-
| Traffic | 1 | 1 | 96.72 | 10.34 | 10.53 | 11.99 | 12.13
491-
| Traffic | 1 | 2 | 192.93 | 10.37 | 10.80 | 11.97 | 12.12
492-
| Traffic | 1 | 4 | 379.00 | 10.55 | 10.88 | 11.09 | 11.96
493-
| Traffic | 1 | 8 | 859.69 | 9.30 | 10.58 | 10.65 | 11.28
483+
| Electricity | 1 | 1 | 171.68 | 5.82 | 5.99 | 6.17 | 7.00
484+
| Electricity | 1 | 2 | 318.92 | 6.27 | 6.43 | 6.60 | 7.51
485+
| Electricity | 1 | 4 | 684.79 | 5.84 | 6.02 | 6.08 | 6.47
486+
| Electricity | 1 | 8 | 1275.54 | 6.27 | 7.31 | 7.36 | 7.51
487+
| Traffic | 1 | 1 | 183.39 | 5.45 | 5.64 | 5.86 | 6.73
488+
| Traffic | 1 | 2 | 340.73 | 5.87 | 6.07 | 6.77 | 7.25
489+
| Traffic | 1 | 4 | 647.33 | 6.18 | 6.35 | 7.99 | 8.07
490+
| Traffic | 1 | 8 | 1364.39 | 5.86 | 6.07 | 6.40 | 7.31
494491
## Release notes
495492
The performance measurements in this document were conducted at the time of publication and may not reflect the performance achieved from NVIDIA’s latest software release. For the most up-to-date performance measurements, go to https://developer.nvidia.com/deep-learning-performance-training-inference.
496493

497494
### Changelog
498495

496+
March 2023
497+
- 23.01 Container Update
498+
- Switch from NVIDIA Apex AMP and NVIDIA Apex FusedLayerNorm to Native PyTorch AMP and Native PyTorch LayerNorm
499+
- Acceleration using NvFuser
500+
499501
February 2022
500502
- 21.12 Container Update
501503
- Triton Inference Performance Numbers

PyTorch/Forecasting/TFT/configuration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,5 +124,5 @@ def __init__(self):
124124

125125

126126
CONFIGS = {'electricity': ElectricityConfig,
127-
'traffic': TrafficConfig,
127+
'traffic': TrafficConfig,
128128
}

PyTorch/Forecasting/TFT/criterions.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import torch
1616
import torch.nn as nn
1717
import torch.nn.functional as F
18+
import numpy as np
1819

1920
class QuantileLoss(nn.Module):
2021
def __init__(self, config):
@@ -26,3 +27,11 @@ def forward(self, predictions, targets):
2627
ql = (1-self.q)*F.relu(diff) + self.q*F.relu(-diff)
2728
losses = ql.view(-1, ql.shape[-1]).mean(0)
2829
return losses
30+
31+
def qrisk(pred, tgt, quantiles):
32+
diff = pred - tgt
33+
ql = (1-quantiles)*np.clip(diff,0, float('inf')) + quantiles*np.clip(-diff,0, float('inf'))
34+
losses = ql.reshape(-1, ql.shape[-1])
35+
normalizer = np.abs(tgt).mean()
36+
risk = 2 * losses / normalizer
37+
return risk.mean(0)

PyTorch/Forecasting/TFT/data_utils.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@
4141
from bisect import bisect
4242

4343
import torch
44-
from torch.utils.data import Dataset,IterableDataset,DataLoader
44+
from torch.utils.data import Dataset, IterableDataset, DataLoader, DistributedSampler, RandomSampler
45+
from torch.utils.data.dataloader import default_collate
4546

4647
class DataTypes(enum.IntEnum):
4748
"""Defines numerical types of each column."""
@@ -401,6 +402,51 @@ def sample_data(dataset, num_samples):
401402
else:
402403
return torch.utils.data.Subset(dataset, np.random.choice(np.arange(len(dataset)), size=num_samples, replace=False))
403404

405+
def load_dataset(args, config, collate_fn=default_collate):
406+
from utils import print_once
407+
train_split = TFTBinaryDataset(os.path.join(args.data_path, 'train.bin'), config)
408+
train_split = sample_data(train_split, args.sample_data[0])
409+
if args.distributed_world_size > 1:
410+
data_sampler = DistributedSampler(train_split, args.distributed_world_size, args.distributed_rank, seed=args.seed + args.distributed_rank, drop_last=True)
411+
else:
412+
data_sampler = RandomSampler(train_split)
413+
train_loader = DataLoader(train_split,
414+
batch_size=args.batch_size,
415+
num_workers=4,
416+
sampler=data_sampler,
417+
collate_fn=collate_fn,
418+
pin_memory=True)
419+
420+
valid_split = TFTBinaryDataset(os.path.join(args.data_path, 'valid.bin'), config)
421+
valid_split = sample_data(valid_split, args.sample_data[1])
422+
if args.distributed_world_size > 1:
423+
data_sampler = DistributedSampler(valid_split, args.distributed_world_size, args.distributed_rank, shuffle=False, drop_last=False)
424+
else:
425+
data_sampler = None
426+
valid_loader = DataLoader(valid_split,
427+
batch_size=args.batch_size,
428+
sampler=data_sampler,
429+
num_workers=4,
430+
collate_fn=collate_fn,
431+
pin_memory=True)
432+
433+
test_split = TFTBinaryDataset(os.path.join(args.data_path, 'test.bin'), config)
434+
if args.distributed_world_size > 1:
435+
data_sampler = DistributedSampler(test_split, args.distributed_world_size, args.distributed_rank, shuffle=False, drop_last=False)
436+
else:
437+
data_sampler = None
438+
test_loader = DataLoader(test_split,
439+
batch_size=args.batch_size,
440+
sampler=data_sampler,
441+
num_workers=4,
442+
collate_fn=collate_fn,
443+
pin_memory=True)
444+
445+
print_once(f'Train split length: {len(train_split)}')
446+
print_once(f'Valid split length: {len(valid_split)}')
447+
print_once(f'Test split length: {len(test_split)}')
448+
449+
return train_loader, valid_loader, test_loader
404450

405451
def standarize_electricity(path):
406452
"""Code taken from https://github.com/google-research/google-research/blob/master/tft/script_download_data.py"""
@@ -574,4 +620,3 @@ def read_matrix(filename):
574620

575621
flat_df.to_csv(os.path.join(path, 'standarized.csv'))
576622

577-

0 commit comments

Comments
 (0)