|
1 | | -#include <torch/csrc/inductor/aoti_torch/c/shim.h> |
| 1 | +#include <libtorchaudio/utils.h> |
2 | 2 | #include <torch/csrc/stable/library.h> |
3 | | -#include <torch/csrc/stable/ops.h> |
4 | | -#include <torch/csrc/stable/tensor.h> |
5 | | -#include <torch/script.h> |
6 | | -#include <torch/torch.h> |
7 | | - |
8 | | -using namespace std; |
9 | 3 |
|
10 | 4 | namespace torchaudio { |
11 | 5 | namespace alignment { |
12 | 6 | namespace cpu { |
| 7 | + |
| 8 | +using torch::stable::Tensor; |
| 9 | +using torch::headeronly::ScalarType; |
| 10 | + |
13 | 11 | // Inspired from |
14 | 12 | // https://github.com/flashlight/sequence/blob/main/flashlight/lib/sequence/criterion/cpu/ConnectionistTemporalClassificationCriterion.cpp |
15 | | -template <typename scalar_t, at::ScalarType target_scalar_type> |
| 13 | +template <typename scalar_t, ScalarType target_scalar_type> |
16 | 14 | void forced_align_impl( |
17 | | - const torch::Tensor& logProbs, |
18 | | - const torch::Tensor& targets, |
| 15 | + const Tensor& logProbs, |
| 16 | + const Tensor& targets, |
19 | 17 | const int64_t blank, |
20 | | - torch::Tensor& paths) { |
| 18 | + Tensor& paths) { |
21 | 19 | const scalar_t kNegInfinity = -std::numeric_limits<scalar_t>::infinity(); |
22 | 20 | using target_t = typename std:: |
23 | | - conditional<target_scalar_type == torch::kInt, int, int64_t>::type; |
| 21 | + conditional<target_scalar_type == ScalarType::Int, int, int64_t>::type; |
24 | 22 | const auto batchIndex = |
25 | 23 | 0; // TODO: support batch version and use the real batch index |
26 | 24 | const auto T = logProbs.size(1); |
@@ -138,73 +136,111 @@ void forced_align_impl( |
138 | 136 | delete[] backPtr_a; |
139 | 137 | } |
140 | 138 |
|
141 | | -std::tuple<torch::Tensor, torch::Tensor> compute( |
142 | | - const torch::Tensor& logProbs, |
143 | | - const torch::Tensor& targets, |
144 | | - const torch::Tensor& inputLengths, |
145 | | - const torch::Tensor& targetLengths, |
| 139 | +std::tuple<Tensor, Tensor> compute( |
| 140 | + const Tensor& logProbs, |
| 141 | + const Tensor& targets, |
| 142 | + const Tensor& inputLengths, |
| 143 | + const Tensor& targetLengths, |
146 | 144 | const int64_t blank) { |
147 | | - TORCH_CHECK(logProbs.is_cpu(), "log_probs must be a CPU tensor"); |
148 | | - TORCH_CHECK(targets.is_cpu(), "targets must be a CPU tensor"); |
149 | | - TORCH_CHECK( |
150 | | - logProbs.device() == targets.device(), |
151 | | - "log_probs and targets need to be on the same device"); |
152 | | - TORCH_CHECK( |
153 | | - logProbs.dtype() == torch::kFloat64 || |
154 | | - logProbs.dtype() == torch::kFloat32 || |
155 | | - logProbs.dtype() == torch::kFloat16, |
| 145 | + STD_TORCH_CHECK(logProbs.is_cpu(), "log_probs must be a CPU tensor"); |
| 146 | + STD_TORCH_CHECK(targets.is_cpu(), "targets must be a CPU tensor"); |
| 147 | + STD_TORCH_CHECK(inputLengths.is_cpu(), "input_lengths must be a CPU tensor"); |
| 148 | + STD_TORCH_CHECK(targetLengths.is_cpu(), "target_lengths must be a CPU tensor"); |
| 149 | + STD_TORCH_CHECK( |
| 150 | + logProbs.scalar_type() == ScalarType::Double || |
| 151 | + logProbs.scalar_type() == ScalarType::Float || |
| 152 | + logProbs.scalar_type() == ScalarType::Half, |
156 | 153 | "log_probs must be float64, float32 or float16 (half) type"); |
157 | | - TORCH_CHECK( |
158 | | - targets.dtype() == torch::kInt32 || targets.dtype() == torch::kInt64, |
| 154 | + STD_TORCH_CHECK( |
| 155 | + targets.scalar_type() == ScalarType::Int || targets.scalar_type() == ScalarType::Long, |
159 | 156 | "targets must be int32 or int64 type"); |
160 | | - TORCH_CHECK(logProbs.is_contiguous(), "log_probs must be contiguous"); |
161 | | - TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous"); |
162 | | - TORCH_CHECK( |
| 157 | + STD_TORCH_CHECK(logProbs.is_contiguous(), "log_probs must be contiguous"); |
| 158 | + STD_TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous"); |
| 159 | + STD_TORCH_CHECK( |
163 | 160 | logProbs.dim() == 3, |
164 | 161 | "log_probs must be 3-D (batch_size, input length, num classes)"); |
165 | | - TORCH_CHECK( |
| 162 | + STD_TORCH_CHECK( |
166 | 163 | targets.dim() == 2, "targets must be 2-D (batch_size, target length,)"); |
167 | | - TORCH_CHECK( |
| 164 | + STD_TORCH_CHECK( |
168 | 165 | inputLengths.dim() == 1, "input_lengths must be 1-D (batch_size,)"); |
169 | | - TORCH_CHECK( |
| 166 | + STD_TORCH_CHECK( |
170 | 167 | targetLengths.dim() == 1, "target_lengths must be 1-D (batch_size,)"); |
171 | | - TORCH_CHECK( |
| 168 | + STD_TORCH_CHECK( |
172 | 169 | logProbs.size(0) == 1, |
173 | 170 | "The batch dimension for log_probs must be 1 at the current version.") |
174 | | - TORCH_CHECK( |
| 171 | + STD_TORCH_CHECK( |
175 | 172 | targets.size(0) == 1, |
176 | 173 | "The batch dimension for targets must be 1 at the current version.") |
177 | | - TORCH_CHECK( |
| 174 | + STD_TORCH_CHECK( |
178 | 175 | blank >= 0 && blank < logProbs.size(-1), |
179 | 176 | "blank must be within [0, num classes)"); |
180 | 177 |
|
181 | | - TORCH_CHECK( |
182 | | - logProbs.size(1) == at::max(inputLengths).item().toInt(), |
| 178 | + STD_TORCH_CHECK( |
| 179 | + logProbs.size(1) == torchaudio::util::max<int>(inputLengths), |
183 | 180 | "input length mismatch"); |
184 | | - TORCH_CHECK( |
185 | | - targets.size(1) == at::max(targetLengths).item().toInt(), |
| 181 | + STD_TORCH_CHECK( |
| 182 | + targets.size(1) == torchaudio::util::max<int>(targetLengths), |
186 | 183 | "target length mismatch"); |
187 | 184 |
|
188 | 185 | const auto B = logProbs.size(0); |
189 | 186 | const auto T = logProbs.size(1); |
190 | | - auto paths = torch::zeros( |
191 | | - {B, T}, |
192 | | - torch::TensorOptions().device(targets.device()).dtype(targets.dtype())); |
193 | | - AT_DISPATCH_FLOATING_TYPES_AND_HALF( |
194 | | - logProbs.scalar_type(), "forced_align_impl", [&] { |
195 | | - if (targets.scalar_type() == torch::kInt64) { |
196 | | - forced_align_impl<scalar_t, torch::kInt64>( |
197 | | - logProbs, targets, blank, paths); |
198 | | - } else { |
199 | | - forced_align_impl<scalar_t, torch::kInt32>( |
200 | | - logProbs, targets, blank, paths); |
201 | | - } |
202 | | - }); |
| 187 | + Tensor paths = torch::stable::new_empty(targets, {B, T}); |
| 188 | + torch::stable::zero_(paths); |
| 189 | + |
| 190 | + switch (logProbs.scalar_type()) { |
| 191 | + case ScalarType::Double: { |
| 192 | + if (targets.scalar_type() == ScalarType::Long) { |
| 193 | + forced_align_impl<double, ScalarType::Long>(logProbs, targets, blank, paths); |
| 194 | + } else if (targets.scalar_type() == ScalarType::Int) { |
| 195 | + forced_align_impl<double, ScalarType::Int>(logProbs, targets, blank, paths); |
| 196 | + } else { |
| 197 | + STD_TORCH_CHECK(false, "unreachable"); |
| 198 | + } |
| 199 | + break; |
| 200 | + } |
| 201 | + case ScalarType::Float: { |
| 202 | + if (targets.scalar_type() == ScalarType::Long) { |
| 203 | + forced_align_impl<float, ScalarType::Long>(logProbs, targets, blank, paths); |
| 204 | + } else if (targets.scalar_type() == ScalarType::Int) { |
| 205 | + forced_align_impl<float, ScalarType::Int>(logProbs, targets, blank, paths); |
| 206 | + } else { |
| 207 | + STD_TORCH_CHECK(false, "unreachable"); |
| 208 | + } |
| 209 | + break; |
| 210 | + } |
| 211 | + case ScalarType::Half: { |
| 212 | + if (targets.scalar_type() == ScalarType::Long) { |
| 213 | + forced_align_impl<c10::Half, ScalarType::Long>(logProbs, targets, blank, paths); |
| 214 | + } else if (targets.scalar_type() == ScalarType::Int) { |
| 215 | + forced_align_impl<c10::Half, ScalarType::Int>(logProbs, targets, blank, paths); |
| 216 | + } else { |
| 217 | + STD_TORCH_CHECK(false, "unreachable"); |
| 218 | + } |
| 219 | + break; |
| 220 | + } |
| 221 | + default: { |
| 222 | + STD_TORCH_CHECK(false, "unreachable"); |
| 223 | + } |
| 224 | + }; |
| 225 | + |
203 | 226 | return std::make_tuple(paths, logProbs); |
204 | 227 | } |
205 | 228 |
|
206 | | -TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { |
207 | | - m.impl("forced_align", &compute); |
| 229 | +void boxed_forced_align_cpu(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { |
| 230 | + STD_TORCH_CHECK(num_args == 5, "num_args must be 5"); |
| 231 | + STD_TORCH_CHECK(num_outputs == 2, "num_outputs must be 2"); |
| 232 | + std::tuple<Tensor, Tensor> res = compute( |
| 233 | + /*logProbs*/to<Tensor>(stack[0]), |
| 234 | + /*targets*/to<Tensor>(stack[1]), |
| 235 | + /*logit_lengths*/to<Tensor>(stack[2]), |
| 236 | + /*target_lengths*/to<Tensor>(stack[3]), |
| 237 | + /*blank*/float(to<int64_t>(stack[4]))); |
| 238 | + stack[0] = from(std::get<0>(res)); |
| 239 | + stack[1] = from(std::get<1>(res)); |
| 240 | +} |
| 241 | + |
| 242 | +STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { |
| 243 | + m.impl("forced_align", &boxed_forced_align_cpu); |
208 | 244 | } |
209 | 245 |
|
210 | 246 | } // namespace cpu |
|
0 commit comments