Skip to content

Commit a758ba4

Browse files
committed
update comments and the functions about PNE()
1 parent d011342 commit a758ba4

File tree

2 files changed

+57
-50
lines changed

2 files changed

+57
-50
lines changed

src/decoder/lattice-faster-decoder-combine.cc

Lines changed: 39 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ LatticeFasterDecoderCombineTpl<FST, Token>::LatticeFasterDecoderCombineTpl(
4141
const LatticeFasterDecoderCombineConfig &config, FST *fst):
4242
fst_(fst), delete_fst_(true), config_(config), num_toks_(0) {
4343
config.Check();
44+
prev_toks_.reserve(1000);
45+
cur_toks_.reserve(1000);
4446
}
4547

4648

@@ -53,8 +55,8 @@ LatticeFasterDecoderCombineTpl<FST, Token>::~LatticeFasterDecoderCombineTpl() {
5355
template <typename FST, typename Token>
5456
void LatticeFasterDecoderCombineTpl<FST, Token>::InitDecoding() {
5557
// clean up from last time:
58+
prev_toks_.clear();
5659
cur_toks_.clear();
57-
next_toks_.clear();
5860
cost_offsets_.clear();
5961
ClearActiveTokens();
6062

@@ -67,7 +69,7 @@ void LatticeFasterDecoderCombineTpl<FST, Token>::InitDecoding() {
6769
active_toks_.resize(1);
6870
Token *start_tok = new Token(0.0, 0.0, NULL, NULL, NULL);
6971
active_toks_[0].toks = start_tok;
70-
next_toks_[start_state] = start_tok; // initialize current tokens map
72+
cur_toks_[start_state] = start_tok; // initialize current tokens map
7173
num_toks_++;
7274
}
7375

@@ -87,9 +89,7 @@ bool LatticeFasterDecoderCombineTpl<FST, Token>::Decode(DecodableInterface *deco
8789
PruneActiveTokens(config_.lattice_beam * config_.prune_scale);
8890
ProcessForFrame(decodable);
8991
}
90-
// Procss non-emitting arcs for the last frame.
91-
ProcessNonemitting(NULL);
92-
92+
// A complete token list of the last frame will be generated in FinalizeDecoding()
9393
FinalizeDecoding();
9494

9595
// Returns true if we have any kind of traceback available (not necessarily
@@ -126,11 +126,10 @@ bool LatticeFasterDecoderCombineTpl<FST, Token>::GetRawLattice(
126126
KALDI_ERR << "You cannot call FinalizeDecoding() and then call "
127127
<< "GetRawLattice() with use_final_probs == false";
128128

129-
std::unordered_map<Token*, BaseFloat> *recover_map = NULL;
129+
std::unordered_map<Token*, BaseFloat> recover_map;
130130
if (!decoding_finalized_) {
131-
recover_map = new std::unordered_map<Token*, BaseFloat>();
132131
// Process the non-emitting arcs for the unfinished last frame.
133-
ProcessNonemitting(recover_map);
132+
ProcessNonemitting(&recover_map);
134133
}
135134

136135

@@ -202,9 +201,8 @@ bool LatticeFasterDecoderCombineTpl<FST, Token>::GetRawLattice(
202201
}
203202
}
204203

205-
if (recover_map) { // recover last token list
204+
if (!decoding_finalized_) { // recover last token list
206205
RecoverLastTokenList(recover_map);
207-
delete recover_map;
208206
}
209207
return (ofst->NumStates() > 0);
210208
}
@@ -217,13 +215,13 @@ bool LatticeFasterDecoderCombineTpl<FST, Token>::GetRawLattice(
217215
// will not be affacted.
218216
template<typename FST, typename Token>
219217
void LatticeFasterDecoderCombineTpl<FST, Token>::RecoverLastTokenList(
220-
std::unordered_map<Token*, BaseFloat> *recover_map) {
221-
if (recover_map) {
218+
const std::unordered_map<Token*, BaseFloat> &recover_map) {
219+
if (!recover_map.empty()) {
222220
for (Token* tok = active_toks_[active_toks_.size() - 1].toks;
223221
tok != NULL;) {
224-
if (recover_map->find(tok) != recover_map->end()) {
222+
if (recover_map.find(tok) != recover_map.end()) {
225223
DeleteForwardLinks(tok);
226-
tok->tot_cost = (*recover_map)[tok];
224+
tok->tot_cost = recover_map.find(tok)->second;
227225
tok->in_current_queue = false;
228226
tok = tok->next;
229227
} else {
@@ -588,8 +586,8 @@ void LatticeFasterDecoderCombineTpl<FST, Token>::ComputeFinalCosts(
588586
BaseFloat best_cost = infinity,
589587
best_cost_with_final = infinity;
590588

591-
// The final tokens are recorded in unordered_map "next_toks_".
592-
for (IterType iter = next_toks_.begin(); iter != next_toks_.end(); iter++) {
589+
// The final tokens are recorded in unordered_map "cur_toks_".
590+
for (IterType iter = cur_toks_.begin(); iter != cur_toks_.end(); iter++) {
593591
StateId state = iter->first;
594592
Token *tok = iter->second;
595593
BaseFloat final_cost = fst_->Final(state).Value();
@@ -658,7 +656,6 @@ void LatticeFasterDecoderCombineTpl<FST, Token>::AdvanceDecoding(
658656
}
659657
ProcessForFrame(decodable);
660658
}
661-
ProcessNonemitting(NULL);
662659
}
663660

664661
// FinalizeDecoding() is a version of PruneActiveTokens that we call
@@ -686,7 +683,7 @@ void LatticeFasterDecoderCombineTpl<FST, Token>::FinalizeDecoding() {
686683
template <typename FST, typename Token>
687684
BaseFloat LatticeFasterDecoderCombineTpl<FST, Token>::GetCutoff(
688685
const StateIdToTokenMap &toks, BaseFloat *adaptive_beam,
689-
StateId *best_elem_id, Token **best_elem) {
686+
StateId *best_state_id, Token **best_token) {
690687
// positive == high cost == bad.
691688
// best_weight is the minimum value.
692689
BaseFloat best_weight = std::numeric_limits<BaseFloat>::infinity();
@@ -696,9 +693,9 @@ BaseFloat LatticeFasterDecoderCombineTpl<FST, Token>::GetCutoff(
696693
BaseFloat w = static_cast<BaseFloat>(iter->second->tot_cost);
697694
if (w < best_weight) {
698695
best_weight = w;
699-
if (best_elem) {
700-
*best_elem_id = iter->first;
701-
*best_elem = iter->second;
696+
if (best_token) {
697+
*best_state_id = iter->first;
698+
*best_token = iter->second;
702699
}
703700
}
704701
}
@@ -711,9 +708,9 @@ BaseFloat LatticeFasterDecoderCombineTpl<FST, Token>::GetCutoff(
711708
tmp_array_.push_back(w);
712709
if (w < best_weight) {
713710
best_weight = w;
714-
if (best_elem) {
715-
*best_elem_id = iter->first;
716-
*best_elem = iter->second;
711+
if (best_token) {
712+
*best_state_id = iter->first;
713+
*best_token = iter->second;
717714
}
718715
}
719716
}
@@ -722,8 +719,8 @@ BaseFloat LatticeFasterDecoderCombineTpl<FST, Token>::GetCutoff(
722719
min_active_cutoff = std::numeric_limits<BaseFloat>::infinity(),
723720
max_active_cutoff = std::numeric_limits<BaseFloat>::infinity();
724721

725-
KALDI_VLOG(6) << "Number of tokens active on frame " << NumFramesDecoded()
726-
<< " is " << tmp_array_.size();
722+
KALDI_VLOG(6) << "Number of emitting tokens on frame "
723+
<< NumFramesDecoded() - 1 << " is " << tmp_array_.size();
727724

728725
if (tmp_array_.size() > static_cast<size_t>(config_.max_active)) {
729726
std::nth_element(tmp_array_.begin(),
@@ -766,9 +763,9 @@ void LatticeFasterDecoderCombineTpl<FST, Token>::ProcessForFrame(
766763
// from the decodable object.
767764
active_toks_.resize(active_toks_.size() + 1);
768765

766+
prev_toks_.swap(cur_toks_);
769767
cur_toks_.clear();
770-
cur_toks_.swap(next_toks_);
771-
if (cur_toks_.empty()) {
768+
if (prev_toks_.empty()) {
772769
if (!warned_) {
773770
KALDI_WARN << "Error, no surviving tokens on frame " << frame;
774771
warned_ = true;
@@ -780,7 +777,7 @@ void LatticeFasterDecoderCombineTpl<FST, Token>::ProcessForFrame(
780777
StateId best_tok_state_id;
781778
// "cur_cutoff" is used to constrain the epsilon emittion in current frame.
782779
// It will not be updated.
783-
BaseFloat cur_cutoff = GetCutoff(cur_toks_, &adaptive_beam,
780+
BaseFloat cur_cutoff = GetCutoff(prev_toks_, &adaptive_beam,
784781
&best_tok_state_id, &best_tok);
785782
KALDI_VLOG(6) << "Adaptive beam on frame " << NumFramesDecoded() << " is "
786783
<< adaptive_beam;
@@ -801,7 +798,8 @@ void LatticeFasterDecoderCombineTpl<FST, Token>::ProcessForFrame(
801798
// Notice: As the difference between the combine version and the traditional
802799
// version, this "best_tok" is choosen from emittion tokens. Normally, the
803800
// best token of one frame comes from an epsilon non-emittion. So the best
804-
// token is a looser boundary. Use it to estimate a bound on the next cutoff.
801+
// token is a looser boundary. We use it to estimate a bound on the next
802+
// cutoff and we will update the "next_cutoff" once we have better tokens.
805803
// The "next_cutoff" will be updated in further processing.
806804
if (best_tok) {
807805
cost_offset = - best_tok->tot_cost;
@@ -827,7 +825,7 @@ void LatticeFasterDecoderCombineTpl<FST, Token>::ProcessForFrame(
827825

828826
// Build a queue which contains the emittion tokens from previous frame.
829827
std::vector<StateId> cur_queue;
830-
for (IterType iter = cur_toks_.begin(); iter != cur_toks_.end(); iter++) {
828+
for (IterType iter = prev_toks_.begin(); iter != prev_toks_.end(); iter++) {
831829
cur_queue.push_back(iter->first);
832830
iter->second->in_current_queue = true;
833831
}
@@ -837,9 +835,11 @@ void LatticeFasterDecoderCombineTpl<FST, Token>::ProcessForFrame(
837835
StateId state = cur_queue.back();
838836
cur_queue.pop_back();
839837

840-
KALDI_ASSERT(cur_toks_.find(state) != cur_toks_.end());
841-
Token *tok = cur_toks_[state];
838+
KALDI_ASSERT(prev_toks_.find(state) != prev_toks_.end());
839+
Token *tok = prev_toks_[state];
840+
842841
BaseFloat cur_cost = tok->tot_cost;
842+
tok->in_current_queue = false; // out of queue
843843
if (cur_cost > cur_cutoff) // Don't bother processing successors.
844844
continue;
845845
// If "tok" has any existing forward links, delete them,
@@ -857,7 +857,7 @@ void LatticeFasterDecoderCombineTpl<FST, Token>::ProcessForFrame(
857857
BaseFloat tot_cost = cur_cost + graph_cost;
858858
if (tot_cost < cur_cutoff) {
859859
Token *new_tok = FindOrAddToken(arc.nextstate, frame, tot_cost,
860-
tok, &cur_toks_, &changed);
860+
tok, &prev_toks_, &changed);
861861

862862
// Add ForwardLink from tok to new_tok. Put it on the head of
863863
// tok->link list
@@ -882,29 +882,29 @@ void LatticeFasterDecoderCombineTpl<FST, Token>::ProcessForFrame(
882882

883883
// no change flag is needed
884884
Token *next_tok = FindOrAddToken(arc.nextstate, frame + 1, tot_cost,
885-
tok, &next_toks_, NULL);
885+
tok, &cur_toks_, NULL);
886886
// Add ForwardLink from tok to next_tok. Put it on the head of tok->link
887887
// list
888888
tok->links = new ForwardLinkT(next_tok, arc.ilabel, arc.olabel,
889889
graph_cost, ac_cost, tok->links);
890890
}
891891
} // for all arcs
892-
tok->in_current_queue = false; // out of queue
893892
} // end of while loop
894-
KALDI_VLOG(6) << "toks after: " << cur_toks_.size();
893+
KALDI_VLOG(6) << "Number of tokens active on frame " << NumFramesDecoded() - 1
894+
<< " is " << prev_toks_.size();
895895
}
896896

897897

898898
template <typename FST, typename Token>
899899
void LatticeFasterDecoderCombineTpl<FST, Token>::ProcessNonemitting(
900900
std::unordered_map<Token*, BaseFloat> *recover_map) {
901901
if (recover_map) { // Build the elements which are used to recover
902-
for (IterType iter = next_toks_.begin(); iter != next_toks_.end(); iter++) {
902+
for (IterType iter = cur_toks_.begin(); iter != cur_toks_.end(); iter++) {
903903
(*recover_map)[iter->second] = iter->second->tot_cost;
904904
}
905905
}
906906

907-
StateIdToTokenMap tmp_toks(next_toks_);
907+
StateIdToTokenMap tmp_toks(cur_toks_);
908908
int32 frame = active_toks_.size() - 1;
909909
// Build the queue to process non-emitting arcs
910910
std::vector<StateId> cur_queue;

src/decoder/lattice-faster-decoder-combine.h

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,9 @@ class LatticeFasterDecoderCombineTpl {
243243
using Weight = typename Arc::Weight;
244244
using ForwardLinkT = decodercombine::ForwardLink<Token>;
245245

246-
using StateIdToTokenMap = typename std::unordered_map<StateId, Token*>;
246+
using StateIdToTokenMap = typename std::unordered_map<StateId, Token*,
247+
std::hash<StateId>, std::equal_to<StateId>,
248+
fst::PoolAllocator<std::pair<const StateId, Token*> > >;
247249
using IterType = typename StateIdToTokenMap::const_iterator;
248250

249251
// Instantiate this class once for each thing you have to decode.
@@ -295,9 +297,10 @@ class LatticeFasterDecoderCombineTpl {
295297
/// of the graph then it will include those as final-probs, else
296298
/// it will treat all final-probs as one.
297299
/// The raw lattice will be topologically sorted.
298-
/// The function can be called during decoding, it will take "next_toks_" map
299-
/// and generate the complete token list for the last frame. Then recover it
300-
/// to ensure the consistency of ProcessForFrame().
300+
/// The function can be called during decoding, it will process non-emitting
301+
/// arcs from "cur_toks_" map to get tokens from both non-emitting and
302+
/// emitting arcs for getting raw lattice. Then recover it to ensure the
303+
/// consistency of ProcessForFrame().
301304
///
302305
/// See also GetRawLatticePruned in lattice-faster-online-decoder.h,
303306
/// which also supports a pruning beam, in case for some reason
@@ -447,15 +450,18 @@ class LatticeFasterDecoderCombineTpl {
447450
void PruneActiveTokens(BaseFloat delta);
448451

449452
/// Processes non-emitting (epsilon) arcs and emitting arcs for one frame
450-
/// together. It takes the emittion tokens in "cur_toks_" from last frame.
451-
/// Generates non-emitting tokens for current frame and emitting tokens for
453+
/// together. It takes the emittion tokens in "prev_toks_" from last frame.
454+
/// Generates non-emitting tokens for previous frame and emitting tokens for
452455
/// next frame.
456+
/// Notice: The emitting tokens for the current frame means the token take
457+
/// acoustic scores of the current frame. (i.e. the destnations of emitting
458+
/// arcs.)
453459
void ProcessForFrame(DecodableInterface *decodable);
454460

455461
/// Processes nonemitting (epsilon) arcs for one frame.
456462
/// Calls this function once when all frames were processed.
457463
/// Or calls it in GetRawLattice() to generate the complete token list for
458-
/// the last frame. [Deal With the tokens in map "next_toks_" which would
464+
/// the last frame. [Deal With the tokens in map "cur_toks_" which would
459465
/// only contains emittion tokens from previous frame.]
460466
/// If "recover_map" isn't NULL, we build the recover_map which will be used
461467
/// to recover "active_toks_[last_frame]" token list for the last frame.
@@ -466,17 +472,18 @@ class LatticeFasterDecoderCombineTpl {
466472
/// ProcessForFrame(), recover it.
467473
/// Notice: as new token will be added to the head of TokenList, tok->next
468474
/// will not be affacted.
469-
void RecoverLastTokenList(std::unordered_map<Token*, BaseFloat> *recover_map);
475+
void RecoverLastTokenList(
476+
const std::unordered_map<Token*, BaseFloat> &recover_map);
470477

471478

472-
/// The "cur_toks_" and "next_toks_" actually allow us to maintain current
479+
/// The "prev_toks_" and "cur_toks_" actually allow us to maintain current
473480
/// and next frames. They are indexed by StateId. It is indexed by frame-index
474481
/// plus one, where the frame-index is zero-based, as used in decodable object.
475482
/// That is, the emitting probs of frame t are accounted for in tokens at
476483
/// toks_[t+1]. The zeroth frame is for nonemitting transition at the start of
477484
/// the graph.
485+
StateIdToTokenMap prev_toks_;
478486
StateIdToTokenMap cur_toks_;
479-
StateIdToTokenMap next_toks_;
480487

481488
/// Gets the weight cutoff.
482489
/// Notice: In traiditional version, the histogram prunning method is applied
@@ -485,7 +492,7 @@ class LatticeFasterDecoderCombineTpl {
485492
/// and min_active values might be narrowed.
486493
BaseFloat GetCutoff(const StateIdToTokenMap& toks,
487494
BaseFloat *adaptive_beam,
488-
StateId *best_elem_id, Token **best_elem);
495+
StateId *best_state_id, Token **best_token);
489496

490497
std::vector<TokenList> active_toks_; // Lists of tokens, indexed by
491498
// frame (members of TokenList are toks, must_prune_forward_links,

0 commit comments

Comments
 (0)