diff --git a/mplc/constants.py b/mplc/constants.py index 7f9171b8..db571aee 100644 --- a/mplc/constants.py +++ b/mplc/constants.py @@ -49,8 +49,9 @@ TITANIC = "titanic" ESC50 = "esc50" IMDB = 'imdb' +CIFAR100 = "cifar100" # Supported datasets -SUPPORTED_DATASETS_NAMES = [MNIST, CIFAR10, TITANIC, ESC50, IMDB] +SUPPORTED_DATASETS_NAMES = [MNIST, CIFAR10, TITANIC, ESC50, IMDB, CIFAR100] # Number of attempts allowed before raising an error while trying to download dataset NUMBER_OF_DOWNLOAD_ATTEMPTS = 3 diff --git a/mplc/dataset.py b/mplc/dataset.py index 4f4286bd..3e492da3 100644 --- a/mplc/dataset.py +++ b/mplc/dataset.py @@ -17,7 +17,7 @@ from librosa.feature import mfcc from loguru import logger from sklearn.model_selection import train_test_split -from tensorflow.keras.datasets import cifar10, mnist, imdb +from tensorflow.keras.datasets import cifar10, cifar100, mnist, imdb from tensorflow.keras.layers import Activation from tensorflow.keras.layers import Conv2D, GlobalAveragePooling2D, MaxPooling2D from tensorflow.keras.layers import Dense, Dropout @@ -29,7 +29,8 @@ from tensorflow.keras.utils import to_categorical from . import constants -from .models import LogisticRegression +from .models import LogisticRegression, ModelPytorch +from torchvision import models class Dataset(ABC): @@ -194,6 +195,77 @@ def generate_new_model(self): return model +class Cifar100(Dataset): + def __init__(self): + self.input_shape = (3, 32, 32) + self.num_classes = 100 + x_test, x_train, y_test, y_train = self.load_data() + + super(Cifar100, self).__init__(dataset_name='cifar100', + num_classes=self.num_classes, + input_shape=self.input_shape, + x_train=x_train, + y_train=y_train, + x_test=x_test, + y_test=y_test) + + def load_data(self): + attempts = 0 + while True: + try: + (x_train, y_train), (x_test, y_test) = cifar100.load_data() + break + except (HTTPError, URLError) as e: + if hasattr(e, 'code'): + temp = e.code + else: + temp = e.errno + logger.debug( + f'URL fetch failure on ' + f'https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz : ' + f'{temp} -- {e.reason}') + if attempts < constants.NUMBER_OF_DOWNLOAD_ATTEMPTS: + sleep(2) + attempts += 1 + else: + raise + + # Pre-process inputs + x_train = self.preprocess_dataset_inputs(x_train) + x_test = self.preprocess_dataset_inputs(x_test) + # y_train = self.preprocess_dataset_labels(y_train) + # y_test = self.preprocess_dataset_labels(y_test) + return x_test, x_train, y_test, y_train + + # Data samples pre-processing method for inputs + @staticmethod + def preprocess_dataset_inputs(x): + x = x.astype("float32") + x /= 255 + + return x + + # Data samples pre-processing method for labels + def preprocess_dataset_labels(self, y): + y = to_categorical(y, self.num_classes) + + return y + + # Model structure and generation + def generate_new_model(self): + model = ModelPytorch() + return model + + # train, test, val splits + @staticmethod + def train_test_split_local(x, y): + return train_test_split(x, y, test_size=0.1, random_state=42) + + @staticmethod + def train_val_split_local(x, y): + return train_test_split(x, y, test_size=0.1, random_state=42) + + class Titanic(Dataset): def __init__(self, proportion=1, val_proportion=0.1): diff --git a/mplc/models.py b/mplc/models.py index 6ed4ca10..45da8a0e 100644 --- a/mplc/models.py +++ b/mplc/models.py @@ -1,10 +1,16 @@ import numpy as np +import collections from joblib import dump, load from loguru import logger from sklearn.linear_model import LogisticRegression as skLR from sklearn.metrics import log_loss from tensorflow.keras.backend import dot from tensorflow.keras.layers import Dense +import torch, torchvision +import torch.nn as nn +import torch.optim as optim +import torch.utils.data as data +import torchvision.transforms as transforms class LogisticRegression(skLR): @@ -88,6 +94,157 @@ def load_model(path): path.replace('.h5', '.joblib') return load(path) +class cifar100_dataset(torch.utils.data.Dataset): + + def __init__(self, x, y, transform=[]): + self.x = x + self.y = y + self.transform = transform + + def __len__(self): + return len(self.x) + + def __getitem__(self, index): + + x = self.x[index] + y = torch.tensor(int(self.y[index][0])) + + if self.transform: + x = self.transform(x) + + return x, y + +class ModelPytorch(nn.Module): + def __init__(self): + super(ModelPytorch, self).__init__() + model = torchvision.models.vgg16() + self.features = nn.Sequential(model.features) + self.avgpool = nn.AdaptiveAvgPool2d(output_size=(7, 7)) + self.classifier = nn.Sequential( + nn.Linear(25088, 4096), + nn.ReLU(inplace=True), + nn.Dropout(p=0.5, inplace=False), + nn.Linear(4096, 4096), + nn.ReLU(inplace=True), + nn.Dropout(p=0.5, inplace=False), + nn.Linear(4096, 1000) + ) + self.optimizer = optim.Adam(model.parameters(), lr=1e-3) + + + def forward(self, x): + x = self.features(x) + x = self.avgpool(x) + x = x.view(x.size(0), -1) + return self.classifier(x) + + + def fit(self, x_train, y_train, batch_size, validation_data, epochs=1, verbose=False, callbacks=None): + criterion = nn.CrossEntropyLoss() + transform = transforms.Compose([transforms.ToTensor()]) + + train_data = cifar100_dataset(x_train, y_train, transform) + train_loader = data.DataLoader(train_data, batch_size=int(batch_size), shuffle=True) + + history = super(ModelPytorch, self).train() + + for batch_idx, (image, label) in enumerate(train_loader): + images, labels = torch.autograd.Variable(image), torch.autograd.Variable(label) + + outputs = self.forward(images) + loss = criterion(outputs, labels) + + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + [loss, acc] = self.evaluate(x_train, y_train) + [val_loss, val_acc] = self.evaluate(*validation_data) + # Mimic Keras' history + history.history = { + 'loss': [loss], + 'accuracy': [acc], + 'val_loss': [val_loss], + 'val_accuracy': [val_acc] + } + + return history + + def evaluate(self, x_eval, y_eval, **kwargs): + criterion = nn.CrossEntropyLoss() + transform = transforms.Compose([transforms.ToTensor()]) + + test_data = cifar100_dataset(x_eval, y_eval, transform) + test_loader = data.DataLoader(test_data, shuffle=True) + + self.eval() + + with torch.no_grad(): + + y_true_np = [] + y_pred_np = [] + count=0 + for i, (images, labels) in enumerate(test_loader): + count+= 1 + N = images.size(0) + + images = torch.autograd.Variable(images) + labels = torch.autograd.Variable(labels) + + outputs = self(images) + predictions = outputs.max(1, keepdim=True)[1] + + val_loss =+ criterion(outputs, labels).item() + val_acc =+ (predictions.eq(labels.view_as(predictions)).sum().item() / N) + + model_evaluation = [val_loss/count, val_acc/count] + + return model_evaluation + + + def save_weights(self, path): + if '.h5' in path: + logger.debug('Automatically switch file format from .h5 to .pth') + path.replace('.h5', '.pth') + torch.save(self.state_dict(), path) + + + def load_weights(self, path): + if '.h5' in path: + logger.debug('Automatically switch file format from .h5 to .pth') + path.replace('.h5', '.pth') + weights = torch.load(path) + self.set_weights(weights) + + + def get_weights(self): + self.state_dict() + weights = [] + for layer in self.state_dict().keys(): + weights.append(self.state_dict()[layer].numpy()) + return weights + + + def set_weights(self, weights): + for i, layer in enumerate(self.state_dict().keys()): + self.state_dict()[layer]= torch.Tensor(weights[i]) + + + def save_model(self, path): + if '.h5' in path: + logger.debug('Automatically switch file format from .h5 to .pth') + path.replace('.h5', '.pth') + torch.save(self, path) + + + @staticmethod + def load_model(path): + if '.h5' in path: + logger.debug('Automatically switch file format from .h5 to .pth') + path.replace('.h5', '.pth') + model = torch.load(path) + return model.eval() + class NoiseAdaptationChannel(Dense): """ diff --git a/mplc/scenario.py b/mplc/scenario.py index a77d3e1f..485a5414 100644 --- a/mplc/scenario.py +++ b/mplc/scenario.py @@ -138,6 +138,8 @@ def __init__( self.dataset = dataset_module.Esc50() elif dataset_name == constants.IMDB: self.dataset = dataset_module.Imdb() + elif dataset_name == constants.CIFAR100: + self.dataset = dataset_module.Cifar100() else: raise Exception( f"Dataset named '{dataset_name}' is not supported (yet). You can construct your own "