Skip to content

Commit f148d3e

Browse files
committed
Change default dataset download path to user home directory
1 parent 5116a3b commit f148d3e

File tree

4 files changed

+60
-23
lines changed

4 files changed

+60
-23
lines changed

aeon/datasets/_data_loaders.py

Lines changed: 54 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from aeon.utils.conversion import convert_collection
4242

4343
DIRNAME = "data"
44-
MODULE = Path(aeon.__file__).parent / "datasets"
44+
MODULE = os.path.join(os.path.dirname(aeon.__file__), "datasets")
4545

4646
CONNECTION_ERRORS = (
4747
HTTPError,
@@ -466,7 +466,9 @@ def _download_and_extract(url, extract_path=None):
466466
with open(zip_file_name, "wb") as out_file:
467467
out_file.write(response.read())
468468
if extract_path is None:
469-
extract_path = os.path.join(MODULE, "local_data/%s/" % file_name.split(".")[0])
469+
extract_path = os.path.join(
470+
str(Path.home() / ".aeon"), "local_data/%s/" % file_name.split(".")[0]
471+
)
470472
else:
471473
extract_path = os.path.join(extract_path, "%s/" % file_name.split(".")[0])
472474

@@ -525,8 +527,14 @@ def _load_tsc_dataset(
525527
local_module = extract_path
526528
local_dirname = ""
527529
else:
528-
local_module = MODULE
529-
local_dirname = "data"
530+
bundled_path = os.path.join(MODULE, "data", name)
531+
if os.path.exists(bundled_path):
532+
local_module = MODULE
533+
local_dirname = "data"
534+
else:
535+
aeon_home = Path.home() / ".aeon"
536+
local_module = str(aeon_home)
537+
local_dirname = "data"
530538

531539
if not os.path.exists(os.path.join(local_module, local_dirname)):
532540
os.makedirs(os.path.join(local_module, local_dirname))
@@ -546,7 +554,11 @@ def _load_tsc_dataset(
546554
try:
547555
_download_and_extract(
548556
url,
549-
extract_path=extract_path,
557+
extract_path=(
558+
extract_path
559+
if extract_path is not None
560+
else os.path.join(local_module, local_dirname)
561+
),
550562
)
551563
except zipfile.BadZipFile as e:
552564
raise ValueError(
@@ -988,8 +1000,13 @@ def load_forecasting(name, extract_path=None, return_metadata=False):
9881000
local_module = extract_path
9891001
local_dirname = ""
9901002
else:
991-
local_module = MODULE
992-
local_dirname = "data"
1003+
bundled_path = os.path.join(MODULE, "data", name)
1004+
if os.path.exists(bundled_path):
1005+
local_module = MODULE
1006+
local_dirname = "data"
1007+
else:
1008+
local_module = str(Path.home() / ".aeon")
1009+
local_dirname = "data"
9931010

9941011
if not os.path.exists(os.path.join(local_module, local_dirname)):
9951012
os.makedirs(os.path.join(local_module, local_dirname))
@@ -1029,7 +1046,11 @@ def load_forecasting(name, extract_path=None, return_metadata=False):
10291046
try:
10301047
_download_and_extract(
10311048
url,
1032-
extract_path=extract_path,
1049+
extract_path=(
1050+
extract_path
1051+
if extract_path is not None
1052+
else os.path.join(local_module, local_dirname)
1053+
),
10331054
)
10341055
except zipfile.BadZipFile:
10351056
raise ValueError(
@@ -1142,8 +1163,13 @@ def load_regression(
11421163
local_module = extract_path
11431164
local_dirname = ""
11441165
else:
1145-
local_module = MODULE
1146-
local_dirname = "data"
1166+
bundled_path = os.path.join(MODULE, "data", name)
1167+
if os.path.exists(bundled_path):
1168+
local_module = MODULE
1169+
local_dirname = "data"
1170+
else:
1171+
local_module = str(Path.home() / ".aeon")
1172+
local_dirname = "data"
11471173
error_str = (
11481174
f"File name {name} is not in the list of valid files to download,"
11491175
f"see aeon.datasets.tser_datasetss.tser_soton for the list. "
@@ -1183,7 +1209,11 @@ def load_regression(
11831209
try:
11841210
_download_and_extract(
11851211
url,
1186-
extract_path=extract_path,
1212+
extract_path=(
1213+
extract_path
1214+
if extract_path is not None
1215+
else os.path.join(local_module, local_dirname)
1216+
),
11871217
)
11881218
except zipfile.BadZipFile:
11891219
try_monash = True
@@ -1323,8 +1353,13 @@ def load_classification(
13231353
local_module = extract_path
13241354
local_dirname = None
13251355
else:
1326-
local_module = MODULE
1327-
local_dirname = "data"
1356+
bundled_path = os.path.join(MODULE, "data", name)
1357+
if os.path.exists(bundled_path):
1358+
local_module = MODULE
1359+
local_dirname = "data"
1360+
else:
1361+
local_module = str(Path.home() / ".aeon")
1362+
local_dirname = "data"
13281363
if local_dirname is None:
13291364
path = local_module
13301365
else:
@@ -1363,7 +1398,11 @@ def load_classification(
13631398
try:
13641399
_download_and_extract(
13651400
url,
1366-
extract_path=extract_path,
1401+
extract_path=(
1402+
extract_path
1403+
if extract_path is not None
1404+
else os.path.join(local_module, local_dirname)
1405+
),
13671406
)
13681407
except zipfile.BadZipFile:
13691408
try_zenodo = True
@@ -1444,7 +1483,7 @@ def download_all_regression(extract_path=None):
14441483
local_module = extract_path
14451484
local_dirname = ""
14461485
else:
1447-
local_module = MODULE
1486+
local_module = str(Path.home() / ".aeon")
14481487
local_dirname = "data"
14491488

14501489
if not os.path.exists(os.path.join(local_module, local_dirname)):

aeon/datasets/_single_problem_loaders.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
]
2525

2626
import os
27-
from pathlib import Path
2827

2928
import numpy as np
3029
import pandas as pd
@@ -33,7 +32,7 @@
3332
from aeon.datasets._data_loaders import _load_saved_dataset, _load_tsc_dataset
3433

3534
DIRNAME = "data"
36-
MODULE = Path(__file__).parent
35+
MODULE = os.path.dirname(__file__)
3736

3837

3938
def load_gunpoint(split=None, return_type="numpy3d"):
@@ -990,4 +989,4 @@ def load_longley(return_array=True):
990989
data = data.astype(float)
991990
if return_array:
992991
return data.to_numpy().T
993-
return data.T
992+
return data.T

aeon/datasets/dataset_collections.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,13 @@
3434
"get_available_tsf_datasets",
3535
]
3636
import os
37-
from pathlib import Path
3837

3938
import aeon
4039
from aeon.datasets.tsc_datasets import multivariate, univariate
4140
from aeon.datasets.tser_datasets import tser_monash, tser_soton
4241
from aeon.datasets.tsf_datasets import tsf_all
4342

44-
MODULE = Path(aeon.__file__).parent / "datasets"
43+
MODULE = os.path.join(os.path.dirname(aeon.__file__), "datasets")
4544

4645

4746
def get_available_tser_datasets(name="tser_soton", return_list=True):
@@ -160,4 +159,4 @@ def get_downloaded_tsf_datasets(extract_path=None):
160159
all_files = os.listdir(sub_dir)
161160
if name + ".tsf" in all_files:
162161
datasets.append(name)
163-
return datasets
162+
return datasets

aeon/datasets/tests/test_data_loaders.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def test_load_forecasting_from_repo():
5757
assert not meta["contain_missing_values"]
5858
assert not meta["contain_equal_length"]
5959

60-
shutil.rmtree(os.path.dirname(__file__) + "/../local_data")
60+
shutil.rmtree(os.path.dirname(__file__) + "/../local_data", ignore_errors=True)
6161

6262

6363
@pytest.mark.skipif(
@@ -84,7 +84,7 @@ def test_load_classification_from_repo():
8484
assert meta["classlabel"]
8585
assert not meta["targetlabel"]
8686
assert meta["class_values"] == ["1", "2"]
87-
shutil.rmtree(os.path.dirname(__file__) + "/../local_data")
87+
shutil.rmtree(os.path.dirname(__file__) + "/../local_data", ignore_errors=True)
8888

8989

9090
@pytest.mark.skipif(

0 commit comments

Comments
 (0)