|
| 1 | +import torch |
| 2 | +import torch.nn as nn |
| 3 | +import torch.nn.functional as F |
| 4 | +from torch.autograd import Variable |
| 5 | + |
| 6 | + |
| 7 | +class IrisNet(nn.Module): |
| 8 | + def __init__(self): |
| 9 | + super(IrisNet, self).__init__() |
| 10 | + self.fc1 = nn.Linear(4, 100) |
| 11 | + self.fc2 = nn.Linear(100, 100) |
| 12 | + self.fc3 = nn.Linear(100, 3) |
| 13 | + self.softmax = nn.Softmax(dim=1) |
| 14 | + |
| 15 | + def forward(self, X): |
| 16 | + X = F.relu(self.fc1(X)) |
| 17 | + X = self.fc2(X) |
| 18 | + X = self.fc3(X) |
| 19 | + X = self.softmax(X) |
| 20 | + return X |
| 21 | + |
| 22 | + |
| 23 | +if __name__ == "__main__": |
| 24 | + from sklearn.datasets import load_iris |
| 25 | + from sklearn.model_selection import train_test_split |
| 26 | + from sklearn.metrics import accuracy_score |
| 27 | + |
| 28 | + iris = load_iris() |
| 29 | + X, y = iris.data, iris.target |
| 30 | + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.8, random_state=42) |
| 31 | + |
| 32 | + train_X = Variable(torch.Tensor(X_train).float()) |
| 33 | + test_X = Variable(torch.Tensor(X_test).float()) |
| 34 | + train_y = Variable(torch.Tensor(y_train).long()) |
| 35 | + test_y = Variable(torch.Tensor(y_test).long()) |
| 36 | + |
| 37 | + model = IrisNet() |
| 38 | + |
| 39 | + criterion = nn.CrossEntropyLoss() |
| 40 | + |
| 41 | + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) |
| 42 | + |
| 43 | + for epoch in range(1000): |
| 44 | + optimizer.zero_grad() |
| 45 | + out = model(train_X) |
| 46 | + loss = criterion(out, train_y) |
| 47 | + loss.backward() |
| 48 | + optimizer.step() |
| 49 | + |
| 50 | + if epoch % 100 == 0: |
| 51 | + print("number of epoch {} loss {}".format(epoch, loss)) |
| 52 | + |
| 53 | + predict_out = model(test_X) |
| 54 | + _, predict_y = torch.max(predict_out, 1) |
| 55 | + |
| 56 | + print("prediction accuracy {}".format(accuracy_score(test_y.data, predict_y.data))) |
| 57 | + |
| 58 | + torch.save(model.state_dict(), "weights.pth") |
0 commit comments