66#include < torch/csrc/inductor/aoti_torch/c/shim.h>
77#include < torch/csrc/inductor/aoti_torch/utils.h>
88#include < libtorchaudio/accessor.h>
9+ #include < torch/headeronly/util/Half.h>
910
1011
1112using namespace std ;
@@ -22,7 +23,7 @@ template <typename scalar_t, typename target_t>
2223void forced_align_impl (
2324 const Tensor logProbs,
2425 const Tensor targets,
25- const Tensor blank,
26+ target_t blank,
2627 Tensor paths) {
2728 const scalar_t kNegInfinity = -std::numeric_limits<scalar_t >::infinity ();
2829 const auto batchIndex =
@@ -143,15 +144,15 @@ std::tuple<Tensor, Tensor> compute(
143144 TORCH_CHECK (logProbs.is_cpu (), " log_probs must be a CPU tensor" );
144145 TORCH_CHECK (targets.is_cpu (), " targets must be a CPU tensor" );
145146 TORCH_CHECK (
146- logProbs.device () == targets.device (),
147+ logProbs.get_device () == targets.get_device (),
147148 " log_probs and targets need to be on the same device" );
148149 TORCH_CHECK (
149- logProbs.dtype () == torch:: kFloat64 ||
150- logProbs.dtype () == torch:: kFloat32 ||
151- logProbs.dtype () == torch:: kFloat16 ,
150+ logProbs.dtype () == aoti_torch_dtype_float64 () ||
151+ logProbs.dtype () == aoti_torch_dtype_float32 () ||
152+ logProbs.dtype () == aoti_torch_dtype_float16 () ,
152153 " log_probs must be float64, float32 or float16 (half) type" );
153154 TORCH_CHECK (
154- targets.dtype () == torch:: kInt32 || targets.dtype () == torch:: kInt64 ,
155+ targets.dtype () == aoti_torch_dtype_int32 () || targets.dtype () == aoti_torch_dtype_int64 () ,
155156 " targets must be int32 or int64 type" );
156157 TORCH_CHECK (logProbs.is_contiguous (), " log_probs must be contiguous" );
157158 TORCH_CHECK (targets.is_contiguous (), " targets must be contiguous" );
@@ -174,38 +175,41 @@ std::tuple<Tensor, Tensor> compute(
174175 blank >= 0 && blank < logProbs.size (-1 ),
175176 " blank must be within [0, num classes)" );
176177
177- TORCH_CHECK (
178- logProbs.size (1 ) == at::max (inputLengths).item ().toInt (),
179- " input length mismatch" );
180- TORCH_CHECK (
181- targets.size (1 ) == at::max (targetLengths).item ().toInt (),
182- " target length mismatch" );
178+ // TODO: Requires port of `max` operator.
179+ // TORCH_CHECK(
180+ // logProbs.size(1) == at::max(inputLengths).item().toInt(),
181+ // "input length mismatch");
182+ // TORCH_CHECK(
183+ // targets.size(1) == at::max(targetLengths).item().toInt(),
184+ // "target length mismatch");
183185
184186 const auto B = logProbs.size (0 );
185187 const auto T = logProbs.size (1 );
186188
187189 int64_t paths_size[2 ] = {B, T};
188190 int64_t paths_stride[2 ] = {T, 1 };
189191 AtenTensorHandle paths_h;
190- aoti_torch_empty_strided (1 , paths_size, paths_stride, targets_dtype, targets_device, targets_device_index, &paths_h);
192+ int32_t targets_device;
193+ aoti_torch_get_device_type (targets.get (), &targets_device);
194+ aoti_torch_empty_strided (1 , paths_size, paths_stride, targets.dtype (), targets_device, targets.get_device (), &paths_h);
191195 auto paths = Tensor (paths_h);
192196
193197
194198 if (targets.dtype () == aoti_torch_dtype_int64 ()) {
195- if (logProbs.scalar_type () == aoti_torch_dtype_float64 ()) {
196- forced_align_impl<float64, int64 >(logProbs, targets, blank, paths);
197- } else if (logProbs.scalar_type () == aoti_torch_dtype_float32 ()) {
198- forced_align_impl<float32, int64 >(logProbs, targets, blank, paths);
199- } else if (logProbs.scalar_type () == aoti_torch_dtype_float16 ()) {
200- forced_align_impl<float16, int64 >(logProbs, targets, blank, paths);
199+ if (logProbs.dtype () == aoti_torch_dtype_float64 ()) {
200+ forced_align_impl<double , int64_t >(logProbs, targets, blank, paths);
201+ } else if (logProbs.dtype () == aoti_torch_dtype_float32 ()) {
202+ forced_align_impl<float , int64_t >(logProbs, targets, blank, paths);
203+ } else if (logProbs.dtype () == aoti_torch_dtype_float16 ()) {
204+ forced_align_impl<c10::Half, int64_t >(logProbs, targets, blank, paths);
201205 }
202- } else if (targets.scalar_type () == aoti_torch_dtype_int32 ()) {
203- if (logProbs.scalar_type () == aoti_torch_dtype_float64 ()) {
204- forced_align_impl<float64, int32 >(logProbs, targets, blank, paths);
205- } else if (logProbs.scalar_type () == aoti_torch_dtype_float32 ()) {
206- forced_align_impl<float32, int32 >(logProbs, targets, blank, paths);
207- } else if (logProbs.scalar_type () == aoti_torch_dtype_float16 ()) {
208- forced_align_impl<float16, int32 >(logProbs, targets, blank, paths);
206+ } else if (targets.dtype () == aoti_torch_dtype_int32 ()) {
207+ if (logProbs.dtype () == aoti_torch_dtype_float64 ()) {
208+ forced_align_impl<double , int32_t >(logProbs, targets, blank, paths);
209+ } else if (logProbs.dtype () == aoti_torch_dtype_float32 ()) {
210+ forced_align_impl<float , int32_t >(logProbs, targets, blank, paths);
211+ } else if (logProbs.dtype () == aoti_torch_dtype_float16 ()) {
212+ forced_align_impl<c10::Half, int32_t >(logProbs, targets, blank, paths);
209213 }
210214 }
211215 return std::make_tuple (
0 commit comments