Skip to content

Commit 4be5b43

Browse files
authored
Fix check_install test script
1 parent 7a604cd commit 4be5b43

File tree

1 file changed

+9
-21
lines changed

1 file changed

+9
-21
lines changed

dlclive/check_install/check_install.py

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,17 @@
66
"""
77

88

9-
import os
109
import sys
1110
import shutil
1211
import warnings
13-
from dlclibrary.dlcmodelzoo.modelzoo_download import (
14-
download_huggingface_model,
15-
MODELOPTIONS,
16-
)
1712

1813
from dlclive import benchmark_videos
1914
import urllib.request
2015
import argparse
2116
from pathlib import Path
22-
import tarfile
17+
from dlclibrary.dlcmodelzoo.modelzoo_download import (
18+
download_huggingface_model,
19+
)
2320

2421

2522
def urllib_pbar(count, blockSize, totalSize):
@@ -29,6 +26,7 @@ def urllib_pbar(count, blockSize, totalSize):
2926
sys.stdout.write("\b"*len(outstr))
3027
sys.stdout.flush()
3128

29+
3230
def main(display:bool=None):
3331
parser = argparse.ArgumentParser(
3432
description="Test DLC-Live installation by downloading and evaluating a demo DLC project!")
@@ -47,33 +45,23 @@ def main(display:bool=None):
4745
tmp_dir.mkdir(mode=0o775,exist_ok=True)
4846

4947
video_file = str(tmp_dir / 'dog_clip.avi')
50-
model_tarball = tmp_dir / 'DLC_Dog_resnet_50_iteration-0_shuffle-0.tar.gz'
51-
model_dir = model_tarball.with_suffix('').with_suffix('') # remove two suffixes (tar.gz)
52-
48+
model_dir = tmp_dir / 'DLC_Dog_resnet_50_iteration-0_shuffle-0'
5349

5450
# download dog test video from github:
5551
print(f"Downloading Video to {video_file}")
5652
url_link = "https://github.com/DeepLabCut/DeepLabCut-live/blob/master/check_install/dog_clip.avi?raw=True"
5753
urllib.request.urlretrieve(url_link, video_file, reporthook=urllib_pbar)
5854

5955
# download exported dog model from DeepLabCut Model Zoo
60-
if Path(model_tarball).exists():
61-
print('Tarball already downloaded, using cached version')
56+
if Path(model_dir / 'snapshot-75000.pb').exists():
57+
print('Model already downloaded, using cached version')
6258
else:
6359
print("Downloading full_dog model from the DeepLabCut Model Zoo...")
64-
model = 'superanimal_quadruped'
65-
download_huggingface_model(model,model_dir)
66-
#model_url = "http://deeplabcut.rowland.harvard.edu/models/DLC_Dog_resnet_50_iteration-0_shuffle-0.tar.gz"
67-
#urllib.request.urlretrieve(model_url, str(model_tarball), reporthook=urllib_pbar)
68-
69-
#print('Untarring compressed model')
70-
#model_file = tarfile.open(str(model_tarball))
71-
# model_file.extractall(str(model_dir.parent))
72-
#model_file.close()
60+
download_huggingface_model("full_dog", model_dir)
7361

7462
# assert these things exist so we can give informative error messages
7563
assert Path(video_file).exists()
76-
assert Path(model_dir).exists() and Path(model_dir).is_dir()
64+
assert Path(model_dir / 'snapshot-75000.pb').exists()
7765

7866
# run benchmark videos
7967
print("\n Running inference...\n")

0 commit comments

Comments
 (0)