Skip to content

Commit 752df8c

Browse files
csukuangfjdanpovey
authored andcommitted
[src,pybind] Replace MatrixWriter with CompressedMatrixWriter. (#3803)
1 parent 277bbaf commit 752df8c

19 files changed

+355
-123
lines changed

egs/aishell/s10/chain/inference.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,10 @@ def main():
2525
logging.info(' '.join(sys.argv))
2626

2727
if torch.cuda.is_available() == False:
28-
logging.error('No GPU detected!')
29-
sys.exit(-1)
30-
31-
kaldi.SelectGpuDevice(device_id=args.device_id)
32-
kaldi.CuDeviceAllowMultithreading()
33-
device = torch.device('cuda', args.device_id)
28+
logging.warning('No GPU detected! Use CPU for inference.')
29+
device = torch.device('cpu')
30+
else:
31+
device = torch.device('cuda', args.device_id)
3432

3533
model = get_chain_model(feat_dim=args.feat_dim,
3634
output_dim=args.output_dim,
@@ -47,7 +45,14 @@ def main():
4745
specifier = 'ark,scp:{filename}.ark,{filename}.scp'.format(
4846
filename=os.path.join(args.dir, 'confidence'))
4947

50-
writer = kaldi.MatrixWriter(specifier)
48+
if args.save_as_compressed:
49+
Writer = kaldi.CompressedMatrixWriter
50+
Matrix = kaldi.CompressedMatrix
51+
else:
52+
Writer = kaldi.MatrixWriter
53+
Matrix = kaldi.FloatMatrix
54+
55+
writer = Writer(specifier)
5156

5257
dataloader = get_feat_dataloader(
5358
feats_scp=args.feats_scp,
@@ -69,13 +74,14 @@ def main():
6974
value = value.cpu()
7075

7176
m = kaldi.SubMatrixFromDLPack(to_dlpack(value))
72-
m = kaldi.FloatMatrix(m)
77+
m = Matrix(m)
7378
writer.Write(key, m)
7479

7580
if batch_idx % 10 == 0:
7681
logging.info('Processed batch {}/{} ({:.6f}%)'.format(
7782
batch_idx, len(dataloader),
7883
float(batch_idx) / len(dataloader) * 100))
84+
7985
writer.Close()
8086
logging.info('confidence is saved to {}'.format(
8187
os.path.join(args.dir, 'confidence.scp')))

egs/aishell/s10/chain/options.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,21 @@
77
import os
88

99

10+
def _str2bool(v):
11+
'''
12+
This function is modified from
13+
https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
14+
'''
15+
if isinstance(v, bool):
16+
return v
17+
elif v.lower() in ('yes', 'true', 't', 'y', '1'):
18+
return True
19+
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
20+
return False
21+
else:
22+
raise argparse.ArgumentTypeError('Boolean value expected.')
23+
24+
1025
def _set_inference_args(parser):
1126
parser.add_argument('--feats-scp',
1227
dest='feats_scp',
@@ -25,6 +40,13 @@ def _set_inference_args(parser):
2540
type=int,
2641
default=-1)
2742

43+
parser.add_argument(
44+
'--save-as-compressed',
45+
dest='save_as_compressed',
46+
help='true to save the neural network output to CompressedMatrix,'
47+
' false to Matrix<float>',
48+
type=_str2bool)
49+
2850

2951
def _set_training_args(parser):
3052
parser.add_argument('--train.cegs-dir',
@@ -92,9 +114,7 @@ def _check_inference_args(args):
92114

93115
def _check_args(args):
94116

95-
assert args.is_training in [0, 1]
96-
97-
if args.is_training == 1:
117+
if args.is_training:
98118
_check_training_args(args)
99119
else:
100120
_check_inference_args(args)
@@ -146,9 +166,9 @@ def get_args():
146166

147167
parser.add_argument('--is-training',
148168
dest='is_training',
149-
help='1 for training, 0 for inference',
169+
help='true for training, false for inference',
150170
required=True,
151-
type=int)
171+
type=_str2bool)
152172

153173
parser.add_argument(
154174
'--lda-mat-filename',

egs/aishell/s10/local/run_chain.sh

100644100755
Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,16 @@ minibatch_size=128
3030
num_epochs=10
3131
lr=2e-3
3232

33-
feat_dim=$(cat exp/chain/egs/info/feat_dim)
34-
output_dim=$(cat exp/chain/egs/info/num_pdfs)
35-
3633
hidden_dim=625
3734
kernel_size_list="1, 3, 3, 3, 3, 3" # comma separated list
3835
stride_list="1, 1, 3, 1, 1, 1" # comma separated list
3936

4037
log_level=info # valid values: debug, info, warning
4138

39+
# true to save network output as kaldi::CompressedMatrix
40+
# false to save it as kaldi::Matrix<float>
41+
save_nn_output_as_compressed=false
42+
4243
. ./path.sh
4344
. ./cmd.sh
4445

@@ -116,6 +117,9 @@ if [[ $stage -le 5 ]]; then
116117
exp/chain $lat_dir exp/chain/egs
117118
fi
118119

120+
feat_dim=$(cat exp/chain/egs/info/feat_dim)
121+
output_dim=$(cat exp/chain/egs/info/num_pdfs)
122+
119123
if [[ $stage -le 6 ]]; then
120124
echo "merging egs"
121125
mkdir -p exp/chain/merged_egs
@@ -152,7 +156,7 @@ if [[ $stage -le 8 ]]; then
152156
--dir exp/chain/train \
153157
--feat-dim $feat_dim \
154158
--hidden-dim $hidden_dim \
155-
--is-training 1 \
159+
--is-training true \
156160
--kernel-size-list "$kernel_size_list" \
157161
--log-level $log_level \
158162
--output-dim $output_dim \
@@ -182,12 +186,13 @@ if [[ $stage -le 9 ]]; then
182186
--feat-dim $feat_dim \
183187
--feats-scp data/fbank_pitch/$x/feats.scp \
184188
--hidden-dim $hidden_dim \
185-
--is-training 0 \
189+
--is-training false \
186190
--kernel-size-list "$kernel_size_list" \
187191
--log-level $log_level \
188192
--model-left-context $model_left_context \
189193
--model-right-context $model_right_context \
190194
--output-dim $output_dim \
195+
--save-as-compressed $save_nn_output_as_compressed \
191196
--stride-list "$stride_list" || exit 1
192197
fi
193198
done

src/pybind/Makefile

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,12 @@ fst/vector_fst_pybind.cc \
5959
fst/weight_pybind.cc \
6060
fstext/kaldi_fst_io_pybind.cc \
6161
kaldi_pybind.cc \
62+
matrix/compressed_matrix_pybind.cc \
63+
matrix/kaldi_matrix_pybind.cc \
64+
matrix/kaldi_vector_pybind.cc \
6265
matrix/matrix_common_pybind.cc \
6366
matrix/matrix_pybind.cc \
6467
matrix/sparse_matrix_pybind.cc \
65-
matrix/vector_pybind.cc \
6668
nnet3/nnet3_pybind.cc \
6769
nnet3/nnet_chain_example_pybind.cc \
6870
nnet3/nnet_common_pybind.cc \

src/pybind/kaldi.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,5 @@
1717
from util.table import SequentialVectorReader
1818
from util.table import RandomAccessVectorReader
1919
from util.table import VectorWriter
20+
21+
from util.table import CompressedMatrixWriter

src/pybind/kaldi_pybind.cc

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,7 @@
2424
#include "cudamatrix/cudamatrix_pybind.h"
2525
#include "feat/feat_pybind.h"
2626
#include "feat/wave_reader_pybind.h"
27-
#include "matrix/matrix_common_pybind.h"
2827
#include "matrix/matrix_pybind.h"
29-
#include "matrix/sparse_matrix_pybind.h"
30-
#include "matrix/vector_pybind.h"
3128
#include "util/util_pybind.h"
3229

3330
#include "fst/fst_pybind.h"
@@ -43,10 +40,7 @@ PYBIND11_MODULE(kaldi_pybind, m) {
4340
"src/matrix and src/util directories. "
4441
"Source is in $(KALDI_ROOT)/src/pybind/kaldi_pybind.cc";
4542

46-
pybind_matrix_common(m);
4743
pybind_matrix(m);
48-
pybind_sparse_matrix(m);
49-
pybind_vector(m);
5044
pybind_util(m);
5145
pybind_feat(m);
5246

src/pybind/matrix/Makefile

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11

22
test:
3-
python3 ./vector_pybind_test.py
4-
python3 ./matrix_pybind_test.py
3+
python3 ./compressed_matrix_pybind_test.py
4+
python3 ./kaldi_matrix_pybind_test.py
5+
python3 ./kaldi_vector_pybind_test.py
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
// pybind/matrix/compressed_matrix_pybind.cc
2+
3+
// Copyright 2019 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang)
4+
5+
// See ../../../COPYING for clarification regarding multiple authors
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain a copy of the License at
10+
//
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
//
13+
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15+
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16+
// MERCHANTABLITY OR NON-INFRINGEMENT.
17+
// See the Apache 2 License for the specific language governing permissions and
18+
// limitations under the License.
19+
20+
#include "matrix/compressed_matrix_pybind.h"
21+
22+
#include "matrix/compressed-matrix.h"
23+
24+
using namespace kaldi;
25+
26+
void pybind_compressed_matrix(py::module& m) {
27+
py::enum_<CompressionMethod>(
28+
m, "CompressionMethod", py::arithmetic(),
29+
"The enum CompressionMethod is used when creating a CompressedMatrix (a "
30+
"lossily compressed matrix) from a regular Matrix. It dictates how we "
31+
"choose the compressed format and how we choose the ranges of floats "
32+
"that are represented by particular integers.")
33+
.value(
34+
"kAutomaticMethod", kAutomaticMethod,
35+
"This is the default when you don't specify the compression method. "
36+
" It is a shorthand for using kSpeechFeature if the num-rows is "
37+
"more than 8, and kTwoByteAuto otherwise.")
38+
.value(
39+
"kSpeechFeature", kSpeechFeature,
40+
"This is the most complicated of the compression methods, and was "
41+
"designed for speech features which have a roughly Gaussian "
42+
"distribution with different ranges for each dimension. Each "
43+
"element is stored in one byte, but there is an 8-byte header per "
44+
"column; the spacing of the integer values is not uniform but is in "
45+
"3 ranges.")
46+
.value("kTwoByteAuto", kTwoByteAuto,
47+
"Each element is stored in two bytes as a uint16, with the "
48+
"representable range of values chosen automatically with the "
49+
"minimum and maximum elements of the matrix as its edges.")
50+
.value("kTwoByteSignedInteger", kTwoByteSignedInteger,
51+
"Each element is stored in two bytes as a uint16, with the "
52+
"representable range of value chosen to coincide with what you'd "
53+
"get if you stored signed integers, i.e. [-32768.0, 32767.0]. "
54+
"Suitable for waveform data that was previously stored as 16-bit "
55+
"PCM.")
56+
.value("kOneByteAuto", kOneByteAuto,
57+
"Each element is stored in one byte as a uint8, with the "
58+
"representable range of values chosen automatically with the "
59+
"minimum and maximum elements of the matrix as its edges.")
60+
.value("kOneByteUnsignedInteger", kOneByteUnsignedInteger,
61+
"Each element is stored in one byte as a uint8, with the "
62+
"representable range of values equal to [0.0, 255.0].")
63+
.value("kOneByteZeroOne", kOneByteZeroOne,
64+
"Each element is stored in one byte as a uint8, with the "
65+
"representable range of values equal to [0.0, 1.0]. Suitable for "
66+
"image data that has previously been compressed as int8.")
67+
.export_values();
68+
{
69+
using PyClass = CompressedMatrix;
70+
py::class_<PyClass>(m, "CompressedMatrix")
71+
.def(py::init<>())
72+
.def(py::init<const MatrixBase<float>&, CompressionMethod>(),
73+
py::arg("mat"), py::arg("method") = kAutomaticMethod)
74+
.def("NumRows", &PyClass::NumRows,
75+
"Returns number of rows (or zero for emtpy matrix).")
76+
.def("NumCols", &PyClass::NumCols,
77+
"Returns number of columns (or zero for emtpy matrix).");
78+
}
79+
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// pybind/matrix/compressed_matrix_pybind.h
2+
3+
// Copyright 2019 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang)
4+
5+
// See ../../../COPYING for clarification regarding multiple authors
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
12+
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
13+
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
14+
// MERCHANTABLITY OR NON-INFRINGEMENT.
15+
// See the Apache 2 License for the specific language governing permissions and
16+
// limitations under the License.
17+
18+
#ifndef KALDI_PYBIND_MATRIX_COMPRESSED_MATRIX_PYBIND_H_
19+
#define KALDI_PYBIND_MATRIX_COMPRESSED_MATRIX_PYBIND_H_
20+
21+
#include "pybind/kaldi_pybind.h"
22+
23+
void pybind_compressed_matrix(py::module& m);
24+
25+
#endif // KALDI_PYBIND_MATRIX_COMPRESSED_MATRIX_PYBIND_H_
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#!/usr/bin/env python3
2+
3+
# Copyright 2019 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang)
4+
# Apache 2.0
5+
6+
import os
7+
import sys
8+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), os.pardir))
9+
10+
import unittest
11+
12+
import kaldi
13+
14+
class TestCompressedMatrix(unittest.TestCase):
15+
def test_from_float_matrix(self):
16+
num_rows = 2
17+
num_cols = 3
18+
m = kaldi.FloatMatrix(num_rows, num_cols)
19+
20+
cm = kaldi.CompressedMatrix(m)
21+
22+
self.assertEqual(cm.NumRows(), num_rows)
23+
self.assertEqual(cm.NumCols(), num_cols)
24+
25+
if __name__ == '__main__':
26+
unittest.main()

0 commit comments

Comments
 (0)