Skip to content

Commit df970aa

Browse files
csukuangfjdanpovey
authored andcommitted
[pybind] Wrap kaldi_pybind to kaldi package. (#3815)
1 parent 752df8c commit df970aa

16 files changed

+995
-51
lines changed

egs/aishell/s10/chain/chain_loss.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from torch.utils.dlpack import to_dlpack
99

1010
import kaldi
11-
import kaldi_pybind.chain as chain
11+
from kaldi import chain
1212

1313
g_nnet_output_deriv_tensor = None
1414
g_xent_output_deriv_tensor = None
@@ -56,15 +56,15 @@ def forward(ctx, opts, den_graph, supervision, nnet_output_tensor,
5656
# it contains [objf, l2_term, weight] and will be returned to the caller
5757
objf_l2_term_weight_tensor = torch.zeros(3).float()
5858

59-
nnet_output = kaldi.CuSubMatrixFromDLPack(to_dlpack(nnet_output_tensor))
59+
nnet_output = kaldi.PytorchToCuSubMatrix(to_dlpack(nnet_output_tensor))
6060

61-
nnet_output_deriv = kaldi.CuSubMatrixFromDLPack(
61+
nnet_output_deriv = kaldi.PytorchToCuSubMatrix(
6262
to_dlpack(g_nnet_output_deriv_tensor))
6363

64-
xent_output_deriv = kaldi.CuSubMatrixFromDLPack(
64+
xent_output_deriv = kaldi.PytorchToCuSubMatrix(
6565
to_dlpack(g_xent_output_deriv_tensor))
6666

67-
objf_l2_term_weight = kaldi.SubVectorFromDLPack(
67+
objf_l2_term_weight = kaldi.PytorchToSubVector(
6868
to_dlpack(objf_l2_term_weight_tensor))
6969

7070
chain.ComputeChainObjfAndDeriv(opts=opts,

src/pybind/Makefile

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,6 @@ test: all
121121
make -C chain test
122122
make -C cudamatrix test
123123
make -C dlpack test
124-
$(eval include ../../tools/env.sh)
125124
make -C feat test
126125
make -C fst test
127126
make -C matrix test

src/pybind/chain/chain_supervision_pybind_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import numpy as np
1313

14-
import kaldi_pybind.chain as chain
14+
from kaldi import chain
1515

1616

1717
class TestChainSupervision(unittest.TestCase):

src/pybind/feat/feat_pybind_test.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,8 @@
1010
import unittest
1111
import numpy as np
1212

13-
import kaldi_pybind as k
14-
15-
import kaldi_pybind.feat as feat
13+
import kaldi
14+
from kaldi import feat
1615
from kaldi import SequentialWaveReader
1716
from kaldi import SequentialMatrixReader
1817

@@ -35,8 +34,8 @@ def test_mfcc(self):
3534
value.Duration() * value.SampFreq(),
3635
places=1)
3736

38-
waveform = k.FloatSubVector(nd.reshape(nsamp))
39-
features = k.FloatMatrix(1, 1)
37+
waveform = kaldi.FloatSubVector(nd.reshape(nsamp))
38+
features = kaldi.FloatMatrix(1, 1)
4039
mfcc.ComputeFeatures(waveform, value.SampFreq(), 1.0, features)
4140
self.assertEqual(key, gold_reader.Key())
4241
gold_feat = gold_reader.Value().numpy()

src/pybind/fst/arc_pybind_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import unittest
1212

13-
import kaldi_pybind.fst as fst
13+
from kaldi import fst
1414

1515

1616
class TestArc(unittest.TestCase):

src/pybind/fst/fst_pybind_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import unittest
1111

12-
import kaldi_pybind.fst as fst
12+
from kaldi import fst
1313

1414

1515
class TestArc(unittest.TestCase):

src/pybind/fst/symbol_table_pybind_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99

1010
import unittest
1111

12-
import kaldi_pybind.fst as fst
1312
import kaldi
13+
from kaldi import fst
1414

1515

1616
class TestSymbolTable(unittest.TestCase):

src/pybind/fst/vector_fst_pybind_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99

1010
import unittest
1111

12-
import kaldi_pybind.fst as fst
1312
import kaldi
13+
from kaldi import fst
1414

1515

1616
class TestStdVectorFst(unittest.TestCase):

src/pybind/fst/weight_pybind_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import unittest
1212

13-
import kaldi_pybind.fst as fst
13+
from kaldi import fst
1414

1515

1616
class TestWeight(unittest.TestCase):

src/pybind/kaldi.py

Lines changed: 0 additions & 21 deletions
This file was deleted.

0 commit comments

Comments
 (0)