|
11 | 11 | "In this notebook we build a binarry classifier for the ATIS Dataset using [BERT](https://arxiv.org/abs/1810.04805), a pre-Trained NLP model open soucred by google in late 2018 that can be used for [Transfer Learning](https://towardsdatascience.com/transfer-learning-in-nlp-fecc59f546e4) on text data. This notebook has been adapted from this [Article](https://towardsdatascience.com/bert-for-dummies-step-by-step-tutorial-fb90890ffe03). The link for the dataset can be found [here](https://www.kaggle.com/siddhadev/ms-cntk-atis/data#).<br> This notebook requires a GPU to get setup. We suggest you to run this on your local machine only if you have a GPU setup or else you can use google colab." |
12 | 12 | ] |
13 | 13 | }, |
| 14 | + { |
| 15 | + "cell_type": "markdown", |
| 16 | + "metadata": {}, |
| 17 | + "source": [ |
| 18 | + "## Imports" |
| 19 | + ] |
| 20 | + }, |
14 | 21 | { |
15 | 22 | "cell_type": "code", |
16 | 23 | "execution_count": 0, |
|
115 | 122 | } |
116 | 123 | ], |
117 | 124 | "source": [ |
118 | | - "#importing a few necessary packages and setting the DATA directory\n", |
119 | 125 | "\n", |
| 126 | + "#if not using colab, comment below line\n", |
120 | 127 | "%tensorflow_version 1.x\n", |
121 | 128 | "\n", |
122 | 129 | "from torch.nn import Adam\n", |
|
150 | 157 | "torch.cuda.get_device_name(0)" |
151 | 158 | ] |
152 | 159 | }, |
| 160 | + { |
| 161 | + "cell_type": "markdown", |
| 162 | + "metadata": {}, |
| 163 | + "source": [ |
| 164 | + "## Data Loading" |
| 165 | + ] |
| 166 | + }, |
153 | 167 | { |
154 | 168 | "cell_type": "code", |
155 | 169 | "execution_count": 0, |
|
345 | 359 | "query_data_test, intent_data_test, intent_data_label_test, slot_data_test = load_atis('atis.test.pkl')\n" |
346 | 360 | ] |
347 | 361 | }, |
| 362 | + { |
| 363 | + "cell_type": "markdown", |
| 364 | + "metadata": {}, |
| 365 | + "source": [ |
| 366 | + "Let's look at a few training queries." |
| 367 | + ] |
| 368 | + }, |
348 | 369 | { |
349 | 370 | "cell_type": "code", |
350 | 371 | "execution_count": 0, |
|
381 | 402 | "query_data_train" |
382 | 403 | ] |
383 | 404 | }, |
| 405 | + { |
| 406 | + "cell_type": "markdown", |
| 407 | + "metadata": {}, |
| 408 | + "source": [ |
| 409 | + "## Data Pre-processing\n", |
| 410 | + "We need to convert the sentences to tensors." |
| 411 | + ] |
| 412 | + }, |
384 | 413 | { |
385 | 414 | "cell_type": "code", |
386 | 415 | "execution_count": 0, |
|
431 | 460 | ] |
432 | 461 | }, |
433 | 462 | { |
434 | | - "cell_type": "code", |
435 | | - "execution_count": 0, |
436 | | - "metadata": { |
437 | | - "colab": {}, |
438 | | - "colab_type": "code", |
439 | | - "id": "S9SMEwslo-ve" |
440 | | - }, |
441 | | - "outputs": [], |
442 | | - "source": [] |
| 463 | + "cell_type": "markdown", |
| 464 | + "metadata": {}, |
| 465 | + "source": [ |
| 466 | + "BERT expects data to be in a specific format, i.e, [CLS] token1,token2,....[SEP]" |
| 467 | + ] |
443 | 468 | }, |
444 | 469 | { |
445 | 470 | "cell_type": "code", |
|
508 | 533 | "input_ids = pad_sequences(input_ids, maxlen=MAX_LEN, dtype=\"long\", truncating=\"post\", padding=\"post\")" |
509 | 534 | ] |
510 | 535 | }, |
| 536 | + { |
| 537 | + "cell_type": "markdown", |
| 538 | + "metadata": {}, |
| 539 | + "source": [ |
| 540 | + "Creating the BERT attention masks" |
| 541 | + ] |
| 542 | + }, |
511 | 543 | { |
512 | 544 | "cell_type": "code", |
513 | 545 | "execution_count": 0, |
|
579 | 611 | "validation_dataloader = DataLoader(validation_data, sampler=validation_sampler, batch_size=batch_size)\n" |
580 | 612 | ] |
581 | 613 | }, |
| 614 | + { |
| 615 | + "cell_type": "markdown", |
| 616 | + "metadata": {}, |
| 617 | + "source": [ |
| 618 | + "## Training" |
| 619 | + ] |
| 620 | + }, |
582 | 621 | { |
583 | 622 | "cell_type": "code", |
584 | 623 | "execution_count": 0, |
|
913 | 952 | "model.cuda()" |
914 | 953 | ] |
915 | 954 | }, |
| 955 | + { |
| 956 | + "cell_type": "markdown", |
| 957 | + "metadata": {}, |
| 958 | + "source": [ |
| 959 | + "## Fine-Tuning BERT" |
| 960 | + ] |
| 961 | + }, |
916 | 962 | { |
917 | 963 | "cell_type": "code", |
918 | 964 | "execution_count": 0, |
|
1149 | 1195 | "name": "python", |
1150 | 1196 | "nbconvert_exporter": "python", |
1151 | 1197 | "pygments_lexer": "ipython3", |
1152 | | - "version": "3.6.10" |
| 1198 | + "version": "3.6.12" |
1153 | 1199 | } |
1154 | 1200 | }, |
1155 | 1201 | "nbformat": 4, |
|
0 commit comments