From e4967a564e3f215cd6f76f67b6722d8c5ee25fb5 Mon Sep 17 00:00:00 2001 From: Callum McLean Date: Tue, 20 Jun 2023 16:40:52 +0100 Subject: [PATCH 1/8] Add `run_speech_recognition_seq2seq.py` --- .../run_speech_recognition_seq2seq.py | 618 ++++++++++++++++++ 1 file changed, 618 insertions(+) create mode 100755 examples/speech-recognition/run_speech_recognition_seq2seq.py diff --git a/examples/speech-recognition/run_speech_recognition_seq2seq.py b/examples/speech-recognition/run_speech_recognition_seq2seq.py new file mode 100755 index 000000000..12d82b86e --- /dev/null +++ b/examples/speech-recognition/run_speech_recognition_seq2seq.py @@ -0,0 +1,618 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Fine-tuning the library models for sequence to sequence speech recognition. +""" +# You can also adapt this script on your own sequence to sequence speech +# recognition task. Pointers for this are left as comments. + +import logging +import os +import sys +import warnings +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Union + +import datasets +import evaluate +import torch +from datasets import DatasetDict, load_dataset + +import transformers +from transformers import ( + AutoConfig, + AutoFeatureExtractor, + AutoModelForSpeechSeq2Seq, + AutoProcessor, + AutoTokenizer, + HfArgumentParser, + WhisperProcessor, + set_seed, +) +from transformers.trainer_utils import get_last_checkpoint +from transformers.utils import check_min_version, send_example_telemetry +from transformers.utils.versions import require_version + +from optimum.graphcore import IPUConfig, IPUSeq2SeqTrainer +from optimum.graphcore import IPUSeq2SeqTrainingArguments as Seq2SeqTrainingArguments + +# Will error if the minimal version of Transformers is not installed. Remove at your own risks. +check_min_version("4.29.0") + +require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt") + +logger = logging.getLogger(__name__) + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ + + model_name_or_path: str = field( + metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} + ) + config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + tokenizer_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + feature_extractor_name: Optional[str] = field( + default=None, metadata={"help": "feature extractor name or path if not the same as model_name"} + ) + cache_dir: Optional[str] = field( + default=None, + metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"}, + ) + use_fast_tokenizer: bool = field( + default=True, + metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, + ) + model_revision: str = field( + default="main", + metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, + ) + use_auth_token: bool = field( + default=False, + metadata={ + "help": ( + "Will use the token generated when running `huggingface-cli login` (necessary to use this script " + "with private models)." + ) + }, + ) + freeze_feature_encoder: bool = field( + default=True, metadata={"help": "Whether to freeze the feature encoder layers of the model."} + ) + freeze_encoder: bool = field( + default=False, metadata={"help": "Whether to freeze the entire encoder of the seq2seq model."} + ) + forced_decoder_ids: List[List[int]] = field( + default=None, + metadata={ + "help": ( + "A list of pairs of integers which indicates a mapping from generation indices to token indices " + "that will be forced before sampling. For example, [[0, 123]] means the first generated token " + "will always be a token of index 123." + ) + }, + ) + suppress_tokens: List[int] = field( + default=None, metadata={"help": "A list of tokens that will be suppressed at generation."} + ) + apply_spec_augment: bool = field( + default=False, + metadata={ + "help": "Whether to apply *SpecAugment* data augmentation to the input features. This is currently only relevant for Wav2Vec2, HuBERT, WavLM and Whisper models." + }, + ) + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + dataset_name: str = field( + default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} + ) + dataset_config_name: Optional[str] = field( + default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} + ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": ( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ) + }, + ) + max_eval_samples: Optional[int] = field( + default=None, + metadata={ + "help": ( + "For debugging purposes or quicker training, truncate the number of evaluation examples to this " + "value if set." + ) + }, + ) + audio_column_name: str = field( + default="audio", + metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"}, + ) + text_column_name: str = field( + default="text", + metadata={"help": "The name of the dataset column containing the text data. Defaults to 'text'"}, + ) + max_duration_in_seconds: float = field( + default=20.0, + metadata={ + "help": ( + "Truncate audio files that are longer than `max_duration_in_seconds` seconds to" + " 'max_duration_in_seconds`" + ) + }, + ) + min_duration_in_seconds: float = field( + default=0.0, metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"} + ) + preprocessing_only: bool = field( + default=False, + metadata={ + "help": ( + "Whether to only do data preprocessing and skip training. This is especially useful when data" + " preprocessing errors out in distributed training due to timeout. In this case, one should run the" + " preprocessing in a non-distributed setup with `preprocessing_only=True` so that the cached datasets" + " can consequently be loaded in distributed training" + ) + }, + ) + train_split_name: str = field( + default="train", + metadata={ + "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'" + }, + ) + eval_split_name: str = field( + default="test", + metadata={ + "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'" + }, + ) + do_lower_case: bool = field( + default=True, + metadata={"help": "Whether the target text should be lower cased."}, + ) + language: str = field( + default=None, + metadata={ + "help": ( + "Language for multilingual fine-tuning. This argument should be set for multilingual fine-tuning " + "only. For English speech recognition, it should be set to `None`." + ) + }, + ) + task: str = field( + default="transcribe", + metadata={"help": "Task, either `transcribe` for speech recognition or `translate` for speech translation."}, + ) + + +@dataclass +class DataCollatorSpeechSeq2SeqWithPadding: + """ + Data collator that will dynamically pad the inputs received. + Args: + processor ([`WhisperProcessor`]) + The processor used for processing the data. + decoder_start_token_id (`int`) + The begin-of-sentence of the decoder. + forward_attention_mask (`bool`) + Whether to return attention_mask. + """ + + processor: Any + decoder_start_token_id: int + forward_attention_mask: bool + padding: Union[bool, str] = "longest" + pad_to_multiple_of: Optional[int] = None + pad_to_multiple_of_labels: Optional[int] = None + + def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: + # split inputs and labels since they have to be of different lengths and need + # different padding methods + model_input_name = self.processor.model_input_names[0] + input_features = [{model_input_name: feature[model_input_name]} for feature in features] + label_features = [{"input_ids": feature["labels"]} for feature in features] + + batch = self.processor.feature_extractor.pad( + input_features, + pad_to_multiple_of=self.pad_to_multiple_of, + return_tensors="pt" + ) + + if self.forward_attention_mask: + batch["attention_mask"] = torch.LongTensor([feature["attention_mask"] for feature in features]) + + labels_batch = self.processor.tokenizer.pad( + label_features, + pad_to_multiple_of=self.pad_to_multiple_of_labels, + return_tensors="pt" + ) + + # replace padding with -100 to ignore loss correctly + labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100) + + # if bos token is appended in previous tokenization step, + # cut bos token here as it's append later anyways + if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item(): + labels = labels[:, 1:] + + batch["labels"] = labels + + return batch.data + + +def main(): + # 1. Parse input arguments + # See all possible arguments in src/transformers/training_args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments)) + + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + if training_args.gradient_checkpointing: + raise ValueError("Gradient checkpointing not supported.") + + # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The + # information sent is the one passed as arguments along with your Python/PyTorch versions. + send_example_telemetry("run_speech_recognition_seq2seq", model_args, data_args) + + # 2. Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + log_level = training_args.get_process_log_level() + logger.setLevel(log_level) + datasets.utils.logging.set_verbosity(log_level) + transformers.utils.logging.set_verbosity(log_level) + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + + # Log on each process the small summary: + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + ) + logger.info(f"Training/evaluation parameters {training_args}") + + # 3. Detecting last checkpoint and eventually continue from last checkpoint + last_checkpoint = None + if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + + # Set seed before initializing model. + set_seed(training_args.seed) + + # 4. Load dataset + raw_datasets = DatasetDict() + + if training_args.do_train: + raw_datasets["train"] = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + split=data_args.train_split_name, + cache_dir=model_args.cache_dir, + use_auth_token=True if model_args.use_auth_token else None, + ) + + if training_args.do_eval: + raw_datasets["eval"] = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + split=data_args.eval_split_name, + cache_dir=model_args.cache_dir, + use_auth_token=True if model_args.use_auth_token else None, + ) + + if data_args.audio_column_name not in next(iter(raw_datasets.values())).column_names: + raise ValueError( + f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. " + "Make sure to set `--audio_column_name` to the correct audio column - one of " + f"{', '.join(next(iter(raw_datasets.values())).column_names)}." + ) + + if data_args.text_column_name not in next(iter(raw_datasets.values())).column_names: + raise ValueError( + f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. " + "Make sure to set `--text_column_name` to the correct text column - one of " + f"{', '.join(next(iter(raw_datasets.values())).column_names)}." + ) + + # 5. Load pretrained model, tokenizer, and feature extractor + # + # Distributed training: + # The .from_pretrained methods guarantee that only one local process can concurrently + config = AutoConfig.from_pretrained( + model_args.config_name if model_args.config_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + + config.update({"forced_decoder_ids": model_args.forced_decoder_ids, "suppress_tokens": model_args.suppress_tokens}) + + # SpecAugment for whisper models + if getattr(config, "model_type", None) == "whisper": + if model_args.apply_spec_augment: + raise ValueError("SpecAugment is not supported on IPU") + config.update({"apply_spec_augment": model_args.apply_spec_augment}) + + # IPU specific config updates + config.update({"apply_spec_augment": False}) + + # Whisper does not have a layer_norm_eps option, remains to be seen if this is a problem + #config.update({"layer_norm_eps": 0.0001}) + + feature_extractor = AutoFeatureExtractor.from_pretrained( + model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + use_fast=model_args.use_fast_tokenizer, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + model = AutoModelForSpeechSeq2Seq.from_pretrained( + model_args.model_name_or_path, + config=config, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + ipu_config = IPUConfig.from_pretrained( + training_args.ipu_config_name if training_args.ipu_config_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + use_auth_token=True if data_args.use_auth_token else None, + ) + if model.config.decoder_start_token_id is None: + raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") + + if model_args.freeze_feature_encoder: + model.freeze_feature_encoder() + + if model_args.freeze_encoder: + model.freeze_encoder() + model.model.encoder.gradient_checkpointing = False + + if data_args.language is not None: + # We only need to set the task id when the language is specified (i.e. in a multilingual setting) + tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task) + + # 6. Resample speech dataset if necessary + dataset_sampling_rate = next(iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate + if dataset_sampling_rate != feature_extractor.sampling_rate: + raw_datasets = raw_datasets.cast_column( + data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate) + ) + + # 7. Preprocessing the datasets. + # We need to read the audio files as arrays and tokenize the targets. + max_input_length = data_args.max_duration_in_seconds * feature_extractor.sampling_rate + min_input_length = data_args.min_duration_in_seconds * feature_extractor.sampling_rate + audio_column_name = data_args.audio_column_name + num_workers = data_args.preprocessing_num_workers + text_column_name = data_args.text_column_name + model_input_name = feature_extractor.model_input_names[0] + do_lower_case = data_args.do_lower_case + # if SpecAugment is used for whisper models, return attention_mask to guide the mask along time axis + forward_attention_mask = ( + getattr(config, "model_type", None) == "whisper" + and getattr(config, "apply_spec_augment", False) + and getattr(config, "mask_time_prob", 0) > 0 + ) + + if data_args.max_train_samples is not None: + raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples)) + + if data_args.max_eval_samples is not None: + raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples)) + + def prepare_dataset(batch, feature_extractor, tokenizer): + # process audio + sample = batch[audio_column_name] + inputs = feature_extractor( + sample["array"], sampling_rate=sample["sampling_rate"], return_attention_mask=forward_attention_mask + ) + # process audio length + batch[model_input_name] = inputs.get(model_input_name)[0] + batch["input_length"] = len(sample["array"]) + if forward_attention_mask: + batch["attention_mask"] = inputs.get("attention_mask")[0] + + if not training_args.fp32: + # Cast audio inputs to FP16 + batch["input_values"] = batch["input_values"].astype(np.float16) + + # process targets + input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name] + batch["labels"] = tokenizer(input_str).input_ids + return batch + + with training_args.main_process_first(desc="dataset map pre-processing"): + vectorized_datasets = raw_datasets.map( + lambda batch: prepare_dataset(batch, feature_extractor, tokenizer), + remove_columns=next(iter(raw_datasets.values())).column_names, + num_proc=data_args.preprocessing_num_workers, + desc="preprocess train dataset", + ) + + # filter data that is shorter than min_input_length or longer than + # max_input_length + def is_audio_in_length_range(length): + return length > min_input_length and length < max_input_length + + vectorized_datasets = vectorized_datasets.filter( + is_audio_in_length_range, + num_proc=num_workers, + input_columns=["input_length"], + ) + + # for large datasets it is advised to run the preprocessing on a + # single machine first with `args.preprocessing_only` since there will mostly likely + # be a timeout when running the script in distributed mode. + # In a second step `args.preprocessing_only` can then be set to `False` to load the + # cached dataset + if data_args.preprocessing_only: + cache = {k: v.cache_files for k, v in vectorized_datasets.items()} + logger.info(f"Data preprocessing finished. Files cached at {cache}.") + return + + # 8. Load Metric + metric = evaluate.load("wer") + + def compute_metrics(pred): + pred_ids = pred.predictions + + pred.label_ids[pred.label_ids == -100] = tokenizer.pad_token_id + + pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True) + # we do not want to group tokens when computing the metrics + label_str = tokenizer.batch_decode(pred.label_ids, skip_special_tokens=True) + + wer = metric.compute(predictions=pred_str, references=label_str) + + return {"wer": wer} + + # 9. Create a single speech processor + # save feature extractor, tokenizer and config + feature_extractor.save_pretrained(training_args.output_dir) + tokenizer.save_pretrained(training_args.output_dir) + config.save_pretrained(training_args.output_dir) + + processor = WhisperProcessor(feature_extractor, tokenizer) + + # 10. Define data collator + data_collator = DataCollatorSpeechSeq2SeqWithPadding( + processor=processor, + decoder_start_token_id=model.config.decoder_start_token_id, + forward_attention_mask=forward_attention_mask, + pad_to_multiple_of=max_input_length, + pad_to_multiple_of_labels=500 + ) + + # 11. Initialize Trainer + trainer = IPUSeq2SeqTrainer( + model=model, + ipu_config=ipu_config, + args=training_args, + train_dataset=vectorized_datasets["train"] if training_args.do_train else None, + eval_dataset=vectorized_datasets["eval"] if training_args.do_eval else None, + data_collator=data_collator, + compute_metrics=compute_metrics if training_args.predict_with_generate else None, + ) + + # 12. Training + if training_args.do_train: + checkpoint = None + if training_args.resume_from_checkpoint is not None: + checkpoint = training_args.resume_from_checkpoint + elif last_checkpoint is not None: + checkpoint = last_checkpoint + train_result = trainer.train(resume_from_checkpoint=checkpoint) + trainer.save_model() # Saves the feature extractor too for easy upload + + metrics = train_result.metrics + max_train_samples = ( + data_args.max_train_samples + if data_args.max_train_samples is not None + else len(vectorized_datasets["train"]) + ) + metrics["train_samples"] = min(max_train_samples, len(vectorized_datasets["train"])) + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_state() + + # 13. Evaluation + results = {} + if training_args.do_eval: + logger.info("*** Evaluate ***") + metrics = trainer.evaluate( + metric_key_prefix="eval", + max_length=training_args.generation_max_length, + num_beams=training_args.generation_num_beams, + ) + max_eval_samples = ( + data_args.max_eval_samples if data_args.max_eval_samples is not None else len(vectorized_datasets["eval"]) + ) + metrics["eval_samples"] = min(max_eval_samples, len(vectorized_datasets["eval"])) + + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + # 14. Write Training Stats + kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "automatic-speech-recognition"} + if data_args.dataset_name is not None: + kwargs["dataset_tags"] = data_args.dataset_name + if data_args.dataset_config_name is not None: + kwargs["dataset_args"] = data_args.dataset_config_name + kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}" + else: + kwargs["dataset"] = data_args.dataset_name + + if training_args.push_to_hub: + trainer.push_to_hub(**kwargs) + else: + trainer.create_model_card(**kwargs) + + return results + + +if __name__ == "__main__": + main() From 6e01df2977d717eef3f5681539cb479a3e531a3c Mon Sep 17 00:00:00 2001 From: Callum McLean Date: Tue, 4 Jul 2023 16:32:21 +0100 Subject: [PATCH 2/8] Script now working with interleaved training + validation --- .../run_speech_recognition_seq2seq.py | 36 ++++++++++--------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/examples/speech-recognition/run_speech_recognition_seq2seq.py b/examples/speech-recognition/run_speech_recognition_seq2seq.py index 12d82b86e..692aa0286 100755 --- a/examples/speech-recognition/run_speech_recognition_seq2seq.py +++ b/examples/speech-recognition/run_speech_recognition_seq2seq.py @@ -20,6 +20,7 @@ # recognition task. Pointers for this are left as comments. import logging +import math import os import sys import warnings @@ -28,6 +29,7 @@ import datasets import evaluate +import numpy as np import torch from datasets import DatasetDict, load_dataset @@ -312,10 +314,6 @@ def main(): transformers.utils.logging.enable_explicit_format() # Log on each process the small summary: - logger.warning( - f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" - f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" - ) logger.info(f"Training/evaluation parameters {training_args}") # 3. Detecting last checkpoint and eventually continue from last checkpoint @@ -419,8 +417,9 @@ def main(): ipu_config = IPUConfig.from_pretrained( training_args.ipu_config_name if training_args.ipu_config_name else model_args.model_name_or_path, cache_dir=model_args.cache_dir, - use_auth_token=True if data_args.use_auth_token else None, + use_auth_token=True if model_args.use_auth_token else None, ) + if model.config.decoder_start_token_id is None: raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") @@ -470,7 +469,7 @@ def prepare_dataset(batch, feature_extractor, tokenizer): inputs = feature_extractor( sample["array"], sampling_rate=sample["sampling_rate"], return_attention_mask=forward_attention_mask ) - # process audio length + batch[model_input_name] = inputs.get(model_input_name)[0] batch["input_length"] = len(sample["array"]) if forward_attention_mask: @@ -478,20 +477,19 @@ def prepare_dataset(batch, feature_extractor, tokenizer): if not training_args.fp32: # Cast audio inputs to FP16 - batch["input_values"] = batch["input_values"].astype(np.float16) + batch[model_input_name] = batch[model_input_name].astype(np.float16) # process targets input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name] batch["labels"] = tokenizer(input_str).input_ids return batch - with training_args.main_process_first(desc="dataset map pre-processing"): - vectorized_datasets = raw_datasets.map( - lambda batch: prepare_dataset(batch, feature_extractor, tokenizer), - remove_columns=next(iter(raw_datasets.values())).column_names, - num_proc=data_args.preprocessing_num_workers, - desc="preprocess train dataset", - ) + vectorized_datasets = raw_datasets.map( + lambda batch: prepare_dataset(batch, feature_extractor, tokenizer), + remove_columns=next(iter(raw_datasets.values())).column_names, + num_proc=data_args.preprocessing_num_workers, + desc="preprocess train dataset", + ) # filter data that is shorter than min_input_length or longer than # max_input_length @@ -543,8 +541,9 @@ def compute_metrics(pred): processor=processor, decoder_start_token_id=model.config.decoder_start_token_id, forward_attention_mask=forward_attention_mask, - pad_to_multiple_of=max_input_length, - pad_to_multiple_of_labels=500 + #pad_to_multiple_of=math.ceil(max_input_length), + pad_to_multiple_of=80, + pad_to_multiple_of_labels=training_args.generation_max_length ) # 11. Initialize Trainer @@ -556,6 +555,11 @@ def compute_metrics(pred): eval_dataset=vectorized_datasets["eval"] if training_args.do_eval else None, data_collator=data_collator, compute_metrics=compute_metrics if training_args.predict_with_generate else None, + eval_parallelize_kwargs={ + 'use_cache': True, + 'use_cross_cache': True, + 'max_length': training_args.generation_max_length + } ) # 12. Training From 5ac00cab369ebcc87ccd863cd42b9f79e2d0ed23 Mon Sep 17 00:00:00 2001 From: Callum McLean Date: Tue, 11 Jul 2023 17:33:26 +0100 Subject: [PATCH 3/8] `make style` + add diff .txt to tests --- .../run_speech_recognition_seq2seq.py | 35 ++-- .../run_speech_recognition_seq2seq.txt | 151 ++++++++++++++++++ 2 files changed, 165 insertions(+), 21 deletions(-) create mode 100644 tests/examples/run_speech_recognition_seq2seq.txt diff --git a/examples/speech-recognition/run_speech_recognition_seq2seq.py b/examples/speech-recognition/run_speech_recognition_seq2seq.py index 692aa0286..9ba32c4ca 100755 --- a/examples/speech-recognition/run_speech_recognition_seq2seq.py +++ b/examples/speech-recognition/run_speech_recognition_seq2seq.py @@ -20,10 +20,8 @@ # recognition task. Pointers for this are left as comments. import logging -import math import os import sys -import warnings from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Union @@ -31,14 +29,12 @@ import evaluate import numpy as np import torch -from datasets import DatasetDict, load_dataset - import transformers +from datasets import DatasetDict, load_dataset from transformers import ( AutoConfig, AutoFeatureExtractor, AutoModelForSpeechSeq2Seq, - AutoProcessor, AutoTokenizer, HfArgumentParser, WhisperProcessor, @@ -51,6 +47,7 @@ from optimum.graphcore import IPUConfig, IPUSeq2SeqTrainer from optimum.graphcore import IPUSeq2SeqTrainingArguments as Seq2SeqTrainingArguments + # Will error if the minimal version of Transformers is not installed. Remove at your own risks. check_min_version("4.29.0") @@ -252,18 +249,14 @@ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> label_features = [{"input_ids": feature["labels"]} for feature in features] batch = self.processor.feature_extractor.pad( - input_features, - pad_to_multiple_of=self.pad_to_multiple_of, - return_tensors="pt" + input_features, pad_to_multiple_of=self.pad_to_multiple_of, return_tensors="pt" ) if self.forward_attention_mask: batch["attention_mask"] = torch.LongTensor([feature["attention_mask"] for feature in features]) labels_batch = self.processor.tokenizer.pad( - label_features, - pad_to_multiple_of=self.pad_to_multiple_of_labels, - return_tensors="pt" + label_features, pad_to_multiple_of=self.pad_to_multiple_of_labels, return_tensors="pt" ) # replace padding with -100 to ignore loss correctly @@ -390,9 +383,9 @@ def main(): # IPU specific config updates config.update({"apply_spec_augment": False}) - + # Whisper does not have a layer_norm_eps option, remains to be seen if this is a problem - #config.update({"layer_norm_eps": 0.0001}) + # config.update({"layer_norm_eps": 0.0001}) feature_extractor = AutoFeatureExtractor.from_pretrained( model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path, @@ -477,8 +470,8 @@ def prepare_dataset(batch, feature_extractor, tokenizer): if not training_args.fp32: # Cast audio inputs to FP16 - batch[model_input_name] = batch[model_input_name].astype(np.float16) - + batch[model_input_name] = batch[model_input_name].astype(np.float16) + # process targets input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name] batch["labels"] = tokenizer(input_str).input_ids @@ -541,9 +534,9 @@ def compute_metrics(pred): processor=processor, decoder_start_token_id=model.config.decoder_start_token_id, forward_attention_mask=forward_attention_mask, - #pad_to_multiple_of=math.ceil(max_input_length), + # pad_to_multiple_of=math.ceil(max_input_length), pad_to_multiple_of=80, - pad_to_multiple_of_labels=training_args.generation_max_length + pad_to_multiple_of_labels=training_args.generation_max_length, ) # 11. Initialize Trainer @@ -556,10 +549,10 @@ def compute_metrics(pred): data_collator=data_collator, compute_metrics=compute_metrics if training_args.predict_with_generate else None, eval_parallelize_kwargs={ - 'use_cache': True, - 'use_cross_cache': True, - 'max_length': training_args.generation_max_length - } + "use_cache": True, + "use_cross_cache": True, + "max_length": training_args.generation_max_length, + }, ) # 12. Training diff --git a/tests/examples/run_speech_recognition_seq2seq.txt b/tests/examples/run_speech_recognition_seq2seq.txt new file mode 100644 index 000000000..81d3cb0b9 --- /dev/null +++ b/tests/examples/run_speech_recognition_seq2seq.txt @@ -0,0 +1,151 @@ +29a30 +> import numpy as np +31,32d31 +< from datasets import DatasetDict, load_dataset +< +33a33 +> from datasets import DatasetDict, load_dataset +38d37 +< AutoProcessor, +41,42c40 +< Seq2SeqTrainer, +< Seq2SeqTrainingArguments, +--- +> WhisperProcessor, +45c43 +< from transformers.trainer_utils import get_last_checkpoint, is_main_process +--- +> from transformers.trainer_utils import get_last_checkpoint +48a47,49 +> from optimum.graphcore import IPUConfig, IPUSeq2SeqTrainer +> from optimum.graphcore import IPUSeq2SeqTrainingArguments as Seq2SeqTrainingArguments +> +51c52 +< check_min_version("4.31.0.dev0") +--- +> check_min_version("4.29.0") +238a240,242 +> padding: Union[bool, str] = "longest" +> pad_to_multiple_of: Optional[int] = None +> pad_to_multiple_of_labels: Optional[int] = None +247c251,253 +< batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt") +--- +> batch = self.processor.feature_extractor.pad( +> input_features, pad_to_multiple_of=self.pad_to_multiple_of, return_tensors="pt" +> ) +252c258,260 +< labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt") +--- +> labels_batch = self.processor.tokenizer.pad( +> label_features, pad_to_multiple_of=self.pad_to_multiple_of_labels, return_tensors="pt" +> ) +264c272 +< return batch +--- +> return batch.data +280a289,294 +> if training_args.gradient_checkpointing: +> raise ValueError("Gradient checkpointing not supported.") +> +> print(f"{training_args.ipu_config_name = }") +> print(f"{training_args.ipu_config_name if training_args.ipu_config_name else model_args.model_name_or_path = }") +> +298,299d311 +< logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) +< +301,304d312 +< logger.warning( +< f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" +< f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" +< ) +307,311d314 +< # Set the verbosity to info of the Transformers logger (on main process only): +< if is_main_process(training_args.local_rank): +< transformers.utils.logging.set_verbosity_info() +< logger.info("Training/evaluation parameters %s", training_args) +< +379a383,384 +> if model_args.apply_spec_augment: +> raise ValueError("SpecAugment is not supported on IPU") +381a387,392 +> # IPU specific config updates +> config.update({"apply_spec_augment": False}) +> +> # Whisper does not have a layer_norm_eps option, remains to be seen if this is a problem +> # config.update({"layer_norm_eps": 0.0001}) +> +401a413,417 +> ipu_config = IPUConfig.from_pretrained( +> training_args.ipu_config_name if training_args.ipu_config_name else model_args.model_name_or_path, +> cache_dir=model_args.cache_dir, +> use_auth_token=True if model_args.use_auth_token else None, +> ) +446c462 +< def prepare_dataset(batch): +--- +> def prepare_dataset(batch, feature_extractor, tokenizer): +451a468,470 +> +> igmin = inputs.get(model_input_name)[0] +> # print(igmin.shape) +453c472 +< batch[model_input_name] = inputs.get(model_input_name)[0] +--- +> batch[model_input_name] = igmin +457a477,480 +> if not training_args.fp32: +> # Cast audio inputs to FP16 +> batch[model_input_name] = batch[model_input_name].astype(np.float16) +> +463,469c486,491 +< with training_args.main_process_first(desc="dataset map pre-processing"): +< vectorized_datasets = raw_datasets.map( +< prepare_dataset, +< remove_columns=next(iter(raw_datasets.values())).column_names, +< num_proc=data_args.preprocessing_num_workers, +< desc="preprocess train dataset", +< ) +--- +> vectorized_datasets = raw_datasets.map( +> lambda batch: prepare_dataset(batch, feature_extractor, tokenizer), +> remove_columns=next(iter(raw_datasets.values())).column_names, +> num_proc=data_args.preprocessing_num_workers, +> desc="preprocess train dataset", +> ) +509,516c531,534 +< # make sure all processes wait until data is saved +< with training_args.main_process_first(): +< # only the main process saves them +< if is_main_process(training_args.local_rank): +< # save feature extractor, tokenizer and config +< feature_extractor.save_pretrained(training_args.output_dir) +< tokenizer.save_pretrained(training_args.output_dir) +< config.save_pretrained(training_args.output_dir) +--- +> # save feature extractor, tokenizer and config +> feature_extractor.save_pretrained(training_args.output_dir) +> tokenizer.save_pretrained(training_args.output_dir) +> config.save_pretrained(training_args.output_dir) +518c536 +< processor = AutoProcessor.from_pretrained(training_args.output_dir) +--- +> processor = WhisperProcessor(feature_extractor, tokenizer) +524a543,545 +> # pad_to_multiple_of=math.ceil(max_input_length), +> pad_to_multiple_of=80, +> pad_to_multiple_of_labels=training_args.generation_max_length, +528c549 +< trainer = Seq2SeqTrainer( +--- +> trainer = IPUSeq2SeqTrainer( +529a551 +> ipu_config=ipu_config, +533d554 +< tokenizer=feature_extractor, +535a557,561 +> eval_parallelize_kwargs={ +> "use_cache": True, +> "use_cross_cache": True, +> "max_length": training_args.generation_max_length, +> }, From b42ddf8a19cae1073d6af1f7f54490d7df8e29da Mon Sep 17 00:00:00 2001 From: Callum McLean Date: Mon, 17 Jul 2023 14:51:27 +0100 Subject: [PATCH 4/8] `eval_parallelize_kwargs` -> `inference_parallelize_kwargs` --- examples/speech-recognition/run_speech_recognition_seq2seq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/speech-recognition/run_speech_recognition_seq2seq.py b/examples/speech-recognition/run_speech_recognition_seq2seq.py index 9ba32c4ca..5509e1eee 100755 --- a/examples/speech-recognition/run_speech_recognition_seq2seq.py +++ b/examples/speech-recognition/run_speech_recognition_seq2seq.py @@ -548,7 +548,7 @@ def compute_metrics(pred): eval_dataset=vectorized_datasets["eval"] if training_args.do_eval else None, data_collator=data_collator, compute_metrics=compute_metrics if training_args.predict_with_generate else None, - eval_parallelize_kwargs={ + inference_parallelize_kwargs={ "use_cache": True, "use_cross_cache": True, "max_length": training_args.generation_max_length, From c7c3f9eaa8ad6fbb331c9644a785449bea1659cd Mon Sep 17 00:00:00 2001 From: Callum McLean Date: Mon, 17 Jul 2023 15:26:53 +0100 Subject: [PATCH 5/8] Remove `pad_to_multiple_of` --- examples/speech-recognition/run_speech_recognition_seq2seq.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/speech-recognition/run_speech_recognition_seq2seq.py b/examples/speech-recognition/run_speech_recognition_seq2seq.py index 5509e1eee..b3fec1ae7 100755 --- a/examples/speech-recognition/run_speech_recognition_seq2seq.py +++ b/examples/speech-recognition/run_speech_recognition_seq2seq.py @@ -534,8 +534,6 @@ def compute_metrics(pred): processor=processor, decoder_start_token_id=model.config.decoder_start_token_id, forward_attention_mask=forward_attention_mask, - # pad_to_multiple_of=math.ceil(max_input_length), - pad_to_multiple_of=80, pad_to_multiple_of_labels=training_args.generation_max_length, ) From 53742af4c7da98a6ccc0513e20e87bc135673a9f Mon Sep 17 00:00:00 2001 From: Callum McLean Date: Mon, 17 Jul 2023 15:36:09 +0100 Subject: [PATCH 6/8] Redo diff file --- .../run_speech_recognition_seq2seq.txt | 51 ++++++++----------- 1 file changed, 21 insertions(+), 30 deletions(-) diff --git a/tests/examples/run_speech_recognition_seq2seq.txt b/tests/examples/run_speech_recognition_seq2seq.txt index 81d3cb0b9..d11b90719 100644 --- a/tests/examples/run_speech_recognition_seq2seq.txt +++ b/tests/examples/run_speech_recognition_seq2seq.txt @@ -44,61 +44,54 @@ < return batch --- > return batch.data -280a289,294 +280a289,291 > if training_args.gradient_checkpointing: > raise ValueError("Gradient checkpointing not supported.") > -> print(f"{training_args.ipu_config_name = }") -> print(f"{training_args.ipu_config_name if training_args.ipu_config_name else model_args.model_name_or_path = }") -> -298,299d311 +298,299d308 < logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) < -301,304d312 +301,304d309 < logger.warning( < f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" < f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" < ) -307,311d314 +307,311d311 < # Set the verbosity to info of the Transformers logger (on main process only): < if is_main_process(training_args.local_rank): < transformers.utils.logging.set_verbosity_info() < logger.info("Training/evaluation parameters %s", training_args) < -379a383,384 +379a380,381 > if model_args.apply_spec_augment: > raise ValueError("SpecAugment is not supported on IPU") -381a387,392 +381a384,389 > # IPU specific config updates > config.update({"apply_spec_augment": False}) > > # Whisper does not have a layer_norm_eps option, remains to be seen if this is a problem > # config.update({"layer_norm_eps": 0.0001}) > -401a413,417 +401a410,414 > ipu_config = IPUConfig.from_pretrained( > training_args.ipu_config_name if training_args.ipu_config_name else model_args.model_name_or_path, > cache_dir=model_args.cache_dir, > use_auth_token=True if model_args.use_auth_token else None, > ) -446c462 +446c459 < def prepare_dataset(batch): --- > def prepare_dataset(batch, feature_extractor, tokenizer): -451a468,470 -> -> igmin = inputs.get(model_input_name)[0] -> # print(igmin.shape) -453c472 -< batch[model_input_name] = inputs.get(model_input_name)[0] +452c465 +< # process audio length --- -> batch[model_input_name] = igmin -457a477,480 +> +457a471,474 > if not training_args.fp32: > # Cast audio inputs to FP16 > batch[model_input_name] = batch[model_input_name].astype(np.float16) > -463,469c486,491 +463,469c480,485 < with training_args.main_process_first(desc="dataset map pre-processing"): < vectorized_datasets = raw_datasets.map( < prepare_dataset, @@ -113,7 +106,7 @@ > num_proc=data_args.preprocessing_num_workers, > desc="preprocess train dataset", > ) -509,516c531,534 +509,516c525,528 < # make sure all processes wait until data is saved < with training_args.main_process_first(): < # only the main process saves them @@ -127,24 +120,22 @@ > feature_extractor.save_pretrained(training_args.output_dir) > tokenizer.save_pretrained(training_args.output_dir) > config.save_pretrained(training_args.output_dir) -518c536 +518c530 < processor = AutoProcessor.from_pretrained(training_args.output_dir) --- > processor = WhisperProcessor(feature_extractor, tokenizer) -524a543,545 -> # pad_to_multiple_of=math.ceil(max_input_length), -> pad_to_multiple_of=80, +524a537 > pad_to_multiple_of_labels=training_args.generation_max_length, -528c549 +528c541 < trainer = Seq2SeqTrainer( --- > trainer = IPUSeq2SeqTrainer( -529a551 +529a543 > ipu_config=ipu_config, -533d554 +533d546 < tokenizer=feature_extractor, -535a557,561 -> eval_parallelize_kwargs={ +535a549,553 +> inference_parallelize_kwargs={ > "use_cache": True, > "use_cross_cache": True, > "max_length": training_args.generation_max_length, From ee5282cd1697511a518f9fce4c23f045eed6578b Mon Sep 17 00:00:00 2001 From: Callum McLean Date: Mon, 17 Jul 2023 17:12:25 +0100 Subject: [PATCH 7/8] Set `inference_parallelize_kwargs` in `IPUConfig` --- .../run_speech_recognition_seq2seq.py | 10 +++--- .../run_speech_recognition_seq2seq.txt | 33 +++++++++---------- 2 files changed, 21 insertions(+), 22 deletions(-) diff --git a/examples/speech-recognition/run_speech_recognition_seq2seq.py b/examples/speech-recognition/run_speech_recognition_seq2seq.py index b3fec1ae7..244a4a201 100755 --- a/examples/speech-recognition/run_speech_recognition_seq2seq.py +++ b/examples/speech-recognition/run_speech_recognition_seq2seq.py @@ -411,6 +411,11 @@ def main(): training_args.ipu_config_name if training_args.ipu_config_name else model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, + inference_parallelize_kwargs={ + "use_cache": True, + "use_cross_cache": True, + "max_length": training_args.generation_max_length, + }, ) if model.config.decoder_start_token_id is None: @@ -546,11 +551,6 @@ def compute_metrics(pred): eval_dataset=vectorized_datasets["eval"] if training_args.do_eval else None, data_collator=data_collator, compute_metrics=compute_metrics if training_args.predict_with_generate else None, - inference_parallelize_kwargs={ - "use_cache": True, - "use_cross_cache": True, - "max_length": training_args.generation_max_length, - }, ) # 12. Training diff --git a/tests/examples/run_speech_recognition_seq2seq.txt b/tests/examples/run_speech_recognition_seq2seq.txt index d11b90719..e259c424a 100644 --- a/tests/examples/run_speech_recognition_seq2seq.txt +++ b/tests/examples/run_speech_recognition_seq2seq.txt @@ -72,26 +72,31 @@ > # Whisper does not have a layer_norm_eps option, remains to be seen if this is a problem > # config.update({"layer_norm_eps": 0.0001}) > -401a410,414 +401a410,419 > ipu_config = IPUConfig.from_pretrained( > training_args.ipu_config_name if training_args.ipu_config_name else model_args.model_name_or_path, > cache_dir=model_args.cache_dir, > use_auth_token=True if model_args.use_auth_token else None, +> inference_parallelize_kwargs={ +> "use_cache": True, +> "use_cross_cache": True, +> "max_length": training_args.generation_max_length, +> }, > ) -446c459 +446c464 < def prepare_dataset(batch): --- > def prepare_dataset(batch, feature_extractor, tokenizer): -452c465 +452c470 < # process audio length --- > -457a471,474 +457a476,479 > if not training_args.fp32: > # Cast audio inputs to FP16 > batch[model_input_name] = batch[model_input_name].astype(np.float16) > -463,469c480,485 +463,469c485,490 < with training_args.main_process_first(desc="dataset map pre-processing"): < vectorized_datasets = raw_datasets.map( < prepare_dataset, @@ -106,7 +111,7 @@ > num_proc=data_args.preprocessing_num_workers, > desc="preprocess train dataset", > ) -509,516c525,528 +509,516c530,533 < # make sure all processes wait until data is saved < with training_args.main_process_first(): < # only the main process saves them @@ -120,23 +125,17 @@ > feature_extractor.save_pretrained(training_args.output_dir) > tokenizer.save_pretrained(training_args.output_dir) > config.save_pretrained(training_args.output_dir) -518c530 +518c535 < processor = AutoProcessor.from_pretrained(training_args.output_dir) --- > processor = WhisperProcessor(feature_extractor, tokenizer) -524a537 +524a542 > pad_to_multiple_of_labels=training_args.generation_max_length, -528c541 +528c546 < trainer = Seq2SeqTrainer( --- > trainer = IPUSeq2SeqTrainer( -529a543 +529a548 > ipu_config=ipu_config, -533d546 +533d551 < tokenizer=feature_extractor, -535a549,553 -> inference_parallelize_kwargs={ -> "use_cache": True, -> "use_cross_cache": True, -> "max_length": training_args.generation_max_length, -> }, From fe466efe3e5d800f8a24968f353b2ea0bc0b4194 Mon Sep 17 00:00:00 2001 From: Callum McLean Date: Tue, 1 Aug 2023 11:22:22 +0100 Subject: [PATCH 8/8] Try manually removing transformers version bit from diff file --- tests/examples/run_speech_recognition_seq2seq.txt | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/examples/run_speech_recognition_seq2seq.txt b/tests/examples/run_speech_recognition_seq2seq.txt index e259c424a..d56c58c0e 100644 --- a/tests/examples/run_speech_recognition_seq2seq.txt +++ b/tests/examples/run_speech_recognition_seq2seq.txt @@ -20,10 +20,6 @@ > from optimum.graphcore import IPUConfig, IPUSeq2SeqTrainer > from optimum.graphcore import IPUSeq2SeqTrainingArguments as Seq2SeqTrainingArguments > -51c52 -< check_min_version("4.31.0.dev0") ---- -> check_min_version("4.29.0") 238a240,242 > padding: Union[bool, str] = "longest" > pad_to_multiple_of: Optional[int] = None