|
1 | | -#include <libtorchaudio/stable/dispatch.h> |
2 | 1 | #include <libtorchaudio/stable/ops.h> |
3 | 2 | #include <libtorchaudio/utils.h> |
4 | 3 | #include <torch/csrc/stable/library.h> |
5 | 4 | #include <torch/csrc/stable/tensor.h> |
| 5 | +#include <torch/headeronly/core/Dispatch_v2.h> |
| 6 | +#include <torch/headeronly/core/ScalarType.h> |
6 | 7 |
|
7 | 8 | namespace torchaudio { |
8 | 9 | namespace alignment { |
@@ -138,6 +139,13 @@ void forced_align_impl( |
138 | 139 | delete[] backPtr_a; |
139 | 140 | } |
140 | 141 |
|
| 142 | +template <typename scalar_t> |
| 143 | +const auto forced_align_long_impl = |
| 144 | + forced_align_impl<scalar_t, ScalarType::Long>; |
| 145 | + |
| 146 | +template <typename scalar_t> |
| 147 | +const auto forced_align_int_impl = forced_align_impl<scalar_t, ScalarType::Int>; |
| 148 | + |
141 | 149 | std::tuple<Tensor, Tensor> compute( |
142 | 150 | const Tensor& logProbs, |
143 | 151 | const Tensor& targets, |
@@ -178,32 +186,41 @@ std::tuple<Tensor, Tensor> compute( |
178 | 186 | STD_TORCH_CHECK( |
179 | 187 | blank >= 0 && blank < logProbs.size(-1), |
180 | 188 | "blank must be within [0, num classes)"); |
181 | | - STABLE_DISPATCH_INDEX_TYPES( |
182 | | - inputLengths.scalar_type(), "forced_align_impl", [&] { |
| 189 | + THO_DISPATCH_V2( |
| 190 | + inputLengths.scalar_type(), |
| 191 | + "forced_align_impl", |
| 192 | + AT_WRAP([&] { |
183 | 193 | STD_TORCH_CHECK( |
184 | | - logProbs.size(1) == torchaudio::util::max<index_t>(inputLengths), |
| 194 | + logProbs.size(1) == torchaudio::util::max<scalar_t>(inputLengths), |
185 | 195 | "input length mismatch"); |
186 | | - }); |
187 | | - STABLE_DISPATCH_INDEX_TYPES( |
188 | | - targetLengths.scalar_type(), "forced_align_impl", [&] { |
| 196 | + }), |
| 197 | + ScalarType::Int, |
| 198 | + ScalarType::Long); |
| 199 | + THO_DISPATCH_V2( |
| 200 | + targetLengths.scalar_type(), |
| 201 | + "forced_align_impl", |
| 202 | + AT_WRAP([&] { |
189 | 203 | STD_TORCH_CHECK( |
190 | | - targets.size(1) == torchaudio::util::max<index_t>(targetLengths), |
| 204 | + targets.size(1) == torchaudio::util::max<scalar_t>(targetLengths), |
191 | 205 | "target length mismatch"); |
192 | | - }); |
| 206 | + }), |
| 207 | + ScalarType::Int, |
| 208 | + ScalarType::Long); |
193 | 209 | const auto B = logProbs.size(0); |
194 | 210 | const auto T = logProbs.size(1); |
195 | 211 | Tensor paths = torchaudio::stable::new_zeros(targets, {B, T}); |
196 | | - |
197 | | - STABLE_DISPATCH_FLOATING_TYPES_AND_HALF( |
198 | | - logProbs.scalar_type(), "forced_align_impl", [&] { |
| 212 | + THO_DISPATCH_V2( |
| 213 | + logProbs.scalar_type(), |
| 214 | + "forced_align_impl", |
| 215 | + AT_WRAP([&] { |
199 | 216 | if (targets.scalar_type() == ScalarType::Long) { |
200 | | - forced_align_impl<scalar_t, ScalarType::Long>( |
201 | | - logProbs, targets, blank, paths); |
| 217 | + forced_align_long_impl<scalar_t>(logProbs, targets, blank, paths); |
202 | 218 | } else { |
203 | | - forced_align_impl<scalar_t, ScalarType::Int>( |
204 | | - logProbs, targets, blank, paths); |
| 219 | + forced_align_int_impl<scalar_t>(logProbs, targets, blank, paths); |
205 | 220 | } |
206 | | - }); |
| 221 | + }), |
| 222 | + AT_EXPAND(AT_FLOATING_TYPES), |
| 223 | + ScalarType::Half); |
207 | 224 | return std::make_tuple(paths, logProbs); |
208 | 225 | } |
209 | 226 |
|
|
0 commit comments