Skip to content

Commit 0bb0faf

Browse files
committed
Add pytorch model for cifar100 [WIP]
1 parent 714171f commit 0bb0faf

File tree

1 file changed

+236
-1
lines changed

1 file changed

+236
-1
lines changed

mplc/dataset.py

Lines changed: 236 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import numpy as np
1515
import pandas as pd
1616
from joblib import dump, load
17-
from keras.datasets import cifar10, mnist, imdb
17+
from keras.datasets import cifar10, cifar100, mnist, imdb
1818
from keras.layers import Activation
1919
from keras.layers import Conv2D, GlobalAveragePooling2D, MaxPooling2D
2020
from keras.layers import Dense, Dropout
@@ -209,6 +209,241 @@ def train_val_split_local(x, y):
209209
return train_test_split(x, y, test_size=0.1, random_state=42)
210210

211211

212+
class Cifar100(Dataset):
213+
def __init__(self):
214+
self.input_shape = (32, 32, 3)
215+
self.num_classes = 100
216+
x_test, x_train, y_test, y_train = self.load_data()
217+
218+
super(Cifar10, self).__init__(dataset_name='cifar100',
219+
num_classes=self.num_classes,
220+
input_shape=self.input_shape,
221+
x_train=x_train,
222+
y_train=y_train,
223+
x_test=x_test,
224+
y_test=y_test)
225+
226+
def load_data(self):
227+
attempts = 0
228+
while True:
229+
try:
230+
(x_train, y_train), (x_test, y_test) = cifar100.load_data()
231+
break
232+
except (HTTPError, URLError) as e:
233+
if hasattr(e, 'code'):
234+
temp = e.code
235+
else:
236+
temp = e.errno
237+
logger.debug(
238+
f'URL fetch failure on '
239+
f'https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz : '
240+
f'{temp} -- {e.reason}')
241+
if attempts < constants.NUMBER_OF_DOWNLOAD_ATTEMPTS:
242+
sleep(2)
243+
attempts += 1
244+
else:
245+
raise
246+
247+
# Pre-process inputs
248+
x_train = self.preprocess_dataset_inputs(x_train)
249+
x_test = self.preprocess_dataset_inputs(x_test)
250+
y_train = self.preprocess_dataset_labels(y_train)
251+
y_test = self.preprocess_dataset_labels(y_test)
252+
return x_test, x_train, y_test, y_train
253+
254+
# Data samples pre-processing method for inputs
255+
@staticmethod
256+
def preprocess_dataset_inputs(x):
257+
x = x.astype("float32")
258+
x /= 255
259+
260+
return x
261+
262+
# Data samples pre-processing method for labels
263+
def preprocess_dataset_labels(self, y):
264+
y = to_categorical(y, self.num_classes)
265+
266+
return y
267+
268+
# Model structure and generation
269+
def generate_new_model(self):
270+
"""Return a CNN model from scratch based on given batch_size"""
271+
272+
model = models.vgg16()
273+
274+
# TODO: Add new model
275+
# model = Sequential()
276+
# model.add(Conv2D(32, (3, 3), padding='same', input_shape=self.input_shape))
277+
# model.add(Activation('relu'))
278+
# model.add(Conv2D(32, (3, 3)))
279+
# model.add(Activation('relu'))
280+
# model.add(MaxPooling2D(pool_size=(2, 2)))
281+
# model.add(Dropout(0.25))
282+
283+
# model.add(Conv2D(64, (3, 3), padding='same'))
284+
# model.add(Activation('relu'))
285+
# model.add(Conv2D(64, (3, 3)))
286+
# model.add(Activation('relu'))
287+
# model.add(MaxPooling2D(pool_size=(2, 2)))
288+
# model.add(Dropout(0.25))
289+
290+
# model.add(Flatten())
291+
# model.add(Dense(512))
292+
# model.add(Activation('relu'))
293+
# model.add(Dropout(0.5))
294+
# model.add(Dense(self.num_classes))
295+
# model.add(Activation('softmax'))
296+
297+
# # initiate RMSprop optimizer
298+
# opt = RMSprop(learning_rate=0.0001, decay=1e-6)
299+
300+
# # Let's train the model using RMSprop
301+
# model.compile(loss='categorical_crossentropy',
302+
# optimizer=opt,
303+
# metrics=['accuracy'])
304+
305+
return model
306+
307+
# train, test, val splits
308+
@staticmethod
309+
def train_test_split_local(x, y):
310+
return train_test_split(x, y, test_size=0.1, random_state=42)
311+
312+
@staticmethod
313+
def train_val_split_local(x, y):
314+
return train_test_split(x, y, test_size=0.1, random_state=42)
315+
316+
317+
class cifar100_dataset(torch.utils.data.Dataset):
318+
319+
def __init__(self, x, y, transform=[]):
320+
self.x = x
321+
self.y = y
322+
self.transform = transform
323+
324+
def __len__(self):
325+
return len(self.x)
326+
327+
def __getitem__(self, index):
328+
329+
x = self.x[index]
330+
y = torch.tensor(int(self.y[index]))
331+
332+
if self.transform:
333+
x = self.transform(x)
334+
335+
return x, y
336+
337+
338+
class ModelPytorch(torchvision.model.vgg16):
339+
def __init__(self, optimizer, criterion):
340+
super(Cifar100.ModelPytorch, self).__init__()
341+
self.optimizer = optimizer
342+
self.criterion = criterion
343+
344+
def fit(self, x_train, y_train, batch_size, validation_data, epochs=1, verbose=False):
345+
train_data = cifar100_dataset(x_train, y_train)
346+
train_loader = data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
347+
348+
history = super(Cifar100.ModelPytorch, self).train()
349+
350+
for batch_idx, (image, label) in enumerate(trainloader):
351+
images, labels = torch.autograd.Variable(image), torch.autograd.Variable(label)
352+
353+
outputs = model(images)
354+
loss = self.criterion(outputs, labels)
355+
356+
self.optimizer.zero_grad()
357+
loss.backward()
358+
self.optimizer.step()
359+
360+
[loss, acc] = self.evaluate(x_train, y_train)
361+
[val_loss, val_acc] = self.evaluate(*validation_data)
362+
# Mimic Keras' history
363+
history.history = {
364+
'loss': [loss],
365+
'accuracy': [acc],
366+
'val_loss': [val_loss],
367+
'val_accuracy': [val_acc]
368+
}
369+
370+
return history
371+
372+
def evaluate(self, x_eval, y_eval, **kwargs):
373+
test_data = cifar100_dataset(x_eval, y_eval)
374+
test_loader = data.DataLoader(test_data, batch_size=batch_size, shuffle=True)
375+
376+
self.eval()
377+
378+
with torch.no_grad():
379+
380+
y_true_np = []
381+
y_pred_np = []
382+
count=0
383+
for i, (images, labels) in enumerate(validation_loader):
384+
count+= 1
385+
N = images.size(0)
386+
387+
images = torch.autograd.Variable(images)
388+
labels = torch.autograd.Variable(labels)
389+
390+
outputs = model_ft(images)
391+
392+
predictions = outputs.max(1, keepdim=True)[1]
393+
394+
val_loss =+ criterion(outputs, labels).item()
395+
val_acc =+ (predictions.eq(labels.view_as(predictions)).sum().item() / N)
396+
397+
model_evaluation = [val_loss/count, val_acc/count]
398+
399+
return model_evaluation
400+
401+
#TODO
402+
# def save_weights(self, path):
403+
# if self.coef_ is None:
404+
# raise ValueError(
405+
# 'Coef and intercept are set to None, it seems the model has not been fit properly.')
406+
# if '.h5' in path:
407+
# logger.debug('Automatically switch file format from .h5 to .npy')
408+
# path.replace('.h5', '.npy')
409+
# np.save(path, self.get_weights())
410+
411+
# def load_weights(self, path):
412+
# if '.h5' in path:
413+
# logger.debug('Automatically switch file format from .h5 to .npy')
414+
# path.replace('.h5', '.npy')
415+
# weights = load(path)
416+
# self.set_weights(weights)
417+
418+
# def get_weights(self):
419+
# if self.coef_ is None:
420+
# return None
421+
# else:
422+
# return np.concatenate((self.coef_, self.intercept_.reshape(1, 1)), axis=1)
423+
424+
# def set_weights(self, weights):
425+
# if weights is None:
426+
# self.coef_ = None
427+
# self.intercept_ = None
428+
# else:
429+
# self.coef_ = np.array(weights[0][:-1]).reshape(1, -1)
430+
# self.intercept_ = np.array(weights[0][-1]).reshape(1)
431+
432+
# def save_model(self, path):
433+
# if '.h5' in path:
434+
# logger.debug('Automatically switch file format from .h5 to .joblib')
435+
# path.replace('.h5', '.joblib')
436+
# dump(self, path)
437+
438+
# @staticmethod
439+
# def load_model(path):
440+
# if '.h5' in path:
441+
# logger.debug('Automatically switch file format from .h5 to .joblib')
442+
# path.replace('.h5', '.joblib')
443+
# return load(path)
444+
445+
446+
212447
class Titanic(Dataset):
213448
def __init__(self):
214449
self.num_classes = 2

0 commit comments

Comments
 (0)