11#include < libtorchaudio/utils.h>
2+ #include < libtorchaudio/stable/TensorAccessor.h>
3+ #include < libtorchaudio/stable/dispatch.h>
24#include < torch/csrc/stable/library.h>
35
46#include < cub/cub.cuh>
@@ -20,9 +22,9 @@ using torch::headeronly::ScalarType;
2022
2123template <typename scalar_t , typename target_t >
2224__global__ void falign_cuda_step_kernel (
23- const at:: PackedTensorAccessor32<scalar_t , 3 , at ::RestrictPtrTraits>
25+ const torchaudio::stable:: PackedTensorAccessor32<scalar_t , 3 , torchaudio::stable ::RestrictPtrTraits>
2426 logProbs_a,
25- const at:: PackedTensorAccessor32<target_t , 2 , at ::RestrictPtrTraits>
27+ const torchaudio::stable:: PackedTensorAccessor32<target_t , 2 , torchaudio::stable ::RestrictPtrTraits>
2628 targets_a,
2729 const int T,
2830 const int L,
@@ -33,9 +35,9 @@ __global__ void falign_cuda_step_kernel(
3335 int start,
3436 int end,
3537 int backPtrBufferLen,
36- at:: PackedTensorAccessor32<scalar_t , 2 , at ::RestrictPtrTraits>
38+ torchaudio::stable:: PackedTensorAccessor32<scalar_t , 2 , torchaudio::stable ::RestrictPtrTraits>
3739 alphas_a,
38- at:: PackedTensorAccessor32<int8_t , 2 , at ::RestrictPtrTraits>
40+ torchaudio::stable:: PackedTensorAccessor32<int8_t , 2 , torchaudio::stable ::RestrictPtrTraits>
3941 backPtrBuffer_a) {
4042 scalar_t kNegInfinity = -std::numeric_limits<scalar_t >::infinity ();
4143 const int batchIndex =
@@ -122,15 +124,15 @@ void forced_align_impl(
122124 const scalar_t kNegInfinity = -std::numeric_limits<scalar_t >::infinity ();
123125 using target_t = typename std::
124126 conditional<target_scalar_type == ScalarType::Int, int , int64_t >::type;
125- auto paths_a = paths. accessor <target_t , 2 >();
127+ auto paths_a = torchaudio::stable:: accessor<target_t , 2 >(paths );
126128 const int batchIndex =
127129 0 ; // TODO: support batch version and use the real batch index
128130 const int T = logProbs.size (1 ); // num frames
129131 const int N = logProbs.size (2 ); // alphabet size
130132 const int L = targets.size (1 ); // label length
131133 const int S = 2 * L + 1 ;
132134
133- auto targetsCpu = torch ::stable::cpu (targets);
135+ auto targetsCpu = torchaudio ::stable::cpu (targets);
134136 // backPtrBuffer stores the index offset fthe best path at current position
135137 // We copy the values to CPU after running every kBackPtrBufferSize of
136138 // frames.
@@ -147,8 +149,8 @@ void forced_align_impl(
147149 torch::stable::fill_ (alphas, kNegInfinity );
148150
149151 // CPU accessors
150- auto targetsCpu_a = targetsCpu. accessor <target_t , 2 >();
151- auto backPtrCpu_a = backPtrCpu. accessor <int8_t , 2 >();
152+ auto targetsCpu_a = torchaudio::stable:: accessor<target_t , 2 >(targetsCpu );
153+ auto backPtrCpu_a = torchaudio::stable:: accessor<int8_t , 2 >(backPtrCpu );
152154 // count the number of repeats in label
153155 int R = 0 ;
154156 for (int i = 1 ; i < L; ++i) {
@@ -189,8 +191,8 @@ void forced_align_impl(
189191 }
190192 falign_cuda_step_kernel<scalar_t , target_t >
191193 <<<1 , kNumThreads , 0 , defaultStream>>> (
192- logProbs. packed_accessor32 <scalar_t , 3 , at:: RestrictPtrTraits>(),
193- targets. packed_accessor32 <target_t , 2 , at:: RestrictPtrTraits>(),
194+ torchaudio::stable:: packed_accessor32<scalar_t , 3 , torchaudio::stable:: RestrictPtrTraits>(logProbs ),
195+ torchaudio::stable:: packed_accessor32<target_t , 2 , torchaudio::stable:: RestrictPtrTraits>(targets ),
194196 T,
195197 L,
196198 N,
@@ -200,15 +202,14 @@ void forced_align_impl(
200202 start,
201203 end,
202204 backPtrBufferLen,
203- alphas.packed_accessor32 <scalar_t , 2 , at::RestrictPtrTraits>(),
204- backPtrBuffer
205- .packed_accessor32 <int8_t , 2 , at::RestrictPtrTraits>());
205+ torchaudio::stable::packed_accessor32<scalar_t , 2 , torchaudio::stable::RestrictPtrTraits>(alphas),
206+ torchaudio::stable::packed_accessor32<int8_t , 2 , torchaudio::stable::RestrictPtrTraits>(backPtrBuffer));
206207 C10_CUDA_KERNEL_LAUNCH_CHECK ();
207208 ++backPtrBufferLen;
208209 if (backPtrBufferLen == kBackPtrBufferSize || t == T - 1 ) {
209210 cpuDataTranferStream.synchronize ();
210211 // GPU -> GPU copy
211- bufferCopy = backPtrBuffer. clone ();
212+ bufferCopy = torchaudio::stable:: clone (backPtrBuffer );
212213 STD_TORCH_CHECK (bufferCopy.is_contiguous (), " unexpected fail, need to implement stable::Tensor::contiguous()" )
213214 defaultStream.synchronize ();
214215 at::cuda::setCurrentCUDAStream (cpuDataTranferStream);
@@ -227,8 +228,8 @@ void forced_align_impl(
227228 }
228229 cpuDataTranferStream.synchronize ();
229230
230- auto alphasCpu = torch ::stable::cpu (alphas);
231- auto alphasCpu_a = alphasCpu. accessor <scalar_t , 2 >();
231+ auto alphasCpu = torchaudio ::stable::cpu (alphas);
232+ auto alphasCpu_a = torchaudio::stable:: accessor<scalar_t , 2 >(alphasCpu );
232233 int curIdxOffset = ((T - 1 ) % 2 );
233234 int ltrIdx =
234235 alphasCpu_a[curIdxOffset][S - 1 ] > alphasCpu_a[curIdxOffset][S - 2 ]
@@ -294,50 +295,20 @@ std::tuple<Tensor, Tensor> compute(
294295 auto B = logProbs.size (0 );
295296 auto T = logProbs.size (1 ); // num frames
296297
297- Tensor paths = torch::stable::new_empty (targets, {B, T}, std::nullopt , aoti_torch_device_type_cpu ());
298- torch::stable::zero_ (paths);
298+ Tensor paths = torchaudio::stable::new_zeros (targets, {B, T}, /* dtype=*/ std::nullopt , /* layout=*/ std::nullopt , /* device=*/ torchaudio::stable::cpu_device ());
299299
300- switch (logProbs.scalar_type ()) {
301- case ScalarType::Double: {
302- if (targets.scalar_type () == ScalarType::Long) {
303- forced_align_impl<double , ScalarType::Long>(logProbs, targets, blank, paths);
304- } else if (targets.scalar_type () == ScalarType::Int) {
305- forced_align_impl<double , ScalarType::Int>(logProbs, targets, blank, paths);
306- } else {
307- STD_TORCH_CHECK (false , " unreachable" );
308- }
309- break ;
310- }
311- case ScalarType::Float: {
312- if (targets.scalar_type () == ScalarType::Long) {
313- forced_align_impl<float , ScalarType::Long>(logProbs, targets, blank, paths);
314- } else if (targets.scalar_type () == ScalarType::Int) {
315- forced_align_impl<float , ScalarType::Int>(logProbs, targets, blank, paths);
316- } else {
317- STD_TORCH_CHECK (false , " unreachable" );
318- }
319- break ;
320- }
321- case ScalarType::Half: {
322- if (targets.scalar_type () == ScalarType::Long) {
323- forced_align_impl<c10::Half, ScalarType::Long>(logProbs, targets, blank, paths);
324- } else if (targets.scalar_type () == ScalarType::Int) {
325- forced_align_impl<c10::Half, ScalarType::Int>(logProbs, targets, blank, paths);
326- } else {
327- STD_TORCH_CHECK (false , " unreachable" );
328- }
329- break ;
330- }
331- default : {
332- STD_TORCH_CHECK (false , " unreachable" );
333- }
334- };
335- Tensor pathsCuda = torch::stable::new_empty (paths,
336- torchaudio::util::sizes (paths),
337- std::nullopt ,
338- aoti_torch_device_type_cuda (),
339- logProbs.get_device_index ());
340- torch::stable::copy_ (pathsCuda, paths);
300+ STABLE_DISPATCH_FLOATING_TYPES_AND_HALF (
301+ logProbs.scalar_type (), " forced_align_impl" , [&] {
302+ if (targets.scalar_type () == ScalarType::Long) {
303+ forced_align_impl<scalar_t , ScalarType::Long>(
304+ logProbs, targets, blank, paths);
305+ } else {
306+ forced_align_impl<scalar_t , ScalarType::Int>(
307+ logProbs, targets, blank, paths);
308+ }
309+ });
310+
311+ Tensor pathsCuda = torchaudio::stable::cuda (paths, logProbs.get_device_index ());
341312 return std::make_tuple (pathsCuda, logProbs);
342313}
343314
0 commit comments