Skip to content

Commit d5a4e7c

Browse files
Good practice for next(iter(data))
1 parent 3a95d11 commit d5a4e7c

File tree

1 file changed

+15
-6
lines changed

1 file changed

+15
-6
lines changed

intro-to-pytorch/Part 3 - Training Neural Networks (Exercises).ipynb

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,10 @@
118118
"criterion = nn.CrossEntropyLoss()\n",
119119
"\n",
120120
"# Get our data\n",
121-
"images, labels = next(iter(trainloader))\n",
121+
"dataiter = iter(trainloader)\n",
122+
"\n",
123+
"images, labels = next(dataiter)\n",
124+
"\n",
122125
"# Flatten images\n",
123126
"images = images.view(images.shape[0], -1)\n",
124127
"\n",
@@ -153,7 +156,10 @@
153156
"\n",
154157
"### Run this to check your work\n",
155158
"# Get our data\n",
156-
"images, labels = next(iter(trainloader))\n",
159+
"dataiter = iter(trainloader)\n",
160+
"\n",
161+
"images, labels = next(dataiter)\n",
162+
"\n",
157163
"# Flatten images\n",
158164
"images = images.view(images.shape[0], -1)\n",
159165
"\n",
@@ -310,7 +316,8 @@
310316
" nn.LogSoftmax(dim=1))\n",
311317
"\n",
312318
"criterion = nn.NLLLoss()\n",
313-
"images, labels = next(iter(trainloader))\n",
319+
"dataiter = iter(trainloader)\n",
320+
"images, labels = next(dataiter)\n",
314321
"images = images.view(images.shape[0], -1)\n",
315322
"\n",
316323
"logits = model(images)\n",
@@ -373,7 +380,8 @@
373380
"source": [
374381
"print('Initial weights - ', model[0].weight)\n",
375382
"\n",
376-
"images, labels = next(iter(trainloader))\n",
383+
"dataiter = iter(trainloader)\n",
384+
"images, labels = next(dataiter)\n",
377385
"images.resize_(64, 784)\n",
378386
"\n",
379387
"# Clear the gradients, do this because gradients are accumulated\n",
@@ -458,7 +466,8 @@
458466
"%matplotlib inline\n",
459467
"import helper\n",
460468
"\n",
461-
"images, labels = next(iter(trainloader))\n",
469+
"dataiter = iter(trainloader)\n",
470+
"images, labels = next(dataiter)\n",
462471
"\n",
463472
"img = images[0].view(1, 784)\n",
464473
"# Turn off gradients to speed up this part\n",
@@ -494,7 +503,7 @@
494503
"name": "python",
495504
"nbconvert_exporter": "python",
496505
"pygments_lexer": "ipython3",
497-
"version": "3.6.7"
506+
"version": "3.8.5"
498507
}
499508
},
500509
"nbformat": 4,

0 commit comments

Comments
 (0)