|
3 | 3 | { |
4 | 4 | "cell_type": "code", |
5 | 5 | "execution_count": null, |
6 | | - "id": "frequent-field", |
| 6 | + "id": "0", |
7 | 7 | "metadata": { |
8 | 8 | "tags": [] |
9 | 9 | }, |
|
18 | 18 | { |
19 | 19 | "cell_type": "code", |
20 | 20 | "execution_count": null, |
21 | | - "id": "opened-virgin", |
| 21 | + "id": "1", |
22 | 22 | "metadata": { |
23 | 23 | "tags": [] |
24 | 24 | }, |
|
31 | 31 | }, |
32 | 32 | { |
33 | 33 | "cell_type": "markdown", |
34 | | - "id": "lasting-express", |
| 34 | + "id": "2", |
35 | 35 | "metadata": {}, |
36 | 36 | "source": [ |
37 | 37 | "# Optimized Learning\n", |
|
41 | 41 | }, |
42 | 42 | { |
43 | 43 | "cell_type": "markdown", |
44 | | - "id": "forward-process", |
| 44 | + "id": "3", |
45 | 45 | "metadata": {}, |
46 | 46 | "source": [ |
47 | 47 | "## Autograd to JAX\n", |
|
52 | 52 | }, |
53 | 53 | { |
54 | 54 | "cell_type": "markdown", |
55 | | - "id": "correct-cyprus", |
| 55 | + "id": "4", |
56 | 56 | "metadata": {}, |
57 | 57 | "source": [ |
58 | 58 | "## Example: Transforming a function into its derivative\n", |
|
68 | 68 | { |
69 | 69 | "cell_type": "code", |
70 | 70 | "execution_count": null, |
71 | | - "id": "demanding-opportunity", |
| 71 | + "id": "5", |
72 | 72 | "metadata": { |
73 | 73 | "tags": [] |
74 | 74 | }, |
|
90 | 90 | }, |
91 | 91 | { |
92 | 92 | "cell_type": "markdown", |
93 | | - "id": "forty-lindsay", |
| 93 | + "id": "6", |
94 | 94 | "metadata": {}, |
95 | 95 | "source": [ |
96 | 96 | "Here's another example using a polynomial function:\n", |
|
105 | 105 | { |
106 | 106 | "cell_type": "code", |
107 | 107 | "execution_count": null, |
108 | | - "id": "neutral-neighbor", |
| 108 | + "id": "7", |
109 | 109 | "metadata": { |
110 | 110 | "tags": [] |
111 | 111 | }, |
|
128 | 128 | }, |
129 | 129 | { |
130 | 130 | "cell_type": "markdown", |
131 | | - "id": "steady-bikini", |
| 131 | + "id": "8", |
132 | 132 | "metadata": {}, |
133 | 133 | "source": [ |
134 | 134 | "## Using grad to solve minimization problems\n", |
|
147 | 147 | { |
148 | 148 | "cell_type": "code", |
149 | 149 | "execution_count": null, |
150 | | - "id": "opponent-modification", |
| 150 | + "id": "9", |
151 | 151 | "metadata": { |
152 | 152 | "tags": [] |
153 | 153 | }, |
|
163 | 163 | }, |
164 | 164 | { |
165 | 165 | "cell_type": "markdown", |
166 | | - "id": "beautiful-theory", |
| 166 | + "id": "10", |
167 | 167 | "metadata": {}, |
168 | 168 | "source": [ |
169 | 169 | "We know from calculus that the sign of the second derivative tells us whether we have a minima or maxima at a point.\n", |
|
178 | 178 | { |
179 | 179 | "cell_type": "code", |
180 | 180 | "execution_count": null, |
181 | | - "id": "former-syracuse", |
| 181 | + "id": "11", |
182 | 182 | "metadata": {}, |
183 | 183 | "outputs": [], |
184 | 184 | "source": [ |
|
189 | 189 | }, |
190 | 190 | { |
191 | 191 | "cell_type": "markdown", |
192 | | - "id": "surrounded-plain", |
| 192 | + "id": "12", |
193 | 193 | "metadata": {}, |
194 | 194 | "source": [ |
195 | 195 | "Grad is composable an arbitrary number of times. You can keep calling grad as many times as you like." |
196 | 196 | ] |
197 | 197 | }, |
198 | 198 | { |
199 | 199 | "cell_type": "markdown", |
200 | | - "id": "brazilian-atlas", |
| 200 | + "id": "13", |
201 | 201 | "metadata": {}, |
202 | 202 | "source": [ |
203 | 203 | "## Maximum likelihood estimation\n", |
|
216 | 216 | { |
217 | 217 | "cell_type": "code", |
218 | 218 | "execution_count": null, |
219 | | - "id": "confidential-sympathy", |
| 219 | + "id": "14", |
220 | 220 | "metadata": { |
221 | 221 | "tags": [] |
222 | 222 | }, |
|
236 | 236 | }, |
237 | 237 | { |
238 | 238 | "cell_type": "markdown", |
239 | | - "id": "atlantic-excellence", |
| 239 | + "id": "15", |
240 | 240 | "metadata": {}, |
241 | 241 | "source": [ |
242 | 242 | "Our estimation task will necessitate calculating the total joint log likelihood of our data under a Gaussian model.\n", |
|
248 | 248 | { |
249 | 249 | "cell_type": "code", |
250 | 250 | "execution_count": null, |
251 | | - "id": "known-terrain", |
| 251 | + "id": "16", |
252 | 252 | "metadata": { |
253 | 253 | "tags": [] |
254 | 254 | }, |
|
263 | 263 | }, |
264 | 264 | { |
265 | 265 | "cell_type": "markdown", |
266 | | - "id": "terminal-census", |
| 266 | + "id": "17", |
267 | 267 | "metadata": {}, |
268 | 268 | "source": [ |
269 | 269 | "If you're wondering why we use `log_sigma` rather than `sigma`, it is a choice made for practical reasons.\n", |
|
280 | 280 | { |
281 | 281 | "cell_type": "code", |
282 | 282 | "execution_count": null, |
283 | | - "id": "dominant-delight", |
| 283 | + "id": "18", |
284 | 284 | "metadata": { |
285 | 285 | "tags": [] |
286 | 286 | }, |
|
293 | 293 | }, |
294 | 294 | { |
295 | 295 | "cell_type": "markdown", |
296 | | - "id": "equal-brazilian", |
| 296 | + "id": "19", |
297 | 297 | "metadata": {}, |
298 | 298 | "source": [ |
299 | 299 | "Now, we can create the gradient function of our negative log likelihood.\n", |
|
307 | 307 | { |
308 | 308 | "cell_type": "code", |
309 | 309 | "execution_count": null, |
310 | | - "id": "meaning-scanning", |
| 310 | + "id": "20", |
311 | 311 | "metadata": { |
312 | 312 | "tags": [] |
313 | 313 | }, |
|
322 | 322 | }, |
323 | 323 | { |
324 | 324 | "cell_type": "markdown", |
325 | | - "id": "hourly-miller", |
| 325 | + "id": "21", |
326 | 326 | "metadata": {}, |
327 | 327 | "source": [ |
328 | 328 | "Now, we can do the gradient descent step!" |
|
331 | 331 | { |
332 | 332 | "cell_type": "code", |
333 | 333 | "execution_count": null, |
334 | | - "id": "cosmetic-perception", |
| 334 | + "id": "22", |
335 | 335 | "metadata": { |
336 | 336 | "tags": [] |
337 | 337 | }, |
|
347 | 347 | }, |
348 | 348 | { |
349 | 349 | "cell_type": "markdown", |
350 | | - "id": "defensive-family", |
| 350 | + "id": "23", |
351 | 351 | "metadata": {}, |
352 | 352 | "source": [ |
353 | 353 | "And voila! We have gradient descended our way to the maximum likelihood parameters :)." |
354 | 354 | ] |
355 | 355 | }, |
356 | 356 | { |
357 | 357 | "cell_type": "markdown", |
358 | | - "id": "constant-account", |
| 358 | + "id": "24", |
359 | 359 | "metadata": {}, |
360 | 360 | "source": [ |
361 | 361 | "## Exercise: Where is the gold? It's at the minima!\n", |
|
368 | 368 | { |
369 | 369 | "cell_type": "code", |
370 | 370 | "execution_count": null, |
371 | | - "id": "focal-climate", |
| 371 | + "id": "25", |
372 | 372 | "metadata": { |
373 | 373 | "tags": [] |
374 | 374 | }, |
|
383 | 383 | }, |
384 | 384 | { |
385 | 385 | "cell_type": "markdown", |
386 | | - "id": "massive-corps", |
| 386 | + "id": "26", |
387 | 387 | "metadata": {}, |
388 | 388 | "source": [ |
389 | 389 | "It should be evident from here that there are two minima in the function.\n", |
|
398 | 398 | { |
399 | 399 | "cell_type": "code", |
400 | 400 | "execution_count": null, |
401 | | - "id": "opened-beads", |
| 401 | + "id": "27", |
402 | 402 | "metadata": { |
403 | 403 | "tags": [] |
404 | 404 | }, |
|
420 | 420 | }, |
421 | 421 | { |
422 | 422 | "cell_type": "markdown", |
423 | | - "id": "brown-violation", |
| 423 | + "id": "28", |
424 | 424 | "metadata": {}, |
425 | 425 | "source": [ |
426 | 426 | "Now, implement the optimization loop!" |
|
429 | 429 | { |
430 | 430 | "cell_type": "code", |
431 | 431 | "execution_count": null, |
432 | | - "id": "alternative-wisdom", |
| 432 | + "id": "29", |
433 | 433 | "metadata": { |
434 | 434 | "tags": [] |
435 | 435 | }, |
|
450 | 450 | }, |
451 | 451 | { |
452 | 452 | "cell_type": "markdown", |
453 | | - "id": "alternative-iraqi", |
| 453 | + "id": "30", |
454 | 454 | "metadata": {}, |
455 | 455 | "source": [ |
456 | 456 | "## Exercise: programming a robot that only moves along one axis\n", |
|
464 | 464 | { |
465 | 465 | "cell_type": "code", |
466 | 466 | "execution_count": null, |
467 | | - "id": "operational-advantage", |
| 467 | + "id": "31", |
468 | 468 | "metadata": { |
469 | 469 | "tags": [] |
470 | 470 | }, |
|
490 | 490 | }, |
491 | 491 | { |
492 | 492 | "cell_type": "markdown", |
493 | | - "id": "ecological-asian", |
| 493 | + "id": "32", |
494 | 494 | "metadata": {}, |
495 | 495 | "source": [ |
496 | 496 | "For your reference we have the function plotted below." |
|
499 | 499 | { |
500 | 500 | "cell_type": "code", |
501 | 501 | "execution_count": null, |
502 | | - "id": "convenient-optics", |
| 502 | + "id": "33", |
503 | 503 | "metadata": { |
504 | 504 | "tags": [] |
505 | 505 | }, |
|
531 | 531 | { |
532 | 532 | "cell_type": "code", |
533 | 533 | "execution_count": null, |
534 | | - "id": "loaded-labor", |
| 534 | + "id": "34", |
535 | 535 | "metadata": {}, |
536 | 536 | "outputs": [], |
537 | 537 | "source": [] |
|
0 commit comments