tesseract  3.03
/usr/local/google/home/jbreiden/tesseract-ocr-read-only/wordrec/language_model.cpp
Go to the documentation of this file.
00001 
00002 // File:        language_model.cpp
00003 // Description: Functions that utilize the knowledge about the properties,
00004 //              structure and statistics of the language to help recognition.
00005 // Author:      Daria Antonova
00006 // Created:     Mon Nov 11 11:26:43 PST 2009
00007 //
00008 // (C) Copyright 2009, Google Inc.
00009 // Licensed under the Apache License, Version 2.0 (the "License");
00010 // you may not use this file except in compliance with the License.
00011 // You may obtain a copy of the License at
00012 // http://www.apache.org/licenses/LICENSE-2.0
00013 // Unless required by applicable law or agreed to in writing, software
00014 // distributed under the License is distributed on an "AS IS" BASIS,
00015 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
00016 // See the License for the specific language governing permissions and
00017 // limitations under the License.
00018 //
00020 
00021 #include <math.h>
00022 
00023 #include "language_model.h"
00024 
00025 #include "dawg.h"
00026 #include "freelist.h"
00027 #include "intproto.h"
00028 #include "helpers.h"
00029 #include "lm_state.h"
00030 #include "lm_pain_points.h"
00031 #include "matrix.h"
00032 #include "params.h"
00033 #include "params_training_featdef.h"
00034 
00035 #if defined(_MSC_VER) || defined(ANDROID)
00036 double log2(double n) {
00037   return log(n) / log(2.0);
00038 }
00039 #endif  // _MSC_VER
00040 
00041 namespace tesseract {
00042 
00043 const float LanguageModel::kMaxAvgNgramCost = 25.0f;
00044 
00045 LanguageModel::LanguageModel(const UnicityTable<FontInfo> *fontinfo_table,
00046                              Dict *dict)
00047   : INT_MEMBER(language_model_debug_level, 0, "Language model debug level",
00048                dict->getCCUtil()->params()),
00049     BOOL_INIT_MEMBER(language_model_ngram_on, false,
00050                      "Turn on/off the use of character ngram model",
00051                      dict->getCCUtil()->params()),
00052     INT_MEMBER(language_model_ngram_order, 8,
00053                "Maximum order of the character ngram model",
00054                dict->getCCUtil()->params()),
00055     INT_MEMBER(language_model_viterbi_list_max_num_prunable, 10,
00056                "Maximum number of prunable (those for which"
00057                " PrunablePath() is true) entries in each viterbi list"
00058                " recorded in BLOB_CHOICEs",
00059                dict->getCCUtil()->params()),
00060     INT_MEMBER(language_model_viterbi_list_max_size, 500,
00061                "Maximum size of viterbi lists recorded in BLOB_CHOICEs",
00062                dict->getCCUtil()->params()),
00063     double_MEMBER(language_model_ngram_small_prob, 0.000001,
00064                   "To avoid overly small denominators use this as the "
00065                   "floor of the probability returned by the ngram model.",
00066                   dict->getCCUtil()->params()),
00067     double_MEMBER(language_model_ngram_nonmatch_score, -40.0,
00068                   "Average classifier score of a non-matching unichar.",
00069                   dict->getCCUtil()->params()),
00070     BOOL_MEMBER(language_model_ngram_use_only_first_uft8_step, false,
00071                 "Use only the first UTF8 step of the given string"
00072                 " when computing log probabilities.",
00073                 dict->getCCUtil()->params()),
00074     double_MEMBER(language_model_ngram_scale_factor, 0.03,
00075                   "Strength of the character ngram model relative to the"
00076                   " character classifier ",
00077                   dict->getCCUtil()->params()),
00078     double_MEMBER(language_model_ngram_rating_factor, 16.0,
00079                   "Factor to bring log-probs into the same range as ratings"
00080                   " when multiplied by outline length ",
00081                   dict->getCCUtil()->params()),
00082     BOOL_MEMBER(language_model_ngram_space_delimited_language, true,
00083                 "Words are delimited by space",
00084                 dict->getCCUtil()->params()),
00085     INT_MEMBER(language_model_min_compound_length, 3,
00086                "Minimum length of compound words",
00087                dict->getCCUtil()->params()),
00088     double_MEMBER(language_model_penalty_non_freq_dict_word, 0.1,
00089                   "Penalty for words not in the frequent word dictionary",
00090                   dict->getCCUtil()->params()),
00091     double_MEMBER(language_model_penalty_non_dict_word, 0.15,
00092                   "Penalty for non-dictionary words",
00093                   dict->getCCUtil()->params()),
00094     double_MEMBER(language_model_penalty_punc, 0.2,
00095                   "Penalty for inconsistent punctuation",
00096                   dict->getCCUtil()->params()),
00097     double_MEMBER(language_model_penalty_case, 0.1,
00098                   "Penalty for inconsistent case",
00099                   dict->getCCUtil()->params()),
00100     double_MEMBER(language_model_penalty_script, 0.5,
00101                   "Penalty for inconsistent script",
00102                   dict->getCCUtil()->params()),
00103     double_MEMBER(language_model_penalty_chartype, 0.3,
00104                   "Penalty for inconsistent character type",
00105                   dict->getCCUtil()->params()),
00106     // TODO(daria, rays): enable font consistency checking
00107     // after improving font analysis.
00108     double_MEMBER(language_model_penalty_font, 0.00,
00109                   "Penalty for inconsistent font",
00110                   dict->getCCUtil()->params()),
00111     double_MEMBER(language_model_penalty_spacing, 0.05,
00112                   "Penalty for inconsistent spacing",
00113                   dict->getCCUtil()->params()),
00114     double_MEMBER(language_model_penalty_increment, 0.01,
00115                   "Penalty increment",
00116                   dict->getCCUtil()->params()),
00117     INT_MEMBER(wordrec_display_segmentations, 0, "Display Segmentations",
00118                dict->getCCUtil()->params()),
00119     BOOL_INIT_MEMBER(language_model_use_sigmoidal_certainty, false,
00120                      "Use sigmoidal score for certainty",
00121                      dict->getCCUtil()->params()),
00122   fontinfo_table_(fontinfo_table), dict_(dict),
00123   fixed_pitch_(false), max_char_wh_ratio_(0.0),
00124   acceptable_choice_found_(false) {
00125   ASSERT_HOST(dict_ != NULL);
00126   dawg_args_ = new DawgArgs(NULL, new DawgPositionVector(), NO_PERM);
00127   very_beginning_active_dawgs_ = new DawgPositionVector();
00128   beginning_active_dawgs_ = new DawgPositionVector();
00129 }
00130 
00131 LanguageModel::~LanguageModel() {
00132   delete very_beginning_active_dawgs_;
00133   delete beginning_active_dawgs_;
00134   delete dawg_args_->updated_dawgs;
00135   delete dawg_args_;
00136 }
00137 
00138 void LanguageModel::InitForWord(const WERD_CHOICE *prev_word,
00139                                 bool fixed_pitch, float max_char_wh_ratio,
00140                                 float rating_cert_scale) {
00141   fixed_pitch_ = fixed_pitch;
00142   max_char_wh_ratio_ = max_char_wh_ratio;
00143   rating_cert_scale_ = rating_cert_scale;
00144   acceptable_choice_found_ = false;
00145   correct_segmentation_explored_ = false;
00146 
00147   // Initialize vectors with beginning DawgInfos.
00148   very_beginning_active_dawgs_->clear();
00149   dict_->init_active_dawgs(very_beginning_active_dawgs_, false);
00150   beginning_active_dawgs_->clear();
00151   dict_->default_dawgs(beginning_active_dawgs_, false);
00152 
00153   // Fill prev_word_str_ with the last language_model_ngram_order
00154   // unichars from prev_word.
00155   if (language_model_ngram_on) {
00156     if (prev_word != NULL && prev_word->unichar_string() != NULL) {
00157       prev_word_str_ = prev_word->unichar_string();
00158       if (language_model_ngram_space_delimited_language) prev_word_str_ += ' ';
00159     } else {
00160       prev_word_str_ = " ";
00161     }
00162     const char *str_ptr = prev_word_str_.string();
00163     const char *str_end = str_ptr + prev_word_str_.length();
00164     int step;
00165     prev_word_unichar_step_len_ = 0;
00166     while (str_ptr != str_end && (step = UNICHAR::utf8_step(str_ptr))) {
00167       str_ptr += step;
00168       ++prev_word_unichar_step_len_;
00169     }
00170     ASSERT_HOST(str_ptr == str_end);
00171   }
00172 }
00173 
00174 // Helper scans the collection of predecessors for competing siblings that
00175 // have the same letter with the opposite case, setting competing_vse.
00176 static void ScanParentsForCaseMix(const UNICHARSET& unicharset,
00177                                   LanguageModelState* parent_node) {
00178   if (parent_node == NULL) return;
00179   ViterbiStateEntry_IT vit(&parent_node->viterbi_state_entries);
00180   for (vit.mark_cycle_pt(); !vit.cycled_list(); vit.forward()) {
00181     ViterbiStateEntry* vse = vit.data();
00182     vse->competing_vse = NULL;
00183     UNICHAR_ID unichar_id = vse->curr_b->unichar_id();
00184     if (unicharset.get_isupper(unichar_id) ||
00185         unicharset.get_islower(unichar_id)) {
00186       UNICHAR_ID other_case = unicharset.get_other_case(unichar_id);
00187       if (other_case == unichar_id) continue;  // Not in unicharset.
00188       // Find other case in same list. There could be multiple entries with
00189       // the same unichar_id, but in theory, they should all point to the
00190       // same BLOB_CHOICE, and that is what we will be using to decide
00191       // which to keep.
00192       ViterbiStateEntry_IT vit2(&parent_node->viterbi_state_entries);
00193       for (vit2.mark_cycle_pt(); !vit2.cycled_list() &&
00194            vit2.data()->curr_b->unichar_id() != other_case;
00195            vit2.forward()) {}
00196       if (!vit2.cycled_list()) {
00197         vse->competing_vse = vit2.data();
00198       }
00199     }
00200   }
00201 }
00202 
00203 // Helper returns true if the given choice has a better case variant before
00204 // it in the choice_list that is not distinguishable by size.
00205 static bool HasBetterCaseVariant(const UNICHARSET& unicharset,
00206                                  const BLOB_CHOICE* choice,
00207                                  BLOB_CHOICE_LIST* choices) {
00208   UNICHAR_ID choice_id = choice->unichar_id();
00209   UNICHAR_ID other_case = unicharset.get_other_case(choice_id);
00210   if (other_case == choice_id || other_case == INVALID_UNICHAR_ID)
00211     return false;  // Not upper or lower or not in unicharset.
00212   if (unicharset.SizesDistinct(choice_id, other_case))
00213     return false;  // Can be separated by size.
00214   BLOB_CHOICE_IT bc_it(choices);
00215   for (bc_it.mark_cycle_pt(); !bc_it.cycled_list(); bc_it.forward()) {
00216     BLOB_CHOICE* better_choice = bc_it.data();
00217     if (better_choice->unichar_id() == other_case)
00218       return true;  // Found an earlier instance of other_case.
00219     else if (better_choice == choice)
00220       return false;  // Reached the original choice.
00221   }
00222   return false;  // Should never happen, but just in case.
00223 }
00224 
00225 // UpdateState has the job of combining the ViterbiStateEntry lists on each
00226 // of the choices on parent_list with each of the blob choices in curr_list,
00227 // making a new ViterbiStateEntry for each sensible path.
00228 // This could be a huge set of combinations, creating a lot of work only to
00229 // be truncated by some beam limit, but only certain kinds of paths will
00230 // continue at the next step:
00231 //  paths that are liked by the language model: either a DAWG or the n-gram
00232 //    model, where active.
00233 //  paths that represent some kind of top choice. The old permuter permuted
00234 //   the top raw classifier score, the top upper case word and the top lower-
00235 //   case word. UpdateState now concentrates its top-choice paths on top
00236 //   lower-case, top upper-case (or caseless alpha), and top digit sequence,
00237 //   with allowance for continuation of these paths through blobs where such
00238 //   a character does not appear in the choices list.
00239 // GetNextParentVSE enforces some of these models to minimize the number of
00240 // calls to AddViterbiStateEntry, even prior to looking at the language model.
00241 // Thus an n-blob sequence of [l1I] will produce 3n calls to
00242 // AddViterbiStateEntry instead of 3^n.
00243 // Of course it isn't quite that simple as Title Case is handled by allowing
00244 // lower case to continue an upper case initial, but it has to be detected
00245 // in the combiner so it knows which upper case letters are initial alphas.
00246 bool LanguageModel::UpdateState(
00247     bool just_classified,
00248     int curr_col, int curr_row,
00249     BLOB_CHOICE_LIST *curr_list,
00250     LanguageModelState *parent_node,
00251     LMPainPoints *pain_points,
00252     WERD_RES *word_res,
00253     BestChoiceBundle *best_choice_bundle,
00254     BlamerBundle *blamer_bundle) {
00255   if (language_model_debug_level > 0) {
00256     tprintf("\nUpdateState: col=%d row=%d %s",
00257             curr_col, curr_row, just_classified ? "just_classified" : "");
00258     if (language_model_debug_level > 5)
00259       tprintf("(parent=%p)\n", parent_node);
00260     else
00261       tprintf("\n");
00262   }
00263   // Initialize helper variables.
00264   bool word_end = (curr_row+1 >= word_res->ratings->dimension());
00265   bool new_changed = false;
00266   float denom = (language_model_ngram_on) ? ComputeDenom(curr_list) : 1.0f;
00267   const UNICHARSET& unicharset = dict_->getUnicharset();
00268   BLOB_CHOICE *first_lower = NULL;
00269   BLOB_CHOICE *first_upper = NULL;
00270   BLOB_CHOICE *first_digit = NULL;
00271   bool has_alnum_mix = false;
00272   if (parent_node != NULL) {
00273     int result = SetTopParentLowerUpperDigit(parent_node);
00274     if (result < 0) {
00275       if (language_model_debug_level > 0)
00276         tprintf("No parents found to process\n");
00277       return false;
00278     }
00279     if (result > 0)
00280       has_alnum_mix = true;
00281   }
00282   if (!GetTopLowerUpperDigit(curr_list, &first_lower, &first_upper,
00283                              &first_digit))
00284     has_alnum_mix = false;;
00285   ScanParentsForCaseMix(unicharset, parent_node);
00286   if (language_model_debug_level > 3 && parent_node != NULL) {
00287     parent_node->Print("Parent viterbi list");
00288   }
00289   LanguageModelState *curr_state = best_choice_bundle->beam[curr_row];
00290 
00291   // Call AddViterbiStateEntry() for each parent+child ViterbiStateEntry.
00292   ViterbiStateEntry_IT vit;
00293   BLOB_CHOICE_IT c_it(curr_list);
00294   for (c_it.mark_cycle_pt(); !c_it.cycled_list(); c_it.forward()) {
00295     BLOB_CHOICE* choice = c_it.data();
00296     // TODO(antonova): make sure commenting this out if ok for ngram
00297     // model scoring (I think this was introduced to fix ngram model quirks).
00298     // Skip NULL unichars unless it is the only choice.
00299     //if (!curr_list->singleton() && c_it.data()->unichar_id() == 0) continue;
00300     UNICHAR_ID unichar_id = choice->unichar_id();
00301     if (unicharset.get_fragment(unichar_id)) {
00302       continue;  // skip fragments
00303     }
00304     // Set top choice flags.
00305     LanguageModelFlagsType blob_choice_flags = kXhtConsistentFlag;
00306     if (c_it.at_first() || !new_changed)
00307       blob_choice_flags |= kSmallestRatingFlag;
00308     if (first_lower == choice) blob_choice_flags |= kLowerCaseFlag;
00309     if (first_upper == choice) blob_choice_flags |= kUpperCaseFlag;
00310     if (first_digit == choice) blob_choice_flags |= kDigitFlag;
00311 
00312     if (parent_node == NULL) {
00313       // Process the beginning of a word.
00314       // If there is a better case variant that is not distinguished by size,
00315       // skip this blob choice, as we have no choice but to accept the result
00316       // of the character classifier to distinguish between them, even if
00317       // followed by an upper case.
00318       // With words like iPoc, and other CamelBackWords, the lower-upper
00319       // transition can only be achieved if the classifier has the correct case
00320       // as the top choice, and leaving an initial I lower down the list
00321       // increases the chances of choosing IPoc simply because it doesn't
00322       // include such a transition. iPoc will beat iPOC and ipoc because
00323       // the other words are baseline/x-height inconsistent.
00324       if (HasBetterCaseVariant(unicharset, choice, curr_list))
00325         continue;
00326       // Upper counts as lower at the beginning of a word.
00327       if (blob_choice_flags & kUpperCaseFlag)
00328         blob_choice_flags |= kLowerCaseFlag;
00329       new_changed |= AddViterbiStateEntry(
00330           blob_choice_flags, denom, word_end, curr_col, curr_row,
00331           choice, curr_state, NULL, pain_points,
00332           word_res, best_choice_bundle, blamer_bundle);
00333     } else {
00334       // Get viterbi entries from each parent ViterbiStateEntry.
00335       vit.set_to_list(&parent_node->viterbi_state_entries);
00336       int vit_counter = 0;
00337       vit.mark_cycle_pt();
00338       ViterbiStateEntry* parent_vse = NULL;
00339       LanguageModelFlagsType top_choice_flags;
00340       while ((parent_vse = GetNextParentVSE(just_classified, has_alnum_mix,
00341                                             c_it.data(), blob_choice_flags,
00342                                             unicharset, word_res, &vit,
00343                                             &top_choice_flags)) != NULL) {
00344         // Skip pruned entries and do not look at prunable entries if already
00345         // examined language_model_viterbi_list_max_num_prunable of those.
00346         if (PrunablePath(*parent_vse) &&
00347             (++vit_counter > language_model_viterbi_list_max_num_prunable ||
00348              (language_model_ngram_on && parent_vse->ngram_info->pruned))) {
00349           continue;
00350         }
00351         // If the parent has no alnum choice, (ie choice is the first in a
00352         // string of alnum), and there is a better case variant that is not
00353         // distinguished by size, skip this blob choice/parent, as with the
00354         // initial blob treatment above.
00355         if (!parent_vse->HasAlnumChoice(unicharset) &&
00356             HasBetterCaseVariant(unicharset, choice, curr_list))
00357           continue;
00358         // Create a new ViterbiStateEntry if BLOB_CHOICE in c_it.data()
00359         // looks good according to the Dawgs or character ngram model.
00360         new_changed |= AddViterbiStateEntry(
00361             top_choice_flags, denom, word_end, curr_col, curr_row,
00362             c_it.data(), curr_state, parent_vse, pain_points,
00363             word_res, best_choice_bundle, blamer_bundle);
00364       }
00365     }
00366   }
00367   return new_changed;
00368 }
00369 
00370 // Finds the first lower and upper case letter and first digit in curr_list.
00371 // For non-upper/lower languages, alpha counts as upper.
00372 // Uses the first character in the list in place of empty results.
00373 // Returns true if both alpha and digits are found.
00374 bool LanguageModel::GetTopLowerUpperDigit(BLOB_CHOICE_LIST *curr_list,
00375                                           BLOB_CHOICE **first_lower,
00376                                           BLOB_CHOICE **first_upper,
00377                                           BLOB_CHOICE **first_digit) const {
00378   BLOB_CHOICE_IT c_it(curr_list);
00379   const UNICHARSET &unicharset = dict_->getUnicharset();
00380   BLOB_CHOICE *first_unichar = NULL;
00381   for (c_it.mark_cycle_pt(); !c_it.cycled_list(); c_it.forward()) {
00382     UNICHAR_ID unichar_id = c_it.data()->unichar_id();
00383     if (unicharset.get_fragment(unichar_id)) continue;  // skip fragments
00384     if (first_unichar == NULL) first_unichar = c_it.data();
00385     if (*first_lower == NULL && unicharset.get_islower(unichar_id)) {
00386       *first_lower = c_it.data();
00387     }
00388     if (*first_upper == NULL && unicharset.get_isalpha(unichar_id) &&
00389         !unicharset.get_islower(unichar_id)) {
00390       *first_upper = c_it.data();
00391     }
00392     if (*first_digit == NULL && unicharset.get_isdigit(unichar_id)) {
00393       *first_digit = c_it.data();
00394     }
00395   }
00396   ASSERT_HOST(first_unichar != NULL);
00397   bool mixed = (*first_lower != NULL || *first_upper != NULL) &&
00398       *first_digit != NULL;
00399   if (*first_lower == NULL) *first_lower = first_unichar;
00400   if (*first_upper == NULL) *first_upper = first_unichar;
00401   if (*first_digit == NULL) *first_digit = first_unichar;
00402   return mixed;
00403 }
00404 
00405 // Forces there to be at least one entry in the overall set of the
00406 // viterbi_state_entries of each element of parent_node that has the
00407 // top_choice_flag set for lower, upper and digit using the same rules as
00408 // GetTopLowerUpperDigit, setting the flag on the first found suitable
00409 // candidate, whether or not the flag is set on some other parent.
00410 // Returns 1 if both alpha and digits are found among the parents, -1 if no
00411 // parents are found at all (a legitimate case), and 0 otherwise.
00412 int LanguageModel::SetTopParentLowerUpperDigit(
00413     LanguageModelState *parent_node) const {
00414   if (parent_node == NULL) return -1;
00415   UNICHAR_ID top_id = INVALID_UNICHAR_ID;
00416   ViterbiStateEntry* top_lower = NULL;
00417   ViterbiStateEntry* top_upper = NULL;
00418   ViterbiStateEntry* top_digit = NULL;
00419   ViterbiStateEntry* top_choice = NULL;
00420   float lower_rating = 0.0f;
00421   float upper_rating = 0.0f;
00422   float digit_rating = 0.0f;
00423   float top_rating = 0.0f;
00424   const UNICHARSET &unicharset = dict_->getUnicharset();
00425   ViterbiStateEntry_IT vit(&parent_node->viterbi_state_entries);
00426   for (vit.mark_cycle_pt(); !vit.cycled_list(); vit.forward()) {
00427     ViterbiStateEntry* vse = vit.data();
00428     // INVALID_UNICHAR_ID should be treated like a zero-width joiner, so scan
00429     // back to the real character if needed.
00430     ViterbiStateEntry* unichar_vse = vse;
00431     UNICHAR_ID unichar_id = unichar_vse->curr_b->unichar_id();
00432     float rating = unichar_vse->curr_b->rating();
00433     while (unichar_id == INVALID_UNICHAR_ID &&
00434            unichar_vse->parent_vse != NULL) {
00435       unichar_vse = unichar_vse->parent_vse;
00436       unichar_id = unichar_vse->curr_b->unichar_id();
00437       rating = unichar_vse->curr_b->rating();
00438     }
00439     if (unichar_id != INVALID_UNICHAR_ID) {
00440       if (unicharset.get_islower(unichar_id)) {
00441         if (top_lower == NULL || lower_rating > rating) {
00442           top_lower = vse;
00443           lower_rating = rating;
00444         }
00445       } else if (unicharset.get_isalpha(unichar_id)) {
00446         if (top_upper == NULL || upper_rating > rating) {
00447           top_upper = vse;
00448           upper_rating = rating;
00449         }
00450       } else if (unicharset.get_isdigit(unichar_id)) {
00451         if (top_digit == NULL || digit_rating > rating) {
00452           top_digit = vse;
00453           digit_rating = rating;
00454         }
00455       }
00456     }
00457     if (top_choice == NULL || top_rating > rating) {
00458       top_choice = vse;
00459       top_rating = rating;
00460       top_id = unichar_id;
00461     }
00462   }
00463   if (top_choice == NULL) return -1;
00464   bool mixed = (top_lower != NULL || top_upper != NULL) &&
00465       top_digit != NULL;
00466   if (top_lower == NULL) top_lower = top_choice;
00467   top_lower->top_choice_flags |= kLowerCaseFlag;
00468   if (top_upper == NULL) top_upper = top_choice;
00469   top_upper->top_choice_flags |= kUpperCaseFlag;
00470   if (top_digit == NULL) top_digit = top_choice;
00471   top_digit->top_choice_flags |= kDigitFlag;
00472   top_choice->top_choice_flags |= kSmallestRatingFlag;
00473   if (top_id != INVALID_UNICHAR_ID && dict_->compound_marker(top_id) &&
00474       (top_choice->top_choice_flags &
00475           (kLowerCaseFlag | kUpperCaseFlag | kDigitFlag))) {
00476     // If the compound marker top choice carries any of the top alnum flags,
00477     // then give it all of them, allowing words like I-295 to be chosen.
00478     top_choice->top_choice_flags |=
00479         kLowerCaseFlag | kUpperCaseFlag | kDigitFlag;
00480   }
00481   return mixed ? 1 : 0;
00482 }
00483 
00484 // Finds the next ViterbiStateEntry with which the given unichar_id can
00485 // combine sensibly, taking into account any mixed alnum/mixed case
00486 // situation, and whether this combination has been inspected before.
00487 ViterbiStateEntry* LanguageModel::GetNextParentVSE(
00488     bool just_classified, bool mixed_alnum, const BLOB_CHOICE* bc,
00489     LanguageModelFlagsType blob_choice_flags, const UNICHARSET& unicharset,
00490     WERD_RES* word_res, ViterbiStateEntry_IT* vse_it,
00491     LanguageModelFlagsType* top_choice_flags) const {
00492   for (; !vse_it->cycled_list(); vse_it->forward()) {
00493     ViterbiStateEntry* parent_vse = vse_it->data();
00494     // Only consider the parent if it has been updated or
00495     // if the current ratings cell has just been classified.
00496     if (!just_classified && !parent_vse->updated) continue;
00497     if (language_model_debug_level > 2)
00498       parent_vse->Print("Considering");
00499     // If the parent is non-alnum, then upper counts as lower.
00500     *top_choice_flags = blob_choice_flags;
00501     if ((blob_choice_flags & kUpperCaseFlag) &&
00502         !parent_vse->HasAlnumChoice(unicharset)) {
00503       *top_choice_flags |= kLowerCaseFlag;
00504     }
00505     *top_choice_flags &= parent_vse->top_choice_flags;
00506     UNICHAR_ID unichar_id = bc->unichar_id();
00507     const BLOB_CHOICE* parent_b = parent_vse->curr_b;
00508     UNICHAR_ID parent_id = parent_b->unichar_id();
00509     // Digits do not bind to alphas if there is a mix in both parent and current
00510     // or if the alpha is not the top choice.
00511     if (unicharset.get_isdigit(unichar_id) &&
00512         unicharset.get_isalpha(parent_id) &&
00513         (mixed_alnum || *top_choice_flags == 0))
00514       continue;  // Digits don't bind to alphas.
00515     // Likewise alphas do not bind to digits if there is a mix in both or if
00516     // the digit is not the top choice.
00517     if (unicharset.get_isalpha(unichar_id) &&
00518         unicharset.get_isdigit(parent_id) &&
00519         (mixed_alnum || *top_choice_flags == 0))
00520       continue;  // Alphas don't bind to digits.
00521     // If there is a case mix of the same alpha in the parent list, then
00522     // competing_vse is non-null and will be used to determine whether
00523     // or not to bind the current blob choice.
00524     if (parent_vse->competing_vse != NULL) {
00525       const BLOB_CHOICE* competing_b = parent_vse->competing_vse->curr_b;
00526       UNICHAR_ID other_id = competing_b->unichar_id();
00527       if (language_model_debug_level >= 5) {
00528         tprintf("Parent %s has competition %s\n",
00529                 unicharset.id_to_unichar(parent_id),
00530                 unicharset.id_to_unichar(other_id));
00531       }
00532       if (unicharset.SizesDistinct(parent_id, other_id)) {
00533         // If other_id matches bc wrt position and size, and parent_id, doesn't,
00534         // don't bind to the current parent.
00535         if (bc->PosAndSizeAgree(*competing_b, word_res->x_height,
00536                                 language_model_debug_level >= 5) &&
00537             !bc->PosAndSizeAgree(*parent_b, word_res->x_height,
00538                                 language_model_debug_level >= 5))
00539           continue;  // Competing blobchoice has a better vertical match.
00540       }
00541     }
00542     vse_it->forward();
00543     return parent_vse;  // This one is good!
00544   }
00545   return NULL;  // Ran out of possibilities.
00546 }
00547 
00548 bool LanguageModel::AddViterbiStateEntry(
00549     LanguageModelFlagsType top_choice_flags,
00550     float denom,
00551     bool word_end,
00552     int curr_col, int curr_row,
00553     BLOB_CHOICE *b,
00554     LanguageModelState *curr_state,
00555     ViterbiStateEntry *parent_vse,
00556     LMPainPoints *pain_points,
00557     WERD_RES *word_res,
00558     BestChoiceBundle *best_choice_bundle,
00559     BlamerBundle *blamer_bundle) {
00560   ViterbiStateEntry_IT vit;
00561   if (language_model_debug_level > 1) {
00562     tprintf("AddViterbiStateEntry for unichar %s rating=%.4f"
00563             " certainty=%.4f top_choice_flags=0x%x",
00564             dict_->getUnicharset().id_to_unichar(b->unichar_id()),
00565             b->rating(), b->certainty(), top_choice_flags);
00566     if (language_model_debug_level > 5)
00567       tprintf(" parent_vse=%p\n", parent_vse);
00568     else
00569       tprintf("\n");
00570   }
00571   // Check whether the list is full.
00572   if (curr_state != NULL &&
00573       curr_state->viterbi_state_entries_length >=
00574           language_model_viterbi_list_max_size) {
00575     if (language_model_debug_level > 1) {
00576       tprintf("AddViterbiStateEntry: viterbi list is full!\n");
00577     }
00578     return false;
00579   }
00580 
00581   // Invoke Dawg language model component.
00582   LanguageModelDawgInfo *dawg_info =
00583     GenerateDawgInfo(word_end, curr_col, curr_row, *b, parent_vse);
00584 
00585   float outline_length =
00586       AssociateUtils::ComputeOutlineLength(rating_cert_scale_, *b);
00587   // Invoke Ngram language model component.
00588   LanguageModelNgramInfo *ngram_info = NULL;
00589   if (language_model_ngram_on) {
00590     ngram_info = GenerateNgramInfo(
00591         dict_->getUnicharset().id_to_unichar(b->unichar_id()), b->certainty(),
00592         denom, curr_col, curr_row, outline_length, parent_vse);
00593     ASSERT_HOST(ngram_info != NULL);
00594   }
00595   bool liked_by_language_model = dawg_info != NULL ||
00596       (ngram_info != NULL && !ngram_info->pruned);
00597   // Quick escape if not liked by the language model, can't be consistent
00598   // xheight, and not top choice.
00599   if (!liked_by_language_model && top_choice_flags == 0) {
00600     if (language_model_debug_level > 1) {
00601       tprintf("Language model components very early pruned this entry\n");
00602     }
00603     delete ngram_info;
00604     delete dawg_info;
00605     return false;
00606   }
00607 
00608   // Check consistency of the path and set the relevant consistency_info.
00609   LMConsistencyInfo consistency_info(
00610     parent_vse != NULL ? &parent_vse->consistency_info : NULL);
00611   // Start with just the x-height consistency, as it provides significant
00612   // pruning opportunity.
00613   consistency_info.ComputeXheightConsistency(
00614       b, dict_->getUnicharset().get_ispunctuation(b->unichar_id()));
00615   // Turn off xheight consistent flag if not consistent.
00616   if (consistency_info.InconsistentXHeight()) {
00617     top_choice_flags &= ~kXhtConsistentFlag;
00618   }
00619 
00620   // Quick escape if not liked by the language model, not consistent xheight,
00621   // and not top choice.
00622   if (!liked_by_language_model && top_choice_flags == 0) {
00623     if (language_model_debug_level > 1) {
00624       tprintf("Language model components early pruned this entry\n");
00625     }
00626     delete ngram_info;
00627     delete dawg_info;
00628     return false;
00629   }
00630 
00631   // Compute the rest of the consistency info.
00632   FillConsistencyInfo(curr_col, word_end, b, parent_vse,
00633                       word_res, &consistency_info);
00634   if (dawg_info != NULL && consistency_info.invalid_punc) {
00635     consistency_info.invalid_punc = false;  // do not penalize dict words
00636   }
00637 
00638   // Compute cost of associating the blobs that represent the current unichar.
00639   AssociateStats associate_stats;
00640   ComputeAssociateStats(curr_col, curr_row, max_char_wh_ratio_,
00641                         parent_vse, word_res, &associate_stats);
00642   if (parent_vse != NULL) {
00643     associate_stats.shape_cost += parent_vse->associate_stats.shape_cost;
00644     associate_stats.bad_shape |= parent_vse->associate_stats.bad_shape;
00645   }
00646 
00647   // Create the new ViterbiStateEntry compute the adjusted cost of the path.
00648   ViterbiStateEntry *new_vse = new ViterbiStateEntry(
00649       parent_vse, b, 0.0, outline_length,
00650       consistency_info, associate_stats, top_choice_flags, dawg_info,
00651       ngram_info, (language_model_debug_level > 0) ?
00652           dict_->getUnicharset().id_to_unichar(b->unichar_id()) : NULL);
00653   new_vse->cost = ComputeAdjustedPathCost(new_vse);
00654 
00655   // Invoke Top Choice language model component to make the final adjustments
00656   // to new_vse->top_choice_flags.
00657   if (!curr_state->viterbi_state_entries.empty() && new_vse->top_choice_flags) {
00658     GenerateTopChoiceInfo(new_vse, parent_vse, curr_state);
00659   }
00660 
00661   // If language model components did not like this unichar - return.
00662   bool keep = new_vse->top_choice_flags || liked_by_language_model;
00663   if (!(top_choice_flags & kSmallestRatingFlag) &&  // no non-top choice paths
00664       consistency_info.inconsistent_script) {       // with inconsistent script
00665     keep = false;
00666   }
00667   if (!keep) {
00668     if (language_model_debug_level > 1) {
00669       tprintf("Language model components did not like this entry\n");
00670     }
00671     delete new_vse;
00672     return false;
00673   }
00674 
00675   // Discard this entry if it represents a prunable path and
00676   // language_model_viterbi_list_max_num_prunable such entries with a lower
00677   // cost have already been recorded.
00678   if (PrunablePath(*new_vse) &&
00679       (curr_state->viterbi_state_entries_prunable_length >=
00680        language_model_viterbi_list_max_num_prunable) &&
00681       new_vse->cost >= curr_state->viterbi_state_entries_prunable_max_cost) {
00682     if (language_model_debug_level > 1) {
00683       tprintf("Discarded ViterbiEntry with high cost %g max cost %g\n",
00684               new_vse->cost,
00685               curr_state->viterbi_state_entries_prunable_max_cost);
00686     }
00687     delete new_vse;
00688     return false;
00689   }
00690 
00691   // Update best choice if needed.
00692   if (word_end) {
00693     UpdateBestChoice(new_vse, pain_points, word_res,
00694                      best_choice_bundle, blamer_bundle);
00695     // Discard the entry if UpdateBestChoice() found flaws in it.
00696     if (new_vse->cost >= WERD_CHOICE::kBadRating &&
00697         new_vse != best_choice_bundle->best_vse) {
00698       if (language_model_debug_level > 1) {
00699         tprintf("Discarded ViterbiEntry with high cost %g\n", new_vse->cost);
00700       }
00701       delete new_vse;
00702       return false;
00703     }
00704   }
00705 
00706   // Add the new ViterbiStateEntry and to curr_state->viterbi_state_entries.
00707   curr_state->viterbi_state_entries.add_sorted(ViterbiStateEntry::Compare,
00708                                                false, new_vse);
00709   curr_state->viterbi_state_entries_length++;
00710   if (PrunablePath(*new_vse)) {
00711     curr_state->viterbi_state_entries_prunable_length++;
00712   }
00713 
00714   // Update lms->viterbi_state_entries_prunable_max_cost and clear
00715   // top_choice_flags of entries with ratings_sum than new_vse->ratings_sum.
00716   if ((curr_state->viterbi_state_entries_prunable_length >=
00717        language_model_viterbi_list_max_num_prunable) ||
00718       new_vse->top_choice_flags) {
00719     ASSERT_HOST(!curr_state->viterbi_state_entries.empty());
00720     int prunable_counter = language_model_viterbi_list_max_num_prunable;
00721     vit.set_to_list(&(curr_state->viterbi_state_entries));
00722     for (vit.mark_cycle_pt(); !vit.cycled_list(); vit.forward()) {
00723       ViterbiStateEntry *curr_vse = vit.data();
00724       // Clear the appropriate top choice flags of the entries in the
00725       // list that have cost higher thank new_entry->cost
00726       // (since they will not be top choices any more).
00727       if (curr_vse->top_choice_flags && curr_vse != new_vse &&
00728           curr_vse->cost > new_vse->cost) {
00729         curr_vse->top_choice_flags &= ~(new_vse->top_choice_flags);
00730       }
00731       if (prunable_counter > 0 && PrunablePath(*curr_vse)) --prunable_counter;
00732       // Update curr_state->viterbi_state_entries_prunable_max_cost.
00733       if (prunable_counter == 0) {
00734         curr_state->viterbi_state_entries_prunable_max_cost = vit.data()->cost;
00735         if (language_model_debug_level > 1) {
00736           tprintf("Set viterbi_state_entries_prunable_max_cost to %g\n",
00737                   curr_state->viterbi_state_entries_prunable_max_cost);
00738         }
00739         prunable_counter = -1;  // stop counting
00740       }
00741     }
00742   }
00743 
00744   // Print the newly created ViterbiStateEntry.
00745   if (language_model_debug_level > 2) {
00746     new_vse->Print("New");
00747     if (language_model_debug_level > 5)
00748       curr_state->Print("Updated viterbi list");
00749   }
00750 
00751   return true;
00752 }
00753 
00754 void LanguageModel::GenerateTopChoiceInfo(ViterbiStateEntry *new_vse,
00755                                           const ViterbiStateEntry *parent_vse,
00756                                           LanguageModelState *lms) {
00757   ViterbiStateEntry_IT vit(&(lms->viterbi_state_entries));
00758   for (vit.mark_cycle_pt(); !vit.cycled_list() && new_vse->top_choice_flags &&
00759        new_vse->cost >= vit.data()->cost; vit.forward()) {
00760     // Clear the appropriate flags if the list already contains
00761     // a top choice entry with a lower cost.
00762     new_vse->top_choice_flags &= ~(vit.data()->top_choice_flags);
00763   }
00764   if (language_model_debug_level > 2) {
00765     tprintf("GenerateTopChoiceInfo: top_choice_flags=0x%x\n",
00766             new_vse->top_choice_flags);
00767   }
00768 }
00769 
00770 LanguageModelDawgInfo *LanguageModel::GenerateDawgInfo(
00771     bool word_end,
00772     int curr_col, int curr_row,
00773     const BLOB_CHOICE &b,
00774     const ViterbiStateEntry *parent_vse) {
00775   // Initialize active_dawgs from parent_vse if it is not NULL.
00776   // Otherwise use very_beginning_active_dawgs_.
00777   if (parent_vse == NULL) {
00778     dawg_args_->active_dawgs = very_beginning_active_dawgs_;
00779     dawg_args_->permuter = NO_PERM;
00780   } else {
00781     if (parent_vse->dawg_info == NULL) return NULL;  // not a dict word path
00782     dawg_args_->active_dawgs = parent_vse->dawg_info->active_dawgs;
00783     dawg_args_->permuter = parent_vse->dawg_info->permuter;
00784   }
00785 
00786   // Deal with hyphenated words.
00787   if (word_end && dict_->has_hyphen_end(b.unichar_id(), curr_col == 0)) {
00788     if (language_model_debug_level > 0) tprintf("Hyphenated word found\n");
00789     return new LanguageModelDawgInfo(dawg_args_->active_dawgs,
00790                                      COMPOUND_PERM);
00791   }
00792 
00793   // Deal with compound words.
00794   if (dict_->compound_marker(b.unichar_id()) &&
00795       (parent_vse == NULL || parent_vse->dawg_info->permuter != NUMBER_PERM)) {
00796     if (language_model_debug_level > 0) tprintf("Found compound marker\n");
00797     // Do not allow compound operators at the beginning and end of the word.
00798     // Do not allow more than one compound operator per word.
00799     // Do not allow compounding of words with lengths shorter than
00800     // language_model_min_compound_length
00801     if (parent_vse == NULL || word_end ||
00802         dawg_args_->permuter == COMPOUND_PERM ||
00803         parent_vse->length < language_model_min_compound_length) return NULL;
00804 
00805     int i;
00806     // Check a that the path terminated before the current character is a word.
00807     bool has_word_ending = false;
00808     for (i = 0; i < parent_vse->dawg_info->active_dawgs->size(); ++i) {
00809       const DawgPosition &pos = (*parent_vse->dawg_info->active_dawgs)[i];
00810       const Dawg *pdawg = pos.dawg_index < 0
00811           ? NULL : dict_->GetDawg(pos.dawg_index);
00812       if (pdawg == NULL || pos.back_to_punc) continue;;
00813       if (pdawg->type() == DAWG_TYPE_WORD && pos.dawg_ref != NO_EDGE &&
00814           pdawg->end_of_word(pos.dawg_ref)) {
00815         has_word_ending = true;
00816         break;
00817       }
00818     }
00819     if (!has_word_ending) return NULL;
00820 
00821     if (language_model_debug_level > 0) tprintf("Compound word found\n");
00822     return new LanguageModelDawgInfo(beginning_active_dawgs_, COMPOUND_PERM);
00823   }  // done dealing with compound words
00824 
00825   LanguageModelDawgInfo *dawg_info = NULL;
00826 
00827   // Call LetterIsOkay().
00828   // Use the normalized IDs so that all shapes of ' can be allowed in words
00829   // like don't.
00830   const GenericVector<UNICHAR_ID>& normed_ids =
00831       dict_->getUnicharset().normed_ids(b.unichar_id());
00832   DawgPositionVector tmp_active_dawgs;
00833   for (int i = 0; i < normed_ids.size(); ++i) {
00834     if (language_model_debug_level > 2)
00835       tprintf("Test Letter OK for unichar %d, normed %d\n",
00836               b.unichar_id(), normed_ids[i]);
00837     dict_->LetterIsOkay(dawg_args_, normed_ids[i],
00838                         word_end && i == normed_ids.size() - 1);
00839     if (dawg_args_->permuter == NO_PERM) {
00840       break;
00841     } else if (i < normed_ids.size() - 1) {
00842       tmp_active_dawgs = *dawg_args_->updated_dawgs;
00843       dawg_args_->active_dawgs = &tmp_active_dawgs;
00844     }
00845     if (language_model_debug_level > 2)
00846       tprintf("Letter was OK for unichar %d, normed %d\n",
00847               b.unichar_id(), normed_ids[i]);
00848   }
00849   dawg_args_->active_dawgs = NULL;
00850   if (dawg_args_->permuter != NO_PERM) {
00851     dawg_info = new LanguageModelDawgInfo(dawg_args_->updated_dawgs,
00852                                           dawg_args_->permuter);
00853   } else if (language_model_debug_level > 3) {
00854     tprintf("Letter %s not OK!\n",
00855             dict_->getUnicharset().id_to_unichar(b.unichar_id()));
00856   }
00857 
00858   return dawg_info;
00859 }
00860 
00861 LanguageModelNgramInfo *LanguageModel::GenerateNgramInfo(
00862     const char *unichar, float certainty, float denom,
00863     int curr_col, int curr_row, float outline_length,
00864     const ViterbiStateEntry *parent_vse) {
00865   // Initialize parent context.
00866   const char *pcontext_ptr = "";
00867   int pcontext_unichar_step_len = 0;
00868   if (parent_vse == NULL) {
00869     pcontext_ptr = prev_word_str_.string();
00870     pcontext_unichar_step_len = prev_word_unichar_step_len_;
00871   } else {
00872     pcontext_ptr = parent_vse->ngram_info->context.string();
00873     pcontext_unichar_step_len =
00874       parent_vse->ngram_info->context_unichar_step_len;
00875   }
00876   // Compute p(unichar | parent context).
00877   int unichar_step_len = 0;
00878   bool pruned = false;
00879   float ngram_cost;
00880   float ngram_and_classifier_cost =
00881       ComputeNgramCost(unichar, certainty, denom,
00882                        pcontext_ptr, &unichar_step_len,
00883                        &pruned, &ngram_cost);
00884   // Normalize just the ngram_and_classifier_cost by outline_length.
00885   // The ngram_cost is used by the params_model, so it needs to be left as-is,
00886   // and the params model cost will be normalized by outline_length.
00887   ngram_and_classifier_cost *=
00888       outline_length / language_model_ngram_rating_factor;
00889   // Add the ngram_cost of the parent.
00890   if (parent_vse != NULL) {
00891     ngram_and_classifier_cost +=
00892         parent_vse->ngram_info->ngram_and_classifier_cost;
00893     ngram_cost += parent_vse->ngram_info->ngram_cost;
00894   }
00895 
00896   // Shorten parent context string by unichar_step_len unichars.
00897   int num_remove = (unichar_step_len + pcontext_unichar_step_len -
00898                     language_model_ngram_order);
00899   if (num_remove > 0) pcontext_unichar_step_len -= num_remove;
00900   while (num_remove > 0 && *pcontext_ptr != '\0') {
00901     pcontext_ptr += UNICHAR::utf8_step(pcontext_ptr);
00902     --num_remove;
00903   }
00904 
00905   // Decide whether to prune this ngram path and update changed accordingly.
00906   if (parent_vse != NULL && parent_vse->ngram_info->pruned) pruned = true;
00907 
00908   // Construct and return the new LanguageModelNgramInfo.
00909   LanguageModelNgramInfo *ngram_info = new LanguageModelNgramInfo(
00910       pcontext_ptr, pcontext_unichar_step_len, pruned, ngram_cost,
00911       ngram_and_classifier_cost);
00912   ngram_info->context += unichar;
00913   ngram_info->context_unichar_step_len += unichar_step_len;
00914   assert(ngram_info->context_unichar_step_len <= language_model_ngram_order);
00915   return ngram_info;
00916 }
00917 
00918 float LanguageModel::ComputeNgramCost(const char *unichar,
00919                                       float certainty,
00920                                       float denom,
00921                                       const char *context,
00922                                       int *unichar_step_len,
00923                                       bool *found_small_prob,
00924                                       float *ngram_cost) {
00925   const char *context_ptr = context;
00926   char *modified_context = NULL;
00927   char *modified_context_end = NULL;
00928   const char *unichar_ptr = unichar;
00929   const char *unichar_end = unichar_ptr + strlen(unichar_ptr);
00930   float prob = 0.0f;
00931   int step = 0;
00932   while (unichar_ptr < unichar_end &&
00933          (step = UNICHAR::utf8_step(unichar_ptr)) > 0) {
00934     if (language_model_debug_level > 1) {
00935       tprintf("prob(%s | %s)=%g\n", unichar_ptr, context_ptr,
00936               dict_->ProbabilityInContext(context_ptr, -1, unichar_ptr, step));
00937     }
00938     prob += dict_->ProbabilityInContext(context_ptr, -1, unichar_ptr, step);
00939     ++(*unichar_step_len);
00940     if (language_model_ngram_use_only_first_uft8_step) break;
00941     unichar_ptr += step;
00942     // If there are multiple UTF8 characters present in unichar, context is
00943     // updated to include the previously examined characters from str,
00944     // unless use_only_first_uft8_step is true.
00945     if (unichar_ptr < unichar_end) {
00946       if (modified_context == NULL) {
00947         int context_len = strlen(context);
00948         modified_context =
00949           new char[context_len + strlen(unichar_ptr) + step + 1];
00950         strncpy(modified_context, context, context_len);
00951         modified_context_end = modified_context + context_len;
00952         context_ptr = modified_context;
00953       }
00954       strncpy(modified_context_end, unichar_ptr - step, step);
00955       modified_context_end += step;
00956       *modified_context_end = '\0';
00957     }
00958   }
00959   prob /= static_cast<float>(*unichar_step_len);  // normalize
00960   if (prob < language_model_ngram_small_prob) {
00961     if (language_model_debug_level > 0) tprintf("Found small prob %g\n", prob);
00962     *found_small_prob = true;
00963     prob = language_model_ngram_small_prob;
00964   }
00965   *ngram_cost = -1.0*log2(prob);
00966   float ngram_and_classifier_cost =
00967       -1.0*log2(CertaintyScore(certainty)/denom) +
00968       *ngram_cost * language_model_ngram_scale_factor;
00969   if (language_model_debug_level > 1) {
00970     tprintf("-log [ p(%s) * p(%s | %s) ] = -log2(%g*%g) = %g\n", unichar,
00971             unichar, context_ptr, CertaintyScore(certainty)/denom, prob,
00972             ngram_and_classifier_cost);
00973   }
00974   if (modified_context != NULL) delete[] modified_context;
00975   return ngram_and_classifier_cost;
00976 }
00977 
00978 float LanguageModel::ComputeDenom(BLOB_CHOICE_LIST *curr_list) {
00979   if (curr_list->empty()) return 1.0f;
00980   float denom = 0.0f;
00981   int len = 0;
00982   BLOB_CHOICE_IT c_it(curr_list);
00983   for (c_it.mark_cycle_pt(); !c_it.cycled_list(); c_it.forward()) {
00984     ASSERT_HOST(c_it.data() != NULL);
00985     ++len;
00986     denom += CertaintyScore(c_it.data()->certainty());
00987   }
00988   assert(len != 0);
00989   // The ideal situation would be to have the classifier scores for
00990   // classifying each position as each of the characters in the unicharset.
00991   // Since we can not do this because of speed, we add a very crude estimate
00992   // of what these scores for the "missing" classifications would sum up to.
00993   denom += (dict_->getUnicharset().size() - len) *
00994     CertaintyScore(language_model_ngram_nonmatch_score);
00995 
00996   return denom;
00997 }
00998 
00999 void LanguageModel::FillConsistencyInfo(
01000     int curr_col,
01001     bool word_end,
01002     BLOB_CHOICE *b,
01003     ViterbiStateEntry *parent_vse,
01004     WERD_RES *word_res,
01005     LMConsistencyInfo *consistency_info) {
01006   const UNICHARSET &unicharset = dict_->getUnicharset();
01007   UNICHAR_ID unichar_id = b->unichar_id();
01008   BLOB_CHOICE* parent_b = parent_vse != NULL ? parent_vse->curr_b : NULL;
01009 
01010   // Check punctuation validity.
01011   if (unicharset.get_ispunctuation(unichar_id)) consistency_info->num_punc++;
01012   if (dict_->GetPuncDawg() != NULL && !consistency_info->invalid_punc) {
01013     if (dict_->compound_marker(unichar_id) && parent_b != NULL &&
01014         (unicharset.get_isalpha(parent_b->unichar_id()) ||
01015          unicharset.get_isdigit(parent_b->unichar_id()))) {
01016       // reset punc_ref for compound words
01017       consistency_info->punc_ref = NO_EDGE;
01018     } else {
01019       bool is_apos = dict_->is_apostrophe(unichar_id);
01020       bool prev_is_numalpha = (parent_b != NULL &&
01021           (unicharset.get_isalpha(parent_b->unichar_id()) ||
01022            unicharset.get_isdigit(parent_b->unichar_id())));
01023       UNICHAR_ID pattern_unichar_id =
01024         (unicharset.get_isalpha(unichar_id) ||
01025          unicharset.get_isdigit(unichar_id) ||
01026          (is_apos && prev_is_numalpha)) ?
01027         Dawg::kPatternUnicharID : unichar_id;
01028       if (consistency_info->punc_ref == NO_EDGE ||
01029           pattern_unichar_id != Dawg::kPatternUnicharID ||
01030           dict_->GetPuncDawg()->edge_letter(consistency_info->punc_ref) !=
01031           Dawg::kPatternUnicharID) {
01032         NODE_REF node = Dict::GetStartingNode(dict_->GetPuncDawg(),
01033                                               consistency_info->punc_ref);
01034         consistency_info->punc_ref =
01035           (node != NO_EDGE) ? dict_->GetPuncDawg()->edge_char_of(
01036               node, pattern_unichar_id, word_end) : NO_EDGE;
01037         if (consistency_info->punc_ref == NO_EDGE) {
01038           consistency_info->invalid_punc = true;
01039         }
01040       }
01041     }
01042   }
01043 
01044   // Update case related counters.
01045   if (parent_vse != NULL && !word_end && dict_->compound_marker(unichar_id)) {
01046     // Reset counters if we are dealing with a compound word.
01047     consistency_info->num_lower = 0;
01048     consistency_info->num_non_first_upper = 0;
01049   }
01050   else if (unicharset.get_islower(unichar_id)) {
01051     consistency_info->num_lower++;
01052   } else if ((parent_b != NULL) && unicharset.get_isupper(unichar_id)) {
01053     if (unicharset.get_isupper(parent_b->unichar_id()) ||
01054         consistency_info->num_lower > 0 ||
01055         consistency_info->num_non_first_upper > 0) {
01056       consistency_info->num_non_first_upper++;
01057     }
01058   }
01059 
01060   // Initialize consistency_info->script_id (use script of unichar_id
01061   // if it is not Common, use script id recorded by the parent otherwise).
01062   // Set inconsistent_script to true if the script of the current unichar
01063   // is not consistent with that of the parent.
01064   consistency_info->script_id = unicharset.get_script(unichar_id);
01065   // Hiragana and Katakana can mix with Han.
01066   if (dict_->getUnicharset().han_sid() != dict_->getUnicharset().null_sid()) {
01067     if ((unicharset.hiragana_sid() != unicharset.null_sid() &&
01068          consistency_info->script_id == unicharset.hiragana_sid()) ||
01069         (unicharset.katakana_sid() != unicharset.null_sid() &&
01070          consistency_info->script_id == unicharset.katakana_sid())) {
01071       consistency_info->script_id = dict_->getUnicharset().han_sid();
01072     }
01073   }
01074 
01075   if (parent_vse != NULL &&
01076       (parent_vse->consistency_info.script_id !=
01077        dict_->getUnicharset().common_sid())) {
01078     int parent_script_id = parent_vse->consistency_info.script_id;
01079     // If script_id is Common, use script id of the parent instead.
01080     if (consistency_info->script_id == dict_->getUnicharset().common_sid()) {
01081       consistency_info->script_id = parent_script_id;
01082     }
01083     if (consistency_info->script_id != parent_script_id) {
01084       consistency_info->inconsistent_script = true;
01085     }
01086   }
01087 
01088   // Update chartype related counters.
01089   if (unicharset.get_isalpha(unichar_id)) {
01090     consistency_info->num_alphas++;
01091   } else if (unicharset.get_isdigit(unichar_id)) {
01092     consistency_info->num_digits++;
01093   } else if (!unicharset.get_ispunctuation(unichar_id)) {
01094     consistency_info->num_other++;
01095   }
01096 
01097   // Check font and spacing consistency.
01098   if (fontinfo_table_->size() > 0 && parent_b != NULL) {
01099     int fontinfo_id = -1;
01100     if (parent_b->fontinfo_id() == b->fontinfo_id() ||
01101         parent_b->fontinfo_id2() == b->fontinfo_id()) {
01102       fontinfo_id = b->fontinfo_id();
01103     } else if (parent_b->fontinfo_id() == b->fontinfo_id2() ||
01104                 parent_b->fontinfo_id2() == b->fontinfo_id2()) {
01105       fontinfo_id = b->fontinfo_id2();
01106     }
01107     if(language_model_debug_level > 1) {
01108       tprintf("pfont %s pfont %s font %s font2 %s common %s(%d)\n",
01109               (parent_b->fontinfo_id() >= 0) ?
01110                   fontinfo_table_->get(parent_b->fontinfo_id()).name : "" ,
01111               (parent_b->fontinfo_id2() >= 0) ?
01112                   fontinfo_table_->get(parent_b->fontinfo_id2()).name : "",
01113               (b->fontinfo_id() >= 0) ?
01114                   fontinfo_table_->get(b->fontinfo_id()).name : "",
01115               (fontinfo_id >= 0) ? fontinfo_table_->get(fontinfo_id).name : "",
01116               (fontinfo_id >= 0) ? fontinfo_table_->get(fontinfo_id).name : "",
01117               fontinfo_id);
01118     }
01119     if (!word_res->blob_widths.empty()) {  // if we have widths/gaps info
01120       bool expected_gap_found = false;
01121       float expected_gap;
01122       int temp_gap;
01123       if (fontinfo_id >= 0) {  // found a common font
01124         ASSERT_HOST(fontinfo_id < fontinfo_table_->size());
01125         if (fontinfo_table_->get(fontinfo_id).get_spacing(
01126             parent_b->unichar_id(), unichar_id, &temp_gap)) {
01127           expected_gap = temp_gap;
01128           expected_gap_found = true;
01129         }
01130       } else {
01131         consistency_info->inconsistent_font = true;
01132         // Get an average of the expected gaps in each font
01133         int num_addends = 0;
01134         expected_gap = 0;
01135         int temp_fid;
01136         for (int i = 0; i < 4; ++i) {
01137           if (i == 0) {
01138             temp_fid = parent_b->fontinfo_id();
01139           } else if (i == 1) {
01140             temp_fid = parent_b->fontinfo_id2();
01141           } else if (i == 2) {
01142             temp_fid = b->fontinfo_id();
01143           } else {
01144             temp_fid = b->fontinfo_id2();
01145           }
01146           ASSERT_HOST(temp_fid < 0 || fontinfo_table_->size());
01147           if (temp_fid >= 0 && fontinfo_table_->get(temp_fid).get_spacing(
01148               parent_b->unichar_id(), unichar_id, &temp_gap)) {
01149             expected_gap += temp_gap;
01150             num_addends++;
01151           }
01152         }
01153         expected_gap_found = (num_addends > 0);
01154         if (num_addends > 0) {
01155           expected_gap /= static_cast<float>(num_addends);
01156         }
01157       }
01158       if (expected_gap_found) {
01159         float actual_gap =
01160             static_cast<float>(word_res->GetBlobsGap(curr_col-1));
01161         float gap_ratio = expected_gap / actual_gap;
01162         // TODO(daria): find a good way to tune this heuristic estimate.
01163         if (gap_ratio < 1/2 || gap_ratio > 2) {
01164           consistency_info->num_inconsistent_spaces++;
01165         }
01166         if (language_model_debug_level > 1) {
01167           tprintf("spacing for %s(%d) %s(%d) col %d: expected %g actual %g\n",
01168                   unicharset.id_to_unichar(parent_b->unichar_id()),
01169                   parent_b->unichar_id(), unicharset.id_to_unichar(unichar_id),
01170                   unichar_id, curr_col, expected_gap, actual_gap);
01171         }
01172       }
01173     }
01174   }
01175 }
01176 
01177 float LanguageModel::ComputeAdjustedPathCost(ViterbiStateEntry *vse) {
01178   ASSERT_HOST(vse != NULL);
01179   if (params_model_.Initialized()) {
01180     float features[PTRAIN_NUM_FEATURE_TYPES];
01181     ExtractFeaturesFromPath(*vse, features);
01182     float cost = params_model_.ComputeCost(features);
01183     if (language_model_debug_level > 3) {
01184       tprintf("ComputeAdjustedPathCost %g ParamsModel features:\n", cost);
01185       if (language_model_debug_level >= 5) {
01186         for (int f = 0; f < PTRAIN_NUM_FEATURE_TYPES; ++f) {
01187           tprintf("%s=%g\n", kParamsTrainingFeatureTypeName[f], features[f]);
01188         }
01189       }
01190     }
01191     return cost * vse->outline_length;
01192   } else {
01193     float adjustment = 1.0f;
01194     if (vse->dawg_info == NULL || vse->dawg_info->permuter != FREQ_DAWG_PERM) {
01195       adjustment += language_model_penalty_non_freq_dict_word;
01196     }
01197     if (vse->dawg_info == NULL) {
01198       adjustment += language_model_penalty_non_dict_word;
01199       if (vse->length > language_model_min_compound_length) {
01200         adjustment += ((vse->length - language_model_min_compound_length) *
01201             language_model_penalty_increment);
01202       }
01203     }
01204     if (vse->associate_stats.shape_cost > 0) {
01205       adjustment += vse->associate_stats.shape_cost /
01206           static_cast<float>(vse->length);
01207     }
01208     if (language_model_ngram_on) {
01209       ASSERT_HOST(vse->ngram_info != NULL);
01210       return vse->ngram_info->ngram_and_classifier_cost * adjustment;
01211     } else {
01212       adjustment += ComputeConsistencyAdjustment(vse->dawg_info,
01213                                                  vse->consistency_info);
01214       return vse->ratings_sum * adjustment;
01215     }
01216   }
01217 }
01218 
01219 void LanguageModel::UpdateBestChoice(
01220     ViterbiStateEntry *vse,
01221     LMPainPoints *pain_points,
01222     WERD_RES *word_res,
01223     BestChoiceBundle *best_choice_bundle,
01224     BlamerBundle *blamer_bundle) {
01225   bool truth_path;
01226   WERD_CHOICE *word = ConstructWord(vse, word_res, &best_choice_bundle->fixpt,
01227                                     blamer_bundle, &truth_path);
01228   ASSERT_HOST(word != NULL);
01229   if (dict_->stopper_debug_level >= 1) {
01230     STRING word_str;
01231     word->string_and_lengths(&word_str, NULL);
01232     vse->Print(word_str.string());
01233   }
01234   if (language_model_debug_level > 0) {
01235     word->print("UpdateBestChoice() constructed word");
01236   }
01237   // Record features from the current path if necessary.
01238   ParamsTrainingHypothesis curr_hyp;
01239   if (blamer_bundle != NULL) {
01240     if (vse->dawg_info != NULL) vse->dawg_info->permuter =
01241         static_cast<PermuterType>(word->permuter());
01242     ExtractFeaturesFromPath(*vse, curr_hyp.features);
01243     word->string_and_lengths(&(curr_hyp.str), NULL);
01244     curr_hyp.cost = vse->cost;  // record cost for error rate computations
01245     if (language_model_debug_level > 0) {
01246       tprintf("Raw features extracted from %s (cost=%g) [ ",
01247               curr_hyp.str.string(), curr_hyp.cost);
01248       for (int deb_i = 0; deb_i < PTRAIN_NUM_FEATURE_TYPES; ++deb_i) {
01249         tprintf("%g ", curr_hyp.features[deb_i]);
01250       }
01251       tprintf("]\n");
01252     }
01253     // Record the current hypothesis in params_training_bundle.
01254     blamer_bundle->AddHypothesis(curr_hyp);
01255     if (truth_path)
01256       blamer_bundle->UpdateBestRating(word->rating());
01257   }
01258   if (blamer_bundle != NULL && blamer_bundle->GuidedSegsearchStillGoing()) {
01259     // The word was constructed solely for blamer_bundle->AddHypothesis, so
01260     // we no longer need it.
01261     delete word;
01262     return;
01263   }
01264   if (word_res->chopped_word != NULL && !word_res->chopped_word->blobs.empty())
01265     word->SetScriptPositions(false, word_res->chopped_word);
01266   // Update and log new raw_choice if needed.
01267   if (word_res->raw_choice == NULL ||
01268       word->rating() < word_res->raw_choice->rating()) {
01269     if (word_res->LogNewRawChoice(word) && language_model_debug_level > 0)
01270       tprintf("Updated raw choice\n");
01271   }
01272   // Set the modified rating for best choice to vse->cost and log best choice.
01273   word->set_rating(vse->cost);
01274   // Call LogNewChoice() for best choice from Dict::adjust_word() since it
01275   // computes adjust_factor that is used by the adaption code (e.g. by
01276   // ClassifyAdaptableWord() to compute adaption acceptance thresholds).
01277   // Note: the rating of the word is not adjusted.
01278   dict_->adjust_word(word, vse->dawg_info == NULL,
01279                      vse->consistency_info.xht_decision, 0.0,
01280                      false, language_model_debug_level > 0);
01281   // Hand ownership of the word over to the word_res.
01282   if (!word_res->LogNewCookedChoice(dict_->tessedit_truncate_wordchoice_log,
01283                                     dict_->stopper_debug_level >= 1, word)) {
01284     // The word was so bad that it was deleted.
01285     return;
01286   }
01287   if (word_res->best_choice == word) {
01288     // Word was the new best.
01289     if (dict_->AcceptableChoice(*word, vse->consistency_info.xht_decision) &&
01290         AcceptablePath(*vse)) {
01291       acceptable_choice_found_ = true;
01292     }
01293     // Update best_choice_bundle.
01294     best_choice_bundle->updated = true;
01295     best_choice_bundle->best_vse = vse;
01296     if (language_model_debug_level > 0) {
01297       tprintf("Updated best choice\n");
01298       word->print_state("New state ");
01299     }
01300     // Update hyphen state if we are dealing with a dictionary word.
01301     if (vse->dawg_info != NULL) {
01302       if (dict_->has_hyphen_end(*word)) {
01303         dict_->set_hyphen_word(*word, *(dawg_args_->active_dawgs));
01304       } else {
01305         dict_->reset_hyphen_vars(true);
01306       }
01307     }
01308 
01309     if (blamer_bundle != NULL) {
01310       blamer_bundle->set_best_choice_is_dict_and_top_choice(
01311           vse->dawg_info != NULL && vse->top_choice_flags);
01312     }
01313   }
01314   if (wordrec_display_segmentations) {
01315     word->DisplaySegmentation(word_res->chopped_word);
01316   }
01317 }
01318 
01319 void LanguageModel::ExtractFeaturesFromPath(
01320     const ViterbiStateEntry &vse, float features[]) {
01321   memset(features, 0, sizeof(float) * PTRAIN_NUM_FEATURE_TYPES);
01322   // Record dictionary match info.
01323   int len = vse.length <= kMaxSmallWordUnichars ? 0 :
01324       vse.length <= kMaxMediumWordUnichars ? 1 : 2;
01325   if (vse.dawg_info != NULL) {
01326     int permuter = vse.dawg_info->permuter;
01327     if (permuter == NUMBER_PERM || permuter == USER_PATTERN_PERM) {
01328       if (vse.consistency_info.num_digits == vse.length) {
01329         features[PTRAIN_DIGITS_SHORT+len] = 1.0;
01330       } else {
01331         features[PTRAIN_NUM_SHORT+len] = 1.0;
01332       }
01333     } else if (permuter == DOC_DAWG_PERM) {
01334       features[PTRAIN_DOC_SHORT+len] = 1.0;
01335     } else if (permuter == SYSTEM_DAWG_PERM || permuter == USER_DAWG_PERM ||
01336         permuter == COMPOUND_PERM) {
01337       features[PTRAIN_DICT_SHORT+len] = 1.0;
01338     } else if (permuter == FREQ_DAWG_PERM) {
01339       features[PTRAIN_FREQ_SHORT+len] = 1.0;
01340     }
01341   }
01342   // Record shape cost feature (normalized by path length).
01343   features[PTRAIN_SHAPE_COST_PER_CHAR] =
01344       vse.associate_stats.shape_cost / static_cast<float>(vse.length);
01345   // Record ngram cost. (normalized by the path length).
01346   features[PTRAIN_NGRAM_COST_PER_CHAR] = 0.0;
01347   if (vse.ngram_info != NULL) {
01348     features[PTRAIN_NGRAM_COST_PER_CHAR] =
01349         vse.ngram_info->ngram_cost / static_cast<float>(vse.length);
01350   }
01351   // Record consistency-related features.
01352   // Disabled this feature for due to its poor performance.
01353   // features[PTRAIN_NUM_BAD_PUNC] = vse.consistency_info.NumInconsistentPunc();
01354   features[PTRAIN_NUM_BAD_CASE] = vse.consistency_info.NumInconsistentCase();
01355   features[PTRAIN_XHEIGHT_CONSISTENCY] = vse.consistency_info.xht_decision;
01356   features[PTRAIN_NUM_BAD_CHAR_TYPE] = vse.dawg_info == NULL ?
01357       vse.consistency_info.NumInconsistentChartype() : 0.0;
01358   features[PTRAIN_NUM_BAD_SPACING] =
01359       vse.consistency_info.NumInconsistentSpaces();
01360   // Disabled this feature for now due to its poor performance.
01361   // features[PTRAIN_NUM_BAD_FONT] = vse.consistency_info.inconsistent_font;
01362 
01363   // Classifier-related features.
01364   features[PTRAIN_RATING_PER_CHAR] =
01365       vse.ratings_sum / static_cast<float>(vse.outline_length);
01366 }
01367 
01368 WERD_CHOICE *LanguageModel::ConstructWord(
01369     ViterbiStateEntry *vse,
01370     WERD_RES *word_res,
01371     DANGERR *fixpt,
01372     BlamerBundle *blamer_bundle,
01373     bool *truth_path) {
01374   if (truth_path != NULL) {
01375     *truth_path =
01376         (blamer_bundle != NULL &&
01377          vse->length == blamer_bundle->correct_segmentation_length());
01378   }
01379   BLOB_CHOICE *curr_b = vse->curr_b;
01380   ViterbiStateEntry *curr_vse = vse;
01381 
01382   int i;
01383   bool compound = dict_->hyphenated();  // treat hyphenated words as compound
01384 
01385   // Re-compute the variance of the width-to-height ratios (since we now
01386   // can compute the mean over the whole word).
01387   float full_wh_ratio_mean = 0.0f;
01388   if (vse->associate_stats.full_wh_ratio_var != 0.0f) {
01389     vse->associate_stats.shape_cost -= vse->associate_stats.full_wh_ratio_var;
01390     full_wh_ratio_mean = (vse->associate_stats.full_wh_ratio_total /
01391                           static_cast<float>(vse->length));
01392     vse->associate_stats.full_wh_ratio_var = 0.0f;
01393   }
01394 
01395   // Construct a WERD_CHOICE by tracing parent pointers.
01396   WERD_CHOICE *word = new WERD_CHOICE(word_res->uch_set, vse->length);
01397   word->set_length(vse->length);
01398   int total_blobs = 0;
01399   for (i = (vse->length-1); i >= 0; --i) {
01400     if (blamer_bundle != NULL && truth_path != NULL && *truth_path &&
01401         !blamer_bundle->MatrixPositionCorrect(i, curr_b->matrix_cell())) {
01402         *truth_path = false;
01403     }
01404     // The number of blobs used for this choice is row - col + 1.
01405     int num_blobs = curr_b->matrix_cell().row - curr_b->matrix_cell().col + 1;
01406     total_blobs += num_blobs;
01407     word->set_blob_choice(i, num_blobs, curr_b);
01408     // Update the width-to-height ratio variance. Useful non-space delimited
01409     // languages to ensure that the blobs are of uniform width.
01410     // Skip leading and trailing punctuation when computing the variance.
01411     if ((full_wh_ratio_mean != 0.0f &&
01412          ((curr_vse != vse && curr_vse->parent_vse != NULL) ||
01413           !dict_->getUnicharset().get_ispunctuation(curr_b->unichar_id())))) {
01414       vse->associate_stats.full_wh_ratio_var +=
01415         pow(full_wh_ratio_mean - curr_vse->associate_stats.full_wh_ratio, 2);
01416       if (language_model_debug_level > 2) {
01417         tprintf("full_wh_ratio_var += (%g-%g)^2\n",
01418                 full_wh_ratio_mean, curr_vse->associate_stats.full_wh_ratio);
01419       }
01420     }
01421 
01422     // Mark the word as compound if compound permuter was set for any of
01423     // the unichars on the path (usually this will happen for unichars
01424     // that are compounding operators, like "-" and "/").
01425     if (!compound && curr_vse->dawg_info &&
01426         curr_vse->dawg_info->permuter == COMPOUND_PERM) compound = true;
01427 
01428     // Update curr_* pointers.
01429     curr_vse = curr_vse->parent_vse;
01430     if (curr_vse == NULL) break;
01431     curr_b = curr_vse->curr_b;
01432   }
01433   ASSERT_HOST(i == 0);  // check that we recorded all the unichar ids.
01434   ASSERT_HOST(total_blobs == word_res->ratings->dimension());
01435   // Re-adjust shape cost to include the updated width-to-height variance.
01436   if (full_wh_ratio_mean != 0.0f) {
01437     vse->associate_stats.shape_cost += vse->associate_stats.full_wh_ratio_var;
01438   }
01439 
01440   word->set_rating(vse->ratings_sum);
01441   word->set_certainty(vse->min_certainty);
01442   word->set_x_heights(vse->consistency_info.BodyMinXHeight(),
01443                       vse->consistency_info.BodyMaxXHeight());
01444   if (vse->dawg_info != NULL) {
01445     word->set_permuter(compound ? COMPOUND_PERM : vse->dawg_info->permuter);
01446   } else if (language_model_ngram_on && !vse->ngram_info->pruned) {
01447     word->set_permuter(NGRAM_PERM);
01448   } else if (vse->top_choice_flags) {
01449     word->set_permuter(TOP_CHOICE_PERM);
01450   } else {
01451     word->set_permuter(NO_PERM);
01452   }
01453   word->set_dangerous_ambig_found_(!dict_->NoDangerousAmbig(word, fixpt, true,
01454                                                             word_res->ratings));
01455   return word;
01456 }
01457 
01458 }  // namespace tesseract
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines