@@ -27,10 +27,9 @@ void forced_align_impl(
2727 const auto L = targets.size (1 );
2828 const auto S = 2 * L + 1 ;
2929
30- auto alphas_a = new scalar_t [S][2 ]; // scalar_t is just logProbs.dtype()
31- for (int i = 0 ; i < S; i++) {
32- alphas_a[i][0 ] = kNegInfinity ;
33- alphas_a[i][1 ] = kNegInfinity ;
30+ auto alphas_a = new scalar_t [2 * S]; // scalar_t is just logProbs.dtype()
31+ for (int i = 0 ; i < 2 * S; i++) {
32+ alphas_a[i] = kNegInfinity ;
3433 }
3534
3635 torch::Tensor backPtr = torch::empty ({T, S}, torch::kInt8 ).fill_ (-1 );
@@ -56,7 +55,7 @@ void forced_align_impl(
5655 auto end = (S == 1 ) ? 1 : 2 ;
5756 for (auto i = start; i < end; i++) {
5857 auto labelIdx = (i % 2 == 0 ) ? blank : targets_a[batchIndex][i / 2 ];
59- alphas_a[i][ 0 ] = logProbs_a[batchIndex][0 ][labelIdx];
58+ alphas_a[i] = logProbs_a[batchIndex][0 ][labelIdx]; // alphas_a[0, i]
6059 }
6160 for (auto t = 1 ; t < T; t++) {
6261 if (T - t <= L + R) {
@@ -79,18 +78,18 @@ void forced_align_impl(
7978 auto curIdxOffset = t % 2 ;
8079 auto prevIdxOffset = (t - 1 ) % 2 ;
8180 for (auto j = 0 ; j < S; ++j) {
82- alphas_a[j][curIdxOffset] = -std::numeric_limits<scalar_t >::infinity ();
81+ alphas_a[curIdxOffset * S + j] = -std::numeric_limits<scalar_t >::infinity (); // alphas_a[curIdxOffset][j]
8382 }
8483 if (start == 0 ) {
85- alphas_a[0 ][ curIdxOffset] =
86- alphas_a[0 ][ prevIdxOffset] + logProbs_a[batchIndex][t][blank];
84+ alphas_a[curIdxOffset * S ] =
85+ alphas_a[prevIdxOffset * S ] + logProbs_a[batchIndex][t][blank];
8786 backPtr_a[t][0 ] = 0 ;
8887 startloop += 1 ;
8988 }
9089
9190 for (auto i = startloop; i < end; i++) {
92- auto x0 = alphas_a[i] [prevIdxOffset];
93- auto x1 = alphas_a[i - 1 ][prevIdxOffset];
91+ auto x0 = alphas_a[prevIdxOffset * S + i]; // alphas_a [prevIdxOffset][i ];
92+ auto x1 = alphas_a[prevIdxOffset * S + i - 1 ]; // alphas_a [prevIdxOffset][i - 1 ];
9493 auto x2 = -std::numeric_limits<scalar_t >::infinity ();
9594
9695 auto labelIdx = (i % 2 == 0 ) ? blank : targets_a[batchIndex][i / 2 ];
@@ -101,7 +100,7 @@ void forced_align_impl(
101100 // (i != 1) just ensures we don't access targets[i - 2] if its i < 2
102101 if (i % 2 != 0 && i != 1 &&
103102 targets_a[batchIndex][i / 2 ] != targets_a[batchIndex][i / 2 - 1 ]) {
104- x2 = alphas_a[i - 2 ][prevIdxOffset];
103+ x2 = alphas_a[prevIdxOffset * S + i - 2 ]; // alphas_a [prevIdxOffset][i - 2 ];
105104 }
106105 scalar_t result = 0.0 ;
107106 if (x2 > x1 && x2 > x0) {
@@ -114,11 +113,12 @@ void forced_align_impl(
114113 result = x0;
115114 backPtr_a[t][i] = 0 ;
116115 }
117- alphas_a[i][curIdxOffset] = result + logProbs_a[batchIndex][t][labelIdx];
116+ alphas_a[curIdxOffset * S + i] = result + logProbs_a[batchIndex][t][labelIdx]; // alphas_a[curIdxOffset][i]
118117 }
119118 }
120119 auto idx1 = (T - 1 ) % 2 ;
121- auto ltrIdx = alphas_a[S - 1 ][idx1] > alphas_a[S - 2 ][idx1] ? S - 1 : S - 2 ;
120+ auto ltrIdx = alphas_a[S * idx1 + S - 1 ] >
121+ alphas_a[S * idx1 + S - 2 ] ? S - 1 : S - 2 ; // alphas_a[idx1][S - 1], alphas_a[idx1][S - 2]
122122 delete[] alphas_a;
123123 // path stores the token index for each time step after force alignment.
124124 for (auto t = T - 1 ; t > -1 ; t--) {
0 commit comments