Skip to content

Commit d011342

Browse files
committed
Update design and comments
1 parent 2538a32 commit d011342

File tree

2 files changed

+88
-87
lines changed

2 files changed

+88
-87
lines changed

src/decoder/lattice-faster-decoder-combine.cc

Lines changed: 60 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
// decoder/lattice-faster-decoder.cc
1+
// decoder/lattice-faster-decoder-combine.cc
22

33
// Copyright 2009-2012 Microsoft Corporation Mirko Hannemann
4-
// 2013-2018 Johns Hopkins University (Author: Daniel Povey)
4+
// 2013-2019 Johns Hopkins University (Author: Daniel Povey)
55
// 2014 Guoguo Chen
66
// 2018 Zhehuai Chen
7+
// 2019 Hang Lyu
78

89
// See ../../COPYING for clarification regarding multiple authors
910
//
@@ -68,10 +69,6 @@ void LatticeFasterDecoderCombineTpl<FST, Token>::InitDecoding() {
6869
active_toks_[0].toks = start_tok;
6970
next_toks_[start_state] = start_tok; // initialize current tokens map
7071
num_toks_++;
71-
72-
recover_ = false;
73-
frame_processed_.resize(1);
74-
frame_processed_[0] = false;
7572
}
7673

7774
// Returns true if any kind of traceback is available (not necessarily from
@@ -91,8 +88,7 @@ bool LatticeFasterDecoderCombineTpl<FST, Token>::Decode(DecodableInterface *deco
9188
ProcessForFrame(decodable);
9289
}
9390
// Procss non-emitting arcs for the last frame.
94-
ProcessNonemitting(false);
95-
frame_processed_[active_toks_.size() - 1] = true; // the last frame is processed.
91+
ProcessNonemitting(NULL);
9692

9793
FinalizeDecoding();
9894

@@ -123,17 +119,21 @@ bool LatticeFasterDecoderCombineTpl<FST, Token>::GetRawLattice(
123119
typedef Arc::StateId StateId;
124120
typedef Arc::Weight Weight;
125121
typedef Arc::Label Label;
126-
// Process the non-emitting arcs for the unfinished last frame.
127-
if (!frame_processed_[active_toks_.size() - 1]) {
128-
ProcessNonemitting(true);
129-
}
130122
// Note: you can't use the old interface (Decode()) if you want to
131123
// get the lattice with use_final_probs = false. You'd have to do
132124
// InitDecoding() and then AdvanceDecoding().
133125
if (decoding_finalized_ && !use_final_probs)
134126
KALDI_ERR << "You cannot call FinalizeDecoding() and then call "
135127
<< "GetRawLattice() with use_final_probs == false";
136128

129+
std::unordered_map<Token*, BaseFloat> *recover_map = NULL;
130+
if (!decoding_finalized_) {
131+
recover_map = new std::unordered_map<Token*, BaseFloat>();
132+
// Process the non-emitting arcs for the unfinished last frame.
133+
ProcessNonemitting(recover_map);
134+
}
135+
136+
137137
unordered_map<Token*, BaseFloat> final_costs_local;
138138

139139
const unordered_map<Token*, BaseFloat> &final_costs =
@@ -201,10 +201,42 @@ bool LatticeFasterDecoderCombineTpl<FST, Token>::GetRawLattice(
201201
}
202202
}
203203
}
204+
205+
if (recover_map) { // recover last token list
206+
RecoverLastTokenList(recover_map);
207+
delete recover_map;
208+
}
204209
return (ofst->NumStates() > 0);
205210
}
206211

207212

213+
// When GetRawLattice() is called during decoding, the
214+
// active_toks_[last_frame] is changed. To keep the consistency of function
215+
// ProcessForFrame(), recover it.
216+
// Notice: as new token will be added to the head of TokenList, tok->next
217+
// will not be affacted.
218+
template<typename FST, typename Token>
219+
void LatticeFasterDecoderCombineTpl<FST, Token>::RecoverLastTokenList(
220+
std::unordered_map<Token*, BaseFloat> *recover_map) {
221+
if (recover_map) {
222+
for (Token* tok = active_toks_[active_toks_.size() - 1].toks;
223+
tok != NULL;) {
224+
if (recover_map->find(tok) != recover_map->end()) {
225+
DeleteForwardLinks(tok);
226+
tok->tot_cost = (*recover_map)[tok];
227+
tok->in_current_queue = false;
228+
tok = tok->next;
229+
} else {
230+
DeleteForwardLinks(tok);
231+
Token *next_tok = tok->next;
232+
delete tok;
233+
num_toks_--;
234+
tok = next_tok;
235+
}
236+
}
237+
}
238+
}
239+
208240
// This function is now deprecated, since now we do determinization from outside
209241
// the LatticeFasterDecoder class. Outputs an FST corresponding to the
210242
// lattice-determinized lattice (one path per word sequence).
@@ -258,19 +290,19 @@ bool LatticeFasterDecoderCombineTpl<FST, Token>::GetLattice(
258290
only do it every 'config_.prune_interval' frames).
259291
*/
260292

261-
// FindOrAddToken either locates a token in hash of toks_,
293+
// FindOrAddToken either locates a token in hash map "token_map"
262294
// or if necessary inserts a new, empty token (i.e. with no forward links)
263295
// for the current frame. [note: it's inserted if necessary into hash toks_
264296
// and also into the singly linked list of tokens active on this frame
265297
// (whose head is at active_toks_[frame]).
266298
template <typename FST, typename Token>
267299
inline Token* LatticeFasterDecoderCombineTpl<FST, Token>::FindOrAddToken(
268-
StateId state, int32 frame, BaseFloat tot_cost, Token *backpointer,
300+
StateId state, int32 frame_plus_one, BaseFloat tot_cost, Token *backpointer,
269301
StateIdToTokenMap *token_map, bool *changed) {
270302
// Returns the Token pointer. Sets "changed" (if non-NULL) to true
271303
// if the token was newly created or the cost changed.
272-
KALDI_ASSERT(frame < active_toks_.size());
273-
Token *&toks = active_toks_[frame].toks;
304+
KALDI_ASSERT(frame_plus_one < active_toks_.size());
305+
Token *&toks = active_toks_[frame_plus_one].toks;
274306
typename StateIdToTokenMap::iterator e_found = token_map->find(state);
275307
if (e_found == token_map->end()) { // no such token presently.
276308
const BaseFloat extra_cost = 0.0;
@@ -626,7 +658,7 @@ void LatticeFasterDecoderCombineTpl<FST, Token>::AdvanceDecoding(
626658
}
627659
ProcessForFrame(decodable);
628660
}
629-
ProcessNonemitting(false);
661+
ProcessNonemitting(NULL);
630662
}
631663

632664
// FinalizeDecoding() is a version of PruneActiveTokens that we call
@@ -732,38 +764,7 @@ void LatticeFasterDecoderCombineTpl<FST, Token>::ProcessForFrame(
732764
int32 frame = active_toks_.size() - 1; // frame is the frame-index
733765
// (zero-based) used to get likelihoods
734766
// from the decodable object.
735-
if (!recover_ && frame_processed_[frame]) {
736-
KALDI_ERR << "Maybe the whole utterance has been processed, you shouldn't"
737-
<< " call ProcessForFrame() again.";
738-
} else if (recover_ && !frame_processed_[frame]) {
739-
KALDI_ERR << "Should not happen.";
740-
}
741-
742-
// Maybe called GetRawLattice() in the middle of an utterance. The
743-
// active_toks_[frame] is changed. Recover it.
744-
// Notice: as new token will be added to the head of TokenList, tok->next
745-
// will not be affacted.
746-
if (recover_) {
747-
frame_processed_[frame] = false;
748-
for (Token* tok = active_toks_[frame].toks; tok != NULL;) {
749-
if (recover_map_.find(tok) != recover_map_.end()) {
750-
DeleteForwardLinks(tok);
751-
tok->tot_cost = recover_map_[tok];
752-
tok->in_current_queue = false;
753-
tok = tok->next;
754-
} else {
755-
DeleteForwardLinks(tok);
756-
Token *next_tok = tok->next;
757-
delete tok;
758-
num_toks_--;
759-
tok = next_tok;
760-
}
761-
}
762-
recover_ = false;
763-
}
764-
765767
active_toks_.resize(active_toks_.size() + 1);
766-
frame_processed_.resize(frame_processed_.size() + 1);
767768

768769
cur_toks_.clear();
769770
cur_toks_.swap(next_toks_);
@@ -890,27 +891,24 @@ void LatticeFasterDecoderCombineTpl<FST, Token>::ProcessForFrame(
890891
} // for all arcs
891892
tok->in_current_queue = false; // out of queue
892893
} // end of while loop
893-
frame_processed_[frame] = true;
894-
frame_processed_[frame + 1] = false;
894+
KALDI_VLOG(6) << "toks after: " << cur_toks_.size();
895895
}
896896

897897

898898
template <typename FST, typename Token>
899-
void LatticeFasterDecoderCombineTpl<FST, Token>::ProcessNonemitting(bool recover) {
900-
if (recover) { // Build the elements which are used to recover
901-
// Set the flag to true so that we will recover "next_toks_" map in
902-
// ProcessForFrame() firstly.
903-
recover_ = true;
899+
void LatticeFasterDecoderCombineTpl<FST, Token>::ProcessNonemitting(
900+
std::unordered_map<Token*, BaseFloat> *recover_map) {
901+
if (recover_map) { // Build the elements which are used to recover
904902
for (IterType iter = next_toks_.begin(); iter != next_toks_.end(); iter++) {
905-
recover_map_[iter->second] = iter->second->tot_cost;
903+
(*recover_map)[iter->second] = iter->second->tot_cost;
906904
}
907905
}
908906

909-
StateIdToTokenMap tmp_toks_(next_toks_);
907+
StateIdToTokenMap tmp_toks(next_toks_);
910908
int32 frame = active_toks_.size() - 1;
911909
// Build the queue to process non-emitting arcs
912910
std::vector<StateId> cur_queue;
913-
for (IterType iter = tmp_toks_.begin(); iter != tmp_toks_.end(); iter++) {
911+
for (IterType iter = tmp_toks.begin(); iter != tmp_toks.end(); iter++) {
914912
if (fst_->NumInputEpsilons(iter->first) != 0) {
915913
cur_queue.push_back(iter->first);
916914
iter->second->in_current_queue = true;
@@ -920,14 +918,14 @@ void LatticeFasterDecoderCombineTpl<FST, Token>::ProcessNonemitting(bool recover
920918
// "cur_cutoff" is used to constrain the epsilon emittion in current frame.
921919
// It will not be updated.
922920
BaseFloat adaptive_beam;
923-
BaseFloat cur_cutoff = GetCutoff(tmp_toks_, &adaptive_beam, NULL, NULL);
921+
BaseFloat cur_cutoff = GetCutoff(tmp_toks, &adaptive_beam, NULL, NULL);
924922

925923
while (!cur_queue.empty()) {
926924
StateId state = cur_queue.back();
927925
cur_queue.pop_back();
928926

929-
KALDI_ASSERT(tmp_toks_.find(state) != tmp_toks_.end());
930-
Token *tok = tmp_toks_[state];
927+
KALDI_ASSERT(tmp_toks.find(state) != tmp_toks.end());
928+
Token *tok = tmp_toks[state];
931929
BaseFloat cur_cost = tok->tot_cost;
932930
if (cur_cost > cur_cutoff) // Don't bother processing successors.
933931
continue;
@@ -946,7 +944,7 @@ void LatticeFasterDecoderCombineTpl<FST, Token>::ProcessNonemitting(bool recover
946944
BaseFloat tot_cost = cur_cost + graph_cost;
947945
if (tot_cost < cur_cutoff) {
948946
Token *new_tok = FindOrAddToken(arc.nextstate, frame, tot_cost,
949-
tok, &tmp_toks_, &changed);
947+
tok, &tmp_toks, &changed);
950948

951949
// Add ForwardLink from tok to new_tok. Put it on the head of
952950
// tok->link list
@@ -964,9 +962,6 @@ void LatticeFasterDecoderCombineTpl<FST, Token>::ProcessNonemitting(bool recover
964962
} // end of for loop
965963
tok->in_current_queue = false;
966964
} // end of while loop
967-
frame_processed_[active_toks_.size() - 1] = true; // in case someone call
968-
// GetRawLattice() twice
969-
// continuously.
970965
}
971966

972967

src/decoder/lattice-faster-decoder-combine.h

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
// decoder/lattice-faster-decoder.h
1+
// decoder/lattice-faster-decoder-combine.h
22

33
// Copyright 2009-2013 Microsoft Corporation; Mirko Hannemann;
4-
// 2013-2014 Johns Hopkins University (Author: Daniel Povey)
4+
// 2013-2019 Johns Hopkins University (Author: Daniel Povey)
55
// 2014 Guoguo Chen
66
// 2018 Zhehuai Chen
7+
// 2019 Hang Lyu
78

89
// See ../../COPYING for clarification regarding multiple authors
910
//
@@ -294,6 +295,9 @@ class LatticeFasterDecoderCombineTpl {
294295
/// of the graph then it will include those as final-probs, else
295296
/// it will treat all final-probs as one.
296297
/// The raw lattice will be topologically sorted.
298+
/// The function can be called during decoding, it will take "next_toks_" map
299+
/// and generate the complete token list for the last frame. Then recover it
300+
/// to ensure the consistency of ProcessForFrame().
297301
///
298302
/// See also GetRawLatticePruned in lattice-faster-online-decoder.h,
299303
/// which also supports a pruning beam, in case for some reason
@@ -373,17 +377,17 @@ class LatticeFasterDecoderCombineTpl {
373377
must_prune_tokens(true) { }
374378
};
375379

376-
// FindOrAddToken either locates a token in hash of toks_, or if necessary
380+
// FindOrAddToken either locates a token in hash map "token_map", or if necessary
377381
// inserts a new, empty token (i.e. with no forward links) for the current
378-
// frame. [note: it's inserted if necessary into hash toks_ and also into the
382+
// frame. [note: it's inserted if necessary into hash map and also into the
379383
// singly linked list of tokens active on this frame (whose head is at
380384
// active_toks_[frame]). The frame_plus_one argument is the acoustic frame
381385
// index plus one, which is used to index into the active_toks_ array.
382386
// Returns the Token pointer. Sets "changed" (if non-NULL) to true if the
383387
// token was newly created or the cost changed.
384388
// If Token == StdToken, the 'backpointer' argument has no purpose (and will
385389
// hopefully be optimized out).
386-
inline Token *FindOrAddToken(StateId state, int32 frame,
390+
inline Token *FindOrAddToken(StateId state, int32 frame_plus_one,
387391
BaseFloat tot_cost, Token *backpointer,
388392
StateIdToTokenMap *token_map,
389393
bool *changed);
@@ -442,18 +446,28 @@ class LatticeFasterDecoderCombineTpl {
442446
// less far.
443447
void PruneActiveTokens(BaseFloat delta);
444448

445-
/// Processes nonemitting (epsilon) arcs and emitting arcs for one frame
446-
/// together. Consider it as a combination of ProcessEmitting() and
447-
/// ProcessNonemitting().
449+
/// Processes non-emitting (epsilon) arcs and emitting arcs for one frame
450+
/// together. It takes the emittion tokens in "cur_toks_" from last frame.
451+
/// Generates non-emitting tokens for current frame and emitting tokens for
452+
/// next frame.
448453
void ProcessForFrame(DecodableInterface *decodable);
449454

450455
/// Processes nonemitting (epsilon) arcs for one frame.
451-
/// Called once when all frames were processed or in GetRawLattice().
452-
/// Deal With the tokens in map "next_toks_" which would only contains
453-
/// emittion tokens from previous frame.
454-
/// If you call this function not in the end of an utterance, recover
455-
/// should be true.
456-
void ProcessNonemitting(bool recover);
456+
/// Calls this function once when all frames were processed.
457+
/// Or calls it in GetRawLattice() to generate the complete token list for
458+
/// the last frame. [Deal With the tokens in map "next_toks_" which would
459+
/// only contains emittion tokens from previous frame.]
460+
/// If "recover_map" isn't NULL, we build the recover_map which will be used
461+
/// to recover "active_toks_[last_frame]" token list for the last frame.
462+
void ProcessNonemitting(std::unordered_map<Token*, BaseFloat> *recover_map);
463+
464+
/// When GetRawLattice() is called during decoding, the
465+
/// active_toks_[last_frame] is changed. To keep the consistency of function
466+
/// ProcessForFrame(), recover it.
467+
/// Notice: as new token will be added to the head of TokenList, tok->next
468+
/// will not be affacted.
469+
void RecoverLastTokenList(std::unordered_map<Token*, BaseFloat> *recover_map);
470+
457471

458472
/// The "cur_toks_" and "next_toks_" actually allow us to maintain current
459473
/// and next frames. They are indexed by StateId. It is indexed by frame-index
@@ -464,14 +478,6 @@ class LatticeFasterDecoderCombineTpl {
464478
StateIdToTokenMap cur_toks_;
465479
StateIdToTokenMap next_toks_;
466480

467-
/// When we call GetRawLattice() in the middle of an utterance, we have to
468-
/// process non-emitting arcs so that we need to recover it original status.
469-
std::unordered_map<Token*, BaseFloat> recover_map_; // Token pointer to tot_cost
470-
bool recover_;
471-
/// Indicate each frame is processed wholly or not. The size equals to
472-
/// active_toks_.
473-
std::vector<bool> frame_processed_;
474-
475481
/// Gets the weight cutoff.
476482
/// Notice: In traiditional version, the histogram prunning method is applied
477483
/// on a complete token list on one frame. But, in this version, it is used

0 commit comments

Comments
 (0)