Skip to content

Commit 2c79d68

Browse files
committed
Eliminate src/libtorchaudio/stable/dispatch.h
1 parent 220f6aa commit 2c79d68

File tree

4 files changed

+49
-137
lines changed

4 files changed

+49
-137
lines changed

src/libtorchaudio/forced_align/cpu/compute.cpp

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
#include <libtorchaudio/stable/dispatch.h>
21
#include <libtorchaudio/stable/ops.h>
32
#include <libtorchaudio/utils.h>
43
#include <torch/csrc/stable/library.h>
54
#include <torch/csrc/stable/tensor.h>
5+
#include <torch/headeronly/core/Dispatch_v2.h>
6+
#include <torch/headeronly/core/ScalarType.h>
67

78
namespace torchaudio {
89
namespace alignment {
@@ -138,6 +139,13 @@ void forced_align_impl(
138139
delete[] backPtr_a;
139140
}
140141

142+
template <typename scalar_t>
143+
const auto forced_align_long_impl =
144+
forced_align_impl<scalar_t, ScalarType::Long>;
145+
146+
template <typename scalar_t>
147+
const auto forced_align_int_impl = forced_align_impl<scalar_t, ScalarType::Int>;
148+
141149
std::tuple<Tensor, Tensor> compute(
142150
const Tensor& logProbs,
143151
const Tensor& targets,
@@ -178,32 +186,41 @@ std::tuple<Tensor, Tensor> compute(
178186
STD_TORCH_CHECK(
179187
blank >= 0 && blank < logProbs.size(-1),
180188
"blank must be within [0, num classes)");
181-
STABLE_DISPATCH_INDEX_TYPES(
182-
inputLengths.scalar_type(), "forced_align_impl", [&] {
189+
THO_DISPATCH_V2(
190+
inputLengths.scalar_type(),
191+
"forced_align_impl",
192+
AT_WRAP([&] {
183193
STD_TORCH_CHECK(
184-
logProbs.size(1) == torchaudio::util::max<index_t>(inputLengths),
194+
logProbs.size(1) == torchaudio::util::max<scalar_t>(inputLengths),
185195
"input length mismatch");
186-
});
187-
STABLE_DISPATCH_INDEX_TYPES(
188-
targetLengths.scalar_type(), "forced_align_impl", [&] {
196+
}),
197+
ScalarType::Int,
198+
ScalarType::Long);
199+
THO_DISPATCH_V2(
200+
targetLengths.scalar_type(),
201+
"forced_align_impl",
202+
AT_WRAP([&] {
189203
STD_TORCH_CHECK(
190-
targets.size(1) == torchaudio::util::max<index_t>(targetLengths),
204+
targets.size(1) == torchaudio::util::max<scalar_t>(targetLengths),
191205
"target length mismatch");
192-
});
206+
}),
207+
ScalarType::Int,
208+
ScalarType::Long);
193209
const auto B = logProbs.size(0);
194210
const auto T = logProbs.size(1);
195211
Tensor paths = torchaudio::stable::new_zeros(targets, {B, T});
196-
197-
STABLE_DISPATCH_FLOATING_TYPES_AND_HALF(
198-
logProbs.scalar_type(), "forced_align_impl", [&] {
212+
THO_DISPATCH_V2(
213+
logProbs.scalar_type(),
214+
"forced_align_impl",
215+
AT_WRAP([&] {
199216
if (targets.scalar_type() == ScalarType::Long) {
200-
forced_align_impl<scalar_t, ScalarType::Long>(
201-
logProbs, targets, blank, paths);
217+
forced_align_long_impl<scalar_t>(logProbs, targets, blank, paths);
202218
} else {
203-
forced_align_impl<scalar_t, ScalarType::Int>(
204-
logProbs, targets, blank, paths);
219+
forced_align_int_impl<scalar_t>(logProbs, targets, blank, paths);
205220
}
206-
});
221+
}),
222+
AT_EXPAND(AT_FLOATING_TYPES),
223+
ScalarType::Half);
207224
return std::make_tuple(paths, logProbs);
208225
}
209226

src/libtorchaudio/forced_align/gpu/compute.cu

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
#include <libtorchaudio/utils.h>
22
#include <libtorchaudio/stable/TensorAccessor.h>
3-
#include <libtorchaudio/stable/dispatch.h>
43
#include <torch/csrc/stable/library.h>
4+
#include <torch/headeronly/core/Dispatch_v2.h>
5+
#include <torch/headeronly/core/ScalarType.h>
56

67
#include <cub/cub.cuh>
78
#include <limits.h>
@@ -243,6 +244,13 @@ void forced_align_impl(
243244
}
244245
}
245246

247+
template <typename scalar_t>
248+
const auto forced_align_long_impl =
249+
forced_align_impl<scalar_t, ScalarType::Long>;
250+
251+
template <typename scalar_t>
252+
const auto forced_align_int_impl = forced_align_impl<scalar_t, ScalarType::Int>;
253+
246254
std::tuple<Tensor, Tensor> compute(
247255
const Tensor& logProbs,
248256
const Tensor& targets,
@@ -297,16 +305,13 @@ std::tuple<Tensor, Tensor> compute(
297305

298306
Tensor paths = torchaudio::stable::new_zeros(targets, {B, T}, /*dtype=*/std::nullopt, /*layout=*/std::nullopt, /*device=*/torchaudio::stable::cpu_device());
299307

300-
STABLE_DISPATCH_FLOATING_TYPES_AND_HALF(
301-
logProbs.scalar_type(), "forced_align_impl", [&] {
308+
THO_DISPATCH_V2(logProbs.scalar_type(), "forced_align_impl", AT_WRAP([&] {
302309
if (targets.scalar_type() == ScalarType::Long) {
303-
forced_align_impl<scalar_t, ScalarType::Long>(
304-
logProbs, targets, blank, paths);
310+
forced_align_long_impl<scalar_t>(logProbs, targets, blank, paths);
305311
} else {
306-
forced_align_impl<scalar_t, ScalarType::Int>(
307-
logProbs, targets, blank, paths);
312+
forced_align_int_impl<scalar_t>(logProbs, targets, blank, paths);
308313
}
309-
});
314+
}), AT_EXPAND(AT_FLOATING_TYPES), ScalarType::Half);
310315

311316
Tensor pathsCuda = torchaudio::stable::cuda(paths, logProbs.get_device_index());
312317
return std::make_tuple(pathsCuda, logProbs);

src/libtorchaudio/stable/TensorAccessor.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class TensorAccessorBase {
4444
// Originally, TensorAccessor is a view of sizes and strides as
4545
// these are ArrayRef instances. Until torch::stable supports
4646
// ArrayRef-like features, we store copies of sizes and strides:
47-
for (auto i = 0; i < N; ++i) {
47+
for (size_t i = 0; i < N; ++i) {
4848
this->sizes_[i] = sizes_[i];
4949
this->strides_[i] = strides_[i];
5050
}
@@ -146,7 +146,7 @@ class GenericPackedTensorAccessorBase {
146146
const source_index_t* sizes_,
147147
const source_index_t* strides_)
148148
: data_(data_) {
149-
for (auto i = 0; i < N; ++i) {
149+
for (size_t i = 0; i < N; ++i) {
150150
this->sizes_[i] = sizes_[i];
151151
this->strides_[i] = strides_[i];
152152
}

src/libtorchaudio/stable/dispatch.h

Lines changed: 0 additions & 110 deletions
This file was deleted.

0 commit comments

Comments
 (0)