tesseract
3.03
|
00001 00002 // File: language_model.cpp 00003 // Description: Functions that utilize the knowledge about the properties, 00004 // structure and statistics of the language to help recognition. 00005 // Author: Daria Antonova 00006 // Created: Mon Nov 11 11:26:43 PST 2009 00007 // 00008 // (C) Copyright 2009, Google Inc. 00009 // Licensed under the Apache License, Version 2.0 (the "License"); 00010 // you may not use this file except in compliance with the License. 00011 // You may obtain a copy of the License at 00012 // http://www.apache.org/licenses/LICENSE-2.0 00013 // Unless required by applicable law or agreed to in writing, software 00014 // distributed under the License is distributed on an "AS IS" BASIS, 00015 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 00016 // See the License for the specific language governing permissions and 00017 // limitations under the License. 00018 // 00020 00021 #include <math.h> 00022 00023 #include "language_model.h" 00024 00025 #include "dawg.h" 00026 #include "freelist.h" 00027 #include "intproto.h" 00028 #include "helpers.h" 00029 #include "lm_state.h" 00030 #include "lm_pain_points.h" 00031 #include "matrix.h" 00032 #include "params.h" 00033 #include "params_training_featdef.h" 00034 00035 #if defined(_MSC_VER) || defined(ANDROID) 00036 double log2(double n) { 00037 return log(n) / log(2.0); 00038 } 00039 #endif // _MSC_VER 00040 00041 namespace tesseract { 00042 00043 const float LanguageModel::kMaxAvgNgramCost = 25.0f; 00044 00045 LanguageModel::LanguageModel(const UnicityTable<FontInfo> *fontinfo_table, 00046 Dict *dict) 00047 : INT_MEMBER(language_model_debug_level, 0, "Language model debug level", 00048 dict->getCCUtil()->params()), 00049 BOOL_INIT_MEMBER(language_model_ngram_on, false, 00050 "Turn on/off the use of character ngram model", 00051 dict->getCCUtil()->params()), 00052 INT_MEMBER(language_model_ngram_order, 8, 00053 "Maximum order of the character ngram model", 00054 dict->getCCUtil()->params()), 00055 INT_MEMBER(language_model_viterbi_list_max_num_prunable, 10, 00056 "Maximum number of prunable (those for which" 00057 " PrunablePath() is true) entries in each viterbi list" 00058 " recorded in BLOB_CHOICEs", 00059 dict->getCCUtil()->params()), 00060 INT_MEMBER(language_model_viterbi_list_max_size, 500, 00061 "Maximum size of viterbi lists recorded in BLOB_CHOICEs", 00062 dict->getCCUtil()->params()), 00063 double_MEMBER(language_model_ngram_small_prob, 0.000001, 00064 "To avoid overly small denominators use this as the " 00065 "floor of the probability returned by the ngram model.", 00066 dict->getCCUtil()->params()), 00067 double_MEMBER(language_model_ngram_nonmatch_score, -40.0, 00068 "Average classifier score of a non-matching unichar.", 00069 dict->getCCUtil()->params()), 00070 BOOL_MEMBER(language_model_ngram_use_only_first_uft8_step, false, 00071 "Use only the first UTF8 step of the given string" 00072 " when computing log probabilities.", 00073 dict->getCCUtil()->params()), 00074 double_MEMBER(language_model_ngram_scale_factor, 0.03, 00075 "Strength of the character ngram model relative to the" 00076 " character classifier ", 00077 dict->getCCUtil()->params()), 00078 double_MEMBER(language_model_ngram_rating_factor, 16.0, 00079 "Factor to bring log-probs into the same range as ratings" 00080 " when multiplied by outline length ", 00081 dict->getCCUtil()->params()), 00082 BOOL_MEMBER(language_model_ngram_space_delimited_language, true, 00083 "Words are delimited by space", 00084 dict->getCCUtil()->params()), 00085 INT_MEMBER(language_model_min_compound_length, 3, 00086 "Minimum length of compound words", 00087 dict->getCCUtil()->params()), 00088 double_MEMBER(language_model_penalty_non_freq_dict_word, 0.1, 00089 "Penalty for words not in the frequent word dictionary", 00090 dict->getCCUtil()->params()), 00091 double_MEMBER(language_model_penalty_non_dict_word, 0.15, 00092 "Penalty for non-dictionary words", 00093 dict->getCCUtil()->params()), 00094 double_MEMBER(language_model_penalty_punc, 0.2, 00095 "Penalty for inconsistent punctuation", 00096 dict->getCCUtil()->params()), 00097 double_MEMBER(language_model_penalty_case, 0.1, 00098 "Penalty for inconsistent case", 00099 dict->getCCUtil()->params()), 00100 double_MEMBER(language_model_penalty_script, 0.5, 00101 "Penalty for inconsistent script", 00102 dict->getCCUtil()->params()), 00103 double_MEMBER(language_model_penalty_chartype, 0.3, 00104 "Penalty for inconsistent character type", 00105 dict->getCCUtil()->params()), 00106 // TODO(daria, rays): enable font consistency checking 00107 // after improving font analysis. 00108 double_MEMBER(language_model_penalty_font, 0.00, 00109 "Penalty for inconsistent font", 00110 dict->getCCUtil()->params()), 00111 double_MEMBER(language_model_penalty_spacing, 0.05, 00112 "Penalty for inconsistent spacing", 00113 dict->getCCUtil()->params()), 00114 double_MEMBER(language_model_penalty_increment, 0.01, 00115 "Penalty increment", 00116 dict->getCCUtil()->params()), 00117 INT_MEMBER(wordrec_display_segmentations, 0, "Display Segmentations", 00118 dict->getCCUtil()->params()), 00119 BOOL_INIT_MEMBER(language_model_use_sigmoidal_certainty, false, 00120 "Use sigmoidal score for certainty", 00121 dict->getCCUtil()->params()), 00122 fontinfo_table_(fontinfo_table), dict_(dict), 00123 fixed_pitch_(false), max_char_wh_ratio_(0.0), 00124 acceptable_choice_found_(false) { 00125 ASSERT_HOST(dict_ != NULL); 00126 dawg_args_ = new DawgArgs(NULL, new DawgPositionVector(), NO_PERM); 00127 very_beginning_active_dawgs_ = new DawgPositionVector(); 00128 beginning_active_dawgs_ = new DawgPositionVector(); 00129 } 00130 00131 LanguageModel::~LanguageModel() { 00132 delete very_beginning_active_dawgs_; 00133 delete beginning_active_dawgs_; 00134 delete dawg_args_->updated_dawgs; 00135 delete dawg_args_; 00136 } 00137 00138 void LanguageModel::InitForWord(const WERD_CHOICE *prev_word, 00139 bool fixed_pitch, float max_char_wh_ratio, 00140 float rating_cert_scale) { 00141 fixed_pitch_ = fixed_pitch; 00142 max_char_wh_ratio_ = max_char_wh_ratio; 00143 rating_cert_scale_ = rating_cert_scale; 00144 acceptable_choice_found_ = false; 00145 correct_segmentation_explored_ = false; 00146 00147 // Initialize vectors with beginning DawgInfos. 00148 very_beginning_active_dawgs_->clear(); 00149 dict_->init_active_dawgs(very_beginning_active_dawgs_, false); 00150 beginning_active_dawgs_->clear(); 00151 dict_->default_dawgs(beginning_active_dawgs_, false); 00152 00153 // Fill prev_word_str_ with the last language_model_ngram_order 00154 // unichars from prev_word. 00155 if (language_model_ngram_on) { 00156 if (prev_word != NULL && prev_word->unichar_string() != NULL) { 00157 prev_word_str_ = prev_word->unichar_string(); 00158 if (language_model_ngram_space_delimited_language) prev_word_str_ += ' '; 00159 } else { 00160 prev_word_str_ = " "; 00161 } 00162 const char *str_ptr = prev_word_str_.string(); 00163 const char *str_end = str_ptr + prev_word_str_.length(); 00164 int step; 00165 prev_word_unichar_step_len_ = 0; 00166 while (str_ptr != str_end && (step = UNICHAR::utf8_step(str_ptr))) { 00167 str_ptr += step; 00168 ++prev_word_unichar_step_len_; 00169 } 00170 ASSERT_HOST(str_ptr == str_end); 00171 } 00172 } 00173 00174 // Helper scans the collection of predecessors for competing siblings that 00175 // have the same letter with the opposite case, setting competing_vse. 00176 static void ScanParentsForCaseMix(const UNICHARSET& unicharset, 00177 LanguageModelState* parent_node) { 00178 if (parent_node == NULL) return; 00179 ViterbiStateEntry_IT vit(&parent_node->viterbi_state_entries); 00180 for (vit.mark_cycle_pt(); !vit.cycled_list(); vit.forward()) { 00181 ViterbiStateEntry* vse = vit.data(); 00182 vse->competing_vse = NULL; 00183 UNICHAR_ID unichar_id = vse->curr_b->unichar_id(); 00184 if (unicharset.get_isupper(unichar_id) || 00185 unicharset.get_islower(unichar_id)) { 00186 UNICHAR_ID other_case = unicharset.get_other_case(unichar_id); 00187 if (other_case == unichar_id) continue; // Not in unicharset. 00188 // Find other case in same list. There could be multiple entries with 00189 // the same unichar_id, but in theory, they should all point to the 00190 // same BLOB_CHOICE, and that is what we will be using to decide 00191 // which to keep. 00192 ViterbiStateEntry_IT vit2(&parent_node->viterbi_state_entries); 00193 for (vit2.mark_cycle_pt(); !vit2.cycled_list() && 00194 vit2.data()->curr_b->unichar_id() != other_case; 00195 vit2.forward()) {} 00196 if (!vit2.cycled_list()) { 00197 vse->competing_vse = vit2.data(); 00198 } 00199 } 00200 } 00201 } 00202 00203 // Helper returns true if the given choice has a better case variant before 00204 // it in the choice_list that is not distinguishable by size. 00205 static bool HasBetterCaseVariant(const UNICHARSET& unicharset, 00206 const BLOB_CHOICE* choice, 00207 BLOB_CHOICE_LIST* choices) { 00208 UNICHAR_ID choice_id = choice->unichar_id(); 00209 UNICHAR_ID other_case = unicharset.get_other_case(choice_id); 00210 if (other_case == choice_id || other_case == INVALID_UNICHAR_ID) 00211 return false; // Not upper or lower or not in unicharset. 00212 if (unicharset.SizesDistinct(choice_id, other_case)) 00213 return false; // Can be separated by size. 00214 BLOB_CHOICE_IT bc_it(choices); 00215 for (bc_it.mark_cycle_pt(); !bc_it.cycled_list(); bc_it.forward()) { 00216 BLOB_CHOICE* better_choice = bc_it.data(); 00217 if (better_choice->unichar_id() == other_case) 00218 return true; // Found an earlier instance of other_case. 00219 else if (better_choice == choice) 00220 return false; // Reached the original choice. 00221 } 00222 return false; // Should never happen, but just in case. 00223 } 00224 00225 // UpdateState has the job of combining the ViterbiStateEntry lists on each 00226 // of the choices on parent_list with each of the blob choices in curr_list, 00227 // making a new ViterbiStateEntry for each sensible path. 00228 // This could be a huge set of combinations, creating a lot of work only to 00229 // be truncated by some beam limit, but only certain kinds of paths will 00230 // continue at the next step: 00231 // paths that are liked by the language model: either a DAWG or the n-gram 00232 // model, where active. 00233 // paths that represent some kind of top choice. The old permuter permuted 00234 // the top raw classifier score, the top upper case word and the top lower- 00235 // case word. UpdateState now concentrates its top-choice paths on top 00236 // lower-case, top upper-case (or caseless alpha), and top digit sequence, 00237 // with allowance for continuation of these paths through blobs where such 00238 // a character does not appear in the choices list. 00239 // GetNextParentVSE enforces some of these models to minimize the number of 00240 // calls to AddViterbiStateEntry, even prior to looking at the language model. 00241 // Thus an n-blob sequence of [l1I] will produce 3n calls to 00242 // AddViterbiStateEntry instead of 3^n. 00243 // Of course it isn't quite that simple as Title Case is handled by allowing 00244 // lower case to continue an upper case initial, but it has to be detected 00245 // in the combiner so it knows which upper case letters are initial alphas. 00246 bool LanguageModel::UpdateState( 00247 bool just_classified, 00248 int curr_col, int curr_row, 00249 BLOB_CHOICE_LIST *curr_list, 00250 LanguageModelState *parent_node, 00251 LMPainPoints *pain_points, 00252 WERD_RES *word_res, 00253 BestChoiceBundle *best_choice_bundle, 00254 BlamerBundle *blamer_bundle) { 00255 if (language_model_debug_level > 0) { 00256 tprintf("\nUpdateState: col=%d row=%d %s", 00257 curr_col, curr_row, just_classified ? "just_classified" : ""); 00258 if (language_model_debug_level > 5) 00259 tprintf("(parent=%p)\n", parent_node); 00260 else 00261 tprintf("\n"); 00262 } 00263 // Initialize helper variables. 00264 bool word_end = (curr_row+1 >= word_res->ratings->dimension()); 00265 bool new_changed = false; 00266 float denom = (language_model_ngram_on) ? ComputeDenom(curr_list) : 1.0f; 00267 const UNICHARSET& unicharset = dict_->getUnicharset(); 00268 BLOB_CHOICE *first_lower = NULL; 00269 BLOB_CHOICE *first_upper = NULL; 00270 BLOB_CHOICE *first_digit = NULL; 00271 bool has_alnum_mix = false; 00272 if (parent_node != NULL) { 00273 int result = SetTopParentLowerUpperDigit(parent_node); 00274 if (result < 0) { 00275 if (language_model_debug_level > 0) 00276 tprintf("No parents found to process\n"); 00277 return false; 00278 } 00279 if (result > 0) 00280 has_alnum_mix = true; 00281 } 00282 if (!GetTopLowerUpperDigit(curr_list, &first_lower, &first_upper, 00283 &first_digit)) 00284 has_alnum_mix = false;; 00285 ScanParentsForCaseMix(unicharset, parent_node); 00286 if (language_model_debug_level > 3 && parent_node != NULL) { 00287 parent_node->Print("Parent viterbi list"); 00288 } 00289 LanguageModelState *curr_state = best_choice_bundle->beam[curr_row]; 00290 00291 // Call AddViterbiStateEntry() for each parent+child ViterbiStateEntry. 00292 ViterbiStateEntry_IT vit; 00293 BLOB_CHOICE_IT c_it(curr_list); 00294 for (c_it.mark_cycle_pt(); !c_it.cycled_list(); c_it.forward()) { 00295 BLOB_CHOICE* choice = c_it.data(); 00296 // TODO(antonova): make sure commenting this out if ok for ngram 00297 // model scoring (I think this was introduced to fix ngram model quirks). 00298 // Skip NULL unichars unless it is the only choice. 00299 //if (!curr_list->singleton() && c_it.data()->unichar_id() == 0) continue; 00300 UNICHAR_ID unichar_id = choice->unichar_id(); 00301 if (unicharset.get_fragment(unichar_id)) { 00302 continue; // skip fragments 00303 } 00304 // Set top choice flags. 00305 LanguageModelFlagsType blob_choice_flags = kXhtConsistentFlag; 00306 if (c_it.at_first() || !new_changed) 00307 blob_choice_flags |= kSmallestRatingFlag; 00308 if (first_lower == choice) blob_choice_flags |= kLowerCaseFlag; 00309 if (first_upper == choice) blob_choice_flags |= kUpperCaseFlag; 00310 if (first_digit == choice) blob_choice_flags |= kDigitFlag; 00311 00312 if (parent_node == NULL) { 00313 // Process the beginning of a word. 00314 // If there is a better case variant that is not distinguished by size, 00315 // skip this blob choice, as we have no choice but to accept the result 00316 // of the character classifier to distinguish between them, even if 00317 // followed by an upper case. 00318 // With words like iPoc, and other CamelBackWords, the lower-upper 00319 // transition can only be achieved if the classifier has the correct case 00320 // as the top choice, and leaving an initial I lower down the list 00321 // increases the chances of choosing IPoc simply because it doesn't 00322 // include such a transition. iPoc will beat iPOC and ipoc because 00323 // the other words are baseline/x-height inconsistent. 00324 if (HasBetterCaseVariant(unicharset, choice, curr_list)) 00325 continue; 00326 // Upper counts as lower at the beginning of a word. 00327 if (blob_choice_flags & kUpperCaseFlag) 00328 blob_choice_flags |= kLowerCaseFlag; 00329 new_changed |= AddViterbiStateEntry( 00330 blob_choice_flags, denom, word_end, curr_col, curr_row, 00331 choice, curr_state, NULL, pain_points, 00332 word_res, best_choice_bundle, blamer_bundle); 00333 } else { 00334 // Get viterbi entries from each parent ViterbiStateEntry. 00335 vit.set_to_list(&parent_node->viterbi_state_entries); 00336 int vit_counter = 0; 00337 vit.mark_cycle_pt(); 00338 ViterbiStateEntry* parent_vse = NULL; 00339 LanguageModelFlagsType top_choice_flags; 00340 while ((parent_vse = GetNextParentVSE(just_classified, has_alnum_mix, 00341 c_it.data(), blob_choice_flags, 00342 unicharset, word_res, &vit, 00343 &top_choice_flags)) != NULL) { 00344 // Skip pruned entries and do not look at prunable entries if already 00345 // examined language_model_viterbi_list_max_num_prunable of those. 00346 if (PrunablePath(*parent_vse) && 00347 (++vit_counter > language_model_viterbi_list_max_num_prunable || 00348 (language_model_ngram_on && parent_vse->ngram_info->pruned))) { 00349 continue; 00350 } 00351 // If the parent has no alnum choice, (ie choice is the first in a 00352 // string of alnum), and there is a better case variant that is not 00353 // distinguished by size, skip this blob choice/parent, as with the 00354 // initial blob treatment above. 00355 if (!parent_vse->HasAlnumChoice(unicharset) && 00356 HasBetterCaseVariant(unicharset, choice, curr_list)) 00357 continue; 00358 // Create a new ViterbiStateEntry if BLOB_CHOICE in c_it.data() 00359 // looks good according to the Dawgs or character ngram model. 00360 new_changed |= AddViterbiStateEntry( 00361 top_choice_flags, denom, word_end, curr_col, curr_row, 00362 c_it.data(), curr_state, parent_vse, pain_points, 00363 word_res, best_choice_bundle, blamer_bundle); 00364 } 00365 } 00366 } 00367 return new_changed; 00368 } 00369 00370 // Finds the first lower and upper case letter and first digit in curr_list. 00371 // For non-upper/lower languages, alpha counts as upper. 00372 // Uses the first character in the list in place of empty results. 00373 // Returns true if both alpha and digits are found. 00374 bool LanguageModel::GetTopLowerUpperDigit(BLOB_CHOICE_LIST *curr_list, 00375 BLOB_CHOICE **first_lower, 00376 BLOB_CHOICE **first_upper, 00377 BLOB_CHOICE **first_digit) const { 00378 BLOB_CHOICE_IT c_it(curr_list); 00379 const UNICHARSET &unicharset = dict_->getUnicharset(); 00380 BLOB_CHOICE *first_unichar = NULL; 00381 for (c_it.mark_cycle_pt(); !c_it.cycled_list(); c_it.forward()) { 00382 UNICHAR_ID unichar_id = c_it.data()->unichar_id(); 00383 if (unicharset.get_fragment(unichar_id)) continue; // skip fragments 00384 if (first_unichar == NULL) first_unichar = c_it.data(); 00385 if (*first_lower == NULL && unicharset.get_islower(unichar_id)) { 00386 *first_lower = c_it.data(); 00387 } 00388 if (*first_upper == NULL && unicharset.get_isalpha(unichar_id) && 00389 !unicharset.get_islower(unichar_id)) { 00390 *first_upper = c_it.data(); 00391 } 00392 if (*first_digit == NULL && unicharset.get_isdigit(unichar_id)) { 00393 *first_digit = c_it.data(); 00394 } 00395 } 00396 ASSERT_HOST(first_unichar != NULL); 00397 bool mixed = (*first_lower != NULL || *first_upper != NULL) && 00398 *first_digit != NULL; 00399 if (*first_lower == NULL) *first_lower = first_unichar; 00400 if (*first_upper == NULL) *first_upper = first_unichar; 00401 if (*first_digit == NULL) *first_digit = first_unichar; 00402 return mixed; 00403 } 00404 00405 // Forces there to be at least one entry in the overall set of the 00406 // viterbi_state_entries of each element of parent_node that has the 00407 // top_choice_flag set for lower, upper and digit using the same rules as 00408 // GetTopLowerUpperDigit, setting the flag on the first found suitable 00409 // candidate, whether or not the flag is set on some other parent. 00410 // Returns 1 if both alpha and digits are found among the parents, -1 if no 00411 // parents are found at all (a legitimate case), and 0 otherwise. 00412 int LanguageModel::SetTopParentLowerUpperDigit( 00413 LanguageModelState *parent_node) const { 00414 if (parent_node == NULL) return -1; 00415 UNICHAR_ID top_id = INVALID_UNICHAR_ID; 00416 ViterbiStateEntry* top_lower = NULL; 00417 ViterbiStateEntry* top_upper = NULL; 00418 ViterbiStateEntry* top_digit = NULL; 00419 ViterbiStateEntry* top_choice = NULL; 00420 float lower_rating = 0.0f; 00421 float upper_rating = 0.0f; 00422 float digit_rating = 0.0f; 00423 float top_rating = 0.0f; 00424 const UNICHARSET &unicharset = dict_->getUnicharset(); 00425 ViterbiStateEntry_IT vit(&parent_node->viterbi_state_entries); 00426 for (vit.mark_cycle_pt(); !vit.cycled_list(); vit.forward()) { 00427 ViterbiStateEntry* vse = vit.data(); 00428 // INVALID_UNICHAR_ID should be treated like a zero-width joiner, so scan 00429 // back to the real character if needed. 00430 ViterbiStateEntry* unichar_vse = vse; 00431 UNICHAR_ID unichar_id = unichar_vse->curr_b->unichar_id(); 00432 float rating = unichar_vse->curr_b->rating(); 00433 while (unichar_id == INVALID_UNICHAR_ID && 00434 unichar_vse->parent_vse != NULL) { 00435 unichar_vse = unichar_vse->parent_vse; 00436 unichar_id = unichar_vse->curr_b->unichar_id(); 00437 rating = unichar_vse->curr_b->rating(); 00438 } 00439 if (unichar_id != INVALID_UNICHAR_ID) { 00440 if (unicharset.get_islower(unichar_id)) { 00441 if (top_lower == NULL || lower_rating > rating) { 00442 top_lower = vse; 00443 lower_rating = rating; 00444 } 00445 } else if (unicharset.get_isalpha(unichar_id)) { 00446 if (top_upper == NULL || upper_rating > rating) { 00447 top_upper = vse; 00448 upper_rating = rating; 00449 } 00450 } else if (unicharset.get_isdigit(unichar_id)) { 00451 if (top_digit == NULL || digit_rating > rating) { 00452 top_digit = vse; 00453 digit_rating = rating; 00454 } 00455 } 00456 } 00457 if (top_choice == NULL || top_rating > rating) { 00458 top_choice = vse; 00459 top_rating = rating; 00460 top_id = unichar_id; 00461 } 00462 } 00463 if (top_choice == NULL) return -1; 00464 bool mixed = (top_lower != NULL || top_upper != NULL) && 00465 top_digit != NULL; 00466 if (top_lower == NULL) top_lower = top_choice; 00467 top_lower->top_choice_flags |= kLowerCaseFlag; 00468 if (top_upper == NULL) top_upper = top_choice; 00469 top_upper->top_choice_flags |= kUpperCaseFlag; 00470 if (top_digit == NULL) top_digit = top_choice; 00471 top_digit->top_choice_flags |= kDigitFlag; 00472 top_choice->top_choice_flags |= kSmallestRatingFlag; 00473 if (top_id != INVALID_UNICHAR_ID && dict_->compound_marker(top_id) && 00474 (top_choice->top_choice_flags & 00475 (kLowerCaseFlag | kUpperCaseFlag | kDigitFlag))) { 00476 // If the compound marker top choice carries any of the top alnum flags, 00477 // then give it all of them, allowing words like I-295 to be chosen. 00478 top_choice->top_choice_flags |= 00479 kLowerCaseFlag | kUpperCaseFlag | kDigitFlag; 00480 } 00481 return mixed ? 1 : 0; 00482 } 00483 00484 // Finds the next ViterbiStateEntry with which the given unichar_id can 00485 // combine sensibly, taking into account any mixed alnum/mixed case 00486 // situation, and whether this combination has been inspected before. 00487 ViterbiStateEntry* LanguageModel::GetNextParentVSE( 00488 bool just_classified, bool mixed_alnum, const BLOB_CHOICE* bc, 00489 LanguageModelFlagsType blob_choice_flags, const UNICHARSET& unicharset, 00490 WERD_RES* word_res, ViterbiStateEntry_IT* vse_it, 00491 LanguageModelFlagsType* top_choice_flags) const { 00492 for (; !vse_it->cycled_list(); vse_it->forward()) { 00493 ViterbiStateEntry* parent_vse = vse_it->data(); 00494 // Only consider the parent if it has been updated or 00495 // if the current ratings cell has just been classified. 00496 if (!just_classified && !parent_vse->updated) continue; 00497 if (language_model_debug_level > 2) 00498 parent_vse->Print("Considering"); 00499 // If the parent is non-alnum, then upper counts as lower. 00500 *top_choice_flags = blob_choice_flags; 00501 if ((blob_choice_flags & kUpperCaseFlag) && 00502 !parent_vse->HasAlnumChoice(unicharset)) { 00503 *top_choice_flags |= kLowerCaseFlag; 00504 } 00505 *top_choice_flags &= parent_vse->top_choice_flags; 00506 UNICHAR_ID unichar_id = bc->unichar_id(); 00507 const BLOB_CHOICE* parent_b = parent_vse->curr_b; 00508 UNICHAR_ID parent_id = parent_b->unichar_id(); 00509 // Digits do not bind to alphas if there is a mix in both parent and current 00510 // or if the alpha is not the top choice. 00511 if (unicharset.get_isdigit(unichar_id) && 00512 unicharset.get_isalpha(parent_id) && 00513 (mixed_alnum || *top_choice_flags == 0)) 00514 continue; // Digits don't bind to alphas. 00515 // Likewise alphas do not bind to digits if there is a mix in both or if 00516 // the digit is not the top choice. 00517 if (unicharset.get_isalpha(unichar_id) && 00518 unicharset.get_isdigit(parent_id) && 00519 (mixed_alnum || *top_choice_flags == 0)) 00520 continue; // Alphas don't bind to digits. 00521 // If there is a case mix of the same alpha in the parent list, then 00522 // competing_vse is non-null and will be used to determine whether 00523 // or not to bind the current blob choice. 00524 if (parent_vse->competing_vse != NULL) { 00525 const BLOB_CHOICE* competing_b = parent_vse->competing_vse->curr_b; 00526 UNICHAR_ID other_id = competing_b->unichar_id(); 00527 if (language_model_debug_level >= 5) { 00528 tprintf("Parent %s has competition %s\n", 00529 unicharset.id_to_unichar(parent_id), 00530 unicharset.id_to_unichar(other_id)); 00531 } 00532 if (unicharset.SizesDistinct(parent_id, other_id)) { 00533 // If other_id matches bc wrt position and size, and parent_id, doesn't, 00534 // don't bind to the current parent. 00535 if (bc->PosAndSizeAgree(*competing_b, word_res->x_height, 00536 language_model_debug_level >= 5) && 00537 !bc->PosAndSizeAgree(*parent_b, word_res->x_height, 00538 language_model_debug_level >= 5)) 00539 continue; // Competing blobchoice has a better vertical match. 00540 } 00541 } 00542 vse_it->forward(); 00543 return parent_vse; // This one is good! 00544 } 00545 return NULL; // Ran out of possibilities. 00546 } 00547 00548 bool LanguageModel::AddViterbiStateEntry( 00549 LanguageModelFlagsType top_choice_flags, 00550 float denom, 00551 bool word_end, 00552 int curr_col, int curr_row, 00553 BLOB_CHOICE *b, 00554 LanguageModelState *curr_state, 00555 ViterbiStateEntry *parent_vse, 00556 LMPainPoints *pain_points, 00557 WERD_RES *word_res, 00558 BestChoiceBundle *best_choice_bundle, 00559 BlamerBundle *blamer_bundle) { 00560 ViterbiStateEntry_IT vit; 00561 if (language_model_debug_level > 1) { 00562 tprintf("AddViterbiStateEntry for unichar %s rating=%.4f" 00563 " certainty=%.4f top_choice_flags=0x%x", 00564 dict_->getUnicharset().id_to_unichar(b->unichar_id()), 00565 b->rating(), b->certainty(), top_choice_flags); 00566 if (language_model_debug_level > 5) 00567 tprintf(" parent_vse=%p\n", parent_vse); 00568 else 00569 tprintf("\n"); 00570 } 00571 // Check whether the list is full. 00572 if (curr_state != NULL && 00573 curr_state->viterbi_state_entries_length >= 00574 language_model_viterbi_list_max_size) { 00575 if (language_model_debug_level > 1) { 00576 tprintf("AddViterbiStateEntry: viterbi list is full!\n"); 00577 } 00578 return false; 00579 } 00580 00581 // Invoke Dawg language model component. 00582 LanguageModelDawgInfo *dawg_info = 00583 GenerateDawgInfo(word_end, curr_col, curr_row, *b, parent_vse); 00584 00585 float outline_length = 00586 AssociateUtils::ComputeOutlineLength(rating_cert_scale_, *b); 00587 // Invoke Ngram language model component. 00588 LanguageModelNgramInfo *ngram_info = NULL; 00589 if (language_model_ngram_on) { 00590 ngram_info = GenerateNgramInfo( 00591 dict_->getUnicharset().id_to_unichar(b->unichar_id()), b->certainty(), 00592 denom, curr_col, curr_row, outline_length, parent_vse); 00593 ASSERT_HOST(ngram_info != NULL); 00594 } 00595 bool liked_by_language_model = dawg_info != NULL || 00596 (ngram_info != NULL && !ngram_info->pruned); 00597 // Quick escape if not liked by the language model, can't be consistent 00598 // xheight, and not top choice. 00599 if (!liked_by_language_model && top_choice_flags == 0) { 00600 if (language_model_debug_level > 1) { 00601 tprintf("Language model components very early pruned this entry\n"); 00602 } 00603 delete ngram_info; 00604 delete dawg_info; 00605 return false; 00606 } 00607 00608 // Check consistency of the path and set the relevant consistency_info. 00609 LMConsistencyInfo consistency_info( 00610 parent_vse != NULL ? &parent_vse->consistency_info : NULL); 00611 // Start with just the x-height consistency, as it provides significant 00612 // pruning opportunity. 00613 consistency_info.ComputeXheightConsistency( 00614 b, dict_->getUnicharset().get_ispunctuation(b->unichar_id())); 00615 // Turn off xheight consistent flag if not consistent. 00616 if (consistency_info.InconsistentXHeight()) { 00617 top_choice_flags &= ~kXhtConsistentFlag; 00618 } 00619 00620 // Quick escape if not liked by the language model, not consistent xheight, 00621 // and not top choice. 00622 if (!liked_by_language_model && top_choice_flags == 0) { 00623 if (language_model_debug_level > 1) { 00624 tprintf("Language model components early pruned this entry\n"); 00625 } 00626 delete ngram_info; 00627 delete dawg_info; 00628 return false; 00629 } 00630 00631 // Compute the rest of the consistency info. 00632 FillConsistencyInfo(curr_col, word_end, b, parent_vse, 00633 word_res, &consistency_info); 00634 if (dawg_info != NULL && consistency_info.invalid_punc) { 00635 consistency_info.invalid_punc = false; // do not penalize dict words 00636 } 00637 00638 // Compute cost of associating the blobs that represent the current unichar. 00639 AssociateStats associate_stats; 00640 ComputeAssociateStats(curr_col, curr_row, max_char_wh_ratio_, 00641 parent_vse, word_res, &associate_stats); 00642 if (parent_vse != NULL) { 00643 associate_stats.shape_cost += parent_vse->associate_stats.shape_cost; 00644 associate_stats.bad_shape |= parent_vse->associate_stats.bad_shape; 00645 } 00646 00647 // Create the new ViterbiStateEntry compute the adjusted cost of the path. 00648 ViterbiStateEntry *new_vse = new ViterbiStateEntry( 00649 parent_vse, b, 0.0, outline_length, 00650 consistency_info, associate_stats, top_choice_flags, dawg_info, 00651 ngram_info, (language_model_debug_level > 0) ? 00652 dict_->getUnicharset().id_to_unichar(b->unichar_id()) : NULL); 00653 new_vse->cost = ComputeAdjustedPathCost(new_vse); 00654 00655 // Invoke Top Choice language model component to make the final adjustments 00656 // to new_vse->top_choice_flags. 00657 if (!curr_state->viterbi_state_entries.empty() && new_vse->top_choice_flags) { 00658 GenerateTopChoiceInfo(new_vse, parent_vse, curr_state); 00659 } 00660 00661 // If language model components did not like this unichar - return. 00662 bool keep = new_vse->top_choice_flags || liked_by_language_model; 00663 if (!(top_choice_flags & kSmallestRatingFlag) && // no non-top choice paths 00664 consistency_info.inconsistent_script) { // with inconsistent script 00665 keep = false; 00666 } 00667 if (!keep) { 00668 if (language_model_debug_level > 1) { 00669 tprintf("Language model components did not like this entry\n"); 00670 } 00671 delete new_vse; 00672 return false; 00673 } 00674 00675 // Discard this entry if it represents a prunable path and 00676 // language_model_viterbi_list_max_num_prunable such entries with a lower 00677 // cost have already been recorded. 00678 if (PrunablePath(*new_vse) && 00679 (curr_state->viterbi_state_entries_prunable_length >= 00680 language_model_viterbi_list_max_num_prunable) && 00681 new_vse->cost >= curr_state->viterbi_state_entries_prunable_max_cost) { 00682 if (language_model_debug_level > 1) { 00683 tprintf("Discarded ViterbiEntry with high cost %g max cost %g\n", 00684 new_vse->cost, 00685 curr_state->viterbi_state_entries_prunable_max_cost); 00686 } 00687 delete new_vse; 00688 return false; 00689 } 00690 00691 // Update best choice if needed. 00692 if (word_end) { 00693 UpdateBestChoice(new_vse, pain_points, word_res, 00694 best_choice_bundle, blamer_bundle); 00695 // Discard the entry if UpdateBestChoice() found flaws in it. 00696 if (new_vse->cost >= WERD_CHOICE::kBadRating && 00697 new_vse != best_choice_bundle->best_vse) { 00698 if (language_model_debug_level > 1) { 00699 tprintf("Discarded ViterbiEntry with high cost %g\n", new_vse->cost); 00700 } 00701 delete new_vse; 00702 return false; 00703 } 00704 } 00705 00706 // Add the new ViterbiStateEntry and to curr_state->viterbi_state_entries. 00707 curr_state->viterbi_state_entries.add_sorted(ViterbiStateEntry::Compare, 00708 false, new_vse); 00709 curr_state->viterbi_state_entries_length++; 00710 if (PrunablePath(*new_vse)) { 00711 curr_state->viterbi_state_entries_prunable_length++; 00712 } 00713 00714 // Update lms->viterbi_state_entries_prunable_max_cost and clear 00715 // top_choice_flags of entries with ratings_sum than new_vse->ratings_sum. 00716 if ((curr_state->viterbi_state_entries_prunable_length >= 00717 language_model_viterbi_list_max_num_prunable) || 00718 new_vse->top_choice_flags) { 00719 ASSERT_HOST(!curr_state->viterbi_state_entries.empty()); 00720 int prunable_counter = language_model_viterbi_list_max_num_prunable; 00721 vit.set_to_list(&(curr_state->viterbi_state_entries)); 00722 for (vit.mark_cycle_pt(); !vit.cycled_list(); vit.forward()) { 00723 ViterbiStateEntry *curr_vse = vit.data(); 00724 // Clear the appropriate top choice flags of the entries in the 00725 // list that have cost higher thank new_entry->cost 00726 // (since they will not be top choices any more). 00727 if (curr_vse->top_choice_flags && curr_vse != new_vse && 00728 curr_vse->cost > new_vse->cost) { 00729 curr_vse->top_choice_flags &= ~(new_vse->top_choice_flags); 00730 } 00731 if (prunable_counter > 0 && PrunablePath(*curr_vse)) --prunable_counter; 00732 // Update curr_state->viterbi_state_entries_prunable_max_cost. 00733 if (prunable_counter == 0) { 00734 curr_state->viterbi_state_entries_prunable_max_cost = vit.data()->cost; 00735 if (language_model_debug_level > 1) { 00736 tprintf("Set viterbi_state_entries_prunable_max_cost to %g\n", 00737 curr_state->viterbi_state_entries_prunable_max_cost); 00738 } 00739 prunable_counter = -1; // stop counting 00740 } 00741 } 00742 } 00743 00744 // Print the newly created ViterbiStateEntry. 00745 if (language_model_debug_level > 2) { 00746 new_vse->Print("New"); 00747 if (language_model_debug_level > 5) 00748 curr_state->Print("Updated viterbi list"); 00749 } 00750 00751 return true; 00752 } 00753 00754 void LanguageModel::GenerateTopChoiceInfo(ViterbiStateEntry *new_vse, 00755 const ViterbiStateEntry *parent_vse, 00756 LanguageModelState *lms) { 00757 ViterbiStateEntry_IT vit(&(lms->viterbi_state_entries)); 00758 for (vit.mark_cycle_pt(); !vit.cycled_list() && new_vse->top_choice_flags && 00759 new_vse->cost >= vit.data()->cost; vit.forward()) { 00760 // Clear the appropriate flags if the list already contains 00761 // a top choice entry with a lower cost. 00762 new_vse->top_choice_flags &= ~(vit.data()->top_choice_flags); 00763 } 00764 if (language_model_debug_level > 2) { 00765 tprintf("GenerateTopChoiceInfo: top_choice_flags=0x%x\n", 00766 new_vse->top_choice_flags); 00767 } 00768 } 00769 00770 LanguageModelDawgInfo *LanguageModel::GenerateDawgInfo( 00771 bool word_end, 00772 int curr_col, int curr_row, 00773 const BLOB_CHOICE &b, 00774 const ViterbiStateEntry *parent_vse) { 00775 // Initialize active_dawgs from parent_vse if it is not NULL. 00776 // Otherwise use very_beginning_active_dawgs_. 00777 if (parent_vse == NULL) { 00778 dawg_args_->active_dawgs = very_beginning_active_dawgs_; 00779 dawg_args_->permuter = NO_PERM; 00780 } else { 00781 if (parent_vse->dawg_info == NULL) return NULL; // not a dict word path 00782 dawg_args_->active_dawgs = parent_vse->dawg_info->active_dawgs; 00783 dawg_args_->permuter = parent_vse->dawg_info->permuter; 00784 } 00785 00786 // Deal with hyphenated words. 00787 if (word_end && dict_->has_hyphen_end(b.unichar_id(), curr_col == 0)) { 00788 if (language_model_debug_level > 0) tprintf("Hyphenated word found\n"); 00789 return new LanguageModelDawgInfo(dawg_args_->active_dawgs, 00790 COMPOUND_PERM); 00791 } 00792 00793 // Deal with compound words. 00794 if (dict_->compound_marker(b.unichar_id()) && 00795 (parent_vse == NULL || parent_vse->dawg_info->permuter != NUMBER_PERM)) { 00796 if (language_model_debug_level > 0) tprintf("Found compound marker\n"); 00797 // Do not allow compound operators at the beginning and end of the word. 00798 // Do not allow more than one compound operator per word. 00799 // Do not allow compounding of words with lengths shorter than 00800 // language_model_min_compound_length 00801 if (parent_vse == NULL || word_end || 00802 dawg_args_->permuter == COMPOUND_PERM || 00803 parent_vse->length < language_model_min_compound_length) return NULL; 00804 00805 int i; 00806 // Check a that the path terminated before the current character is a word. 00807 bool has_word_ending = false; 00808 for (i = 0; i < parent_vse->dawg_info->active_dawgs->size(); ++i) { 00809 const DawgPosition &pos = (*parent_vse->dawg_info->active_dawgs)[i]; 00810 const Dawg *pdawg = pos.dawg_index < 0 00811 ? NULL : dict_->GetDawg(pos.dawg_index); 00812 if (pdawg == NULL || pos.back_to_punc) continue;; 00813 if (pdawg->type() == DAWG_TYPE_WORD && pos.dawg_ref != NO_EDGE && 00814 pdawg->end_of_word(pos.dawg_ref)) { 00815 has_word_ending = true; 00816 break; 00817 } 00818 } 00819 if (!has_word_ending) return NULL; 00820 00821 if (language_model_debug_level > 0) tprintf("Compound word found\n"); 00822 return new LanguageModelDawgInfo(beginning_active_dawgs_, COMPOUND_PERM); 00823 } // done dealing with compound words 00824 00825 LanguageModelDawgInfo *dawg_info = NULL; 00826 00827 // Call LetterIsOkay(). 00828 // Use the normalized IDs so that all shapes of ' can be allowed in words 00829 // like don't. 00830 const GenericVector<UNICHAR_ID>& normed_ids = 00831 dict_->getUnicharset().normed_ids(b.unichar_id()); 00832 DawgPositionVector tmp_active_dawgs; 00833 for (int i = 0; i < normed_ids.size(); ++i) { 00834 if (language_model_debug_level > 2) 00835 tprintf("Test Letter OK for unichar %d, normed %d\n", 00836 b.unichar_id(), normed_ids[i]); 00837 dict_->LetterIsOkay(dawg_args_, normed_ids[i], 00838 word_end && i == normed_ids.size() - 1); 00839 if (dawg_args_->permuter == NO_PERM) { 00840 break; 00841 } else if (i < normed_ids.size() - 1) { 00842 tmp_active_dawgs = *dawg_args_->updated_dawgs; 00843 dawg_args_->active_dawgs = &tmp_active_dawgs; 00844 } 00845 if (language_model_debug_level > 2) 00846 tprintf("Letter was OK for unichar %d, normed %d\n", 00847 b.unichar_id(), normed_ids[i]); 00848 } 00849 dawg_args_->active_dawgs = NULL; 00850 if (dawg_args_->permuter != NO_PERM) { 00851 dawg_info = new LanguageModelDawgInfo(dawg_args_->updated_dawgs, 00852 dawg_args_->permuter); 00853 } else if (language_model_debug_level > 3) { 00854 tprintf("Letter %s not OK!\n", 00855 dict_->getUnicharset().id_to_unichar(b.unichar_id())); 00856 } 00857 00858 return dawg_info; 00859 } 00860 00861 LanguageModelNgramInfo *LanguageModel::GenerateNgramInfo( 00862 const char *unichar, float certainty, float denom, 00863 int curr_col, int curr_row, float outline_length, 00864 const ViterbiStateEntry *parent_vse) { 00865 // Initialize parent context. 00866 const char *pcontext_ptr = ""; 00867 int pcontext_unichar_step_len = 0; 00868 if (parent_vse == NULL) { 00869 pcontext_ptr = prev_word_str_.string(); 00870 pcontext_unichar_step_len = prev_word_unichar_step_len_; 00871 } else { 00872 pcontext_ptr = parent_vse->ngram_info->context.string(); 00873 pcontext_unichar_step_len = 00874 parent_vse->ngram_info->context_unichar_step_len; 00875 } 00876 // Compute p(unichar | parent context). 00877 int unichar_step_len = 0; 00878 bool pruned = false; 00879 float ngram_cost; 00880 float ngram_and_classifier_cost = 00881 ComputeNgramCost(unichar, certainty, denom, 00882 pcontext_ptr, &unichar_step_len, 00883 &pruned, &ngram_cost); 00884 // Normalize just the ngram_and_classifier_cost by outline_length. 00885 // The ngram_cost is used by the params_model, so it needs to be left as-is, 00886 // and the params model cost will be normalized by outline_length. 00887 ngram_and_classifier_cost *= 00888 outline_length / language_model_ngram_rating_factor; 00889 // Add the ngram_cost of the parent. 00890 if (parent_vse != NULL) { 00891 ngram_and_classifier_cost += 00892 parent_vse->ngram_info->ngram_and_classifier_cost; 00893 ngram_cost += parent_vse->ngram_info->ngram_cost; 00894 } 00895 00896 // Shorten parent context string by unichar_step_len unichars. 00897 int num_remove = (unichar_step_len + pcontext_unichar_step_len - 00898 language_model_ngram_order); 00899 if (num_remove > 0) pcontext_unichar_step_len -= num_remove; 00900 while (num_remove > 0 && *pcontext_ptr != '\0') { 00901 pcontext_ptr += UNICHAR::utf8_step(pcontext_ptr); 00902 --num_remove; 00903 } 00904 00905 // Decide whether to prune this ngram path and update changed accordingly. 00906 if (parent_vse != NULL && parent_vse->ngram_info->pruned) pruned = true; 00907 00908 // Construct and return the new LanguageModelNgramInfo. 00909 LanguageModelNgramInfo *ngram_info = new LanguageModelNgramInfo( 00910 pcontext_ptr, pcontext_unichar_step_len, pruned, ngram_cost, 00911 ngram_and_classifier_cost); 00912 ngram_info->context += unichar; 00913 ngram_info->context_unichar_step_len += unichar_step_len; 00914 assert(ngram_info->context_unichar_step_len <= language_model_ngram_order); 00915 return ngram_info; 00916 } 00917 00918 float LanguageModel::ComputeNgramCost(const char *unichar, 00919 float certainty, 00920 float denom, 00921 const char *context, 00922 int *unichar_step_len, 00923 bool *found_small_prob, 00924 float *ngram_cost) { 00925 const char *context_ptr = context; 00926 char *modified_context = NULL; 00927 char *modified_context_end = NULL; 00928 const char *unichar_ptr = unichar; 00929 const char *unichar_end = unichar_ptr + strlen(unichar_ptr); 00930 float prob = 0.0f; 00931 int step = 0; 00932 while (unichar_ptr < unichar_end && 00933 (step = UNICHAR::utf8_step(unichar_ptr)) > 0) { 00934 if (language_model_debug_level > 1) { 00935 tprintf("prob(%s | %s)=%g\n", unichar_ptr, context_ptr, 00936 dict_->ProbabilityInContext(context_ptr, -1, unichar_ptr, step)); 00937 } 00938 prob += dict_->ProbabilityInContext(context_ptr, -1, unichar_ptr, step); 00939 ++(*unichar_step_len); 00940 if (language_model_ngram_use_only_first_uft8_step) break; 00941 unichar_ptr += step; 00942 // If there are multiple UTF8 characters present in unichar, context is 00943 // updated to include the previously examined characters from str, 00944 // unless use_only_first_uft8_step is true. 00945 if (unichar_ptr < unichar_end) { 00946 if (modified_context == NULL) { 00947 int context_len = strlen(context); 00948 modified_context = 00949 new char[context_len + strlen(unichar_ptr) + step + 1]; 00950 strncpy(modified_context, context, context_len); 00951 modified_context_end = modified_context + context_len; 00952 context_ptr = modified_context; 00953 } 00954 strncpy(modified_context_end, unichar_ptr - step, step); 00955 modified_context_end += step; 00956 *modified_context_end = '\0'; 00957 } 00958 } 00959 prob /= static_cast<float>(*unichar_step_len); // normalize 00960 if (prob < language_model_ngram_small_prob) { 00961 if (language_model_debug_level > 0) tprintf("Found small prob %g\n", prob); 00962 *found_small_prob = true; 00963 prob = language_model_ngram_small_prob; 00964 } 00965 *ngram_cost = -1.0*log2(prob); 00966 float ngram_and_classifier_cost = 00967 -1.0*log2(CertaintyScore(certainty)/denom) + 00968 *ngram_cost * language_model_ngram_scale_factor; 00969 if (language_model_debug_level > 1) { 00970 tprintf("-log [ p(%s) * p(%s | %s) ] = -log2(%g*%g) = %g\n", unichar, 00971 unichar, context_ptr, CertaintyScore(certainty)/denom, prob, 00972 ngram_and_classifier_cost); 00973 } 00974 if (modified_context != NULL) delete[] modified_context; 00975 return ngram_and_classifier_cost; 00976 } 00977 00978 float LanguageModel::ComputeDenom(BLOB_CHOICE_LIST *curr_list) { 00979 if (curr_list->empty()) return 1.0f; 00980 float denom = 0.0f; 00981 int len = 0; 00982 BLOB_CHOICE_IT c_it(curr_list); 00983 for (c_it.mark_cycle_pt(); !c_it.cycled_list(); c_it.forward()) { 00984 ASSERT_HOST(c_it.data() != NULL); 00985 ++len; 00986 denom += CertaintyScore(c_it.data()->certainty()); 00987 } 00988 assert(len != 0); 00989 // The ideal situation would be to have the classifier scores for 00990 // classifying each position as each of the characters in the unicharset. 00991 // Since we can not do this because of speed, we add a very crude estimate 00992 // of what these scores for the "missing" classifications would sum up to. 00993 denom += (dict_->getUnicharset().size() - len) * 00994 CertaintyScore(language_model_ngram_nonmatch_score); 00995 00996 return denom; 00997 } 00998 00999 void LanguageModel::FillConsistencyInfo( 01000 int curr_col, 01001 bool word_end, 01002 BLOB_CHOICE *b, 01003 ViterbiStateEntry *parent_vse, 01004 WERD_RES *word_res, 01005 LMConsistencyInfo *consistency_info) { 01006 const UNICHARSET &unicharset = dict_->getUnicharset(); 01007 UNICHAR_ID unichar_id = b->unichar_id(); 01008 BLOB_CHOICE* parent_b = parent_vse != NULL ? parent_vse->curr_b : NULL; 01009 01010 // Check punctuation validity. 01011 if (unicharset.get_ispunctuation(unichar_id)) consistency_info->num_punc++; 01012 if (dict_->GetPuncDawg() != NULL && !consistency_info->invalid_punc) { 01013 if (dict_->compound_marker(unichar_id) && parent_b != NULL && 01014 (unicharset.get_isalpha(parent_b->unichar_id()) || 01015 unicharset.get_isdigit(parent_b->unichar_id()))) { 01016 // reset punc_ref for compound words 01017 consistency_info->punc_ref = NO_EDGE; 01018 } else { 01019 bool is_apos = dict_->is_apostrophe(unichar_id); 01020 bool prev_is_numalpha = (parent_b != NULL && 01021 (unicharset.get_isalpha(parent_b->unichar_id()) || 01022 unicharset.get_isdigit(parent_b->unichar_id()))); 01023 UNICHAR_ID pattern_unichar_id = 01024 (unicharset.get_isalpha(unichar_id) || 01025 unicharset.get_isdigit(unichar_id) || 01026 (is_apos && prev_is_numalpha)) ? 01027 Dawg::kPatternUnicharID : unichar_id; 01028 if (consistency_info->punc_ref == NO_EDGE || 01029 pattern_unichar_id != Dawg::kPatternUnicharID || 01030 dict_->GetPuncDawg()->edge_letter(consistency_info->punc_ref) != 01031 Dawg::kPatternUnicharID) { 01032 NODE_REF node = Dict::GetStartingNode(dict_->GetPuncDawg(), 01033 consistency_info->punc_ref); 01034 consistency_info->punc_ref = 01035 (node != NO_EDGE) ? dict_->GetPuncDawg()->edge_char_of( 01036 node, pattern_unichar_id, word_end) : NO_EDGE; 01037 if (consistency_info->punc_ref == NO_EDGE) { 01038 consistency_info->invalid_punc = true; 01039 } 01040 } 01041 } 01042 } 01043 01044 // Update case related counters. 01045 if (parent_vse != NULL && !word_end && dict_->compound_marker(unichar_id)) { 01046 // Reset counters if we are dealing with a compound word. 01047 consistency_info->num_lower = 0; 01048 consistency_info->num_non_first_upper = 0; 01049 } 01050 else if (unicharset.get_islower(unichar_id)) { 01051 consistency_info->num_lower++; 01052 } else if ((parent_b != NULL) && unicharset.get_isupper(unichar_id)) { 01053 if (unicharset.get_isupper(parent_b->unichar_id()) || 01054 consistency_info->num_lower > 0 || 01055 consistency_info->num_non_first_upper > 0) { 01056 consistency_info->num_non_first_upper++; 01057 } 01058 } 01059 01060 // Initialize consistency_info->script_id (use script of unichar_id 01061 // if it is not Common, use script id recorded by the parent otherwise). 01062 // Set inconsistent_script to true if the script of the current unichar 01063 // is not consistent with that of the parent. 01064 consistency_info->script_id = unicharset.get_script(unichar_id); 01065 // Hiragana and Katakana can mix with Han. 01066 if (dict_->getUnicharset().han_sid() != dict_->getUnicharset().null_sid()) { 01067 if ((unicharset.hiragana_sid() != unicharset.null_sid() && 01068 consistency_info->script_id == unicharset.hiragana_sid()) || 01069 (unicharset.katakana_sid() != unicharset.null_sid() && 01070 consistency_info->script_id == unicharset.katakana_sid())) { 01071 consistency_info->script_id = dict_->getUnicharset().han_sid(); 01072 } 01073 } 01074 01075 if (parent_vse != NULL && 01076 (parent_vse->consistency_info.script_id != 01077 dict_->getUnicharset().common_sid())) { 01078 int parent_script_id = parent_vse->consistency_info.script_id; 01079 // If script_id is Common, use script id of the parent instead. 01080 if (consistency_info->script_id == dict_->getUnicharset().common_sid()) { 01081 consistency_info->script_id = parent_script_id; 01082 } 01083 if (consistency_info->script_id != parent_script_id) { 01084 consistency_info->inconsistent_script = true; 01085 } 01086 } 01087 01088 // Update chartype related counters. 01089 if (unicharset.get_isalpha(unichar_id)) { 01090 consistency_info->num_alphas++; 01091 } else if (unicharset.get_isdigit(unichar_id)) { 01092 consistency_info->num_digits++; 01093 } else if (!unicharset.get_ispunctuation(unichar_id)) { 01094 consistency_info->num_other++; 01095 } 01096 01097 // Check font and spacing consistency. 01098 if (fontinfo_table_->size() > 0 && parent_b != NULL) { 01099 int fontinfo_id = -1; 01100 if (parent_b->fontinfo_id() == b->fontinfo_id() || 01101 parent_b->fontinfo_id2() == b->fontinfo_id()) { 01102 fontinfo_id = b->fontinfo_id(); 01103 } else if (parent_b->fontinfo_id() == b->fontinfo_id2() || 01104 parent_b->fontinfo_id2() == b->fontinfo_id2()) { 01105 fontinfo_id = b->fontinfo_id2(); 01106 } 01107 if(language_model_debug_level > 1) { 01108 tprintf("pfont %s pfont %s font %s font2 %s common %s(%d)\n", 01109 (parent_b->fontinfo_id() >= 0) ? 01110 fontinfo_table_->get(parent_b->fontinfo_id()).name : "" , 01111 (parent_b->fontinfo_id2() >= 0) ? 01112 fontinfo_table_->get(parent_b->fontinfo_id2()).name : "", 01113 (b->fontinfo_id() >= 0) ? 01114 fontinfo_table_->get(b->fontinfo_id()).name : "", 01115 (fontinfo_id >= 0) ? fontinfo_table_->get(fontinfo_id).name : "", 01116 (fontinfo_id >= 0) ? fontinfo_table_->get(fontinfo_id).name : "", 01117 fontinfo_id); 01118 } 01119 if (!word_res->blob_widths.empty()) { // if we have widths/gaps info 01120 bool expected_gap_found = false; 01121 float expected_gap; 01122 int temp_gap; 01123 if (fontinfo_id >= 0) { // found a common font 01124 ASSERT_HOST(fontinfo_id < fontinfo_table_->size()); 01125 if (fontinfo_table_->get(fontinfo_id).get_spacing( 01126 parent_b->unichar_id(), unichar_id, &temp_gap)) { 01127 expected_gap = temp_gap; 01128 expected_gap_found = true; 01129 } 01130 } else { 01131 consistency_info->inconsistent_font = true; 01132 // Get an average of the expected gaps in each font 01133 int num_addends = 0; 01134 expected_gap = 0; 01135 int temp_fid; 01136 for (int i = 0; i < 4; ++i) { 01137 if (i == 0) { 01138 temp_fid = parent_b->fontinfo_id(); 01139 } else if (i == 1) { 01140 temp_fid = parent_b->fontinfo_id2(); 01141 } else if (i == 2) { 01142 temp_fid = b->fontinfo_id(); 01143 } else { 01144 temp_fid = b->fontinfo_id2(); 01145 } 01146 ASSERT_HOST(temp_fid < 0 || fontinfo_table_->size()); 01147 if (temp_fid >= 0 && fontinfo_table_->get(temp_fid).get_spacing( 01148 parent_b->unichar_id(), unichar_id, &temp_gap)) { 01149 expected_gap += temp_gap; 01150 num_addends++; 01151 } 01152 } 01153 expected_gap_found = (num_addends > 0); 01154 if (num_addends > 0) { 01155 expected_gap /= static_cast<float>(num_addends); 01156 } 01157 } 01158 if (expected_gap_found) { 01159 float actual_gap = 01160 static_cast<float>(word_res->GetBlobsGap(curr_col-1)); 01161 float gap_ratio = expected_gap / actual_gap; 01162 // TODO(daria): find a good way to tune this heuristic estimate. 01163 if (gap_ratio < 1/2 || gap_ratio > 2) { 01164 consistency_info->num_inconsistent_spaces++; 01165 } 01166 if (language_model_debug_level > 1) { 01167 tprintf("spacing for %s(%d) %s(%d) col %d: expected %g actual %g\n", 01168 unicharset.id_to_unichar(parent_b->unichar_id()), 01169 parent_b->unichar_id(), unicharset.id_to_unichar(unichar_id), 01170 unichar_id, curr_col, expected_gap, actual_gap); 01171 } 01172 } 01173 } 01174 } 01175 } 01176 01177 float LanguageModel::ComputeAdjustedPathCost(ViterbiStateEntry *vse) { 01178 ASSERT_HOST(vse != NULL); 01179 if (params_model_.Initialized()) { 01180 float features[PTRAIN_NUM_FEATURE_TYPES]; 01181 ExtractFeaturesFromPath(*vse, features); 01182 float cost = params_model_.ComputeCost(features); 01183 if (language_model_debug_level > 3) { 01184 tprintf("ComputeAdjustedPathCost %g ParamsModel features:\n", cost); 01185 if (language_model_debug_level >= 5) { 01186 for (int f = 0; f < PTRAIN_NUM_FEATURE_TYPES; ++f) { 01187 tprintf("%s=%g\n", kParamsTrainingFeatureTypeName[f], features[f]); 01188 } 01189 } 01190 } 01191 return cost * vse->outline_length; 01192 } else { 01193 float adjustment = 1.0f; 01194 if (vse->dawg_info == NULL || vse->dawg_info->permuter != FREQ_DAWG_PERM) { 01195 adjustment += language_model_penalty_non_freq_dict_word; 01196 } 01197 if (vse->dawg_info == NULL) { 01198 adjustment += language_model_penalty_non_dict_word; 01199 if (vse->length > language_model_min_compound_length) { 01200 adjustment += ((vse->length - language_model_min_compound_length) * 01201 language_model_penalty_increment); 01202 } 01203 } 01204 if (vse->associate_stats.shape_cost > 0) { 01205 adjustment += vse->associate_stats.shape_cost / 01206 static_cast<float>(vse->length); 01207 } 01208 if (language_model_ngram_on) { 01209 ASSERT_HOST(vse->ngram_info != NULL); 01210 return vse->ngram_info->ngram_and_classifier_cost * adjustment; 01211 } else { 01212 adjustment += ComputeConsistencyAdjustment(vse->dawg_info, 01213 vse->consistency_info); 01214 return vse->ratings_sum * adjustment; 01215 } 01216 } 01217 } 01218 01219 void LanguageModel::UpdateBestChoice( 01220 ViterbiStateEntry *vse, 01221 LMPainPoints *pain_points, 01222 WERD_RES *word_res, 01223 BestChoiceBundle *best_choice_bundle, 01224 BlamerBundle *blamer_bundle) { 01225 bool truth_path; 01226 WERD_CHOICE *word = ConstructWord(vse, word_res, &best_choice_bundle->fixpt, 01227 blamer_bundle, &truth_path); 01228 ASSERT_HOST(word != NULL); 01229 if (dict_->stopper_debug_level >= 1) { 01230 STRING word_str; 01231 word->string_and_lengths(&word_str, NULL); 01232 vse->Print(word_str.string()); 01233 } 01234 if (language_model_debug_level > 0) { 01235 word->print("UpdateBestChoice() constructed word"); 01236 } 01237 // Record features from the current path if necessary. 01238 ParamsTrainingHypothesis curr_hyp; 01239 if (blamer_bundle != NULL) { 01240 if (vse->dawg_info != NULL) vse->dawg_info->permuter = 01241 static_cast<PermuterType>(word->permuter()); 01242 ExtractFeaturesFromPath(*vse, curr_hyp.features); 01243 word->string_and_lengths(&(curr_hyp.str), NULL); 01244 curr_hyp.cost = vse->cost; // record cost for error rate computations 01245 if (language_model_debug_level > 0) { 01246 tprintf("Raw features extracted from %s (cost=%g) [ ", 01247 curr_hyp.str.string(), curr_hyp.cost); 01248 for (int deb_i = 0; deb_i < PTRAIN_NUM_FEATURE_TYPES; ++deb_i) { 01249 tprintf("%g ", curr_hyp.features[deb_i]); 01250 } 01251 tprintf("]\n"); 01252 } 01253 // Record the current hypothesis in params_training_bundle. 01254 blamer_bundle->AddHypothesis(curr_hyp); 01255 if (truth_path) 01256 blamer_bundle->UpdateBestRating(word->rating()); 01257 } 01258 if (blamer_bundle != NULL && blamer_bundle->GuidedSegsearchStillGoing()) { 01259 // The word was constructed solely for blamer_bundle->AddHypothesis, so 01260 // we no longer need it. 01261 delete word; 01262 return; 01263 } 01264 if (word_res->chopped_word != NULL && !word_res->chopped_word->blobs.empty()) 01265 word->SetScriptPositions(false, word_res->chopped_word); 01266 // Update and log new raw_choice if needed. 01267 if (word_res->raw_choice == NULL || 01268 word->rating() < word_res->raw_choice->rating()) { 01269 if (word_res->LogNewRawChoice(word) && language_model_debug_level > 0) 01270 tprintf("Updated raw choice\n"); 01271 } 01272 // Set the modified rating for best choice to vse->cost and log best choice. 01273 word->set_rating(vse->cost); 01274 // Call LogNewChoice() for best choice from Dict::adjust_word() since it 01275 // computes adjust_factor that is used by the adaption code (e.g. by 01276 // ClassifyAdaptableWord() to compute adaption acceptance thresholds). 01277 // Note: the rating of the word is not adjusted. 01278 dict_->adjust_word(word, vse->dawg_info == NULL, 01279 vse->consistency_info.xht_decision, 0.0, 01280 false, language_model_debug_level > 0); 01281 // Hand ownership of the word over to the word_res. 01282 if (!word_res->LogNewCookedChoice(dict_->tessedit_truncate_wordchoice_log, 01283 dict_->stopper_debug_level >= 1, word)) { 01284 // The word was so bad that it was deleted. 01285 return; 01286 } 01287 if (word_res->best_choice == word) { 01288 // Word was the new best. 01289 if (dict_->AcceptableChoice(*word, vse->consistency_info.xht_decision) && 01290 AcceptablePath(*vse)) { 01291 acceptable_choice_found_ = true; 01292 } 01293 // Update best_choice_bundle. 01294 best_choice_bundle->updated = true; 01295 best_choice_bundle->best_vse = vse; 01296 if (language_model_debug_level > 0) { 01297 tprintf("Updated best choice\n"); 01298 word->print_state("New state "); 01299 } 01300 // Update hyphen state if we are dealing with a dictionary word. 01301 if (vse->dawg_info != NULL) { 01302 if (dict_->has_hyphen_end(*word)) { 01303 dict_->set_hyphen_word(*word, *(dawg_args_->active_dawgs)); 01304 } else { 01305 dict_->reset_hyphen_vars(true); 01306 } 01307 } 01308 01309 if (blamer_bundle != NULL) { 01310 blamer_bundle->set_best_choice_is_dict_and_top_choice( 01311 vse->dawg_info != NULL && vse->top_choice_flags); 01312 } 01313 } 01314 if (wordrec_display_segmentations) { 01315 word->DisplaySegmentation(word_res->chopped_word); 01316 } 01317 } 01318 01319 void LanguageModel::ExtractFeaturesFromPath( 01320 const ViterbiStateEntry &vse, float features[]) { 01321 memset(features, 0, sizeof(float) * PTRAIN_NUM_FEATURE_TYPES); 01322 // Record dictionary match info. 01323 int len = vse.length <= kMaxSmallWordUnichars ? 0 : 01324 vse.length <= kMaxMediumWordUnichars ? 1 : 2; 01325 if (vse.dawg_info != NULL) { 01326 int permuter = vse.dawg_info->permuter; 01327 if (permuter == NUMBER_PERM || permuter == USER_PATTERN_PERM) { 01328 if (vse.consistency_info.num_digits == vse.length) { 01329 features[PTRAIN_DIGITS_SHORT+len] = 1.0; 01330 } else { 01331 features[PTRAIN_NUM_SHORT+len] = 1.0; 01332 } 01333 } else if (permuter == DOC_DAWG_PERM) { 01334 features[PTRAIN_DOC_SHORT+len] = 1.0; 01335 } else if (permuter == SYSTEM_DAWG_PERM || permuter == USER_DAWG_PERM || 01336 permuter == COMPOUND_PERM) { 01337 features[PTRAIN_DICT_SHORT+len] = 1.0; 01338 } else if (permuter == FREQ_DAWG_PERM) { 01339 features[PTRAIN_FREQ_SHORT+len] = 1.0; 01340 } 01341 } 01342 // Record shape cost feature (normalized by path length). 01343 features[PTRAIN_SHAPE_COST_PER_CHAR] = 01344 vse.associate_stats.shape_cost / static_cast<float>(vse.length); 01345 // Record ngram cost. (normalized by the path length). 01346 features[PTRAIN_NGRAM_COST_PER_CHAR] = 0.0; 01347 if (vse.ngram_info != NULL) { 01348 features[PTRAIN_NGRAM_COST_PER_CHAR] = 01349 vse.ngram_info->ngram_cost / static_cast<float>(vse.length); 01350 } 01351 // Record consistency-related features. 01352 // Disabled this feature for due to its poor performance. 01353 // features[PTRAIN_NUM_BAD_PUNC] = vse.consistency_info.NumInconsistentPunc(); 01354 features[PTRAIN_NUM_BAD_CASE] = vse.consistency_info.NumInconsistentCase(); 01355 features[PTRAIN_XHEIGHT_CONSISTENCY] = vse.consistency_info.xht_decision; 01356 features[PTRAIN_NUM_BAD_CHAR_TYPE] = vse.dawg_info == NULL ? 01357 vse.consistency_info.NumInconsistentChartype() : 0.0; 01358 features[PTRAIN_NUM_BAD_SPACING] = 01359 vse.consistency_info.NumInconsistentSpaces(); 01360 // Disabled this feature for now due to its poor performance. 01361 // features[PTRAIN_NUM_BAD_FONT] = vse.consistency_info.inconsistent_font; 01362 01363 // Classifier-related features. 01364 features[PTRAIN_RATING_PER_CHAR] = 01365 vse.ratings_sum / static_cast<float>(vse.outline_length); 01366 } 01367 01368 WERD_CHOICE *LanguageModel::ConstructWord( 01369 ViterbiStateEntry *vse, 01370 WERD_RES *word_res, 01371 DANGERR *fixpt, 01372 BlamerBundle *blamer_bundle, 01373 bool *truth_path) { 01374 if (truth_path != NULL) { 01375 *truth_path = 01376 (blamer_bundle != NULL && 01377 vse->length == blamer_bundle->correct_segmentation_length()); 01378 } 01379 BLOB_CHOICE *curr_b = vse->curr_b; 01380 ViterbiStateEntry *curr_vse = vse; 01381 01382 int i; 01383 bool compound = dict_->hyphenated(); // treat hyphenated words as compound 01384 01385 // Re-compute the variance of the width-to-height ratios (since we now 01386 // can compute the mean over the whole word). 01387 float full_wh_ratio_mean = 0.0f; 01388 if (vse->associate_stats.full_wh_ratio_var != 0.0f) { 01389 vse->associate_stats.shape_cost -= vse->associate_stats.full_wh_ratio_var; 01390 full_wh_ratio_mean = (vse->associate_stats.full_wh_ratio_total / 01391 static_cast<float>(vse->length)); 01392 vse->associate_stats.full_wh_ratio_var = 0.0f; 01393 } 01394 01395 // Construct a WERD_CHOICE by tracing parent pointers. 01396 WERD_CHOICE *word = new WERD_CHOICE(word_res->uch_set, vse->length); 01397 word->set_length(vse->length); 01398 int total_blobs = 0; 01399 for (i = (vse->length-1); i >= 0; --i) { 01400 if (blamer_bundle != NULL && truth_path != NULL && *truth_path && 01401 !blamer_bundle->MatrixPositionCorrect(i, curr_b->matrix_cell())) { 01402 *truth_path = false; 01403 } 01404 // The number of blobs used for this choice is row - col + 1. 01405 int num_blobs = curr_b->matrix_cell().row - curr_b->matrix_cell().col + 1; 01406 total_blobs += num_blobs; 01407 word->set_blob_choice(i, num_blobs, curr_b); 01408 // Update the width-to-height ratio variance. Useful non-space delimited 01409 // languages to ensure that the blobs are of uniform width. 01410 // Skip leading and trailing punctuation when computing the variance. 01411 if ((full_wh_ratio_mean != 0.0f && 01412 ((curr_vse != vse && curr_vse->parent_vse != NULL) || 01413 !dict_->getUnicharset().get_ispunctuation(curr_b->unichar_id())))) { 01414 vse->associate_stats.full_wh_ratio_var += 01415 pow(full_wh_ratio_mean - curr_vse->associate_stats.full_wh_ratio, 2); 01416 if (language_model_debug_level > 2) { 01417 tprintf("full_wh_ratio_var += (%g-%g)^2\n", 01418 full_wh_ratio_mean, curr_vse->associate_stats.full_wh_ratio); 01419 } 01420 } 01421 01422 // Mark the word as compound if compound permuter was set for any of 01423 // the unichars on the path (usually this will happen for unichars 01424 // that are compounding operators, like "-" and "/"). 01425 if (!compound && curr_vse->dawg_info && 01426 curr_vse->dawg_info->permuter == COMPOUND_PERM) compound = true; 01427 01428 // Update curr_* pointers. 01429 curr_vse = curr_vse->parent_vse; 01430 if (curr_vse == NULL) break; 01431 curr_b = curr_vse->curr_b; 01432 } 01433 ASSERT_HOST(i == 0); // check that we recorded all the unichar ids. 01434 ASSERT_HOST(total_blobs == word_res->ratings->dimension()); 01435 // Re-adjust shape cost to include the updated width-to-height variance. 01436 if (full_wh_ratio_mean != 0.0f) { 01437 vse->associate_stats.shape_cost += vse->associate_stats.full_wh_ratio_var; 01438 } 01439 01440 word->set_rating(vse->ratings_sum); 01441 word->set_certainty(vse->min_certainty); 01442 word->set_x_heights(vse->consistency_info.BodyMinXHeight(), 01443 vse->consistency_info.BodyMaxXHeight()); 01444 if (vse->dawg_info != NULL) { 01445 word->set_permuter(compound ? COMPOUND_PERM : vse->dawg_info->permuter); 01446 } else if (language_model_ngram_on && !vse->ngram_info->pruned) { 01447 word->set_permuter(NGRAM_PERM); 01448 } else if (vse->top_choice_flags) { 01449 word->set_permuter(TOP_CHOICE_PERM); 01450 } else { 01451 word->set_permuter(NO_PERM); 01452 } 01453 word->set_dangerous_ambig_found_(!dict_->NoDangerousAmbig(word, fixpt, true, 01454 word_res->ratings)); 01455 return word; 01456 } 01457 01458 } // namespace tesseract