Skip to content

Commit 5c78370

Browse files
Add CIFAR_10 dataset loading and available for benchmarking (#121)
* Add CIFAR_10 dataset loading and available for benchmarking * Remove line according to PEP8
1 parent 3083ef8 commit 5c78370

File tree

3 files changed

+48
-2
lines changed

3 files changed

+48
-2
lines changed

configs/sklearn/performance/tsne.json

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,16 @@
2323
"x": "data/mnist_x_test.npy",
2424
"y": "data/mnist_y_test.npy"
2525
}
26-
}
26+
},
27+
{
28+
"source": "npy",
29+
"name": "cifar_10",
30+
"training":
31+
{
32+
"x": "data/cifar_10_x_train.npy",
33+
"y": "data/cifar_10_y_train.npy"
34+
}
35+
}
2736
],
2837
"workload-size": "medium"
2938
}

datasets/load_datasets.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
fraud, gisette, hepmass_150K,
2828
higgs, higgs_one_m, higgs_150K, ijcnn, klaverjas,
2929
santander, skin_segmentation, susy)
30-
from .loader_multiclass import (connect, covertype, covtype, letters, mlsr,
30+
from .loader_multiclass import (cifar_10, connect, covertype, covtype, letters, mlsr,
3131
mnist, msrank, plasticc, sensit)
3232
from .loader_regression import (abalone, california_housing, fried, higgs_10500K,
3333
medical_charges_nominal, mortgage_first_q,
@@ -47,6 +47,7 @@
4747
"census": census,
4848
"cifar_binary": cifar_binary,
4949
"cifar_cluster": cifar_cluster,
50+
"cifar_10": cifar_10,
5051
"codrnanorm": codrnanorm,
5152
"connect": connect,
5253
"covertype": covertype,

datasets/loader_multiclass.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,42 @@
2828
from .loader_utils import count_lines, read_libsvm_msrank, retrieve
2929

3030

31+
def cifar_10(dataset_dir: Path) -> bool:
32+
"""
33+
Source:
34+
University of Toronto
35+
Collected by Alex Krizhevsky, Vinod Nair, and Geoffrey Hinton
36+
https://www.cs.toronto.edu/~kriz/cifar.html
37+
38+
Classification task. n_classes = 10
39+
cifar_10 x train dataset (54000, 3072)
40+
cifar_10 y train dataset (54000, 1)
41+
cifar_10 x test dataset (6000, 3072)
42+
cifar_10 y test dataset (6000, 1)
43+
44+
"""
45+
dataset_name = 'cifar_10'
46+
os.makedirs(dataset_dir, exist_ok=True)
47+
48+
X, y = fetch_openml(data_id=40927, return_X_y=True,
49+
as_frame=False, data_home=dataset_dir)
50+
51+
X = pd.DataFrame(X)
52+
y = pd.DataFrame(y)
53+
y = y.astype(int)
54+
55+
logging.info(f'{dataset_name} is loaded, started parsing...')
56+
57+
x_train, x_test, y_train, y_test = train_test_split(
58+
X, y, test_size=0.1, random_state=42)
59+
for data, name in zip((x_train, x_test, y_train, y_test),
60+
('x_train', 'x_test', 'y_train', 'y_test')):
61+
filename = f'{dataset_name}_{name}.npy'
62+
np.save(os.path.join(dataset_dir, filename), data)
63+
logging.info(f'dataset {dataset_name} is ready.')
64+
return True
65+
66+
3167
def connect(dataset_dir: Path) -> bool:
3268
"""
3369
Source:

0 commit comments

Comments
 (0)