tesseract
3.03
|
00001 // Copyright 2011 Google Inc. All Rights Reserved. 00002 // Author: rays@google.com (Ray Smith) 00003 // 00004 // Licensed under the Apache License, Version 2.0 (the "License"); 00005 // you may not use this file except in compliance with the License. 00006 // You may obtain a copy of the License at 00007 // http://www.apache.org/licenses/LICENSE-2.0 00008 // Unless required by applicable law or agreed to in writing, software 00009 // distributed under the License is distributed on an "AS IS" BASIS, 00010 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 00011 // See the License for the specific language governing permissions and 00012 // limitations under the License. 00013 // 00015 #include <ctime> 00016 00017 #include "errorcounter.h" 00018 00019 #include "fontinfo.h" 00020 #include "ndminx.h" 00021 #include "sampleiterator.h" 00022 #include "shapeclassifier.h" 00023 #include "shapetable.h" 00024 #include "trainingsample.h" 00025 #include "trainingsampleset.h" 00026 #include "unicity_table.h" 00027 00028 namespace tesseract { 00029 00030 // Difference in result rating to be thought of as an "equal" choice. 00031 const double kRatingEpsilon = 1.0 / 32; 00032 00033 // Tests a classifier, computing its error rate. 00034 // See errorcounter.h for description of arguments. 00035 // Iterates over the samples, calling the classifier in normal/silent mode. 00036 // If the classifier makes a CT_UNICHAR_TOPN_ERR error, and the appropriate 00037 // report_level is set (4 or greater), it will then call the classifier again 00038 // with a debug flag and a keep_this argument to find out what is going on. 00039 double ErrorCounter::ComputeErrorRate(ShapeClassifier* classifier, 00040 int report_level, CountTypes boosting_mode, 00041 const FontInfoTable& fontinfo_table, 00042 const GenericVector<Pix*>& page_images, SampleIterator* it, 00043 double* unichar_error, double* scaled_error, STRING* fonts_report) { 00044 int fontsize = it->sample_set()->NumFonts(); 00045 ErrorCounter counter(classifier->GetUnicharset(), fontsize); 00046 GenericVector<UnicharRating> results; 00047 00048 clock_t start = clock(); 00049 int total_samples = 0; 00050 double unscaled_error = 0.0; 00051 // Set a number of samples on which to run the classify debug mode. 00052 int error_samples = report_level > 3 ? report_level * report_level : 0; 00053 // Iterate over all the samples, accumulating errors. 00054 for (it->Begin(); !it->AtEnd(); it->Next()) { 00055 TrainingSample* mutable_sample = it->MutableSample(); 00056 int page_index = mutable_sample->page_num(); 00057 Pix* page_pix = 0 <= page_index && page_index < page_images.size() 00058 ? page_images[page_index] : NULL; 00059 // No debug, no keep this. 00060 classifier->UnicharClassifySample(*mutable_sample, page_pix, 0, 00061 INVALID_UNICHAR_ID, &results); 00062 bool debug_it = false; 00063 int correct_id = mutable_sample->class_id(); 00064 if (counter.unicharset_.has_special_codes() && 00065 (correct_id == UNICHAR_SPACE || correct_id == UNICHAR_JOINED || 00066 correct_id == UNICHAR_BROKEN)) { 00067 // This is junk so use the special counter. 00068 debug_it = counter.AccumulateJunk(report_level > 3, 00069 results, 00070 mutable_sample); 00071 } else { 00072 debug_it = counter.AccumulateErrors(report_level > 3, boosting_mode, 00073 fontinfo_table, 00074 results, mutable_sample); 00075 } 00076 if (debug_it && error_samples > 0) { 00077 // Running debug, keep the correct answer, and debug the classifier. 00078 tprintf("Error on sample %d: %s Classifier debug output:\n", 00079 it->GlobalSampleIndex(), 00080 it->sample_set()->SampleToString(*mutable_sample).string()); 00081 classifier->DebugDisplay(*mutable_sample, page_pix, correct_id); 00082 --error_samples; 00083 } 00084 ++total_samples; 00085 } 00086 double total_time = 1.0 * (clock() - start) / CLOCKS_PER_SEC; 00087 // Create the appropriate error report. 00088 unscaled_error = counter.ReportErrors(report_level, boosting_mode, 00089 fontinfo_table, 00090 *it, unichar_error, fonts_report); 00091 if (scaled_error != NULL) *scaled_error = counter.scaled_error_; 00092 if (report_level > 1) { 00093 // It is useful to know the time in microseconds/char. 00094 tprintf("Errors computed in %.2fs at %.1f μs/char\n", 00095 total_time, 1000000.0 * total_time / total_samples); 00096 } 00097 return unscaled_error; 00098 } 00099 00100 // Tests a pair of classifiers, debugging errors of the new against the old. 00101 // See errorcounter.h for description of arguments. 00102 // Iterates over the samples, calling the classifiers in normal/silent mode. 00103 // If the new_classifier makes a boosting_mode error that the old_classifier 00104 // does not, it will then call the new_classifier again with a debug flag 00105 // and a keep_this argument to find out what is going on. 00106 void ErrorCounter::DebugNewErrors( 00107 ShapeClassifier* new_classifier, ShapeClassifier* old_classifier, 00108 CountTypes boosting_mode, 00109 const FontInfoTable& fontinfo_table, 00110 const GenericVector<Pix*>& page_images, SampleIterator* it) { 00111 int fontsize = it->sample_set()->NumFonts(); 00112 ErrorCounter old_counter(old_classifier->GetUnicharset(), fontsize); 00113 ErrorCounter new_counter(new_classifier->GetUnicharset(), fontsize); 00114 GenericVector<UnicharRating> results; 00115 00116 int total_samples = 0; 00117 int error_samples = 25; 00118 int total_new_errors = 0; 00119 // Iterate over all the samples, accumulating errors. 00120 for (it->Begin(); !it->AtEnd(); it->Next()) { 00121 TrainingSample* mutable_sample = it->MutableSample(); 00122 int page_index = mutable_sample->page_num(); 00123 Pix* page_pix = 0 <= page_index && page_index < page_images.size() 00124 ? page_images[page_index] : NULL; 00125 // No debug, no keep this. 00126 old_classifier->UnicharClassifySample(*mutable_sample, page_pix, 0, 00127 INVALID_UNICHAR_ID, &results); 00128 int correct_id = mutable_sample->class_id(); 00129 if (correct_id != 0 && 00130 !old_counter.AccumulateErrors(true, boosting_mode, fontinfo_table, 00131 results, mutable_sample)) { 00132 // old classifier was correct, check the new one. 00133 new_classifier->UnicharClassifySample(*mutable_sample, page_pix, 0, 00134 INVALID_UNICHAR_ID, &results); 00135 if (correct_id != 0 && 00136 new_counter.AccumulateErrors(true, boosting_mode, fontinfo_table, 00137 results, mutable_sample)) { 00138 tprintf("New Error on sample %d: Classifier debug output:\n", 00139 it->GlobalSampleIndex()); 00140 ++total_new_errors; 00141 new_classifier->UnicharClassifySample(*mutable_sample, page_pix, 1, 00142 correct_id, &results); 00143 if (results.size() > 0 && error_samples > 0) { 00144 new_classifier->DebugDisplay(*mutable_sample, page_pix, correct_id); 00145 --error_samples; 00146 } 00147 } 00148 } 00149 ++total_samples; 00150 } 00151 tprintf("Total new errors = %d\n", total_new_errors); 00152 } 00153 00154 // Constructor is private. Only anticipated use of ErrorCounter is via 00155 // the static ComputeErrorRate. 00156 ErrorCounter::ErrorCounter(const UNICHARSET& unicharset, int fontsize) 00157 : scaled_error_(0.0), rating_epsilon_(kRatingEpsilon), 00158 unichar_counts_(unicharset.size(), unicharset.size(), 0), 00159 ok_score_hist_(0, 101), bad_score_hist_(0, 101), 00160 unicharset_(unicharset) { 00161 Counts empty_counts; 00162 font_counts_.init_to_size(fontsize, empty_counts); 00163 multi_unichar_counts_.init_to_size(unicharset.size(), 0); 00164 } 00165 ErrorCounter::~ErrorCounter() { 00166 } 00167 00168 // Accumulates the errors from the classifier results on a single sample. 00169 // Returns true if debug is true and a CT_UNICHAR_TOPN_ERR error occurred. 00170 // boosting_mode selects the type of error to be used for boosting and the 00171 // is_error_ member of sample is set according to whether the required type 00172 // of error occurred. The font_table provides access to font properties 00173 // for error counting and shape_table is used to understand the relationship 00174 // between unichar_ids and shape_ids in the results 00175 bool ErrorCounter::AccumulateErrors(bool debug, CountTypes boosting_mode, 00176 const FontInfoTable& font_table, 00177 const GenericVector<UnicharRating>& results, 00178 TrainingSample* sample) { 00179 int num_results = results.size(); 00180 int answer_actual_rank = -1; 00181 int font_id = sample->font_id(); 00182 int unichar_id = sample->class_id(); 00183 sample->set_is_error(false); 00184 if (num_results == 0) { 00185 // Reject. We count rejects as a separate category, but still mark the 00186 // sample as an error in case any training module wants to use that to 00187 // improve the classifier. 00188 sample->set_is_error(true); 00189 ++font_counts_[font_id].n[CT_REJECT]; 00190 } else { 00191 // Find rank of correct unichar answer, using rating_epsilon_ to allow 00192 // different answers to score as equal. (Ignoring the font.) 00193 int epsilon_rank = 0; 00194 int answer_epsilon_rank = -1; 00195 int num_top_answers = 0; 00196 double prev_rating = results[0].rating; 00197 bool joined = false; 00198 bool broken = false; 00199 int res_index = 0; 00200 while (res_index < num_results) { 00201 if (results[res_index].rating < prev_rating - rating_epsilon_) { 00202 ++epsilon_rank; 00203 prev_rating = results[res_index].rating; 00204 } 00205 if (results[res_index].unichar_id == unichar_id && 00206 answer_epsilon_rank < 0) { 00207 answer_epsilon_rank = epsilon_rank; 00208 answer_actual_rank = res_index; 00209 } 00210 if (results[res_index].unichar_id == UNICHAR_JOINED && 00211 unicharset_.has_special_codes()) 00212 joined = true; 00213 else if (results[res_index].unichar_id == UNICHAR_BROKEN && 00214 unicharset_.has_special_codes()) 00215 broken = true; 00216 else if (epsilon_rank == 0) 00217 ++num_top_answers; 00218 ++res_index; 00219 } 00220 if (answer_actual_rank != 0) { 00221 // Correct result is not absolute top. 00222 ++font_counts_[font_id].n[CT_UNICHAR_TOPTOP_ERR]; 00223 if (boosting_mode == CT_UNICHAR_TOPTOP_ERR) sample->set_is_error(true); 00224 } 00225 if (answer_epsilon_rank == 0) { 00226 ++font_counts_[font_id].n[CT_UNICHAR_TOP_OK]; 00227 // Unichar OK, but count if multiple unichars. 00228 if (num_top_answers > 1) { 00229 ++font_counts_[font_id].n[CT_OK_MULTI_UNICHAR]; 00230 ++multi_unichar_counts_[unichar_id]; 00231 } 00232 // Check to see if any font in the top choice has attributes that match. 00233 // TODO(rays) It is easy to add counters for individual font attributes 00234 // here if we want them. 00235 if (font_table.SetContainsFontProperties( 00236 font_id, results[answer_actual_rank].fonts)) { 00237 // Font attributes were matched. 00238 // Check for multiple properties. 00239 if (font_table.SetContainsMultipleFontProperties( 00240 results[answer_actual_rank].fonts)) 00241 ++font_counts_[font_id].n[CT_OK_MULTI_FONT]; 00242 } else { 00243 // Font attributes weren't matched. 00244 ++font_counts_[font_id].n[CT_FONT_ATTR_ERR]; 00245 } 00246 } else { 00247 // This is a top unichar error. 00248 ++font_counts_[font_id].n[CT_UNICHAR_TOP1_ERR]; 00249 if (boosting_mode == CT_UNICHAR_TOP1_ERR) sample->set_is_error(true); 00250 // Count maps from unichar id to wrong unichar id. 00251 ++unichar_counts_(unichar_id, results[0].unichar_id); 00252 if (answer_epsilon_rank < 0 || answer_epsilon_rank >= 2) { 00253 // It is also a 2nd choice unichar error. 00254 ++font_counts_[font_id].n[CT_UNICHAR_TOP2_ERR]; 00255 if (boosting_mode == CT_UNICHAR_TOP2_ERR) sample->set_is_error(true); 00256 } 00257 if (answer_epsilon_rank < 0) { 00258 // It is also a top-n choice unichar error. 00259 ++font_counts_[font_id].n[CT_UNICHAR_TOPN_ERR]; 00260 if (boosting_mode == CT_UNICHAR_TOPN_ERR) sample->set_is_error(true); 00261 answer_epsilon_rank = epsilon_rank; 00262 } 00263 } 00264 // Compute mean number of return values and mean rank of correct answer. 00265 font_counts_[font_id].n[CT_NUM_RESULTS] += num_results; 00266 font_counts_[font_id].n[CT_RANK] += answer_epsilon_rank; 00267 if (joined) 00268 ++font_counts_[font_id].n[CT_OK_JOINED]; 00269 if (broken) 00270 ++font_counts_[font_id].n[CT_OK_BROKEN]; 00271 } 00272 // If it was an error for boosting then sum the weight. 00273 if (sample->is_error()) { 00274 scaled_error_ += sample->weight(); 00275 if (debug) { 00276 tprintf("%d results for char %s font %d :", 00277 num_results, unicharset_.id_to_unichar(unichar_id), 00278 font_id); 00279 for (int i = 0; i < num_results; ++i) { 00280 tprintf(" %.3f : %s\n", 00281 results[i].rating, 00282 unicharset_.id_to_unichar(results[i].unichar_id)); 00283 } 00284 return true; 00285 } 00286 int percent = 0; 00287 if (num_results > 0) 00288 percent = IntCastRounded(results[0].rating * 100); 00289 bad_score_hist_.add(percent, 1); 00290 } else { 00291 int percent = 0; 00292 if (answer_actual_rank >= 0) 00293 percent = IntCastRounded(results[answer_actual_rank].rating * 100); 00294 ok_score_hist_.add(percent, 1); 00295 } 00296 return false; 00297 } 00298 00299 // Accumulates counts for junk. Counts only whether the junk was correctly 00300 // rejected or not. 00301 bool ErrorCounter::AccumulateJunk(bool debug, 00302 const GenericVector<UnicharRating>& results, 00303 TrainingSample* sample) { 00304 // For junk we accept no answer, or an explicit shape answer matching the 00305 // class id of the sample. 00306 int num_results = results.size(); 00307 int font_id = sample->font_id(); 00308 int unichar_id = sample->class_id(); 00309 int percent = 0; 00310 if (num_results > 0) 00311 percent = IntCastRounded(results[0].rating * 100); 00312 if (num_results > 0 && results[0].unichar_id != unichar_id) { 00313 // This is a junk error. 00314 ++font_counts_[font_id].n[CT_ACCEPTED_JUNK]; 00315 sample->set_is_error(true); 00316 // It counts as an error for boosting too so sum the weight. 00317 scaled_error_ += sample->weight(); 00318 bad_score_hist_.add(percent, 1); 00319 return debug; 00320 } else { 00321 // Correctly rejected. 00322 ++font_counts_[font_id].n[CT_REJECTED_JUNK]; 00323 sample->set_is_error(false); 00324 ok_score_hist_.add(percent, 1); 00325 } 00326 return false; 00327 } 00328 00329 // Creates a report of the error rate. The report_level controls the detail 00330 // that is reported to stderr via tprintf: 00331 // 0 -> no output. 00332 // >=1 -> bottom-line error rate. 00333 // >=3 -> font-level error rate. 00334 // boosting_mode determines the return value. It selects which (un-weighted) 00335 // error rate to return. 00336 // The fontinfo_table from MasterTrainer provides the names of fonts. 00337 // The it determines the current subset of the training samples. 00338 // If not NULL, the top-choice unichar error rate is saved in unichar_error. 00339 // If not NULL, the report string is saved in fonts_report. 00340 // (Ignoring report_level). 00341 double ErrorCounter::ReportErrors(int report_level, CountTypes boosting_mode, 00342 const FontInfoTable& fontinfo_table, 00343 const SampleIterator& it, 00344 double* unichar_error, 00345 STRING* fonts_report) { 00346 // Compute totals over all the fonts and report individual font results 00347 // when required. 00348 Counts totals; 00349 int fontsize = font_counts_.size(); 00350 for (int f = 0; f < fontsize; ++f) { 00351 // Accumulate counts over fonts. 00352 totals += font_counts_[f]; 00353 STRING font_report; 00354 if (ReportString(false, font_counts_[f], &font_report)) { 00355 if (fonts_report != NULL) { 00356 *fonts_report += fontinfo_table.get(f).name; 00357 *fonts_report += ": "; 00358 *fonts_report += font_report; 00359 *fonts_report += "\n"; 00360 } 00361 if (report_level > 2) { 00362 // Report individual font error rates. 00363 tprintf("%s: %s\n", fontinfo_table.get(f).name, font_report.string()); 00364 } 00365 } 00366 } 00367 // Report the totals. 00368 STRING total_report; 00369 bool any_results = ReportString(true, totals, &total_report); 00370 if (fonts_report != NULL && fonts_report->length() == 0) { 00371 // Make sure we return something even if there were no samples. 00372 *fonts_report = "NoSamplesFound: "; 00373 *fonts_report += total_report; 00374 *fonts_report += "\n"; 00375 } 00376 if (report_level > 0) { 00377 // Report the totals. 00378 STRING total_report; 00379 if (any_results) { 00380 tprintf("TOTAL Scaled Err=%.4g%%, %s\n", 00381 scaled_error_ * 100.0, total_report.string()); 00382 } 00383 // Report the worst substitution error only for now. 00384 if (totals.n[CT_UNICHAR_TOP1_ERR] > 0) { 00385 int charsetsize = unicharset_.size(); 00386 int worst_uni_id = 0; 00387 int worst_result_id = 0; 00388 int worst_err = 0; 00389 for (int u = 0; u < charsetsize; ++u) { 00390 for (int v = 0; v < charsetsize; ++v) { 00391 if (unichar_counts_(u, v) > worst_err) { 00392 worst_err = unichar_counts_(u, v); 00393 worst_uni_id = u; 00394 worst_result_id = v; 00395 } 00396 } 00397 } 00398 if (worst_err > 0) { 00399 tprintf("Worst error = %d:%s -> %s with %d/%d=%.2f%% errors\n", 00400 worst_uni_id, unicharset_.id_to_unichar(worst_uni_id), 00401 unicharset_.id_to_unichar(worst_result_id), 00402 worst_err, totals.n[CT_UNICHAR_TOP1_ERR], 00403 100.0 * worst_err / totals.n[CT_UNICHAR_TOP1_ERR]); 00404 } 00405 } 00406 tprintf("Multi-unichar shape use:\n"); 00407 for (int u = 0; u < multi_unichar_counts_.size(); ++u) { 00408 if (multi_unichar_counts_[u] > 0) { 00409 tprintf("%d multiple answers for unichar: %s\n", 00410 multi_unichar_counts_[u], 00411 unicharset_.id_to_unichar(u)); 00412 } 00413 } 00414 tprintf("OK Score histogram:\n"); 00415 ok_score_hist_.print(); 00416 tprintf("ERROR Score histogram:\n"); 00417 bad_score_hist_.print(); 00418 } 00419 00420 double rates[CT_SIZE]; 00421 if (!ComputeRates(totals, rates)) 00422 return 0.0; 00423 // Set output values if asked for. 00424 if (unichar_error != NULL) 00425 *unichar_error = rates[CT_UNICHAR_TOP1_ERR]; 00426 return rates[boosting_mode]; 00427 } 00428 00429 // Sets the report string to a combined human and machine-readable report 00430 // string of the error rates. 00431 // Returns false if there is no data, leaving report unchanged, unless 00432 // even_if_empty is true. 00433 bool ErrorCounter::ReportString(bool even_if_empty, const Counts& counts, 00434 STRING* report) { 00435 // Compute the error rates. 00436 double rates[CT_SIZE]; 00437 if (!ComputeRates(counts, rates) && !even_if_empty) 00438 return false; 00439 // Using %.4g%%, the length of the output string should exactly match the 00440 // length of the format string, but in case of overflow, allow for +eddd 00441 // on each number. 00442 const int kMaxExtraLength = 5; // Length of +eddd. 00443 // Keep this format string and the snprintf in sync with the CountTypes enum. 00444 const char* format_str = "Unichar=%.4g%%[1], %.4g%%[2], %.4g%%[n], %.4g%%[T] " 00445 "Mult=%.4g%%, Jn=%.4g%%, Brk=%.4g%%, Rej=%.4g%%, " 00446 "FontAttr=%.4g%%, Multi=%.4g%%, " 00447 "Answers=%.3g, Rank=%.3g, " 00448 "OKjunk=%.4g%%, Badjunk=%.4g%%"; 00449 int max_str_len = strlen(format_str) + kMaxExtraLength * (CT_SIZE - 1) + 1; 00450 char* formatted_str = new char[max_str_len]; 00451 snprintf(formatted_str, max_str_len, format_str, 00452 rates[CT_UNICHAR_TOP1_ERR] * 100.0, 00453 rates[CT_UNICHAR_TOP2_ERR] * 100.0, 00454 rates[CT_UNICHAR_TOPN_ERR] * 100.0, 00455 rates[CT_UNICHAR_TOPTOP_ERR] * 100.0, 00456 rates[CT_OK_MULTI_UNICHAR] * 100.0, 00457 rates[CT_OK_JOINED] * 100.0, 00458 rates[CT_OK_BROKEN] * 100.0, 00459 rates[CT_REJECT] * 100.0, 00460 rates[CT_FONT_ATTR_ERR] * 100.0, 00461 rates[CT_OK_MULTI_FONT] * 100.0, 00462 rates[CT_NUM_RESULTS], 00463 rates[CT_RANK], 00464 100.0 * rates[CT_REJECTED_JUNK], 00465 100.0 * rates[CT_ACCEPTED_JUNK]); 00466 *report = formatted_str; 00467 delete [] formatted_str; 00468 // Now append each field of counts with a tab in front so the result can 00469 // be loaded into a spreadsheet. 00470 for (int ct = 0; ct < CT_SIZE; ++ct) 00471 report->add_str_int("\t", counts.n[ct]); 00472 return true; 00473 } 00474 00475 // Computes the error rates and returns in rates which is an array of size 00476 // CT_SIZE. Returns false if there is no data, leaving rates unchanged. 00477 bool ErrorCounter::ComputeRates(const Counts& counts, double rates[CT_SIZE]) { 00478 int ok_samples = counts.n[CT_UNICHAR_TOP_OK] + counts.n[CT_UNICHAR_TOP1_ERR] + 00479 counts.n[CT_REJECT]; 00480 int junk_samples = counts.n[CT_REJECTED_JUNK] + counts.n[CT_ACCEPTED_JUNK]; 00481 // Compute rates for normal chars. 00482 double denominator = static_cast<double>(MAX(ok_samples, 1)); 00483 for (int ct = 0; ct <= CT_RANK; ++ct) 00484 rates[ct] = counts.n[ct] / denominator; 00485 // Compute rates for junk. 00486 denominator = static_cast<double>(MAX(junk_samples, 1)); 00487 for (int ct = CT_REJECTED_JUNK; ct <= CT_ACCEPTED_JUNK; ++ct) 00488 rates[ct] = counts.n[ct] / denominator; 00489 return ok_samples != 0 || junk_samples != 0; 00490 } 00491 00492 ErrorCounter::Counts::Counts() { 00493 memset(n, 0, sizeof(n[0]) * CT_SIZE); 00494 } 00495 // Adds other into this for computing totals. 00496 void ErrorCounter::Counts::operator+=(const Counts& other) { 00497 for (int ct = 0; ct < CT_SIZE; ++ct) 00498 n[ct] += other.n[ct]; 00499 } 00500 00501 00502 } // namespace tesseract. 00503 00504 00505 00506 00507