@@ -14,19 +14,17 @@ namespace torchaudio {
1414namespace alignment {
1515namespace cpu {
1616
17-
17+ using torch::stable::Tensor;
1818
1919// Inspired from
2020// https://github.com/flashlight/sequence/blob/main/flashlight/lib/sequence/criterion/cpu/ConnectionistTemporalClassificationCriterion.cpp
21- template <typename scalar_t , at::ScalarType target_scalar_type >
21+ template <typename scalar_t , typename target_t >
2222void forced_align_impl (
23- const torch:: Tensor& logProbs,
24- const torch:: Tensor& targets,
25- const int64_t blank,
26- torch:: Tensor& paths) {
23+ const Tensor logProbs,
24+ const Tensor targets,
25+ const Tensor blank,
26+ Tensor paths) {
2727 const scalar_t kNegInfinity = -std::numeric_limits<scalar_t >::infinity ();
28- using target_t = typename std::
29- conditional<target_scalar_type == torch::kInt , int , int64_t >::type;
3028 const auto batchIndex =
3129 0 ; // TODO: support batch version and use the real batch index
3230 const auto T = logProbs.size (1 );
@@ -136,11 +134,11 @@ void forced_align_impl(
136134 }
137135}
138136
139- std::tuple<torch:: Tensor, torch:: Tensor> compute (
140- const torch:: Tensor& logProbs,
141- const torch:: Tensor& targets,
142- const torch:: Tensor& inputLengths,
143- const torch:: Tensor& targetLengths,
137+ std::tuple<Tensor, Tensor> compute (
138+ const Tensor& logProbs,
139+ const Tensor& targets,
140+ const Tensor& inputLengths,
141+ const Tensor& targetLengths,
144142 const int64_t blank) {
145143 TORCH_CHECK (logProbs.is_cpu (), " log_probs must be a CPU tensor" );
146144 TORCH_CHECK (targets.is_cpu (), " targets must be a CPU tensor" );
@@ -185,19 +183,31 @@ std::tuple<torch::Tensor, torch::Tensor> compute(
185183
186184 const auto B = logProbs.size (0 );
187185 const auto T = logProbs.size (1 );
188- auto paths = torch::zeros (
189- {B, T},
190- torch::TensorOptions ().device (targets.device ()).dtype (targets.dtype ()));
191- AT_DISPATCH_FLOATING_TYPES_AND_HALF (
192- logProbs.scalar_type (), " forced_align_impl" , [&] {
193- if (targets.scalar_type () == torch::kInt64 ) {
194- forced_align_impl<scalar_t , torch::kInt64 >(
195- logProbs, targets, blank, paths);
196- } else {
197- forced_align_impl<scalar_t , torch::kInt32 >(
198- logProbs, targets, blank, paths);
199- }
200- });
186+
187+ int64_t paths_size[2 ] = {B, T};
188+ int64_t paths_stride[2 ] = {T, 1 };
189+ AtenTensorHandle paths_h;
190+ aoti_torch_empty_strided (1 , paths_size, paths_stride, targets_dtype, targets_device, targets_device_index, &paths_h);
191+ auto paths = Tensor (paths_h);
192+
193+
194+ if (targets.scalar_type () == aoti_torch_dtype_int64 ()) {
195+ if (logProbs.scalar_type () == aoti_torch_dtype_float64 ()) {
196+ forced_align_impl<float64, int64>(logProbs, targets, blank, paths);
197+ } else if (logProbs.scalar_type () == aoti_torch_dtype_float32 ()) {
198+ forced_align_impl<float32, int64>(logProbs, targets, blank, paths);
199+ } else if (logProbs.scalar_type () == aoti_torch_dtype_float16 ()) {
200+ forced_align_impl<float16, int64>(logProbs, targets, blank, paths);
201+ }
202+ } else if (targets.scalar_type () == aoti_torch_dtype_int32 ()) {
203+ if (logProbs.scalar_type () == aoti_torch_dtype_float64 ()) {
204+ forced_align_impl<float64, int32>(logProbs, targets, blank, paths);
205+ } else if (logProbs.scalar_type () == aoti_torch_dtype_float32 ()) {
206+ forced_align_impl<float32, int32>(logProbs, targets, blank, paths);
207+ } else if (logProbs.scalar_type () == aoti_torch_dtype_float16 ()) {
208+ forced_align_impl<float16, int32>(logProbs, targets, blank, paths);
209+ }
210+ }
201211 return std::make_tuple (
202212 paths,
203213 logProbs.index (
@@ -207,8 +217,21 @@ std::tuple<torch::Tensor, torch::Tensor> compute(
207217 paths.index ({0 })}));
208218}
209219
210- TORCH_LIBRARY_IMPL (torchaudio, CPU, m) {
211- m.impl (" forced_align" , &compute);
220+
221+ void boxed_compute (StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
222+ Tensor t1 (to<AtenTensorHandle>(stack[0 ]));
223+ Tensor t2 (to<AtenTensorHandle>(stack[1 ]));
224+ Tensor t3 (to<AtenTensorHandle>(stack[2 ]));
225+ Tensor t4 (to<AtenTensorHandle>(stack[3 ]));
226+ int64_t blank = to<int64_t >(stack[4 ]);
227+ auto result = compute (
228+ std::move (t1), std::move (t2), std::move (t3), std::move (t4), blank);
229+ stack[0 ] = from (std::get<0 >(result));
230+ stack[1 ] = from (std::get<1 >(result));
231+ }
232+
233+ STABLE_TORCH_LIBRARY_IMPL (torchaudio, CPU, m) {
234+ m.impl (" forced_align" , &boxed_compute);
212235}
213236
214237} // namespace cpu
0 commit comments