4141from aeon .utils .conversion import convert_collection
4242
4343DIRNAME = "data"
44- MODULE = Path ( aeon .__file__ ). parent / "datasets"
44+ MODULE = os . path . join ( os . path . dirname ( aeon .__file__ ), "datasets" )
4545
4646CONNECTION_ERRORS = (
4747 HTTPError ,
@@ -466,7 +466,9 @@ def _download_and_extract(url, extract_path=None):
466466 with open (zip_file_name , "wb" ) as out_file :
467467 out_file .write (response .read ())
468468 if extract_path is None :
469- extract_path = os .path .join (MODULE , "local_data/%s/" % file_name .split ("." )[0 ])
469+ extract_path = os .path .join (
470+ str (Path .home () / ".aeon" ), "local_data/%s/" % file_name .split ("." )[0 ]
471+ )
470472 else :
471473 extract_path = os .path .join (extract_path , "%s/" % file_name .split ("." )[0 ])
472474
@@ -525,8 +527,14 @@ def _load_tsc_dataset(
525527 local_module = extract_path
526528 local_dirname = ""
527529 else :
528- local_module = MODULE
529- local_dirname = "data"
530+ bundled_path = os .path .join (MODULE , "data" , name )
531+ if os .path .exists (bundled_path ):
532+ local_module = MODULE
533+ local_dirname = "data"
534+ else :
535+ aeon_home = Path .home () / ".aeon"
536+ local_module = str (aeon_home )
537+ local_dirname = "data"
530538
531539 if not os .path .exists (os .path .join (local_module , local_dirname )):
532540 os .makedirs (os .path .join (local_module , local_dirname ))
@@ -546,7 +554,11 @@ def _load_tsc_dataset(
546554 try :
547555 _download_and_extract (
548556 url ,
549- extract_path = extract_path ,
557+ extract_path = (
558+ extract_path
559+ if extract_path is not None
560+ else os .path .join (local_module , local_dirname )
561+ ),
550562 )
551563 except zipfile .BadZipFile as e :
552564 raise ValueError (
@@ -988,8 +1000,13 @@ def load_forecasting(name, extract_path=None, return_metadata=False):
9881000 local_module = extract_path
9891001 local_dirname = ""
9901002 else :
991- local_module = MODULE
992- local_dirname = "data"
1003+ bundled_path = os .path .join (MODULE , "data" , name )
1004+ if os .path .exists (bundled_path ):
1005+ local_module = MODULE
1006+ local_dirname = "data"
1007+ else :
1008+ local_module = str (Path .home () / ".aeon" )
1009+ local_dirname = "data"
9931010
9941011 if not os .path .exists (os .path .join (local_module , local_dirname )):
9951012 os .makedirs (os .path .join (local_module , local_dirname ))
@@ -1029,7 +1046,11 @@ def load_forecasting(name, extract_path=None, return_metadata=False):
10291046 try :
10301047 _download_and_extract (
10311048 url ,
1032- extract_path = extract_path ,
1049+ extract_path = (
1050+ extract_path
1051+ if extract_path is not None
1052+ else os .path .join (local_module , local_dirname )
1053+ ),
10331054 )
10341055 except zipfile .BadZipFile :
10351056 raise ValueError (
@@ -1142,8 +1163,13 @@ def load_regression(
11421163 local_module = extract_path
11431164 local_dirname = ""
11441165 else :
1145- local_module = MODULE
1146- local_dirname = "data"
1166+ bundled_path = os .path .join (MODULE , "data" , name )
1167+ if os .path .exists (bundled_path ):
1168+ local_module = MODULE
1169+ local_dirname = "data"
1170+ else :
1171+ local_module = str (Path .home () / ".aeon" )
1172+ local_dirname = "data"
11471173 error_str = (
11481174 f"File name { name } is not in the list of valid files to download,"
11491175 f"see aeon.datasets.tser_datasetss.tser_soton for the list. "
@@ -1183,7 +1209,11 @@ def load_regression(
11831209 try :
11841210 _download_and_extract (
11851211 url ,
1186- extract_path = extract_path ,
1212+ extract_path = (
1213+ extract_path
1214+ if extract_path is not None
1215+ else os .path .join (local_module , local_dirname )
1216+ ),
11871217 )
11881218 except zipfile .BadZipFile :
11891219 try_monash = True
@@ -1323,8 +1353,13 @@ def load_classification(
13231353 local_module = extract_path
13241354 local_dirname = None
13251355 else :
1326- local_module = MODULE
1327- local_dirname = "data"
1356+ bundled_path = os .path .join (MODULE , "data" , name )
1357+ if os .path .exists (bundled_path ):
1358+ local_module = MODULE
1359+ local_dirname = "data"
1360+ else :
1361+ local_module = str (Path .home () / ".aeon" )
1362+ local_dirname = "data"
13281363 if local_dirname is None :
13291364 path = local_module
13301365 else :
@@ -1363,7 +1398,11 @@ def load_classification(
13631398 try :
13641399 _download_and_extract (
13651400 url ,
1366- extract_path = extract_path ,
1401+ extract_path = (
1402+ extract_path
1403+ if extract_path is not None
1404+ else os .path .join (local_module , local_dirname )
1405+ ),
13671406 )
13681407 except zipfile .BadZipFile :
13691408 try_zenodo = True
@@ -1444,7 +1483,7 @@ def download_all_regression(extract_path=None):
14441483 local_module = extract_path
14451484 local_dirname = ""
14461485 else :
1447- local_module = MODULE
1486+ local_module = str ( Path . home () / ".aeon" )
14481487 local_dirname = "data"
14491488
14501489 if not os .path .exists (os .path .join (local_module , local_dirname )):
0 commit comments