Skip to content

Commit a07a807

Browse files
IzzyPuttermannv-kkudrynski
authored andcommitted
[TFT/PyTorch] Adding TFT to Torchhub
1 parent 7475648 commit a07a807

File tree

3 files changed

+101
-1
lines changed

3 files changed

+101
-1
lines changed

PyTorch/Forecasting/TFT/modeling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ def forward(self, x: Tensor, mask_future_timesteps: bool = True) -> Tuple[Tensor
265265
out = self.out_proj(m_attn_vec)
266266
out = self.out_dropout(out)
267267

268-
return out, attn_vec
268+
return out, attn_prob
269269

270270

271271

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import sys
17+
import urllib.request
18+
from zipfile import ZipFile
19+
import torch
20+
from torch.utils.data import DataLoader
21+
NGC_CHECKPOINT_URLS = {}
22+
NGC_CHECKPOINT_URLS["electricity"] = "https://api.ngc.nvidia.com/v2/models/nvidia/tft_pyt_ckpt_base_eletricity_amp/versions/21.06.0/zip"
23+
NGC_CHECKPOINT_URLS["traffic"] = "https://api.ngc.nvidia.com/v2/models/nvidia/tft_pyt_ckpt_base_traffic_amp/versions/21.06.0/zip"
24+
25+
26+
def _download_checkpoint(checkpoint, force_reload):
27+
model_dir = os.path.join(torch.hub._get_torch_home(), 'checkpoints')
28+
if not os.path.exists(model_dir):
29+
os.makedirs(model_dir)
30+
ckpt_file = os.path.join(model_dir, os.path.basename(checkpoint))
31+
if not os.path.exists(ckpt_file) or force_reload:
32+
sys.stderr.write('Downloading checkpoint from {}\n'.format(checkpoint))
33+
urllib.request.urlretrieve(checkpoint, ckpt_file)
34+
with ZipFile(ckpt_file, "r") as zf:
35+
zf.extractall(path=model_dir)
36+
return os.path.join(model_dir, "checkpoint.pt")
37+
38+
def nvidia_tft(pretrained=True, **kwargs):
39+
from .modeling import TemporalFusionTransformer
40+
"""Constructs a TFT model.
41+
For detailed information on model input and output, training recipies, inference and performance
42+
visit: github.com/NVIDIA/DeepLearningExamples and/or ngc.nvidia.com
43+
Args (type[, default value]):
44+
pretrained (bool, True): If True, returns a pretrained model.
45+
dataset (str, 'electricity'): loads selected model type electricity or traffic. Defaults to electricity
46+
"""
47+
ds_type = kwargs.get("dataset", "electricity")
48+
ckpt = _download_checkpoint(NGC_CHECKPOINT_URLS[ds_type], True)
49+
state_dict = torch.load(ckpt)
50+
config = state_dict['config']
51+
52+
model = TemporalFusionTransformer(config)
53+
if pretrained:
54+
model.load_state_dict(state_dict['model'])
55+
model.eval()
56+
return model
57+
58+
def nvidia_tft_data_utils(**kwargs):
59+
60+
from .data_utils import TFTDataset
61+
from .configuration import ElectricityConfig
62+
class Processing:
63+
@staticmethod
64+
def download_data(path):
65+
if not os.path.exists(os.path.join(path, "raw")):
66+
os.makedirs(os.path.join(path, "raw"), exist_ok=True)
67+
dataset_url = "https://archive.ics.uci.edu/ml/machine-learning-databases/00321/LD2011_2014.txt.zip"
68+
ckpt_file = os.path.join(path, "raw/electricity.zip")
69+
if not os.path.exists(ckpt_file):
70+
sys.stderr.write('Downloading checkpoint from {}\n'.format(dataset_url))
71+
urllib.request.urlretrieve(dataset_url, ckpt_file)
72+
with ZipFile(ckpt_file, "r") as zf:
73+
zf.extractall(path=os.path.join(path, "raw/electricity/"))
74+
75+
@staticmethod
76+
def preprocess(path):
77+
config = ElectricityConfig()
78+
if not os.path.exists(os.path.join(path, "processed")):
79+
os.makedirs(os.path.join(path, "processed"), exist_ok=True)
80+
from data_utils import standarize_electricity as standarize
81+
from data_utils import preprocess
82+
standarize(os.path.join(path, "raw/electricity"))
83+
preprocess(os.path.join(path, "raw/electricity/standarized.csv"), os.path.join(path, "processed/electricity_bin/"), config)
84+
85+
86+
@staticmethod
87+
def get_batch(path):
88+
config = ElectricityConfig()
89+
test_split = TFTDataset(os.path.join(path, "processed/electricity_bin/", "test.csv"), config)
90+
data_loader = DataLoader(test_split, batch_size=16, num_workers=0)
91+
for i, batch in enumerate(data_loader):
92+
if i == 40:
93+
break
94+
return batch
95+
96+
return Processing()
97+

hubconf.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,6 @@
2424
from PyTorch.SpeechSynthesis.Tacotron2.tacotron2 import nvidia_tts_utils
2525
from PyTorch.SpeechSynthesis.Tacotron2.waveglow import nvidia_waveglow
2626
sys.path.append(os.path.join(sys.path[0], 'PyTorch/SpeechSynthesis/Tacotron2'))
27+
28+
from PyTorch.Forecasting.TFT.tft_torchhub import nvidia_tft, nvidia_tft_data_utils
29+
sys.path.append(os.path.join(sys.path[0], 'PyTorch/Forecasting/TFT'))

0 commit comments

Comments
 (0)