tesseract  3.03
/usr/local/google/home/jbreiden/tesseract-ocr-read-only/classify/mastertrainer.h
Go to the documentation of this file.
00001 // Copyright 2010 Google Inc. All Rights Reserved.
00002 // Author: rays@google.com (Ray Smith)
00004 // File:        mastertrainer.h
00005 // Description: Trainer to build the MasterClassifier.
00006 // Author:      Ray Smith
00007 // Created:     Wed Nov 03 18:07:01 PDT 2010
00008 //
00009 // (C) Copyright 2010, Google Inc.
00010 // Licensed under the Apache License, Version 2.0 (the "License");
00011 // you may not use this file except in compliance with the License.
00012 // You may obtain a copy of the License at
00013 // http://www.apache.org/licenses/LICENSE-2.0
00014 // Unless required by applicable law or agreed to in writing, software
00015 // distributed under the License is distributed on an "AS IS" BASIS,
00016 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
00017 // See the License for the specific language governing permissions and
00018 // limitations under the License.
00019 //
00021 
00022 #ifndef TESSERACT_TRAINING_MASTERTRAINER_H__
00023 #define TESSERACT_TRAINING_MASTERTRAINER_H__
00024 
00028 #include "classify.h"
00029 #include "cluster.h"
00030 #include "intfx.h"
00031 #include "elst.h"
00032 #include "errorcounter.h"
00033 #include "featdefs.h"
00034 #include "fontinfo.h"
00035 #include "indexmapbidi.h"
00036 #include "intfeaturespace.h"
00037 #include "intfeaturemap.h"
00038 #include "intmatcher.h"
00039 #include "params.h"
00040 #include "shapetable.h"
00041 #include "trainingsample.h"
00042 #include "trainingsampleset.h"
00043 #include "unicharset.h"
00044 
00045 namespace tesseract {
00046 
00047 class ShapeClassifier;
00048 
00049 // Simple struct to hold the distance between two shapes during clustering.
00050 struct ShapeDist {
00051   ShapeDist() : shape1(0), shape2(0), distance(0.0f) {}
00052   ShapeDist(int s1, int s2, float dist)
00053     : shape1(s1), shape2(s2), distance(dist) {}
00054 
00055   // Sort operator to sort in ascending order of distance.
00056   bool operator<(const ShapeDist& other) const {
00057     return distance < other.distance;
00058   }
00059 
00060   int shape1;
00061   int shape2;
00062   float distance;
00063 };
00064 
00065 // Class to encapsulate training processes that use the TrainingSampleSet.
00066 // Initially supports shape clustering and mftrainining.
00067 // Other important features of the MasterTrainer are conditioning the data
00068 // by outlier elimination, replication with perturbation, and serialization.
00069 class MasterTrainer {
00070  public:
00071   MasterTrainer(NormalizationMode norm_mode, bool shape_analysis,
00072                 bool replicate_samples, int debug_level);
00073   ~MasterTrainer();
00074 
00075   // Writes to the given file. Returns false in case of error.
00076   bool Serialize(FILE* fp) const;
00077   // Reads from the given file. Returns false in case of error.
00078   // If swap is true, assumes a big/little-endian swap is needed.
00079   bool DeSerialize(bool swap, FILE* fp);
00080 
00081   // Loads an initial unicharset, or sets one up if the file cannot be read.
00082   void LoadUnicharset(const char* filename);
00083 
00084   // Sets the feature space definition.
00085   void SetFeatureSpace(const IntFeatureSpace& fs) {
00086     feature_space_ = fs;
00087     feature_map_.Init(fs);
00088   }
00089 
00090   // Reads the samples and their features from the given file,
00091   // adding them to the trainer with the font_id from the content of the file.
00092   // If verification, then these are verification samples, not training.
00093   void ReadTrainingSamples(const char* page_name,
00094                            const FEATURE_DEFS_STRUCT& feature_defs,
00095                            bool verification);
00096 
00097   // Adds the given single sample to the trainer, setting the classid
00098   // appropriately from the given unichar_str.
00099   void AddSample(bool verification, const char* unichar_str,
00100                  TrainingSample* sample);
00101 
00102   // Loads all pages from the given tif filename and append to page_images_.
00103   // Must be called after ReadTrainingSamples, as the current number of images
00104   // is used as an offset for page numbers in the samples.
00105   void LoadPageImages(const char* filename);
00106 
00107   // Cleans up the samples after initial load from the tr files, and prior to
00108   // saving the MasterTrainer:
00109   // Remaps fragmented chars if running shape anaylsis.
00110   // Sets up the samples appropriately for class/fontwise access.
00111   // Deletes outlier samples.
00112   void PostLoadCleanup();
00113 
00114   // Gets the samples ready for training. Use after both
00115   // ReadTrainingSamples+PostLoadCleanup or DeSerialize.
00116   // Re-indexes the features and computes canonical and cloud features.
00117   void PreTrainingSetup();
00118 
00119   // Sets up the master_shapes_ table, which tells which fonts should stay
00120   // together until they get to a leaf node classifier.
00121   void SetupMasterShapes();
00122 
00123   // Adds the junk_samples_ to the main samples_ set. Junk samples are initially
00124   // fragments and n-grams (all incorrectly segmented characters).
00125   // Various training functions may result in incorrectly segmented characters
00126   // being added to the unicharset of the main samples, perhaps because they
00127   // form a "radical" decomposition of some (Indic) grapheme, or because they
00128   // just look the same as a real character (like rn/m)
00129   // This function moves all the junk samples, to the main samples_ set, but
00130   // desirable junk, being any sample for which the unichar already exists in
00131   // the samples_ unicharset gets the unichar-ids re-indexed to match, but
00132   // anything else gets re-marked as unichar_id 0 (space character) to identify
00133   // it as junk to the error counter.
00134   void IncludeJunk();
00135 
00136   // Replicates the samples and perturbs them if the enable_replication_ flag
00137   // is set. MUST be used after the last call to OrganizeByFontAndClass on
00138   // the training samples, ie after IncludeJunk if it is going to be used, as
00139   // OrganizeByFontAndClass will eat the replicated samples into the regular
00140   // samples.
00141   void ReplicateAndRandomizeSamplesIfRequired();
00142 
00143   // Loads the basic font properties file into fontinfo_table_.
00144   // Returns false on failure.
00145   bool LoadFontInfo(const char* filename);
00146 
00147   // Loads the xheight font properties file into xheights_.
00148   // Returns false on failure.
00149   bool LoadXHeights(const char* filename);
00150 
00151   // Reads spacing stats from filename and adds them to fontinfo_table.
00152   // Returns false on failure.
00153   bool AddSpacingInfo(const char *filename);
00154 
00155   // Returns the font id corresponding to the given font name.
00156   // Returns -1 if the font cannot be found.
00157   int GetFontInfoId(const char* font_name);
00158   // Returns the font_id of the closest matching font name to the given
00159   // filename. It is assumed that a substring of the filename will match
00160   // one of the fonts. If more than one is matched, the longest is returned.
00161   int GetBestMatchingFontInfoId(const char* filename);
00162 
00163   // Returns the filename of the tr file corresponding to the command-line
00164   // argument with the given index.
00165   const STRING& GetTRFileName(int index) const {
00166     return tr_filenames_[index];
00167   }
00168 
00169   // Sets up a flat shapetable with one shape per class/font combination.
00170   void SetupFlatShapeTable(ShapeTable* shape_table);
00171 
00172   // Sets up a Clusterer for mftraining on a single shape_id.
00173   // Call FreeClusterer on the return value after use.
00174   CLUSTERER* SetupForClustering(const ShapeTable& shape_table,
00175                                 const FEATURE_DEFS_STRUCT& feature_defs,
00176                                 int shape_id, int* num_samples);
00177 
00178   // Writes the given float_classes (produced by SetupForFloat2Int) as inttemp
00179   // to the given inttemp_file, and the corresponding pffmtable.
00180   // The unicharset is the original encoding of graphemes, and shape_set should
00181   // match the size of the shape_table, and may possibly be totally fake.
00182   void WriteInttempAndPFFMTable(const UNICHARSET& unicharset,
00183                                 const UNICHARSET& shape_set,
00184                                 const ShapeTable& shape_table,
00185                                 CLASS_STRUCT* float_classes,
00186                                 const char* inttemp_file,
00187                                 const char* pffmtable_file);
00188 
00189   const UNICHARSET& unicharset() const {
00190     return samples_.unicharset();
00191   }
00192   TrainingSampleSet* GetSamples() {
00193     return &samples_;
00194   }
00195   const ShapeTable& master_shapes() const {
00196     return master_shapes_;
00197   }
00198 
00199   // Generates debug output relating to the canonical distance between the
00200   // two given UTF8 grapheme strings.
00201   void DebugCanonical(const char* unichar_str1, const char* unichar_str2);
00202   #ifndef GRAPHICS_DISABLED
00203   // Debugging for cloud/canonical features.
00204   // Displays a Features window containing:
00205   // If unichar_str2 is in the unicharset, and canonical_font is non-negative,
00206   // displays the canonical features of the char/font combination in red.
00207   // If unichar_str1 is in the unicharset, and cloud_font is non-negative,
00208   // displays the cloud feature of the char/font combination in green.
00209   // The canonical features are drawn first to show which ones have no
00210   // matches in the cloud features.
00211   // Until the features window is destroyed, each click in the features window
00212   // will display the samples that have that feature in a separate window.
00213   void DisplaySamples(const char* unichar_str1, int cloud_font,
00214                       const char* unichar_str2, int canonical_font);
00215   #endif  // GRAPHICS_DISABLED
00216 
00217   void TestClassifierVOld(bool replicate_samples,
00218                           ShapeClassifier* test_classifier,
00219                           ShapeClassifier* old_classifier);
00220 
00221   // Tests the given test_classifier on the internal samples.
00222   // See TestClassifier for details.
00223   void TestClassifierOnSamples(CountTypes error_mode,
00224                                int report_level,
00225                                bool replicate_samples,
00226                                ShapeClassifier* test_classifier,
00227                                STRING* report_string);
00228   // Tests the given test_classifier on the given samples
00229   // error_mode indicates what counts as an error.
00230   // report_levels:
00231   // 0 = no output.
00232   // 1 = bottom-line error rate.
00233   // 2 = bottom-line error rate + time.
00234   // 3 = font-level error rate + time.
00235   // 4 = list of all errors + short classifier debug output on 16 errors.
00236   // 5 = list of all errors + short classifier debug output on 25 errors.
00237   // If replicate_samples is true, then the test is run on an extended test
00238   // sample including replicated and systematically perturbed samples.
00239   // If report_string is non-NULL, a summary of the results for each font
00240   // is appended to the report_string.
00241   double TestClassifier(CountTypes error_mode,
00242                         int report_level,
00243                         bool replicate_samples,
00244                         TrainingSampleSet* samples,
00245                         ShapeClassifier* test_classifier,
00246                         STRING* report_string);
00247 
00248   // Returns the average (in some sense) distance between the two given
00249   // shapes, which may contain multiple fonts and/or unichars.
00250   // This function is public to facilitate testing.
00251   float ShapeDistance(const ShapeTable& shapes, int s1, int s2);
00252 
00253  private:
00254   // Replaces samples that are always fragmented with the corresponding
00255   // fragment samples.
00256   void ReplaceFragmentedSamples();
00257 
00258   // Runs a hierarchical agglomerative clustering to merge shapes in the given
00259   // shape_table, while satisfying the given constraints:
00260   // * End with at least min_shapes left in shape_table,
00261   // * No shape shall have more than max_shape_unichars in it,
00262   // * Don't merge shapes where the distance between them exceeds max_dist.
00263   void ClusterShapes(int min_shapes, int max_shape_unichars,
00264                      float max_dist, ShapeTable* shape_table);
00265 
00266  private:
00267   NormalizationMode norm_mode_;
00268   // Character set we are training for.
00269   UNICHARSET unicharset_;
00270   // Original feature space. Subspace mapping is contained in feature_map_.
00271   IntFeatureSpace feature_space_;
00272   TrainingSampleSet samples_;
00273   TrainingSampleSet junk_samples_;
00274   TrainingSampleSet verify_samples_;
00275   // Master shape table defines what fonts stay together until the leaves.
00276   ShapeTable master_shapes_;
00277   // Flat shape table has each unichar/font id pair in a separate shape.
00278   ShapeTable flat_shapes_;
00279   // Font metrics gathered from multiple files.
00280   FontInfoTable fontinfo_table_;
00281   // Array of xheights indexed by font ids in fontinfo_table_;
00282   GenericVector<inT32> xheights_;
00283 
00284   // Non-serialized data initialized by other means or used temporarily
00285   // during loading of training samples.
00286   // Number of different class labels in unicharset_.
00287   int charsetsize_;
00288   // Flag to indicate that we are running shape analysis and need fragments
00289   // fixing.
00290   bool enable_shape_anaylsis_;
00291   // Flag to indicate that sample replication is required.
00292   bool enable_replication_;
00293   // Array of classids of fragments that replace the correctly segmented chars.
00294   int* fragments_;
00295   // Classid of previous correctly segmented sample that was added.
00296   int prev_unichar_id_;
00297   // Debug output control.
00298   int debug_level_;
00299   // Feature map used to construct reduced feature spaces for compact
00300   // classifiers.
00301   IntFeatureMap feature_map_;
00302   // Vector of Pix pointers used for classifiers that need the image.
00303   // Indexed by page_num_ in the samples.
00304   // These images are owned by the trainer and need to be pixDestroyed.
00305   GenericVector<Pix*> page_images_;
00306   // Vector of filenames of loaded tr files.
00307   GenericVector<STRING> tr_filenames_;
00308 };
00309 
00310 }  // namespace tesseract.
00311 
00312 #endif
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines