tesseract
3.03
|
00001 // Copyright 2010 Google Inc. All Rights Reserved. 00002 // Author: rays@google.com (Ray Smith) 00004 // File: mastertrainer.h 00005 // Description: Trainer to build the MasterClassifier. 00006 // Author: Ray Smith 00007 // Created: Wed Nov 03 18:07:01 PDT 2010 00008 // 00009 // (C) Copyright 2010, Google Inc. 00010 // Licensed under the Apache License, Version 2.0 (the "License"); 00011 // you may not use this file except in compliance with the License. 00012 // You may obtain a copy of the License at 00013 // http://www.apache.org/licenses/LICENSE-2.0 00014 // Unless required by applicable law or agreed to in writing, software 00015 // distributed under the License is distributed on an "AS IS" BASIS, 00016 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 00017 // See the License for the specific language governing permissions and 00018 // limitations under the License. 00019 // 00021 00022 #ifndef TESSERACT_TRAINING_MASTERTRAINER_H__ 00023 #define TESSERACT_TRAINING_MASTERTRAINER_H__ 00024 00028 #include "classify.h" 00029 #include "cluster.h" 00030 #include "intfx.h" 00031 #include "elst.h" 00032 #include "errorcounter.h" 00033 #include "featdefs.h" 00034 #include "fontinfo.h" 00035 #include "indexmapbidi.h" 00036 #include "intfeaturespace.h" 00037 #include "intfeaturemap.h" 00038 #include "intmatcher.h" 00039 #include "params.h" 00040 #include "shapetable.h" 00041 #include "trainingsample.h" 00042 #include "trainingsampleset.h" 00043 #include "unicharset.h" 00044 00045 namespace tesseract { 00046 00047 class ShapeClassifier; 00048 00049 // Simple struct to hold the distance between two shapes during clustering. 00050 struct ShapeDist { 00051 ShapeDist() : shape1(0), shape2(0), distance(0.0f) {} 00052 ShapeDist(int s1, int s2, float dist) 00053 : shape1(s1), shape2(s2), distance(dist) {} 00054 00055 // Sort operator to sort in ascending order of distance. 00056 bool operator<(const ShapeDist& other) const { 00057 return distance < other.distance; 00058 } 00059 00060 int shape1; 00061 int shape2; 00062 float distance; 00063 }; 00064 00065 // Class to encapsulate training processes that use the TrainingSampleSet. 00066 // Initially supports shape clustering and mftrainining. 00067 // Other important features of the MasterTrainer are conditioning the data 00068 // by outlier elimination, replication with perturbation, and serialization. 00069 class MasterTrainer { 00070 public: 00071 MasterTrainer(NormalizationMode norm_mode, bool shape_analysis, 00072 bool replicate_samples, int debug_level); 00073 ~MasterTrainer(); 00074 00075 // Writes to the given file. Returns false in case of error. 00076 bool Serialize(FILE* fp) const; 00077 // Reads from the given file. Returns false in case of error. 00078 // If swap is true, assumes a big/little-endian swap is needed. 00079 bool DeSerialize(bool swap, FILE* fp); 00080 00081 // Loads an initial unicharset, or sets one up if the file cannot be read. 00082 void LoadUnicharset(const char* filename); 00083 00084 // Sets the feature space definition. 00085 void SetFeatureSpace(const IntFeatureSpace& fs) { 00086 feature_space_ = fs; 00087 feature_map_.Init(fs); 00088 } 00089 00090 // Reads the samples and their features from the given file, 00091 // adding them to the trainer with the font_id from the content of the file. 00092 // If verification, then these are verification samples, not training. 00093 void ReadTrainingSamples(const char* page_name, 00094 const FEATURE_DEFS_STRUCT& feature_defs, 00095 bool verification); 00096 00097 // Adds the given single sample to the trainer, setting the classid 00098 // appropriately from the given unichar_str. 00099 void AddSample(bool verification, const char* unichar_str, 00100 TrainingSample* sample); 00101 00102 // Loads all pages from the given tif filename and append to page_images_. 00103 // Must be called after ReadTrainingSamples, as the current number of images 00104 // is used as an offset for page numbers in the samples. 00105 void LoadPageImages(const char* filename); 00106 00107 // Cleans up the samples after initial load from the tr files, and prior to 00108 // saving the MasterTrainer: 00109 // Remaps fragmented chars if running shape anaylsis. 00110 // Sets up the samples appropriately for class/fontwise access. 00111 // Deletes outlier samples. 00112 void PostLoadCleanup(); 00113 00114 // Gets the samples ready for training. Use after both 00115 // ReadTrainingSamples+PostLoadCleanup or DeSerialize. 00116 // Re-indexes the features and computes canonical and cloud features. 00117 void PreTrainingSetup(); 00118 00119 // Sets up the master_shapes_ table, which tells which fonts should stay 00120 // together until they get to a leaf node classifier. 00121 void SetupMasterShapes(); 00122 00123 // Adds the junk_samples_ to the main samples_ set. Junk samples are initially 00124 // fragments and n-grams (all incorrectly segmented characters). 00125 // Various training functions may result in incorrectly segmented characters 00126 // being added to the unicharset of the main samples, perhaps because they 00127 // form a "radical" decomposition of some (Indic) grapheme, or because they 00128 // just look the same as a real character (like rn/m) 00129 // This function moves all the junk samples, to the main samples_ set, but 00130 // desirable junk, being any sample for which the unichar already exists in 00131 // the samples_ unicharset gets the unichar-ids re-indexed to match, but 00132 // anything else gets re-marked as unichar_id 0 (space character) to identify 00133 // it as junk to the error counter. 00134 void IncludeJunk(); 00135 00136 // Replicates the samples and perturbs them if the enable_replication_ flag 00137 // is set. MUST be used after the last call to OrganizeByFontAndClass on 00138 // the training samples, ie after IncludeJunk if it is going to be used, as 00139 // OrganizeByFontAndClass will eat the replicated samples into the regular 00140 // samples. 00141 void ReplicateAndRandomizeSamplesIfRequired(); 00142 00143 // Loads the basic font properties file into fontinfo_table_. 00144 // Returns false on failure. 00145 bool LoadFontInfo(const char* filename); 00146 00147 // Loads the xheight font properties file into xheights_. 00148 // Returns false on failure. 00149 bool LoadXHeights(const char* filename); 00150 00151 // Reads spacing stats from filename and adds them to fontinfo_table. 00152 // Returns false on failure. 00153 bool AddSpacingInfo(const char *filename); 00154 00155 // Returns the font id corresponding to the given font name. 00156 // Returns -1 if the font cannot be found. 00157 int GetFontInfoId(const char* font_name); 00158 // Returns the font_id of the closest matching font name to the given 00159 // filename. It is assumed that a substring of the filename will match 00160 // one of the fonts. If more than one is matched, the longest is returned. 00161 int GetBestMatchingFontInfoId(const char* filename); 00162 00163 // Returns the filename of the tr file corresponding to the command-line 00164 // argument with the given index. 00165 const STRING& GetTRFileName(int index) const { 00166 return tr_filenames_[index]; 00167 } 00168 00169 // Sets up a flat shapetable with one shape per class/font combination. 00170 void SetupFlatShapeTable(ShapeTable* shape_table); 00171 00172 // Sets up a Clusterer for mftraining on a single shape_id. 00173 // Call FreeClusterer on the return value after use. 00174 CLUSTERER* SetupForClustering(const ShapeTable& shape_table, 00175 const FEATURE_DEFS_STRUCT& feature_defs, 00176 int shape_id, int* num_samples); 00177 00178 // Writes the given float_classes (produced by SetupForFloat2Int) as inttemp 00179 // to the given inttemp_file, and the corresponding pffmtable. 00180 // The unicharset is the original encoding of graphemes, and shape_set should 00181 // match the size of the shape_table, and may possibly be totally fake. 00182 void WriteInttempAndPFFMTable(const UNICHARSET& unicharset, 00183 const UNICHARSET& shape_set, 00184 const ShapeTable& shape_table, 00185 CLASS_STRUCT* float_classes, 00186 const char* inttemp_file, 00187 const char* pffmtable_file); 00188 00189 const UNICHARSET& unicharset() const { 00190 return samples_.unicharset(); 00191 } 00192 TrainingSampleSet* GetSamples() { 00193 return &samples_; 00194 } 00195 const ShapeTable& master_shapes() const { 00196 return master_shapes_; 00197 } 00198 00199 // Generates debug output relating to the canonical distance between the 00200 // two given UTF8 grapheme strings. 00201 void DebugCanonical(const char* unichar_str1, const char* unichar_str2); 00202 #ifndef GRAPHICS_DISABLED 00203 // Debugging for cloud/canonical features. 00204 // Displays a Features window containing: 00205 // If unichar_str2 is in the unicharset, and canonical_font is non-negative, 00206 // displays the canonical features of the char/font combination in red. 00207 // If unichar_str1 is in the unicharset, and cloud_font is non-negative, 00208 // displays the cloud feature of the char/font combination in green. 00209 // The canonical features are drawn first to show which ones have no 00210 // matches in the cloud features. 00211 // Until the features window is destroyed, each click in the features window 00212 // will display the samples that have that feature in a separate window. 00213 void DisplaySamples(const char* unichar_str1, int cloud_font, 00214 const char* unichar_str2, int canonical_font); 00215 #endif // GRAPHICS_DISABLED 00216 00217 void TestClassifierVOld(bool replicate_samples, 00218 ShapeClassifier* test_classifier, 00219 ShapeClassifier* old_classifier); 00220 00221 // Tests the given test_classifier on the internal samples. 00222 // See TestClassifier for details. 00223 void TestClassifierOnSamples(CountTypes error_mode, 00224 int report_level, 00225 bool replicate_samples, 00226 ShapeClassifier* test_classifier, 00227 STRING* report_string); 00228 // Tests the given test_classifier on the given samples 00229 // error_mode indicates what counts as an error. 00230 // report_levels: 00231 // 0 = no output. 00232 // 1 = bottom-line error rate. 00233 // 2 = bottom-line error rate + time. 00234 // 3 = font-level error rate + time. 00235 // 4 = list of all errors + short classifier debug output on 16 errors. 00236 // 5 = list of all errors + short classifier debug output on 25 errors. 00237 // If replicate_samples is true, then the test is run on an extended test 00238 // sample including replicated and systematically perturbed samples. 00239 // If report_string is non-NULL, a summary of the results for each font 00240 // is appended to the report_string. 00241 double TestClassifier(CountTypes error_mode, 00242 int report_level, 00243 bool replicate_samples, 00244 TrainingSampleSet* samples, 00245 ShapeClassifier* test_classifier, 00246 STRING* report_string); 00247 00248 // Returns the average (in some sense) distance between the two given 00249 // shapes, which may contain multiple fonts and/or unichars. 00250 // This function is public to facilitate testing. 00251 float ShapeDistance(const ShapeTable& shapes, int s1, int s2); 00252 00253 private: 00254 // Replaces samples that are always fragmented with the corresponding 00255 // fragment samples. 00256 void ReplaceFragmentedSamples(); 00257 00258 // Runs a hierarchical agglomerative clustering to merge shapes in the given 00259 // shape_table, while satisfying the given constraints: 00260 // * End with at least min_shapes left in shape_table, 00261 // * No shape shall have more than max_shape_unichars in it, 00262 // * Don't merge shapes where the distance between them exceeds max_dist. 00263 void ClusterShapes(int min_shapes, int max_shape_unichars, 00264 float max_dist, ShapeTable* shape_table); 00265 00266 private: 00267 NormalizationMode norm_mode_; 00268 // Character set we are training for. 00269 UNICHARSET unicharset_; 00270 // Original feature space. Subspace mapping is contained in feature_map_. 00271 IntFeatureSpace feature_space_; 00272 TrainingSampleSet samples_; 00273 TrainingSampleSet junk_samples_; 00274 TrainingSampleSet verify_samples_; 00275 // Master shape table defines what fonts stay together until the leaves. 00276 ShapeTable master_shapes_; 00277 // Flat shape table has each unichar/font id pair in a separate shape. 00278 ShapeTable flat_shapes_; 00279 // Font metrics gathered from multiple files. 00280 FontInfoTable fontinfo_table_; 00281 // Array of xheights indexed by font ids in fontinfo_table_; 00282 GenericVector<inT32> xheights_; 00283 00284 // Non-serialized data initialized by other means or used temporarily 00285 // during loading of training samples. 00286 // Number of different class labels in unicharset_. 00287 int charsetsize_; 00288 // Flag to indicate that we are running shape analysis and need fragments 00289 // fixing. 00290 bool enable_shape_anaylsis_; 00291 // Flag to indicate that sample replication is required. 00292 bool enable_replication_; 00293 // Array of classids of fragments that replace the correctly segmented chars. 00294 int* fragments_; 00295 // Classid of previous correctly segmented sample that was added. 00296 int prev_unichar_id_; 00297 // Debug output control. 00298 int debug_level_; 00299 // Feature map used to construct reduced feature spaces for compact 00300 // classifiers. 00301 IntFeatureMap feature_map_; 00302 // Vector of Pix pointers used for classifiers that need the image. 00303 // Indexed by page_num_ in the samples. 00304 // These images are owned by the trainer and need to be pixDestroyed. 00305 GenericVector<Pix*> page_images_; 00306 // Vector of filenames of loaded tr files. 00307 GenericVector<STRING> tr_filenames_; 00308 }; 00309 00310 } // namespace tesseract. 00311 00312 #endif