@@ -33,11 +33,14 @@ void forced_align_impl(
3333 alphas_a[i][1 ] = kNegInfinity ;
3434 }
3535
36- torch::Tensor backPtr = torch::empty ({T, S}, torch::kInt8 ).fill_ (-1 );
36+ auto backPtr_a = new int8_t [T * S];
37+ for (int i = 0 ; i < T * S; i++) {
38+ backPtr_a[i] = -1 ;
39+ }
40+
3741 auto logProbs_a = logProbs.accessor <scalar_t , 3 >();
3842 auto targets_a = targets.accessor <target_t , 2 >();
3943 auto paths_a = paths.accessor <target_t , 2 >();
40- auto backPtr_a = backPtr.accessor <int8_t , 2 >();
4144 auto R = 0 ;
4245 for (auto i = 1 ; i < L; i++) {
4346 if (targets_a[batchIndex][i] == targets_a[batchIndex][i - 1 ]) {
@@ -84,7 +87,7 @@ void forced_align_impl(
8487 if (start == 0 ) {
8588 alphas_a[0 ][curIdxOffset] =
8689 alphas_a[0 ][prevIdxOffset] + logProbs_a[batchIndex][t][blank];
87- backPtr_a[t][ 0 ] = 0 ;
90+ backPtr_a[S * t ] = 0 ;
8891 startloop += 1 ;
8992 }
9093
@@ -106,13 +109,13 @@ void forced_align_impl(
106109 scalar_t result = 0.0 ;
107110 if (x2 > x1 && x2 > x0) {
108111 result = x2;
109- backPtr_a[t][ i] = 2 ;
112+ backPtr_a[t * S + i] = 2 ;
110113 } else if (x1 > x0 && x1 > x2) {
111114 result = x1;
112- backPtr_a[t][ i] = 1 ;
115+ backPtr_a[t * S + i] = 1 ;
113116 } else {
114117 result = x0;
115- backPtr_a[t][ i] = 0 ;
118+ backPtr_a[t * S + i] = 0 ;
116119 }
117120 alphas_a[i][curIdxOffset] = result + logProbs_a[batchIndex][t][labelIdx];
118121 }
@@ -123,7 +126,7 @@ void forced_align_impl(
123126 for (auto t = T - 1 ; t > -1 ; t--) {
124127 auto lbl_idx = ltrIdx % 2 == 0 ? blank : targets_a[batchIndex][ltrIdx / 2 ];
125128 paths_a[batchIndex][t] = lbl_idx;
126- ltrIdx -= backPtr_a[t][ ltrIdx];
129+ ltrIdx -= backPtr_a[t * S + ltrIdx];
127130 }
128131}
129132
0 commit comments