Skip to content

Commit 35d8759

Browse files
alancuckinv-kkudrynski
authored andcommitted
[wav2vec2/PyT] Initial release
1 parent c2bb3fe commit 35d8759

File tree

85 files changed

+12548
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

85 files changed

+12548
-0
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
datasets/
2+
results/
3+
models/
4+
pretrained_models/
5+
tb_*/
6+
*.pt
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Copyright (c) 2023 NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:22.11-py3
16+
FROM ${FROM_IMAGE_NAME}
17+
18+
ENV PYTHONPATH /workspace/wav2vec2
19+
WORKDIR /workspace/wav2vec2
20+
21+
COPY requirements.txt .
22+
RUN pip install -r requirements.txt
23+
24+
COPY . .

PyTorch/SpeechRecognition/wav2vec2/README.md

Lines changed: 587 additions & 0 deletions
Large diffs are not rendered by default.

PyTorch/SpeechRecognition/wav2vec2/common/__init__.py

Whitespace-only changes.
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import copy
16+
import numpy as np
17+
from torch.utils.data import DataLoader
18+
19+
from common.fairseq.data import data_utils
20+
from common.helpers import print_once
21+
from common.sampler import DistributedIndicesSampler
22+
23+
24+
def adjust_max_tokens(train_dataset, world_size, args):
25+
26+
def get_steps_per_epoch(world_size, max_tokens, update_freq):
27+
train_loader, sampler = get_batch_iterator(
28+
train_dataset,
29+
True,
30+
max_tokens=max_tokens,
31+
max_sentences=args.batch_size,
32+
max_positions=(max_tokens, max_tokens),
33+
ignore_invalid_inputs=True,
34+
required_batch_size_multiple=args.required_batch_size_multiple,
35+
seed=args.seed,
36+
num_shards=world_size,
37+
shard_id=0,
38+
num_workers=args.num_workers)
39+
40+
steps_per_epoch = len(train_loader) // update_freq
41+
return steps_per_epoch
42+
43+
steps_ref = get_steps_per_epoch(args.ref_world_size, args.ref_max_tokens, 1)
44+
45+
min_ = args.ref_max_tokens // 20
46+
max_ = args.ref_max_tokens * 20
47+
48+
prev_max_tokens = 0
49+
align_to = 1000
50+
while min_ < max_:
51+
max_tokens = (max_ + min_) // 2 // align_to * align_to # try to round
52+
if max_tokens == prev_max_tokens:
53+
break
54+
prev_max_tokens = max_tokens
55+
steps = get_steps_per_epoch(world_size, max_tokens, args.update_freq)
56+
print_once(f"max_tokens={max_tokens} yields {steps} steps "
57+
f"(adjusting for {steps_ref}).")
58+
if steps == steps_ref:
59+
break
60+
elif steps > steps_ref:
61+
min_ = max_tokens
62+
else:
63+
max_ = max_tokens
64+
65+
args.max_tokens = max_tokens
66+
args.max_tokens_valid = max_tokens
67+
68+
69+
def filter_indices_by_size(
70+
indices, dataset, max_positions=None, ignore_invalid_inputs=False
71+
):
72+
"""
73+
Filter examples that are too large
74+
75+
Args:
76+
indices (np.array): original array of sample indices
77+
dataset (~fairseq.data.FairseqDataset): dataset to batch
78+
max_positions (optional): max sentence length supported by the
79+
model (default: None).
80+
ignore_invalid_inputs (bool, optional): don't raise Exception for
81+
sentences that are too long (default: False).
82+
Returns:
83+
np.array: array of filtered sample indices
84+
"""
85+
indices, ignored = dataset.filter_indices_by_size(indices, max_positions)
86+
# TODO: consider removing this function. If `len(ignored) > 0`,
87+
# an error is raised in fairseq dataset code, both in sup and unsup case
88+
if len(ignored) > 0:
89+
if not ignore_invalid_inputs:
90+
raise Exception(
91+
(
92+
"Size of sample #{} is invalid (={}) since max_positions={}, "
93+
"skip this example with --skip-invalid-size-inputs-valid-test"
94+
).format(ignored[0], dataset.size(ignored[0]), max_positions)
95+
)
96+
print(
97+
(
98+
"WARNING: {:,} samples have invalid sizes and will be skipped, "
99+
"max_positions={}, first few sample ids={}"
100+
).format(len(ignored), max_positions, ignored[:10])
101+
)
102+
return indices
103+
104+
105+
def get_batch_iterator(
106+
dataset,
107+
training,
108+
max_tokens=None,
109+
max_sentences=None,
110+
max_positions=None,
111+
ignore_invalid_inputs=False,
112+
required_batch_size_multiple=1,
113+
seed=1,
114+
num_shards=1,
115+
shard_id=0,
116+
num_workers=0,
117+
num_concat_batches=1,
118+
):
119+
# get indices ordered by example size
120+
with data_utils.numpy_seed(seed):
121+
indices = dataset.ordered_indices()
122+
123+
# filter examples that are too large
124+
if max_positions is not None:
125+
indices = filter_indices_by_size(
126+
indices, dataset, max_positions, ignore_invalid_inputs)
127+
128+
# create mini-batches with given size constraints
129+
batch_inds, non_grouped_batch_inds = dataset.batch_by_size(
130+
indices,
131+
max_tokens=max_tokens,
132+
max_sentences=max_sentences,
133+
required_batch_size_multiple=required_batch_size_multiple,
134+
num_concat_batches=num_concat_batches,
135+
)
136+
137+
batch_ids = copy.deepcopy(non_grouped_batch_inds)
138+
[bi.fill(i) for i, bi in enumerate(batch_ids)]
139+
inds_ids = zip(np.concatenate(batch_inds), np.concatenate(batch_ids))
140+
dataset.batch_ids = {idx: batch_idx for idx, batch_idx in inds_ids}
141+
142+
# Batches are already specified, now we just need to shuffle them
143+
batch_ind_sampler = DistributedIndicesSampler(batch_inds, shuffle=training,
144+
num_replicas=num_shards,
145+
rank=shard_id, seed=seed,
146+
drop_last=training,
147+
fillvalue=[])
148+
loader = DataLoader(
149+
dataset=dataset,
150+
collate_fn=dataset.collater,
151+
batch_sampler=batch_ind_sampler,
152+
num_workers=num_workers,
153+
pin_memory=True,
154+
persistent_workers=num_workers > 0,
155+
)
156+
return loader, batch_ind_sampler
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
7+
#
8+
# Licensed under the Apache License, Version 2.0 (the "License");
9+
# you may not use this file except in compliance with the License.
10+
# You may obtain a copy of the License at
11+
#
12+
# http://www.apache.org/licenses/LICENSE-2.0
13+
#
14+
# Unless required by applicable law or agreed to in writing, software
15+
# distributed under the License is distributed on an "AS IS" BASIS,
16+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17+
# See the License for the specific language governing permissions and
18+
# limitations under the License.
19+
"""isort:skip_file"""
20+
21+
from .add_target_dataset import AddTargetDataset, BaseWrapperDataset
22+
from .audio.raw_audio_dataset import FileAudioDataset
23+
from .dictionary import Dictionary
24+
25+
26+
__all__ = [
27+
"AddTargetDataset",
28+
"Dictionary",
29+
"FileAudioDataset",
30+
]
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
7+
#
8+
# Licensed under the Apache License, Version 2.0 (the "License");
9+
# you may not use this file except in compliance with the License.
10+
# You may obtain a copy of the License at
11+
#
12+
# http://www.apache.org/licenses/LICENSE-2.0
13+
#
14+
# Unless required by applicable law or agreed to in writing, software
15+
# distributed under the License is distributed on an "AS IS" BASIS,
16+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17+
# See the License for the specific language governing permissions and
18+
# limitations under the License.
19+
20+
import torch
21+
22+
from . import data_utils
23+
24+
25+
class BaseWrapperDataset(torch.utils.data.Dataset):
26+
def __init__(self, dataset):
27+
super().__init__()
28+
self.dataset = dataset
29+
30+
def __getitem__(self, index):
31+
return self.dataset[index]
32+
33+
def __len__(self):
34+
return len(self.dataset)
35+
36+
@property
37+
def sizes(self):
38+
return self.dataset.sizes
39+
40+
def num_tokens(self, index):
41+
return self.dataset.num_tokens(index)
42+
43+
def size(self, index):
44+
return self.dataset.size(index)
45+
46+
def ordered_indices(self):
47+
return self.dataset.ordered_indices()
48+
49+
def batch_by_size(
50+
self,
51+
indices,
52+
max_tokens=None,
53+
max_sentences=None,
54+
required_batch_size_multiple=1,
55+
num_concat_batches=1,
56+
):
57+
return self.dataset.batch_by_size(
58+
indices,
59+
max_tokens=max_tokens,
60+
max_sentences=max_sentences,
61+
required_batch_size_multiple=required_batch_size_multiple,
62+
num_concat_batches=num_concat_batches,
63+
)
64+
65+
def filter_indices_by_size(self, indices, max_sizes):
66+
return self.dataset.filter_indices_by_size(indices, max_sizes)
67+
68+
69+
class AddTargetDataset(BaseWrapperDataset):
70+
def __init__(
71+
self,
72+
dataset,
73+
labels,
74+
pad,
75+
eos,
76+
batch_targets,
77+
process_label=None,
78+
add_to_input=False,
79+
):
80+
super().__init__(dataset)
81+
self.labels = labels
82+
self.batch_targets = batch_targets
83+
self.pad = pad
84+
self.eos = eos
85+
self.process_label = process_label
86+
self.add_to_input = add_to_input
87+
88+
def get_label(self, index):
89+
return (
90+
self.labels[index]
91+
if self.process_label is None
92+
else self.process_label(self.labels[index])
93+
)
94+
95+
def __getitem__(self, index):
96+
item = self.dataset[index]
97+
item["label"] = self.get_label(index)
98+
return item
99+
100+
def size(self, index):
101+
sz = self.dataset.size(index)
102+
own_sz = len(self.get_label(index))
103+
return (sz, own_sz)
104+
105+
def collater(self, samples):
106+
collated = self.dataset.collater(samples)
107+
if len(collated) == 0:
108+
return collated
109+
indices = set(collated["id"].tolist())
110+
target = [s["label"] for s in samples if s["id"] in indices]
111+
112+
if self.batch_targets:
113+
collated["target_lengths"] = torch.LongTensor([len(t) for t in target])
114+
target = data_utils.collate_tokens(target, pad_idx=self.pad, left_pad=False)
115+
collated["ntokens"] = collated["target_lengths"].sum().item()
116+
else:
117+
collated["ntokens"] = sum([len(t) for t in target])
118+
119+
collated["target"] = target
120+
121+
if self.add_to_input:
122+
eos = target.new_full((target.size(0), 1), self.eos)
123+
collated["target"] = torch.cat([target, eos], dim=-1).long()
124+
collated["net_input"]["prev_output_tokens"] = torch.cat(
125+
[eos, target], dim=-1
126+
).long()
127+
collated["ntokens"] += target.size(0)
128+
return collated
129+
130+
def __setattr__(self, attr, val):
131+
if attr == "batch_ids":
132+
self.dataset.batch_ids = val
133+
else:
134+
super().__setattr__(attr, val)

0 commit comments

Comments
 (0)