@@ -41,6 +41,8 @@ LatticeFasterDecoderCombineTpl<FST, Token>::LatticeFasterDecoderCombineTpl(
4141 const LatticeFasterDecoderCombineConfig &config, FST *fst):
4242 fst_ (fst), delete_fst_(true ), config_(config), num_toks_(0 ) {
4343 config.Check ();
44+ prev_toks_.reserve (1000 );
45+ cur_toks_.reserve (1000 );
4446}
4547
4648
@@ -53,8 +55,8 @@ LatticeFasterDecoderCombineTpl<FST, Token>::~LatticeFasterDecoderCombineTpl() {
5355template <typename FST, typename Token>
5456void LatticeFasterDecoderCombineTpl<FST, Token>::InitDecoding() {
5557 // clean up from last time:
58+ prev_toks_.clear ();
5659 cur_toks_.clear ();
57- next_toks_.clear ();
5860 cost_offsets_.clear ();
5961 ClearActiveTokens ();
6062
@@ -67,7 +69,7 @@ void LatticeFasterDecoderCombineTpl<FST, Token>::InitDecoding() {
6769 active_toks_.resize (1 );
6870 Token *start_tok = new Token (0.0 , 0.0 , NULL , NULL , NULL );
6971 active_toks_[0 ].toks = start_tok;
70- next_toks_ [start_state] = start_tok; // initialize current tokens map
72+ cur_toks_ [start_state] = start_tok; // initialize current tokens map
7173 num_toks_++;
7274}
7375
@@ -87,9 +89,7 @@ bool LatticeFasterDecoderCombineTpl<FST, Token>::Decode(DecodableInterface *deco
8789 PruneActiveTokens (config_.lattice_beam * config_.prune_scale );
8890 ProcessForFrame (decodable);
8991 }
90- // Procss non-emitting arcs for the last frame.
91- ProcessNonemitting (NULL );
92-
92+ // A complete token list of the last frame will be generated in FinalizeDecoding()
9393 FinalizeDecoding ();
9494
9595 // Returns true if we have any kind of traceback available (not necessarily
@@ -126,11 +126,10 @@ bool LatticeFasterDecoderCombineTpl<FST, Token>::GetRawLattice(
126126 KALDI_ERR << " You cannot call FinalizeDecoding() and then call "
127127 << " GetRawLattice() with use_final_probs == false" ;
128128
129- std::unordered_map<Token*, BaseFloat> * recover_map = NULL ;
129+ std::unordered_map<Token*, BaseFloat> recover_map;
130130 if (!decoding_finalized_) {
131- recover_map = new std::unordered_map<Token*, BaseFloat>();
132131 // Process the non-emitting arcs for the unfinished last frame.
133- ProcessNonemitting (recover_map);
132+ ProcessNonemitting (& recover_map);
134133 }
135134
136135
@@ -202,9 +201,8 @@ bool LatticeFasterDecoderCombineTpl<FST, Token>::GetRawLattice(
202201 }
203202 }
204203
205- if (recover_map ) { // recover last token list
204+ if (!decoding_finalized_ ) { // recover last token list
206205 RecoverLastTokenList (recover_map);
207- delete recover_map;
208206 }
209207 return (ofst->NumStates () > 0 );
210208}
@@ -217,13 +215,13 @@ bool LatticeFasterDecoderCombineTpl<FST, Token>::GetRawLattice(
217215// will not be affacted.
218216template <typename FST, typename Token>
219217void LatticeFasterDecoderCombineTpl<FST, Token>::RecoverLastTokenList(
220- std::unordered_map<Token*, BaseFloat> * recover_map) {
221- if (recover_map) {
218+ const std::unordered_map<Token*, BaseFloat> & recover_map) {
219+ if (! recover_map. empty () ) {
222220 for (Token* tok = active_toks_[active_toks_.size () - 1 ].toks ;
223221 tok != NULL ;) {
224- if (recover_map-> find (tok) != recover_map-> end ()) {
222+ if (recover_map. find (tok) != recover_map. end ()) {
225223 DeleteForwardLinks (tok);
226- tok->tot_cost = (* recover_map)[ tok] ;
224+ tok->tot_cost = recover_map. find ( tok)-> second ;
227225 tok->in_current_queue = false ;
228226 tok = tok->next ;
229227 } else {
@@ -588,8 +586,8 @@ void LatticeFasterDecoderCombineTpl<FST, Token>::ComputeFinalCosts(
588586 BaseFloat best_cost = infinity,
589587 best_cost_with_final = infinity;
590588
591- // The final tokens are recorded in unordered_map "next_toks_ ".
592- for (IterType iter = next_toks_ .begin (); iter != next_toks_ .end (); iter++) {
589+ // The final tokens are recorded in unordered_map "cur_toks_ ".
590+ for (IterType iter = cur_toks_ .begin (); iter != cur_toks_ .end (); iter++) {
593591 StateId state = iter->first ;
594592 Token *tok = iter->second ;
595593 BaseFloat final_cost = fst_->Final (state).Value ();
@@ -658,7 +656,6 @@ void LatticeFasterDecoderCombineTpl<FST, Token>::AdvanceDecoding(
658656 }
659657 ProcessForFrame (decodable);
660658 }
661- ProcessNonemitting (NULL );
662659}
663660
664661// FinalizeDecoding() is a version of PruneActiveTokens that we call
@@ -686,7 +683,7 @@ void LatticeFasterDecoderCombineTpl<FST, Token>::FinalizeDecoding() {
686683template <typename FST, typename Token>
687684BaseFloat LatticeFasterDecoderCombineTpl<FST, Token>::GetCutoff(
688685 const StateIdToTokenMap &toks, BaseFloat *adaptive_beam,
689- StateId *best_elem_id , Token **best_elem ) {
686+ StateId *best_state_id , Token **best_token ) {
690687 // positive == high cost == bad.
691688 // best_weight is the minimum value.
692689 BaseFloat best_weight = std::numeric_limits<BaseFloat>::infinity ();
@@ -696,9 +693,9 @@ BaseFloat LatticeFasterDecoderCombineTpl<FST, Token>::GetCutoff(
696693 BaseFloat w = static_cast <BaseFloat>(iter->second ->tot_cost );
697694 if (w < best_weight) {
698695 best_weight = w;
699- if (best_elem ) {
700- *best_elem_id = iter->first ;
701- *best_elem = iter->second ;
696+ if (best_token ) {
697+ *best_state_id = iter->first ;
698+ *best_token = iter->second ;
702699 }
703700 }
704701 }
@@ -711,9 +708,9 @@ BaseFloat LatticeFasterDecoderCombineTpl<FST, Token>::GetCutoff(
711708 tmp_array_.push_back (w);
712709 if (w < best_weight) {
713710 best_weight = w;
714- if (best_elem ) {
715- *best_elem_id = iter->first ;
716- *best_elem = iter->second ;
711+ if (best_token ) {
712+ *best_state_id = iter->first ;
713+ *best_token = iter->second ;
717714 }
718715 }
719716 }
@@ -722,8 +719,8 @@ BaseFloat LatticeFasterDecoderCombineTpl<FST, Token>::GetCutoff(
722719 min_active_cutoff = std::numeric_limits<BaseFloat>::infinity (),
723720 max_active_cutoff = std::numeric_limits<BaseFloat>::infinity ();
724721
725- KALDI_VLOG (6 ) << " Number of tokens active on frame " << NumFramesDecoded ()
726- << " is " << tmp_array_.size ();
722+ KALDI_VLOG (6 ) << " Number of emitting tokens on frame "
723+ << NumFramesDecoded () - 1 << " is " << tmp_array_.size ();
727724
728725 if (tmp_array_.size () > static_cast <size_t >(config_.max_active )) {
729726 std::nth_element (tmp_array_.begin (),
@@ -766,9 +763,9 @@ void LatticeFasterDecoderCombineTpl<FST, Token>::ProcessForFrame(
766763 // from the decodable object.
767764 active_toks_.resize (active_toks_.size () + 1 );
768765
766+ prev_toks_.swap (cur_toks_);
769767 cur_toks_.clear ();
770- cur_toks_.swap (next_toks_);
771- if (cur_toks_.empty ()) {
768+ if (prev_toks_.empty ()) {
772769 if (!warned_) {
773770 KALDI_WARN << " Error, no surviving tokens on frame " << frame;
774771 warned_ = true ;
@@ -780,7 +777,7 @@ void LatticeFasterDecoderCombineTpl<FST, Token>::ProcessForFrame(
780777 StateId best_tok_state_id;
781778 // "cur_cutoff" is used to constrain the epsilon emittion in current frame.
782779 // It will not be updated.
783- BaseFloat cur_cutoff = GetCutoff (cur_toks_ , &adaptive_beam,
780+ BaseFloat cur_cutoff = GetCutoff (prev_toks_ , &adaptive_beam,
784781 &best_tok_state_id, &best_tok);
785782 KALDI_VLOG (6 ) << " Adaptive beam on frame " << NumFramesDecoded () << " is "
786783 << adaptive_beam;
@@ -801,7 +798,8 @@ void LatticeFasterDecoderCombineTpl<FST, Token>::ProcessForFrame(
801798 // Notice: As the difference between the combine version and the traditional
802799 // version, this "best_tok" is choosen from emittion tokens. Normally, the
803800 // best token of one frame comes from an epsilon non-emittion. So the best
804- // token is a looser boundary. Use it to estimate a bound on the next cutoff.
801+ // token is a looser boundary. We use it to estimate a bound on the next
802+ // cutoff and we will update the "next_cutoff" once we have better tokens.
805803 // The "next_cutoff" will be updated in further processing.
806804 if (best_tok) {
807805 cost_offset = - best_tok->tot_cost ;
@@ -827,7 +825,7 @@ void LatticeFasterDecoderCombineTpl<FST, Token>::ProcessForFrame(
827825
828826 // Build a queue which contains the emittion tokens from previous frame.
829827 std::vector<StateId> cur_queue;
830- for (IterType iter = cur_toks_ .begin (); iter != cur_toks_ .end (); iter++) {
828+ for (IterType iter = prev_toks_ .begin (); iter != prev_toks_ .end (); iter++) {
831829 cur_queue.push_back (iter->first );
832830 iter->second ->in_current_queue = true ;
833831 }
@@ -837,9 +835,11 @@ void LatticeFasterDecoderCombineTpl<FST, Token>::ProcessForFrame(
837835 StateId state = cur_queue.back ();
838836 cur_queue.pop_back ();
839837
840- KALDI_ASSERT (cur_toks_.find (state) != cur_toks_.end ());
841- Token *tok = cur_toks_[state];
838+ KALDI_ASSERT (prev_toks_.find (state) != prev_toks_.end ());
839+ Token *tok = prev_toks_[state];
840+
842841 BaseFloat cur_cost = tok->tot_cost ;
842+ tok->in_current_queue = false ; // out of queue
843843 if (cur_cost > cur_cutoff) // Don't bother processing successors.
844844 continue ;
845845 // If "tok" has any existing forward links, delete them,
@@ -857,7 +857,7 @@ void LatticeFasterDecoderCombineTpl<FST, Token>::ProcessForFrame(
857857 BaseFloat tot_cost = cur_cost + graph_cost;
858858 if (tot_cost < cur_cutoff) {
859859 Token *new_tok = FindOrAddToken (arc.nextstate , frame, tot_cost,
860- tok, &cur_toks_ , &changed);
860+ tok, &prev_toks_ , &changed);
861861
862862 // Add ForwardLink from tok to new_tok. Put it on the head of
863863 // tok->link list
@@ -882,29 +882,29 @@ void LatticeFasterDecoderCombineTpl<FST, Token>::ProcessForFrame(
882882
883883 // no change flag is needed
884884 Token *next_tok = FindOrAddToken (arc.nextstate , frame + 1 , tot_cost,
885- tok, &next_toks_ , NULL );
885+ tok, &cur_toks_ , NULL );
886886 // Add ForwardLink from tok to next_tok. Put it on the head of tok->link
887887 // list
888888 tok->links = new ForwardLinkT (next_tok, arc.ilabel , arc.olabel ,
889889 graph_cost, ac_cost, tok->links );
890890 }
891891 } // for all arcs
892- tok->in_current_queue = false ; // out of queue
893892 } // end of while loop
894- KALDI_VLOG (6 ) << " toks after: " << cur_toks_.size ();
893+ KALDI_VLOG (6 ) << " Number of tokens active on frame " << NumFramesDecoded () - 1
894+ << " is " << prev_toks_.size ();
895895}
896896
897897
898898template <typename FST, typename Token>
899899void LatticeFasterDecoderCombineTpl<FST, Token>::ProcessNonemitting(
900900 std::unordered_map<Token*, BaseFloat> *recover_map) {
901901 if (recover_map) { // Build the elements which are used to recover
902- for (IterType iter = next_toks_ .begin (); iter != next_toks_ .end (); iter++) {
902+ for (IterType iter = cur_toks_ .begin (); iter != cur_toks_ .end (); iter++) {
903903 (*recover_map)[iter->second ] = iter->second ->tot_cost ;
904904 }
905905 }
906906
907- StateIdToTokenMap tmp_toks (next_toks_ );
907+ StateIdToTokenMap tmp_toks (cur_toks_ );
908908 int32 frame = active_toks_.size () - 1 ;
909909 // Build the queue to process non-emitting arcs
910910 std::vector<StateId> cur_queue;
0 commit comments