tesseract  3.03
/usr/local/google/home/jbreiden/tesseract-ocr-read-only/classify/errorcounter.cpp
Go to the documentation of this file.
00001 // Copyright 2011 Google Inc. All Rights Reserved.
00002 // Author: rays@google.com (Ray Smith)
00003 //
00004 // Licensed under the Apache License, Version 2.0 (the "License");
00005 // you may not use this file except in compliance with the License.
00006 // You may obtain a copy of the License at
00007 // http://www.apache.org/licenses/LICENSE-2.0
00008 // Unless required by applicable law or agreed to in writing, software
00009 // distributed under the License is distributed on an "AS IS" BASIS,
00010 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
00011 // See the License for the specific language governing permissions and
00012 // limitations under the License.
00013 //
00015 #include <ctime>
00016 
00017 #include "errorcounter.h"
00018 
00019 #include "fontinfo.h"
00020 #include "ndminx.h"
00021 #include "sampleiterator.h"
00022 #include "shapeclassifier.h"
00023 #include "shapetable.h"
00024 #include "trainingsample.h"
00025 #include "trainingsampleset.h"
00026 #include "unicity_table.h"
00027 
00028 namespace tesseract {
00029 
00030 // Difference in result rating to be thought of as an "equal" choice.
00031 const double kRatingEpsilon = 1.0 / 32;
00032 
00033 // Tests a classifier, computing its error rate.
00034 // See errorcounter.h for description of arguments.
00035 // Iterates over the samples, calling the classifier in normal/silent mode.
00036 // If the classifier makes a CT_UNICHAR_TOPN_ERR error, and the appropriate
00037 // report_level is set (4 or greater), it will then call the classifier again
00038 // with a debug flag and a keep_this argument to find out what is going on.
00039 double ErrorCounter::ComputeErrorRate(ShapeClassifier* classifier,
00040     int report_level, CountTypes boosting_mode,
00041     const FontInfoTable& fontinfo_table,
00042     const GenericVector<Pix*>& page_images, SampleIterator* it,
00043     double* unichar_error,  double* scaled_error, STRING* fonts_report) {
00044   int fontsize = it->sample_set()->NumFonts();
00045   ErrorCounter counter(classifier->GetUnicharset(), fontsize);
00046   GenericVector<UnicharRating> results;
00047 
00048   clock_t start = clock();
00049   int total_samples = 0;
00050   double unscaled_error = 0.0;
00051   // Set a number of samples on which to run the classify debug mode.
00052   int error_samples = report_level > 3 ? report_level * report_level : 0;
00053   // Iterate over all the samples, accumulating errors.
00054   for (it->Begin(); !it->AtEnd(); it->Next()) {
00055     TrainingSample* mutable_sample = it->MutableSample();
00056     int page_index = mutable_sample->page_num();
00057     Pix* page_pix = 0 <= page_index && page_index < page_images.size()
00058                   ? page_images[page_index] : NULL;
00059     // No debug, no keep this.
00060     classifier->UnicharClassifySample(*mutable_sample, page_pix, 0,
00061                                       INVALID_UNICHAR_ID, &results);
00062     bool debug_it = false;
00063     int correct_id = mutable_sample->class_id();
00064     if (counter.unicharset_.has_special_codes() &&
00065         (correct_id == UNICHAR_SPACE || correct_id == UNICHAR_JOINED ||
00066          correct_id == UNICHAR_BROKEN)) {
00067       // This is junk so use the special counter.
00068       debug_it = counter.AccumulateJunk(report_level > 3,
00069                                         results,
00070                                         mutable_sample);
00071     } else {
00072       debug_it = counter.AccumulateErrors(report_level > 3, boosting_mode,
00073                                           fontinfo_table,
00074                                           results, mutable_sample);
00075     }
00076     if (debug_it && error_samples > 0) {
00077       // Running debug, keep the correct answer, and debug the classifier.
00078       tprintf("Error on sample %d: %s Classifier debug output:\n",
00079               it->GlobalSampleIndex(),
00080               it->sample_set()->SampleToString(*mutable_sample).string());
00081       classifier->DebugDisplay(*mutable_sample, page_pix, correct_id);
00082       --error_samples;
00083     }
00084     ++total_samples;
00085   }
00086   double total_time = 1.0 * (clock() - start) / CLOCKS_PER_SEC;
00087   // Create the appropriate error report.
00088   unscaled_error = counter.ReportErrors(report_level, boosting_mode,
00089                                         fontinfo_table,
00090                                         *it, unichar_error, fonts_report);
00091   if (scaled_error != NULL) *scaled_error = counter.scaled_error_;
00092   if (report_level > 1) {
00093     // It is useful to know the time in microseconds/char.
00094     tprintf("Errors computed in %.2fs at %.1f μs/char\n",
00095             total_time, 1000000.0 * total_time / total_samples);
00096   }
00097   return unscaled_error;
00098 }
00099 
00100 // Tests a pair of classifiers, debugging errors of the new against the old.
00101 // See errorcounter.h for description of arguments.
00102 // Iterates over the samples, calling the classifiers in normal/silent mode.
00103 // If the new_classifier makes a boosting_mode error that the old_classifier
00104 // does not, it will then call the new_classifier again with a debug flag
00105 // and a keep_this argument to find out what is going on.
00106 void ErrorCounter::DebugNewErrors(
00107     ShapeClassifier* new_classifier, ShapeClassifier* old_classifier,
00108     CountTypes boosting_mode,
00109     const FontInfoTable& fontinfo_table,
00110     const GenericVector<Pix*>& page_images, SampleIterator* it) {
00111   int fontsize = it->sample_set()->NumFonts();
00112   ErrorCounter old_counter(old_classifier->GetUnicharset(), fontsize);
00113   ErrorCounter new_counter(new_classifier->GetUnicharset(), fontsize);
00114   GenericVector<UnicharRating> results;
00115 
00116   int total_samples = 0;
00117   int error_samples = 25;
00118   int total_new_errors = 0;
00119   // Iterate over all the samples, accumulating errors.
00120   for (it->Begin(); !it->AtEnd(); it->Next()) {
00121     TrainingSample* mutable_sample = it->MutableSample();
00122     int page_index = mutable_sample->page_num();
00123     Pix* page_pix = 0 <= page_index && page_index < page_images.size()
00124                   ? page_images[page_index] : NULL;
00125     // No debug, no keep this.
00126     old_classifier->UnicharClassifySample(*mutable_sample, page_pix, 0,
00127                                           INVALID_UNICHAR_ID, &results);
00128     int correct_id = mutable_sample->class_id();
00129     if (correct_id != 0 &&
00130         !old_counter.AccumulateErrors(true, boosting_mode, fontinfo_table,
00131                                       results, mutable_sample)) {
00132       // old classifier was correct, check the new one.
00133       new_classifier->UnicharClassifySample(*mutable_sample, page_pix, 0,
00134                                             INVALID_UNICHAR_ID, &results);
00135       if (correct_id != 0 &&
00136           new_counter.AccumulateErrors(true, boosting_mode, fontinfo_table,
00137                                         results, mutable_sample)) {
00138         tprintf("New Error on sample %d: Classifier debug output:\n",
00139                 it->GlobalSampleIndex());
00140         ++total_new_errors;
00141         new_classifier->UnicharClassifySample(*mutable_sample, page_pix, 1,
00142                                               correct_id, &results);
00143         if (results.size() > 0 && error_samples > 0) {
00144           new_classifier->DebugDisplay(*mutable_sample, page_pix, correct_id);
00145           --error_samples;
00146         }
00147       }
00148     }
00149     ++total_samples;
00150   }
00151   tprintf("Total new errors = %d\n", total_new_errors);
00152 }
00153 
00154 // Constructor is private. Only anticipated use of ErrorCounter is via
00155 // the static ComputeErrorRate.
00156 ErrorCounter::ErrorCounter(const UNICHARSET& unicharset, int fontsize)
00157   : scaled_error_(0.0), rating_epsilon_(kRatingEpsilon),
00158     unichar_counts_(unicharset.size(), unicharset.size(), 0),
00159     ok_score_hist_(0, 101), bad_score_hist_(0, 101),
00160     unicharset_(unicharset) {
00161   Counts empty_counts;
00162   font_counts_.init_to_size(fontsize, empty_counts);
00163   multi_unichar_counts_.init_to_size(unicharset.size(), 0);
00164 }
00165 ErrorCounter::~ErrorCounter() {
00166 }
00167 
00168 // Accumulates the errors from the classifier results on a single sample.
00169 // Returns true if debug is true and a CT_UNICHAR_TOPN_ERR error occurred.
00170 // boosting_mode selects the type of error to be used for boosting and the
00171 // is_error_ member of sample is set according to whether the required type
00172 // of error occurred. The font_table provides access to font properties
00173 // for error counting and shape_table is used to understand the relationship
00174 // between unichar_ids and shape_ids in the results
00175 bool ErrorCounter::AccumulateErrors(bool debug, CountTypes boosting_mode,
00176                                     const FontInfoTable& font_table,
00177                                     const GenericVector<UnicharRating>& results,
00178                                     TrainingSample* sample) {
00179   int num_results = results.size();
00180   int answer_actual_rank = -1;
00181   int font_id = sample->font_id();
00182   int unichar_id = sample->class_id();
00183   sample->set_is_error(false);
00184   if (num_results == 0) {
00185     // Reject. We count rejects as a separate category, but still mark the
00186     // sample as an error in case any training module wants to use that to
00187     // improve the classifier.
00188     sample->set_is_error(true);
00189     ++font_counts_[font_id].n[CT_REJECT];
00190   } else {
00191     // Find rank of correct unichar answer, using rating_epsilon_ to allow
00192     // different answers to score as equal. (Ignoring the font.)
00193     int epsilon_rank = 0;
00194     int answer_epsilon_rank = -1;
00195     int num_top_answers = 0;
00196     double prev_rating = results[0].rating;
00197     bool joined = false;
00198     bool broken = false;
00199     int res_index = 0;
00200     while (res_index < num_results) {
00201       if (results[res_index].rating < prev_rating - rating_epsilon_) {
00202         ++epsilon_rank;
00203         prev_rating = results[res_index].rating;
00204       }
00205       if (results[res_index].unichar_id == unichar_id &&
00206           answer_epsilon_rank < 0) {
00207         answer_epsilon_rank = epsilon_rank;
00208         answer_actual_rank = res_index;
00209       }
00210       if (results[res_index].unichar_id == UNICHAR_JOINED &&
00211           unicharset_.has_special_codes())
00212         joined = true;
00213       else if (results[res_index].unichar_id == UNICHAR_BROKEN &&
00214                unicharset_.has_special_codes())
00215         broken = true;
00216       else if (epsilon_rank == 0)
00217         ++num_top_answers;
00218       ++res_index;
00219     }
00220     if (answer_actual_rank != 0) {
00221       // Correct result is not absolute top.
00222       ++font_counts_[font_id].n[CT_UNICHAR_TOPTOP_ERR];
00223       if (boosting_mode == CT_UNICHAR_TOPTOP_ERR) sample->set_is_error(true);
00224     }
00225     if (answer_epsilon_rank == 0) {
00226       ++font_counts_[font_id].n[CT_UNICHAR_TOP_OK];
00227       // Unichar OK, but count if multiple unichars.
00228       if (num_top_answers > 1) {
00229         ++font_counts_[font_id].n[CT_OK_MULTI_UNICHAR];
00230         ++multi_unichar_counts_[unichar_id];
00231       }
00232       // Check to see if any font in the top choice has attributes that match.
00233       // TODO(rays) It is easy to add counters for individual font attributes
00234       // here if we want them.
00235       if (font_table.SetContainsFontProperties(
00236           font_id, results[answer_actual_rank].fonts)) {
00237         // Font attributes were matched.
00238         // Check for multiple properties.
00239         if (font_table.SetContainsMultipleFontProperties(
00240             results[answer_actual_rank].fonts))
00241           ++font_counts_[font_id].n[CT_OK_MULTI_FONT];
00242       } else {
00243         // Font attributes weren't matched.
00244         ++font_counts_[font_id].n[CT_FONT_ATTR_ERR];
00245       }
00246     } else {
00247       // This is a top unichar error.
00248       ++font_counts_[font_id].n[CT_UNICHAR_TOP1_ERR];
00249       if (boosting_mode == CT_UNICHAR_TOP1_ERR) sample->set_is_error(true);
00250       // Count maps from unichar id to wrong unichar id.
00251       ++unichar_counts_(unichar_id, results[0].unichar_id);
00252       if (answer_epsilon_rank < 0 || answer_epsilon_rank >= 2) {
00253         // It is also a 2nd choice unichar error.
00254         ++font_counts_[font_id].n[CT_UNICHAR_TOP2_ERR];
00255         if (boosting_mode == CT_UNICHAR_TOP2_ERR) sample->set_is_error(true);
00256       }
00257       if (answer_epsilon_rank < 0) {
00258         // It is also a top-n choice unichar error.
00259         ++font_counts_[font_id].n[CT_UNICHAR_TOPN_ERR];
00260         if (boosting_mode == CT_UNICHAR_TOPN_ERR) sample->set_is_error(true);
00261         answer_epsilon_rank = epsilon_rank;
00262       }
00263     }
00264     // Compute mean number of return values and mean rank of correct answer.
00265     font_counts_[font_id].n[CT_NUM_RESULTS] += num_results;
00266     font_counts_[font_id].n[CT_RANK] += answer_epsilon_rank;
00267     if (joined)
00268       ++font_counts_[font_id].n[CT_OK_JOINED];
00269     if (broken)
00270       ++font_counts_[font_id].n[CT_OK_BROKEN];
00271   }
00272   // If it was an error for boosting then sum the weight.
00273   if (sample->is_error()) {
00274     scaled_error_ += sample->weight();
00275     if (debug) {
00276       tprintf("%d results for char %s font %d :",
00277               num_results, unicharset_.id_to_unichar(unichar_id),
00278               font_id);
00279       for (int i = 0; i < num_results; ++i) {
00280         tprintf(" %.3f : %s\n",
00281                 results[i].rating,
00282                 unicharset_.id_to_unichar(results[i].unichar_id));
00283       }
00284       return true;
00285     }
00286     int percent = 0;
00287     if (num_results > 0)
00288       percent = IntCastRounded(results[0].rating * 100);
00289     bad_score_hist_.add(percent, 1);
00290   } else {
00291     int percent = 0;
00292     if (answer_actual_rank >= 0)
00293       percent = IntCastRounded(results[answer_actual_rank].rating * 100);
00294     ok_score_hist_.add(percent, 1);
00295   }
00296   return false;
00297 }
00298 
00299 // Accumulates counts for junk. Counts only whether the junk was correctly
00300 // rejected or not.
00301 bool ErrorCounter::AccumulateJunk(bool debug,
00302                                   const GenericVector<UnicharRating>& results,
00303                                   TrainingSample* sample) {
00304   // For junk we accept no answer, or an explicit shape answer matching the
00305   // class id of the sample.
00306   int num_results = results.size();
00307   int font_id = sample->font_id();
00308   int unichar_id = sample->class_id();
00309   int percent = 0;
00310   if (num_results > 0)
00311     percent = IntCastRounded(results[0].rating * 100);
00312   if (num_results > 0 && results[0].unichar_id != unichar_id) {
00313     // This is a junk error.
00314     ++font_counts_[font_id].n[CT_ACCEPTED_JUNK];
00315     sample->set_is_error(true);
00316     // It counts as an error for boosting too so sum the weight.
00317     scaled_error_ += sample->weight();
00318     bad_score_hist_.add(percent, 1);
00319     return debug;
00320   } else {
00321     // Correctly rejected.
00322     ++font_counts_[font_id].n[CT_REJECTED_JUNK];
00323     sample->set_is_error(false);
00324     ok_score_hist_.add(percent, 1);
00325   }
00326   return false;
00327 }
00328 
00329 // Creates a report of the error rate. The report_level controls the detail
00330 // that is reported to stderr via tprintf:
00331 // 0   -> no output.
00332 // >=1 -> bottom-line error rate.
00333 // >=3 -> font-level error rate.
00334 // boosting_mode determines the return value. It selects which (un-weighted)
00335 // error rate to return.
00336 // The fontinfo_table from MasterTrainer provides the names of fonts.
00337 // The it determines the current subset of the training samples.
00338 // If not NULL, the top-choice unichar error rate is saved in unichar_error.
00339 // If not NULL, the report string is saved in fonts_report.
00340 // (Ignoring report_level).
00341 double ErrorCounter::ReportErrors(int report_level, CountTypes boosting_mode,
00342                                   const FontInfoTable& fontinfo_table,
00343                                   const SampleIterator& it,
00344                                   double* unichar_error,
00345                                   STRING* fonts_report) {
00346   // Compute totals over all the fonts and report individual font results
00347   // when required.
00348   Counts totals;
00349   int fontsize = font_counts_.size();
00350   for (int f = 0; f < fontsize; ++f) {
00351     // Accumulate counts over fonts.
00352     totals += font_counts_[f];
00353     STRING font_report;
00354     if (ReportString(false, font_counts_[f], &font_report)) {
00355       if (fonts_report != NULL) {
00356         *fonts_report += fontinfo_table.get(f).name;
00357         *fonts_report += ": ";
00358         *fonts_report += font_report;
00359         *fonts_report += "\n";
00360       }
00361       if (report_level > 2) {
00362         // Report individual font error rates.
00363         tprintf("%s: %s\n", fontinfo_table.get(f).name, font_report.string());
00364       }
00365     }
00366   }
00367   // Report the totals.
00368   STRING total_report;
00369   bool any_results = ReportString(true, totals, &total_report);
00370   if (fonts_report != NULL && fonts_report->length() == 0) {
00371     // Make sure we return something even if there were no samples.
00372     *fonts_report = "NoSamplesFound: ";
00373     *fonts_report += total_report;
00374     *fonts_report += "\n";
00375   }
00376   if (report_level > 0) {
00377     // Report the totals.
00378     STRING total_report;
00379     if (any_results) {
00380       tprintf("TOTAL Scaled Err=%.4g%%, %s\n",
00381               scaled_error_ * 100.0, total_report.string());
00382     }
00383     // Report the worst substitution error only for now.
00384     if (totals.n[CT_UNICHAR_TOP1_ERR] > 0) {
00385       int charsetsize = unicharset_.size();
00386       int worst_uni_id = 0;
00387       int worst_result_id = 0;
00388       int worst_err = 0;
00389       for (int u = 0; u < charsetsize; ++u) {
00390         for (int v = 0; v < charsetsize; ++v) {
00391           if (unichar_counts_(u, v) > worst_err) {
00392             worst_err = unichar_counts_(u, v);
00393             worst_uni_id = u;
00394             worst_result_id = v;
00395           }
00396         }
00397       }
00398       if (worst_err > 0) {
00399         tprintf("Worst error = %d:%s -> %s with %d/%d=%.2f%% errors\n",
00400                 worst_uni_id, unicharset_.id_to_unichar(worst_uni_id),
00401                 unicharset_.id_to_unichar(worst_result_id),
00402                 worst_err, totals.n[CT_UNICHAR_TOP1_ERR],
00403                 100.0 * worst_err / totals.n[CT_UNICHAR_TOP1_ERR]);
00404       }
00405     }
00406     tprintf("Multi-unichar shape use:\n");
00407     for (int u = 0; u < multi_unichar_counts_.size(); ++u) {
00408       if (multi_unichar_counts_[u] > 0) {
00409         tprintf("%d multiple answers for unichar: %s\n",
00410                 multi_unichar_counts_[u],
00411                 unicharset_.id_to_unichar(u));
00412       }
00413     }
00414     tprintf("OK Score histogram:\n");
00415     ok_score_hist_.print();
00416     tprintf("ERROR Score histogram:\n");
00417     bad_score_hist_.print();
00418   }
00419 
00420   double rates[CT_SIZE];
00421   if (!ComputeRates(totals, rates))
00422     return 0.0;
00423   // Set output values if asked for.
00424   if (unichar_error != NULL)
00425     *unichar_error = rates[CT_UNICHAR_TOP1_ERR];
00426   return rates[boosting_mode];
00427 }
00428 
00429 // Sets the report string to a combined human and machine-readable report
00430 // string of the error rates.
00431 // Returns false if there is no data, leaving report unchanged, unless
00432 // even_if_empty is true.
00433 bool ErrorCounter::ReportString(bool even_if_empty, const Counts& counts,
00434                                 STRING* report) {
00435   // Compute the error rates.
00436   double rates[CT_SIZE];
00437   if (!ComputeRates(counts, rates) && !even_if_empty)
00438     return false;
00439   // Using %.4g%%, the length of the output string should exactly match the
00440   // length of the format string, but in case of overflow, allow for +eddd
00441   // on each number.
00442   const int kMaxExtraLength = 5;  // Length of +eddd.
00443   // Keep this format string and the snprintf in sync with the CountTypes enum.
00444   const char* format_str = "Unichar=%.4g%%[1], %.4g%%[2], %.4g%%[n], %.4g%%[T] "
00445                            "Mult=%.4g%%, Jn=%.4g%%, Brk=%.4g%%, Rej=%.4g%%, "
00446                            "FontAttr=%.4g%%, Multi=%.4g%%, "
00447                            "Answers=%.3g, Rank=%.3g, "
00448                            "OKjunk=%.4g%%, Badjunk=%.4g%%";
00449   int max_str_len = strlen(format_str) + kMaxExtraLength * (CT_SIZE - 1) + 1;
00450   char* formatted_str = new char[max_str_len];
00451   snprintf(formatted_str, max_str_len, format_str,
00452            rates[CT_UNICHAR_TOP1_ERR] * 100.0,
00453            rates[CT_UNICHAR_TOP2_ERR] * 100.0,
00454            rates[CT_UNICHAR_TOPN_ERR] * 100.0,
00455            rates[CT_UNICHAR_TOPTOP_ERR] * 100.0,
00456            rates[CT_OK_MULTI_UNICHAR] * 100.0,
00457            rates[CT_OK_JOINED] * 100.0,
00458            rates[CT_OK_BROKEN] * 100.0,
00459            rates[CT_REJECT] * 100.0,
00460            rates[CT_FONT_ATTR_ERR] * 100.0,
00461            rates[CT_OK_MULTI_FONT] * 100.0,
00462            rates[CT_NUM_RESULTS],
00463            rates[CT_RANK],
00464            100.0 * rates[CT_REJECTED_JUNK],
00465            100.0 * rates[CT_ACCEPTED_JUNK]);
00466   *report = formatted_str;
00467   delete [] formatted_str;
00468   // Now append each field of counts with a tab in front so the result can
00469   // be loaded into a spreadsheet.
00470   for (int ct = 0; ct < CT_SIZE; ++ct)
00471     report->add_str_int("\t", counts.n[ct]);
00472   return true;
00473 }
00474 
00475 // Computes the error rates and returns in rates which is an array of size
00476 // CT_SIZE. Returns false if there is no data, leaving rates unchanged.
00477 bool ErrorCounter::ComputeRates(const Counts& counts, double rates[CT_SIZE]) {
00478   int ok_samples = counts.n[CT_UNICHAR_TOP_OK] + counts.n[CT_UNICHAR_TOP1_ERR] +
00479       counts.n[CT_REJECT];
00480   int junk_samples = counts.n[CT_REJECTED_JUNK] + counts.n[CT_ACCEPTED_JUNK];
00481   // Compute rates for normal chars.
00482   double denominator = static_cast<double>(MAX(ok_samples, 1));
00483   for (int ct = 0; ct <= CT_RANK; ++ct)
00484     rates[ct] = counts.n[ct] / denominator;
00485   // Compute rates for junk.
00486   denominator = static_cast<double>(MAX(junk_samples, 1));
00487   for (int ct = CT_REJECTED_JUNK; ct <= CT_ACCEPTED_JUNK; ++ct)
00488     rates[ct] = counts.n[ct] / denominator;
00489   return ok_samples != 0 || junk_samples != 0;
00490 }
00491 
00492 ErrorCounter::Counts::Counts() {
00493   memset(n, 0, sizeof(n[0]) * CT_SIZE);
00494 }
00495 // Adds other into this for computing totals.
00496 void ErrorCounter::Counts::operator+=(const Counts& other) {
00497   for (int ct = 0; ct < CT_SIZE; ++ct)
00498     n[ct] += other.n[ct];
00499 }
00500 
00501 
00502 }  // namespace tesseract.
00503 
00504 
00505 
00506 
00507 
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines