Skip to content

Commit 97e062a

Browse files
committed
Updates with Poplar SDK 2.4 release
1 parent 9b61ab6 commit 97e062a

19 files changed

+316
-110
lines changed

applications/pytorch/bert/README.md

Lines changed: 83 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
This directory contains an implementation of BERT models in PyTorch for the IPU, leveraging the HuggingFace Transformers library. There are two examples:
44

5-
1. BERT for pretraining - `run_pretraining.py`
5+
1. BERT for pre-training - `run_pretraining.py`
66
2. BERT for SQuAD - `run_squad.py`
77

88
## Environment setup
@@ -14,31 +14,31 @@ Then, create a virtual environment, install the required packages and build the
1414
```console
1515
virtualenv venv -p python3.6
1616
source venv/bin/activate
17-
pip install -r requirements.txt
17+
pip3 install -r requirements.txt
1818
make
1919
```
2020

21-
## Run the pretraining application
21+
## Run the pre-training application
2222

2323
Setup your environment as explained above and run the example with the configuration of your choice.
2424

2525
```console
26-
python run_pretraining.py --config demo_tiny_128
26+
python3 run_pretraining.py --config demo_tiny_128
2727
```
2828

2929
## Configurations
3030

31-
To see the available configurations for both SQuAD and pretraining see the `configs.yml` file.
31+
To see the available configurations for both SQuAD and pre-training see the `configs.yml` file.
3232

3333
To see the available options available to use in the command line interface use the `--help` argument:
3434

3535
```console
36-
python run_pretraining.py --help
36+
python3 run_pretraining.py --help
3737
# or
38-
python run_squad.py --help
38+
python3 run_squad.py --help
3939
```
4040

41-
## Running pretraining with checkpointing
41+
## Running pre-training with checkpointing
4242

4343
To enable the saving of model checkpoints on a run you need to add `--checkpoint-output-dir <path/to/checkpoint/dir>` to the command line. By default this will save a model checkpoint at the start and end of training.
4444

@@ -48,67 +48,89 @@ To load model weights from a checkpoint directory use the flag `--pretrained-che
4848

4949
## Run the SQuAD application
5050

51-
The question answering with SQuAD example is found in the `run_squad.py` script. Like with pretraining there are SQuAD configs defined in `configs.yml`.
51+
The question answering with SQuAD example is found in the `run_squad.py` script. Like with pre-training there are SQuAD configs defined in `configs.yml`.
5252

5353
To run BERT-Base:
54+
5455
```console
55-
python run_squad.py --config squad_base_384
56+
python3 run_squad.py --config squad_base_384
5657
```
5758

5859
For BERT-Large there is `squad_large_384`, which is a high performance large configuration that uses an 8 IPU pipeline, unlike the other configs that use 4.
5960

60-
You will also need to specify a pretrained checkpoint to fine-tune, which is specified with the `--pretrained-checkpoint <FILE-PATH/HF-model-hub-name>` flag.
61+
You will also need to specify a pre-trained checkpoint to fine-tune, which is specified with the `--pretrained-checkpoint <FILE-PATH/HF-model-hub-name>` flag.
6162

6263
## Caching executables
6364

6465
When running the application, it is possible to save/load executables to/from a cache store. This allows for reusing a saved executable instead of re-compiling the model when re-running identical model configurations. To enable saving/loading from the cache store, use `--executable-cache-dir <relative/path/to/cache/store>` when running the application.
6566

66-
## Running the entire pretraining and SQuAD pipeline
67+
## Running the entire pre-training and SQuAD pipeline
6768

6869
For Base on POD16:
70+
6971
```console
70-
# Phase 1 pretraining
71-
python run_pretraining.py --config pretrain_base_128 --checkpoint-output-dir checkpoints/pretrain_base_128
72+
# Phase 1 pre-training
73+
python3 run_pretraining.py --config pretrain_base_128 --checkpoint-output-dir checkpoints/pretrain_base_128
7274

73-
# Phase 2 pretraining
74-
python run_pretraining.py --config pretrain_base_384 --checkpoint-output-dir checkpoints/pretrain_base_384 --pretrained-checkpoint checkpoints/pretrain_base_128/step_N/
75+
# Phase 2 pre-training
76+
python3 run_pretraining.py --config pretrain_base_384 --checkpoint-output-dir checkpoints/pretrain_base_384 --pretrained-checkpoint checkpoints/pretrain_base_128/step_N/
77+
78+
# To do phase 2 pretraining with a sequence length of 512, simply replace `384` with `512`.
7579

7680
# SQuAD fine-tuning
77-
python run_squad.py --config squad_base_384 --pretrained-checkpoint checkpoints/pretrain_base_384/step_N/
81+
python3 run_squad.py --config squad_base_384 --pretrained-checkpoint checkpoints/pretrain_base_384/step_N/
7882
```
7983

8084
For Large on POD16:
85+
8186
```console
8287
# Phase 1 pretraining
83-
python run_pretraining.py --config pretrain_large_128 --checkpoint-output-dir checkpoints/pretrain_large_128
88+
python3 run_pretraining.py --config pretrain_large_128 --checkpoint-output-dir checkpoints/pretrain_large_128
8489

8590
# Phase 2 pretraining
86-
python run_pretraining.py --config pretrain_large_384 --checkpoint-output-dir checkpoints/pretrain_large_384 --pretrained-checkpoint checkpoints/pretrain_large_128/step_N/
91+
python3 run_pretraining.py --config pretrain_large_384 --checkpoint-output-dir checkpoints/pretrain_large_384 --pretrained-checkpoint checkpoints/pretrain_large_128/step_N/
92+
93+
# To do the same on POD64, simply append `_POD64` to the pretraining config names. To do phase 2 pretraining with a sequence length of 512, simply replace `384` with `512`.
8794

8895
# SQuAD fine-tuning
89-
python run_squad.py --config squad_large_384 --pretrained-checkpoint checkpoints/pretrain_large_384/step_N/
96+
python3 run_squad.py --config squad_large_384 --pretrained-checkpoint checkpoints/pretrain_large_384/step_N/
9097
```
9198

9299
To do the same on POD64, simply append `_POD64` to the pretraining config names.
93100

94-
## Run the tests (optional)
101+
## POD128 configurations
102+
103+
PopDist and PopRun allow to seamlessly launch applications on large IPU-POD systems such as POD128. Further details about them can be found in the [docs]( https://docs.graphcore.ai/projects/poprun-user-guide/en/latest/index.html).
104+
105+
We provide utility scripts to run the phase 1 and phase 2 pretraining in POD128. They can be executed as:
106+
107+
```console
108+
# Phase 1 pretraining in POD128
109+
bash training_scripts/pretrain_large_128_POD128.sh
110+
111+
# Phase 2 pretraining in POD128
112+
bash training_scripts/pretrain_large_384_POD128.sh
113+
```
95114

96-
Setup your environment and generate the sample dataset as explained above and run `python -m pytest` from the root folder.
115+
The resulting pretraining checkpoint can be fine-tuned for SQuAD in a POD16 as described before.
97116

117+
## Run the tests (optional)
118+
119+
Setup your environment and generate the sample dataset as explained above and run `python3 -m pytest` from the root folder.
98120

99121
## Generate sample_text dataset (optional)
100122

101123
The sample text provided enables training on a very small dataset for small scale testing.
102-
For convenience it is already provided in the `/data` folder in txt and tfrecord format.
124+
For convenience it is already provided in the `/data` folder in `txt` and `tfrecord` format.
103125
In order to re-generate the sample dataset, run the following script:
104126

105127
```console
106-
python third_party/create_pretraining_data.py --input-file data/sample_text.txt --output-file data/sample_text.tfrecord --sequence-length 128 --mask-tokens 20 --duplication-factor 4 --do-lower-case --model bert-base-uncased
128+
python3 third_party/create_pretraining_data.py --input-file data/sample_text.txt --output-file data/sample_text.tfrecord --sequence-length 128 --mask-tokens 20 --duplication-factor 4 --do-lower-case --model bert-base-uncased
107129
```
108130

109131
## Generate pretraining dataset (optional)
110132

111-
The dataset used for pretraining is WIKI-103. It can be generated from a RAW dump of Wikipedia following a four step process.
133+
The dataset used for pretraining is WIKI-103. It can be generated from a RAW dump of Wikipedia following a five step process.
112134

113135
### 1. Download
114136

@@ -118,13 +140,13 @@ Use the `wikipedia_download.sh` script to download the latest Wikipedia dump, ab
118140
./data/wikipedia_download.sh <chosen-path-for-dump-file>
119141
```
120142

121-
Dumps are available from https://dumps.wikimedia.org/ (and mirrors) and are licensed under CC BY-SA 3.0 and GNU Free Documentation Licenses.
143+
Dumps are available from <https://dumps.wikimedia.org/> (and mirrors) and are licensed under CC BY-SA 3.0 and GNU Free Documentation Licenses.
122144

123145
### 2. Extraction
124146

125147
In order to create the pre-training data we need to extract the Wikipedia dump and put it in this form:
126148

127-
```
149+
```text
128150
<doc id = article1>
129151
Title of article 1
130152
@@ -141,46 +163,69 @@ Body of article 2
141163

142164
and so on.
143165

144-
One of the tools that can be used to do so is WikiExtractor, https://github.com/attardi/wikiextractor.
166+
One of the tools that can be used to do so is WikiExtractor, <https://github.com/attardi/wikiextractor>.
167+
Install the WikiExtractor package with `pip3 install wikiextractor`.
168+
169+
In order not to encounter a `UnicodeEncodeError` at this step, you may want to run these two commands first:
170+
171+
```console
172+
export PYTHONIOENCODING=utf-8
173+
export LC_ALL=C.UTF-8
174+
```
145175

146-
You can use the the `wikipedia_extract.sh` script to use WikiExtractor to extract the data dump.
176+
You can then use the the `wikipedia_extract.sh` script to use WikiExtractor to extract the data dump.
147177

148178
```console
149179
./data/wikipedia_extract.sh <chosen-path-for-dump-file>/wikidump.xml <chosen-folder-for-extracted-files>
150180
```
151181

152-
The result should be a folder containing directories named `AA`, `AB`...
182+
The result should be a folder containing directories named `AA`, `AB`, ...
183+
Note that the number of directories depends on the parameters of the `wikipedia_extract.sh` script, and is not to be confused with alphabetical ordering of the wikipedia articles.
184+
In other words you should probably not expect all of `AC`, `AD`, ... `ZX`, `ZY`, `ZZ` to be created by the script.
153185

154186
### 3. Pre-processing
155187

156-
Install nltk package with `pip install nltk`.
188+
Install nltk package with `pip3 install nltk`.
157189
Use the `wikipedia_preprocess.py` script to preprocess the extracted files.
158190

159191
```console
160-
./data/wikipedia_preprocess.py --input-file-path <chosen-folder-for-extracted-files> --output-file-path <chosen-folder-for-preprocessed-files>
192+
python3 ./data/wikipedia_preprocess.py --input-file-path <chosen-folder-for-extracted-files> --output-file-path <chosen-folder-for-preprocessed-files>
161193
```
162194

163195
### 4. Tokenization
164196

165-
The script `create_pretraining_data.py` can accept a glob of input files to tokenise. However, attempting to process them all at once may result in the process being killed by the OS for consuming too much memory. It is therefore preferable to convert the files in groups. This is handled by the `./data/wikipedia_tokenize.py` script. At the same time, it is worth bearing in mind that `create_pretraining_data.py` shuffles the training instances across the loaded group of files, so a larger group would result in better shuffling of the samples seen by BERT during pre-training.
197+
The script `create_pretraining_data.py` can accept a glob of input files to tokenize.
198+
However, attempting to process them all at once may result in the process being killed by the OS for consuming too much memory.
199+
It is therefore preferable to convert the files in groups. This is handled by the `./data/wikipedia_tokenize.py` script.
200+
At the same time, it is worth bearing in mind that `create_pretraining_data.py` shuffles the training instances across the loaded group of files, so a larger group would result in better shuffling of the samples seen by BERT during pre-training.
201+
202+
The tokenization depends on `tensorflow` which can be installed by `pip3 install tensorflow`.
166203

167204
sequence length 128
205+
168206
```console
169-
./data/wikipedia_tokenize.py <chosen-folder-for-preprocessed-files> <chosen-folder-for-dataset-files> --sequence-length 128 --mask-tokens 20
207+
python3 ./data/wikipedia_tokenize.py <chosen-folder-for-preprocessed-files> <chosen-folder-for-dataset-files> --sequence-length 128 --mask-tokens 20
170208
```
171209

172210
sequence length 384
211+
212+
```console
213+
python3 ./data/wikipedia_tokenize.py <chosen-folder-for-preprocessed-files> <chosen-folder-for-dataset-files> --sequence-length 384 --mask-tokens 56
214+
```
215+
216+
sequence length 512
217+
173218
```console
174-
./data/wikipedia_tokenize.py <chosen-folder-for-preprocessed-files> <chosen-folder-for-dataset-files> --sequence-length 384 --mask-tokens 56
219+
python3 ./data/wikipedia_tokenize.py <chosen-folder-for-preprocessed-files> <chosen-folder-for-dataset-files> --sequence-length 512 --mask-tokens 76
175220
```
176221

177-
### Indexing
222+
### 5. Indexing
178223

179-
In order to use the multi-threaded dataloader, tfrecord index files need to be generated.
224+
In order to use the multi-threaded `dataloader`, `tfrecord` index files need to be generated.
180225
First install the `tfrecord` Python package into your Python environment:
181226

182227
```console
183-
pip install tfrecord
228+
pip3 install tfrecord
184229
```
185230

186231
Then go to the directory containing the pre-processed Wikipedia files and run:

applications/pytorch/bert/README_Benchmarks.md

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@ Run the following commands from inside the applications/pytorch/bert/ directory.
1818

1919
Command:
2020
```console
21-
python run_pretraining.py --config pretrain_base_128 --training-steps 10 --input-file $DATASETS_DIR/wikipedia/128/wiki_1[0-1]*.tfrecord --disable-progress-bar
21+
python3 run_pretraining.py \
22+
--config pretrain_base_128 \
23+
--training-steps 10 \
24+
--input-file $DATASETS_DIR/wikipedia/128/wiki_1[0-1]*.tfrecord \
25+
--disable-progress-bar
2226
```
2327

2428
### Pretrain BERT-Base Sequence Length 384
@@ -27,7 +31,11 @@ python run_pretraining.py --config pretrain_base_128 --training-steps 10 --input
2731

2832
Command:
2933
```console
30-
python run_pretraining.py --config pretrain_base_384 --training-steps 10 --input-file $DATASETS_DIR/wikipedia/384/wiki_1[0-1]*.tfrecord --disable-progress-bar
34+
python3 run_pretraining.py \
35+
--config pretrain_base_384 \
36+
--training-steps 10 \
37+
--input-file $DATASETS_DIR/wikipedia/384/wiki_1[0-1]*.tfrecord \
38+
--disable-progress-bar
3139
```
3240

3341
### Pretrain BERT-Large Sequence Length 128
@@ -36,17 +44,24 @@ python run_pretraining.py --config pretrain_base_384 --training-steps 10 --input
3644

3745
Command:
3846
```console
39-
python run_pretraining.py --config pretrain_large_128 --training-steps 10 --input-file $DATASETS_DIR/wikipedia/128/wiki_1[0-1]*.tfrecord --disable-progress-bar
47+
python3 run_pretraining.py \
48+
--config pretrain_large_128 \
49+
--training-steps 10 \
50+
--input-file $DATASETS_DIR/wikipedia/128/wiki_1[0-1]*.tfrecord \
51+
--disable-progress-bar
4052
```
4153

4254
#### 1 x IPU-POD64
4355

4456
Command:
4557
```console
46-
python run_pretraining.py --config pretrain_large_128_POD64 --training-steps 10 --input-file $DATASETS_DIR/wikipedia/128/wiki_1[0-1]*.tfrecord --disable-progress-bar
58+
python3 run_pretraining.py \
59+
--config pretrain_large_128_POD64 \
60+
--training-steps 10 \
61+
--input-file $DATASETS_DIR/wikipedia/128/wiki_1[0-1]*.tfrecord \
62+
--disable-progress-bar
4763
```
4864

49-
5065
#### 1 x IPU-POD128
5166

5267
#### 1 x IPU-POD128
@@ -91,7 +106,7 @@ python run_pretraining.py --config pretrain_large_128_POD64 --replication-factor
91106
--replicated-tensor-sharding True \
92107
--random-seed 1984 \
93108
--input-files $DATASETS_DIR/wikipedia/torch_bert/128/*.tfrecord
94-
```CAL_HOME}/exec_cache" python run_pretraining.py --config configs/pretrain_large_128_phase1_POD128.json --train-file "$DATASETS_DIR/tf_wikipedia/tokenised_128_dup5_mask20/*.tfrecord"
109+
95110
```
96111

97112
### Pretrain BERT-Large Sequence Length 384
@@ -100,14 +115,22 @@ python run_pretraining.py --config pretrain_large_128_POD64 --replication-factor
100115

101116
Command:
102117
```console
103-
python run_pretraining.py --config pretrain_large_384 --training-steps 10 --input-file $DATASETS_DIR/wikipedia/384/wiki_1[0-1]*.tfrecord --disable-progress-bar
118+
python3 run_pretraining.py \
119+
--config pretrain_large_384 \
120+
--training-steps 10 \
121+
--input-file $DATASETS_DIR/wikipedia/384/wiki_1[0-1]*.tfrecord \
122+
--disable-progress-bar
104123
```
105124

106125
#### 1 x IPU-POD64
107126

108127
Command:
109128
```console
110-
python run_pretraining.py --config pretrain_large_384_POD64 --training-steps 10 --input-file $DATASETS_DIR/wikipedia/384/wiki_1[0-1]*.tfrecord --disable-progress-bar
129+
python3 run_pretraining.py \
130+
--config pretrain_large_384_POD64 \
131+
--training-steps 10 \
132+
--input-file $DATASETS_DIR/wikipedia/384/wiki_1[0-1]*.tfrecord \
133+
--disable-progress-bar
111134
```
112135

113136
#### 1 x IPU-POD128

applications/pytorch/bert/args.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ def parse_bert_args(args=None):
6262
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
6363

6464
# Execution
65-
parser.add_argument("--batch-size", type=int, help="Set the micro batch-size")
65+
parser.add_argument("--micro-batch-size", type=int,
66+
help="Set the micro-batch-size. This is the single forward-backward path batch-size on one replica")
6667
parser.add_argument("--training-steps", type=int, help="Number of training steps")
6768
parser.add_argument("--batches-per-step", type=int, help="Number of batches per training step")
6869
parser.add_argument("--replication-factor", type=int, help="Number of replicas")
@@ -208,11 +209,11 @@ def parse_bert_args(args=None):
208209
parser.error("checkpoint-steps must be >=1")
209210

210211
if args.use_popdist:
211-
args.global_batch_size = args.replication_factor * args.gradient_accumulation * args.batch_size * args.popdist_size
212+
args.global_batch_size = args.replication_factor * args.gradient_accumulation * args.micro_batch_size * args.popdist_size
212213
else:
213-
args.global_batch_size = args.replication_factor * args.gradient_accumulation * args.batch_size
214+
args.global_batch_size = args.replication_factor * args.gradient_accumulation * args.micro_batch_size
214215

215-
args.samples_per_step = args.replication_factor * args.gradient_accumulation * args.batch_size * args.batches_per_step
216+
args.samples_per_step = args.replication_factor * args.gradient_accumulation * args.micro_batch_size * args.batches_per_step
216217
args.intermediate_size = args.hidden_size * 4
217218

218219
return args

applications/pytorch/bert/checkpointing.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,23 @@ def checkpoints_exist(path):
2727
return False
2828

2929

30-
def save_checkpoint(config, model, step, metrics=None):
30+
def save_checkpoint(config, model, step, optimizer=None, metrics=None):
3131
if config.checkpoint_output_dir:
3232
path = os.path.join(os.path.abspath(config.checkpoint_output_dir), f"step_{step}")
3333
os.makedirs(path, exist_ok=True)
3434

3535
logger(f"Saving checkpoint for step {step} to: {path}\n")
3636
model.save_pretrained(path)
37-
torch.save({
38-
"step": step,
39-
"metrics": metrics,
40-
"config": config
41-
}, os.path.join(path, "training_state.pt"))
37+
if optimizer is None:
38+
torch.save({
39+
"step": step,
40+
"metrics": metrics,
41+
"config": config
42+
}, os.path.join(path, "training_state.pt"))
43+
else:
44+
torch.save({
45+
"step": step,
46+
"optimizer_state_dict": optimizer.state_dict(),
47+
"metrics": metrics,
48+
"config": config
49+
}, os.path.join(path, "training_state.pt"))

0 commit comments

Comments
 (0)