|
2 | 2 | "cells": [ |
3 | 3 | { |
4 | 4 | "cell_type": "raw", |
5 | | - "id": "49d2aaf9", |
| 5 | + "id": "ea7ceaf7", |
6 | 6 | "metadata": {}, |
7 | 7 | "source": [ |
8 | 8 | "# Copyright 2023 NVIDIA Corporation. All Rights Reserved.\n", |
|
23 | 23 | }, |
24 | 24 | { |
25 | 25 | "cell_type": "markdown", |
26 | | - "id": "fc0ea279", |
| 26 | + "id": "4e1a2027", |
27 | 27 | "metadata": {}, |
28 | 28 | "source": [ |
29 | 29 | "# Edge Classification Pretraining demo (IEEE)" |
30 | 30 | ] |
31 | 31 | }, |
32 | 32 | { |
33 | 33 | "cell_type": "markdown", |
34 | | - "id": "5a5196e2", |
| 34 | + "id": "05d3d798", |
35 | 35 | "metadata": {}, |
36 | 36 | "source": [ |
37 | 37 | "## Overview\n", |
| 38 | + "\n", |
| 39 | + "Often times it is helpful to pre-train or initialize a network with learned weights on a downstream task of interest and further fine-tune.\n", |
| 40 | + "\n", |
38 | 41 | "This notebook demonstrates the steps for pretraing a GNN on synthetic data and finetuning on real data. " |
39 | 42 | ] |
40 | 43 | }, |
41 | 44 | { |
42 | 45 | "cell_type": "markdown", |
43 | | - "id": "01a6300c", |
| 46 | + "id": "26f39e76", |
44 | 47 | "metadata": {}, |
45 | 48 | "source": [ |
46 | 49 | "### Imports" |
|
49 | 52 | { |
50 | 53 | "cell_type": "code", |
51 | 54 | "execution_count": 1, |
52 | | - "id": "69a28c16", |
| 55 | + "id": "f315cfcd", |
53 | 56 | "metadata": {}, |
54 | 57 | "outputs": [ |
55 | 58 | { |
|
94 | 97 | }, |
95 | 98 | { |
96 | 99 | "cell_type": "markdown", |
97 | | - "id": "01eddd70", |
| 100 | + "id": "20e3e3a6", |
98 | 101 | "metadata": {}, |
99 | 102 | "source": [ |
100 | 103 | "### Generate synthetic data" |
101 | 104 | ] |
102 | 105 | }, |
| 106 | + { |
| 107 | + "cell_type": "markdown", |
| 108 | + "id": "5c3db76c", |
| 109 | + "metadata": {}, |
| 110 | + "source": [ |
| 111 | + "In the following cells, a synthesizer is instantiated and fitted on the IEEE dataset.\n", |
| 112 | + "\n", |
| 113 | + "Once fitted, the synthesizer is used to generate synthetic data with similar characteristics.\n", |
| 114 | + "\n", |
| 115 | + "For a more detailed explanation checkout the `e2e_ieee_demo.ipynb`" |
| 116 | + ] |
| 117 | + }, |
103 | 118 | { |
104 | 119 | "cell_type": "code", |
105 | 120 | "execution_count": 2, |
106 | | - "id": "65da8b0a", |
| 121 | + "id": "8f86bf18", |
107 | 122 | "metadata": {}, |
108 | 123 | "outputs": [ |
109 | 124 | { |
|
131 | 146 | { |
132 | 147 | "cell_type": "code", |
133 | 148 | "execution_count": 3, |
134 | | - "id": "b0b64872", |
| 149 | + "id": "60bb8cfb", |
135 | 150 | "metadata": {}, |
136 | 151 | "outputs": [], |
137 | 152 | "source": [ |
|
145 | 160 | { |
146 | 161 | "cell_type": "code", |
147 | 162 | "execution_count": 4, |
148 | | - "id": "ac0d50f7", |
| 163 | + "id": "37d4eb69", |
149 | 164 | "metadata": {}, |
150 | 165 | "outputs": [ |
151 | 166 | { |
|
164 | 179 | { |
165 | 180 | "cell_type": "code", |
166 | 181 | "execution_count": 5, |
167 | | - "id": "84732600", |
| 182 | + "id": "873d0cf2", |
168 | 183 | "metadata": {}, |
169 | 184 | "outputs": [ |
170 | 185 | { |
|
204 | 219 | { |
205 | 220 | "cell_type": "code", |
206 | 221 | "execution_count": 6, |
207 | | - "id": "b615610c", |
| 222 | + "id": "b08f1603", |
208 | 223 | "metadata": {}, |
209 | 224 | "outputs": [ |
210 | 225 | { |
|
251 | 266 | }, |
252 | 267 | { |
253 | 268 | "cell_type": "markdown", |
254 | | - "id": "66f7a839", |
| 269 | + "id": "03e21408", |
255 | 270 | "metadata": {}, |
256 | 271 | "source": [ |
257 | 272 | "### Train GNN" |
258 | 273 | ] |
259 | 274 | }, |
260 | 275 | { |
261 | 276 | "cell_type": "markdown", |
262 | | - "id": "07805108", |
| 277 | + "id": "a834318e", |
| 278 | + "metadata": {}, |
| 279 | + "source": [ |
| 280 | + "To train an example GNN we need the following:\n", |
| 281 | + "\n", |
| 282 | + "- a dataset object instantiated using either the synthetic or original data\n", |
| 283 | + "- the model, optimizer and hyperparameters defined\n", |
| 284 | + "\n", |
| 285 | + "In the tool an example dataloader is implemented for edge classification under `syngen/benchmark/data_loader`.\n", |
| 286 | + "\n", |
| 287 | + "This dataset object is used to great the dgl graphs corresponding to both the generated data and real data." |
| 288 | + ] |
| 289 | + }, |
| 290 | + { |
| 291 | + "cell_type": "markdown", |
| 292 | + "id": "28fabfa9", |
263 | 293 | "metadata": {}, |
264 | 294 | "source": [ |
265 | 295 | "#### Create datasets" |
|
268 | 298 | { |
269 | 299 | "cell_type": "code", |
270 | 300 | "execution_count": 7, |
271 | | - "id": "0fe941f0", |
| 301 | + "id": "f7e8bd44", |
272 | 302 | "metadata": {}, |
273 | 303 | "outputs": [], |
274 | 304 | "source": [ |
|
279 | 309 | }, |
280 | 310 | { |
281 | 311 | "cell_type": "markdown", |
282 | | - "id": "a8b23137", |
| 312 | + "id": "b830709c", |
283 | 313 | "metadata": {}, |
284 | 314 | "source": [ |
285 | 315 | "#### Create helper function\n" |
286 | 316 | ] |
287 | 317 | }, |
| 318 | + { |
| 319 | + "cell_type": "markdown", |
| 320 | + "id": "b959a3a2", |
| 321 | + "metadata": {}, |
| 322 | + "source": [ |
| 323 | + "The helper function defines a simple trianing loop and standard metrics for edge classification." |
| 324 | + ] |
| 325 | + }, |
288 | 326 | { |
289 | 327 | "cell_type": "code", |
290 | 328 | "execution_count": 8, |
291 | | - "id": "f46973e3", |
| 329 | + "id": "5c4bec86", |
292 | 330 | "metadata": {}, |
293 | 331 | "outputs": [], |
294 | 332 | "source": [ |
|
329 | 367 | }, |
330 | 368 | { |
331 | 369 | "cell_type": "markdown", |
332 | | - "id": "6ad092e6", |
| 370 | + "id": "dc4cea06", |
333 | 371 | "metadata": {}, |
334 | 372 | "source": [ |
335 | 373 | "#### No-Pretrain" |
336 | 374 | ] |
337 | 375 | }, |
| 376 | + { |
| 377 | + "cell_type": "markdown", |
| 378 | + "id": "093203f8", |
| 379 | + "metadata": {}, |
| 380 | + "source": [ |
| 381 | + "Without pre-training the model is trained from scratch using the original data graph." |
| 382 | + ] |
| 383 | + }, |
338 | 384 | { |
339 | 385 | "cell_type": "code", |
340 | 386 | "execution_count": 9, |
341 | | - "id": "d4ad039a", |
| 387 | + "id": "93ab387d", |
342 | 388 | "metadata": {}, |
343 | 389 | "outputs": [ |
344 | 390 | { |
|
383 | 429 | }, |
384 | 430 | { |
385 | 431 | "cell_type": "markdown", |
386 | | - "id": "7f061442", |
| 432 | + "id": "08f5280a", |
387 | 433 | "metadata": {}, |
388 | 434 | "source": [ |
389 | 435 | "#### Pretrain" |
390 | 436 | ] |
391 | 437 | }, |
| 438 | + { |
| 439 | + "cell_type": "markdown", |
| 440 | + "id": "18bebba4", |
| 441 | + "metadata": {}, |
| 442 | + "source": [ |
| 443 | + "In this example the model is first trained on the generated data for a certain epoch budget.\n", |
| 444 | + "\n", |
| 445 | + "Subsequently it is further trained on the original data graph." |
| 446 | + ] |
| 447 | + }, |
392 | 448 | { |
393 | 449 | "cell_type": "code", |
394 | 450 | "execution_count": 10, |
395 | | - "id": "2f3985b2", |
| 451 | + "id": "e21ab679", |
396 | 452 | "metadata": {}, |
397 | 453 | "outputs": [ |
398 | 454 | { |
|
438 | 494 | { |
439 | 495 | "cell_type": "code", |
440 | 496 | "execution_count": 11, |
441 | | - "id": "f33bec4f", |
| 497 | + "id": "8b615c76", |
442 | 498 | "metadata": {}, |
443 | 499 | "outputs": [ |
444 | 500 | { |
|
458 | 514 | }, |
459 | 515 | { |
460 | 516 | "cell_type": "markdown", |
461 | | - "id": "a6f0cfbe", |
| 517 | + "id": "69b9e95c", |
462 | 518 | "metadata": {}, |
463 | 519 | "source": [ |
464 | 520 | "### CLI example" |
465 | 521 | ] |
466 | 522 | }, |
467 | 523 | { |
468 | 524 | "cell_type": "markdown", |
469 | | - "id": "2c48ec37", |
| 525 | + "id": "93fd05a0", |
| 526 | + "metadata": {}, |
| 527 | + "source": [ |
| 528 | + "The tool also provides this functionality through its CLI.\n", |
| 529 | + "\n", |
| 530 | + "The commands used to generate and pretrain/fine tune on the downstream tasks as done above are provided below." |
| 531 | + ] |
| 532 | + }, |
| 533 | + { |
| 534 | + "cell_type": "markdown", |
| 535 | + "id": "8de441fe", |
470 | 536 | "metadata": {}, |
471 | 537 | "source": [ |
472 | 538 | "#### Generate synthetic graph" |
|
475 | 541 | { |
476 | 542 | "cell_type": "code", |
477 | 543 | "execution_count": 1, |
478 | | - "id": "b588c44a", |
| 544 | + "id": "af89d214", |
479 | 545 | "metadata": {}, |
480 | 546 | "outputs": [ |
481 | 547 | { |
|
553 | 619 | }, |
554 | 620 | { |
555 | 621 | "cell_type": "markdown", |
556 | | - "id": "01eeff23", |
| 622 | + "id": "7fef4fb7", |
557 | 623 | "metadata": {}, |
558 | 624 | "source": [ |
559 | 625 | "#### Results without pretraining" |
|
562 | 628 | { |
563 | 629 | "cell_type": "code", |
564 | 630 | "execution_count": 2, |
565 | | - "id": "50238488", |
| 631 | + "id": "c65ab4be", |
566 | 632 | "metadata": {}, |
567 | 633 | "outputs": [ |
568 | 634 | { |
|
607 | 673 | }, |
608 | 674 | { |
609 | 675 | "cell_type": "markdown", |
610 | | - "id": "1a8474cb", |
| 676 | + "id": "e6655f58", |
611 | 677 | "metadata": {}, |
612 | 678 | "source": [ |
613 | 679 | "#### Pretrain and finetune" |
|
616 | 682 | { |
617 | 683 | "cell_type": "code", |
618 | 684 | "execution_count": 3, |
619 | | - "id": "92039366", |
| 685 | + "id": "fd2b8caf", |
620 | 686 | "metadata": {}, |
621 | 687 | "outputs": [ |
622 | 688 | { |
|
668 | 734 | { |
669 | 735 | "cell_type": "code", |
670 | 736 | "execution_count": null, |
671 | | - "id": "f0405bf2", |
| 737 | + "id": "2da530b6", |
672 | 738 | "metadata": {}, |
673 | 739 | "outputs": [], |
674 | 740 | "source": [] |
|
693 | 759 | "name": "python", |
694 | 760 | "nbconvert_exporter": "python", |
695 | 761 | "pygments_lexer": "ipython3", |
696 | | - "version": "3.8.15" |
| 762 | + "version": "3.8.10" |
697 | 763 | } |
698 | 764 | }, |
699 | 765 | "nbformat": 4, |
|
0 commit comments