@@ -32,10 +32,9 @@ void forced_align_impl(
3232 const auto L = targets.size (1 );
3333 const auto S = 2 * L + 1 ;
3434
35- auto alphas_a = new scalar_t [S][2 ]; // scalar_t is just logProbs.dtype()
36- for (int i = 0 ; i < S; i++) {
37- alphas_a[i][0 ] = kNegInfinity ;
38- alphas_a[i][1 ] = kNegInfinity ;
35+ auto alphas_a = new scalar_t [2 * S]; // scalar_t is just logProbs.dtype()
36+ for (int i = 0 ; i < 2 * S; i++) {
37+ alphas_a[i] = kNegInfinity ;
3938 }
4039
4140 auto backPtr_a = new int8_t [T * S];
@@ -64,7 +63,8 @@ void forced_align_impl(
6463 auto end = (S == 1 ) ? 1 : 2 ;
6564 for (auto i = start; i < end; i++) {
6665 auto labelIdx = (i % 2 == 0 ) ? blank : targets_a.index (batchIndex, i / 2 );
67- alphas_a[i][0 ] = logProbs_a.index (batchIndex,0 ,labelIdx);
66+ alphas_a[i] = logProbs_a.index (batchIndex,0 ,labelIdx);
67+
6868 }
6969 for (auto t = 1 ; t < T; t++) {
7070 if (T - t <= L + R) {
@@ -87,18 +87,18 @@ void forced_align_impl(
8787 auto curIdxOffset = t % 2 ;
8888 auto prevIdxOffset = (t - 1 ) % 2 ;
8989 for (auto j = 0 ; j < S; ++j) {
90- alphas_a[j][curIdxOffset] = -std::numeric_limits<scalar_t >::infinity ();
90+ alphas_a[curIdxOffset * S + j] = -std::numeric_limits<scalar_t >::infinity (); // alphas_a[curIdxOffset][j]
9191 }
9292 if (start == 0 ) {
93- alphas_a[0 ][ curIdxOffset] =
94- alphas_a[0 ][ prevIdxOffset] + logProbs_a.index (batchIndex, t, blank);
93+ alphas_a[curIdxOffset * S ] =
94+ alphas_a[prevIdxOffset * S ] + logProbs_a.index (batchIndex, t, blank);
9595 backPtr_a[S * t] = 0 ; // backPtr_a[t][0] = 0
9696 startloop += 1 ;
9797 }
9898
9999 for (auto i = startloop; i < end; i++) {
100- auto x0 = alphas_a[i] [prevIdxOffset];
101- auto x1 = alphas_a[i - 1 ][prevIdxOffset];
100+ auto x0 = alphas_a[prevIdxOffset * S + i]; // alphas_a [prevIdxOffset][i ];
101+ auto x1 = alphas_a[prevIdxOffset * S + i - 1 ]; // alphas_a [prevIdxOffset][i - 1 ];
102102 auto x2 = -std::numeric_limits<scalar_t >::infinity ();
103103
104104 auto labelIdx = (i % 2 == 0 ) ? blank : targets_a.index (batchIndex, i / 2 );
@@ -109,7 +109,7 @@ void forced_align_impl(
109109 // (i != 1) just ensures we don't access targets[i - 2] if its i < 2
110110 if (i % 2 != 0 && i != 1 &&
111111 targets_a.index (batchIndex, i / 2 ) != targets_a.index (batchIndex, i / 2 - 1 )) {
112- x2 = alphas_a[i - 2 ][prevIdxOffset];
112+ x2 = alphas_a[prevIdxOffset * S + i - 2 ]; // alphas_a [prevIdxOffset][i - 2 ];
113113 }
114114 scalar_t result = 0.0 ;
115115 if (x2 > x1 && x2 > x0) {
@@ -122,11 +122,13 @@ void forced_align_impl(
122122 result = x0;
123123 backPtr_a[t * S + i] = 0 ; // backPtr_a[t][i] = 0
124124 }
125- alphas_a[i][curIdxOffset] = result + logProbs_a.index (batchIndex, t, labelIdx);
125+
126+ alphas_a[curIdxOffset * S + i] = result + logProbs_a.index (batchIndex, t, labelIdx); // alphas_a[curIdxOffset][i]
126127 }
127128 }
128129 auto idx1 = (T - 1 ) % 2 ;
129- auto ltrIdx = alphas_a[S - 1 ][idx1] > alphas_a[S - 2 ][idx1] ? S - 1 : S - 2 ;
130+ auto ltrIdx = alphas_a[S * idx1 + S - 1 ] >
131+ alphas_a[S * idx1 + S - 2 ] ? S - 1 : S - 2 ; // alphas_a[idx1][S - 1], alphas_a[idx1][S - 2]
130132 delete[] alphas_a;
131133 // path stores the token index for each time step after force alignment.
132134 for (auto t = T - 1 ; t > -1 ; t--) {
0 commit comments