11#include < torch/script.h>
22#include < torch/torch.h>
3+ #include < torch/csrc/stable/library.h>
4+ #include < torch/csrc/stable/tensor.h>
5+ #include < torch/csrc/stable/ops.h>
6+ #include < torch/csrc/inductor/aoti_torch/c/shim.h>
37
48using namespace std ;
59
@@ -22,17 +26,17 @@ void forced_align_impl(
2226 const auto T = logProbs.size (1 );
2327 const auto L = targets.size (1 );
2428 const auto S = 2 * L + 1 ;
25- torch::Tensor alphas = torch::empty (
26- {2 , S},
27- torch::TensorOptions ()
28- .device (logProbs.device ())
29- .dtype (logProbs.dtype ()))
30- .fill_ (kNegInfinity );
29+
30+ auto alphas_a = new scalar_t [S][2 ];
31+ for (int i = 0 ; i < S; i++) {
32+ alphas_a[i][0 ] = kNegInfinity ;
33+ alphas_a[i][1 ] = kNegInfinity ;
34+ }
35+
3136 torch::Tensor backPtr = torch::empty ({T, S}, torch::kInt8 ).fill_ (-1 );
3237 auto logProbs_a = logProbs.accessor <scalar_t , 3 >();
3338 auto targets_a = targets.accessor <target_t , 2 >();
3439 auto paths_a = paths.accessor <target_t , 2 >();
35- auto alphas_a = alphas.accessor <scalar_t , 2 >();
3640 auto backPtr_a = backPtr.accessor <int8_t , 2 >();
3741 auto R = 0 ;
3842 for (auto i = 1 ; i < L; i++) {
@@ -52,7 +56,7 @@ void forced_align_impl(
5256 auto end = (S == 1 ) ? 1 : 2 ;
5357 for (auto i = start; i < end; i++) {
5458 auto labelIdx = (i % 2 == 0 ) ? blank : targets_a[batchIndex][i / 2 ];
55- alphas_a[0 ][i ] = logProbs_a[batchIndex][0 ][labelIdx];
59+ alphas_a[i][ 0 ] = logProbs_a[batchIndex][0 ][labelIdx];
5660 }
5761 for (auto t = 1 ; t < T; t++) {
5862 if (T - t <= L + R) {
@@ -75,18 +79,18 @@ void forced_align_impl(
7579 auto curIdxOffset = t % 2 ;
7680 auto prevIdxOffset = (t - 1 ) % 2 ;
7781 for (auto j = 0 ; j < S; ++j) {
78- alphas_a[curIdxOffset][j ] = -std::numeric_limits<scalar_t >::infinity ();
82+ alphas_a[j][curIdxOffset ] = -std::numeric_limits<scalar_t >::infinity ();
7983 }
8084 if (start == 0 ) {
81- alphas_a[curIdxOffset][ 0 ] =
82- alphas_a[prevIdxOffset][ 0 ] + logProbs_a[batchIndex][t][blank];
85+ alphas_a[0 ][curIdxOffset ] =
86+ alphas_a[0 ][prevIdxOffset ] + logProbs_a[batchIndex][t][blank];
8387 backPtr_a[t][0 ] = 0 ;
8488 startloop += 1 ;
8589 }
8690
8791 for (auto i = startloop; i < end; i++) {
88- auto x0 = alphas_a[prevIdxOffset][i ];
89- auto x1 = alphas_a[prevIdxOffset][ i - 1 ];
92+ auto x0 = alphas_a[i][prevIdxOffset ];
93+ auto x1 = alphas_a[i - 1 ][prevIdxOffset ];
9094 auto x2 = -std::numeric_limits<scalar_t >::infinity ();
9195
9296 auto labelIdx = (i % 2 == 0 ) ? blank : targets_a[batchIndex][i / 2 ];
@@ -97,7 +101,7 @@ void forced_align_impl(
97101 // (i != 1) just ensures we don't access targets[i - 2] if its i < 2
98102 if (i % 2 != 0 && i != 1 &&
99103 targets_a[batchIndex][i / 2 ] != targets_a[batchIndex][i / 2 - 1 ]) {
100- x2 = alphas_a[prevIdxOffset][ i - 2 ];
104+ x2 = alphas_a[i - 2 ][prevIdxOffset ];
101105 }
102106 scalar_t result = 0.0 ;
103107 if (x2 > x1 && x2 > x0) {
@@ -110,11 +114,11 @@ void forced_align_impl(
110114 result = x0;
111115 backPtr_a[t][i] = 0 ;
112116 }
113- alphas_a[curIdxOffset][i ] = result + logProbs_a[batchIndex][t][labelIdx];
117+ alphas_a[i][curIdxOffset ] = result + logProbs_a[batchIndex][t][labelIdx];
114118 }
115119 }
116120 auto idx1 = (T - 1 ) % 2 ;
117- auto ltrIdx = alphas_a[idx1][ S - 1 ] > alphas_a[idx1][ S - 2 ] ? S - 1 : S - 2 ;
121+ auto ltrIdx = alphas_a[S - 1 ][idx1] > alphas_a[S - 2 ][idx1 ] ? S - 1 : S - 2 ;
118122 // path stores the token index for each time step after force alignment.
119123 for (auto t = T - 1 ; t > -1 ; t--) {
120124 auto lbl_idx = ltrIdx % 2 == 0 ? blank : targets_a[batchIndex][ltrIdx / 2 ];
0 commit comments