-
Notifications
You must be signed in to change notification settings - Fork 3.3k
/
Copy pathtft_torchhub.py
95 lines (85 loc) · 4.22 KB
/
tft_torchhub.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import urllib.request
from zipfile import ZipFile
import torch
from torch.utils.data import DataLoader
NGC_CHECKPOINT_URLS = {}
NGC_CHECKPOINT_URLS["electricity"] = "https://api.ngc.nvidia.com/v2/models/nvidia/dle/tft_base_pyt_ckpt_ds-electricity/versions/22.11.0_amp/zip"
NGC_CHECKPOINT_URLS["traffic"] = "https://api.ngc.nvidia.com/v2/models/nvidia/dle/tft_base_pyt_ckpt_ds-traffic/versions/22.11.0_amp/zip"
def _download_checkpoint(checkpoint, force_reload):
model_dir = os.path.join(torch.hub._get_torch_home(), 'checkpoints')
if not os.path.exists(model_dir):
os.makedirs(model_dir)
ckpt_file = os.path.join(model_dir, os.path.basename(checkpoint))
if not os.path.exists(ckpt_file) or force_reload:
sys.stderr.write('Downloading checkpoint from {}\n'.format(checkpoint))
urllib.request.urlretrieve(checkpoint, ckpt_file)
with ZipFile(ckpt_file, "r") as zf:
zf.extractall(path=model_dir)
return os.path.join(model_dir, "checkpoint.pt")
def nvidia_tft(pretrained=True, **kwargs):
from .modeling import TemporalFusionTransformer
"""Constructs a TFT model.
For detailed information on model input and output, training recipies, inference and performance
visit: github.com/NVIDIA/DeepLearningExamples and/or ngc.nvidia.com
Args (type[, default value]):
pretrained (bool, True): If True, returns a pretrained model.
dataset (str, 'electricity'): loads selected model type electricity or traffic. Defaults to electricity
"""
ds_type = kwargs.get("dataset", "electricity")
ckpt = _download_checkpoint(NGC_CHECKPOINT_URLS[ds_type], True)
state_dict = torch.load(ckpt)
config = state_dict['config']
model = TemporalFusionTransformer(config)
if pretrained:
model.load_state_dict(state_dict['model'])
model.eval()
return model
def nvidia_tft_data_utils(**kwargs):
from .data_utils import TFTDataset
from .configuration import ElectricityConfig
class Processing:
@staticmethod
def download_data(path):
if not os.path.exists(os.path.join(path, "raw")):
os.makedirs(os.path.join(path, "raw"), exist_ok=True)
dataset_url = "https://archive.ics.uci.edu/ml/machine-learning-databases/00321/LD2011_2014.txt.zip"
ckpt_file = os.path.join(path, "raw/electricity.zip")
if not os.path.exists(ckpt_file):
sys.stderr.write('Downloading checkpoint from {}\n'.format(dataset_url))
urllib.request.urlretrieve(dataset_url, ckpt_file)
with ZipFile(ckpt_file, "r") as zf:
zf.extractall(path=os.path.join(path, "raw/electricity/"))
@staticmethod
def preprocess(path):
config = ElectricityConfig()
if not os.path.exists(os.path.join(path, "processed")):
os.makedirs(os.path.join(path, "processed"), exist_ok=True)
from data_utils import standarize_electricity as standarize
from data_utils import preprocess
standarize(os.path.join(path, "raw/electricity"))
preprocess(os.path.join(path, "raw/electricity/standarized.csv"), os.path.join(path, "processed/electricity_bin/"), config)
@staticmethod
def get_batch(path):
config = ElectricityConfig()
test_split = TFTDataset(os.path.join(path, "processed/electricity_bin/", "test.csv"), config)
data_loader = DataLoader(test_split, batch_size=16, num_workers=0)
for i, batch in enumerate(data_loader):
if i == 40:
break
return batch
return Processing()