File tree Expand file tree Collapse file tree 5 files changed +78
-39
lines changed Expand file tree Collapse file tree 5 files changed +78
-39
lines changed Original file line number Diff line number Diff line change 66 lfilter.cpp
77 overdrive.cpp
88 utils.cpp
9+ accessor_tests.cpp
910 )
1011
1112set (
Original file line number Diff line number Diff line change 1+ #pragma once
2+
3+ #include < torch/torch.h>
4+ #include < type_traits>
5+ #include < cstdarg>
6+
7+ template <unsigned int k, typename T, bool IsConst = true >
8+ class Accessor {
9+ int64_t strides[k];
10+ T *data;
11+
12+ public:
13+ using tensor_type = typename std::conditional<IsConst, const torch::Tensor&, torch::Tensor&>::type;
14+
15+ Accessor (tensor_type tensor) {
16+ data = tensor.template data_ptr <T>();
17+ for (int i = 0 ; i < k; i++) {
18+ strides[i] = tensor.stride (i);
19+ }
20+ }
21+
22+ T index (...) {
23+ va_list args;
24+ va_start (args, k);
25+ int64_t ix = 0 ;
26+ for (int i = 0 ; i < k; i++) {
27+ ix += strides[i] * va_arg (args, int );
28+ }
29+ va_end (args);
30+ return data[ix];
31+ }
32+
33+ template <bool C = IsConst>
34+ typename std::enable_if<!C, void >::type set_index (T value, ...) {
35+ va_list args;
36+ va_start (args, value);
37+ int64_t ix = 0 ;
38+ for (int i = 0 ; i < k; i++) {
39+ ix += strides[i] * va_arg (args, int );
40+ }
41+ va_end (args);
42+ data[ix] = value;
43+ }
44+ };
Original file line number Diff line number Diff line change 1+ #include < libtorchaudio/accessor.h>
2+ #include < cstdint>
3+ #include < torch/torch.h>
4+
5+ using namespace std ;
6+
7+ bool test_accessor (const torch::Tensor& tensor) {
8+ int64_t * data_ptr = tensor.template data_ptr <int64_t >();
9+ auto accessor = Accessor<3 , int64_t >(tensor);
10+ for (int i = 0 ; i < tensor.size (0 ); i++) {
11+ for (int j = 0 ; j < tensor.size (1 ); j++) {
12+ for (int k = 0 ; k < tensor.size (2 ); k++) {
13+ auto check = *(data_ptr++) == accessor.index (i, j, k);
14+ if (!check) {
15+ return false ;
16+ }
17+ }
18+ }
19+ }
20+ return true ;
21+ }
22+
23+ TORCH_LIBRARY_FRAGMENT (torchaudio, m) {
24+ m.def (" torchaudio::_test_accessor" , &test_accessor);
25+ }
Original file line number Diff line number Diff line change 55#include < torch/csrc/stable/ops.h>
66#include < torch/csrc/inductor/aoti_torch/c/shim.h>
77#include < torch/csrc/inductor/aoti_torch/utils.h>
8- #include < cstdarg>
9- #include < type_traits>
8+ #include < libtorchaudio/accessor.h>
109
1110
1211using namespace std ;
@@ -15,44 +14,7 @@ namespace torchaudio {
1514namespace alignment {
1615namespace cpu {
1716
18- template <unsigned int k, typename T, bool IsConst = true >
19- class Accessor {
20- int64_t strides[k];
21- T *data;
22-
23- public:
24- using tensor_type = typename std::conditional<IsConst, const torch::Tensor&, torch::Tensor&>::type;
25-
26- Accessor (tensor_type tensor) {
27- data = tensor.template data_ptr <T>();
28- for (int i = 0 ; i < k; i++) {
29- strides[i] = tensor.stride (i);
30- }
31- }
3217
33- T index (...) {
34- va_list args;
35- va_start (args, k);
36- int64_t ix = 0 ;
37- for (int i = 0 ; i < k; i++) {
38- ix += strides[i] * va_arg (args, int );
39- }
40- va_end (args);
41- return data[ix];
42- }
43-
44- template <bool C = IsConst>
45- typename std::enable_if<!C, void >::type set_index (T value, ...) {
46- va_list args;
47- va_start (args, value);
48- int64_t ix = 0 ;
49- for (int i = 0 ; i < k; i++) {
50- ix += strides[i] * va_arg (args, int );
51- }
52- va_end (args);
53- data[ix] = value;
54- }
55- };
5618
5719// Inspired from
5820// https://github.com/flashlight/sequence/blob/main/flashlight/lib/sequence/criterion/cpu/ConnectionistTemporalClassificationCriterion.cpp
Original file line number Diff line number Diff line change 1+ import torch
2+ from torchaudio ._extension import _IS_TORCHAUDIO_EXT_AVAILABLE
3+
4+ if _IS_TORCHAUDIO_EXT_AVAILABLE :
5+ def test_accessor ():
6+ tensor = torch .randint (1000 , (5 ,4 ,3 ))
7+ assert torch .ops .torchaudio ._test_accessor (tensor )
You can’t perform that action at this time.
0 commit comments