From 22649074a87673f567ee7a5522384cf4c65461d7 Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Tue, 19 Aug 2025 21:20:43 +0000 Subject: [PATCH 1/4] Use stable tensors in overdrive --- src/libtorchaudio/overdrive.cpp | 106 +++++++++++++++++++++++--------- 1 file changed, 76 insertions(+), 30 deletions(-) diff --git a/src/libtorchaudio/overdrive.cpp b/src/libtorchaudio/overdrive.cpp index 4954271e41..545e172acb 100644 --- a/src/libtorchaudio/overdrive.cpp +++ b/src/libtorchaudio/overdrive.cpp @@ -1,52 +1,98 @@ #include #include +#include +#include +#include +#include +#include +#include -namespace { +using namespace std; + +namespace torchaudio { + +using torch::stable::Tensor; template void overdrive_cpu_kernel( - at::TensorAccessor waveform_accessor, - at::TensorAccessor temp_accessor, - at::TensorAccessor last_in_accessor, - at::TensorAccessor last_out_accessor, - at::TensorAccessor output_waveform_accessor) { + Accessor<2, scalar_t> waveform_accessor, + Accessor<2, scalar_t> temp_accessor, + Accessor<1, scalar_t, false> last_in_accessor, + Accessor<1, scalar_t> last_out_accessor, + Accessor<2, scalar_t, false> output_waveform_accessor) { int64_t n_frames = waveform_accessor.size(1); int64_t n_channels = waveform_accessor.size(0); at::parallel_for(0, n_channels, 1, [&](int64_t begin, int64_t end) { for (int64_t i_channel = begin; i_channel < end; ++i_channel) { for (int64_t i_frame = 0; i_frame < n_frames; ++i_frame) { - last_out_accessor[i_channel] = temp_accessor[i_channel][i_frame] - - last_in_accessor[i_channel] + 0.995 * last_out_accessor[i_channel]; - last_in_accessor[i_channel] = temp_accessor[i_channel][i_frame]; - output_waveform_accessor[i_channel][i_frame] = - waveform_accessor[i_channel][i_frame] * 0.5 + - last_out_accessor[i_channel] * 0.75; + last_out_accessor.set_index( + temp_accessor.index(i_channel, i_frame) - + last_in_accessor.index(i_channel) + 0.995 * last_out_accessor.index(i_channel), + i_channel); + last_in_accessor.set_index(temp_accessor.index(i_channel, i_frame), i_channel); + output_waveform_accessor.set_index( + waveform_accessor.index(i_channel, i_frame) * 0.5 + + last_out_accessor.index(i_channel) * 0.75, + i_channel, i_frame); } } }); } void overdrive_core_loop_cpu( - at::Tensor& waveform, - at::Tensor& temp, - at::Tensor& last_in, - at::Tensor& last_out, - at::Tensor& output_waveform) { - AT_DISPATCH_FLOATING_TYPES(waveform.scalar_type(), "overdrive_cpu", ([&] { - overdrive_cpu_kernel( - waveform.accessor(), - temp.accessor(), - last_in.accessor(), - last_out.accessor(), - output_waveform.accessor()); - })); + const Tensor waveform, + const Tensor temp, + Tensor last_in, + const Tensor last_out, + Tensor output_waveform) { + int32_t dtype; + aoti_torch_get_dtype(waveform.get(), &dtype); + if (dtype == aoti_torch_dtype_float64()) { + overdrive_cpu_kernel( + Accessor<2, double>(wave_acc), + Accessor<2, double>(temp_acc), + Accessor<1, double>(last_in_acc), + Accessor<1, double>(last_out_acc), + Accessor<2, double>(out_acc)); + } else if (dtype == aoti_torch_dtype_float32()) { + overdrive_cpu_kernel( + Accessor<2, float>(wave_acc), + Accessor<2, float>(temp_acc), + Accessor<1, float>(last_in_acc), + Accessor<1, float>(last_out_acc), + Accessor<2, float>(out_acc)); + } else if (dtype == aoti_torch_dtype_float16()) { + overdrive_cpu_kernel( + Accessor<2, c10::Half>(wave_acc), + Accessor<2, c10::Half>(temp_acc), + Accessor<1, c10::Half>(last_in_acc), + Accessor<1, c10::Half>(last_out_acc), + Accessor<2, c10::Half>(out_acc)); + } } -} // namespace -// Note: We want to avoid using "catch-all" kernel. -// The following registration should be replaced with CPU specific registration. -TORCH_LIBRARY_FRAGMENT(torchaudio, m) { - m.def("torchaudio::_overdrive_core_loop", &overdrive_core_loop_cpu); + +void boxed_overdrive_core_loop(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { + Tensor t1(to(stack[0])); + Tensor t2(to(stack[1])); + Tensor t3(to(stack[2])); + Tensor t4(to(stack[3])); + Tensor t5(to(stack[4])); + overdrive_core_loop( + std::move(t1), std::move(t2), std::move(t3), std::move(t4), std::move(t5)); } + +STABLE_TORCH_LIBRARY_FRAGMENT(torchaudio, m) { + m.def( + "overdrive_core_loop(Tensor waveform," + "Tensor temp, Tensor last_in, Tensor last_out," + "Tensor output_waveform)" +} + +STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { + m.impl("overdrive_core_loop", &overdrive_core_loop_cpu); +} + +} // namespace From 3853f57750fcc6d16f03c1f3bf159cd100fbe82d Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Tue, 19 Aug 2025 22:05:53 +0000 Subject: [PATCH 2/4] Fix misnaming in overdrive --- src/libtorchaudio/overdrive.cpp | 36 ++++++++++++++++----------------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/src/libtorchaudio/overdrive.cpp b/src/libtorchaudio/overdrive.cpp index 545e172acb..22ae2576a4 100644 --- a/src/libtorchaudio/overdrive.cpp +++ b/src/libtorchaudio/overdrive.cpp @@ -50,25 +50,25 @@ void overdrive_core_loop_cpu( aoti_torch_get_dtype(waveform.get(), &dtype); if (dtype == aoti_torch_dtype_float64()) { overdrive_cpu_kernel( - Accessor<2, double>(wave_acc), - Accessor<2, double>(temp_acc), - Accessor<1, double>(last_in_acc), - Accessor<1, double>(last_out_acc), - Accessor<2, double>(out_acc)); + Accessor<2, double>(waveform), + Accessor<2, double>(temp), + Accessor<1, double, false>(last_in), + Accessor<1, double>(last_out), + Accessor<2, double, false>(output_waveform)); } else if (dtype == aoti_torch_dtype_float32()) { overdrive_cpu_kernel( - Accessor<2, float>(wave_acc), - Accessor<2, float>(temp_acc), - Accessor<1, float>(last_in_acc), - Accessor<1, float>(last_out_acc), - Accessor<2, float>(out_acc)); + Accessor<2, float>(waveform), + Accessor<2, float>(temp), + Accessor<1, float, false>(last_in), + Accessor<1, float>(last_out), + Accessor<2, float, false>(output_waveform)); } else if (dtype == aoti_torch_dtype_float16()) { overdrive_cpu_kernel( - Accessor<2, c10::Half>(wave_acc), - Accessor<2, c10::Half>(temp_acc), - Accessor<1, c10::Half>(last_in_acc), - Accessor<1, c10::Half>(last_out_acc), - Accessor<2, c10::Half>(out_acc)); + Accessor<2, c10::Half>(waveform), + Accessor<2, c10::Half>(temp), + Accessor<1, c10::Half, false>(last_in), + Accessor<1, c10::Half>(last_out), + Accessor<2, c10::Half, false>(output_waveform)); } } @@ -80,7 +80,7 @@ void boxed_overdrive_core_loop(StableIValue* stack, uint64_t num_args, uint64_t Tensor t3(to(stack[2])); Tensor t4(to(stack[3])); Tensor t5(to(stack[4])); - overdrive_core_loop( + overdrive_core_loop_cpu( std::move(t1), std::move(t2), std::move(t3), std::move(t4), std::move(t5)); } @@ -88,11 +88,11 @@ STABLE_TORCH_LIBRARY_FRAGMENT(torchaudio, m) { m.def( "overdrive_core_loop(Tensor waveform," "Tensor temp, Tensor last_in, Tensor last_out," - "Tensor output_waveform)" + "Tensor output_waveform)"); } STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { - m.impl("overdrive_core_loop", &overdrive_core_loop_cpu); + m.impl("overdrive_core_loop", &boxed_overdrive_core_loop); } } // namespace From 2a06d7e904db86a338b7d69ffb89f514f9120268 Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Tue, 19 Aug 2025 22:21:59 +0000 Subject: [PATCH 3/4] Fix missing const errors --- src/libtorchaudio/overdrive.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/libtorchaudio/overdrive.cpp b/src/libtorchaudio/overdrive.cpp index 22ae2576a4..8f8001928f 100644 --- a/src/libtorchaudio/overdrive.cpp +++ b/src/libtorchaudio/overdrive.cpp @@ -18,7 +18,7 @@ void overdrive_cpu_kernel( Accessor<2, scalar_t> waveform_accessor, Accessor<2, scalar_t> temp_accessor, Accessor<1, scalar_t, false> last_in_accessor, - Accessor<1, scalar_t> last_out_accessor, + Accessor<1, scalar_t, false> last_out_accessor, Accessor<2, scalar_t, false> output_waveform_accessor) { int64_t n_frames = waveform_accessor.size(1); int64_t n_channels = waveform_accessor.size(0); @@ -44,7 +44,7 @@ void overdrive_core_loop_cpu( const Tensor waveform, const Tensor temp, Tensor last_in, - const Tensor last_out, + Tensor last_out, Tensor output_waveform) { int32_t dtype; aoti_torch_get_dtype(waveform.get(), &dtype); @@ -53,21 +53,21 @@ void overdrive_core_loop_cpu( Accessor<2, double>(waveform), Accessor<2, double>(temp), Accessor<1, double, false>(last_in), - Accessor<1, double>(last_out), + Accessor<1, double, false>(last_out), Accessor<2, double, false>(output_waveform)); } else if (dtype == aoti_torch_dtype_float32()) { overdrive_cpu_kernel( Accessor<2, float>(waveform), Accessor<2, float>(temp), Accessor<1, float, false>(last_in), - Accessor<1, float>(last_out), + Accessor<1, float, false>(last_out), Accessor<2, float, false>(output_waveform)); } else if (dtype == aoti_torch_dtype_float16()) { overdrive_cpu_kernel( Accessor<2, c10::Half>(waveform), Accessor<2, c10::Half>(temp), Accessor<1, c10::Half, false>(last_in), - Accessor<1, c10::Half>(last_out), + Accessor<1, c10::Half, false>(last_out), Accessor<2, c10::Half, false>(output_waveform)); } } From d7db30a9375c85bf05e55716f68fbc1c7138ef30 Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Tue, 19 Aug 2025 22:36:49 +0000 Subject: [PATCH 4/4] Fix python export of overdrive --- src/libtorchaudio/overdrive.cpp | 2 +- src/torchaudio/functional/filtering.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/libtorchaudio/overdrive.cpp b/src/libtorchaudio/overdrive.cpp index 8f8001928f..387e4a7f78 100644 --- a/src/libtorchaudio/overdrive.cpp +++ b/src/libtorchaudio/overdrive.cpp @@ -88,7 +88,7 @@ STABLE_TORCH_LIBRARY_FRAGMENT(torchaudio, m) { m.def( "overdrive_core_loop(Tensor waveform," "Tensor temp, Tensor last_in, Tensor last_out," - "Tensor output_waveform)"); + "Tensor output_waveform) -> ()"); } STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { diff --git a/src/torchaudio/functional/filtering.py b/src/torchaudio/functional/filtering.py index 76deb04a96..662e59cb95 100644 --- a/src/torchaudio/functional/filtering.py +++ b/src/torchaudio/functional/filtering.py @@ -1114,7 +1114,7 @@ def _overdrive_core_loop_generic( if _IS_TORCHAUDIO_EXT_AVAILABLE: - _overdrive_core_loop_cpu = torch.ops.torchaudio._overdrive_core_loop + _overdrive_core_loop_cpu = torch.ops.torchaudio.overdrive_core_loop.default else: _overdrive_core_loop_cpu = _overdrive_core_loop_generic